use core::time;
use std::{cmp::Ordering, collections::VecDeque, error::Error, fs::File, io::{BufRead, BufReader}, thread};
use itertools::Itertools;
use multiqueue::{BroadcastReceiver, BroadcastSender, broadcast_queue};
use rayon::prelude::*;
use rand::{Rng, prelude::SliceRandom};

pub struct ThreadedDataloader<D: Sync + Send + Clone, S: Sync + Send + Clone + 'static> {
    paths: Vec<String>,
    start_index: usize,
    end_index: usize,
    chunk_size: usize,
    batch_size: usize,
    chunks_processed: Vec<bool>,
    post_buffer_send: Option<BroadcastSender<Vec<D>>>,
    post_buffer_recv: Option<BroadcastReceiver<Vec<D>>>,
    post_buffer: VecDeque<Vec<D>>,
    loading_process: Option<thread::JoinHandle<()>>,
    loading_process_flag: Option<thread_control::Flag>,
    loading_state: Option<S>,
    loading_function: fn(&String, Option<&S>) -> D,
    sorting_function: Option<fn(&D, &D) -> Ordering>,
    current_index: usize,
}

impl <D: Sync + Send + Clone, S: Sync + Send + Clone + 'static> Default for ThreadedDataloader<D, S> {
    fn default() -> Self {
        ThreadedDataloader {
            paths: vec![String::new()],
            batch_size: 1,
            start_index: 0,
            end_index: 0,
            chunk_size: 10000,
            loading_state: None,
            loading_function: |_, _| {unimplemented!()}, // Default function, does nothing
            sorting_function: None,
            current_index: 0,
            chunks_processed: vec![],
            post_buffer_send: None,
            post_buffer_recv: None,
            post_buffer: VecDeque::new(),
            loading_process: None,
            loading_process_flag: None,
        }
    }
}

impl<D: Sync + Send + Clone + 'static, S: Sync + Send + Clone + 'static> ThreadedDataloader <D, S>{
    #[allow(clippy::too_many_arguments)]
    pub fn new(paths: &[&str], batch_size: usize, start_index: Option<usize>, end_index: Option<usize>, chunk_size: usize, loading_state: Option<S>, loading_function: fn(&String, Option<&S>) -> D, sorting_function: Option<fn(&D, &D) -> Ordering>) -> Self {
        // Find end index
        let end_index = {
            let mut file_end = 0;
            for path in paths.clone() {
                let file = File::open(path).expect("Failed to read file!");
                let reader = BufReader::new(file);
                file_end += reader.lines().count();
            };
            if let Some(e_index) = end_index {
                usize::min(e_index, file_end)
            } else {
                file_end
            }
        };
        let start_index = start_index.unwrap_or_default();
        // Create post buffer queue
        let (send, recv) = broadcast_queue::<Vec<D>>(100);
        ThreadedDataloader{
            paths: paths.iter().map(|p| {p.to_string()}).collect(),
            batch_size,
            start_index,
            end_index,
            chunk_size: usize::min(chunk_size, end_index - start_index),
            loading_state,
            loading_function,
            sorting_function,
            chunks_processed: vec![false; (end_index - start_index) / chunk_size],
            current_index: 0,
            post_buffer_send: Some(send),
            post_buffer_recv: Some(recv),
            post_buffer: VecDeque::new(),
            loading_process: None,
            loading_process_flag: None,
        }
    }

