use hashbrown::HashMap;
use nalgebra::{DMatrix, DVector, OMatrix};
use rand::{Rng, thread_rng};
use rayon::prelude::*;
use crate::errors::ALSError;

const DEFAULT_ITERATIONS : usize = 10;
const DEFAULT_EPS : f64 = 1.0e-9;
const DEFAULT_REG : f64 = 1.0;
type T = f64;
pub type RTriplet<T> = (usize, usize, T);

pub struct ALS<T> {
    n : usize,
    m : usize,
    k : usize,
    r_row_first: HashMap<usize, HashMap<usize, T>>,
    r_col_first : HashMap<usize, HashMap<usize, T>>,
    x_mat : Vec<DVector<T>>,
    y_mat : Vec<DVector<T>>,
    default_iters : usize,
    default_regularization: T,
}

impl ALS<T> {

    /// Constructs a new ALS learner for an initially empty sparse matrix R of size N x M using
    /// K features for X and Y.
    pub fn new(n : usize, m : usize, k : usize) -> Self {
        let mut als =
        ALS {
            n,
            m,
            k,
            r_row_first : HashMap::new(),
            r_col_first : HashMap::new(),
            x_mat : vec![],
            y_mat : vec![],
            default_iters : DEFAULT_ITERATIONS,
            default_regularization: DEFAULT_REG,
        };
        als.init_y();
        als.init_x();
        als
    }

    /// Adds a value to the sparse matrix R. Will overwrite a previous value if indices coincide.
    pub fn add(&mut self, e : RTriplet<T>) -> Result<Option<T>, ALSError<T>> {
        if e.0 >= self.n {
            return Err(ALSError::InvalidTripletError(e, format!("{} exceeds row index range for R = {}x{}", e.0, self.n, self.m)))
        }
        if e.1 >= self.m {
            return Err(ALSError::InvalidTripletError(e, format!("{} exceeds column index range of R = {}x{}", e.1, self.n, self.m)))
        }

        let mut previous_entry_val = None;
        self.r_row_first.entry(e.0)
            .and_modify(|col| {
                previous_entry_val = col.insert(e.1, e.2);
            })
            .or_insert({
                let mut col = HashMap::new();
                previous_entry_val = col.insert(e.1, e.2);
                col
            });

        self.r_col_first.entry(e.1)
            .and_modify(|row| {
                row.insert(e.0, e.2);
            })
            .or_insert({
                let mut row = HashMap::new();
                row.insert(e.0, e.2);
                row
            });

        Ok(previous_entry_val)
    }

    /// Resets all entries of X with values uniformly sampled from (0, 1 / sqrt(K)).
    pub fn reset_x(&mut self) {
        let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
        self.x_mat.par_iter_mut().for_each(|x_col| {
            x_col.fill_with(|| thread_rng().gen_range(0.0..upper_init_bound))
        });
    }

    /// Resets all entries of Y with values uniformly sampled from (0, 1 / sqrt(K)).
    pub fn reset_y(&mut self) {
        let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
        self.y_mat.par_iter_mut().for_each(|y_col| {
            y_col.fill_with(|| thread_rng().gen_range(0.0..upper_init_bound))
        });
    }

    fn init_x(&mut self) {
        self.x_mat = Vec::with_capacity(self.n);
        let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
        self.x_mat.par_extend((0..self.n).into_par_iter()
            .map(|_| DVector::<T>::from_fn(
                self.k,
                |_, _| thread_rng().gen_range(0.0..upper_init_bound))));
    }

    fn init_y(&mut self) {
        self.y_mat = Vec::with_capacity(self.m);
        let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
        self.y_mat.par_extend((0..self.m).into_par_iter()
            .map(|_| DVector::<T>::from_fn(
                self.k,
                |_, _| thread_rng().gen_range(0.0..upper_init_bound))));
    }

    /// Clears all entries of R.
    pub fn reset_r(&mut self) {
        self.r_row_first = HashMap::new();
        self.r_col_first = HashMap::new();
    }

    /// Sets the regularization factor.
    pub fn set_regularization(&mut self, lambda : T) {
        self.default_regularization = lambda;
    }

    pub fn set_default_iters(&mut self, iters : usize) {
        self.default_iters = iters;
    }

