use super::{Tokenizer, HFTokenizer};

pub struct BPETokenizer {
    hf_tokenizer: HFTokenizer,
}

impl Tokenizer for BPETokenizer {
    fn load() -> Self {
        use crate::tokenization::hf_tokenizers::models::bpe::BPE;
        // Create tokenizer
        let bpe_builder = BPE::builder();
        let bpe_vocab = super::super::vocab::load_bpe_vocab();
        let mut merges: Vec<(String, String)> = Vec::new();
        let lines: Vec<&str> = include_str!("../resources/bpe_merges.txt").split('\n').collect();
        for line in lines {
            let line = String::from(line).replace("Ġ", "").replace("\n", "").replace("##", "");
            // Filter out junk
            if line.contains(' ') && !line.contains('#') {
                let line: Vec<&str> = line.split(' ').collect();
                // Make sure vocab contains both tokens and combined token
                if bpe_vocab.token2index.contains_key(&line[0].to_string()) && bpe_vocab.token2index.contains_key(&line[1].to_string()) && bpe_vocab.token2index.contains_key(&format!("{}{}", line[0].to_string(), line[1].to_string())) {
                    merges.push((line[0].to_string(), line[1].to_string()));
                }
            }
        }
        let bpe_builder = bpe_builder.vocab_and_merges(bpe_vocab.token2index, merges);
        let bpe = bpe_builder
            .unk_token("[UNK]".into())
            .build().expect("BPE Tokenizer failed to build!");

        BPETokenizer {
            hf_tokenizer: HFTokenizer::new(bpe)
        }
    }

    fn tokenize(&self, string: &str) -> Vec<String> {
        super::hf_tokenizers::utils::parallelism::set_parallelism(true);
        // Lowercase
        let string = string.to_lowercase();
        // Create tokenizer and tokenize
        let encoding = self.hf_tokenizer.encode(string, false).expect("BPE tokenization failed!");
        // Convert back to string
        encoding.get_tokens().to_vec()
    }

    fn batch_tokenize(&self, strings: Vec<&str>) -> Vec<Vec<String>> {
        super::hf_tokenizers::utils::parallelism::set_parallelism(true);
        // Lowercase
        let strings = strings.iter().map(|a| {a.to_lowercase()}).collect();
        // Create tokenizer and tokenize
        let encodings = self.hf_tokenizer.encode_batch(strings, false).expect("BPE tokenization failed!");
        // Convert back to strings
        let mut tokens: Vec<Vec<String>> = Vec::with_capacity(encodings.len());
        for encoding in encodings {
            tokens.push(encoding.get_tokens().to_vec());
        };
        tokens
    }

    fn untokenize(&self, tokens: Vec<String>) -> String {
        tokens.join("")
    }

    fn batch_untokenize(&self, tokens: Vec<Vec<String>>) -> Vec<String> {
        tokens.iter().map(|tokens| {
            tokens.join("")
        }).collect()
    }
}