//! Gaussian/Normal distribution over x in (-∞, ∞)
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};

use once_cell::sync::OnceCell;
use rand::Rng;
use rand_distr::Normal;
use special::Error as _;
use std::f64::consts::SQRT_2;
use std::fmt;

use crate::consts::*;
use crate::data::GaussianSuffStat;
use crate::impl_display;
use crate::traits::*;

/// Gaussian / [Normal distribution](https://en.wikipedia.org/wiki/Normal_distribution),
/// N(μ, σ) over real values.
///
/// # Examples
///
/// Compute the [KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
/// between two Gaussians.
///
/// ```
/// use rv::prelude::*;
///
/// let gauss_1 = Gaussian::new(0.1, 2.3).unwrap();
/// let gauss_2 = Gaussian::standard();
///
/// // KL is not symmetric
/// let kl_12 = gauss_1.kl(&gauss_2);
/// let kl_21 = gauss_2.kl(&gauss_1);
///
/// // ... but kl_sym is because it's the sum of KL(P|Q) and KL(Q|P)
/// let kl_sym = gauss_1.kl_sym(&gauss_2);
/// assert!((kl_sym - (kl_12 + kl_21)).abs() < 1E-12);
/// ```
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Gaussian {
    /// Mean
    mu: f64,
    /// Standard deviation
    sigma: f64,
    /// Cached log(sigma)
    #[cfg_attr(feature = "serde1", serde(skip))]
    ln_sigma: OnceCell<f64>,
}

impl PartialEq for Gaussian {
    fn eq(&self, other: &Gaussian) -> bool {
        self.mu == other.mu && self.sigma == other.sigma
    }
}

#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum GaussianError {
    /// The mu parameter is infinite or NaN
    MuNotFinite { mu: f64 },
    /// The sigma parameter is less than or equal to zero
    SigmaTooLow { sigma: f64 },
    /// The sigma parameter is infinite or NaN
    SigmaNotFinite { sigma: f64 },
}

impl Gaussian {
    /// Create a new Gaussian distribution
    ///
    /// # Aruments
    /// - mu: mean
    /// - sigma: standard deviation
    pub fn new(mu: f64, sigma: f64) -> Result<Self, GaussianError> {
        if !mu.is_finite() {
            Err(GaussianError::MuNotFinite { mu })
        } else if sigma <= 0.0 {
            Err(GaussianError::SigmaTooLow { sigma })
        } else if !sigma.is_finite() {
            Err(GaussianError::SigmaNotFinite { sigma })
        } else {
            Ok(Gaussian {
                mu,
                sigma,
                ln_sigma: OnceCell::new(),
            })
        }
    }

    /// Creates a new Gaussian without checking whether the parameters are
    /// valid.
    #[inline]
    pub fn new_unchecked(mu: f64, sigma: f64) -> Self {
        Gaussian {
            mu,
            sigma,
            ln_sigma: OnceCell::new(),
        }
    }

    /// Standard normal
    ///
    /// # Example
    ///
    /// ```rust
    /// # use rv::dist::Gaussian;
    /// let gauss = Gaussian::standard();
    ///
    /// assert_eq!(gauss, Gaussian::new(0.0, 1.0).unwrap());
    /// ```
    #[inline]
    pub fn standard() -> Self {
        Gaussian {
            mu: 0.0,
            sigma: 1.0,
            ln_sigma: OnceCell::from(0.0),
        }
    }

    /// Get mu parameter
    ///
    /// # Example
    ///
    /// ```rust
    /// # use rv::dist::Gaussian;
    /// let gauss = Gaussian::new(2.0, 1.5).unwrap();
    ///
    /// assert_eq!(gauss.mu(), 2.0);
    /// ```
    #[inline]
    pub fn mu(&self) -> f64 {
        self.mu
    }

