/*! Susceptible-infected information spread module.
*/

use crate::{cn, Network};
use indexmap::map::Entry;
use rand::prelude::{IteratorRandom, Rng};

/// `Observers` data structure. Wraps [`cn::IndexMap`] of `(node index, infection time)` pairs,
/// together with the recorded true infection source and true spread start time.
#[derive(Debug, Clone)]
pub struct Observers {
    true_source: usize,
    true_time: usize,
    observers: cn::IndexMap<usize, usize>,
}

impl Observers {
    /// Returns the recoreded true source.
    pub fn true_source(&self) -> usize {
        self.true_source
    }

    /// Returns the recorded true infection start time.
    pub fn true_time(&self) -> usize {
        self.true_time
    }

    /// Returns immutable reference to the wrapped [`cn::IndexMap`] of `(node index, infection
    /// time)` pairs.
    pub fn as_map(&self) -> &cn::IndexMap<usize, usize> {
        &self.observers
    }

    /// Sorts the observers by infection time.
    ///
    /// Returns mutable reference to `self` for easier method chaining.
    pub fn by_time(&mut self) -> &mut Self {
        self.observers.sort_by(|_, t1, _, t2| t1.cmp(t2));
        self
    }

    /// Keep only observers contained in `keep`, removing all of the others. Observers which do not
    /// exist are ignored.
    ///
    /// Returns mutable reference to `self` for easier method chaining.
    pub fn keep(&mut self, keep: &[usize]) -> &mut Self {
        self.observers.retain(|index, _| keep.contains(index));
        self
    }

    /// Keep only `amount` of random observers, discarding all of the others. Empties the inner
    /// [`cn::IndexMap`] if `amount` is greater than (or equal to) it's length.
    ///
    /// Returns mutable reference to `self` for easier method chaining.
    pub fn keep_random(&mut self, amount: usize, rng: &mut impl Rng) -> &mut Self {
        if amount >= self.observers.len() {
            self.observers.clear();
        } else if amount == 0 {
            return self;
        } else {
            let to_keep = self
                .observers
                .keys()
                .copied()
                .choose_multiple(rng, amount);
            self.keep(&to_keep);
        }
        self
    }

    /// Keep only `portion` of random observers, discarding all of the others. Empties the inner
    /// [`cn::IndexMap`] if `portion >= 1`. Does not remove any observers if `portion < 0`.
    ///
    /// Returns mutable reference to `self` for easier method chaining.
    pub fn keep_portion(&mut self, portion: f64, rng: &mut impl Rng) -> &mut Self {
        self.keep_random((portion * self.as_map().len() as f64).round() as usize, rng)
    }

    /// Keep only observers found in `net`, discarding all of the others.
    ///
    /// Returns mutable reference to `self` for easier method chaining.
    pub fn adjust_to(&mut self, net: &Network) -> &mut Self {
        self.keep(&net.nodes().collect::<Vec<usize>>())
    }
}

/// Initialize the SI information spread and perform a number of synchronous `steps`, starting from
/// `source`. Returns the [`Observers`] structure.
///
/// Returns [`cn::Err::NoSuchNode`] if `source` does not exist.
pub fn spread(net: &Network, source: usize, steps: usize) -> cn::Result<Observers> {
    let mut infected = cn::IndexMap::default();
    infected.insert(source, 0);
    for step in 1..steps + 1 {
        let inf: Vec<usize> = infected.keys().copied().collect();
        for i in inf {
            for (neighbor, weight) in net.links_of(i)? {
                if let Entry::Vacant(entry) = infected.entry(*neighbor) {
                    if net.rng_lock().unwrap().gen::<f64>() < *weight {
                        entry.insert(step);
                    }
                }
            }
        }
    }
    Ok(Observers {
        true_source: source,
        true_time: 0,
        observers: infected,
    })
}

/// Same as [`spread`], but with source chosen randomly using the network's random number
/// generator. Returns the [`Observers`] structure.
pub fn spread_random(net: &Network, steps: usize) -> Observers {
    let root = net.nodes().choose(&mut *net.rng_lock().unwrap()).unwrap();
    spread(net, root, steps).unwrap()
}
