// SPDX-License-Identifier: GPL-3.0-or-later

use std::collections::BTreeMap;
use std::convert::TryInto;

use anyhow::{anyhow, Context, Result};
use sodiumoxide::crypto::aead::xchacha20poly1305_ietf::Key;
use sodiumoxide::crypto::aead::xchacha20poly1305_ietf;

use crate::database::Database;
use crate::header::Header;

pub type Entries = BTreeMap<String, String>;

pub fn encrypt(header: &mut Header, entries: &Entries, key: &Key) -> Database {
    let plaintext = entries_to_plaintext(entries);

    // Each encryption gets a new nonce
    header.nonce = xchacha20poly1305_ietf::gen_nonce();
    let header_bytes = header.serialize();

    let ciphertext =
        xchacha20poly1305_ietf::seal(&plaintext, Some(&header_bytes), &header.nonce, key);

    Database {
        header: header_bytes,
        ciphertext,
    }
}

fn decrypt(database: &Database, header: &Header, key: &Key) -> Result<Entries> {
    let plaintext = xchacha20poly1305_ietf::open(
        &database.ciphertext,
        Some(&database.header),
        &header.nonce,
        key
    ).map_err(|_| anyhow!("Decryption failed."))?;

    Ok(plaintext_to_entries(&plaintext))
}

pub fn decrypt_soft(database: &Database, header: &Header, key: &Key) -> Result<Entries> {
    decrypt(database, header, key).context("Did you enter the wrong passphrase?")
}

pub fn decrypt_hard(database: &Database, header: &Header, key: &Key) -> Result<Entries> {
    decrypt(database, header, key).context("Database has been corrupted!")
}

fn entries_to_plaintext(entries: &Entries) -> Vec<u8> {
    let mut plaintext = Vec::<u8>::new();
    for (name, content) in entries {
        serialize_string(&mut plaintext, name);
        serialize_string(&mut plaintext, content);
    }

    plaintext
}

fn plaintext_to_entries(plaintext: &[u8]) -> Entries {
    let mut entries = Entries::new();

    let length = plaintext.len();
    let mut i: usize = 0;

    while i < length {
        let name = deserialize_string(plaintext, &mut i);
        let content = deserialize_string(plaintext, &mut i);
        entries.insert(name, content);
    }

    assert!(i == length);

    entries
}

fn serialize_string(bytes: &mut Vec<u8>, string: &str) {
    bytes.extend_from_slice(&u32::to_le_bytes(string.len().try_into().unwrap()));
    bytes.extend_from_slice(string.as_bytes());
}

fn deserialize_string(bytes: &[u8], i: &mut usize) -> String {
    let size: usize = u32::from_le_bytes(bytes[*i..*i + 4].try_into().unwrap())
        .try_into()
        .unwrap();
    *i += 4;
    let string = String::from_utf8(bytes[*i..*i + size].to_vec()).unwrap();
    *i += size;
    string
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn string_serialization_roundtrips() {
        let strings = ["", "abc", "Hello world!", "平仮名", "😀😁"];
        for string in strings.iter() {
            let mut bytes = Vec::new();
            let mut i = 0;
            serialize_string(&mut bytes, string);
            assert_eq!(*string, deserialize_string(&mut bytes, &mut i));
            assert_eq!(i, 4 + string.len());
        }
    }

    #[test]
    fn entries_serialization_roundtrips() {
        let mut entries = Entries::new();
        assert_eq!(entries, plaintext_to_entries(&entries_to_plaintext(&entries)));

        entries.insert("".to_string(), "".to_string());
        assert_eq!(entries, plaintext_to_entries(&entries_to_plaintext(&entries)));

        entries.insert("Hello".to_string(), "World".to_string());
        assert_eq!(entries, plaintext_to_entries(&entries_to_plaintext(&entries)));

        entries.insert("平仮名".to_string(), "😀😁".to_string());
        assert_eq!(entries, plaintext_to_entries(&entries_to_plaintext(&entries)));
    }

    #[test]
    fn encryption_roundtrips() {
        let key = xchacha20poly1305_ietf::gen_key();
        let mut header = Header::default();

        let mut entries = Entries::new();
        assert_eq!(entries, decrypt(&encrypt(&mut header, &entries, &key), &header, &key).unwrap());

        entries.insert("".to_string(), "".to_string());
        assert_eq!(entries, decrypt(&encrypt(&mut header, &entries, &key), &header, &key).unwrap());

        entries.insert("Hello".to_string(), "World".to_string());
        assert_eq!(entries, decrypt(&encrypt(&mut header, &entries, &key), &header, &key).unwrap());

        entries.insert("平仮名".to_string(), "😀😁".to_string());
        assert_eq!(entries, decrypt(&encrypt(&mut header, &entries, &key), &header, &key).unwrap());
    }

    #[test]
    fn database_header_corruption_is_detected() {
        let key = xchacha20poly1305_ietf::gen_key();
        let mut header = Header::default();

        let entries = Entries::new();
        let mut database = encrypt(&mut header, &entries, &key);

        database.header[0] += 1;

        assert!(decrypt(&database, &header, &key).is_err());
    }

    #[test]
    fn database_ciphertext_corruption_is_detected() {
        let key = xchacha20poly1305_ietf::gen_key();
        let mut header = Header::default();

        let entries = Entries::new();
        let mut database = encrypt(&mut header, &entries, &key);

        database.ciphertext[0] += 1;

        assert!(decrypt(&database, &header, &key).is_err());
    }
}
