//! [Curve25519](https://en.wikipedia.org/wiki/Curve25519) scalar multiplication.
//!
//! This module corresponds to the [`crypto_scalarmult`
//! API](https://doc.libsodium.org/advanced/scalar_multiplication) from Sodium.
//!
//! Note that this API only supports scalar multiplication of individual elements (i.e: adding a
//! point to itself `n` times), and not the full group operation over the curve.
//!
//! Curve25519 is the Montgomery curve `y^2 = x^3 + 486662x^2 + x` defined over the prime field of
//! order `2^255 - 19`, with generator given by the point on the curve with `x = 9`. In the byte
//! representation of Curve25519 points, only the `x` coordinate is stored. Any 32-byte string is a
//! valid point on the curve.
//!
//! Curve25519 is not of prime-order: The number of points on the curve is `8 * p`, where `p` is a
//! large prime. Many cryptographic operations require that operations are performed over a
//! prime-order group, so normally all the points on the curve are scalar-multiplied by 8, to obtain
//! a prime-order subgroup. [This blog
//! post](https://neilmadden.blog/2020/05/28/whats-the-curve25519-clamping-all-about/) explains the
//! clamping procedure, which is applied as part of the scalar multiplication operation here.

use super::CurveError;
use crate::{assert_not_err, mem, random, require_init, AlkaliError};
use libsodium_sys as sodium;

/// The length of the byte representation of a point on the curve, in bytes.
pub const POINT_LENGTH: usize = sodium::crypto_scalarmult_curve25519_BYTES as usize;

/// The length of a secret scalar by which a point on the curve can be multiplied, in bytes.
pub const SCALAR_LENGTH: usize = sodium::crypto_scalarmult_curve25519_SCALARBYTES as usize;

mem::hardened_buffer! {
    /// A secret scalar value by which a point on the curve can be scalar-multiplied.
    ///
    /// A scalar such as this generally takes the role of a secret key in elliptic-curve
    /// cryptography. Given `Q = nP`, where `Q, P` are public points on the curve, and `n` is an
    /// unknown scalar, it is computationally infeasible to calculate `n` (ECDLP).
    ///
    /// There are no technical constraints on the contents of a scalar for this API. Scalars are
    /// [clamped](https://neilmadden.blog/2020/05/28/whats-the-curve25519-clamping-all-about/) at
    /// the time of usage, so a scalar can just be any random sequence of bytes. However, if a
    /// scalar is intended to be secret, it should be generated randomly using [`Scalar::generate`].
    ///
    /// This is a [hardened buffer type](https://docs.rs/alkali#hardened-buffer-types), and will be
    /// zeroed on drop. A number of other security measures are taken to protect its contents. This
    /// type in particular can be thought of as roughly equivalent to a `[u8; SCALAR_LENGTH]`, and
    /// implements [`std::ops::Deref`], so it can be used like it is an `&[u8]`. This struct uses
    /// heap memory while in scope, allocated using Sodium's [secure memory
    /// utilities](https://doc.libsodium.org/memory_management).
    pub Scalar(SCALAR_LENGTH);
}

impl Scalar {
    /// Generate a random scalar value for use with Curve25519.
    pub fn generate() -> Result<Self, AlkaliError> {
        let mut n = Self::new_empty()?;
        random::fill_random(n.as_mut())?;
        Ok(n)
    }
}

/// A point on Curve25519.
///
/// For Curve25519, only the `x` coordinate is stored.
#[derive(Clone, Copy, Debug)]
pub struct Point(pub [u8; POINT_LENGTH]);

