//! # sjcl
//! Simple decrypt-only SJCL library.
//!
//! Only supports AES-CCM so far, but OCB2 is deprecated AFAIK.
//! To use you only need the result of a SJCL encrypted secret and the
//! passphrase.
//!
//! ## Usage
//!
//! Decrypt a file loaded into a string:
//! ```rust
//! use sjcl::{decrypt_raw, SjclError};
//! 
//! # fn main() -> Result<(), SjclError> {
//! let data = "{\"iv\":\"nJu7KZF2eEqMv403U2oc3w==\", \"v\":1, \"iter\":10000, \"ks\":256, \"ts\":64, \"mode\":\"ccm\", \"adata\":\"\", \"cipher\":\"aes\", \"salt\":\"mMmxX6SipEM=\", \"ct\":\"VwnKwpW1ah5HmdvwuFBthx0=\"}".to_string();
//! let password_phrase = "abcdefghi".to_string();
//! let plaintext = decrypt_raw(data, password_phrase)?;
//! assert_eq!("test\ntest".to_string(), String::from_utf8(plaintext).unwrap());
//! # Ok(())
//! # }
//! ```
//!
extern crate base64;

use aes::{Aes128, Aes192, Aes256};
use ccm::aead::{generic_array::GenericArray, Aead, NewAead};
use ccm::{
    consts::{U13, U16, U24, U32, U8},
    Ccm,
};
use password_hash::{PasswordHasher, SaltString};
use pbkdf2::{Params, Pbkdf2};
use serde::Deserialize;
use serde_json;

use thiserror::Error;
#[derive(Error, Debug)]
pub enum SjclError {
    #[error("Failed to decrypt chunk: {message:?}")]
    DecryptionError { message: String },
    #[error("Method is not implemented")]
    NotImplementedError,
}

/// Deserialized block generated by SJCL.
#[derive(Debug, Deserialize)]
pub struct SjclBlockJson {
    iv: String,
    v: u32,
    iter: u32,
    ks: usize,
    ts: usize,
    mode: String,
    adata: String,
    cipher: String,
    salt: String,
    ct: String,
}

type AesCcm256 = Ccm<Aes256, U8, U13>;
type AesCcm128 = Ccm<Aes128, U8, U13>;
type AesCcm192 = Ccm<Aes192, U8, U13>;

/// Decrypts a chunk of SJCL encrypted JSON with a given passphrase.
/// ```rust
/// let data = "{\"iv\":\"nJu7KZF2eEqMv403U2oc3w==\", \"v\":1, \"iter\":10000, \"ks\":256, \"ts\":64, \"mode\":\"ccm\", \"adata\":\"\", \"cipher\":\"aes\", \"salt\":\"mMmxX6SipEM=\", \"ct\":\"VwnKwpW1ah5HmdvwuFBthx0=\"}".to_string();
/// let password_phrase = "abcdefghi".to_string();
/// let plaintext = decrypt_raw(data, password_phrase)?;
/// assert_eq!("test\ntest".to_string(), String::from_utf8(plaintext).unwrap());
/// ```
pub fn decrypt_raw(chunk: String, key: String) -> Result<Vec<u8>, SjclError> {
    match serde_json::from_str(&chunk) {
        Ok(chunk) => decrypt(chunk, key),
        Err(_) => {
            return Err(SjclError::DecryptionError {
                message: "Failed to parse JSON".to_string(),
            })
        }
    }
}

/// Utility function to trim the initialization vector to the proper size of
/// the nonce.
/// (See: [SJCL/core.ccm.js](https://github.com/bitwiseshiftleft/sjcl/blob/master/core/ccm.js#L61))
fn truncate_iv(mut iv: Vec<u8>, output_size: usize, tag_size: usize) -> Vec<u8> {
    let iv_size = iv.len();
    let output_size = (output_size - tag_size) / 8;

    let mut l = 2;
    while l < 4 && ((output_size >> (8 * l)) > 0) {
        l += 1
    }
    if iv_size <= 15 && l < 15 - iv_size {
        l = 15 - iv_size
    }

    let _ = iv.split_off(15 - l);
    iv
}

