use approx::assert_relative_eq;
use ndarray::{arr1, arr2, s, Array1};

use ndarray_ndimage::{gaussian_filter, median_filter, Mask};

#[test] // Results verified manually.
fn test_median_filter() {
    let mut gt = Mask::from_elem((3, 3, 3), false);
    let mut mask = gt.clone();
    mask[(0, 0, 0)] = true;
    assert_eq!(median_filter(&mask), gt);
    mask[(1, 0, 0)] = true;
    assert_eq!(median_filter(&mask), gt);
    mask[(0, 1, 0)] = true;
    assert_eq!(median_filter(&mask), gt);

    gt[(0, 0, 0)] = true;
    mask[(0, 0, 1)] = true;
    assert_eq!(median_filter(&mask), gt);

    mask[(1, 1, 0)] = true;
    assert_eq!(median_filter(&mask), gt);

    gt[(1, 0, 0)] = true;
    gt[(0, 1, 0)] = true;
    gt[(0, 0, 1)] = true;
    mask[(1, 0, 1)] = true;
    assert_eq!(median_filter(&mask), gt);

    gt[(2, 0, 0)] = true;
    mask[(1, 1, 1)] = true;
    assert_eq!(median_filter(&mask.view()), gt);
}

#[test]
fn test_gaussian_filter_1d() {
    let mut a: Array1<f32> = (0..7).map(|v| v as f32).collect();
    assert_relative_eq!(
        gaussian_filter(&a, 1.0, 4.0),
        arr1(&[0.42704096, 1.0679559, 2.0048335, 3.0, 3.9951665, 4.932044, 5.572959]),
        epsilon = 1e-5
    );
    a[0] = 0.7;
    assert_relative_eq!(
        gaussian_filter(&a.view(), 2.0, 3.0),
        arr1(&[1.4193099, 1.737984, 2.3200142, 3.0642939, 3.8351974, 4.4778357, 4.845365]),
        epsilon = 1e-5
    );
}

#[test]
fn test_gaussian_filter_2d() {
    let a: Array1<f32> = (0..70).step_by(2).map(|v| v as f32).collect();
    let mut a = a.into_shape((5, 7)).unwrap();
    a[(0, 0)] = 17.0;
    assert_relative_eq!(
        gaussian_filter(&a, 1.0, 4.0),
        arr2(&[
            [13.815777, 11.339161, 10.62479, 12.028319, 13.970364, 15.842661, 17.12449],
            [19.028267, 18.574514, 19.253122, 20.97248, 22.940516, 24.813597, 26.095427],
            [29.490631, 30.42986, 32.06769, 34.004536, 35.990467, 37.864086, 39.14592],
            [41.95432, 43.209373, 45.064693, 47.050846, 49.040836, 50.914577, 52.196407],
            [50.876965, 52.158012, 54.031227, 56.02144, 58.01176, 59.885513, 61.167343],
        ]),
        epsilon = 1e-4
    );
    let a: Array1<f32> = (0..84).step_by(2).map(|v| v as f32).collect();
    let mut a = a.into_shape((6, 7)).unwrap();
    a[(0, 0)] = 8.5;
    assert_relative_eq!(
        gaussian_filter(&a, 1.0, 2.0),
        arr2(&[
            [10.078889, 9.458512, 10.006921, 11.707343, 13.707343, 15.598366, 16.892008],
            [17.220367, 17.630152, 18.90118, 20.76284, 22.76284, 24.653864, 25.947506],
            [29.114912, 30.247316, 32.025234, 34.000000, 36.000000, 37.89102, 39.184666],
            [42.815334, 44.10898, 46.000000, 48.000000, 50.000000, 51.89102, 53.184666],
            [56.052494, 57.346134, 59.23716, 61.23716, 63.23716, 65.12818, 66.42182],
            [65.107994, 66.401634, 68.292656, 70.292656, 72.292656, 74.18368, 75.47732],
        ]),
        epsilon = 1e-4
    );

    let a: Array1<f32> = (0..112).step_by(2).map(|v| v as f32).collect();
    let mut a = a.into_shape((8, 7)).unwrap();
    a[(0, 0)] = 18.2;
    assert_relative_eq!(
        gaussian_filter(&a, 1.5, 3.5),
        arr2(&[
            [16.712738, 16.30507, 16.362633, 17.34964, 18.918924, 20.453388, 21.402458],
            [22.053278, 22.092232, 22.654442, 23.931578, 25.60057, 27.156698, 28.1087],
            [31.7295, 32.2731, 33.405533, 35.01049, 36.79215, 38.372753, 39.328068],
            [44.08236, 44.91609, 46.376343, 48.169773, 50.0162, 51.61088, 52.5681],
            [57.50711, 58.440548, 60.013466, 61.87167, 63.740356, 65.339874, 66.297745],
            [70.68089, 71.636, 73.2334, 75.10567, 76.979195, 78.579765, 79.53778],
            [81.8913, 82.849335, 84.45004, 86.32423, 88.1984, 89.79911, 90.75715],
            [88.59754, 89.55557, 91.15629, 93.030464, 94.90464, 96.505356, 97.46339],
        ]),
        epsilon = 1e-4
    );
}