    /// Trains for a specified amount of iterations.
    pub fn train_for(&mut self, iters: usize) {
       self.ensure_x_y_existence();
        let mut precomp_yyt: HashMap<usize, OMatrix<T, _, _>> = HashMap::with_capacity(self.m);
        let mut precomp_xxt: HashMap<usize, OMatrix<T, _, _>> = HashMap::with_capacity(self.n);
        let reg_diag = DMatrix::<T>::from_diagonal_element(self.k, self.k, self.default_regularization);
        precomp_yyt.par_extend(
            self.r_col_first.par_keys()
                .map(|i_m| {
                    (*i_m, DMatrix::<T>::zeros(self.k, self.k))
                })
        );
        precomp_xxt.par_extend(
            self.r_row_first.par_keys()
                .map(|i_n| {
                    (*i_n, DMatrix::<T>::zeros(self.k, self.k))
                })
        );
        for _ in 0..iters {
            precomp_yyt.par_iter_mut().for_each(|(i_m, kk_term)| {
                let y_i = &self.y_mat[*i_m];
                y_i.mul_to(&y_i.transpose(), kk_term);
            });

            self.x_mat.par_iter_mut().enumerate().for_each(|(i_n, x_row)| {
                if let Some(r_row) = self.r_row_first.get(&i_n) {
                    let mut first_sum = reg_diag.clone();
                    let mut second_sum: DVector<T> = DVector::zeros(self.k);
                    r_row.iter().for_each(|(i_m, r_nm)|{
                        first_sum += precomp_yyt.get(i_m).unwrap();
                        second_sum += &(&self.y_mat[*i_m] * *r_nm);
                    });
                    if !first_sum.try_inverse_mut() {
                        first_sum = first_sum.pseudo_inverse(DEFAULT_EPS).unwrap();
                    }
                    first_sum.mul_to(&second_sum, x_row);
                }
            });

            precomp_xxt.par_iter_mut().for_each(|(i_n, kk_term)| {
                let x_i = &self.x_mat[*i_n];
                x_i.mul_to(&x_i.transpose(), kk_term);
            });

            self.y_mat.par_iter_mut().enumerate().for_each(|(i_m, y_row)| {
                if let Some(r_col) =  self.r_col_first.get(&i_m) {
                    let mut first_sum = reg_diag.clone();
                    let mut second_sum: DVector<T> = DVector::zeros(self.k);
                    r_col.iter().for_each(|(i_n, r_nm)|{
                        first_sum += precomp_xxt.get(i_n).unwrap();
                        second_sum += &(&self.x_mat[*i_n] * *r_nm);
                    });
                    if !first_sum.try_inverse_mut() {
                        first_sum = first_sum.pseudo_inverse(DEFAULT_EPS).unwrap();
                    }
                    first_sum.mul_to(&second_sum, y_row);
                }

            });
        }
    }

    fn ensure_x_y_existence(&mut self) {
        if self.x_mat.len() != self.n {
            self.init_x();
        }

        if self.y_mat.len() != self.m {
            self.init_y();
        }
    }

    /// Trains for the default amount of iterations set for the instance.
    pub fn train(&mut self) {
        self.train_for(self.default_iters);
    }

    /// Get the feature vectors of the row
    pub fn get_row_factors(&self, row : usize) -> Option<&DVector<T>> {
        self.x_mat.get(row)
    }
    pub fn get_col_factors(&self, col : usize) -> Option<&DVector<T>> {
        self.y_mat.get(col)
    }

    pub fn get_x(&self) -> &Vec<DVector<T>> {
        &self.x_mat
    }

    pub fn get_y(&self) -> &Vec<DVector<T>> {
        &self.y_mat
    }


    /// Computes the cost function between X^T x Y and R.
    pub fn cost(&mut self) -> T {
        self.ensure_x_y_existence();
        let r_term : T = self.r_row_first.par_iter().map(|(i_n, col)| {
            col
                .par_iter()
                .map(|(i_m, val)|
                    (*val - (self.x_mat[*i_n].transpose() * &self.y_mat[*i_m])[(0, 0)])
                        .powi(2)
                )
                .sum::<T>()
        }).sum::<T>();

        let x_term : T = self.x_mat
            .par_iter()
            .map(|x_in| (x_in.transpose() * x_in)[(0, 0)])
            .sum::<T>();

        let y_term : T = self.y_mat
            .par_iter()
            .map(|y_in| (y_in.transpose() * y_in)[(0, 0)])
            .sum::<T>();

        r_term + self.default_regularization * (x_term + y_term)
    }

    /// Predicts the value of R at some index.
    pub fn predict_r_val(&self, n :usize, m : usize) -> T {
        (self.x_mat[n].transpose() * &self.y_mat[m])[(0, 0)]
    }
}

