/*!
Information spread source location module.

This module implements function aiding in **information source location**, using the information
provided by [`si::Observers`]. The algorithms construct a [`Ranking`] of suspect nodes, based on
different criteria.

## Note
Location algorithms require the subject [`Network`] to be [whole](crate::Network::is_whole).
*/

use crate::{bfs, cn, si, Network, Weight};
use nalgebra::{DMatrix, DVector};
use rayon::prelude::*;
use std::cmp::Ordering;

/// The result of running a source-location algorithm on an infected network.
///
/// It is a [`Vec`] of (suspect index, score) pairs, ordered by the score. The higher the node is
/// positioned, the higher the probability that it is the true information source.
#[derive(Debug, Clone)]
pub struct Ranking {
    ranking: cn::IndexMap<usize, f64>,
    true_source: usize,
}

impl Ranking {
    /// Returns immutable reference to the wrapped [`cn::IndexMap`].
    pub fn as_map(&self) -> &cn::IndexMap<usize, f64> {
        &self.ranking
    }

    /// Returns the recorded true source.
    pub fn true_source(&self) -> usize {
        self.true_source
    }

    /// Returns the recorded top (suspect, score) pair.
    pub fn top(&self) -> Option<(usize, f64)> {
        self.as_map().first().map(|(&i, &s)| (i, s))
    }

    /// Returns the position of the true source in the ranking, `None` if it's not in the ranking.
    pub fn true_source_position(&self) -> Option<usize> {
        self.as_map().get_index_of(&self.true_source())
    }

    /// Returns the score of the true source, `None` if it's not in the ranking.
    pub fn true_source_score(&self) -> Option<f64> {
        self.as_map().get(&self.true_source()).copied()
    }

    /// Returns ranking precision.
    ///
    /// Precision is  # of true positives / # of all positives.
    pub fn precision(&self) -> f64 {
        if let Some((_, top_score)) = self.top() {
            if let Some(source_score) = self.true_source_score() {
                if (top_score - source_score).abs() < f64::EPSILON {
                    // 1 / # of vertexes tied for top score
                    return 1.0
                        / self
                            .as_map()
                            .values()
                            .filter(|&s| (s - top_score).abs() < f64::EPSILON)
                            .count() as f64;
                }
            }
        }
        0.0
    }

    /// Returns the distance between
    pub fn distance_error(&self, net: &Network) -> cn::Result<Option<usize>> {
        if let Some((top, _)) = self.top() {
            if top == self.true_source() {
                return Ok(Some(0));
            } else {
                return bfs::distance(net, top, self.true_source());
            }
        }
        Ok(None)
    }
}

