#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/kprobes.h>
#include <linux/semaphore.h>
#include <linux/slab.h>

#include <linux/fs.h>
#include <linux/device.h>

MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Suren A. Chilingaryan <csa@suren.me>");
MODULE_DESCRIPTION("NVIDIA Driver Tracer");
MODULE_VERSION("0.0.1");


#if !defined(CONFIG_X86) || !defined(CONFIG_X86_64)
# error "Only x86-64 platform is currently supported"
#endif 

DEFINE_SEMAPHORE(nv_sem);
static spinlock_t nv_lock;
static int nvtrace_enabled = 0;

static int nv_state = 0;
static unsigned nv_last_user = 0;
static unsigned nv_last_buffer = 0;

    // DS: we need to do it per pid 
static int handler_pre(struct kprobe *p, struct pt_regs *regs)
{
	char *logstr;
	int pos, size;
	int cmd = regs->dx&0xFF;

	switch (cmd) {
	 case 0x4a:
	    size = 0xb0;
	    break;
	 case 0x57:
	    size = 0x38;
	    break;
	 case 0x2a:
	    size = 0x20;
	    break;
	 default:
	    size = 4;
	}

	logstr = kmalloc(9 * (size / 4), GFP_KERNEL);
	if (logstr) {
	    for (pos = 0; (pos + 4) <= size; pos += 4)
		sprintf(logstr + 9 * (pos / 4), " %08x", *(uint32_t*)(regs->cx + pos));
	    printk(KERN_INFO " cmd = 0x%lx: %s\n", regs->dx, logstr);
	    kfree(logstr);
	}
	
/*
        printk(KERN_INFO " cmd = 0x%lx:", regs->dx);
        for (pos = 0; (pos + 4) <= size; pos += 4) {
	    printk(KERN_INFO " %08x", *(uint32_t*)(regs->cx + pos));
	}
        printk(KERN_INFO "\n");
*/

/*        printk(KERN_INFO 
		" cmd = 0x%lx: %08x %08x %08x %08x %08x %08x %08x %08x\n",
	    regs->dx, 
	    *(uint32_t*)(regs->cx + 0x00), *(uint32_t*)(regs->cx + 0x04), *(uint32_t*)(regs->cx + 0x08), *(uint32_t*)(regs->cx + 0x0C),
	    *(uint32_t*)(regs->cx + 0x10), *(uint32_t*)(regs->cx + 0x14), *(uint32_t*)(regs->cx + 0x18), *(uint32_t*)(regs->cx + 0x1C)
	);*/
	
	if ((nv_state == 0x4a)&&(cmd == 0x57)) {
	    nv_last_user = *(uint32_t*)regs->cx;
	    nv_last_buffer = *(uint32_t*)(regs->cx + 0x0C);

	    printk(KERN_INFO "cmd = 0x%lx, userid = 0x%x, bufferid = 0x%x\n", regs->dx, nv_last_user, nv_last_buffer);
	}
	nv_state = cmd;

	return 0;
}

static struct kprobe nv_curprobe, nv_probe = {
	.symbol_name	= "nvidia_ioctl",
	.pre_handler	= handler_pre
};


static ssize_t nv_enable_show(struct class *cls, struct class_attribute *attr, char *buf) {
	sprintf(buf, "%d", nvtrace_enabled);
	return 1;
}

static ssize_t nv_enable_store(struct class *cls, struct class_attribute *attr, const char *buf, size_t count) {
	int ret;
	int enable = 0;

	if ((sscanf(buf, "%d", &enable) != 1)||(enable < 0)||(enable > 1))
	    return -EINVAL;

	if (down_interruptible(&nv_sem))
	    return -EINVAL;

	if (enable != nvtrace_enabled) {
	    if (enable) {
		nv_curprobe = nv_probe;

		ret = register_kprobe(&nv_curprobe);
		if (ret < 0) {
		    up(&nv_sem);
		    printk(KERN_ERR "register_kprobe failed, returned %d\n", ret);
		    return ret;
		}
		printk(KERN_INFO "Planted kprobe at %p\n", nv_curprobe.addr);
	    } else {
		unregister_kprobe(&nv_curprobe);
		printk(KERN_INFO "kprobe at %p unregistered\n", nv_curprobe.addr);
	    }
	    nvtrace_enabled = enable;
	}
	up(&nv_sem);

	return count;
}

static ssize_t nv_user_show(struct class *cls, struct class_attribute *attr, char *buf) {
	sprintf(buf, "0x%x", nv_last_user);
	return strlen(buf);
}

static ssize_t nv_buffer_show(struct class *cls, struct class_attribute *attr, char *buf) {
	sprintf(buf, "0x%x", nv_last_buffer);
	return strlen(buf);
}

static struct class_attribute nv_attrs[] = {
    __ATTR(enable, S_IRUGO|S_IWUSR, nv_enable_show, nv_enable_store),
    __ATTR(user, S_IRUGO, nv_user_show, NULL),
    __ATTR(buffer, S_IRUGO, nv_buffer_show, NULL),
    __ATTR_NULL
};

static struct class nv_class = {
	.name		= "nvtrace",
	.owner		= THIS_MODULE,
	.class_attrs	= nv_attrs
};

static int __init nvtrace_init(void)
{
	spin_lock_init(&nv_lock);
	
	return class_register(&nv_class);
}

static void __exit nvtrace_exit(void)
{
	class_unregister(&nv_class);

	if (nvtrace_enabled) {
	    unregister_kprobe(&nv_curprobe);
	}
}

module_init(nvtrace_init)
module_exit(nvtrace_exit)