#include "tcpip.hpp"
#include "debug.hpp"
#include "classifier.hpp"
#include <new>
#include <stdio.h>
#include <string.h>

extern "C" {
	#include <netinet/in.h>
	#include <netinet/ip.h>
	#include <netinet/tcp.h>
	#include <netinet/udp.h>
	#include <arpa/inet.h>
}

using namespace std;

static int set_if(char *str, char **p_if, int *len);

ipv4_packet_classifier::ipv4_packet_classifier()
{
	saddr = daddr = 0;
	smask = dmask = u_int32_t(-1L);
	user_prio = real_prio = -1;
	proto = __IP_HDR_PROTO_ALL__;
	sport_list = dport_list = 0;
	n_sports = n_dports = 0;
	inp_if = out_if = 0;
}

ipv4_packet_classifier::~ipv4_packet_classifier()
{
	delete[] sport_list;
	delete[] dport_list;
	delete[] inp_if;
	delete[] out_if;
}

// data in host order
// 1 -> hit
// 0 -> miss
inline int ipv4_packet_classifier::check_port(int port, port_range *port_list, 
	int n_ports)
{
	if( n_ports==0 ) return 1;
	for( int i = 0 ; i<n_ports ; i++ ) {
		int lo = port_list[i].low;
		int hi = port_list[i].high;
		if( port>=lo && port<=hi ) return 1;
	}
	return 0;
}

// 1 -> packet hit
// 0 -> miss (didn't match)
int ipv4_packet_classifier::check(generic_packet *pckt)
{
	struct iphdr *ip_hdr = (struct iphdr *)(pckt->payload);
	u_int32_t packet_masked_saddr, packet_masked_daddr;
	int result;

	packet_masked_saddr = ip_hdr->saddr & smask;
	packet_masked_daddr = ip_hdr->daddr & dmask;
	if( packet_masked_saddr != saddr ) return 0;
	if( packet_masked_daddr != daddr ) return 0;

	if( inp_if ) {
		if( pckt->inp_if==0 ) return 0;
		result = strncmp(inp_if, pckt->inp_if, n_inp_if);
		if( result!=0 ) return 0;
	}
	if( out_if ) {
		if( pckt->out_if==0 ) return 0;
		result = strncmp(out_if, pckt->out_if, n_out_if);
		if( result!=0 ) return 0;
	}

	if( proto>=0 ) {
		if( proto!=int(ip_hdr->protocol) ) return 0;

		if( proto==__IP_HDR_PROTO_TCP__ ) {
			unsigned int hl = ip_hdr->ihl << 2;
			void *ip_payload = ((char*)ip_hdr) + hl;
			struct tcphdr *tcp_hdr = (struct tcphdr*)ip_payload;

			result = check_port(ntohs(tcp_hdr->source), sport_list, 
				n_sports);
			if( result==0 ) return 0;

			result = check_port(ntohs(tcp_hdr->dest), dport_list, 
				n_dports);
			if( result==0 ) return 0;

		} else if( proto==__IP_HDR_PROTO_UDP__ ) {
			unsigned int hl = ip_hdr->ihl << 2;
			void *ip_payload = (char*)(ip_hdr) + hl;
			struct udphdr *udp_hdr = (struct udphdr*)ip_payload;

			result = check_port(ntohs(udp_hdr->source), sport_list, 
				n_sports);
			if( result==0 ) return 0;

			result = check_port(ntohs(udp_hdr->dest), dport_list, 
				n_dports);
			if( result==0 ) return 0;			
		}
	}

	return 1;
}

int ipv4_packet_classifier::set_saddr(const char *addr, const char *mask)
{
	int result;
	struct in_addr tmpaddr, tmpmask;

	result = inet_aton(addr, &tmpaddr);
	if( result==0 ) return -1;

	result = inet_aton(mask, &tmpmask);
	if( result==0 ) return -1;

	saddr = tmpaddr.s_addr & tmpmask.s_addr;
	smask = tmpmask.s_addr;

	return 0;
}

