use super::*;
use crate::key_derivation::*;

use bytes::{Buf, Bytes, BytesMut};
use lazy_static::lazy_static;

pub struct RTCPTestCase {
    ssrc: u32,
    index: usize,
    encrypted: Bytes,
    decrypted: Bytes,
}

lazy_static! {
    static ref RTCP_TEST_MASTER_KEY: Bytes = Bytes::from_static(&[
        0xfd, 0xa6, 0x25, 0x95, 0xd7, 0xf6, 0x92, 0x6f, 0x7d, 0x9c, 0x02, 0x4c, 0xc9, 0x20, 0x9f,
        0x34
    ]);

    static ref RTCP_TEST_MASTER_SALT: Bytes = Bytes::from_static(&[
        0xa9, 0x65, 0x19, 0x85, 0x54, 0x0b, 0x47, 0xbe, 0x2f, 0x27, 0xa8, 0xb8, 0x81, 0x23
    ]);

    static ref RTCP_TEST_CASES: Vec<RTCPTestCase> = vec![
        RTCPTestCase {
            ssrc:      0x66ef91ff,
            index:     0,
            encrypted: Bytes::from_static(&[
                0x80, 0xc8, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0xcd, 0x34, 0xc5, 0x78, 0xb2, 0x8b,
                0xe1, 0x6b, 0xc5, 0x09, 0xd5, 0x77, 0xe4, 0xce, 0x5f, 0x20, 0x80, 0x21, 0xbd, 0x66,
                0x74, 0x65, 0xe9, 0x5f, 0x49, 0xe5, 0xf5, 0xc0, 0x68, 0x4e, 0xe5, 0x6a, 0x78, 0x07,
                0x75, 0x46, 0xed, 0x90, 0xf6, 0xdc, 0x9d, 0xef, 0x3b, 0xdf, 0xf2, 0x79, 0xa9, 0xd8,
                0x80, 0x00, 0x00, 0x01, 0x60, 0xc0, 0xae, 0xb5, 0x6f, 0x40, 0x88, 0x0e, 0x28, 0xba
            ]),
            decrypted: Bytes::from_static(&[
                0x80, 0xc8, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0xdf, 0x48, 0x80, 0xdd, 0x61, 0xa6,
                0x2e, 0xd3, 0xd8, 0xbc, 0xde, 0xbe, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x16, 0x04,
                0x81, 0xca, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0x01, 0x10, 0x52, 0x6e, 0x54, 0x35,
                0x43, 0x6d, 0x4a, 0x68, 0x7a, 0x79, 0x65, 0x74, 0x41, 0x78, 0x77, 0x2b, 0x00, 0x00
            ]),
        },
        RTCPTestCase{
            ssrc:      0x11111111,
            index:     0,
            encrypted: Bytes::from_static(&[
                0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11,
                0xda, 0xf5, 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, 0xbf, 0x22, 0xf6, 0x46,
                0x11, 0xa4, 0xc1, 0x3a, 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, 0x95, 0x38,
                0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad,
                0x80, 0x00, 0x00, 0x01, 0x34, 0x3c, 0x2e, 0x83, 0x17, 0x13, 0x93, 0x69, 0xcf, 0xc0
            ]),
            decrypted: Bytes::from_static(&[
                0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0xdf, 0x48, 0x80, 0xdd, 0x61, 0xa6,
                0x2e, 0xd3, 0xd8, 0xbc, 0xde, 0xbe, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x16, 0x04,
                0x81, 0xca, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0x01, 0x10, 0x52, 0x6e, 0x54, 0x35,
                0x43, 0x6d, 0x4a, 0x68, 0x7a, 0x79, 0x65, 0x74, 0x41, 0x78, 0x77, 0x2b, 0x00, 0x00
            ]),
        },
        RTCPTestCase{
            ssrc:      0x11111111,
            index:     0x7ffffffe, // Upper boundary of index
            encrypted: Bytes::from_static(&[
                0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11,
                0xda, 0xf5, 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, 0xbf, 0x22, 0xf6, 0x46,
                0x11, 0xa4, 0xc1, 0x3a, 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, 0x95, 0x38,
                0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad,
                0xff, 0xff, 0xff, 0xff, 0x5a, 0x99, 0xce, 0xed, 0x9f, 0x2e, 0x4d, 0x9d, 0xfa, 0x97
            ]),
            decrypted: Bytes::from_static(&[
                0x80, 0xc8, 0x0, 0x6, 0x11, 0x11, 0x11, 0x11, 0x4, 0x99, 0x47, 0x53, 0xc4, 0x1e,
                0xb9, 0xde, 0x52, 0xa3, 0x1d, 0x77, 0x2f, 0xff, 0xcc, 0x75, 0xbb, 0x6a, 0x29, 0xb8,
                0x1, 0xb7, 0x2e, 0x4b, 0x4e, 0xcb, 0xa4, 0x81, 0x2d, 0x46, 0x4, 0x5e, 0x86, 0x90,
                0x17, 0x4f, 0x4d, 0x78, 0x2f, 0x58, 0xb8, 0x67, 0x91, 0x89, 0xe3, 0x61, 0x1, 0x7d
            ]),
        },
        RTCPTestCase{
            ssrc:      0x11111111,
            index:     0x7fffffff, // Will be wrapped to 0
            encrypted: Bytes::from_static(&[
                0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11,
                0xda, 0xf5, 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, 0xbf, 0x22, 0xf6, 0x46,
                0x11, 0xa4, 0xc1, 0x3a, 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, 0x95, 0x38,
                0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad,
                0x80, 0x00, 0x00, 0x00, 0x7d, 0x51, 0xf8, 0x0e, 0x56, 0x40, 0x72, 0x7b, 0x9e, 0x02
            ]),
            decrypted: Bytes::from_static(&[
                0x80, 0xc8, 0x0, 0x6, 0x11, 0x11, 0x11, 0x11, 0xda, 0xb5, 0xe0, 0x56, 0x9a, 0x4a,
                0x74, 0xed, 0x8a, 0x54, 0xc, 0xcf, 0xd5, 0x9, 0xb1, 0x40, 0x1, 0x42, 0xc3, 0x9a,
                0x76, 0x0, 0xa9, 0xd4, 0xf7, 0x29, 0x9e, 0x51, 0xfb, 0x3c, 0xc1, 0x74, 0x72, 0xf9,
                0x52, 0xb1, 0x92, 0x31, 0xca, 0x22, 0xab, 0x3e, 0xc5, 0x5f, 0x83, 0x34, 0xf0, 0x28
            ]),
        },
    ];
}

