use std::cell::RefCell;
use std::collections::HashMap;
use std::io::Write;
use std::rc::Rc;

use rand::{thread_rng, SeedableRng};
use rand_xoshiro::Xoshiro256StarStar;

mod func;
mod opcode;
mod operator;
mod value;

use crate::ast::Parser;
use crate::compiler::Compiler;
use crate::error::Result;

pub use opcode::{HeapID, OpCode};
pub use operator::{BinOp, UnOp};
pub use value::{FuncDef, FuncImpl, Table, Value, VarType};
use value::{Numeric, VmValue};

struct Heap {
    heap: HashMap<usize, VmValue>,
    next: usize,
}

impl Heap {
    fn new() -> Self {
        Self {
            heap: HashMap::new(),
            next: 0,
        }
    }

    fn alloc(&mut self) -> usize {
        let ind = self.next;
        self.next += 1;
        ind
    }

    fn get(&self, ind: &usize) -> VmValue {
        self.heap
            .get(ind)
            .cloned()
            .unwrap_or(VmValue::Single(Value::Nil))
    }

    fn set(&mut self, ind: usize, val: VmValue) {
        self.heap.insert(ind, val);
    }
}

struct StackFrame {
    ret: Option<(usize, usize)>,
    locals: Vec<usize>,
    varargs: Vec<Value>,
}

pub struct VM<'a> {
    code: Vec<Vec<OpCode>>,
    env: Rc<RefCell<Table>>,
    heap: Heap,
    frames: Vec<StackFrame>,
    stack: Vec<VmValue>,
    rng: Xoshiro256StarStar,
    out: &'a mut dyn Write,
}

impl<'a> VM<'a> {
    pub fn create(out: &'a mut dyn Write) -> Self {
        let mut env = Table::default();
        func::set_builtin_funcs(&mut env);

        Self {
            code: Vec::new(),
            env: Rc::new(RefCell::new(env)),
            heap: Heap::new(),
            frames: Vec::new(),
            stack: Vec::new(),
            rng: Xoshiro256StarStar::from_rng(thread_rng()).unwrap(),
            out,
        }
    }

    pub fn run_str(&mut self, input: &str) -> Result<()> {
        let ast = Parser::parse(input)?;
        self.code = Compiler::compile(ast);

        self.run_loop();
        self.out.flush().unwrap();

        assert!(self.frames.is_empty());
        assert!(self.stack.is_empty());
        Ok(())
    }