int ipv4_packet_classifier::set_daddr(const char *addr, const char *mask)
{
	int result;
	struct in_addr tmpaddr, tmpmask;

	result = inet_aton(addr, &tmpaddr);
	if( result==0 ) return -1;

	result = inet_aton(mask, &tmpmask);
	if( result==0 ) return -1;

	daddr = tmpaddr.s_addr & tmpmask.s_addr;
	dmask = tmpmask.s_addr;

	return 0;
}

int ipv4_packet_classifier::set_proto(int n)
{
	if( n<-1 || n>255 ) return -1;
	proto = n;
	return 0;
}

int ipv4_packet_classifier::add_src_port(int lo, int hi)
{
	if( lo<0 || lo>65535 ) return -1;
	if( hi<0 || hi>65535 ) return -1;
	if( lo>hi ) return -1;

	struct port_range *new_list = new(nothrow) port_range[n_sports+1];
	if( new_list==0 ) return -1;

	if( n_sports>0 ) {
		memcpy(new_list, sport_list, n_sports*sizeof(port_range));
		delete[] sport_list;
	}
	sport_list = new_list;

	sport_list[n_sports].low = lo;
	sport_list[n_sports].high = hi;
	n_sports++;
	return 0;
}

int ipv4_packet_classifier::add_dst_port(int lo, int hi)
{
	if( lo<0 || lo>65535 ) return -1;
	if( hi<0 || hi>65535 ) return -1;
	if( lo>hi ) return -1;

	struct port_range *new_list = new(nothrow) port_range[n_dports+1];
	if( new_list==0 ) return -1;

	if( n_dports>0 ) {
		memcpy(new_list, dport_list, n_dports*sizeof(port_range));
		delete[] dport_list;
	}
	dport_list = new_list;

	dport_list[n_dports].low = lo;
	dport_list[n_dports].high = hi;
	n_dports++;
	return 0;
}

char *ipv4_packet_classifier::log_attr(char *line, int bytes)
{
	int i, n;
	char *proto_str, *inp_if_str, *out_if_str;
	struct in_addr in_saddr, in_smask, in_daddr, in_dmask;
	char saddr_str[64], daddr_str[64], smask_str[64], dmask_str[64];

	switch(proto) {
		case __IP_HDR_PROTO_TCP__ : proto_str = "tcp";  break;
		case __IP_HDR_PROTO_UDP__ : proto_str = "udp";  break;
		case __IP_HDR_PROTO_ICMP__: proto_str = "icmp"; break;
		case __IP_HDR_PROTO_ALL__ : proto_str = "all";  break;
		default: proto_str = "?"; break;
	}

	inp_if_str = inp_if ? inp_if : ((char*)"all");
	out_if_str = out_if ? out_if : ((char*)"all");

	in_saddr.s_addr = saddr;
	in_smask.s_addr = smask;
	in_daddr.s_addr = daddr;
	in_dmask.s_addr = dmask;
	snprintf(saddr_str, 63, "%s", inet_ntoa(in_saddr));
	snprintf(smask_str, 63, "%s", inet_ntoa(in_smask));
	snprintf(daddr_str, 63, "%s", inet_ntoa(in_daddr));
	snprintf(dmask_str, 63, "%s", inet_ntoa(in_dmask));

	n = snprintf(line, bytes-1, "prio=%d(%d) proto=%s "
		"inp_if=\"%s\" out_if=\"%s\" saddr=%s/%s sport=", 
		user_prio, real_prio, proto_str, inp_if_str, out_if_str, 
		saddr_str, smask_str);
	if( n>=bytes-2 )
		return "buffer overflow!";

	if( n_sports==0 ) {
		n += snprintf(line+n, bytes-n-1, "all ");
		if( n>=bytes-2) 
			return "buffer overflow!";
	} else {
		for( i = 0 ; i<n_sports ; i++ ) {
			int lo = sport_list[i].low;
			int hi = sport_list[i].high;
			if( lo==hi ) {
				n += snprintf(line+n, bytes-n-1, 
					i==n_sports-1 ? "%d " : "%d,", 
					lo);
			} else {
				n += snprintf(line+n, bytes-n-1, 
					i==n_sports-1 ? "%d-%d " : "%d-%d,", 
					lo, hi);
			}
			if( n>=bytes-2 ) 
				return "buffer overflow!";
		}
	}

	n += snprintf(line+n, bytes-n-1, "daddr=%s/%s dport=",
		daddr_str, dmask_str);
	if( n>=bytes-2 ) 
		return "buffer overflow!";

	if( n_dports==0 ) {
		n += snprintf(line+n, bytes-n-1, "all");
		if( n>=bytes-2 ) 
			return "buffer overflow!";
	} else {
		for( i = 0 ; i<n_dports ; i++ ) {
			int lo = dport_list[i].low;
			int hi = dport_list[i].high;
			if( lo==hi ) {
				n += snprintf(line+n, bytes-n-1, 
					i==n_dports-1 ? "%d " : "%d,", 
					lo);
			} else {
				n += snprintf(line+n, bytes-n-1, 
					i==n_dports-1 ? "%d-%d " : "%d-%d,",
					lo, hi);
			}
			if( n>=bytes-2 ) 
				return "buffer overflow!";
		}
	}

	return line;
}

