use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};

use nanorand::tls::TlsWyRand;

use crate::util::random_float;

/// A 3D vector.
/// This struct supports elementwise arithmetic operations (+, -, *, /).
/// In addition [`dot`](dot_product) and [`cross`](cross_product) are provided.
#[derive(Copy, Clone, Default, Debug, PartialEq)]
pub struct Vec3 {
    x: f64,
    y: f64,
    z: f64,
}

impl Add<Vec3> for Vec3 {
    type Output = Vec3;

    fn add(self, other: Vec3) -> Vec3 {
        Vec3 {
            x: self.x + other.x,
            y: self.y + other.y,
            z: self.z + other.z,
        }
    }
}

impl AddAssign<Vec3> for Vec3 {
    fn add_assign(&mut self, other: Vec3) {
        *self = *self + other;
    }
}

impl Sub<Vec3> for Vec3 {
    type Output = Vec3;

    fn sub(self, other: Vec3) -> Vec3 {
        Vec3 {
            x: self.x - other.x,
            y: self.y - other.y,
            z: self.z - other.z,
        }
    }
}

impl SubAssign<Vec3> for Vec3 {
    fn sub_assign(&mut self, other: Vec3) {
        *self = *self - other;
    }
}

impl Mul<f64> for Vec3 {
    type Output = Vec3;

    fn mul(self, other: f64) -> Vec3 {
        Vec3 {
            x: self.x * other,
            y: self.y * other,
            z: self.z * other,
        }
    }
}

impl MulAssign<f64> for Vec3 {
    fn mul_assign(&mut self, other: f64) {
        *self = *self * other;
    }
}

impl Mul<Vec3> for f64 {
    type Output = Vec3;

    fn mul(self, other: Vec3) -> Vec3 {
        other * self
    }
}

impl Div<f64> for Vec3 {
    type Output = Vec3;

    fn div(self, other: f64) -> Vec3 {
        Vec3 {
            x: self.x / other,
            y: self.y / other,
            z: self.z / other,
        }
    }
}

impl DivAssign<f64> for Vec3 {
    fn div_assign(&mut self, other: f64) {
        *self = *self / other;
    }
}

impl Neg for Vec3 {
    type Output = Vec3;

    fn neg(self) -> Vec3 {
        Vec3 {
            x: -self.x,
            y: -self.y,
            z: -self.z,
        }
    }
}

impl Vec3 {
    /// Create a new 3D vector.
    pub fn new(x: f64, y: f64, z: f64) -> Self {
        Vec3 { x, y, z }
    }

    /// Generate a Vec3 where each component is a uniform random number between `min` and `max`.
    pub fn random(mut rng: &mut TlsWyRand, min: f64, max: f64) -> Self {
        Vec3 {
            x: random_float(&mut rng, min, max),
            y: random_float(&mut rng, min, max),
            z: random_float(&mut rng, min, max),
        }
    }

    /// Generate a random Vec3 on the surface of the unit sphere with uniform (i.e. isotropic) distribution.
    pub fn random_unit_vector(mut rng: &mut TlsWyRand) -> Self {
        Self::random_in_unit_sphere(&mut rng).normalize()
    }

    /// Generate a random Vec3 inside the unit sphere with uniform distribution.
    pub fn random_in_unit_sphere(mut rng: &mut TlsWyRand) -> Self {
        loop {
            let random_vec3 = Self::random(&mut rng, -1.0, 1.0);
            if random_vec3.magnitude() < 1.0 {
                return random_vec3;
            }
        }
    }

    /// Generate a random Vec3 inside the xy unit disk with uniform distribution.
    pub fn random_in_unit_disk(mut rng: &mut TlsWyRand) -> Self {
        loop {
            let random_vec3 = Vec3 {
                x: random_float(&mut rng, -1.0, 1.0),
                y: random_float(&mut rng, -1.0, 1.0),
                z: 0.0,
            };
            if random_vec3.magnitude() < 1.0 {
                return random_vec3;
            }
        }
    }

    /// Reflect the Vec3 with respect to the given normal.
    pub fn reflect(self, normal: Vec3) -> Self {
        self - 2.0 * Self::dot_product(self, normal) * normal
    }

