/* Copyright (C) 2020 Dylan Staatz - All Rights Reserved. */

use std::marker::PhantomData;

use nalgebra::storage::Storage;
use nalgebra::{Const, Norm, SVector, SimdRealField, Vector};
use serde::{de::DeserializeOwned, Deserialize, Serialize};

use crate::trajectories::Trajectory;

/// A trajectory where that represents a straight line between two points in euclidean space
///
/// Stores just a cached value of the cost between two points that are stored elsewhere
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(bound(
  serialize = "X: Serialize, NORM: Serialize",
  deserialize = "X: DeserializeOwned, NORM: DeserializeOwned"
))]
pub struct LinearTrajectory<X, NORM, const N: usize>
where
  X: SimdRealField,
  NORM: Norm<X>,
{
  norm: NORM,
  phantom_data: PhantomData<X>,
}

impl<X, NORM, const N: usize> LinearTrajectory<X, NORM, N>
where
  X: SimdRealField,
  NORM: Norm<X>,
{
  pub fn new(norm: NORM) -> Self {
    Self {
      norm,
      phantom_data: PhantomData,
    }
  }
}

impl<X, NORM, const N: usize> Trajectory<X, N> for LinearTrajectory<X, NORM, N>
where
  X: SimdRealField,
  NORM: Norm<X> + Clone,
{
  fn cost(&self) -> Option<X> {
    None
  }

  fn calc_cost<S1, S2>(
    &self,
    start: &Vector<X, Const<N>, S1>,
    end: &Vector<X, Const<N>, S2>,
  ) -> X
  where
    S1: Storage<X, Const<N>>,
    S2: Storage<X, Const<N>>,
  {
    self.norm.metric_distance(start, end)
  }

  fn interpolate<S1, S2>(
    &self,
    start: &Vector<X, Const<N>, S1>,
    end: &Vector<X, Const<N>, S2>,
    t: X,
  ) -> SVector<X, N>
  where
    S1: Storage<X, Const<N>>,
    S2: Storage<X, Const<N>>,
  {
    start.lerp(end, t)
  }
}
