use std::{collections::HashSet, fmt::Debug, hash::Hash, sync::Arc};

use num::{Num, NumCast, One, Signed, Zero};
use parking_lot::RwLock;

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct Bounds<TBound: Num> {
    pub x: TBound,
    pub y: TBound,
    pub width: TBound,
    pub height: TBound,
}

impl<TBound: Num> Bounds<TBound> {
    pub fn new(x: TBound, y: TBound, width: TBound, height: TBound) -> Self {
        Self {
            x,
            y,
            width,
            height,
        }
    }
}

/// https://gamedev.stackexchange.com/questions/586/what-is-the-fastest-way-to-work-out-2d-bounding-box-intersection
impl<TBound: Copy + Num + NumCast + PartialOrd + Signed> Bounds<TBound> {
    pub fn contains<TProvidesBounds: ProvidesBounds<TBound>>(
        &self,
        other: &TProvidesBounds,
    ) -> bool {
        let other = other.bounds();
        self.x <= other.x
            && (self.x + self.width) >= (other.x + other.width)
            && self.y <= other.y
            && (self.y + self.height) >= (other.y + other.height)
    }

    pub fn intersects<TProvidesBounds: ProvidesBounds<TBound>>(
        &self,
        other: &TProvidesBounds,
    ) -> bool {
        let other = other.bounds();
        let two = NumCast::from(2).unwrap();
        (((self.x + self.width / two).abs() - (other.x + other.width / two)) * two
            < (self.width + other.width))
            && (((self.y + self.height / two).abs() - (other.y + other.height / two)) * two
                < (self.height + other.height))
    }
}

impl<TBound: Copy + Num + NumCast + One + Zero> Bounds<TBound> {
    pub fn quad(&self) -> [(usize, Bounds<TBound>); 4] {
        let two = NumCast::from(2).unwrap();
        let o = TBound::zero();
        let l = TBound::one();
        let half_width = self.width / two;
        let half_height = self.height / two;
        [(0, o, o), (1, l, o), (2, o, l), (3, l, l)].map(|(index, x, y)| {
            (
                index,
                Bounds::new(
                    self.x + (half_height * x),
                    self.y + (half_height * y),
                    half_width,
                    half_height,
                ),
            )
        })
    }
}

impl<TBound: Copy + Num + NumCast + One + PartialOrd + Signed + Zero> Bounds<TBound> {
    fn quad_and_index<TProvidesBounds: ProvidesBounds<TBound>>(
        &self,
        against: &TProvidesBounds,
    ) -> Option<(usize, [(usize, Bounds<TBound>); 4])> {
        let bounds = against.bounds();
        let quad = self.quad();
        let mut found_index = None;
        for (index, bound) in quad.iter() {
            if bound.contains(&bounds) {
                found_index = Some(*index);
                break;
            }
        }

        found_index.map(|found_index| (found_index, quad))
    }
}

pub trait ProvidesBounds<TBound: Num> {
    fn bounds(&self) -> Bounds<TBound>;
}

impl<TBound: Copy + Num, TProvidesBounds: ProvidesBounds<TBound>> ProvidesBounds<TBound>
    for &TProvidesBounds
{
    fn bounds(&self) -> Bounds<TBound> {
        ProvidesBounds::bounds(&**self)
    }
}

impl<TBound: Copy + Num> ProvidesBounds<TBound> for Bounds<TBound> {
    fn bounds(&self) -> Bounds<TBound> {
        *self
    }
}

pub type BranchValue<TData, TBound> = Arc<RwLock<[QuadTreeNode<TData, TBound>; 4]>>;

#[derive(Clone, Debug)]
pub enum QuadTreeNode<TData: ProvidesBounds<TBound>, TBound: Num> {
    None(Bounds<TBound>),
    Branch(Bounds<TBound>, BranchValue<TData, TBound>),
    Leaf(
        Bounds<TBound>,
        Arc<RwLock<HashSet<TData>>>,
        Option<BranchValue<TData, TBound>>,
    ),
}

impl<TData: ProvidesBounds<TBound>, TBound: Copy + Num> ProvidesBounds<TBound>
    for QuadTreeNode<TData, TBound>
{
    fn bounds(&self) -> Bounds<TBound> {
        match self {
            Self::Branch(bounds, ..) => *bounds,
            Self::Leaf(bounds, ..) => *bounds,
            Self::None(bounds) => *bounds,
        }
    }
}