    /// Set the value of mu
    ///
    /// # Example
    ///
    /// ```rust
    /// # use rv::dist::Gaussian;
    /// let mut gauss = Gaussian::new(2.0, 1.5).unwrap();
    /// assert_eq!(gauss.mu(), 2.0);
    ///
    /// gauss.set_mu(1.3).unwrap();
    /// assert_eq!(gauss.mu(), 1.3);
    /// ```
    ///
    /// Will error for invalid values
    ///
    /// ```rust
    /// # use rv::dist::Gaussian;
    /// # let mut gauss = Gaussian::new(2.0, 1.5).unwrap();
    /// assert!(gauss.set_mu(1.3).is_ok());
    /// assert!(gauss.set_mu(std::f64::NEG_INFINITY).is_err());
    /// assert!(gauss.set_mu(std::f64::INFINITY).is_err());
    /// assert!(gauss.set_mu(std::f64::NAN).is_err());
    /// ```
    #[inline]
    pub fn set_mu(&mut self, mu: f64) -> Result<(), GaussianError> {
        if !mu.is_finite() {
            Err(GaussianError::MuNotFinite { mu })
        } else {
            self.set_mu_unchecked(mu);
            Ok(())
        }
    }

    /// Set the value of mu without input validation
    #[inline]
    pub fn set_mu_unchecked(&mut self, mu: f64) {
        self.mu = mu;
    }

    /// Get sigma parameter
    ///
    /// # Example
    ///
    /// ```rust
    /// # use rv::dist::Gaussian;
    /// let gauss = Gaussian::new(2.0, 1.5).unwrap();
    ///
    /// assert_eq!(gauss.sigma(), 1.5);
    /// ```
    #[inline]
    pub fn sigma(&self) -> f64 {
        self.sigma
    }

    /// Set the value of sigma
    ///
    /// # Example
    ///
    /// ```rust
    /// # use rv::dist::Gaussian;
    /// let mut gauss = Gaussian::standard();
    /// assert_eq!(gauss.sigma(), 1.0);
    ///
    /// gauss.set_sigma(2.3).unwrap();
    /// assert_eq!(gauss.sigma(), 2.3);
    /// ```
    ///
    /// Will error for invalid values
    ///
    /// ```rust
    /// # use rv::dist::Gaussian;
    /// # let mut gauss = Gaussian::standard();
    /// assert!(gauss.set_sigma(2.3).is_ok());
    /// assert!(gauss.set_sigma(0.0).is_err());
    /// assert!(gauss.set_sigma(-1.0).is_err());
    /// assert!(gauss.set_sigma(std::f64::INFINITY).is_err());
    /// assert!(gauss.set_sigma(std::f64::NEG_INFINITY).is_err());
    /// assert!(gauss.set_sigma(std::f64::NAN).is_err());
    /// ```
    #[inline]
    pub fn set_sigma(&mut self, sigma: f64) -> Result<(), GaussianError> {
        if sigma <= 0.0 {
            Err(GaussianError::SigmaTooLow { sigma })
        } else if !sigma.is_finite() {
            Err(GaussianError::SigmaNotFinite { sigma })
        } else {
            self.set_sigma_unchecked(sigma);
            Ok(())
        }
    }

    /// Set the value of sigma
    #[inline]
    pub fn set_sigma_unchecked(&mut self, sigma: f64) {
        self.sigma = sigma;
        self.ln_sigma = OnceCell::new();
    }

    /// Evaluate or fetch cached log sigma
    #[inline]
    fn ln_sigma(&self) -> f64 {
        *self.ln_sigma.get_or_init(|| self.sigma.ln())
    }
}

impl Default for Gaussian {
    fn default() -> Self {
        Gaussian::standard()
    }
}

impl From<&Gaussian> for String {
    fn from(gauss: &Gaussian) -> String {
        format!("N(μ: {}, σ: {})", gauss.mu, gauss.sigma)
    }
}

impl_display!(Gaussian);