impl Point {
    /// Scalar multiply this point by the scalar `n`.
    ///
    /// Calculates `Q = nP`, where `P` is this point, `n` is the [`Scalar`] by which `P` should be
    /// multiplied, and `Q` is the return value. That is, `P` is added to itself `n` times.
    ///
    /// Finding `n` given `Q` and `P` is the elliptic curve discrete logarithm problem, since we
    /// check that `P` is in the prime-order subgroup of the curve, so it is computationally
    /// infeasible to find `n`. This can be used to compute a shared secret `Q`, if `P` is a user's
    /// public key and `n` is another user's secret key.
    ///
    /// `n` will be
    /// [clamped](https://neilmadden.blog/2020/05/28/whats-the-curve25519-clamping-all-about/)
    /// before multiplying.
    ///
    /// Returns the result of the scalar multiplication (a new point on the curve), or an error if
    /// `P` is of low order.
    pub fn scalar_mult(&self, n: &Scalar) -> Result<Self, AlkaliError> {
        require_init()?;

        let mut q = [0u8; POINT_LENGTH];

        let scalarmult_result = unsafe {
            // SAFETY: The first argument to this function is the destination to which the scalar
            // product should be written, a point on Curve25519 in compressed format. We define `q`
            // to be `crypto_scalarmult_curve25519_BYTES`, the length of a compressed Curve25519
            // point representation, so `q` is valid for writes of the required length. The next
            // argument is the scalar by which `p` should be multiplied. The `Scalar` type is
            // defined to allocate `crypto_scalarmult_curve25519_SCALARBYTES` bytes, the length of a
            // scalar for this algorithm, so `n` is valid for reads of the required length. The
            // final argument is the compressed representation of the point on Curve25519 which
            // should be multiplied by the scalar. The `Point` type stores
            // `crypto_scalarmult_curve25519_BYTES` bytes, the length of a compressed Curve25519
            // point representation, so `self.0` is valid for reads of the required length. The
            // `Scalar::inner` method simply returns a pointer to the backing memory of the struct.
            sodium::crypto_scalarmult(
                q.as_mut_ptr(),
                n.inner() as *const libc::c_uchar,
                self.0.as_ptr(),
            )
        };

        if scalarmult_result == 0 {
            Ok(Self(q))
        } else {
            Err(CurveError::ScalarMultUnacceptable.into())
        }
    }

    /// Scalar multiply this point by the scalar `n`, in place.
    ///
    /// This function is equivalent to [`Self::scalar_mult`], but modifies `self` in place, rather
    /// than returning the new point.
    pub fn scalar_mult_in_place(&mut self, n: &Scalar) -> Result<(), AlkaliError> {
        let q = self.scalar_mult(n)?;
        self.0 = q.0;
        Ok(())
    }
}

/// Multiply the Curve25519 generator by the scalar `n`.
///
/// Calculates `Q = nG`, where `G` is the generator for the curve (the point with `x = 9`), `n` is
/// the scalar by which `G` should be multiplied, and `Q` is the return value. That is, `G` is added
/// to itself `n` times.
///
/// Finding `n` given `Q` and `G` is the elliptic curve discrete logarithm problem, so it is
/// computationally infeasible to find `n`. This can be used to compute the public key corresponding
/// to the secret key `n`.
///
/// `n` will be
/// [clamped](https://neilmadden.blog/2020/05/28/whats-the-curve25519-clamping-all-about/)
/// before multiplying.
///
/// Returns the result of the scalar multiplication (a new point on the curve).
pub fn scalar_mult_base(n: &Scalar) -> Result<Point, AlkaliError> {
    require_init()?;

    let mut q = [0u8; POINT_LENGTH];

    let scalarmult_result = unsafe {
        // SAFETY: The first argument to this function is the destination to which the scalar
        // product should be written, a point on Curve25519 in compressed format. We define `q` to
        // be `crypto_scalarmult_curve25519_BYTES`, the length of a compressed Curve25519 point
        // representation, so `q` is valid for writes of the required length. The next argument is
        // the scalar by which the generator should be multiplied. The `Scalar` type is defined to
        // allocate `crypto_scalarmult_curve25519_SCALARBYTES` bytes, the length of a scalar for
        // this algorithm, so `n` is valid for reads of the required length. The `Scalar::inner`
        // method simply returns a pointer to the backing memory of the struct.
        sodium::crypto_scalarmult_curve25519_base(q.as_mut_ptr(), n.inner() as *const libc::c_uchar)
    };
    assert_not_err!(scalarmult_result, "crypto_scalarmult_curve25519_base");

    Ok(Point(q))
}

#[cfg(test)]
mod tests {
    use super::{scalar_mult_base, CurveError, Point, Scalar};
    use crate::AlkaliError;

    #[test]
    fn scalar_generation() -> Result<(), AlkaliError> {
        let _ = Scalar::generate()?;
        Ok(())
    }