    fn run_loop(&mut self) {
        let mut pc = (0, 0);

        loop {
            let current = &self.code[pc.0][pc.1];
            pc = (pc.0, pc.1 + 1);

            match current {
                OpCode::PushEnv => {
                    self.stack
                        .push(VmValue::Single(Value::Table(self.env.clone())));
                }
                OpCode::PushLit(val) => {
                    self.stack.push(VmValue::Single(val.clone()));
                }
                OpCode::PushLocal(ind) => {
                    let frame = self.frames.last().unwrap();
                    let local = frame.locals.get(*ind).unwrap();
                    let val = self.heap.get(local);
                    self.stack.push(val);
                }
                OpCode::PushLocalMult(local, ind) => {
                    let frame = self.frames.last().unwrap();
                    let heap = frame.locals.get(*local).unwrap();
                    let varargs = self.heap.get(heap);

                    match varargs {
                        VmValue::Single(..) => unreachable!(),
                        VmValue::Multiple(vec) => {
                            self.stack.push(VmValue::Single(match vec.get(*ind) {
                                Some(val) => val.clone(),
                                None => Value::Nil,
                            }))
                        }
                    }
                }
                OpCode::PushVarargs => {
                    let mut frames = self.frames.iter().rev();

                    loop {
                        match frames.next() {
                            Some(frame) => {
                                if frame.ret.is_some() {
                                    self.stack.push(VmValue::Multiple(frame.varargs.clone()));
                                    break;
                                }
                            }
                            None => panic!("not in function"),
                        }
                    }
                }
                OpCode::PopDisc => {
                    self.stack.pop().unwrap();
                }
                OpCode::PopLocal(ind) => {
                    let val = self.stack.pop().unwrap();
                    let frame = self.frames.last().unwrap();
                    let local = frame.locals.get(*ind).unwrap();
                    self.heap.set(*local, val);
                }
                OpCode::Copy => {
                    let val = self.stack.last().unwrap().clone();
                    self.stack.push(val);
                }
                OpCode::BinOp(op) => {
                    let rhs = self.stack.pop().unwrap().single();
                    let lhs = self.stack.pop().unwrap().single();

                    self.stack.push(VmValue::Single(lhs.bin_op(*op, rhs)));
                }
                OpCode::UnOp(op) => {
                    let val = self.stack.pop().unwrap().single();
                    self.stack.push(VmValue::Single(val.un_op(*op)));
                }
                OpCode::Single => {
                    let val = self.stack.pop().unwrap().single();
                    self.stack.push(VmValue::Single(val));
                }
                OpCode::Combine(count) => {
                    let vals = if *count > 0 {
                        // If last value (first on top of stack) is varargs, extend list
                        let mut vals = match self.stack.pop().unwrap() {
                            VmValue::Single(val) => vec![val],
                            VmValue::Multiple(vec) => vec,
                        };
                        for _ in 1..*count {
                            vals.insert(0, self.stack.pop().unwrap().single());
                        }
                        vals
                    } else {
                        Vec::new()
                    };
                    self.stack.push(VmValue::Multiple(vals));
                }
                OpCode::GetTbl => {
                    let ind = self.stack.pop().unwrap().single();
                    let tbl = self.stack.pop().unwrap().single();

                    match tbl {
                        Value::Table(tbl) => {
                            self.stack.push(VmValue::Single(tbl.borrow().get(ind)))
                        }
                        _ => panic!("could not find value in table"),
                    }
                }
                OpCode::SetTbl => {
                    let tbl = self.stack.pop().unwrap().single();
                    let ind = self.stack.pop().unwrap().single();
                    let val = self.stack.pop().unwrap().single();

                    match tbl {
                        Value::Table(tbl) => tbl.borrow_mut().set(ind, val),
                        _ => panic!(),
                    }
                }
                OpCode::BuildTbl(len) => {
                    let mut tbl = Table::default();

                    for _ in 0..*len {
                        let ind = self.stack.pop().unwrap().single();
                        let val = self.stack.pop().unwrap().single();

                        tbl.set(ind, val);
                    }

                    self.stack
                        .push(VmValue::Single(Value::Table(Rc::new(RefCell::new(tbl)))));
                }
                OpCode::Closure => {
                    let last_frame = self.frames.last().unwrap();
                    let mut func = match self.stack.pop().unwrap().single() {
                        Value::Func(func) => (*func).clone(),
                        _ => panic!("value not a function"),
                    };

                    let mut locals = Vec::new();
                    for local in func.locals.into_iter() {
                        let mapped = match local {
                            VarType::Up(ind) => {
                                let heap = last_frame.locals.get(ind).unwrap();
                                VarType::Up(*heap)
                            }
                            local => local,
                        };
                        locals.push(mapped);
                    }
                    func.locals = locals;

                    self.stack.push(VmValue::Single(Value::Func(Rc::new(func))));
                }
                OpCode::Jump(offset) => {
                    if *offset >= 0 {
                        pc = (pc.0, pc.1 + *offset as usize);
                    } else {
                        let offset = offset.abs() as usize;
                        assert!(pc.1 >= offset);
                        pc = (pc.0, pc.1 - offset);
                    }
                }
                OpCode::JumpIf(offset) => {
                    let val = self.stack.pop().unwrap().single();
                    if val.is_truthy() {
                        if *offset >= 0 {
                            pc = (pc.0, pc.1 + *offset as usize);
                        } else {
                            let offset = offset.abs() as usize;
                            assert!(pc.1 >= offset);
                            pc = (pc.0, pc.1 - offset);
                        }
                    }
                }
                OpCode::JumpIfNot(offset) => {
                    let val = self.stack.pop().unwrap().single();
                    if val.is_falsy() {
                        if *offset >= 0 {
                            pc = (pc.0, pc.1 + *offset as usize);
                        } else {
                            let offset = offset.abs() as usize;
                            assert!(pc.1 >= offset);
                            pc = (pc.0, pc.1 - offset);
                        }
                    }
                }
                OpCode::Goto(count, offset) => {
                    for _ in 0..*count {
                        self.frames.pop().unwrap();
                    }
                    if *offset >= 0 {
                        pc = (pc.0, pc.1 + *offset as usize);
                    } else {
                        let offset = offset.abs() as usize;
                        assert!(pc.1 >= offset);
                        pc = (pc.0, pc.1 - offset);
                    }
                }
                OpCode::ScopeEnter(local_defs) => {
                    let last_frame = self.frames.last();
                    let mut locals = Vec::new();

                    for local_type in local_defs.iter() {
                        let heap = match local_type {
                            VarType::Local => self.heap.alloc(),
                            VarType::Param => unreachable!(),
                            // Always evaluated while upvalues are in scope,
                            // no need for closing here
                            VarType::Up(up) => match last_frame {
                                Some(last_frame) => last_frame.locals[*up],
                                None => panic!(),
                            },
                        };
                        locals.push(heap);
                    }

                    self.frames.push(StackFrame {
                        locals,
                        varargs: vec![],
                        ret: None,
                    });
                }
                OpCode::ScopeLeave => {
                    self.frames.pop().unwrap();
                    if self.frames.is_empty() {
                        return;
                    }
                }
                OpCode::Call(argc) => {
                    let func = match self.stack.pop().unwrap().single() {
                        Value::Func(func) => func,
                        _ => panic!("value not a function"),
                    };

                    // Args list is taken from stack, so in reverse order
                    let mut args_raw = self.stack.split_off(self.stack.len() - argc).into_iter();
                    // Extend argument list if final argument is varargs
                    let mut args = match args_raw.next() {
                        Some(last) => match last {
                            VmValue::Single(val) => vec![val],
                            VmValue::Multiple(mut vec) => {
                                vec.reverse();
                                vec
                            }
                        },
                        None => vec![],
                    };
                    // Take rest of arguments
                    args.extend(args_raw.map(VmValue::single));

                    let mut locals = Vec::new();
                    for local_type in func.locals.iter() {
                        let heap = match local_type {
                            VarType::Local => self.heap.alloc(),
                            VarType::Param => {
                                let heap = self.heap.alloc();
                                if let Some(val) = args.pop() {
                                    self.heap.set(heap, VmValue::Single(val));
                                }
                                heap
                            }
                            VarType::Up(heap) => *heap,
                        };
                        locals.push(heap);
                    }

                    self.frames.push(StackFrame {
                        locals,
                        varargs: if func.varargs {
                            args.into_iter().rev().collect()
                        } else {
                            Vec::new()
                        },
                        ret: Some(pc),
                    });

                    match &func.r#impl {
                        FuncImpl::Builtin { func, .. } => {
                            let ret = func(self).unwrap();
                            self.stack.push(ret);
                            // Popping because there is no return opcode in case of built in function
                            self.frames.pop();
                        }
                        FuncImpl::Defined(label) => {
                            pc = (*label, 0);
                        }
                    }
                }
                OpCode::Return => {
                    // Unroll call stack
                    loop {
                        match self.frames.pop() {
                            Some(frame) => {
                                if let Some(ret) = frame.ret {
                                    pc = ret;
                                    break;
                                }
                            }
                            None => return,
                        }
                    }
                }
            }

            self.out.flush().unwrap();
        }
    }
}
