/* Copyright (C) 2020 Dylan Staatz - All Rights Reserved. */

use nalgebra::SVector;
use rand::distributions::uniform::{SampleUniform, Uniform};
use rand::distributions::Distribution;
use rand::Rng;

use crate::scalar::Scalar;

/// A uniform distribution over type X in dimension N
pub struct LeaderFollowerCoordinates<
  X: SampleUniform,
  const D: usize,
  const N: usize,
> {
  mins: SVector<X, D>,
  maxes: SVector<X, D>,
  leader_samplers: Vec<Uniform<X>>,
  follower_samplers: Vec<Uniform<X>>,
  max_radius_squared: X,
}

impl<X: Scalar + SampleUniform, const D: usize, const N: usize>
  LeaderFollowerCoordinates<X, D, N>
{
  pub fn new(mins: SVector<X, D>, maxes: SVector<X, D>, max_radius: X) -> Self {
    assert_eq!(D * 2, N);

    let mut leader_samplers = Vec::with_capacity(D);
    for (a, b) in mins.iter().zip(maxes.iter()) {
      assert!(a < b);
      leader_samplers.push(Uniform::new(a, b));
    }

    let mut follower_samplers = Vec::with_capacity(D);
    for _ in 0..D {
      follower_samplers.push(Uniform::new(-max_radius, max_radius));
    }

    Self {
      mins,
      maxes,
      leader_samplers,
      follower_samplers,
      max_radius_squared: max_radius * max_radius,
    }
  }
}

impl<X: Scalar + SampleUniform, const D: usize, const N: usize>
  Distribution<SVector<X, N>> for LeaderFollowerCoordinates<X, D, N>
{
  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SVector<X, N> {
    let leader =
      SVector::<X, D>::from_fn(|i, _| self.leader_samplers[i].sample(rng));

    loop {
      // Sample a random point
      let follower_offset =
        SVector::<X, D>::from_fn(|i, _| self.follower_samplers[i].sample(rng));

      // Check location is within bounds of space
      let follower = &leader + &follower_offset;
      if !self
        .mins
        .iter()
        .zip(self.maxes.iter())
        .zip(follower.iter())
        .all(|((min, max), val)| min < val && val < max)
      {
        continue;
      }

      // Check that offset distance is less than max_radius
      if !(follower_offset.norm_squared() < self.max_radius_squared) {
        continue;
      }

      log::debug!("leader: {:?}, follower: {:?}", leader, follower);
      return SVector::<X, N>::from_iterator(
        leader.iter().chain(follower.iter()).cloned(),
      );
    }
  }
}