char *ipv4_packet_classifier::log()
{
	int n;
	static char line[1024];

	n = snprintf(line, 1024, "%s",
		"ipv4 classifier ");
	log_attr(line+n, 1024-n);
	return line;
}

int ipv4_packet_classifier::set_user_prio(int pr)
{
	if( pr<0 ) return -1;
	user_prio = pr;	
	return 0;
}

int ipv4_packet_classifier::set_real_prio(int pr)
{
	if( pr<0 ) return -1;
	real_prio = pr;	
	return 0;
}

int ipv4_packet_classifier::get_user_prio()
{
	return user_prio;
}

int ipv4_packet_classifier::get_real_prio()
{
	return real_prio;
}

// doc, i think i'm starting to hate c++
static int set_if(char *str, char **p_if, int *len)
{
	char *p;
	int n = strlen(str);

	assert(n>1);

	p = new(nothrow) char[n+1];
	if( p==0 ) return -1;

	memcpy(p, str, n+1);
	*p_if  = p;
	*len = p[n-1]=='+' ? n-1 : n;

	return 0;
}

int ipv4_packet_classifier::set_inp_if(char *str)
{
	assert(inp_if==0);
	return set_if(str, &inp_if, &n_inp_if);
}

int ipv4_packet_classifier::set_out_if(char *str)
{
	assert(out_if==0);
	return set_if(str, &out_if, &n_out_if);
}

#ifdef WITH_IPQ
nf_packet_classifier::nf_packet_classifier()
{
	fwmark = fwmask = 0;
}

nf_packet_classifier::~nf_packet_classifier()
{
}

// 1 -> packet hit
// 0 -> miss (didn't match)
int nf_packet_classifier::check(generic_packet *pckt)
{
	int result;

	result = ipv4_packet_classifier::check(pckt);
	if (result != 1)
		return 0; 

	// this could be easily done with c++'s rtti support
	// but, it isn't necessary since there's just one and
	// only source for the packets
	ipq_packet *ipq_pckt = (ipq_packet*)(pckt);

	if ((ipq_pckt->meta->mark&fwmask) != fwmark)
		return 0;

	return 1;
}

char *nf_packet_classifier::log_attr(char *line, int bytes)
{
	char *pc;
	int n, len;

	pc = ipv4_packet_classifier::log_attr(line, bytes);
	if (pc != line)
		return pc;

	len = strlen(line);
	pc = line + len;
	n = snprintf(pc, bytes-len, " fwmark=0x%lx/0x%lx",
		fwmark, fwmask);
	if (n >= bytes-len)
		return "buffer overflow!";

	return line;
}

char *nf_packet_classifier::log()
{
	int n;
	static char line[1024];

	n = snprintf(line, 1024, "%s",
		"nf classifier ");
	log_attr(line+n, 1024);
	return line;
}
#endif
