use std::collections::HashMap;

use crossbeam_utils::atomic::AtomicCell;
use assemblylift_core_io_common::constants::{FUNCTION_INPUT_BUFFER_SIZE, IO_BUFFER_SIZE_BYTES};

use crate::threader::ThreaderEnv;

pub trait LinearBuffer {
    fn initialize(&mut self, buffer: Vec<u8>);
    fn write(&mut self, bytes: &[u8], at_offset: usize) -> usize;
    fn erase(&mut self, offset: usize, len: usize) -> usize;
    fn len(&self) -> usize;
    fn capacity(&self) -> usize;
}

pub trait WasmBuffer {
    fn copy_to_wasm(&self, env: &ThreaderEnv, src: (usize, usize), dst: (usize, usize)) -> Result<(), ()>;
}

pub trait PagedWasmBuffer: WasmBuffer {
    fn first(&mut self, env: &ThreaderEnv, offset: Option<usize>) -> i32;
    fn next(&mut self, env: &ThreaderEnv) -> i32;
}

pub struct FunctionInputBuffer {
    buffer: Vec<u8>,
    page_idx: usize,
}

impl FunctionInputBuffer {
    pub fn new() -> Self {
        Self {
            buffer: Vec::new(),
            page_idx: 0usize,
        }
    }
}

impl LinearBuffer for FunctionInputBuffer {
    fn initialize(&mut self, buffer: Vec<u8>) {
        self.buffer = buffer;
    }

    fn write(&mut self, bytes: &[u8], at_offset: usize) -> usize {
        let mut bytes_written = 0usize;
        for idx in at_offset..bytes.len() {
            self.buffer[idx] = bytes[idx - at_offset];
            bytes_written += 1;
        }
        bytes_written
    }
    
    fn erase(&mut self, offset: usize, len: usize) -> usize {
        let mut bytes_erased = 0usize;
        for idx in offset..len {
            self.buffer[idx] = 0;
            bytes_erased += 1;
        }
        bytes_erased
    }
    
    fn len(&self) -> usize {
        self.buffer.len()
    }

    fn capacity(&self) -> usize {
        self.buffer.capacity()
    }
}

impl PagedWasmBuffer for FunctionInputBuffer {
    fn first(&mut self, env: &ThreaderEnv, _offset: Option<usize>) -> i32 {
        let end: usize = match self.buffer.len() < FUNCTION_INPUT_BUFFER_SIZE {
            true => self.buffer.len(),
            false => FUNCTION_INPUT_BUFFER_SIZE,
        };
        self.copy_to_wasm(env, (0usize, end), (0usize, FUNCTION_INPUT_BUFFER_SIZE)).unwrap();
        self.page_idx = 0usize;
        0
    }

    fn next(&mut self, env: &ThreaderEnv) -> i32 {
        use std::cmp::min;
        if self.buffer.len() > FUNCTION_INPUT_BUFFER_SIZE {
            self.page_idx += 1;
            self.copy_to_wasm(
                env, 
                (FUNCTION_INPUT_BUFFER_SIZE * self.page_idx, min(FUNCTION_INPUT_BUFFER_SIZE * (self.page_idx + 1), self.buffer.len())), 
                (0usize, FUNCTION_INPUT_BUFFER_SIZE)
            ).unwrap();
        }
        0
    }
}

impl WasmBuffer for FunctionInputBuffer {
    fn copy_to_wasm(&self, env: &ThreaderEnv, src: (usize, usize), dst: (usize, usize)) -> Result<(), ()> {
        let wasm_memory = env.memory_ref().unwrap();
        let input_buffer = env
            .get_function_input_buffer
            .get_ref()
            .unwrap()
            .call()
            .unwrap();
        let memory_writer: &[AtomicCell<u8>] = input_buffer
            .deref(
                &wasm_memory,
                dst.0 as u32,
                dst.1 as u32,
            )
            .unwrap();

        for (i, b) in self.buffer[src.0..src.1].iter().enumerate() {
            let idx = i + dst.0;
            memory_writer[idx].store(*b);
        }

        Ok(())
    }
}

pub struct IoBuffer {
    active_buffer: usize,
    buffers: HashMap<usize, Vec<u8>>,
    page_indices: HashMap<usize, usize>,
}

impl IoBuffer {
    pub fn new() -> Self {
        Self {
            active_buffer: 0usize,
            buffers: Default::default(),
            page_indices: Default::default(),
        }
    }

    pub fn len(&self, ioid: usize) -> usize {
        self.buffers.get(&ioid).unwrap().len()
    }
    
    pub fn with_capacity(num_buffers: usize, buffer_capacity: usize) -> Self {
        let mut buffers: HashMap<usize, Vec<u8>> = HashMap::new();
        let mut indices: HashMap<usize, usize> = HashMap::new();
        for idx in 0..num_buffers {
            buffers.insert(idx, Vec::with_capacity(buffer_capacity));
            indices.insert(idx, 0);
        }
        Self {
            active_buffer: 0usize,
            buffers,
            page_indices: indices,
        }
    }

    pub fn write(&mut self, ioid: usize, bytes: &[u8]) -> usize {
        let mut bytes_written = 0usize;
        match self.buffers.get_mut(&ioid) {
            Some(buffer) => {
                for idx in 0..bytes.len() {
                    buffer.push(bytes[idx]);
                    bytes_written += 1;
                }
            }
            None => {
                self.buffers.insert(ioid, Vec::new());
                return self.write(ioid, bytes);
            }
        }
        bytes_written
    }
}

impl PagedWasmBuffer for IoBuffer {
    fn first(&mut self, env: &ThreaderEnv, offset: Option<usize>) -> i32 {
        self.active_buffer = offset.unwrap_or(0);
        self.page_indices.insert(self.active_buffer, 0usize);

        self.copy_to_wasm(
            env, 
            (self.active_buffer, 0usize), 
            (0usize, IO_BUFFER_SIZE_BYTES),
        ).unwrap();
        0
    }

    fn next(&mut self, env: &ThreaderEnv) -> i32 {
        let page_idx = self.page_indices.get(&self.active_buffer).unwrap() + 1;
        let page_offset = page_idx * IO_BUFFER_SIZE_BYTES;
        self.copy_to_wasm(
            env, 
            (self.active_buffer, page_offset), 
            (0usize, IO_BUFFER_SIZE_BYTES),
        ).unwrap();
        *self.page_indices.get_mut(&self.active_buffer).unwrap() = page_idx;
        0
    }
}

impl WasmBuffer for IoBuffer {
    fn copy_to_wasm(&self, env: &ThreaderEnv, src: (usize, usize), dst: (usize, usize)) -> Result<(), ()> {
        use std::cmp::min;
        let wasm_memory = env.memory_ref().unwrap();
        let io_buffer = env
            .get_io_buffer
            .get_ref()
            .unwrap()
            .call()
            .unwrap();
        let memory_writer: &[AtomicCell<u8>] = io_buffer
            .deref(
                &wasm_memory,
                dst.0 as u32,
                dst.1 as u32,
            )
            .unwrap();

        let buffer = self.buffers.get(&src.0).unwrap();
        for (i, b) in buffer[src.1..min(src.1 + IO_BUFFER_SIZE_BYTES, buffer.len())].iter().enumerate() {
            memory_writer[i].store(*b);
        }

        Ok(())
    }
}