    /// Refract the Vec3 with respect to the given normal using n1 and n2 as the indices of refraction.
    pub fn refract(self, normal: Vec3, n1: f64, n2: f64, mut rng: &mut TlsWyRand) -> Self {
        let incidence_angle = (-Self::dot_product(self, normal)).acos();

        // Total internal reflection.
        if incidence_angle >= (n2 / n1).asin()
            || Self::reflection_coefficient(incidence_angle, n1, n2)
                > random_float(&mut rng, 0.0, 1.0)
        {
            return self.reflect(normal);
        }

        let refraction_angle = (n1 * incidence_angle.sin() / n2).asin();

        self - normal * (n2 * refraction_angle.cos() - n1 * incidence_angle.cos()) / n1
    }

    /// Calculate the reflection coefficient using Schlick's approximation.
    fn reflection_coefficient(incidence_angle: f64, n1: f64, n2: f64) -> f64 {
        let normal_reflectivity = ((n1 - n2) / (n1 + n2)).powi(2);
        normal_reflectivity + (1.0 - normal_reflectivity) * (1.0 - incidence_angle.cos()).powi(5)
    }

    /// Return the x component.
    pub fn x(&self) -> f64 {
        self.x
    }

    /// Return the y component.
    pub fn y(&self) -> f64 {
        self.y
    }

    /// Return the z component.
    pub fn z(&self) -> f64 {
        self.z
    }

    /// Return the magnitude (length).
    pub fn magnitude(&self) -> f64 {
        Self::dot_product(*self, *self).sqrt()
    }

    /// Return the absolute value (synonym [`magnitude`](Self::magnitude)).
    pub fn abs(&self) -> f64 {
        self.magnitude()
    }

    /// Return the unit vector parallel to self.
    /// Panics if self cannot be normalized.
    pub fn normalize(&self) -> Vec3 {
        if self.magnitude() == 0.0 {
            panic!("Can't normalize zero vector");
        }
        *self / self.magnitude()
    }

    /// Calculate the dot product.
    pub fn dot_product(vec_a: Vec3, vec_b: Vec3) -> f64 {
        vec_a.x * vec_b.x + vec_a.y * vec_b.y + vec_a.z * vec_b.z
    }

    /// Calculate the cross product.
    pub fn cross_product(vec_a: Vec3, vec_b: Vec3) -> Self {
        Vec3 {
            x: vec_a.y * vec_b.z - vec_a.z * vec_b.y,
            y: vec_a.z * vec_b.x - vec_a.x * vec_b.z,
            z: vec_a.x * vec_b.y - vec_a.y * vec_b.x,
        }
    }
}

#[cfg(test)]
mod tests {
    use assert_approx_eq::assert_approx_eq;
    use nanorand::tls_rng;

    use super::Vec3;

    const EPSILON: f64 = f64::EPSILON * 100.0;

    #[test]
    fn vec3_add() {
        assert_approx_eq!(
            Vec3::new(2.0, 1.0, 0.0) + Vec3::new(1.0, 1.0, 1.0),
            Vec3::new(3.0, 2.0, 1.0),
            EPSILON
        );
        assert_approx_eq!(
            Vec3::new(5.72, 2.5, 8.824) + Vec3::new(8.7, 5.987, 0.12),
            Vec3::new(14.42, 8.487, 8.944),
            EPSILON
        );
        let mut vec_a = Vec3::new(7.0, 2.5, 3.2);
        vec_a += Vec3::new(1.2, 9.23, 6.2);
        assert_approx_eq!(vec_a, Vec3::new(8.2, 11.73, 9.4), EPSILON)
    }

    #[test]
    fn vec3_sub() {
        assert_approx_eq!(
            Vec3::new(2.0, 1.0, 0.0) - Vec3::new(1.0, 1.0, 1.0),
            Vec3::new(1.0, 0.0, -1.0),
            EPSILON
        );
        assert_approx_eq!(
            Vec3::new(5.72, 2.5, 8.824) - Vec3::new(8.7, 5.987, 0.12),
            Vec3::new(-2.98, -3.487, 8.704),
            EPSILON
        );
        let mut vec_a = Vec3::new(7.0, 2.5, 3.2);
        vec_a -= Vec3::new(1.2, 9.23, 6.2);
        assert_approx_eq!(vec_a, Vec3::new(5.8, -6.73, -3.0), EPSILON)
    }

    #[test]
    fn vec3_mul() {
        assert_approx_eq!(
            Vec3::new(2.0, 1.0, 0.0) * 2.0,
            Vec3::new(4.0, 2.0, 0.0),
            EPSILON
        );
        assert_approx_eq!(
            2.5 * Vec3::new(8.7, 5.987, 0.12),
            Vec3::new(21.75, 14.9675, 0.3),
            EPSILON
        );
        let mut vec_a = Vec3::new(7.0, 2.5, 3.2);
        vec_a *= -2.0;
        assert_approx_eq!(vec_a, Vec3::new(-14.0, -5.0, -6.4), EPSILON);
    }

