use super::{Csr, Fence, Imm, Instruction, Register};

use std::{
    io::{self, Read},
    sync::atomic::Ordering,
};

pub struct Decoder<T: Read>(T);

impl<T: Read> Decoder<T> {
    pub fn new(inner: T) -> Self {
        Self(inner)
    }

    fn read_u32(&mut self) -> io::Result<u32> {
        let mut buf = [0u8; 4];
        self.0.read_exact(&mut buf)?;
        Ok(u32::from_le_bytes(buf))
    }

    fn rd(inst: u32) -> Register {
        Register::from(((inst >> 7) & 0b11111) as u8)
    }

    fn rs1(inst: u32) -> Register {
        Register::from(((inst >> 15) & 0b11111) as u8)
    }

    fn rs2(inst: u32) -> Register {
        Register::from(((inst >> 20) & 0b11111) as u8)
    }

    fn csr(inst: u32) -> Csr {
        Csr::from((inst >> 20) as u16)
    }

    fn parse_r(
        inst: u32,
        f: impl FnOnce(Register, Register, Register) -> Instruction,
    ) -> Instruction {
        f(Self::rd(inst), Self::rs1(inst), Self::rs2(inst))
    }

    fn parse_shift_64(
        inst: u32,
        f: impl FnOnce(Register, Register, u8) -> Instruction,
    ) -> Instruction {
        let shamt = ((inst >> 19) & 0b111111) as u8;
        f(Self::rd(inst), Self::rs1(inst), shamt)
    }

    fn parse_shift(
        inst: u32,
        f: impl FnOnce(Register, Register, u8) -> Instruction,
    ) -> Instruction {
        let shamt = ((inst >> 20) & 0b11111) as u8;
        f(Self::rd(inst), Self::rs1(inst), shamt)
    }

    fn parse_csr_reg(
        inst: u32,
        f: impl FnOnce(Register, Register, Csr) -> Instruction,
    ) -> Instruction {
        f(Self::rd(inst), Self::rs1(inst), Self::csr(inst))
    }

    fn parse_csr_imm(inst: u32, f: impl FnOnce(Register, u32, Csr) -> Instruction) -> Instruction {
        let imm = (inst >> 15) & 0b11111;

        f(Self::rd(inst), imm, Self::csr(inst))
    }

    fn parse_u(inst: u32, f: impl FnOnce(Register, u32) -> Instruction) -> Instruction {
        let imm = inst & !0xfff;
        f(Self::rd(inst), imm)
    }

    fn parse_i(inst: u32, f: impl FnOnce(Register, Register, Imm) -> Instruction) -> Instruction {
        let imm = inst >> 20;

        let imm = ((imm as i32) << 20) >> 20;
        f(Self::rd(inst), Self::rs1(inst), imm)
    }

    fn parse_s(inst: u32, f: impl FnOnce(Register, Register, Imm) -> Instruction) -> Instruction {
        let imm = ((inst >> 20) & !0b11111) | ((inst >> 7) & 0b11111);

        let imm = ((imm as i32) << 20) >> 20;

        f(Self::rs1(inst), Self::rs2(inst), imm)
    }

    fn parse_b(inst: u32, f: impl FnOnce(Register, Register, Imm) -> Instruction) -> Instruction {
        let mut imm = 0;
        imm |= (inst >> 19) & 0b11111111111111111111000000000000;
        imm |= (inst << 4) & 0b00000000000000000000100000000000;
        imm |= (inst >> 20) & 0b00000000000000000000011111100000;
        imm |= (inst >> 7) & 0b00000000000000000000000000011110;

        let imm = ((imm as i32) << 19) >> 19;
        f(Self::rs1(inst), Self::rs2(inst), imm)
    }

