use argonautica::Hasher;
use argonautica::config::Variant;
use std::cmp::min;
use std::fmt;
use crate::bottle_error::{BottleError, BottleResult};
use crate::encryption::cprng;


#[derive(Clone, Copy, PartialEq)]
pub struct Argon {
    pub time_cost: u32,
    pub memory_cost_bits: u32,
    pub parallelism: u32,
    salt: [u8; 32],  // might need to expand this if we add new encryption algorithms
    salt_len: usize,
}

impl Argon {
    pub fn new() -> BottleResult<Argon> {
        let mut rv = Argon {
            time_cost: 3,
            memory_cost_bits: 12,
            parallelism: 1,
            salt: [0u8; 32],
            salt_len: 32,
        };
        cprng(&mut rv.salt)?;
        Ok(rv)
    }

    pub fn from_salt(salt: &[u8]) -> Argon {
        let mut rv = Argon {
            time_cost: 3,
            memory_cost_bits: 12,
            parallelism: 1,
            salt: [0u8; 32],
            salt_len: min(salt.len(), 32),
        };
        rv.salt[0 .. rv.salt_len].copy_from_slice(&salt[0 .. rv.salt_len]);
        rv
    }

    /// Generate key material from a user-supplied password.
    /// Fills `key` (which must be the size of the key material you need),
    /// but also returns `key` as a reference slice for convenience.
    pub fn generate_key<'a>(&self, key: &'a mut [u8], password: &[u8]) -> BottleResult<&'a [u8]> {
        let mut hasher = Hasher::default();
        let hash = hasher
            .configure_hash_len(key.len() as u32)
            .configure_iterations(self.time_cost)
            .configure_memory_size(1 << self.memory_cost_bits)
            .configure_lanes(self.parallelism)
            .configure_variant(Variant::Argon2id)
            .opt_out_of_secret_key(true)
            .with_salt(&self.salt[..self.salt_len])
            .with_password(password)
            .hash_raw().map_err(BottleError::Argon2Error)?;
        key.copy_from_slice(hash.raw_hash_bytes());
        Ok(key)
    }

    // returns the slice that it filled
    pub fn encode<'a>(&self, buffer: &'a mut [u8]) -> &'a [u8] {
        assert!(buffer.len() >= 5 + self.salt_len, "buffer must be 5 bytes + the salt length");
        buffer[0] = 0;
        buffer[1..3].copy_from_slice(&(self.time_cost as u16).to_le_bytes());
        buffer[3] = self.memory_cost_bits as u8;
        buffer[4] = self.parallelism as u8;
        buffer[5 .. self.salt_len + 5].copy_from_slice(&self.salt[0 .. self.salt_len]);
        &buffer[0 .. 5 + self.salt_len]
    }

    pub fn decode(buffer: &[u8]) -> Argon {
        assert!(buffer.len() > 5, "buffer must be 5 bytes + the salt length");
        assert!(buffer[0] == 0, "unsupported argon encoding");
        let time_cost = u16::from_le_bytes((&buffer[1..3]).try_into().unwrap()) as u32;
        let memory_cost_bits = buffer[3] as u32;
        let parallelism = buffer[4] as u32;
        let mut salt = [0u8; 32];
        let salt_len = buffer.len() - 5;
        salt[..salt_len].copy_from_slice(&buffer[5..]);
        Argon { time_cost, memory_cost_bits, parallelism, salt, salt_len }
    }
}

impl fmt::Debug for Argon {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "Argon(time_cost={}, memory_cost_bits={}, parallelism={}, salt={}",
            self.time_cost, self.memory_cost_bits, self.parallelism, hex::encode(&self.salt[..self.salt_len]))
    }
}


#[cfg(test)]
mod test {
    use hex::encode;
    use super::Argon;

    #[test]
    fn generate_key() {
        let argon = Argon::from_salt(&[0u8; 16]);
        let mut key = [0u8; 16];
        argon.generate_key(&mut key, b"plus/minus").unwrap();
        assert_eq!(encode(key), "21807ee7d2ba4ce8aba01abf4d8a169e");
    }
}
