// SPDX-License-Identifier: MPL-2.0

use rand::rngs::OsRng;
use rand::seq::SliceRandom;

#[derive(Clone, Copy, Debug)]
pub enum CharSet {
    Digit,
    Lower,
    Upper,
    Punct,
    Alpha,
    Alnum,
    Graph,
}

// clap can implement this using the arg_enum! macro, but unfortunately
// we need to do it ourselves to get the formatting right
impl std::str::FromStr for CharSet {
    type Err = String;

    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
        match s {
            "digit" => Ok(CharSet::Digit),
            "lower" => Ok(CharSet::Lower),
            "upper" => Ok(CharSet::Upper),
            "punct" => Ok(CharSet::Punct),
            "alpha" => Ok(CharSet::Alpha),
            "alnum" => Ok(CharSet::Alnum),
            "graph" => Ok(CharSet::Graph),
            _ => Err(format!("valid values: {}", CharSet::variants().join(", ")))
        }
    }
}

impl std::fmt::Display for CharSet {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        match self {
            CharSet::Digit => write!(f, "digit"),
            CharSet::Lower => write!(f, "lower"),
            CharSet::Upper => write!(f, "upper"),
            CharSet::Punct => write!(f, "punct"),
            CharSet::Alpha => write!(f, "alpha"),
            CharSet::Alnum => write!(f, "alnum"),
            CharSet::Graph => write!(f, "graph"),
        }
    }
}

impl CharSet {
    pub fn variants() -> [&'static str; 7] {
        ["digit", "lower", "upper", "punct", "alpha", "alnum", "graph"]
    }
}

pub const DIGIT: &str = "0123456789";
pub const LOWER: &str = "abcdefghijklmnopqrstuvwxyz";
pub const UPPER: &str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
pub const PUNCT: &str = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~";

// Hmm, how to reduce this duplication?
pub const ALPHA: &str = concat!("abcdefghijklmnopqrstuvwxyz", "ABCDEFGHIJKLMNOPQRSTUVWXYZ",);

pub const ALNUM: &str = concat!(
    "abcdefghijklmnopqrstuvwxyz",
    "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
    "0123456789",
);

pub const GRAPH: &str = concat!(
    "abcdefghijklmnopqrstuvwxyz",
    "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
    "0123456789",
    "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~",
);

impl CharSet {
    fn chars(self) -> Vec<char> {
        match self {
            CharSet::Digit => DIGIT,
            CharSet::Lower => LOWER,
            CharSet::Upper => UPPER,
            CharSet::Punct => PUNCT,
            CharSet::Alpha => ALPHA,
            CharSet::Alnum => ALNUM,
            CharSet::Graph => GRAPH,
        }.chars().collect()
    }
}

pub fn random_string(charset: CharSet, length: u32) -> String {
    let chars = charset.chars();
    (0..length)
        .map(|_| chars.choose(&mut OsRng).unwrap())
        .collect::<String>()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn random_string_length_matches_input_length() {
        let string = random_string(CharSet::Graph, 20);
        assert_eq!(string.len(), 20);
    }

    #[test]
    fn random_string_characters_are_taken_from_charset() {
        let charset = CharSet::Graph;
        let string = random_string(charset, 20);
        assert!(string.chars().all(|c| charset.chars().contains(&c)))
    }
}
