use std::borrow::Borrow;
use std::cmp::min;
use std::hash::Hash;
use std::iter::FromIterator;

use crate::Interval;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde_derive", derive(serde::Deserialize, serde::Serialize))]
struct InternalEntry<N: Ord + Clone, D> {
    data: D,
    interval: Interval<N>,
    max: N,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Entry<'a, N: Ord + Clone, D> {
    index: usize,
    data: &'a D,
    interval: &'a Interval<N>,
}

impl<'a, N: Ord + Clone + 'a, D: 'a> Entry<'a, N, D> {
    pub fn index(&self) -> usize {
        self.index
    }

    pub fn value(&self) -> &'a D {
        self.data
    }

    pub fn interval(&self) -> &'a Interval<N> {
        self.interval
    }
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct EntryMut<'a, N: Ord + Clone, D> {
    index: usize,
    data: &'a mut D,
    interval: &'a Interval<N>,
}

impl<'a, N: Ord + Clone + 'a, D: 'a> EntryMut<'a, N, D> {
    pub fn index(&self) -> usize {
        self.index
    }

    pub fn value(&'a mut self) -> &'a mut D {
        self.data
    }

    pub fn into_value(self) -> &'a mut D {
        self.data
    }

    pub fn interval(&self) -> &'a Interval<N> {
        self.interval
    }
}

struct RemovalIndices(Vec<usize>);

impl RemovalIndices {
    pub fn new() -> Self {
        Self(Vec::with_capacity(512))
    }

    pub fn into_vec(mut self) -> Vec<usize> {
        self.0.sort();
        self.0
    }
}

trait EntryContainer<'a, N: Ord + Clone, D> {
    fn push_entry(&mut self, entry: Entry<'a, N, D>) -> bool;
}

impl<'a, N: Ord + Clone, D> EntryContainer<'a, N, D> for bool {
    fn push_entry(&mut self, _entry: Entry<'a, N, D>) -> bool {
        *self = true;
        true
    }
}

impl<'a, N: Ord + Clone, D> EntryContainer<'a, N, D> for RemovalIndices {
    fn push_entry(&mut self, entry: Entry<'a, N, D>) -> bool {
        self.0.push(entry.index);
        false
    }
}

impl<'a, N: Ord + Clone, D> EntryContainer<'a, N, D> for Option<Entry<'a, N, D>> {
    fn push_entry(&mut self, entry: Entry<'a, N, D>) -> bool {
        *self = Some(entry);
        true
    }
}

impl<'a, N: Ord + Clone, D> EntryContainer<'a, N, D> for Vec<Entry<'a, N, D>> {
    fn push_entry(&mut self, entry: Entry<'a, N, D>) -> bool {
        self.push(entry);
        false
    }
}

trait EntryMutContainer<'a, N: Ord + Clone, D> {
    fn push_entry_mut(&mut self, entry: EntryMut<'a, N, D>) -> bool;
}

impl<'a, N: Ord + Clone, D> EntryMutContainer<'a, N, D> for Option<EntryMut<'a, N, D>> {
    fn push_entry_mut(&mut self, entry: EntryMut<'a, N, D>) -> bool {
        *self = Some(entry);
        true
    }
}

impl<'a, N: Ord + Clone, D> EntryMutContainer<'a, N, D> for Vec<EntryMut<'a, N, D>> {
    fn push_entry_mut(&mut self, entry: EntryMut<'a, N, D>) -> bool {
        self.push(entry);
        false
    }
}

