// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

//! Contains the messages used for OPAQUE

use crate::{
    ciphersuite::CipherSuite,
    envelope::Envelope,
    errors::{
        utils::{check_slice_size, check_slice_size_atleast},
        PakeError, ProtocolError,
    },
    group::Group,
    key_exchange::traits::{KeyExchange, ToBytes},
    keypair::{Key, KeyPair, SizedBytesExt},
};
use generic_array::{typenum::Unsigned, GenericArray};
use generic_bytes::SizedBytes;
use std::convert::TryFrom;

// Messages
// =========

/// The message sent by the client to the server, to initiate registration
pub struct RegistrationRequest<CS: CipherSuite> {
    /// blinded password information
    pub(crate) alpha: CS::Group,
}

impl<CS: CipherSuite> RegistrationRequest<CS> {
    /// Only used for testing purposes
    #[cfg(test)]
    pub fn get_alpha_for_testing(&self) -> CS::Group {
        self.alpha
    }
}

// Cannot be derived because it would require for CS to be Clone.
impl<CS: CipherSuite> Clone for RegistrationRequest<CS> {
    fn clone(&self) -> Self {
        Self { alpha: self.alpha }
    }
}

impl<CS: CipherSuite> RegistrationRequest<CS> {
    /// Serialization into bytes
    pub fn serialize(&self) -> Vec<u8> {
        self.alpha.to_arr().to_vec()
    }

    /// Deserialization from bytes
    pub fn deserialize(input: &[u8]) -> Result<Self, ProtocolError> {
        let elem_len = <CS::Group as Group>::ElemLen::to_usize();
        let checked_slice = check_slice_size(&input, elem_len, "first_message_bytes")?;
        // Check that the message is actually containing an element of the
        // correct subgroup
        let arr = GenericArray::from_slice(checked_slice);
        let alpha = CS::Group::from_element_slice(arr)?;

        // Throw an error if the identity group element is encountered
        if alpha.is_identity() {
            return Err(PakeError::IdentityGroupElementError.into());
        }
        Ok(Self { alpha })
    }
}

/// The answer sent by the server to the user, upon reception of the
/// registration attempt
pub struct RegistrationResponse<CS: CipherSuite> {
    /// The server's oprf output
    pub(crate) beta: CS::Group,
    /// Server's static public key
    pub(crate) server_s_pk: Vec<u8>,
}

impl<CS: CipherSuite> RegistrationResponse<CS> {
    /// Serialization into bytes
    pub fn serialize(&self) -> Vec<u8> {
        [self.beta.to_arr().to_vec(), self.server_s_pk.clone()].concat()
    }

    /// Deserialization from bytes
    pub fn deserialize(input: &[u8]) -> Result<Self, ProtocolError> {
        let elem_len = <CS::Group as Group>::ElemLen::to_usize();
        let key_len = <Key as SizedBytes>::Len::to_usize();
        let checked_slice =
            check_slice_size(&input, elem_len + key_len, "registration_response_bytes")?;

        // Check that the message is actually containing an element of the
        // correct subgroup
        let arr = GenericArray::from_slice(&checked_slice[..elem_len]);
        let beta = CS::Group::from_element_slice(arr)?;

        // Throw an error if the identity group element is encountered
        if beta.is_identity() {
            return Err(PakeError::IdentityGroupElementError.into());
        }

        // Ensure that public key is valid
        let server_s_pk =
            KeyPair::<CS::Group>::check_public_key(Key::from_bytes(&checked_slice[elem_len..])?)?;

        Ok(Self {
            beta,
            server_s_pk: server_s_pk.to_arr().to_vec(),
        })
    }

    #[cfg(test)]
    /// Only used for tests, where we can set the beta value to test for the reflection
    /// error case
    pub fn set_beta_for_testing(&self, new_beta: CS::Group) -> Self {
        Self {
            beta: new_beta,
            server_s_pk: self.server_s_pk.clone(),
        }
    }
}

/// The final message from the client, containing sealed cryptographic
/// identifiers
pub struct RegistrationUpload<CS: CipherSuite> {
    /// The "envelope" generated by the user, containing sealed
    /// cryptographic identifiers
    pub(crate) envelope: Envelope<CS::Hash>,
    /// The user's public key
    pub(crate) client_s_pk: Key,
}

impl<CS: CipherSuite> RegistrationUpload<CS> {
    /// Serialization into bytes
    pub fn serialize(&self) -> Vec<u8> {
        [
            self.client_s_pk.to_arr().to_vec(),
            self.envelope.serialize(),
        ]
        .concat()
    }

    /// Deserialization from bytes
    pub fn deserialize(input: &[u8]) -> Result<Self, ProtocolError> {
        let key_len = <Key as SizedBytes>::Len::to_usize();

        let checked_slice = check_slice_size_atleast(&input, key_len, "registration_upload_bytes")?;

        let (envelope, remainder) = Envelope::<CS::Hash>::deserialize(&checked_slice[key_len..])?;

        if !remainder.is_empty() {
            return Err(PakeError::SerializationError.into());
        }

        Ok(Self {
            envelope,
            client_s_pk: KeyPair::<CS::Group>::check_public_key(Key::from_bytes(
                &checked_slice[..key_len],
            )?)?,
        })
    }
}

/// The message sent by the user to the server, to initiate registration
pub struct CredentialRequest<CS: CipherSuite> {
    /// blinded password information
    pub(crate) alpha: CS::Group,
    pub(crate) ke1_message: <CS::KeyExchange as KeyExchange<CS::Hash, CS::Group>>::KE1Message,
}