/// Decrypts a chunk of SJCL encrypted JSON with a given passphrase.
/// ```rust
/// let data = SjclBlockJson {
///   iv: "aDvOWpwgcF0S7YDvu3TrTQ==".to_string(),
///   v: 1,
///   iter: 1000,
///   ks: 128,
///   ts: 64,
///   mode: "ccm".to_string(),
///   adata: "".to_string(),
///   cipher: "aes".to_string(),
///   salt: "qpVeWJh4g1I=".to_string(),
///   ct: "3F6gxac5V5k39iUNHubqEOHrxuZJqoX2zyws9nU=".to_string(),
/// };
/// let plaintext = decrypt(data, "abcdefghi".to_string());
/// assert_eq!("but dogs are the best".to_string(), plaintext);
/// ```
pub fn decrypt(mut chunk: SjclBlockJson, key: String) -> Result<Vec<u8>, SjclError> {
    match chunk.cipher.as_str() {
        "aes" => {
            match chunk.mode.as_str() {
                "ccm" => {
                    if chunk.v != 1 {
                        return Err(SjclError::DecryptionError {
                            message: "Only version 1 is currently supported".to_string(),
                        });
                    }
                    if chunk.adata.len() > 0 {
                        return Err(SjclError::DecryptionError {
                            message: "Expected empty additional data".to_string(),
                        });
                    }

                    let salt = match base64::decode(chunk.salt) {
                        Ok(v) => SaltString::b64_encode(&v),
                        Err(_) => {
                            return Err(SjclError::DecryptionError {
                                message: "Failed to base64 decode salt".to_string(),
                            })
                        }
                    };
                    let salt = match salt {
                        Ok(s) => s,
                        Err(_) => {
                            return Err(SjclError::DecryptionError {
                                message: "Failed to generate salt string".to_string(),
                            })
                        }
                    };
                    let password_hash = Pbkdf2.hash_password(
                        key.as_bytes(),
                        None,
                        None,
                        Params {
                            rounds: chunk.iter,
                            output_length: chunk.ks / 8,
                        },
                        salt.as_salt(),
                    );
                    let password_hash = match password_hash {
                        Ok(pwh) => pwh,
                        Err(_) => {
                            return Err(SjclError::DecryptionError {
                                message: "Failed to generate password hash".to_string(),
                            })
                        }
                    };
                    let password_hash = password_hash.hash.unwrap();

                    // Fix missing padding
                    for _ in 0..(chunk.iv.len() % 4) {
                        chunk.iv.push('=');
                    }
                    for _ in 0..(chunk.ct.len() % 4) {
                        chunk.ct.push('=');
                    }
                    let iv = match base64::decode(chunk.iv) {
                        Ok(v) => v,
                        Err(_) => {
                            return Err(SjclError::DecryptionError {
                                message: "Failed to decode IV".to_string(),
                            })
                        }
                    };
                    let ct = match base64::decode(chunk.ct) {
                        Ok(v) => v,
                        Err(_) => {
                            return Err(SjclError::DecryptionError {
                                message: "Failed to decode ct".to_string(),
                            })
                        }
                    };
                    let iv = truncate_iv(iv, ct.len() * 8, chunk.ts);
                    let nonce = GenericArray::from_slice(iv.as_slice());
                    match chunk.ks {
                        256 => {
                            let key: &GenericArray<u8, U32> =
                                GenericArray::from_slice(password_hash.as_bytes());
                            let cipher = AesCcm256::new(key);
                            let plaintext = match cipher.decrypt(nonce, ct.as_ref()) {
                                Ok(pt) => pt,
                                Err(_) => {
                                    return Err(SjclError::DecryptionError {
                                        message: "Failed to decrypt ciphertext".to_string(),
                                    });
                                }
                            };
                            Ok(plaintext)
                        }
                        192 => {
                            let key: &GenericArray<u8, U24> =
                                GenericArray::from_slice(password_hash.as_bytes());
                            let cipher = AesCcm192::new(key);
                            let plaintext = match cipher.decrypt(nonce, ct.as_ref()) {
                                Ok(pt) => pt,
                                Err(_) => {
                                    return Err(SjclError::DecryptionError {
                                        message: "Failed to decrypt ciphertext".to_string(),
                                    });
                                }
                            };
                            Ok(plaintext)
                        }
                        128 => {
                            let key: &GenericArray<u8, U16> =
                                GenericArray::from_slice(password_hash.as_bytes());
                            let cipher = AesCcm128::new(key);
                            let plaintext = match cipher.decrypt(nonce, ct.as_ref()) {
                                Ok(pt) => pt,
                                Err(_) => {
                                    return Err(SjclError::DecryptionError {
                                        message: "Failed to decrypt ciphertext".to_string(),
                                    });
                                }
                            };
                            Ok(plaintext)
                        }
                        _ => Err(SjclError::NotImplementedError),
                    }
                }
                "ocb2" => Err(SjclError::NotImplementedError),
                _ => Err(SjclError::NotImplementedError),
            }
        }
        _ => Err(SjclError::NotImplementedError),
    }
}