    #[test]
    fn scalarmult_base_vectors() -> Result<(), AlkaliError> {
        let scalars = [
            [
                0x77, 0x07, 0x6d, 0x0a, 0x73, 0x18, 0xa5, 0x7d, 0x3c, 0x16, 0xc1, 0x72, 0x51, 0xb2,
                0x66, 0x45, 0xdf, 0x4c, 0x2f, 0x87, 0xeb, 0xc0, 0x99, 0x2a, 0xb1, 0x77, 0xfb, 0xa5,
                0x1d, 0xb9, 0x2c, 0x2a,
            ],
            [
                0x5d, 0xab, 0x08, 0x7e, 0x62, 0x4a, 0x8a, 0x4b, 0x79, 0xe1, 0x7f, 0x8b, 0x83, 0x80,
                0x0e, 0xe6, 0x6f, 0x3b, 0xb1, 0x29, 0x26, 0x18, 0xb6, 0xfd, 0x1c, 0x2f, 0x8b, 0x27,
                0xff, 0x88, 0xe0, 0xeb,
            ],
        ];
        let points = [
            [
                0x85, 0x20, 0xf0, 0x09, 0x89, 0x30, 0xa7, 0x54, 0x74, 0x8b, 0x7d, 0xdc, 0xb4, 0x3e,
                0xf7, 0x5a, 0x0d, 0xbf, 0x3a, 0x0d, 0x26, 0x38, 0x1a, 0xf4, 0xeb, 0xa4, 0xa9, 0x8e,
                0xaa, 0x9b, 0x4e, 0x6a,
            ],
            [
                0xde, 0x9e, 0xdb, 0x7d, 0x7b, 0x7d, 0xc1, 0xb4, 0xd3, 0x5b, 0x61, 0xc2, 0xec, 0xe4,
                0x35, 0x37, 0x3f, 0x83, 0x43, 0xc8, 0x5b, 0x78, 0x67, 0x4d, 0xad, 0xfc, 0x7e, 0x14,
                0x6f, 0x88, 0x2b, 0x4f,
            ],
        ];

        let mut s = Scalar::new_empty()?;

        for (scalar, &point) in scalars.iter().zip(points.iter()) {
            s.copy_from_slice(scalar);
            let actual = scalar_mult_base(&s)?;
            assert_eq!(actual.0, point);
        }

        Ok(())
    }

