/* Copyright (C) 2020 Dylan Staatz - All Rights Reserved. */

use nalgebra::constraint::{
  SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
};
use nalgebra::storage::Storage;
use nalgebra::{
  Dim, EuclideanNorm, LpNorm, Matrix, Norm, SimdComplexField, SimdRealField,
  UniformNorm, Vector,
};
use num_traits::Zero;
use serde::{Deserialize, Serialize};

/// Euclidean norm squared.
#[derive(Copy, Clone, Debug)]
pub struct EuclideanNormSquared;

impl<X: SimdComplexField> Norm<X> for EuclideanNormSquared {
  #[inline]
  fn norm<R, C, S>(&self, m: &Matrix<X, R, C, S>) -> X::SimdRealField
  where
    R: Dim,
    C: Dim,
    S: Storage<X, R, C>,
  {
    m.norm_squared()
  }

  #[inline]
  fn metric_distance<R1, C1, S1, R2, C2, S2>(
    &self,
    m1: &Matrix<X, R1, C1, S1>,
    m2: &Matrix<X, R2, C2, S2>,
  ) -> X::SimdRealField
  where
    R1: Dim,
    C1: Dim,
    S1: Storage<X, R1, C1>,
    R2: Dim,
    C2: Dim,
    S2: Storage<X, R2, C2>,
    ShapeConstraint: SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C2>,
  {
    m1.zip_fold(m2, X::SimdRealField::zero(), |acc, a, b| {
      let diff = a - b;
      acc + diff.simd_modulus_squared()
    })
  }
}

/// Enum that implements [`nalgebra::Norm`] for each type of norm
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum NormCost {
  OneNorm,
  TwoNorm,
  LpNorm(i8),
  InfNorm,
}

impl NormCost {
  pub fn cost<X, R1, S1, R2, S2>(
    &self,
    a: &Vector<X, R1, S1>,
    b: &Vector<X, R2, S2>,
  ) -> X
  where
    X: SimdRealField,
    R1: Dim,
    S1: Storage<X, R1>,
    R2: Dim,
    S2: Storage<X, R2>,
    ShapeConstraint: SameNumberOfRows<R1, R2>,
  {
    self.metric_distance(a, b)
  }
}

impl Default for NormCost {
  fn default() -> Self {
    Self::TwoNorm
  }
}

impl<X: SimdComplexField> Norm<X> for NormCost {
  #[inline]
  fn norm<R, C, S>(&self, m: &Matrix<X, R, C, S>) -> X::SimdRealField
  where
    R: Dim,
    C: Dim,
    S: Storage<X, R, C>,
  {
    match self {
      Self::OneNorm => LpNorm(1).norm(m),
      Self::TwoNorm => EuclideanNorm.norm(m),
      Self::LpNorm(i) => LpNorm((*i).into()).norm(m),
      Self::InfNorm => UniformNorm.norm(m),
    }
  }

  #[inline]
  fn metric_distance<R1, C1, S1, R2, C2, S2>(
    &self,
    m1: &Matrix<X, R1, C1, S1>,
    m2: &Matrix<X, R2, C2, S2>,
  ) -> X::SimdRealField
  where
    R1: Dim,
    C1: Dim,
    S1: Storage<X, R1, C1>,
    R2: Dim,
    C2: Dim,
    S2: Storage<X, R2, C2>,
    ShapeConstraint: SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C2>,
  {
    match self {
      Self::OneNorm => LpNorm(1).metric_distance(m1, m2),
      Self::TwoNorm => EuclideanNorm.metric_distance(m1, m2),
      // Self::TwoNormSquared => EuclideanNormSquared.metric_distance(m1, m2),
      Self::LpNorm(i) => LpNorm((*i).into()).metric_distance(m1, m2),
      Self::InfNorm => UniformNorm.metric_distance(m1, m2),
    }
  }
}
