use crate::ast::{For, ForGen, ForNum};
use crate::compiler::{offset, Compiler, ScopeType};
use crate::vm::{BinOp, OpCode, Oper, Value};

impl<'a> Compiler<'a> {
    pub(super) fn compile_for(&mut self, r#for: For) {
        match r#for {
            For::Numeric(f) => self.compile_for_num(f),
            For::Generic(f) => self.compile_for_gen(f),
        }
    }

    fn compile_for_num(&mut self, r#for: ForNum) {
        self.scope_enter(ScopeType::Loop);

        // Declare control variable and set to init
        let ctrl = self.scopes.declare_local(r#for.name, None);
        let ctrl_reg = self.compile_exp(*r#for.init);
        self.code.emit(OpCode::LocalSet {
            src_reg: ctrl_reg,
            dst_loc: ctrl,
        });

        // Evaluate limit and step values
        let limit_reg = self.compile_exp(*r#for.limit);
        let step_reg = match r#for.step {
            Some(step) => self.compile_exp(*step),
            None => {
                let dst_reg = self.scopes.reg_reserve();
                self.code.emit(OpCode::Lit {
                    val: Value::int(1),
                    dst_reg,
                });
                dst_reg
            }
        };

        // Check limit
        let start = self.code.pos();
        self.code.emit(OpCode::BinOp {
            lhs: Oper::Local(ctrl),
            rhs: Oper::Reg(limit_reg),
            op: BinOp::Leq,
            dst_reg: ctrl_reg,
        });

        // Jump to end if false, placeholder
        let jump_end = self.code.pos();
        self.code.emit(OpCode::Jump { off: 0 }); // Placeholder

        // Block
        self.compile_block(*r#for.block);

        // Step
        self.code.emit(OpCode::BinOp {
            lhs: Oper::Local(ctrl),
            rhs: Oper::Reg(step_reg),
            op: BinOp::Add,
            dst_reg: ctrl_reg,
        });
        self.code.emit(OpCode::LocalSet {
            src_reg: ctrl_reg,
            dst_loc: ctrl,
        });

        let off = offset(self.code.pos(), start);
        self.code.emit(OpCode::Jump { off });

        // Fill in jump to end
        let off = offset(jump_end, self.code.pos());
        self.code.set(
            jump_end,
            OpCode::JumpIfNot {
                cmp_reg: ctrl_reg, // At point of placeholder, ctrl_reg has comparison result
                off,
            },
        );

        self.scopes.reg_free(ctrl_reg);
        self.scopes.reg_free(limit_reg);
        self.scopes.reg_free(step_reg);
        self.scope_leave(ScopeType::Loop);
    }

    fn compile_for_gen(&mut self, r#for: ForGen) {
        self.scope_enter(ScopeType::Loop);

        // Declare variables
        let mut vars = Vec::new();
        for name in r#for.names.into_iter() {
            let loc = self.scopes.declare_local(name, None);
            vars.push(loc);
        }

        // Evaluate expressions
        let exp_reg = self.compile_exp(*r#for.exp);
        let args_reg = self.scopes.reg_reserve();

        // Initialize variables
        let iter = self.scopes.declare_local("--iter".to_string(), None);
        self.code.emit(OpCode::MovMult {
            src_reg: exp_reg,
            ind: 0,
            dst_reg: args_reg,
        });
        self.code.emit(OpCode::LocalSet {
            src_reg: args_reg,
            dst_loc: iter,
        });

        let state = self.scopes.declare_local("--state".to_string(), None);
        self.code.emit(OpCode::MovMult {
            src_reg: exp_reg,
            ind: 1,
            dst_reg: args_reg,
        });
        self.code.emit(OpCode::LocalSet {
            src_reg: args_reg,
            dst_loc: state,
        });

        let ctrl = vars[0];
        self.code.emit(OpCode::MovMult {
            src_reg: exp_reg,
            ind: 2,
            dst_reg: args_reg,
        });
        self.code.emit(OpCode::LocalSet {
            src_reg: args_reg,
            dst_loc: ctrl,
        });

        // Start of loop
        let start = self.code.pos();

        // Call iterator
        self.code.emit(OpCode::Lit {
            val: Value::empty(),
            dst_reg: args_reg,
        });
        self.code.emit(OpCode::LocalGet {
            src_loc: state,
            dst_reg: exp_reg,
        });
        self.code.emit(OpCode::Append {
            src_reg: exp_reg,
            dst_reg: args_reg,
        });
        self.code.emit(OpCode::LocalGet {
            src_loc: ctrl,
            dst_reg: exp_reg,
        });
        self.code.emit(OpCode::Append {
            src_reg: exp_reg,
            dst_reg: args_reg,
        });
        self.code.emit(OpCode::LocalGet {
            src_loc: iter,
            dst_reg: exp_reg,
        });
        self.code.emit(OpCode::Call {
            pos: r#for.pos,
            func_reg: exp_reg,
            args_reg,
            ret_reg: exp_reg,
        });

        // Assign variables
        for (ind, var) in vars.into_iter().enumerate() {
            self.code.emit(OpCode::MovMult {
                src_reg: exp_reg,
                ind,
                dst_reg: args_reg,
            });
            self.code.emit(OpCode::LocalSet {
                src_reg: args_reg,
                dst_loc: var,
            });
        }

        // Check control variable
        self.code.emit(OpCode::LocalGet {
            src_loc: ctrl,
            dst_reg: exp_reg,
        });
        self.code.emit(OpCode::Lit {
            val: Value::Nil,
            dst_reg: args_reg,
        });
        self.code.emit(OpCode::BinOp {
            lhs: Oper::Reg(exp_reg),
            rhs: Oper::Reg(args_reg),
            op: BinOp::Eq,
            dst_reg: exp_reg,
        });

        // Jump to end if false, placeholder
        let jump_end = self.code.pos();
        self.code.emit(OpCode::Jump { off: 0 }); // Placeholder

        // Block
        self.compile_block(*r#for.block);

        // Jump back to start
        let off = offset(self.code.pos(), start);
        self.code.emit(OpCode::Jump { off });

        // Fill in jump to end
        let off = offset(jump_end, self.code.pos());
        self.code.set(
            jump_end,
            OpCode::JumpIf {
                cmp_reg: exp_reg,
                off,
            },
        ); // reg_exp at point of insertion contains comparison result

        self.scopes.reg_free(exp_reg);
        self.scopes.reg_free(args_reg);
        self.scope_leave(ScopeType::Loop);
    }
}