    #[test]
    fn scalarmult_vectors() -> Result<(), AlkaliError> {
        let ps = [
            [
                0xde, 0x9e, 0xdb, 0x7d, 0x7b, 0x7d, 0xc1, 0xb4, 0xd3, 0x5b, 0x61, 0xc2, 0xec, 0xe4,
                0x35, 0x37, 0x3f, 0x83, 0x43, 0xc8, 0x5b, 0x78, 0x67, 0x4d, 0xad, 0xfc, 0x7e, 0x14,
                0x6f, 0x88, 0x2b, 0x4f,
            ],
            [
                0x85, 0x20, 0xf0, 0x09, 0x89, 0x30, 0xa7, 0x54, 0x74, 0x8b, 0x7d, 0xdc, 0xb4, 0x3e,
                0xf7, 0x5a, 0x0d, 0xbf, 0x3a, 0x0d, 0x26, 0x38, 0x1a, 0xf4, 0xeb, 0xa4, 0xa9, 0x8e,
                0xaa, 0x9b, 0x4e, 0x6a,
            ],
            [
                0x9c, 0x64, 0x7d, 0x9a, 0xe5, 0x89, 0xb9, 0xf5, 0x8f, 0xdc, 0x3c, 0xa4, 0x94, 0x7e,
                0xfb, 0xc9, 0x15, 0xc4, 0xb2, 0xe0, 0x8e, 0x74, 0x4a, 0x0e, 0xdf, 0x46, 0x9d, 0xac,
                0x59, 0xc8, 0xf8, 0x5a,
            ],
            [
                0x9c, 0x64, 0x7d, 0x9a, 0xe5, 0x89, 0xb9, 0xf5, 0x8f, 0xdc, 0x3c, 0xa4, 0x94, 0x7e,
                0xfb, 0xc9, 0x15, 0xc4, 0xb2, 0xe0, 0x8e, 0x74, 0x4a, 0x0e, 0xdf, 0x46, 0x9d, 0xac,
                0x59, 0xc8, 0xf8, 0x5a,
            ],
            [
                0x63, 0xaa, 0x40, 0xc6, 0xe3, 0x83, 0x46, 0xc5, 0xca, 0xf2, 0x3a, 0x6d, 0xf0, 0xa5,
                0xe6, 0xc8, 0x08, 0x89, 0xa0, 0x86, 0x47, 0xe5, 0x51, 0xb3, 0x56, 0x34, 0x49, 0xbe,
                0xfc, 0xfc, 0x97, 0x33,
            ],
            [
                0x0f, 0x83, 0xc3, 0x6f, 0xde, 0xd9, 0xd3, 0x2f, 0xad, 0xf4, 0xef, 0xa3, 0xae, 0x93,
                0xa9, 0x0b, 0xb5, 0xcf, 0xa6, 0x68, 0x93, 0xbc, 0x41, 0x2c, 0x43, 0xfa, 0x72, 0x87,
                0xdb, 0xb9, 0x97, 0x79,
            ],
            [
                0x0b, 0x82, 0x11, 0xa2, 0xb6, 0x04, 0x90, 0x97, 0xf6, 0x87, 0x1c, 0x6c, 0x05, 0x2d,
                0x3c, 0x5f, 0xc1, 0xba, 0x17, 0xda, 0x9e, 0x32, 0xae, 0x45, 0x84, 0x03, 0xb0, 0x5b,
                0xb2, 0x83, 0x09, 0x2a,
            ],
            [
                0x34, 0x3a, 0xc2, 0x0a, 0x3b, 0x9c, 0x6a, 0x27, 0xb1, 0x00, 0x81, 0x76, 0x50, 0x9a,
                0xd3, 0x07, 0x35, 0x85, 0x6e, 0xc1, 0xc8, 0xd8, 0xfc, 0xae, 0x13, 0x91, 0x2d, 0x08,
                0xd1, 0x52, 0xf4, 0x6c,
            ],
            [
                0xfa, 0x69, 0x5f, 0xc7, 0xbe, 0x8d, 0x1b, 0xe5, 0xbf, 0x70, 0x48, 0x98, 0xf3, 0x88,
                0xc4, 0x52, 0xba, 0xfd, 0xd3, 0xb8, 0xea, 0xe8, 0x05, 0xf8, 0x68, 0x1a, 0x8d, 0x15,
                0xc2, 0xd4, 0xe1, 0x42,
            ],
        ];
        let ns = [
            [
                0x77, 0x07, 0x6d, 0x0a, 0x73, 0x18, 0xa5, 0x7d, 0x3c, 0x16, 0xc1, 0x72, 0x51, 0xb2,
                0x66, 0x45, 0xdf, 0x4c, 0x2f, 0x87, 0xeb, 0xc0, 0x99, 0x2a, 0xb1, 0x77, 0xfb, 0xa5,
                0x1d, 0xb9, 0x2c, 0x2a,
            ],
            [
                0x5d, 0xab, 0x08, 0x7e, 0x62, 0x4a, 0x8a, 0x4b, 0x79, 0xe1, 0x7f, 0x8b, 0x83, 0x80,
                0x0e, 0xe6, 0x6f, 0x3b, 0xb1, 0x29, 0x26, 0x18, 0xb6, 0xfd, 0x1c, 0x2f, 0x8b, 0x27,
                0xff, 0x88, 0xe0, 0xeb,
            ],
            [
                0x48, 0x52, 0x83, 0x4d, 0x9d, 0x6b, 0x77, 0xda, 0xde, 0xab, 0xaa, 0xf2, 0xe1, 0x1d,
                0xca, 0x66, 0xd1, 0x9f, 0xe7, 0x49, 0x93, 0xa7, 0xbe, 0xc3, 0x6c, 0x6e, 0x16, 0xa0,
                0x98, 0x3f, 0xea, 0xba,
            ],
            [
                0x10, 0x64, 0xa6, 0x7d, 0xa6, 0x39, 0xa8, 0xf6, 0xdf, 0x4f, 0xbe, 0xa2, 0xd6, 0x33,
                0x58, 0xb6, 0x5b, 0xca, 0x80, 0xa7, 0x70, 0x71, 0x2e, 0x14, 0xea, 0x8a, 0x72, 0xdf,
                0x5a, 0x33, 0x13, 0xae,
            ],
            [
                0x58, 0x8c, 0x06, 0x1a, 0x50, 0x80, 0x4a, 0xc4, 0x88, 0xad, 0x77, 0x4a, 0xc7, 0x16,
                0xc3, 0xf5, 0xba, 0x71, 0x4b, 0x27, 0x12, 0xe0, 0x48, 0x49, 0x13, 0x79, 0xa5, 0x00,
                0x21, 0x19, 0x98, 0xa8,
            ],
            [
                0xb0, 0x5b, 0xfd, 0x32, 0xe5, 0x53, 0x25, 0xd9, 0xfd, 0x64, 0x8c, 0xb3, 0x02, 0x84,
                0x80, 0x39, 0x00, 0x0b, 0x39, 0x0e, 0x44, 0xd5, 0x21, 0xe5, 0x8a, 0xab, 0x3b, 0x29,
                0xa6, 0x96, 0x0b, 0xa8,
            ],
            [
                0x70, 0xe3, 0x4b, 0xcb, 0xe1, 0xf4, 0x7f, 0xbc, 0x0f, 0xdd, 0xfd, 0x7c, 0x1e, 0x1a,
                0xa5, 0x3d, 0x57, 0xbf, 0xe0, 0xf6, 0x6d, 0x24, 0x30, 0x67, 0xb4, 0x24, 0xbb, 0x62,
                0x10, 0xbe, 0xd1, 0x9c,
            ],
            [
                0x68, 0xc1, 0xf3, 0xa6, 0x53, 0xa4, 0xcd, 0xb1, 0xd3, 0x7b, 0xba, 0x94, 0x73, 0x8f,
                0x8b, 0x95, 0x7a, 0x57, 0xbe, 0xb2, 0x4d, 0x64, 0x6e, 0x99, 0x4d, 0xc2, 0x9a, 0x27,
                0x6a, 0xad, 0x45, 0x8d,
            ],
            [
                0xd8, 0x77, 0xb2, 0x6d, 0x06, 0xdf, 0xf9, 0xd9, 0xf7, 0xfd, 0x4c, 0x5b, 0x37, 0x69,
                0xf8, 0xcd, 0xd5, 0xb3, 0x05, 0x16, 0xa5, 0xab, 0x80, 0x6b, 0xe3, 0x24, 0xff, 0x3e,
                0xb6, 0x9e, 0xa0, 0xb2,
            ],
        ];
        let qs = [
            [
                0x4a, 0x5d, 0x9d, 0x5b, 0xa4, 0xce, 0x2d, 0xe1, 0x72, 0x8e, 0x3b, 0xf4, 0x80, 0x35,
                0x0f, 0x25, 0xe0, 0x7e, 0x21, 0xc9, 0x47, 0xd1, 0x9e, 0x33, 0x76, 0xf0, 0x9b, 0x3c,
                0x1e, 0x16, 0x17, 0x42,
            ],
            [
                0x4a, 0x5d, 0x9d, 0x5b, 0xa4, 0xce, 0x2d, 0xe1, 0x72, 0x8e, 0x3b, 0xf4, 0x80, 0x35,
                0x0f, 0x25, 0xe0, 0x7e, 0x21, 0xc9, 0x47, 0xd1, 0x9e, 0x33, 0x76, 0xf0, 0x9b, 0x3c,
                0x1e, 0x16, 0x17, 0x42,
            ],
            [
                0x87, 0xb7, 0xf2, 0x12, 0xb6, 0x27, 0xf7, 0xa5, 0x4c, 0xa5, 0xe0, 0xbc, 0xda, 0xdd,
                0xd5, 0x38, 0x9d, 0x9d, 0xe6, 0x15, 0x6c, 0xdb, 0xcf, 0x8e, 0xbe, 0x14, 0xff, 0xbc,
                0xfb, 0x43, 0x65, 0x51,
            ],
            [
                0x4b, 0x82, 0xbd, 0x86, 0x50, 0xea, 0x9b, 0x81, 0xa4, 0x21, 0x81, 0x84, 0x09, 0x26,
                0xa4, 0xff, 0xa1, 0x64, 0x34, 0xd1, 0xbf, 0x29, 0x8d, 0xe1, 0xdb, 0x87, 0xef, 0xb5,
                0xb0, 0xa9, 0xe3, 0x4e,
            ],
            [
                0xb1, 0xa7, 0x07, 0x51, 0x94, 0x95, 0xff, 0xff, 0xb2, 0x98, 0xff, 0x94, 0x17, 0x16,
                0xb0, 0x6d, 0xfa, 0xb8, 0x7c, 0xf8, 0xd9, 0x11, 0x23, 0xfe, 0x2b, 0xe9, 0xa2, 0x33,
                0xdd, 0xa2, 0x22, 0x12,
            ],
            [
                0x67, 0xdd, 0x4a, 0x6e, 0x16, 0x55, 0x33, 0x53, 0x4c, 0x0e, 0x3f, 0x17, 0x2e, 0x4a,
                0xb8, 0x57, 0x6b, 0xca, 0x92, 0x3a, 0x5f, 0x07, 0xb2, 0xc0, 0x69, 0xb4, 0xc3, 0x10,
                0xff, 0x2e, 0x93, 0x5b,
            ],
            [
                0x4a, 0x06, 0x38, 0xcf, 0xaa, 0x9e, 0xf1, 0x93, 0x3b, 0x47, 0xf8, 0x93, 0x92, 0x96,
                0xa6, 0xb2, 0x5b, 0xe5, 0x41, 0xef, 0x7f, 0x70, 0xe8, 0x44, 0xc0, 0xbc, 0xc0, 0x0b,
                0x13, 0x4d, 0xe6, 0x4a,
            ],
            [
                0x39, 0x94, 0x91, 0xfc, 0xe8, 0xdf, 0xab, 0x73, 0xb4, 0xf9, 0xf6, 0x11, 0xde, 0x8e,
                0xa0, 0xb2, 0x7b, 0x28, 0xf8, 0x59, 0x94, 0x25, 0x0b, 0x0f, 0x47, 0x5d, 0x58, 0x5d,
                0x04, 0x2a, 0xc2, 0x07,
            ],
            [
                0x2c, 0x4f, 0xe1, 0x1d, 0x49, 0x0a, 0x53, 0x86, 0x17, 0x76, 0xb1, 0x3b, 0x43, 0x54,
                0xab, 0xd4, 0xcf, 0x5a, 0x97, 0x69, 0x9d, 0xb6, 0xe6, 0xc6, 0x8c, 0x16, 0x26, 0xd0,
                0x76, 0x62, 0xf7, 0x58,
            ],
        ];

        let mut s = Scalar::new_empty()?;

        for ((n, &p), &q) in ns.iter().zip(ps.iter()).zip(qs.iter()) {
            s.copy_from_slice(n);
            let p = Point(p);
            let actual = p.scalar_mult(&s)?;
            assert_eq!(actual.0, q);
        }

        Ok(())
    }

    #[test]
    fn reject_small_order() -> Result<(), AlkaliError> {
        let p = Point([
            0xe0, 0xeb, 0x7a, 0x7c, 0x3b, 0x41, 0xb8, 0xae, 0x16, 0x56, 0xe3, 0xfa, 0xf1, 0x9f,
            0xc4, 0x6a, 0xda, 0x09, 0x8d, 0xeb, 0x9c, 0x32, 0xb1, 0xfd, 0x86, 0x62, 0x05, 0x16,
            0x5f, 0x49, 0xb8, 0x00,
        ]);
        let n = Scalar::try_from(&[
            0x5d, 0xab, 0x08, 0x7e, 0x62, 0x4a, 0x8a, 0x4b, 0x79, 0xe1, 0x7f, 0x8b, 0x83, 0x80,
            0x0e, 0xe6, 0x6f, 0x3b, 0xb1, 0x29, 0x26, 0x18, 0xb6, 0xfd, 0x1c, 0x2f, 0x8b, 0x27,
            0xff, 0x88, 0xe0, 0xeb,
        ])?;

        assert_eq!(
            p.scalar_mult(&n).unwrap_err(),
            AlkaliError::CurveError(CurveError::ScalarMultUnacceptable)
        );

        Ok(())
    }
}