impl<
        TData: Clone + Eq + Hash + ProvidesBounds<TBound>,
        TBound: Copy + Debug + Num + NumCast + PartialOrd + Signed,
    > QuadTreeNode<TData, TBound>
{
    fn insert(&mut self, data: TData) -> bool {
        if !self.bounds().contains(&data.bounds()) {
            println!("{:?} contains {:?}", self.bounds(), data.bounds());
            false
        } else {
            *self = match self {
                Self::None(bounds) => {
                    let found_index_quad = bounds.quad_and_index(&data);

                    if let Some((found_index, quad)) = found_index_quad {
                        Self::Branch(*bounds, parse_into_branch_value(data, found_index, quad))
                    } else {
                        let mut set = HashSet::new();
                        set.insert(data);
                        Self::Leaf(*bounds, Arc::new(RwLock::new(set)), None)
                    }
                }
                Self::Leaf(bounds, values, sub_nodes) => {
                    let found_index_quad = bounds.quad_and_index(&data);

                    if let Some((found_index, quad)) = found_index_quad {
                        if let Some(sub_nodes) = sub_nodes {
                            sub_nodes.write()[found_index].insert(data);
                        } else {
                            *sub_nodes = Some(parse_into_branch_value(data, found_index, quad));
                        }
                    } else {
                        values.write().insert(data);
                    }

                    Self::Leaf(*bounds, values.clone(), sub_nodes.clone())
                }
                Self::Branch(bounds, sub_nodes) => {
                    let found_index_quad = bounds.quad_and_index(&data);

                    if let Some((found_index, _quad)) = found_index_quad {
                        sub_nodes.write()[found_index].insert(data);

                        Self::Branch(*bounds, sub_nodes.clone())
                    } else {
                        let mut set = HashSet::new();
                        set.insert(data);
                        QuadTreeNode::Leaf(
                            *bounds,
                            Arc::new(RwLock::new(set)),
                            Some(sub_nodes.clone()),
                        )
                    }
                }
            };

            true
        }
    }

    fn all_bounds_intersects<TProvidesBounds: ProvidesBounds<TBound>>(
        &self,
        target: TProvidesBounds,
        mut values: HashSet<TData>,
    ) -> HashSet<TData> {
        let target = target.bounds();
        match self {
            Self::Branch(bounds, sub_nodes) => {
                if target.intersects(bounds) {
                    for sub_node in sub_nodes.read().iter() {
                        values = sub_node.all_bounds_intersects(target, values);
                    }
                }

                values
            }
            Self::Leaf(bounds, leaf_values, sub_nodes) => {
                if target.intersects(bounds) {
                    for leaf_data in leaf_values.read().iter() {
                        if target.intersects(&leaf_data.bounds()) {
                            values.insert(leaf_data.clone());
                        }
                    }
                }

                if let Some(sub_nodes) = sub_nodes {
                    for sub_node in sub_nodes.read().iter() {
                        values = sub_node.all_bounds_intersects(target, values);
                    }
                }

                values
            }
            Self::None(..) => values,
        }
    }

    fn all_node_intersects<TProvidesBounds: ProvidesBounds<TBound>>(
        &self,
        target: TProvidesBounds,
        mut values: HashSet<TData>,
    ) -> HashSet<TData> {
        let target = target.bounds();
        match self {
            Self::Branch(bounds, sub_nodes) => {
                if target.intersects(bounds) {
                    for sub_node in sub_nodes.read().iter() {
                        values = sub_node.all_node_intersects(target, values);
                    }
                }

                values
            }
            Self::Leaf(bounds, leaf_values, sub_nodes) => {
                if target.intersects(bounds) {
                    for data in leaf_values.read().iter() {
                        values.insert(data.clone());
                    }
                }

                if let Some(sub_nodes) = sub_nodes {
                    for sub_node in sub_nodes.read().iter() {
                        values = sub_node.all_node_intersects(target, values);
                    }
                }

                values
            }
            Self::None(..) => values,
        }
    }
}

fn parse_into_branch_value<TData: Clone + Eq + Hash + ProvidesBounds<TBound>, TBound: Num>(
    data: TData,
    found_index: usize,
    quad: [(usize, Bounds<TBound>); 4],
) -> Arc<parking_lot::lock_api::RwLock<parking_lot::RawRwLock, [QuadTreeNode<TData, TBound>; 4]>> {
    Arc::new(RwLock::new(quad.map(|(quad_index, quad_bounds)| {
        if quad_index == found_index {
            let mut set = HashSet::new();
            set.insert(data.clone());
            QuadTreeNode::Leaf(quad_bounds, Arc::new(RwLock::new(set)), None)
        } else {
            QuadTreeNode::None(quad_bounds)
        }
    })))
}

#[cfg(test)]
mod tests {
    use std::collections::{HashMap, HashSet};

    use super::*;
    use noisy_float::prelude::{n64, N64};
    use rand::{thread_rng, Rng};

    #[test]
    fn it_works() {
        let mut rng = thread_rng();
        const BOX_COUNT: usize = 200;
        let mut rects: Vec<Bounds<N64>> = Vec::with_capacity(BOX_COUNT);
        for _ in 0..BOX_COUNT {
            rects.push(Bounds::new(
                n64(rng.gen_range(-100.0..=96.0)),
                n64(rng.gen_range(-100.0..=96.0)),
                n64(rng.gen_range(0.5..=4.0)),
                n64(rng.gen_range(0.5..=4.0)),
            ));
        }
        let rects = rects;

        let mut baseline_matches = HashMap::new();
        for lhs in rects.iter() {
            for rhs in rects.iter() {
                if lhs.intersects(rhs) {
                    baseline_matches
                        .entry(lhs)
                        .or_insert_with(HashSet::new)
                        .insert(rhs);
                }
            }
        }
        let baseline_matches = baseline_matches;

        let mut quad_tree = QuadTreeNode::None(Bounds::new(
            n64(-100.0),
            n64(-100.0),
            n64(200.0),
            n64(200.0),
        ));
        for lhs in rects.iter() {
            quad_tree.insert(lhs);
        }
        let quad_tree = quad_tree;

        let mut check_matches = HashMap::new();
        for lhs in rects.iter() {
            let mut values = HashSet::new();
            values = quad_tree.all_bounds_intersects(lhs, values);
            if !values.is_empty() {
                let set = check_matches.entry(lhs).or_insert_with(HashSet::new);
                for value in values.into_iter() {
                    set.insert(value);
                }
            }
        }
        let check_matches = check_matches;

        assert_eq!(baseline_matches.len(), check_matches.len());
        println!("{} == {}", baseline_matches.len(), check_matches.len());

        for (key, values) in check_matches {
            println!("inner: {} == {}", values.len(), baseline_matches[key].len());
            assert!(!baseline_matches[key]
                .iter()
                .filter(|a| !values.contains(**a))
                .any(|_| true));
        }
    }
}
