use std::{cmp::Ordering, collections::VecDeque, error::Error, fs::File, io::{BufRead, BufReader}};
use itertools::Itertools;
use rayon::prelude::*;
use rand::{Rng, prelude::SliceRandom};

pub struct Dataloader<D: Sync + Send + Clone, S: Sync + Send> {
    paths: Vec<String>,
    start_index: usize,
    end_index: usize,
    chunk_size: usize,
    batch_size: usize,
    chunks_processed: Vec<bool>,
    post_buffer: VecDeque<Vec<D>>,
    loading_state: Option<S>,
    loading_function: fn(&String, Option<&S>) -> D,
    sorting_function: Option<fn(&D, &D) -> Ordering>,
    current_index: usize,
}

impl <D: Sync + Send + Clone, S: Sync + Send> Default for Dataloader<D, S> {
    fn default() -> Self {
        Dataloader {
            paths: vec![String::new()],
            batch_size: 1,
            start_index: 0,
            end_index: 0,
            chunk_size: 10000,
            loading_function: |_, _| {unimplemented!()}, // Default function, does nothing
            sorting_function: None,
            loading_state: None,
            current_index: 0,
            chunks_processed: vec![],
            post_buffer: VecDeque::new(),
        }
    }
}

impl<D: Sync + Send + Clone, S: Sync + Send> Dataloader <D, S>{
    #[allow(clippy::too_many_arguments)]
    pub fn new(paths: &[&str], batch_size: usize, start_index: Option<usize>, end_index: Option<usize>, chunk_size: usize, loading_state: Option<S>, loading_function: fn(&String, Option<&S>) -> D, sorting_function: Option<fn(&D, &D) -> Ordering>) -> Self {
        // Find end index
        let end_index = {
            let mut file_end = 0;
            for path in paths.clone() {
                let file = File::open(path).expect("Failed to read file!");
                let reader = BufReader::new(file);
                file_end += reader.lines().count();
            };
            if let Some(e_index) = end_index {
                usize::min(e_index, file_end)
            } else {
                file_end
            }
        };
        let start_index = start_index.unwrap_or_default();
        Dataloader{
            paths: paths.iter().map(|p| {p.to_string()}).collect(),
            batch_size,
            start_index,
            end_index,
            chunk_size: usize::min(chunk_size, end_index - start_index),
            loading_state,
            loading_function,
            sorting_function,
            chunks_processed: vec![false; (end_index - start_index) / chunk_size],
            current_index: 0,
            post_buffer: VecDeque::new(),
        }
    }

    pub fn len(&self) -> usize {
        self.end_index - self.start_index
    }

    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    pub fn load(&mut self) -> Result<(), Box<dyn Error>> {
        // Load another random chunk
        let unprocessed_chunk: Vec<(usize, &bool)> = self.chunks_processed.iter().enumerate()
            .filter(|(_, processed)| {!**processed}).collect();
        let mut rng = rand::thread_rng();
        let chosen_chunk = unprocessed_chunk[rng.gen_range(0..unprocessed_chunk.len())].0;
        self.chunks_processed[chosen_chunk] = true;

        // Load data from chunk
        let len = usize::clamp(self.start_index + ((chosen_chunk + 1) * self.chunk_size), 0, self.end_index) - (self.start_index + chosen_chunk * self.chunk_size);
        let mut data = Vec::with_capacity(len);
        let mut read_lines = 0;
        for path in &self.paths {
            // Check to see if we want to use this file
            let file = File::open(path)?;
            let reader = BufReader::new(file);
            let current_line_count = reader.lines().count();
            read_lines += current_line_count;
            if read_lines < self.start_index + chosen_chunk * self.chunk_size {continue;} // We need to move to the next file

            // Read data from this file until we reached the required amount or run out of file
            let file = File::open(path)?;
            let reader = BufReader::new(file);
            for line in reader.lines().skip((self.start_index + self.chunk_size * chosen_chunk) - (read_lines - current_line_count)).flatten() {
                data.push(line);
                if data.len() >= len {break;}
            }
            if data.len() >= len {break;}
        }

        // Run through load function in parallel
        let mut processed_data = Vec::new();
        let state = if self.loading_state.is_some() {Some(self.loading_state.as_ref().unwrap())} else {None};
        data.par_iter().map(|string| {(self.loading_function)(string, state)}).collect_into_vec(&mut processed_data);
        
        // Run sorting function in parallel
        if let Some(sorting_function) = self.sorting_function {
            processed_data.par_sort_by(sorting_function);
        }
        // Randomize but retain batches
        let mut batched_data = Vec::with_capacity(processed_data.len() / self.batch_size);
        for i in (0..processed_data.len()-self.batch_size).step_by(self.batch_size) {
            batched_data.push(processed_data[i..i+self.batch_size].iter().cloned().collect_vec());
        }
        batched_data.shuffle(&mut rand::thread_rng());

        // Put processed data into post buffer
        self.post_buffer.extend(batched_data.iter().cloned());
        Ok(())
    }
}

impl <D: Sync + Send + Clone, S: Sync + Send> Iterator for Dataloader <D, S>{
    type Item = Vec<D>;

    fn next(&mut self) -> Option<Self::Item> {
        self.current_index += 1;
        if self.current_index == (self.end_index - self.start_index) / self.batch_size {
            self.current_index = 0;
            self.chunks_processed = vec![false; self.chunks_processed.len()];
            None
        } else {
            // Get example from buffer
            if self.post_buffer.is_empty() {
                // Load more examples
                self.load().expect("Failed to load data!");
            }
            Some(self.post_buffer.pop_front().expect("Failed to pop example from post buffer!"))
        }
    }
}

impl<D: Sync + Send + Clone, S: Sync + Send> lentrait::Len for Dataloader<D, S> {
    fn len(&self) -> usize {
        self.len()
    }
}