#[cfg(test)]
mod tests;
use std::{collections::HashMap, fs::File, io::LineWriter};
use serde::{Serialize, Deserialize};

/// Mako Vocab Object
#[derive(Serialize, Deserialize, Default)]
pub struct Vocab {
    pub num_tokens: u32,
    pub token2index: HashMap<String, u32>,
    pub index2token: Vec<String>,
    pub PAD_token: u32,
    pub SOS_token: u32,
    pub EOS_token: u32,
    pub SEP_token: u32,
}

impl Vocab {
    /// Make a new vocab
    pub fn new() -> Vocab {
        let mut voc = Vocab {num_tokens: 0, token2index: HashMap::new(), index2token: Vec::new(), PAD_token: 0, SOS_token: 1, EOS_token: 2, SEP_token:3};
        voc.add_tokens(vec!["[PAD]".to_string(), "[SOS]".to_string(), "[EOS]".to_string(), "[SEP]".to_string()]);
        voc
    }

    /// Returns num_tokens
    pub fn len(&self) -> usize {
        self.num_tokens as usize
    }

    /// Returns if the vocab is empty or not
    pub fn is_empty(&self) -> bool {
        self.num_tokens == 0
    }

    /// Add token to vocab
    pub fn add_token(&mut self, token: String) {
        self.token2index.insert(token.clone(), self.num_tokens);
        self.index2token.push(token);
        self.num_tokens += 1;
    }

    /// Add a vec of tokens to vocab
    pub fn add_tokens(&mut self, tokens: Vec<String>) {
        self.index2token.extend(tokens.clone());
        for (i, token) in tokens.iter().enumerate() { // Probably a more efficient way to do this and avoid the loop
            self.token2index.insert(token.clone(), self.num_tokens + i as u32);
        }
        self.num_tokens += tokens.len() as u32;
    }

    /// Remove a vec of tokens from vocab
    pub fn remove_tokens(&mut self, tokens: Vec<String>) {
        for token in tokens {
            if self.token2index.contains_key(&token) {
                self.remove_token(token);
            }
        }
    }

    /// Remove token from vocab
    pub fn remove_token(&mut self, token: String) {
        // Loop through all higher token2index mappings and decrement (must be a more efficient way to do this)
        for i in (self.token2index[&token] as usize)+1..self.index2token.len() {
            *self.token2index.get_mut(&self.index2token[i]).unwrap() -= 1;
        }
        self.index2token.remove(self.token2index[&token] as usize);
        self.token2index.remove(&token);
        self.num_tokens -= 1;
    }

    /// Get vec of tokens from vec of indexes
    pub fn tokens_from_indexes(&self, indexes: &[u32]) -> Result<Vec<String>, TokenNotFoundError> {
        if *indexes.iter().max().unwrap() >= self.num_tokens {return Err(TokenNotFoundError);} // Make sure we aren't trying to get an index too big

        let mut tokens: Vec<String> = Vec::with_capacity(indexes.len());
        for index in indexes {
            tokens.push(self.index2token[*index as usize].to_string());
        }
        Ok(tokens)
    }

    /// Batched version of tokens_from_indexes
    pub fn batch_tokens_from_indexes(&self, indexes: &[Vec<u32>]) -> Result<Vec<Vec<String>>, TokenNotFoundError> {
        let mut tokens: Vec<Vec<String>> = Vec::with_capacity(indexes.len());
        for sent in indexes {
            tokens.push(self.tokens_from_indexes(sent)?);
        }
        Ok(tokens)
    }

    /// Get vec of indexes from vec of tokens
    pub fn indexes_from_tokens(&self, tokens: &[String]) -> Result<Vec<u32>, TokenNotFoundError> {
        let mut indexes: Vec<u32> = Vec::with_capacity(tokens.len());
        for token in tokens {
            indexes.push(
                match self.token2index.get(token) {
                    Some(index) => *index,
                    None => {return Err(TokenNotFoundError);}
                });
        }
        Ok(indexes)
    }

    /// Batched version of indexes_from_tokens
    pub fn batch_indexes_from_tokens(&self, tokens: &[Vec<String>]) -> Result<Vec<Vec<u32>>, TokenNotFoundError> {
        let mut indexes: Vec<Vec<u32>> = Vec::with_capacity(tokens.len());
        for sent in tokens {
            indexes.push(self.indexes_from_tokens(sent)?);
        }
        Ok(indexes)
    }
}

/// Load the BPE vocab from an internal vocab file into a Vocab object
pub fn load_bpe_vocab() -> Vocab {
    use serde_json::{Value};
    use std::io::Write;

    // Open vocab file
    let json: HashMap<String, Value>  = serde_json::from_str(&include_str!("../resources/bpe_vocab.json").replace("/", "").replace("Ġ", "")).expect("Error parsing BPE vocab file!");
    // Build sorted vector of tokens from hashmap
    let mut token_vec: Vec<String> = vec![String::from(""); 50265]; // Happen to know the largest index in the json is 50264, this is a bad system
    for token in json.keys() {
        token_vec[json[token].as_u64().unwrap() as usize] = token.clone();
    }
    // Build vocab
    let mut vocab = Vocab::new();
    let mut temp_vec: Vec<String> = Vec::new();
    for token in token_vec {
        if !token.is_empty() {
            vocab.add_token(token.clone());
            temp_vec.push(token);
        }
    }
    let mut writer = LineWriter::new(File::create("bpe_vocab.txt").expect(""));
    writer.write_all(temp_vec.join("\n").as_bytes()).expect("msg");
    vocab
}

/// Lod the WordPiece vocab from an internal vocab file into a Vocab object
pub fn load_wordpiece_vocab() -> Vocab {
    //unimplemented!(); // CURRENTLY NOT WORKING CORRECTLY, BOTH VOCAB AND TOKENIZER
    // Open vocab file
    let tokens: Vec<&str> = include_str!("../resources/wordpiece_vocab.txt").split('\n').collect();
    // Build vocab
    let mut vocab = Vocab::new();
    for token in tokens {
        vocab.add_token(String::from(token));
    }
    vocab
}

/// Custom Error Types
#[derive(Debug)]
pub struct TokenNotFoundError;