macro_rules! impl_traits {
    ($kind:ty) => {
        impl Rv<$kind> for Gaussian {
            fn ln_f(&self, x: &$kind) -> f64 {
                let k = (f64::from(*x) - self.mu) / self.sigma;
                -self.ln_sigma() - 0.5 * k * k - HALF_LN_2PI
            }

            fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
                let g = Normal::new(self.mu, self.sigma).unwrap();
                rng.sample(g) as $kind
            }

            fn sample<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<$kind> {
                let g = Normal::new(self.mu, self.sigma).unwrap();
                (0..n).map(|_| rng.sample(g) as $kind).collect()
            }
        }

        impl ContinuousDistr<$kind> for Gaussian {}

        impl Support<$kind> for Gaussian {
            fn supports(&self, x: &$kind) -> bool {
                x.is_finite()
            }
        }

        impl Cdf<$kind> for Gaussian {
            fn cdf(&self, x: &$kind) -> f64 {
                let errf =
                    ((f64::from(*x) - self.mu) / (self.sigma * SQRT_2)).error();
                0.5 * (1.0 + errf)
            }
        }

        impl InverseCdf<$kind> for Gaussian {
            fn invcdf(&self, p: f64) -> $kind {
                assert!(!((p <= 0.0) || (1.0 <= p)), "P out of range");

                let x = (self.sigma * SQRT_2)
                    .mul_add((2.0 * p - 1.0).inv_error(), self.mu);
                x as $kind
            }
        }

        impl Mean<$kind> for Gaussian {
            fn mean(&self) -> Option<$kind> {
                Some(self.mu as $kind)
            }
        }

        impl Median<$kind> for Gaussian {
            fn median(&self) -> Option<$kind> {
                Some(self.mu as $kind)
            }
        }

        impl Mode<$kind> for Gaussian {
            fn mode(&self) -> Option<$kind> {
                Some(self.mu as $kind)
            }
        }

        impl HasSuffStat<$kind> for Gaussian {
            type Stat = GaussianSuffStat;
            fn empty_suffstat(&self) -> Self::Stat {
                GaussianSuffStat::new()
            }
        }
    };
}

impl Variance<f64> for Gaussian {
    fn variance(&self) -> Option<f64> {
        Some(self.sigma * self.sigma)
    }
}

impl Entropy for Gaussian {
    fn entropy(&self) -> f64 {
        HALF_LN_2PI_E + self.ln_sigma()
    }
}

impl Skewness for Gaussian {
    fn skewness(&self) -> Option<f64> {
        Some(0.0)
    }
}

impl Kurtosis for Gaussian {
    fn kurtosis(&self) -> Option<f64> {
        Some(0.0)
    }
}

impl KlDivergence for Gaussian {
    #[allow(clippy::suspicious_operation_groupings)]
    fn kl(&self, other: &Self) -> f64 {
        let m1 = self.mu;
        let m2 = other.mu;

        let s1 = self.sigma;
        let s2 = other.sigma;

        let term1 = s2.ln() - s1.ln();
        let term2 = s1.mul_add(s1, (m1 - m2) * (m1 - m2)) / (2.0 * s2 * s2);

        term1 + term2 - 0.5
    }
}

impl QuadBounds for Gaussian {
    fn quad_bounds(&self) -> (f64, f64) {
        self.interval(0.999_999_999_999)
    }
}

impl_traits!(f32);
impl_traits!(f64);

impl std::error::Error for GaussianError {}

