//! Test information source location.

use cnetworks::*;

fn setup_net(c: f64, seed: &str) -> Network {
    let mut net = Network::with_seed(10, Model::None, Weight::Constant { c }, seed);
    for (node, targets) in [
        (0, vec![1, 3, 4]),
        (1, vec![0, 2, 4, 5, 9]),
        (2, vec![1, 9, 8]),
        (3, vec![0, 4, 7]),
        (4, vec![0, 3, 6, 7, 5, 1]),
        (5, vec![1, 4, 7, 8, 9]),
        (6, vec![3, 7, 4]),
        (7, vec![4, 3, 6, 5, 8]),
        (8, vec![7, 5, 2]),
        (9, vec![1, 5, 2]),
    ] {
        for target in targets {
            net.link(node, target).unwrap();
        }
    }
    net
}

#[test]
fn test_pearson() {
    for c in [0.5, 0.75, 1.0] {
        let net = setup_net(c, "O2jsYh2AQ6BrLWJS");
        let keep = [1, 2, 3];
        let mut observers = si::spread(&net, 5, 10).unwrap();
        observers.keep(&keep);
        let result = locate::pearson(&net, &observers).unwrap();
        let (most_likely, _score) = result[0];
        assert!(most_likely == 5);
    }
}

#[test]
fn test_pinto() {
    for c in [0.5, 0.75, 1.0] {
        let net = setup_net(c, "O2jsYh2AQ6BrLWJS");
        let keep = [1, 2, 3];
        let mut observers = si::spread(&net, 5, 10).unwrap();
        observers.keep(&keep);
        let result = locate::lptv(&net, &observers).unwrap();
        let (most_likely, _score) = result[0];
        assert!(most_likely == 5);
    }
}

#[test]
#[ignore]
fn test_pearson_stats() {
    let mut acc = 0.0;
    let size = 100.0;
    let p = 8.0 / size;
    for _ in 0..5000 {
        let net = Network::new(size as usize, Model::ER { p, whole: true }, Weight::default());
        let mut observers = si::spread(&net, 5, 10).unwrap();
        observers.keep_random((0.1 * size).round() as usize, &mut *net.rng_lock().unwrap());
        let result = locate::pearson(&net, &observers).unwrap();
        let place = result.iter().position(|(node, _score)| node == &5).unwrap();
        let score = result[place].1;
        let max_score = result[0].1;
        if place == 0 {
            acc += 1.0;
        } else if (score - max_score).abs() < f64::EPSILON {
            acc += 1.0
                / result
                    .iter()
                    .filter(|(_node, other_score)| (*other_score - score).abs() < f64::EPSILON)
                    .count() as f64;
        }
    }
    acc /= 5000.0;
    assert!(acc > 0.95);
}

#[test]
#[ignore]
fn test_pinto_stats() {
    let mut acc = 0.0;
    let size = 100.0;
    let p = 8.0 / size;
    for _ in 0..5000 {
        let net = Network::new(size as usize, Model::ER { p, whole: true }, Weight::default());
        let mut observers = si::spread(&net, 5, 10).unwrap();
        observers.keep_random((0.1 * size).round() as usize, &mut *net.rng_lock().unwrap());
        let result = locate::lptv(&net, &observers).unwrap();
        let place = result.iter().position(|(node, _score)| node == &5).unwrap();
        let score = result[place].1;
        let max_score = result[0].1;
        if place == 0 {
            acc += 1.0;
        } else if (score - max_score).abs() < f64::EPSILON {
            acc += 1.0
                / result
                    .iter()
                    .filter(|(_node, other_score)| (*other_score - score).abs() < f64::EPSILON)
                    .count() as f64;
        }
    }
    acc /= 5000.0;
    assert!(acc > 0.95);
}