    pub fn len(&self) -> usize {
        self.end_index - self.start_index
    }

    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    pub fn load(&mut self) -> Result<(), Box<dyn Error>> {
        // Load another random chunk
        let unprocessed_chunk: Vec<(usize, &bool)> = self.chunks_processed.iter().enumerate()
            .filter(|(_, processed)| {!**processed}).collect();
        let mut rng = rand::thread_rng();
        let chosen_chunk = unprocessed_chunk[rng.gen_range(0..unprocessed_chunk.len())].0;
        self.chunks_processed[chosen_chunk] = true;

        let paths = self.paths.clone();
        let (start_index, end_index, chunk_size, batch_size) = (self.start_index, self.end_index, self.chunk_size, self.batch_size);
        let loading_function = self.loading_function;
        let sorting_function = self.sorting_function;
        let loading_state = self.loading_state.clone();
        let sender = self.post_buffer_send.as_ref().unwrap().clone();
        let (flag, control) = thread_control::make_pair();
        self.loading_process_flag = Some(flag);

        self.loading_process = Some(thread::spawn(move || {
            // Load data from chunk
            let len = usize::clamp(start_index + ((chosen_chunk + 1) * chunk_size), 0, end_index) - (start_index + chosen_chunk * chunk_size);
            let mut data = Vec::with_capacity(len);
            let mut read_lines = 0;
            for path in &paths {
                // Check to see if we want to use this file
                let file = File::open(path).expect("Failed to open file!");
                let reader = BufReader::new(file);
                let current_line_count = reader.lines().count();
                read_lines += current_line_count;
                if read_lines < start_index + chosen_chunk * chunk_size {continue;} // We need to move to the next file

                // Read data from this file until we reached the required amount or run out of file
                let file = File::open(path).expect("Failed to open file!");
                let reader = BufReader::new(file);
                for line in reader.lines().skip((start_index + chunk_size * chosen_chunk) - (read_lines - current_line_count)).flatten() {
                    data.push(line);
                    if data.len() >= len {break;}
                }
                if data.len() >= len {break;}
            }

            // Run through load function in parallel
            let mut processed_data = Vec::new();
            let loading_state_ref = if loading_state.is_some() {Some(loading_state.as_ref().unwrap())} else {None};
            data.par_iter().map(|string| {(loading_function)(string, loading_state_ref)}).collect_into_vec(&mut processed_data);

            // Run sorting function in parallel
            if let Some(sorting_function) = sorting_function {
                processed_data.par_sort_by(sorting_function);
            }
            // Randomize but retain batches
            let mut batched_data = Vec::with_capacity(processed_data.len() / batch_size);
            for i in (0..processed_data.len()-batch_size).step_by(batch_size) {
                batched_data.push(processed_data[i..i+batch_size].iter().cloned().collect_vec());
            }
            batched_data.shuffle(&mut rand::thread_rng());

            // Put processed data into post buffer
            for batch in batched_data {
                loop {
                    if sender.try_send(batch.clone()).is_ok() {
                        break;
                    }
                }
            }
            control.stop();
        }));
        Ok(())
    }
}

impl <D: Sync + Send + Clone + 'static, S: Sync + Send + Clone + 'static> Iterator for ThreadedDataloader <D, S>{
    type Item = Vec<D>;

    fn next(&mut self) -> Option<Self::Item> {
        self.current_index += 1;
        if self.current_index == (self.end_index - self.start_index) / self.batch_size {
            self.current_index = 0;
            self.chunks_processed = vec![false; self.chunks_processed.len()];
            None
        } else {
            // Get avaliable batches
            for batch in self.post_buffer_recv.as_ref().unwrap().try_iter() {
                self.post_buffer.push_back(batch);
            }
            // Check if we need to start loading again
            if self.post_buffer.len() < self.chunk_size / self.batch_size {
                // Trigger loading if possible
                if (self.loading_process_flag.is_none() || !self.loading_process_flag.as_ref().unwrap().is_alive()) && !self.chunks_processed.iter().all(|v| {*v}) {
                    self.load().expect("Failed to load!");
                }
                // Make sure we have examples
                if self.post_buffer.is_empty() {
                    if self.chunks_processed.iter().all(|v| {*v}) && self.loading_process_flag.is_none() || !self.loading_process_flag.as_ref().unwrap().is_alive() {
                        self.chunks_processed = vec![false; self.chunks_processed.len()];
                        self.current_index = 0;
                        return None;
                    }
                    loop {
                        // Get avaliable batches
                        if let Ok(batch) = self.post_buffer_recv.as_ref().unwrap().try_recv() {
                            self.post_buffer.push_back(batch);
                            break;
                        }
                        thread::sleep(time::Duration::from_millis(100));
                    }
                }
            }

            Some(self.post_buffer.pop_front().expect("Failed to get example from buffer!"))
        }
    }
}

impl<D: Sync + Send + Clone + 'static, S: Sync + Send + Clone + 'static> lentrait::Len for ThreadedDataloader<D, S> {
    fn len(&self) -> usize {
        self.len()
    }
}