#[test]
fn test_rtcp_lifecycle() -> Result<()> {
    let mut encrypt_context = Context::new(
        &RTCP_TEST_MASTER_KEY,
        &RTCP_TEST_MASTER_SALT,
        ProtectionProfile::Aes128CmHmacSha1_80,
        None,
        None,
    )?;
    let mut decrypt_context = Context::new(
        &RTCP_TEST_MASTER_KEY,
        &RTCP_TEST_MASTER_SALT,
        ProtectionProfile::Aes128CmHmacSha1_80,
        None,
        None,
    )?;

    for test_case in &*RTCP_TEST_CASES {
        let decrypt_result = decrypt_context.decrypt_rtcp(&test_case.encrypted)?;
        assert_eq!(
            decrypt_result, test_case.decrypted,
            "RTCP failed to decrypt"
        );

        encrypt_context.set_index(test_case.ssrc, test_case.index);
        let encrypt_result = encrypt_context.encrypt_rtcp(&test_case.decrypted)?;
        assert_eq!(
            encrypt_result, test_case.encrypted,
            "RTCP failed to encrypt"
        );
    }

    Ok(())
}

#[test]
fn test_rtcp_invalid_auth_tag() -> Result<()> {
    let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len();

    let mut decrypt_context = Context::new(
        &RTCP_TEST_MASTER_KEY,
        &RTCP_TEST_MASTER_SALT,
        ProtectionProfile::Aes128CmHmacSha1_80,
        None,
        None,
    )?;

    let decrypt_result = decrypt_context.decrypt_rtcp(&RTCP_TEST_CASES[0].encrypted)?;
    assert_eq!(
        decrypt_result, RTCP_TEST_CASES[0].decrypted,
        "RTCP failed to decrypt"
    );

    // Zero out auth tag
    let mut rtcp_packet = BytesMut::new();
    rtcp_packet.extend_from_slice(&RTCP_TEST_CASES[0].encrypted);
    let rtcp_packet_len = rtcp_packet.len();
    rtcp_packet[rtcp_packet_len - auth_tag_len..].copy_from_slice(&vec![0; auth_tag_len]);
    let rtcp_packet = rtcp_packet.freeze();
    let decrypt_result = decrypt_context.decrypt_rtcp(&rtcp_packet);
    assert!(
        decrypt_result.is_err(),
        "Was able to decrypt RTCP packet with invalid Auth Tag"
    );

    Ok(())
}