/// https://bitwiseshiftleft.github.io/sjcl/demo/
#[cfg(test)]
mod tests {
    use crate::{decrypt, decrypt_raw, SjclBlockJson};

    #[test]
    fn test_256bit_end_to_end() {
        let data = "{\"iv\":\"nJu7KZF2eEqMv403U2oc3w==\", \"v\":1, \"iter\":10000, \"ks\":256, \"ts\":64, \"mode\":\"ccm\", \"adata\":\"\", \"cipher\":\"aes\", \"salt\":\"mMmxX6SipEM=\", \"ct\":\"VwnKwpW1ah5HmdvwuFBthx0=\"}".to_string();
        let password_phrase = "abcdefghi".to_string();

        let plaintext = "test\ntest".to_string();

        assert_eq!(String::from_utf8(decrypt_raw(data, password_phrase).unwrap()).unwrap(), plaintext);
    }

    #[test]
    fn test_256bit_with_struct() {
        let data = SjclBlockJson {
            iv: "nJu7KZF2eEqMv403U2oc3w".to_string(),
            v: 1,
            iter: 10000,
            ks: 256,
            ts: 64,
            mode: "ccm".to_string(),
            adata: "".to_string(),
            cipher: "aes".to_string(),
            salt: "mMmxX6SipEM".to_string(),
            ct: "VwnKwpW1ah5HmdvwuFBthx0=".to_string(),
        };
        let password_phrase = "abcdefghi".to_string();

        let plaintext = "test\ntest".to_string();

        assert_eq!(String::from_utf8(decrypt(data, password_phrase).unwrap()).unwrap(), plaintext);
    }

    #[test]
    fn test_192bit_end_to_end() {
        let data = "{\"iv\":\"rUeOzcoSOAmbJIZ4o7wZzA==\", \"v\":1, \"iter\":1000, \"ks\":192, \"ts\":64, \"mode\":\"ccm\", \"adata\":\"\", \"cipher\":\"aes\", \"salt\":\"qpVeWJh4g1I=\", \"ct\":\"QJx31ojP+TW25eYZSFnjrG85dOZY\"}".to_string();
        let password_phrase = "abcdefghi".to_string();

        let plaintext = "cats are cute".to_string();

        assert_eq!(String::from_utf8(decrypt_raw(data, password_phrase).unwrap()).unwrap(), plaintext);
    }

    #[test]
    fn test_192bit_with_struct() {
        let data = SjclBlockJson {
            iv: "rUeOzcoSOAmbJIZ4o7wZzA==".to_string(),
            v: 1,
            iter: 1000,
            ks: 192,
            ts: 64,
            mode: "ccm".to_string(),
            adata: "".to_string(),
            cipher: "aes".to_string(),
            salt: "qpVeWJh4g1I=".to_string(),
            ct: "QJx31ojP+TW25eYZSFnjrG85dOZY".to_string(),
        };
        let password_phrase = "abcdefghi".to_string();

        let plaintext = "cats are cute".to_string();

        assert_eq!(String::from_utf8(decrypt(data, password_phrase).unwrap()).unwrap(), plaintext);
    }

    #[test]
    fn test_128bit_end_to_end() {
        let data = "{\"iv\":\"aDvOWpwgcF0S7YDvu3TrTQ==\", \"v\":1, \"iter\":1000, \"ks\":128, \"ts\":64, \"mode\":\"ccm\", \"adata\":\"\", \"cipher\":\"aes\", \"salt\":\"qpVeWJh4g1I=\", \"ct\":\"3F6gxac5V5k39iUNHubqEOHrxuZJqoX2zyws9nU=\"}".to_string();
        let password_phrase = "abcdefghi".to_string();

        let plaintext = "but dogs are the best".to_string();

        assert_eq!(String::from_utf8(decrypt_raw(data, password_phrase).unwrap()).unwrap(), plaintext);
    }

    #[test]
    fn test_128bit_with_struct() {
        let data = SjclBlockJson {
            iv: "aDvOWpwgcF0S7YDvu3TrTQ==".to_string(),
            v: 1,
            iter: 1000,
            ks: 128,
            ts: 64,
            mode: "ccm".to_string(),
            adata: "".to_string(),
            cipher: "aes".to_string(),
            salt: "qpVeWJh4g1I=".to_string(),
            ct: "3F6gxac5V5k39iUNHubqEOHrxuZJqoX2zyws9nU=".to_string(),
        };
        let password_phrase = "abcdefghi".to_string();

        let plaintext = "but dogs are the best".to_string();

        assert_eq!(String::from_utf8(decrypt(data, password_phrase).unwrap()).unwrap(), plaintext);
    }
}
