//! Kademlia node IDs

use crate::serde::serde_byte_array;
use rand::{
    distributions::{
        uniform::{SampleBorrow, SampleRange, SampleUniform, UniformSampler},
        Distribution, Standard,
    },
    Rng, RngCore,
};
use serde::{Deserialize, Serialize};
use std::{
    borrow::Borrow,
    cmp::Ordering,
    fmt::{self, Debug, Display},
    ops::{Bound, RangeBounds},
};

/// A torrent infohash
pub type InfoHash = Id;

/// A DHT node/info_hash ID
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Id(#[serde(with = "serde_byte_array")] [u8; Id::LEN]);

impl Id {
    fn with_bits<R>(&self, range: R, state: bool) -> Id
    where
        R: Iterator<Item = usize> + Debug,
    {
        let mut new = *self;
        for bit in range {
            let byte = bit / 8;
            let bit = (7 - bit % 8) % 8;
            let mask = 1 << bit;
            if state {
                new.0[byte] |= mask;
            } else {
                new.0[byte] &= !mask;
            }
        }
        new
    }

    /// ID with all zeros
    pub const ZERO: Id = Id([0x00u8; Id::LEN]);
    /// ID with all ones
    pub const ONES: Id = Id([0xFFu8; Id::LEN]);
    /// The length of an ID
    pub const LEN: usize = 20;

    const BITS: usize = 160;
}

impl From<[u8; Id::LEN]> for Id {
    fn from(bytes: [u8; Id::LEN]) -> Self {
        Self(bytes)
    }
}

impl AsRef<[u8]> for Id {
    fn as_ref(&self) -> &[u8] {
        &self.0
    }
}

impl Debug for Id {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(fmt, "{}", self)
    }
}

impl Display for Id {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(fmt, "{}", hex::encode(self.0))
    }
}

impl Distribution<Id> for Standard {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Id {
        let mut buf = [0u8; Id::LEN];
        rng.fill(&mut buf);
        Id(buf)
    }
}

#[doc(hidden)]
pub struct UniformId {
    low: Id,
    high: Bound<Id>,
}

impl UniformSampler for UniformId {
    type X = Id;

    fn new<B1, B2>(low: B1, high: B2) -> Self
    where
        B1: SampleBorrow<Self::X> + Sized,
        B2: SampleBorrow<Self::X> + Sized,
    {
        Self {
            low: *low.borrow(),
            high: Bound::Excluded(*high.borrow()),
        }
    }

    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
    where
        B1: SampleBorrow<Self::X> + Sized,
        B2: SampleBorrow<Self::X> + Sized,
    {
        Self {
            low: *low.borrow(),
            high: Bound::Included(*high.borrow()),
        }
    }

    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
        loop {
            let id: Id = rng.gen();
            let low_cmp = id >= self.low;
            let high_cmp = match self.high {
                Bound::Included(high) => id <= high,
                Bound::Excluded(high) => id < high,
                Bound::Unbounded => true,
            };
            if low_cmp && high_cmp {
                break id;
            }
        }
    }
}

impl SampleUniform for Id {
    type Sampler = UniformId;
}

pub struct Neighborhood {
    from: Id,
    spread_bits: usize,
}

impl Neighborhood {
    pub fn new(from: Id, spread_bits: usize) -> Self {
        Self { from, spread_bits }
    }
}

impl Distribution<Id> for Neighborhood {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Id {
        let full_bytes = Id::LEN - self.spread_bits / 8;
        let bottom_bits = self.spread_bits % 8;
        let top_mask = !((1 << bottom_bits) - 1);

        let mut buf = self.from.0;
        rng.fill(&mut buf[full_bytes..]);
        buf[full_bytes] &= !top_mask;
        buf[full_bytes] |= self.from.0[full_bytes] & top_mask;

        Id(buf)
    }
}

#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Distance([u8; Id::LEN]);

impl Distance {
    pub fn between(a: Id, b: Id) -> Self {
        let mut buf = [0u8; Id::LEN];
        let pairs = a.0.iter().zip(b.0.iter());
        buf.iter_mut().zip(pairs).for_each(|(c, (a, b))| *c = a ^ b);
        Self(buf)
    }
}

impl Debug for Distance {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(fmt, "{}", hex::encode(self.0))
    }
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct IdRange {
    first: Id,
    last: Id, // inclusive end bound kept for convenience
    used_bits: usize,
}

impl IdRange {
    pub fn full() -> Self {
        Self {
            first: Id::ZERO,
            last: Id::ONES,
            used_bits: 0,
        }
    }

    pub fn usable_bits(&self) -> usize {
        Id::BITS - self.used_bits
    }

    pub fn first(&self) -> Id {
        self.first
    }

    pub fn last(&self) -> Id {
        self.last
    }

    pub fn split(&self) -> (IdRange, IdRange) {
        assert!(self.used_bits < Id::BITS);
        let new_bits = self.used_bits + 1;
        let before_middle = self.first.with_bits(new_bits..Id::BITS, true);
        let middle = self.first.with_bits(self.used_bits..=self.used_bits, true);
        (
            Self {
                first: self.first,
                last: before_middle,
                used_bits: new_bits,
            },
            Self {
                first: middle,
                last: self.last,
                used_bits: new_bits,
            },
        )
    }
}

