//! Implementation of remainder and modulo as it is done in the Go standard library.
//!
//! For unknown reasons the Rust standard methods as operator %
//! or .rem_euclid() produce different results and azimuthal integration over the
//! chi angle looks different. Thus we just copy the functions math.Remainder and math.Mod
//! as they are, but changing them to be more idiomatic in Rust (i.e. as methods of f64).
//! Since the Go methods internally use functions frexp and ldexp, we have to use
//! crate glm which reproduces them in Rust.
//!
//! We also copy the Go tests to be sure the calculations are equal.
//!
use glm::BaseFloat;

/// Trait to be implemented by [f64] and [f32].
pub trait GoFloatMath: BaseFloat {
    /// Remainder function.
    fn remainder(self, y: Self) -> Self;

    /// Modulo function/
    fn modulo(self, y: Self) -> Self;
}

const TINY: f64 = 4.45014771701440276618e-308f64; // 0x0020000000000000
const HALF_MAX: f64 = f64::MAX / 2.0f64;

impl GoFloatMath for f64 {
    /// The original C code and the comment below are from
    /// FreeBSD's /usr/src/lib/msun/src/e_remainder.c and came
    /// with this notice. The go code is a simplified version of
    /// the original C.
    ///
    /// ====================================================
    ///
    /// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
    ///
    /// Developed at SunPro, a Sun Microsystems, Inc. business.
    /// Permission to use, copy, modify, and distribute this
    /// software is freely granted, provided that this notice
    /// is preserved.
    ///
    /// ====================================================
    ///
    /// __ieee754_remainder(x,y)
    ///
    /// Return :
    ///      returns  x REM y  =  x - [x/y]*y  as if in infinite
    ///      precision arithmetic, where [x/y] is the (infinite bit)
    ///      integer nearest x/y (in half way cases, choose the even one).
    ///
    /// Method :
    ///      Based on Mod() returning  x - [x/y]chopped * y  exactly.
    ///
    /// Remainder returns the IEEE 754 floating-point remainder of x/y.
    ///
    /// Special cases are:
    ///
    ///	Remainder(±Inf, y) = NaN
    ///
    /// Remainder(NaN, y) = NaN
    ///
    /// Remainder(x, 0) = NaN
    ///
    /// Remainder(x, ±Inf) = x
    ///
    /// Remainder(x, NaN) = NaN
    ///
    fn remainder(mut self, mut y: f64) -> f64 {
        if self.is_nan() || y.is_nan() || self.is_infinite() || y == 0. {
            return f64::NAN;
        } else if y.is_infinite() {
            return self;
        }
        let mut sign = false;
        if self < 0. {
            self = -self;
            sign = true;
        }
        if y < 0. {
            y = -y
        }
        if self == y {
            if sign {
                let zero = 0.0;
                return -zero;
            }
            return 0.;
        }
        if y <= HALF_MAX {
            self = self.modulo(y + y); // now self < 2y
        }
        if y < TINY {
            if self + self > y {
                self -= y;
                if self + self >= y {
                    self -= y;
                }
            }
        } else {
            let y_half = 0.5 * y;
            if self > y_half {
                self -= y;
                if self >= y_half {
                    self -= y;
                }
            }
        }
        if sign {
            self = -self;
        }
        self
    }

    /// Floating-point mod function.
    ///
    /// Modulo returns the floating-point remainder of x/y.
    /// The magnitude of the result is less than y and its
    /// sign agrees with that of x.
    ///
    /// Special cases are:
    ///
    /// Mod(±Inf, y) = NaN
    ///
    /// Mod(NaN, y) = NaN
    ///
    /// Mod(x, 0) = NaN
    ///
    /// Mod(x, ±Inf) = x
    ///
    /// Mod(x, NaN) = NaN
    ///
    fn modulo(self, y: f64) -> f64 {
        if y == 0. || self.is_infinite() || self.is_nan() || y.is_nan() {
            return f64::NAN;
        }
        let y = y.abs();
        let (yfr, yexp) = y.frexp();
        let mut r = self;
        if self < 0. {
            r = -self;
        }
        while r >= y {
            let (rfr, mut rexp) = r.frexp();
            if rfr < yfr {
                rexp = rexp - 1;
            }
            r = r - y.ldexp(rexp - yexp);
        }
        if self < 0. {
            r = -r
        }
        r
    }
}