/// Constructs a suspect ranking based on the Pearson correlation coefficient.
///
/// This algorithm calculates the [Pearson correlation
/// coefficient](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient) between the vector
/// of **infection times** and the vector of **distances** from a given suspect. It is assumed that
/// the higher the correlation, the more probable the fact that the suspect is the true source.
///
/// It should be noted that this version uses **only** infection times, disregarding the originally
/// proposed additional suspect elimination based on their distance to the observer's **infectant**
/// (the node which infected the observer).
///
/// # References
/// > Xu, S., Teng, C., Zhou, Y., Peng, J., Zhang, Y., & Zhang, Z. K. (2019). [*Identifying the
/// > diffusion source in complex networks with limited
/// > observers.*](https://doi.org/10.1016/j.physa.2019.121267) Physica A: Statistical Mechanics
/// > and Its Applications, 527, 121267.
pub fn pearson(net: &Network, observers: &si::Observers) -> cn::Result<Ranking> {
    if !net.is_whole() {
        return Err(cn::Err::NotWhole);
    }
    if let Some(o) = observers.as_map().keys().find(|&o| !net.exists(*o)) {
        return Err(cn::Err::NoSuchNode(*o));
    }
    // Average time of infection can be pre calculated
    let avg_t = observers.as_map().values().sum::<usize>() as f64 / observers.as_map().len() as f64;
    let times_reduced: Vec<f64> = observers
        .as_map()
        .values()
        .map(|t| (*t as f64 - avg_t))
        .collect();
    let t_sq_sum: f64 = times_reduced.iter().map(|t| t * t).sum();

    // Need distance for these indexes - infectants and observers
    let need_distance: Vec<usize> = observers.as_map().keys().copied().collect();

    // The not-observer nodes
    let unobservers: Vec<usize> = net.nodes().collect();
    let mut scores: Vec<(usize, f64)> = unobservers
        .par_iter()
        .map(|unobserver| {
            let distances =
                bfs::distance_many(net, *unobserver, &need_distance).unwrap_or_default();
            let mut avg_d = 0.0;
            let mut distances: Vec<f64> = observers
                .as_map()
                .keys()
                .map(|o| {
                    let d = distances[*o] as f64;
                    avg_d += d;
                    d
                })
                .collect();
            avg_d /= distances.len() as f64;
            // Reduce the distances
            for d in distances.iter_mut() {
                let d_red = *d - avg_d;
                *d = if d_red.abs() < f64::EPSILON {
                    // NaN guard
                    f64::EPSILON
                } else {
                    d_red
                };
            }

            let score: f64 = distances
                .iter()
                .zip(&times_reduced)
                .map(|(d, t)| d * t)
                .sum::<f64>()
                / (distances.into_iter().map(|d| d * d).sum::<f64>() * t_sq_sum).sqrt();
            (*unobserver, score)
        })
        .collect();

    // Sort the scores
    scores.sort_unstable_by(|(_, s1), (_, s2)| {
        s2.partial_cmp(s1).unwrap_or(std::cmp::Ordering::Less)
    });
    Ok(Ranking {
        ranking: scores.into_iter().collect(),
        true_source: observers.true_source(),
    })
}
/// Constructs the ranking using the LPTV (Limited Pinto-Thiran-Vetterli) algorithm.
///
/// This algorithm orders suspect nodes using the **maximum likelihood** estimator, as described in
/// the original paper. This "limited" variant does not utilize the information about observers'
/// **infectants** (nodes which infected the observers).
///
/// **Warning:** this method is currently **not implemented** for networks with [`Weight::Uniform`]
/// and will simply return an empty ranking.
///
/// # References
/// > Pinto, P. C., Thiran, P., & Vetterli, M. (2012). [*Locating the source of diffusion in
/// > large-scale networks.*](https://doi.org/10.1103/physrevlett.109.068702) Physical Review
/// > Letters, 109(6), 1–5.
pub fn lptv(net: &Network, observers: &si::Observers) -> cn::Result<Ranking> {
    if !net.is_whole() {
        return Err(cn::Err::NotWhole);
    }
    let need_distance: Vec<usize> = observers.as_map().keys().copied().collect();
    let (ref_o, other_o) = need_distance.split_first().ok_or(cn::Err::NoTarget)?;
    let (mean, mut variance) = match net.weight() {
        Weight::Constant { c } => (1.0 / c, (1.0 - c) / (c * c)),
        Weight::Uniform => {
            return Ok(Ranking {
                ranking: cn::IndexMap::default(),
                true_source: observers.true_source(),
            })
        }
    };
    if variance.abs() < f64::EPSILON {
        variance = f64::EPSILON;
    }
    let suspects: Vec<usize> = net.nodes().collect();

    // Time delays only need to be computed once
    let delays: DVector<f64> = DVector::from_vec(
        other_o
            .par_iter()
            .map(|o| observers.as_map()[o] as f64 - observers.as_map()[ref_o] as f64)
            .collect(),
    );

    let mut scores: Vec<(usize, f64)> = suspects
        .par_iter()
        .filter_map(|&suspect| {
            // The algorithm operates on the bfs tree rooted at the suspect.
            let tree = bfs::tree_active(net, suspect, &need_distance).ok()?;
            // Distances from the suspect to observers
            let distances = bfs::distance_many(&tree, suspect, &need_distance).ok()?;
            // Paths from reference observer to others
            let paths = bfs::path_many(&tree, *ref_o, &need_distance).ok()?;
            // Covariance matrix
            let mut lambda_matrix: DMatrix<f64> =
                DMatrix::from_fn(other_o.len(), other_o.len(), |i: usize, j: usize| {
                    let o_i = other_o[i]; // Observator i index
                    let o_j = other_o[j]; // Observator j index
                    match i.cmp(&j) {
                        Ordering::Equal => (paths[o_i].as_vec().unwrap().len()) as f64,
                        Ordering::Greater => {
                            // Path from i to reference
                            let path_i = &paths[o_i];
                            // Path from j to reference
                            let path_j = &paths[o_j];
                            // Path intersection (+1 because path length)
                            (path_i.common_len(path_j).unwrap()) as f64
                        }
                        Ordering::Less => 0.0,
                    }
                });
            for j in 0..other_o.len() {
                for i in 0..j {
                    *lambda_matrix.index_mut((i, j)) = *lambda_matrix.index((j, i))
                }
            }
            // Determinant is actually (variance ** other_o.len()) times larget than it should be.
            // This is good, else it's VERY tiny, and we will solve that issue with log-addition
            // later.
            let det = lambda_matrix.determinant().abs();
            // Just now multiply the matrix by variance
            lambda_matrix *= variance;
            if !lambda_matrix.try_inverse_mut() {
                panic!("Uh oh the covariance is not invertible!");
            };
            let mu = DVector::from_vec(
                other_o
                    .iter()
                    .map(|o| mean * (distances[*o] as i64 - distances[*ref_o] as i64) as f64)
                    .collect(),
            );
            let d_red = &delays - &mu;
            // The [0] is needed because multipilcation result is really a 1x1 matrix.
            let score = -(d_red.transpose() * lambda_matrix * d_red)[0]
                - det.ln()
                // Instead of multiplying determinant by variance ln turns it into addition
                - other_o.len() as f64 * variance.ln();
            Some((suspect, score))
        })
        .collect();

    scores.sort_unstable_by(|(_, s1), (_, s2)| s2.partial_cmp(s1).unwrap());
    Ok(Ranking {
        ranking: scores.into_iter().collect(),
        true_source: observers.true_source(),
    })
}