#[test]
fn test_gaussian_filter_3d() {
    let a: Array1<f32> = (0..720).map(|v| v as f32 / 50.0).collect();
    let mut a = a.into_shape((10, 9, 8)).unwrap();
    a[(0, 0, 0)] = 0.2;
    a[(3, 3, 3)] = 1.0;

    let g = gaussian_filter(&a, 1.8, 4.0);
    assert_relative_eq!(
        g.slice(s![0, .., ..]),
        arr2(&[
            [1.647472, 1.651181, 1.659609, 1.673325, 1.691082, 1.709747, 1.725337, 1.734229],
            [1.708805, 1.712651, 1.721257, 1.735376, 1.754014, 1.773838, 1.790377, 1.799745],
            [1.818189, 1.822212, 1.831044, 1.845692, 1.865495, 1.886855, 1.904654, 1.914653],
            [1.95729, 1.961716, 1.971077, 1.986287, 2.006792, 2.028921, 2.04732, 2.057615],
            [2.110379, 2.115686, 2.126213, 2.142124, 2.16256, 2.184116, 2.201956, 2.211958],
            [2.265391, 2.271859, 2.283923, 2.300605, 2.320466, 2.340645, 2.357214, 2.366559],
            [2.409767, 2.417196, 2.430533, 2.447822, 2.467118, 2.486012, 2.501415, 2.51016],
            [2.525863, 2.53382, 2.54786, 2.565478, 2.584446, 2.602611, 2.617354, 2.625759],
            [2.591995, 2.600145, 2.614439, 2.632176, 2.651024, 2.668921, 2.683421, 2.691702],
        ]),
        epsilon = 1e-4
    );
    assert_relative_eq!(
        g.slice(s![9, .., ..]),
        arr2(&[
            [11.68823, 11.69645, 11.71083, 11.72861, 11.74741, 11.76522, 11.77964, 11.78788],
            [11.75407, 11.76228, 11.77665, 11.79442, 11.81323, 11.83105, 11.84548, 11.85373],
            [11.86941, 11.8776, 11.89196, 11.90972, 11.92854, 11.94638, 11.96082, 11.96907],
            [12.01257, 12.02076, 12.03511, 12.05287, 12.07169, 12.08953, 12.10399, 12.11224],
            [12.16671, 12.17491, 12.18926, 12.20702, 12.22584, 12.24368, 12.25812, 12.26638],
            [12.32086, 12.32907, 12.34344, 12.36121, 12.38002, 12.39784, 12.41227, 12.42051],
            [12.46405, 12.47227, 12.48665, 12.50443, 12.52324, 12.54104, 12.55545, 12.56369],
            [12.57941, 12.58764, 12.60204, 12.61982, 12.63862, 12.65641, 12.67081, 12.67905],
            [12.64527, 12.6535, 12.6679, 12.68568, 12.70448, 12.72227, 12.73667, 12.7449],
        ]),
        epsilon = 1e-4
    );
}

#[should_panic]
#[test]
fn test_gaussian_filter_panic() {
    let a: Array1<f32> = (0..7).map(|v| v as f32).collect();

    let _ = gaussian_filter(&a, 2.0, 4.0);
}