#[cfg(test)]
mod tests {
    use std::f64::consts::PI;

    use crate::utils::GoFloatMath;

    const VF: [f64; 10] = [
        4.9790119248836735e+00,
        7.7388724745781045e+00,
        -2.7688005719200159e-01,
        -5.0106036182710749e+00,
        9.6362937071984173e+00,
        2.9263772392439646e+00,
        5.2290834314593066e+00,
        2.7279399104360102e+00,
        1.8253080916808550e+00,
        -8.6859247685756013e+00,
    ];
    const REMAINDER: [f64; 10] = [
        4.197615023265299782906368e-02,
        2.261127525421895434476482e+00,
        3.231794108794261433104108e-02,
        -2.120723654214984321697556e-02,
        3.637062928015826201999516e-01,
        1.220868282268106064236690e+00,
        -4.581668629186133046005125e-01,
        -9.117596417440410050403443e-01,
        8.734595415957246977711748e-01,
        1.314075231424398637614104e+00,
    ];
    const MODULO: [f64; 10] = [
        4.197615023265299782906368e-02,
        2.261127525421895434476482e+00,
        3.231794108794261433104108e-02,
        4.989396381728925078391512e+00,
        3.637062928015826201999516e-01,
        1.220868282268106064236690e+00,
        4.770916568540693347699744e+00,
        1.816180268691969246219742e+00,
        8.734595415957246977711748e-01,
        1.314075231424398637614104e+00,
    ];
    const TEN: f64 = 10.0;
    const ZERO: f64 = 0.0;

    fn sc() -> ([[f64; 2]; 34], [f64; 34]) {
        (
            [
                [f64::NEG_INFINITY, f64::NEG_INFINITY],
                [f64::NEG_INFINITY, -PI],
                [f64::NEG_INFINITY, ZERO],
                [f64::NEG_INFINITY, PI],
                [f64::NEG_INFINITY, f64::INFINITY],
                [f64::NEG_INFINITY, f64::NAN],
                [-PI, f64::NEG_INFINITY],
                [-PI, ZERO],
                [-PI, f64::INFINITY],
                [-PI, f64::NAN],
                [ZERO.copysign(-1.), f64::NEG_INFINITY],
                [ZERO.copysign(-1.), ZERO],
                [ZERO.copysign(-1.), f64::INFINITY],
                [ZERO.copysign(-1.), f64::NAN],
                [ZERO, f64::NEG_INFINITY],
                [ZERO, ZERO],
                [ZERO, f64::INFINITY],
                [ZERO, f64::NAN],
                [PI, f64::NEG_INFINITY],
                [PI, ZERO],
                [PI, f64::INFINITY],
                [PI, f64::NAN],
                [f64::INFINITY, f64::NEG_INFINITY],
                [f64::INFINITY, -PI],
                [f64::INFINITY, ZERO],
                [f64::INFINITY, PI],
                [f64::INFINITY, f64::INFINITY],
                [f64::INFINITY, f64::NAN],
                [f64::NAN, f64::NEG_INFINITY],
                [f64::NAN, -PI],
                [f64::NAN, ZERO],
                [f64::NAN, PI],
                [f64::NAN, f64::INFINITY],
                [f64::NAN, f64::NAN],
            ],
            [
                f64::NAN,           // fmod(-Inf, -Inf)
                f64::NAN,           // fmod(-Inf, -PI)
                f64::NAN,           // fmod(-Inf, 0)
                f64::NAN,           // fmod(-Inf, PI)
                f64::NAN,           // fmod(-Inf, +Inf)
                f64::NAN,           // fmod(-Inf, NaN)
                -PI,                // fmod(-PI, -Inf)
                f64::NAN,           // fmod(-PI, 0)
                -PI,                // fmod(-PI, +Inf)
                f64::NAN,           // fmod(-PI, NaN)
                ZERO.copysign(-1.), // fmod(-0, -Inf)
                f64::NAN,           // fmod(-0, 0)
                ZERO.copysign(-1.), // fmod(-0, Inf)
                f64::NAN,           // fmod(-0, NaN)
                ZERO,               // fmod(0, -Inf)
                f64::NAN,           // fmod(0, 0)
                ZERO,               // fmod(0, +Inf)
                f64::NAN,           // fmod(0, NaN)
                PI,                 // fmod(PI, -Inf)
                f64::NAN,           // fmod(PI, 0)
                PI,                 // fmod(PI, +Inf)
                f64::NAN,           // fmod(PI, NaN)
                f64::NAN,           // fmod(+Inf, -Inf)
                f64::NAN,           // fmod(+Inf, -PI)
                f64::NAN,           // fmod(+Inf, 0)
                f64::NAN,           // fmod(+Inf, PI)
                f64::NAN,           // fmod(+Inf, +Inf)
                f64::NAN,           // fmod(+Inf, NaN)
                f64::NAN,           // fmod(NaN, -Inf)
                f64::NAN,           // fmod(NaN, -PI)
                f64::NAN,           // fmod(NaN, 0)
                f64::NAN,           // fmod(NaN, PI)
                f64::NAN,           // fmod(NaN, +Inf)
                f64::NAN,           // fmod(NaN, NaN)
            ],
        )
    }

