1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Serialize, Deserialize)]
pub struct TextData {
    pub chars: HashMap<char, u64>,
    pub bigrams: HashMap<[char; 2], u64>,
    pub trigrams: HashMap<[char; 3], u64>,
    pub skip_1_grams: HashMap<[char; 2], u64>,
}

impl TextData {
    pub fn from(text: String) -> Self {
        let mut chars: HashMap<char, u64> = HashMap::with_capacity(30);
        let mut bigrams: HashMap<[char; 2], u64> = HashMap::with_capacity(30 * 30);
        let mut trigrams: HashMap<[char; 3], u64> = HashMap::with_capacity(30 * 30 * 15);
        let mut skip_1_grams: HashMap<[char; 2], u64> = HashMap::with_capacity(30 * 30);
        for v in text
            .chars()
            .map(|x| x.to_ascii_lowercase())
            .collect::<Vec<char>>()
            .windows(3)
        {
            let ch = chars.entry(v[0]).or_insert(0);
            *ch += 1;
            if v.len() >= 2 {
                let bg = bigrams.entry([v[0], v[1]]).or_insert(0);
                *bg += 1;
            }
            if v.len() == 3 {
                let tg = trigrams.entry([v[0], v[1], v[2]]).or_insert(0);
                let sg = skip_1_grams.entry([v[0], v[1]]).or_insert(0);
                *tg += 1;
                *sg += 1;
            }
        }
        Self {
            chars,
            bigrams,
            trigrams,
            skip_1_grams,
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::TextData;

    #[test]
    fn get_text_data() {
        let data = TextData::from("Hello world!".to_string());
        assert_eq!(*data.chars.get(&'l').unwrap(), 3);
        assert_eq!(*data.bigrams.get(&['l', 'd']).unwrap(), 1);
        assert_eq!(*data.trigrams.get(&['w', 'o', 'r']).unwrap(), 1);
    }
}