/*! Susceptible-infected information spread module.
*/

use crate::{cn, Network};
use indexmap::map::Entry;
use rand::prelude::{IteratorRandom, Rng};

/// `Observers` data structure. Wraps [`cn::IndexMap`] of (index, infection time) pairs.
pub struct Observers(cn::IndexMap<usize, Option<usize>>);

impl Observers {

    /// Returns immutable reference to the wrapped [`cn::IndexMap`].
    pub fn as_map(&self) -> &cn::IndexMap<usize, Option<usize>> {
        &self.0
    }

    /// Sorts the observers by infection time
    pub fn sort_by_time(&mut self) {
        self.0.sort_by(|_, t1, _, t2| t1.cmp(t2));
    }

    /// Keep only observers contained in `keep`, removing all of the others. Observers which do not
    /// exist are ignored.
    pub fn keep(&mut self, keep: &[usize]) {
        self.0.retain(|index, _| keep.contains(index));
    }

    /// Keep only `amount` of random observers, discarding all of the others.
    ///
    /// Empties the inner [`cn::IndexMap`] if `amount` is greater than (or equal to) it's length.
    pub fn keep_random(&mut self, amount: usize, rng: &mut impl Rng) {
        if amount >= self.0.len() {
            self.0.clear();
        } else {
            let to_keep = self
                .0
                .keys()
                .copied()
                .choose_multiple(rng, self.0.len() - amount);
            self.keep(&to_keep);
        }
    }
}

/// Initialize the SI information spread and perform a number of synchronous `steps`, starting from
/// `root`.
///
/// Returns the [`Observers`] set, which by design **DOES NOT** contain the `root`.
///
/// Returns [`cn::Err::NoSuchNode`] if `root` does not exist.
pub fn spread(net: &Network, root: usize, steps: usize) -> cn::Result<Observers> {
    let mut infected = cn::IndexMap::default();
    infected.insert(root, Some(0));
    for step in 1..steps + 1 {
        let inf: Vec<usize> = infected.keys().copied().collect();
        for i in inf {
            for (neighbor, weight) in net.links_of(i)? {
                if let Entry::Vacant(entry) = infected.entry(*neighbor) {
                    if net.rng_lock().unwrap().gen::<f64>() < *weight {
                        entry.insert(Some(step));
                    }
                }
            }
        }
    }
    infected.remove(&root);
    Ok(Observers(infected))
}

/// Same as [`spread`], but with a root chosen randomly using the network's random number generator.
///
/// The randomly chosen root is returned along with the [`Observers`] set.
pub fn spread_random(net: &Network, steps: usize) -> (usize, Observers) {
    let root = net.nodes().choose(&mut *net.rng_lock().unwrap()).unwrap();
    let infected = spread(net, root, steps).unwrap();
    (root, infected)
}