    fn alike(a: f64, b: f64) -> bool {
        if b.is_nan() && a.is_nan() {
            true
        } else if b.is_sign_positive() && a.is_sign_positive() {
            true
        } else if b.is_sign_negative() && b.is_sign_negative() {
            true
        } else {
            false
        }
    }

    fn test_remainder_sign(x: f64, y: f64) {
        let r = x.remainder(y);
        if r == 0.
            && ((r.is_sign_positive() && x.is_sign_negative())
                || (x.is_sign_positive() && r.is_sign_negative()))
        {
            panic!(
                "{}.remainder({}) = {}, sign of (zero) result should agree with sign of x",
                x, y, r
            );
        }
    }

    #[test]
    fn test_remainder() {
        for (i, (val, exp)) in VF.iter().zip(REMAINDER.iter()).enumerate() {
            let calc = TEN.remainder(*val);
            assert_eq!(
                calc, *exp,
                "{}: {}.remainder({}) = {}, expected {}",
                i, TEN, *val, calc, exp
            );
        }
        let (vffmod, fmod) = sc();
        for (i, (v1, v2)) in vffmod.iter().zip(fmod.iter()).enumerate() {
            let f = v1[0].remainder(v1[1]);
            if !alike(f, *v2) {
                panic!(
                    "{}: {}.remainder({}) = {}, expected {}",
                    i, v1[0], v1[1], f, *v2
                );
            }
        }
        // verify precision of result for extreme inputs
        let v1 = 5.9790119248836734e+200;
        let v2 = 1.1258465975523544;
        let r = -0.4810497673014966;
        let f = v1.remainder(v2);
        assert_eq!(f, r, "{}.remainder({}) = {}, expected {}", v1, v2, r, f);
        // verify that sign is correct when r == 0
        for x in 1..4 {
            for y in 1..4 {
                let x = x as f64;
                let y = y as f64;
                test_remainder_sign(x, y);
                test_remainder_sign(x, -y);
                test_remainder_sign(-x, y);
                test_remainder_sign(-x, -y);
            }
        }
    }

    #[test]
    fn test_modulo() {
        for (i, (val, exp)) in VF.iter().zip(MODULO.iter()).enumerate() {
            let calc = TEN.modulo(*val);
            assert_eq!(
                calc, *exp,
                "{}: {}.modulo({}) = {}, expected {}",
                i, TEN, *val, calc, exp
            );
        }
        let (vffmod, fmod) = sc();
        for (i, (v1, v2)) in vffmod.iter().zip(fmod.iter()).enumerate() {
            let f = v1[0].modulo(v1[1]);
            if !alike(f, *v2) {
                panic!(
                    "{}: {}.modulo({}) = {}, expected {}",
                    i, v1[0], v1[1], f, *v2
                );
            }
        }
        // verify precision of result for extreme inputs
        let v1 = 5.9790119248836734e+200;
        let v2 = 1.1258465975523544;
        let r = 0.6447968302508578;
        let f = v1.modulo(v2);
        assert_eq!(f, r, "{}.modulo({}) = {}, expected {}", v1, v2, r, f);
    }
}
