#![allow(non_snake_case)]

//! Julier-Uhlmann duplex 'uscented' state estimation.
//!
//! A discrete Bayesian estimator that uses the [`KalmanState`] linear representation of the system.
//!
//! The 'unscented' transform is used for non-linear state predictions and observations.
//! The transform evaluates the non-linear predict and observe functions at duplex 'sigma' points about the mean to provide an
//! estimate of the distribution. The transforms can be optimised for particular functions by vary the Kappa parameter from its usual value of 1.

use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, RealField, U1};

use crate::linalg::rcond;
use crate::matrix::quadform_tr_x;
use crate::models::{KalmanState, FunctionalPredictor, FunctionalObserver};
use crate::noise::CorrelatedNoise;
use num_traits::FromPrimitive;


/// Duplex 'unscented' state estimation.
///
/// Simply the Kalman state with the kappa value required for the duplex 'unscented' transform.
pub struct UnscentedDuplexState<N: RealField, D: Dim>
where
    DefaultAllocator: Allocator<N, D, D> + Allocator<N, D>,
{
    pub kalman: KalmanState<N, D>,
    pub kappa: N
}

impl<N: Copy + FromPrimitive + RealField, D: Dim> FunctionalPredictor<N, D> for UnscentedDuplexState<N, D>
    where
        DefaultAllocator: Allocator<N, D, D> + Allocator<N, D> + Allocator<N, U1, D>
{
    /// Unscented state prediction with a functional prediction model and additive correlected noise.
    fn predict(
        &mut self, f: impl Fn(&OVector<N, D>) -> OVector<N, D>,
        noise: &CorrelatedNoise<N, D>)
        -> Result<(), &'static str>
    {
        // Create Unscented distribution
        let x_kappa = N::from_usize(self.kalman.x.nrows()).unwrap() + self.kappa;
        let (mut UU, _rcond) = self.kalman.to_unscented_duplex(x_kappa)?;

        // Predict points of XX using supplied predict model
        for c in 0..(UU.len()) {
            UU[c] = f(&UU[c]);
        }

        // State covariance
        self.kalman.from_unscented_duplex(&UU, self.kappa);
        // Additive Noise
        self.kalman.X += &noise.Q;

        Ok(())
    }
}

impl<N: Copy + FromPrimitive + RealField, D: Dim, ZD: Dim> FunctionalObserver<N, D, ZD> for UnscentedDuplexState<N, D>
    where
        DefaultAllocator: Allocator<N, D, D> + Allocator<N, D> + Allocator<N, ZD>
        + Allocator<N, U1, ZD> + Allocator<N, ZD, D>  + Allocator<N, D, ZD> + Allocator<N, ZD, ZD> + Allocator<N, U1, D>
{
    /// Unscented state observation with a functional observation model and additive correlected noise.
    ///
    /// For discontinues functions 'h_normalise' allows the observations to be normalised about a single value.
    fn observe(
        &mut self,
        z: &OVector<N, ZD>,
        h: impl Fn(&OVector<N, D>) -> OVector<N, ZD>,
        noise: &CorrelatedNoise<N, ZD>)
        -> Result<(), &'static str>
    {
        // Create Unscented distribution
        let x_kappa = N::from_usize(self.kalman.x.nrows()).unwrap() + self.kappa;
        let (UU, _rcond) = self.kalman.to_unscented_duplex(x_kappa)?;

        // Predict points of ZZ using supplied observation model
        let usize = UU.len();
        let mut ZZ: Vec<OVector<N, ZD>> = Vec::with_capacity(usize);
        for i in 0..usize {
            ZZ.push(h(&UU[i]));
        }

        // Mean and covariance of observation distribution
        let mut zZ = KalmanState::<N, ZD>::new_zero(z.shape_generic().0);
        zZ.from_unscented_duplex(&ZZ, self.kappa);
        for i in 0..usize {
            ZZ[i] -= &zZ.x;
        }

        let two = N::from_u32(2).unwrap();

        // Correlation of state with observation: Xxz
        // Center point, premult here by 2 for efficiency
        let x = &self.kalman.x;
        let mut XZ = (&UU[0] - x) * ZZ[0].transpose() * two * self.kappa;
        // Remaining Unscented points
        for i in 1..ZZ.len() {
            XZ += &(&UU[i] - x) * ZZ[i].transpose();
        }
        XZ /= two * x_kappa;

        let S = zZ.X + &noise.Q;

        // Inverse innovation covariance
        let SI = S.clone().cholesky().ok_or("S not PD in observe")?.inverse();

        // Kalman gain, X*Hx'*SI
        let W = &XZ * SI;

        // State update
        self.kalman.x += &W * (z - h(&self.kalman.x));
        // X -= W.S.W'
        self.kalman.X.quadform_tr(N::one().neg(), &W, &S, N::one());

        Ok(())
    }
}

impl<N: Copy + FromPrimitive + RealField, D: Dim> KalmanState<N, D>
    where
        DefaultAllocator: Allocator<N, D, D> + Allocator<N, D> + Allocator<N, U1, D>
{
    /// Calculates the Kalman State from the 'UU' unscented duplex sigma points.
    pub fn from_unscented_duplex(
        &mut self,
        UU: &Vec<OVector<N, D>>, kappa: N)
        where
            DefaultAllocator: Allocator<N, D, D> + Allocator<N, D> + Allocator<N, U1, D>
    {
        let two = N::from_u32(2).unwrap();

        let x_scale = N::from_usize((UU.len() - 1) / 2).unwrap() + kappa;
        // Mean of predicted distribution: x
        self.x = &UU[0] * two * kappa;
        for i in 1..UU.len() {
            self.x += &UU[i];
        }
        self.x /= two * x_scale;

        // Covariance of distribution: X
        // Center point, premult here by 2 for efficiency
        quadform_tr_x(&mut self.X, two * kappa, &(&UU[0] - &self.x), N::zero());
        // Remaining Unscented points
        for i in 1..UU.len() {
            quadform_tr_x(&mut self.X, N::one(), &(&UU[i] - &self.x), N::one());
        }
        self.X /= two * x_scale;
    }

    /// Calculates the unscented duplex sigma points from a Kalman State.
    ///
    /// Will return an error if the covariance matrix is not PSD.
    pub fn to_unscented_duplex(
        &self,
        scale: N)
        -> Result<(Vec<OVector<N, D>>, N), &'static str>
        where
            DefaultAllocator: Allocator<N, D, D> + Allocator<N, D>
    {
        let sigma = self.X.clone().cholesky().ok_or("to_unscented_duplex, X not PSD")?.l() * scale.sqrt();

        // Generate UU with the same sample Mean and Covariance
        let mut UU: Vec<OVector<N, D>> = Vec::with_capacity(2 * self.x.nrows() + 1);
        UU.push(self.x.clone());

        for c in 0..self.x.nrows() {
            let sigmaCol = sigma.column(c);
            UU.push(&self.x + &sigmaCol);
            UU.push(&self.x - &sigmaCol);
        }

        Ok((UU, rcond::rcond_symetric(&self.X)))
    }

}