    #[test]
    fn vec3_div() {
        assert_approx_eq!(
            Vec3::new(2.0, 1.0, 0.0) / 2.0,
            Vec3::new(1.0, 0.5, 0.0),
            EPSILON
        );
        assert_approx_eq!(
            Vec3::new(8.7, 5.987, 0.12) / 2.5,
            Vec3::new(3.48, 2.3948, 0.048),
            EPSILON
        );
        let mut vec_a = Vec3::new(7.0, 2.5, 3.2);
        vec_a /= -2.0;
        assert_approx_eq!(vec_a, Vec3::new(-3.5, -1.25, -1.6), EPSILON);
    }

    #[test]
    fn vec3_neg() {
        assert_approx_eq!(
            -Vec3::new(2.0, 1.0, 0.0),
            Vec3::new(-2.0, -1.0, 0.0),
            EPSILON
        );
        assert_approx_eq!(
            -Vec3::new(8.7, 5.987, 0.12),
            Vec3::new(-8.7, -5.987, -0.12),
            EPSILON
        );
    }

    #[test]
    fn vec3_random() {
        let mut rng = tls_rng();
        assert_approx_eq!(Vec3::random_unit_vector(&mut rng).magnitude(), 1.0, EPSILON);
        assert!(Vec3::random_in_unit_disk(&mut rng).magnitude() < 1.0);
        assert_eq!(Vec3::random_in_unit_disk(&mut rng).z(), 0.0);
        assert!(Vec3::random_in_unit_sphere(&mut rng).magnitude() < 1.0);
    }

    #[test]
    fn vec3_reflect() {
        assert_approx_eq!(
            Vec3::new(2.0, 1.0, 0.0).reflect(Vec3::new(0.0, 1.0, 0.0)),
            Vec3::new(2.0, -1.0, 0.0),
            EPSILON
        );
        assert_approx_eq!(
            Vec3::new(8.7, 5.987, 0.12).reflect(Vec3::new(1.0, 1.0, 0.0)),
            Vec3::new(-20.674, -23.387, 0.12),
            EPSILON
        );
    }

    #[test]
    fn vec3_normalize() {
        assert_approx_eq!(
            (-Vec3::new(2.0, 1.0, 0.0)).normalize().magnitude(),
            1.0,
            EPSILON
        );
        assert_approx_eq!(
            Vec3::new(8.7, 5.987, 0.12).normalize().magnitude(),
            1.0,
            EPSILON
        );
    }

    #[test]
    #[should_panic(expected = "Can't normalize zero vector")]
    fn vec3_normalize_panic() {
        Vec3::new(0.0, 0.0, 0.0).normalize();
    }

    #[test]
    fn vec3_magnitude() {
        assert_approx_eq!(Vec3::new(42.0, 0.0, 0.0).magnitude(), 42.0, EPSILON);
        assert_approx_eq!(Vec3::new(-3.0, -4.0, 0.0).magnitude(), 5.0, EPSILON);
        assert_approx_eq!(Vec3::new(2.0, -2.0, 1.0).abs(), 3.0, EPSILON);
    }

    #[test]
    fn vec3_dot() {
        assert_approx_eq!(
            Vec3::dot_product(Vec3::new(2.0, 1.0, 0.0), Vec3::new(1.0, 1.0, 1.0)),
            3.0,
            EPSILON
        );
        assert_approx_eq!(
            Vec3::dot_product(Vec3::new(5.72, 2.5, 8.824), Vec3::new(8.7, 5.987, 0.12)),
            65.79038,
            EPSILON
        );
    }

    #[test]
    fn vec3_cross() {
        assert_approx_eq!(
            Vec3::cross_product(Vec3::new(2.0, 1.0, 0.0), Vec3::new(1.0, 1.0, 1.0)),
            Vec3::new(1.0, -2.0, 1.0),
            EPSILON
        );
        assert_approx_eq!(
            Vec3::cross_product(Vec3::new(5.72, 2.5, 8.824), Vec3::new(2.0, 1.0, 0.0)),
            Vec3::new(-8.824, 17.648, 0.72),
            EPSILON
        );
    }
}
