//! Compute the WPA-PSK of a Wi-Fi SSID and passphrase.
//!
//! # Example
//!
//! Compute and print the WPA-PSK of a valid SSID and passphrase:
//!
//! ```
//! # use wpa_psk::{Ssid, Passphrase, wpa_psk, bytes_to_hex};
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let ssid = Ssid::try_from("home")?;
//! let passphrase = Passphrase::try_from("0123-4567-89")?;
//! let psk = wpa_psk(&ssid, &passphrase);
//! assert_eq!(bytes_to_hex(&psk), "150c047b6fad724512a17fa431687048ee503d14c1ea87681d4f241beb04f5ee");
//! # Ok(())
//! # }
//! ```
//!
//! Compute the WPA-PSK of possibly invalid raw bytes:
//!
//! ```
//! # use wpa_psk::{wpa_psk_unchecked, bytes_to_hex};
//! let ssid = "bar".as_bytes();
//! let passphrase = "2short".as_bytes();
//! let psk = wpa_psk_unchecked(&ssid, &passphrase);
//! assert_eq!(bytes_to_hex(&psk), "cb5de4e4d23b2ab0bf5b9ba0fe8132c1e2af3bb52298ec801af8ad520cea3437");
//! ```

use std::fmt::Display;

use hmac::Hmac;
use pbkdf2::pbkdf2;
use sha1::Sha1;

/// An SSID consisting of 1 up to 32 arbitrary bytes.
#[derive(Debug)]
pub struct Ssid<'a>(&'a [u8]);

impl<'a> TryFrom<&'a [u8]> for Ssid<'a> {
    type Error = &'static str;

    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
        if value.is_empty() {
            Err("SSID must have at least one byte")
        } else if value.len() > 32 {
            Err("SSID must have at most 32 bytes")
        } else {
            Ok(Ssid(value))
        }
    }
}

impl<'a> TryFrom<&'a str> for Ssid<'a> {
    type Error = &'static str;

    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
        Self::try_from(value.as_bytes())
    }
}

/// A passphrase consisting of 8 up to 63 printable ASCII characters.
#[derive(Debug)]
pub struct Passphrase<'a>(&'a [u8]);

impl<'a> TryFrom<&'a [u8]> for Passphrase<'a> {
    type Error = &'static str;

    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
        if value.len() < 8 {
            Err("passphrase must have at least 8 bytes")
        } else if value.len() > 63 {
            Err("passphrase must have at most 63 bytes")
        } else if value.iter().any(|i| !matches!(i, 32u8..=126)) {
            Err("passphrase must consist of printable ASCII characters")
        } else {
            Ok(Passphrase(value))
        }
    }
}

impl<'a> TryFrom<&'a str> for Passphrase<'a> {
    type Error = &'static str;

    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
        Self::try_from(value.as_bytes())
    }
}

impl Display for Passphrase<'_> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", std::str::from_utf8(self.0).unwrap())
    }
}

/// Returns the WPA-PSK of the given SSID and passphrase.
pub fn wpa_psk(ssid: &Ssid, passphrase: &Passphrase) -> [u8; 32] {
    wpa_psk_unchecked(ssid.0, passphrase.0)
}

/// Unchecked WPA-PSK.
/// See [`wpa_psk`].
pub fn wpa_psk_unchecked(ssid: &[u8], passphrase: &[u8]) -> [u8; 32] {
    let mut buf = [0u8; 32];
    pbkdf2::<Hmac<Sha1>>(passphrase, ssid, 4096, &mut buf);
    buf
}

/// Returns the hexdecimal representation of the given bytes.
pub fn bytes_to_hex(bytes: &[u8]) -> String {
    bytes.iter().map(|b| format!("{:02x}", b)).collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn special_characters() {
        let ssid = Ssid::try_from("123abcABC.,-").unwrap();
        let passphrase = Passphrase::try_from("456defDEF *<:D").unwrap();
        assert_eq!(
            bytes_to_hex(&wpa_psk(&ssid, &passphrase)),
            "8a366e5bc51cd5d8fbbeffacc5f1af23fac30e3ac93cdcc368fafbbf63a1085c"
        );
    }

    #[test]
    fn passphrase_too_short() {
        Passphrase::try_from("foobar").unwrap_err();
    }

    #[test]
    fn display_passphrase() {
        assert_eq!(
            format!("{}", Passphrase::try_from("foobarbuzz").unwrap()),
            "foobarbuzz"
        );
    }
}