#[test]
fn test_rtcp_replay_detector_separation() -> Result<()> {
    let mut decrypt_context = Context::new(
        &RTCP_TEST_MASTER_KEY,
        &RTCP_TEST_MASTER_SALT,
        ProtectionProfile::Aes128CmHmacSha1_80,
        None,
        Some(srtcp_replay_protection(10)),
    )?;

    let rtcp_packet1 = RTCP_TEST_CASES[0].encrypted.clone();
    let decrypt_result1 = decrypt_context.decrypt_rtcp(&rtcp_packet1)?;
    assert_eq!(
        decrypt_result1, RTCP_TEST_CASES[0].decrypted,
        "RTCP failed to decrypt"
    );

    let rtcp_packet2 = RTCP_TEST_CASES[1].encrypted.clone();
    let decrypt_result2 = decrypt_context.decrypt_rtcp(&rtcp_packet2)?;
    assert_eq!(
        decrypt_result2, RTCP_TEST_CASES[1].decrypted,
        "RTCP failed to decrypt"
    );

    let result = decrypt_context.decrypt_rtcp(&rtcp_packet1);
    assert!(
        result.is_err(),
        "Was able to decrypt duplicated RTCP packet"
    );

    let result = decrypt_context.decrypt_rtcp(&rtcp_packet2);
    assert!(
        result.is_err(),
        "Was able to decrypt duplicated RTCP packet"
    );

    Ok(())
}

fn get_rtcp_index(encrypted: &Bytes, auth_tag_len: usize) -> u32 {
    let tail_offset = encrypted.len() - (auth_tag_len + SRTCP_INDEX_SIZE);
    let reader = &mut encrypted.slice(tail_offset..tail_offset + SRTCP_INDEX_SIZE);
    let rtcp_index = reader.get_u32() & 0x7FFFFFFF; //^(1 << 31)
    rtcp_index
}

#[test]
fn test_encrypt_rtcp_separation() -> Result<()> {
    let mut encrypt_context = Context::new(
        &RTCP_TEST_MASTER_KEY,
        &RTCP_TEST_MASTER_SALT,
        ProtectionProfile::Aes128CmHmacSha1_80,
        None,
        None,
    )?;

    let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len();

    let mut decrypt_context = Context::new(
        &RTCP_TEST_MASTER_KEY,
        &RTCP_TEST_MASTER_SALT,
        ProtectionProfile::Aes128CmHmacSha1_80,
        None,
        Some(srtcp_replay_protection(10)),
    )?;

    let inputs = vec![
        RTCP_TEST_CASES[0].decrypted.clone(),
        RTCP_TEST_CASES[1].decrypted.clone(),
        RTCP_TEST_CASES[0].decrypted.clone(),
        RTCP_TEST_CASES[1].decrypted.clone(),
    ];
    let mut encrypted_rctps = vec![];

    for input in &inputs {
        let encrypted = encrypt_context.encrypt_rtcp(input)?;
        encrypted_rctps.push(encrypted);
    }

    for (i, expected_index) in [1, 1, 2, 2].iter().enumerate() {
        assert_eq!(
            *expected_index,
            get_rtcp_index(&encrypted_rctps[i], auth_tag_len),
            "RTCP index does not match"
        );
    }

    for (i, output) in encrypted_rctps.iter().enumerate() {
        let decrypted = decrypt_context.decrypt_rtcp(output)?;
        assert_eq!(inputs[i], decrypted);
    }

    Ok(())
}