    fn parse_j(inst: u32, f: impl FnOnce(Register, Imm) -> Instruction) -> Instruction {
        let mut imm = 0;
        imm |= (inst >> 11) & 0b11111111111100000000000000000000;
        imm |= (inst >> 0) & 0b00000000000011111111000000000000;
        imm |= (inst >> 9) & 0b00000000000000000000100000000000;
        imm |= (inst >> 20) & 0b00000000000000000000011111111110;

        let imm = ((imm as i32) << 11) >> 11;
        f(Self::rd(inst), imm)
    }

    fn parse_atomic(
        inst: u32,
        f: impl FnOnce(Register, Register, Register, Ordering) -> Instruction,
    ) -> Instruction {
        let aq = ((inst >> 26) & 0b1) == 0b1;
        let rl = ((inst >> 25) & 0b1) == 0b1;

        let ordering = match (aq, rl) {
            (true, true) => Ordering::SeqCst,
            (true, false) => Ordering::Acquire,
            (false, true) => Ordering::Release,
            (false, false) => Ordering::Relaxed,
        };

        f(Self::rd(inst), Self::rs1(inst), Self::rs2(inst), ordering)
    }
}

impl<T: Read> Iterator for Decoder<T> {
    type Item = Instruction;

    fn next(&mut self) -> Option<Self::Item> {
        let inst = self.read_u32().unwrap();
        let opcode = inst & 0b1111111;
        let funct3 = (inst >> 12) & 0b111;
        let funct5 = (inst >> 27) & 0b11111;
        let funct7 = (inst >> 25) & 0b1111111;

        match opcode {
            0b0110111 => Some(Self::parse_u(inst, Instruction::Lui)),
            0b0010111 => Some(Self::parse_u(inst, Instruction::AuiPC)),

            0b1101111 => Some(Self::parse_j(inst, Instruction::Jal)),
            0b1100111 if funct3 == 0b000 => Some(Self::parse_i(inst, Instruction::Jalr)),

            0b1100011 => match funct3 {
                0b000 => Some(Self::parse_b(inst, Instruction::Beq)),
                0b001 => Some(Self::parse_b(inst, Instruction::Bne)),
                0b100 => Some(Self::parse_b(inst, Instruction::Blt)),
                0b101 => Some(Self::parse_b(inst, Instruction::Bge)),
                0b110 => Some(Self::parse_b(inst, Instruction::Bltu)),
                0b111 => Some(Self::parse_b(inst, Instruction::Bgeu)),
                _ => todo!(),
            },

            0b0000011 => match funct3 {
                0b000 => Some(Self::parse_i(inst, Instruction::Lb)),
                0b001 => Some(Self::parse_i(inst, Instruction::Lh)),
                0b010 => Some(Self::parse_i(inst, Instruction::Lw)),
                0b011 => Some(Self::parse_i(inst, Instruction::Ld)),
                0b100 => Some(Self::parse_i(inst, Instruction::Lbu)),
                0b101 => Some(Self::parse_i(inst, Instruction::Lhu)),
                0b110 => Some(Self::parse_i(inst, Instruction::Lwu)),
                _ => todo!(),
            },

            0b0100011 => match funct3 {
                0b000 => Some(Self::parse_s(inst, Instruction::Sb)),
                0b001 => Some(Self::parse_s(inst, Instruction::Sh)),
                0b010 => Some(Self::parse_s(inst, Instruction::Sw)),
                0b011 => Some(Self::parse_s(inst, Instruction::Sd)),
                _ => todo!(),
            },

            0b0010011 => match (funct3, funct7) {
                (0b000, _) => Some(Self::parse_i(inst, Instruction::Addi)),
                (0b010, _) => Some(Self::parse_i(inst, Instruction::Slti)),
                (0b011, _) => Some(Self::parse_i(inst, Instruction::Sltiu)),
                (0b100, _) => Some(Self::parse_i(inst, Instruction::Xori)),
                (0b110, _) => Some(Self::parse_i(inst, Instruction::ORI)),
                (0b111, _) => Some(Self::parse_i(inst, Instruction::ANDI)),

                // FIXME: 64bit shamt is one bit longer allow both lowes bit set and unset
                (0b001, 0b0000000 | 0b0000001) => Some(Self::parse_shift(inst, Instruction::SLLI)),
                (0b101, 0b0000000 | 0b0000001) => Some(Self::parse_shift(inst, Instruction::SRLI)),
                (0b001, 0b0100000 | 0b0100001) => Some(Self::parse_shift(inst, Instruction::SRAI)),
                _ => todo!(),
            },

            0b0110011 => match (funct3, funct7) {
                (0b000, 0b0000000) => Some(Self::parse_r(inst, Instruction::ADD)),
                (0b000, 0b0100000) => Some(Self::parse_r(inst, Instruction::SUB)),
                (0b001, 0b0000000) => Some(Self::parse_r(inst, Instruction::SLL)),
                (0b010, 0b0000000) => Some(Self::parse_r(inst, Instruction::SLT)),
                (0b011, 0b0000000) => Some(Self::parse_r(inst, Instruction::SLTU)),
                (0b100, 0b0000000) => Some(Self::parse_r(inst, Instruction::XOR)),
                (0b101, 0b0000000) => Some(Self::parse_r(inst, Instruction::SRL)),
                (0b101, 0b0100000) => Some(Self::parse_r(inst, Instruction::SRA)),
                (0b110, 0b0000000) => Some(Self::parse_r(inst, Instruction::OR)),
                (0b111, 0b0000000) => Some(Self::parse_r(inst, Instruction::AND)),

                (0b000, 0b0000001) => Some(Self::parse_r(inst, Instruction::Mul)),
                (0b001, 0b0000001) => Some(Self::parse_r(inst, Instruction::Mulh)),
                (0b010, 0b0000001) => Some(Self::parse_r(inst, Instruction::Mulhsu)),
                (0b011, 0b0000001) => Some(Self::parse_r(inst, Instruction::Mulhu)),
                (0b100, 0b0000001) => Some(Self::parse_r(inst, Instruction::Div)),
                (0b101, 0b0000001) => Some(Self::parse_r(inst, Instruction::Divu)),
                (0b110, 0b0000001) => Some(Self::parse_r(inst, Instruction::Rem)),
                (0b111, 0b0000001) => Some(Self::parse_r(inst, Instruction::Remu)),

                _ => todo!(),
            },

            0b0111011 => match (funct3, funct7) {
                (0b000, 0b0000000) => Some(Self::parse_r(inst, Instruction::ADDW)),
                (0b000, 0b0100000) => Some(Self::parse_r(inst, Instruction::SUBW)),
                (0b001, 0b0000000) => Some(Self::parse_r(inst, Instruction::SLLW)),
                (0b101, 0b0000000) => Some(Self::parse_r(inst, Instruction::SRLW)),
                (0b101, 0b0100000) => Some(Self::parse_r(inst, Instruction::SRAW)),

                (0b000, 0b0000001) => Some(Self::parse_r(inst, Instruction::Mulw)),
                (0b100, 0b0000001) => Some(Self::parse_r(inst, Instruction::Divw)),
                (0b101, 0b0000001) => Some(Self::parse_r(inst, Instruction::Divuw)),
                (0b110, 0b0000001) => Some(Self::parse_r(inst, Instruction::Remw)),
                (0b111, 0b0000001) => Some(Self::parse_r(inst, Instruction::Remuw)),
                _ => todo!(),
            },

            0b0011011 => match (funct3, funct7) {
                (0b000, _) => Some(Self::parse_i(inst, Instruction::Addiw)),

                (0b001, 0b0000000) => Some(Self::parse_shift(inst, Instruction::SLLIW)),
                (0b101, 0b0000000) => Some(Self::parse_shift(inst, Instruction::SRLIW)),
                (0b101, 0b0100000) => Some(Self::parse_shift(inst, Instruction::SRAIW)),

                _ => todo!(),
            },

            0b0001111 => Some(Instruction::FENCE(Fence::decode(inst))),

            0b1110011 => match funct3 {
                0b000 => match (inst >> 20) & 0b1 {
                    0b0 => Some(Instruction::ECALL),
                    0b1 => Some(Instruction::EBREAK),
                    _ => todo!(),
                },

                0b001 => Some(Self::parse_csr_reg(inst, Instruction::CsrRw)),
                0b010 => Some(Self::parse_csr_reg(inst, Instruction::CsrRs)),
                0b011 => Some(Self::parse_csr_reg(inst, Instruction::CsrRc)),

                0b101 => Some(Self::parse_csr_imm(inst, Instruction::CsrRwi)),
                0b110 => Some(Self::parse_csr_imm(inst, Instruction::CsrRsi)),
                0b111 => Some(Self::parse_csr_imm(inst, Instruction::CsrRci)),

                _ => todo!(),
            },

            0b0101111 => match (funct3, funct5) {
                (0b010, 0b00010) => Some(Self::parse_atomic(inst, Instruction::LrW)),
                (0b010, 0b00011) => Some(Self::parse_atomic(inst, Instruction::ScW)),
                (0b010, 0b00001) => Some(Self::parse_atomic(inst, Instruction::AMOSwapW)),
                (0b010, 0b00000) => Some(Self::parse_atomic(inst, Instruction::AMOAddW)),
                (0b010, 0b00100) => Some(Self::parse_atomic(inst, Instruction::AMOXorW)),
                (0b010, 0b01100) => Some(Self::parse_atomic(inst, Instruction::AMOAndW)),
                (0b010, 0b01000) => Some(Self::parse_atomic(inst, Instruction::AMOOrW)),
                (0b010, 0b10000) => Some(Self::parse_atomic(inst, Instruction::AMOMinW)),
                (0b010, 0b10100) => Some(Self::parse_atomic(inst, Instruction::AMOMaxW)),
                (0b010, 0b11000) => Some(Self::parse_atomic(inst, Instruction::AMOMinUW)),
                (0b010, 0b11100) => Some(Self::parse_atomic(inst, Instruction::AMOMaxUW)),

                (0b011, 0b00010) => Some(Self::parse_atomic(inst, Instruction::LrD)),
                (0b011, 0b00011) => Some(Self::parse_atomic(inst, Instruction::ScD)),
                (0b011, 0b00001) => Some(Self::parse_atomic(inst, Instruction::AMOSwapD)),
                (0b011, 0b00000) => Some(Self::parse_atomic(inst, Instruction::AMOAddD)),
                (0b011, 0b00100) => Some(Self::parse_atomic(inst, Instruction::AMOXorD)),
                (0b011, 0b01100) => Some(Self::parse_atomic(inst, Instruction::AMOAndD)),
                (0b011, 0b01000) => Some(Self::parse_atomic(inst, Instruction::AMOOrD)),
                (0b011, 0b10000) => Some(Self::parse_atomic(inst, Instruction::AMOMinD)),
                (0b011, 0b10100) => Some(Self::parse_atomic(inst, Instruction::AMOMaxD)),
                (0b011, 0b11000) => Some(Self::parse_atomic(inst, Instruction::AMOMinUD)),
                (0b011, 0b11100) => Some(Self::parse_atomic(inst, Instruction::AMOMaxUD)),

                _ => todo!(),
            },

            // 0b0000111 => match funct3 {
            //     0b010 => Some(Self::parse_s(inst, Instruction::FlW)),
            //     0b011 => Some(Self::parse_s(inst, Instruction::FlD)),
            //     0b100 => Some(Self::parse_s(inst, Instruction::FlQ)),
            //     0b001 => Some(Self::parse_s(inst, Instruction::FlH)),
            //     _ => todo!(),
            // },
            _ => todo!("{inst:#x?}: {opcode:07b}, {funct3:03b}, {funct7:07b}",),
        }
    }
}