impl<N: Ord + Clone, D> Default for IntervalTree<N, D> {
    fn default() -> Self {
        IntervalTree {
            entries: vec![],
            max_level: 0,
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde_derive", derive(serde::Deserialize, serde::Serialize))]
pub struct IntervalTree<N: Ord + Clone, D> {
    entries: Vec<InternalEntry<N, D>>,
    max_level: usize,
}

impl<N, D, V> FromIterator<(V, D)> for IntervalTree<N, D>
where
    V: Into<Interval<N>>,
    N: Ord + Clone,
{
    fn from_iter<T: IntoIterator<Item = (V, D)>>(iter: T) -> Self {
        let mut tree = Self::new();
        iter.into_iter()
            .for_each(|(interval, data)| tree.insert(interval, data));
        tree.index();
        tree
    }
}

impl<N: Ord + Clone, D> IntervalTree<N, D> {
    pub fn new() -> Self {
        Default::default()
    }

    pub fn insert<K: Into<Interval<N>>, V: Into<D>>(&mut self, interval: K, data: V) {
        let interval = interval.into();
        let data = data.into();

        let max = interval.end().clone();
        self.entries.push(InternalEntry {
            interval,
            data,
            max,
        });
        self.index();
    }

    pub fn extend<I, K, V>(&mut self, intervals: I)
    where K: Into<Interval<N>>,
          V: Into<D>,
          I: IntoIterator<Item=(K, V)> {
        let it = intervals.into_iter();
        let (lb, ub) = it.size_hint();
        self.entries.reserve(ub.unwrap_or(lb));

        for (interval, data) in it.map(|(k, v)| (k.into(), v.into())) {
            let max = interval.end().clone();
            self.entries.push(InternalEntry {
                interval,
                data,
                max,
            });
        }
        self.index();
    }

    fn index(&mut self) {
        self.entries.sort_by(|l, r| l.interval.start().cmp(&r.interval.start()));
        self.index_core();
    }

    fn index_core(&mut self) {
        let a = &mut self.entries;
        if a.is_empty() {
            return;
        }

        let n = a.len();
        let mut last_i = 0;
        let mut last_value = a[0].max.clone();
        (0..n).step_by(2).for_each(|i| {
            last_i = i;
            a[i].max = a[i].interval.end().clone();
            last_value = a[i].max.clone();
        });
        let mut k = 1;
        while (1 << k) <= n {
            // process internal nodes in the bottom-up order
            let x = 1 << (k - 1);
            let i0 = (x << 1) - 1; // i0 is the first node
            let step = x << 2;
            for i in (i0..n).step_by(step) {
                // traverse all nodes at level k
                let end_left = a[i - x].max.clone(); // max value of the left child
                let end_right = if i + x < n { a[i + x].max.clone() } else { last_value.clone() }; // max value of the right child
                let end = max3(a[i].interval.end(), &end_left, &end_right).clone();
                a[i].max = end;
            }
            last_i = if (last_i >> k & 1) > 0 {
                last_i - x
            } else {
                last_i + x
            };
            if last_i < n && a[last_i].max > last_value {
                last_value = a[last_i].max.clone()
            }
            k += 1;
        }
        self.max_level = k - 1;
    }

    pub fn get(&self, index: usize) -> Option<Entry<N, D>> {
        self.entries.get(index)
            .map(|e| Entry { index, interval: &e.interval, data: &e.data })
    }

    pub fn get_mut(&mut self, index: usize) -> Option<EntryMut<N, D>> {
        self.entries.get_mut(index)
            .map(|e| EntryMut { index, interval: &e.interval, data: &mut e.data })
    }

    pub fn overlaps<M: Borrow<N>, K: Into<Interval<M>>>(&self, interval: K) -> bool {
        let mut found = false;
        self.find_aux(interval, &mut found);
        found
    }

    pub fn find<M: Borrow<N>, K: Into<Interval<M>>>(&self, interval: K) -> Option<Entry<N, D>> {
        let mut first = None;
        self.find_aux(interval, &mut first);
        first
    }

    pub fn find_exact<M: Borrow<N>, K: Into<Interval<M>>>(&self, interval: K) -> Option<Entry<N, D>> {
        let mut first = None;
        self.find_exact_aux(interval, &mut first);
        first
    }

    pub fn find_all<M: Borrow<N>, K: Into<Interval<M>>>(&self, interval: K) -> Vec<Entry<N, D>> {
        let mut buf = Vec::with_capacity(512);
        self.find_aux(interval, &mut buf);
        buf
    }

    pub fn find_mut<M: Borrow<N>, K: Into<Interval<M>>>(&mut self, interval: K) -> Option<EntryMut<N, D>> {
        let mut first = None;
        self.find_mut_aux(interval, &mut first);
        first
    }

    pub fn find_all_mut<M: Borrow<N>, K: Into<Interval<M>>>(&mut self, interval: K) -> Vec<EntryMut<N, D>> {
        let mut buf = Vec::with_capacity(512);
        self.find_mut_aux(interval, &mut buf);
        buf
    }

    pub fn remove_exact<M: Borrow<N>, K: Into<Interval<M>>>(&mut self, interval: K) {
        let mut indices = RemovalIndices::new();
        self.find_exact_aux(interval, &mut indices);
        let positions = indices.into_vec();
        for pos in positions.into_iter().rev() {
            self.entries.remove(pos);
        }
        self.index();
    }

    pub fn remove_overlaps<M: Borrow<N>, K: Into<Interval<M>>>(&mut self, interval: K) {
        let mut indices = RemovalIndices::new();
        self.find_aux(interval, &mut indices);
        let positions = indices.into_vec();
        for pos in positions.into_iter().rev() {
            self.entries.remove(pos);
        }
        self.index();
    }

    fn find_exact_aux<'b, 'a: 'b, M: Borrow<N>, I: Into<Interval<M>>, C>(
        &'a self,
        interval: I,
        results: &'b mut C,
    ) where C: EntryContainer<'a, N, D> {
        let interval = interval.into();
        let (start, end) = (interval.start().borrow(), interval.end().borrow());
        let n = self.entries.len() as usize;
        let a = &self.entries;
        let mut stack = [StackCell::empty(); 64];
        // push the root; this is a top down traversal
        stack[0].k = self.max_level;
        stack[0].x = (1 << self.max_level) - 1;
        stack[0].w = false;
        let mut t = 1;
        while t > 0 {
            t -= 1;
            let StackCell { k, x, w } = stack[t];
            if k <= 3 {
                // we are in a small subtree; traverse every node in this subtree
                let i0 = x >> k << k;
                let i1 = min(i0 + (1 << (k + 1)) - 1, n);
                for (i, node) in a.iter().enumerate().take(i1).skip(i0) {
                    if node.interval.start() > end {
                        break;
                    }
                    if start == node.interval.start() && end == node.interval.end() {
                        // if overlap, append to `results`
                        if results.push_entry(Entry {
                            index: i,
                            interval: &self.entries[i].interval,
                            data: &self.entries[i].data,
                        }) {
                            return
                        }
                    }
                }
            } else if !w {
                // if left child not processed
                let y = x - (1 << (k - 1)); // the left child of x; NB: y may be out of range (i.e. y>=n)
                stack[t].k = k;
                stack[t].x = x;
                stack[t].w = true; // re-add node x, but mark the left child having been processed
                t += 1;
                if y >= n || a[y].max >= *start {
                    // push the left child if y is out of range or may overlap with the query
                    stack[t].k = k - 1;
                    stack[t].x = y;
                    stack[t].w = false;
                    t += 1;
                }
            } else if x < n && a[x].interval.start() <= end {
                // need to push the right child
                if start == a[x].interval.start() && end == a[x].interval.end() {
                    if results.push_entry(Entry {
                        index: x,
                        interval: &self.entries[x].interval,
                        data: &self.entries[x].data,
                    }) {
                        return
                    }
                }
                stack[t].k = k - 1;
                stack[t].x = x + (1 << (k - 1));
                stack[t].w = false;
                t += 1;
            }
        }
    }

    fn find_aux<'b, 'a: 'b, M: Borrow<N>, I: Into<Interval<M>>, C>(
        &'a self,
        interval: I,
        results: &'b mut C,
    ) where C: EntryContainer<'a, N, D> {
        let interval = interval.into();
        let (start, end) = (interval.start().borrow(), interval.end().borrow());
        let n = self.entries.len() as usize;
        let a = &self.entries;
        let mut stack = [StackCell::empty(); 64];
        // push the root; this is a top down traversal
        stack[0].k = self.max_level;
        stack[0].x = (1 << self.max_level) - 1;
        stack[0].w = false;
        let mut t = 1;
        while t > 0 {
            t -= 1;
            let StackCell { k, x, w } = stack[t];
            if k <= 3 {
                // we are in a small subtree; traverse every node in this subtree
                let i0 = x >> k << k;
                let i1 = min(i0 + (1 << (k + 1)) - 1, n);
                for (i, node) in a.iter().enumerate().take(i1).skip(i0) {
                    if node.interval.start() > end {
                        break;
                    }
                    if start <= node.interval.end() {
                        // if overlap, append to `results`
                        if results.push_entry(Entry {
                            index: i,
                            interval: &self.entries[i].interval,
                            data: &self.entries[i].data,
                        }) {
                            return
                        }
                    }
                }
            } else if !w {
                // if left child not processed
                let y = x - (1 << (k - 1)); // the left child of x; NB: y may be out of range (i.e. y>=n)
                stack[t].k = k;
                stack[t].x = x;
                stack[t].w = true; // re-add node x, but mark the left child having been processed
                t += 1;
                if y >= n || a[y].max >= *start {
                    // push the left child if y is out of range or may overlap with the query
                    stack[t].k = k - 1;
                    stack[t].x = y;
                    stack[t].w = false;
                    t += 1;
                }
            } else if x < n && a[x].interval.start() <= end {
                // need to push the right child
                if start <= a[x].interval.end() {
                    if results.push_entry(Entry {
                        index: x,
                        interval: &self.entries[x].interval,
                        data: &self.entries[x].data,
                    }) {
                        return
                    }
                }
                stack[t].k = k - 1;
                stack[t].x = x + (1 << (k - 1));
                stack[t].w = false;
                t += 1;
            }
        }
    }

    fn find_mut_aux<'b, 'a: 'b, M: Borrow<N>, I: Into<Interval<M>>, C>(
        &'a mut self,
        interval: I,
        results: &'b mut C,
    ) where C: EntryMutContainer<'a, N, D> {
        let interval = interval.into();
        let (start, end) = (interval.start().borrow(), interval.end().borrow());
        let n = self.entries.len() as usize;
        let a = &self.entries;

        let mut stack = [StackCell::empty(); 64];
        // push the root; this is a top down traversal
        stack[0].k = self.max_level;
        stack[0].x = (1 << self.max_level) - 1;
        stack[0].w = false;

        let mut t = 1;

        while t > 0 {
            t -= 1;
            let StackCell { k, x, w } = stack[t];
            if k <= 3 {
                // we are in a small subtree; traverse every node in this subtree
                let i0 = x >> k << k;
                let i1 = min(i0 + (1 << (k + 1)) - 1, n);
                for (i, node) in a.iter().enumerate().take(i1).skip(i0) {
                    if node.interval.start() > end {
                        break;
                    }
                    if start <= node.interval.end() {
                        // if overlap, append to `results`
                        if unsafe {
                            let entries = self.entries.as_ptr();
                            let entry = entries.add(i);

                            results.push_entry_mut(EntryMut {
                                index: i,
                                interval: &(*entry).interval,
                                data: &mut (*(entry as *mut InternalEntry<N, D>)).data,
                            })
                        } {
                            return;
                        }
                    }
                }
            } else if !w {
                // if left child not processed
                let y = x - (1 << (k - 1)); // the left child of x; NB: y may be out of range (i.e. y>=n)
                stack[t].k = k;
                stack[t].x = x;
                stack[t].w = true; // re-add node x, but mark the left child having been processed
                t += 1;
                if y >= n || a[y].max >= *start {
                    // push the left child if y is out of range or may overlap with the query
                    stack[t].k = k - 1;
                    stack[t].x = y;
                    stack[t].w = false;
                    t += 1;
                }
            } else if x < n && a[x].interval.start() <= end {
                // need to push the right child
                if start <= a[x].interval.end() {
                    if unsafe {
                        let entries = self.entries.as_ptr();
                        let entry = entries.add(x);

                        results.push_entry_mut(EntryMut {
                            index: x,
                            interval: &(*entry).interval,
                            data: &mut (*(entry as *mut InternalEntry<N, D>)).data,
                        })
                    } {
                        return
                    }
                }
                stack[t].k = k - 1;
                stack[t].x = x + (1 << (k - 1));
                stack[t].w = false;
                t += 1;
            }
        }
    }

    pub fn len(&self) -> usize {
        self.entries.len()
    }

    pub fn iter(&self) -> impl Iterator<Item=(&Interval<N>, &D)> {
        self.entries.iter().map(|e| (&e.interval, &e.data))
    }

    pub fn into_iter(self) -> impl Iterator<Item=(Interval<N>, D)> {
        self.entries.into_iter().map(|e| (e.interval, e.data))
    }

    pub fn values(&self) -> impl Iterator<Item=&D> {
        self.entries.iter().map(|e| &e.data)
    }

    pub fn values_mut(&mut self) -> impl Iterator<Item=&mut D> {
        self.entries.iter_mut().map(|e| &mut e.data)
    }
}

fn max3<T: Ord>(a: T, b: T, c: T) -> T {
    a.max(b.max(c))
}

#[derive(Clone, Copy)]
struct StackCell {
    // node
    x: usize,
    // level
    k: usize,
    // false if left child hasn't been processed
    w: bool,
}

impl StackCell {
    fn empty() -> Self {
        Self {
            x: 0,
            k: 0,
            w: false,
        }
    }
}
