// SPDX-License-Identifier: GPL-3.0-or-later

use std::convert::TryInto;

use sodiumoxide::crypto::aead::xchacha20poly1305_ietf::{Nonce, NONCEBYTES};
use sodiumoxide::crypto::pwhash::argon2id13::{MemLimit, OpsLimit, Salt, SALTBYTES};
use thiserror::Error;

#[derive(Error, Debug, PartialEq, Eq)]
pub enum HeaderSerdesError {
    #[error("Header length {found:?} is less than {expected:?}")]
    TooShort {
        found: usize,
        expected: usize,
    },
    #[error("Header magic {found:?} does not match {expected:?}")]
    WrongMagic {
        found: u32,
        expected: u32,
    },
    #[error("Header version {found:?} does not match supported version {expected:?}")]
    UnsupportedVersion {
        found: u32,
        expected: u32,
    }
}

#[derive(Debug)]
pub struct Header {
    pub magic: u32,
    pub version: u32,
    pub opslimit: OpsLimit,
    pub memlimit: MemLimit,
    pub salt: Salt,
    pub nonce: Nonce,
}

// Salt and Nonce do not implement PartialEq, hence this ugliness
impl PartialEq for Header {
    fn eq(&self, other: &Self) -> bool {
        self.magic == other.magic &&
            self.version == other.version &&
            self.opslimit.0 == other.opslimit.0 &&
            self.memlimit.0 == other.memlimit.0 &&
            self.salt.0 == other.salt.0 &&
            self.nonce.0 == other.nonce.0
    }
}

impl Eq for Header {}

impl Header {
    pub const MAGIC: u32 = 0xA8F988BA;
    // TODO bump this to version 1 eventually, maybe
    pub const VERSION: u32 = 0;
    pub const SIZE: usize = 4 + 4 + 8 + 8 + SALTBYTES + NONCEBYTES;

    pub fn default() -> Header {
        // Should really be imported constants, later
        let opslimit = if cfg!(test) { OpsLimit(1) } else { OpsLimit(4) };
        let memlimit = if cfg!(test) { MemLimit(8192) } else { MemLimit(1073741824) };
        Header {
            magic: Header::MAGIC,
            version: Header::VERSION,
            opslimit,
            memlimit,
            // This is of course an invalid salt and nonce, but they are overwritten
            // with randomly generated values as needed.
            salt: Salt([0u8; SALTBYTES]),
            nonce: Nonce([0u8; NONCEBYTES]),
        }
    }

    pub fn serialize(&self) -> [u8; Header::SIZE] {
        let OpsLimit(opslimit) = self.opslimit;
        let MemLimit(memlimit) = self.memlimit;
        let Salt(salt) = self.salt;
        let Nonce(nonce) = self.nonce;

        let mut bytes = [0u8; Header::SIZE];

        // TODO: there has to be a better way
        bytes[0..4].copy_from_slice(&self.magic.to_le_bytes());
        bytes[4..8].copy_from_slice(&self.version.to_le_bytes());
        bytes[8..16].copy_from_slice(&opslimit.to_le_bytes());
        bytes[16..24].copy_from_slice(&memlimit.to_le_bytes());
        bytes[24..(24 + SALTBYTES)].copy_from_slice(&salt);
        bytes[(24 + SALTBYTES)..(24 + SALTBYTES + NONCEBYTES)].copy_from_slice(&nonce);

        bytes
    }

    pub fn deserialize(bytes: &[u8]) -> Result<Header, HeaderSerdesError> {
        if bytes.len() < Header::SIZE {
            return Err(HeaderSerdesError::TooShort { found: bytes.len(), expected: Header::SIZE });
        }

        // Rust doesn't know that these conversions won't fail, hence the unwrap
        let magic = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
        let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
        let opslimit = OpsLimit(
            u64::from_le_bytes(bytes[8..16].try_into().unwrap())
                .try_into()
                .unwrap(),
        );
        let memlimit = MemLimit(
            u64::from_le_bytes(bytes[16..24].try_into().unwrap())
                .try_into()
                .unwrap(),
        );
        let salt = Salt(bytes[24..24 + SALTBYTES].try_into().unwrap());
        let nonce = Nonce(
            bytes[(24 + SALTBYTES)..(24 + SALTBYTES + NONCEBYTES)]
                .try_into()
                .unwrap(),
        );

        if magic != Header::MAGIC {
            return Err(HeaderSerdesError::WrongMagic { found: magic, expected: Header::MAGIC });
        }

        if version != Header::VERSION {
            return Err(HeaderSerdesError::UnsupportedVersion { found: version, expected: Header::VERSION });
        }

        let header = Header {
            magic,
            version,
            opslimit,
            memlimit,
            salt,
            nonce,
        };

        Ok(header)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn serialization_roundtrips() {
        let header = Header::default();
        let bytes = header.serialize();
        let des_header = Header::deserialize(&bytes).unwrap();

        assert_eq!(header, des_header);
    }

    #[test]
    fn deserialization_detects_too_short_header() {
        let bytes = vec![0x0; 10];
        let error = Header::deserialize(&bytes).unwrap_err();
        assert_eq!(error, HeaderSerdesError::TooShort { found: 10, expected: Header::SIZE });
    }

    #[test]
    fn deserialization_detects_wrong_magic() {
        let mut header = Header::default();
        header.magic = 0x1;
        let bytes = header.serialize();
        let error = Header::deserialize(&bytes).err().unwrap();
        assert_eq!(error, HeaderSerdesError::WrongMagic { found: 0x1, expected: Header::MAGIC });
    }

    #[test]
    fn deserialization_detects_wrong_version() {
        let mut header = Header::default();
        header.version = 100;
        let bytes = header.serialize();
        let error = Header::deserialize(&bytes).unwrap_err();
        assert_eq!(error, HeaderSerdesError::UnsupportedVersion { found: 100, expected: Header::VERSION });
    }
}