impl Default for IdRange {
    fn default() -> Self {
        Self::full()
    }
}

impl PartialOrd for IdRange {
    fn partial_cmp(&self, rhs: &IdRange) -> Option<Ordering> {
        self.first.partial_cmp(&rhs.first)
    }
}

impl Ord for IdRange {
    fn cmp(&self, rhs: &IdRange) -> Ordering {
        self.first.cmp(&rhs.first)
    }
}

impl Borrow<Id> for IdRange {
    fn borrow(&self) -> &Id {
        &self.first
    }
}

impl RangeBounds<Id> for IdRange {
    fn start_bound(&self) -> Bound<&Id> {
        Bound::Included(&self.first)
    }
    fn end_bound(&self) -> Bound<&Id> {
        Bound::Excluded(&self.last)
    }
}

impl SampleRange<Id> for IdRange {
    fn sample_single<R: RngCore + ?Sized>(self, rng: &mut R) -> Id {
        rng.sample(Neighborhood::new(self.first, self.usable_bits()))
    }

    fn is_empty(&self) -> bool {
        self.used_bits == Id::BITS
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use hex_literal::hex;

    #[test]
    fn ids_sort_correctly() {
        let id1 = Id::from(hex!("0123456789 0123456789 0123456789 0123456789"));
        let id2 = Id::from(hex!("0123456789 0123456789 0123456789 012345678A"));
        let id3 = Id::from(hex!("0123456789 0123456789 0987654321 0123456789"));
        let id4 = Id::from(hex!("0123456789 0000000000 0123456789 0123456789"));
        let id5 = Id::from(hex!("1111111111 0123456789 0123456789 0123456789"));
        let id6 = Id::from(hex!("0123456789 2222222222 0123456789 0123456789"));
        let id7 = Id::from(hex!("0123456789 0123456789 4444444444 0123456789"));
        let id8 = Id::from(hex!("0123456789 0123456789 0123456789 8888888888"));
        let mut ids = [id1, id2, id3, id4, id5, id6, id7, id8];
        ids.sort();

        let expected = [id4, id1, id2, id8, id3, id7, id6, id5];
        assert_eq!(ids, expected);
    }

    #[test]
    fn distances_are_computed_correctly() {
        let id0 = Id::from(hex!("0123456789 0123456789 0123456789 0123456789"));
        let id1 = Id::from(hex!("0123456789 0123456789 0123456789 0123456789"));
        let id2 = Id::from(hex!("0123456789 0123456789 0123456789 012345678A"));
        let id3 = Id::from(hex!("0123456789 0123456789 0987654321 0123456789"));
        let id4 = Id::from(hex!("0123456789 0000000000 0123456789 0123456789"));
        let id5 = Id::from(hex!("1111111111 0123456789 0123456789 0123456789"));
        let id6 = Id::from(hex!("0123456789 2222222222 0123456789 0123456789"));
        let id7 = Id::from(hex!("0123456789 0123456789 4444444444 0123456789"));
        let id8 = Id::from(hex!("0123456789 0123456789 0123456789 8888888888"));
        let distances = [
            Distance::between(id0, id1),
            Distance::between(id0, id2),
            Distance::between(id0, id3),
            Distance::between(id0, id4),
            Distance::between(id0, id5),
            Distance::between(id0, id6),
            Distance::between(id0, id7),
            Distance::between(id0, id8),
        ];

        let expected = [
            Distance(hex!("0000000000 0000000000 0000000000 0000000000")),
            Distance(hex!("0000000000 0000000000 0000000000 0000000003")),
            Distance(hex!("0000000000 0000000000 08A42024A8 0000000000")),
            Distance(hex!("0000000000 0123456789 0000000000 0000000000")),
            Distance(hex!("1032547698 0000000000 0000000000 0000000000")),
            Distance(hex!("0000000000 23016745AB 0000000000 0000000000")),
            Distance(hex!("0000000000 0000000000 45670123CD 0000000000")),
            Distance(hex!("0000000000 0000000000 0000000000 89ABCDEF01")),
        ];
        assert_eq!(distances, expected);
    }

    #[test]
    fn distances_sort_correctly() {
        let distance1 = Distance(hex!("0000000000 0000000000 0000000000 0000000000"));
        let distance2 = Distance(hex!("0000000000 0000000000 0000000000 0000000003"));
        let distance3 = Distance(hex!("0000000000 0000000000 08A42024A8 0000000000"));
        let distance4 = Distance(hex!("0000000000 0123456789 0000000000 0000000000"));
        let distance5 = Distance(hex!("1032547698 0000000000 0000000000 0000000000"));
        let distance6 = Distance(hex!("0000000000 23016745AB 0000000000 0000000000"));
        let distance7 = Distance(hex!("0000000000 0000000000 45670123CD 0000000000"));
        let distance8 = Distance(hex!("0000000000 0000000000 0000000000 89ABCDEF01"));
        let mut distances = [
            distance1, distance2, distance3, distance4, distance5, distance6, distance7, distance8,
        ];
        distances.sort();

        let expected = [
            distance1, distance2, distance8, distance3, distance7, distance4, distance6, distance5,
        ];
        assert_eq!(distances, expected);
    }
}