impl fmt::Display for GaussianError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::MuNotFinite { mu } => write!(f, "non-finite mu: {}", mu),
            Self::SigmaTooLow { sigma } => {
                write!(f, "sigma ({}) must be greater than zero", sigma)
            }
            Self::SigmaNotFinite { sigma } => {
                write!(f, "non-finite sigma: {}", sigma)
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::test_basic_impls;
    use std::f64;

    const TOL: f64 = 1E-12;

    test_basic_impls!([continuous] Gaussian::standard());

    #[test]
    fn new() {
        let gauss = Gaussian::new(1.2, 3.0).unwrap();
        assert::close(gauss.mu, 1.2, TOL);
        assert::close(gauss.sigma, 3.0, TOL);
    }

    #[test]
    fn standard() {
        let gauss = Gaussian::standard();
        assert::close(gauss.mu, 0.0, TOL);
        assert::close(gauss.sigma, 1.0, TOL);
    }

    #[test]
    fn standard_gaussian_mean_should_be_zero() {
        let mean: f64 = Gaussian::standard().mean().unwrap();
        assert::close(mean, 0.0, TOL);
    }

    #[test]
    fn standard_gaussian_variance_should_be_one() {
        assert::close(Gaussian::standard().variance().unwrap(), 1.0, TOL);
    }

    #[test]
    fn mean_should_be_mu() {
        let mu = 3.4;
        let mean: f64 = Gaussian::new(mu, 0.5).unwrap().mean().unwrap();
        assert::close(mean, mu, TOL);
    }

    #[test]
    fn median_should_be_mu() {
        let mu = 3.4;
        let median: f64 = Gaussian::new(mu, 0.5).unwrap().median().unwrap();
        assert::close(median, mu, TOL);
    }

    #[test]
    fn mode_should_be_mu() {
        let mu = 3.4;
        let mode: f64 = Gaussian::new(mu, 0.5).unwrap().mode().unwrap();
        assert::close(mode, mu, TOL);
    }

    #[test]
    fn variance_should_be_sigma_squared() {
        let sigma = 0.5;
        let gauss = Gaussian::new(3.4, sigma).unwrap();
        assert::close(gauss.variance().unwrap(), sigma * sigma, TOL);
    }

    #[test]
    fn draws_should_be_finite() {
        let mut rng = rand::thread_rng();
        let gauss = Gaussian::standard();
        for _ in 0..100 {
            let x: f64 = gauss.draw(&mut rng);
            assert!(x.is_finite())
        }
    }

    #[test]
    fn sample_length() {
        let mut rng = rand::thread_rng();
        let gauss = Gaussian::standard();
        let xs: Vec<f64> = gauss.sample(10, &mut rng);
        assert_eq!(xs.len(), 10);
    }

    #[test]
    fn standard_ln_pdf_at_zero() {
        let gauss = Gaussian::standard();
        assert::close(gauss.ln_pdf(&0.0_f64), -0.918_938_533_204_672_7, TOL);
    }

    #[test]
    fn standard_ln_pdf_off_zero() {
        let gauss = Gaussian::standard();
        assert::close(gauss.ln_pdf(&2.1_f64), -3.123_938_533_204_672_7, TOL);
    }

    #[test]
    fn nonstandard_ln_pdf_on_mean() {
        let gauss = Gaussian::new(-1.2, 0.33).unwrap();
        assert::close(gauss.ln_pdf(&-1.2_f64), 0.189_724_091_316_938_46, TOL);
    }

    #[test]
    fn nonstandard_ln_pdf_off_mean() {
        let gauss = Gaussian::new(-1.2, 0.33).unwrap();
        assert::close(gauss.ln_pdf(&0.0_f32), -6.421_846_156_616_945, TOL);
    }

    #[test]
    fn should_contain_finite_values() {
        let gauss = Gaussian::standard();
        assert!(gauss.supports(&0.0_f32));
        assert!(gauss.supports(&10E8_f64));
        assert!(gauss.supports(&-10E8_f64));
    }

    #[test]
    fn should_not_contain_nan() {
        let gauss = Gaussian::standard();
        assert!(!gauss.supports(&f64::NAN));
    }

    #[test]
    fn should_not_contain_positive_or_negative_infinity() {
        let gauss = Gaussian::standard();
        assert!(!gauss.supports(&f64::INFINITY));
        assert!(!gauss.supports(&f64::NEG_INFINITY));
    }

    #[test]
    fn skewness_should_be_zero() {
        let gauss = Gaussian::new(-12.3, 45.6).unwrap();
        assert::close(gauss.skewness().unwrap(), 0.0, TOL);
    }

    #[test]
    fn kurtosis_should_be_zero() {
        let gauss = Gaussian::new(-12.3, 45.6).unwrap();
        assert::close(gauss.skewness().unwrap(), 0.0, TOL);
    }

    #[test]
    fn cdf_at_mean_should_be_one_half() {
        let mu1: f64 = 2.3;
        let gauss1 = Gaussian::new(mu1, 0.2).unwrap();
        assert::close(gauss1.cdf(&mu1), 0.5, TOL);

        let mu2: f32 = -8.0;
        let gauss2 = Gaussian::new(mu2.into(), 100.0).unwrap();
        assert::close(gauss2.cdf(&mu2), 0.5, TOL);
    }

    #[test]
    fn cdf_value_at_one() {
        let gauss = Gaussian::standard();
        assert::close(gauss.cdf(&1.0_f64), 0.841_344_746_068_542_9, TOL);
    }

    #[test]
    fn cdf_value_at_neg_two() {
        let gauss = Gaussian::standard();
        assert::close(gauss.cdf(&-2.0_f64), 0.022_750_131_948_179_195, TOL);
    }

    #[test]
    fn quantile_at_one_half_should_be_mu() {
        let mu = 1.2315;
        let gauss = Gaussian::new(mu, 1.0).unwrap();
        let x: f64 = gauss.quantile(0.5);
        assert::close(x, mu, TOL);
    }

    #[test]
    fn quantile_agree_with_cdf() {
        let mut rng = rand::thread_rng();
        let gauss = Gaussian::standard();
        let xs: Vec<f64> = gauss.sample(100, &mut rng);

        xs.iter().for_each(|x| {
            let p = gauss.cdf(x);
            let y: f64 = gauss.quantile(p);
            assert::close(y, *x, TOL);
        })
    }

    #[test]
    fn quad_on_pdf_agrees_with_cdf_x() {
        use crate::misc::quad_eps;
        let ig = Gaussian::new(-2.3, 0.5).unwrap();
        let pdf = |x: f64| ig.f(&x);
        let mut rng = rand::thread_rng();
        for _ in 0..100 {
            let x: f64 = ig.draw(&mut rng);
            let res = quad_eps(pdf, -10.0, x, Some(1e-13));
            let cdf = ig.cdf(&x);
            assert::close(res, cdf, 1e-7);
        }
    }

    #[test]
    fn standard_gaussian_entropy() {
        let gauss = Gaussian::standard();
        assert::close(gauss.entropy(), 1.418_938_533_204_672_7, TOL);
    }

    #[test]
    fn entropy() {
        let gauss = Gaussian::new(3.0, 12.3).unwrap();
        assert::close(gauss.entropy(), 3.928_537_795_583_044_7, TOL);
    }

    #[test]
    fn kl_of_idential_dsitrbutions_should_be_zero() {
        let gauss = Gaussian::new(1.2, 3.4).unwrap();
        assert::close(gauss.kl(&gauss), 0.0, TOL);
    }

    #[test]
    fn kl() {
        let g1 = Gaussian::new(1.0, 2.0).unwrap();
        let g2 = Gaussian::new(2.0, 1.0).unwrap();
        let kl = 0.5_f64.ln() + 5.0 / 2.0 - 0.5;
        assert::close(g1.kl(&g2), kl, TOL);
    }

    #[test]
    fn ln_f_after_set_mu_works() {
        let mut gauss = Gaussian::standard();
        assert::close(gauss.ln_pdf(&0.0_f64), -0.918_938_533_204_672_7, TOL);

        gauss.set_mu(1.0).unwrap();
        assert::close(gauss.ln_pdf(&1.0_f64), -0.918_938_533_204_672_7, TOL);
    }

    #[test]
    fn ln_f_after_set_sigm_works() {
        let mut gauss = Gaussian::new(-1.2, 5.0).unwrap();

        gauss.set_sigma(0.33).unwrap();
        assert::close(gauss.ln_pdf(&-1.2_f64), 0.189_724_091_316_938_46, TOL);
        assert::close(gauss.ln_pdf(&0.0_f32), -6.421_846_156_616_945, TOL);
    }
}