impl<CS: CipherSuite> CredentialRequest<CS> {
    /// Serialization into bytes
    pub fn serialize(&self) -> Result<Vec<u8>, ProtocolError> {
        let mut credential_request: Vec<u8> = Vec::new();
        credential_request.extend_from_slice(&self.alpha.to_arr());
        credential_request.extend_from_slice(&self.ke1_message.to_bytes()?);
        Ok(credential_request)
    }

    /// Deserialization from bytes
    pub fn deserialize(input: &[u8]) -> Result<Self, ProtocolError> {
        let elem_len = <CS::Group as Group>::ElemLen::to_usize();

        let checked_slice =
            check_slice_size_atleast(&input, elem_len, "login_first_message_bytes")?;

        // Check that the message is actually containing an element of the
        // correct subgroup
        let arr = GenericArray::from_slice(&checked_slice[..elem_len]);
        let alpha = CS::Group::from_element_slice(arr)?;

        // Throw an error if the identity group element is encountered
        if alpha.is_identity() {
            return Err(PakeError::IdentityGroupElementError.into());
        }

        let ke1_message =
            <CS::KeyExchange as KeyExchange<CS::Hash, CS::Group>>::KE1Message::try_from(
                &checked_slice[elem_len..],
            )?;

        Ok(Self { alpha, ke1_message })
    }

    /// Only used for testing purposes
    #[cfg(test)]
    pub fn get_alpha_for_testing(&self) -> CS::Group {
        self.alpha
    }
}

/// The answer sent by the server to the user, upon reception of the
/// login attempt
pub struct CredentialResponse<CS: CipherSuite> {
    /// the server's oprf output
    pub(crate) beta: CS::Group,
    pub(crate) server_s_pk: Key,
    /// the user's sealed information,
    pub(crate) envelope: Envelope<CS::Hash>,
    pub(crate) ke2_message: <CS::KeyExchange as KeyExchange<CS::Hash, CS::Group>>::KE2Message,
}

impl<CS: CipherSuite> CredentialResponse<CS> {
    /// Serialization into bytes
    pub fn serialize(&self) -> Result<Vec<u8>, ProtocolError> {
        Ok([
            Self::serialize_without_ke(&self.beta, &self.server_s_pk, &self.envelope),
            self.ke2_message.to_bytes()?,
        ]
        .concat())
    }

    pub(crate) fn serialize_without_ke(
        beta: &CS::Group,
        server_s_pk: &Key,
        envelope: &Envelope<CS::Hash>,
    ) -> Vec<u8> {
        [
            &beta.to_arr(),
            &server_s_pk.to_arr()[..],
            &envelope.to_bytes(),
        ]
        .concat()
    }

    /// Deserialization from bytes
    pub fn deserialize(input: &[u8]) -> Result<Self, ProtocolError> {
        let elem_len = <CS::Group as Group>::ElemLen::to_usize();
        let key_len = <Key as SizedBytes>::Len::to_usize();
        let checked_slice =
            check_slice_size_atleast(input, elem_len + key_len, "login_second_message_bytes")?;

        // Check that the message is actually containing an element of the
        // correct subgroup
        let beta_bytes = &checked_slice[..elem_len];
        let arr = GenericArray::from_slice(beta_bytes);
        let beta = CS::Group::from_element_slice(arr)?;

        // Throw an error if the identity group element is encountered
        if beta.is_identity() {
            return Err(PakeError::IdentityGroupElementError.into());
        }

        let unchecked_server_s_pk = Key::from_bytes(&checked_slice[elem_len..elem_len + key_len])?;
        let server_s_pk = KeyPair::<CS::Group>::check_public_key(unchecked_server_s_pk)?;

        let (envelope, remainder) =
            Envelope::<CS::Hash>::deserialize(&checked_slice[elem_len + key_len..])?;

        let ke2_message_size = CS::KeyExchange::ke2_message_size();
        let checked_remainder =
            check_slice_size_atleast(&remainder, ke2_message_size, "login_second_message_bytes")?;
        let ke2_message =
            <CS::KeyExchange as KeyExchange<CS::Hash, CS::Group>>::KE2Message::try_from(
                &checked_remainder,
            )?;

        Ok(Self {
            beta,
            server_s_pk,
            envelope,
            ke2_message,
        })
    }

    #[cfg(test)]
    /// Only used for tests, where we can set the beta value to test for the reflection
    /// error case
    pub fn set_beta_for_testing(&self, new_beta: CS::Group) -> Self {
        Self {
            beta: new_beta,
            server_s_pk: self.server_s_pk.clone(),
            envelope: self.envelope.clone(),
            ke2_message: self.ke2_message.clone(),
        }
    }
}

/// The answer sent by the client to the server, upon reception of the
/// sealed envelope
pub struct CredentialFinalization<CS: CipherSuite> {
    pub(crate) ke3_message: <CS::KeyExchange as KeyExchange<CS::Hash, CS::Group>>::KE3Message,
}

impl<CS: CipherSuite> CredentialFinalization<CS> {
    /// Serialization into bytes
    pub fn serialize(&self) -> Result<Vec<u8>, ProtocolError> {
        self.ke3_message.to_bytes()
    }

    /// Deserialization from bytes
    pub fn deserialize(input: &[u8]) -> Result<Self, ProtocolError> {
        let ke3_message =
            <CS::KeyExchange as KeyExchange<CS::Hash, CS::Group>>::KE3Message::try_from(input)?;
        Ok(Self { ke3_message })
    }
}
