use core::fmt;

use rand_core::{CryptoRng, RngCore};

use crate::{bindings::*, GenerateSecret};

const MIN_KEY_SIZE: usize = 16;
const MAX_KEY_SIZE: usize = 64;

#[derive(Copy, Clone)]
pub struct Key {
    buf: [u8; MAX_KEY_SIZE],
    len: usize,
}

impl Key {
    fn expose(&self) -> &[u8] {
        &self.buf[..self.len]
    }
}

impl GenerateSecret for Key {
    fn generate<RNG: RngCore + CryptoRng>(rng: &mut RNG) -> Self {
        let mut buf = [0u8; 32];
        rng.fill_bytes(&mut buf);
        Self::from(buf)
    }
}

#[derive(Debug)]
pub struct InvalidKeySize;

impl fmt::Display for InvalidKeySize {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str("Invalid key size.")
    }
}

impl TryFrom<&[u8]> for Key {
    type Error = InvalidKeySize;

    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
        if value.len() > MAX_KEY_SIZE || value.len() < MIN_KEY_SIZE {
            return Err(InvalidKeySize);
        }
        let mut buf = [0; MAX_KEY_SIZE];
        (&mut buf[..value.len()]).copy_from_slice(value);
        Ok(Self {
            buf,
            len: value.len(),
        })
    }
}

impl From<[u8; 32]> for Key {
    fn from(key: [u8; 32]) -> Self {
        Self::try_from(key.as_slice()).unwrap_or_else(|_| unreachable!())
    }
}

impl From<[u8; 64]> for Key {
    fn from(key: [u8; 64]) -> Self {
        Self::try_from(key.as_slice()).unwrap_or_else(|_| unreachable!())
    }
}

pub struct Context<const HASH_SIZE: usize>(crypto_blake2b_ctx);

impl<const HASH_SIZE: usize> Context<HASH_SIZE> {
    fn new(key: &[u8]) -> Self {
        assert!(HASH_SIZE > 0 && HASH_SIZE <= 64);
        let mut ctx = core::mem::MaybeUninit::zeroed();
        let ctx = unsafe {
            crypto_blake2b_general_init(ctx.as_mut_ptr(), HASH_SIZE, key.as_ptr(), key.len());
            ctx.assume_init()
        };
        Self(ctx)
    }

    pub fn update(&mut self, message: &[u8]) {
        unsafe {
            crypto_blake2b_update(&mut self.0, message.as_ptr(), message.len());
        }
    }

    pub fn finish(mut self) -> [u8; HASH_SIZE] {
        let mut hash = [0u8; HASH_SIZE];
        unsafe {
            crypto_blake2b_final(&mut self.0, hash.as_mut_ptr());
        }
        hash
    }
}

impl<const HASH_SIZE: usize> From<&Key> for Context<HASH_SIZE> {
    fn from(key: &Key) -> Self {
        Self::new(key.expose())
    }
}

impl<const HASH_SIZE: usize> Default for Context<HASH_SIZE> {
    fn default() -> Self {
        Self::new(&[])
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    pub fn generate_key() {
        let key = Key::generate(&mut rand_core::OsRng);
        let mut ctx = Context::<32>::from(&key);
        ctx.update(b"test");
        ctx.finish();
    }

    #[test]
    pub fn test_vectors() {
        // Test vectors obtained from https://github.com/BLAKE2/BLAKE2/blob/master/testvectors/blake2-kat.json
        let test_vectors = std::fs::read_to_string("test_vectors/blake2.json").unwrap();
        json::parse(&test_vectors)
            .unwrap()
            .members()
            .into_iter()
            .filter(|tv| tv["hash"] == "blake2b")
            .for_each(|tv| {
                let key = hex::decode(tv["key"].as_str().unwrap()).unwrap();
                let input = hex::decode(tv["in"].as_str().unwrap()).unwrap();
                let output = hex::decode(tv["out"].as_str().unwrap()).unwrap();
                let mut ctx: Context<64> = if key.is_empty() {
                    Context::default()
                } else {
                    let key = Key::try_from(key.as_slice()).unwrap();
                    Context::from(&key)
                };
                ctx.update(&input);
                assert_eq!(output.as_slice(), ctx.finish().as_slice());
            });
    }
}
