use std::ops::Index;

use glam::{Mat4, Vec3};
use ivy_base::Position;

use crate::{epa, gjk, util::minkowski_diff, CollisionPrimitive, EntityPayload};

#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ContactPoints {
    Single([Position; 1]),
    Double([Position; 2]),
}

impl ContactPoints {
    pub fn single(p: Position) -> Self {
        Self::Single([p])
    }

    pub fn double(a: Position, b: Position) -> Self {
        Self::Double([a, b])
    }

    pub fn points(&self) -> &[Position] {
        match self {
            ContactPoints::Single(val) => val,
            ContactPoints::Double(val) => val,
        }
    }

    pub fn iter(&self) -> std::slice::Iter<Position> {
        self.into_iter()
    }
}

impl From<Position> for ContactPoints {
    fn from(val: Position) -> Self {
        Self::Single([val])
    }
}

impl From<[Position; 1]> for ContactPoints {
    fn from(val: [Position; 1]) -> Self {
        Self::Single(val)
    }
}

impl From<[Position; 2]> for ContactPoints {
    fn from(val: [Position; 2]) -> Self {
        Self::Double(val)
    }
}

impl<'a> IntoIterator for &'a ContactPoints {
    type Item = &'a Position;

    type IntoIter = std::slice::Iter<'a, Position>;

    fn into_iter(self) -> Self::IntoIter {
        match self {
            ContactPoints::Single(val) => val.iter(),
            ContactPoints::Double(val) => val.iter(),
        }
    }
}

impl Index<usize> for ContactPoints {
    type Output = Position;

    fn index(&self, index: usize) -> &Self::Output {
        &self.points()[index]
    }
}

#[derive(Debug, Clone, PartialEq)]
pub struct Contact {
    /// The closest points on the two colliders, respectively
    pub points: ContactPoints,
    pub depth: f32,
    pub normal: Vec3,
}

/// Represents a collision between two entities.
#[derive(Debug, Clone)]
pub struct Collision {
    pub a: EntityPayload,
    pub b: EntityPayload,
    pub contact: Contact,
}

pub fn intersect<A: CollisionPrimitive, B: CollisionPrimitive>(
    a_transform: &Mat4,
    b_transform: &Mat4,
    a: &A,
    b: &B,
) -> Option<Contact> {
    let a_transform_inv = a_transform.inverse();
    let b_transform_inv = b_transform.inverse();

    let (intersect, simplex) = gjk(
        a_transform,
        b_transform,
        &a_transform_inv,
        &b_transform_inv,
        a,
        b,
    );

    if intersect {
        Some(epa(
            |dir| {
                minkowski_diff(
                    a_transform,
                    b_transform,
                    &a_transform_inv,
                    &b_transform_inv,
                    a,
                    b,
                    dir,
                )
            },
            simplex,
        ))
    } else {
        None
    }
}
