// SPDX-License-Identifier: MPL-2.0

use thiserror::Error;

use crate::crypto::argon2::{Iterations, Memory, Threads, Salt, SALTBYTES};
use crate::crypto::chacha20::{Nonce, NONCEBYTES};

#[derive(Error, Debug, PartialEq, Eq)]
pub enum HeaderSerdesError {
    #[error("Header length {found:?} is less than {expected:?}")]
    TooShort {
        found: usize,
        expected: usize,
    },
    #[error("Header magic {found:?} does not match {expected:?}")]
    WrongMagic {
        found: u32,
        expected: u32,
    },
    #[error("Header version {found:?} does not match supported version {expected:?}")]
    UnsupportedVersion {
        found: u32,
        expected: u32,
    }
}

#[derive(Debug, PartialEq, Eq)]
pub struct Header {
    pub magic: u32,
    pub version: u32,
    pub iterations: Iterations,
    pub memory: Memory,
    pub threads: Threads,
    pub salt: Salt,
    pub nonce: Nonce,
}

impl Header {
    pub const MAGIC: u32 = 0xA8F988BA;
    // TODO bump this to version 1 eventually, maybe
    pub const VERSION: u32 = 0;
    pub const ITERATIONS: Iterations = if cfg!(test) { Iterations(1) } else { Iterations(4) };
    pub const MEMORY: Memory = if cfg!(test) { Memory(8) } else { Memory(1048576) };
    pub const THREADS: Threads = Threads(1);

    pub const SIZE: usize = 4 + 4 + 4 + 4 + 4 + SALTBYTES + NONCEBYTES;

    pub fn default() -> Header {
        Header {
            magic: Header::MAGIC,
            version: Header::VERSION,
            iterations: Header::ITERATIONS,
            memory: Header::MEMORY,
            threads: Header::THREADS,
            // This is of course an invalid salt and nonce, but they are overwritten
            // with randomly generated values as needed.
            salt: Salt([0u8; SALTBYTES]),
            nonce: Nonce([0u8; NONCEBYTES]),
        }
    }

    pub fn serialize(&self) -> [u8; Header::SIZE] {
        let mut bytes = [0u8; Header::SIZE];

        bytes[0..4].copy_from_slice(&self.magic.to_le_bytes());
        bytes[4..8].copy_from_slice(&self.version.to_le_bytes());
        bytes[8..12].copy_from_slice(&self.iterations.0.to_le_bytes());
        bytes[12..16].copy_from_slice(&self.memory.0.to_le_bytes());
        bytes[16..20].copy_from_slice(&self.threads.0.to_le_bytes());
        bytes[20..(20 + SALTBYTES)].copy_from_slice(&self.salt.0);
        bytes[(20 + SALTBYTES)..(20 + SALTBYTES + NONCEBYTES)].copy_from_slice(&self.nonce.0);

        bytes
    }

    pub fn deserialize(bytes: &[u8]) -> Result<Header, HeaderSerdesError> {
        if bytes.len() < Header::SIZE {
            return Err(HeaderSerdesError::TooShort { found: bytes.len(), expected: Header::SIZE });
        }

        // Indexing into the bytes slice won't fail since we already checked the length, but
        // Rust doesn't know that, hence the unwrap
        let magic = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
        let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
        let iterations = u32::from_le_bytes(bytes[8..12].try_into().unwrap());
        let memory = u32::from_le_bytes(bytes[12..16].try_into().unwrap());
        let threads = u32::from_le_bytes(bytes[16..20].try_into().unwrap());
        let salt = bytes[20..20 + SALTBYTES].try_into().unwrap();
        let nonce = bytes[(20 + SALTBYTES)..(20 + SALTBYTES + NONCEBYTES)].try_into().unwrap();

        if magic != Header::MAGIC {
            return Err(HeaderSerdesError::WrongMagic { found: magic, expected: Header::MAGIC });
        }

        if version != Header::VERSION {
            return Err(HeaderSerdesError::UnsupportedVersion { found: version, expected: Header::VERSION });
        }

        let header = Header {
            magic,
            version,
            iterations: Iterations(iterations),
            memory: Memory(memory),
            threads: Threads(threads),
            salt: Salt(salt),
            nonce: Nonce(nonce),
        };

        Ok(header)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn serialization_roundtrips() {
        let header = Header::default();
        let bytes = header.serialize();
        let des_header = Header::deserialize(&bytes).unwrap();

        assert_eq!(header, des_header);
    }

    #[test]
    fn deserialization_detects_too_short_header() {
        let bytes = vec![0x0; 10];
        let error = Header::deserialize(&bytes).unwrap_err();
        assert_eq!(error, HeaderSerdesError::TooShort { found: 10, expected: Header::SIZE });
    }

    #[test]
    fn deserialization_detects_wrong_magic() {
        let mut header = Header::default();
        header.magic = 0x1;
        let bytes = header.serialize();
        let error = Header::deserialize(&bytes).err().unwrap();
        assert_eq!(error, HeaderSerdesError::WrongMagic { found: 0x1, expected: Header::MAGIC });
    }

    #[test]
    fn deserialization_detects_wrong_version() {
        let mut header = Header::default();
        header.version = 100;
        let bytes = header.serialize();
        let error = Header::deserialize(&bytes).unwrap_err();
        assert_eq!(error, HeaderSerdesError::UnsupportedVersion { found: 100, expected: Header::VERSION });
    }
}
