//! This module provides a distortion corrector of geometry distortion of the
//! diffraction data acquired with 2D detectors (mainly CCD-type detectors).
//!
use crate::spline::Spline;
use cryiorust::frame::{Array, FrameError, FrameResult};

/// Main exported structure which holds distortions info.
pub struct Distortion {
    dim1: isize,
    dim2: isize,
    dim11: isize,
    dim22: isize,
    size: isize,
    size12: isize,
    lut_max: isize,
    delta0: isize,
    delta1: isize,
    lut_coef: Vec<f64>,
    lut_idx: Vec<usize>,
}

impl Distortion {
    /// Creates a new [Distortion] structure.
    pub fn new() -> Distortion {
        Distortion {
            dim1: 0,
            dim2: 0,
            dim11: 0,
            dim22: 0,
            size: 0,
            size12: 0,
            lut_max: 0,
            delta0: 0,
            delta1: 0,
            lut_coef: vec![],
            lut_idx: vec![],
        }
    }

    /// Create an initialized [Distortion] with [Spline].
    pub fn new_init(array: &Array, spline: &Spline) -> Distortion {
        let mut d = Distortion::new();
        d.init(array, spline);
        d
    }

    /// Initializes [Distortion] structure using [Array] and [Spline] coefficients.
    pub fn init(&mut self, array: &Array, spline: &Spline) {
        self.dim1 = array.dim1() as isize;
        self.dim2 = array.dim2() as isize;
        self.dim11 = self.dim1 + 1;
        self.dim22 = self.dim2 + 1;
        self.size = self.dim1 * self.dim2;
        self.size12 = self.dim11 * self.dim22;
        self.lut_max = 0;
        let corners = self.calc_corners(spline.x(), spline.y());
        let pos = self.calc_pos(corners);
        self.calculate_lut_size(&pos);
        self.calc_lut_table(pos);
    }

    /// Check whether [Distortion] is initialized with the current [Array].
    pub fn is_initialized(&self, array: &Array) -> bool {
        self.dim1 == array.dim1() as isize && self.dim2 == array.dim2() as isize
    }

    /// Correct the give [Array] producing a new [Array].
    pub fn correct(&self, array: &Array) -> FrameResult<Array> {
        if array.dim1() != self.dim1 as usize || array.dim2() != self.dim2 as usize {
            return Err(FrameError::FormatError(
                format!(
                    "distortion is initialized with dims {}x{} but provided array has dims {}x{}",
                    self.dim1,
                    self.dim2,
                    array.dim1(),
                    array.dim2()
                )
                .into(),
            ));
        }
        let mut data = Vec::with_capacity(self.size as usize);
        unsafe { data.set_len(self.size as usize) };
        let dd = self.dim2 * self.lut_max;
        for i in 0..self.dim1 {
            let ii = i * dd;
            let ij = i * self.dim2;
            for j in 0..self.dim2 {
                let jj = ii + j * self.lut_max;
                let mut error = 0.;
                let mut sum = 0.;
                for k in 0..self.lut_max {
                    let m = (jj + k) as usize;
                    let (idx, coef) = unsafe {
                        (
                            *self.lut_idx.get_unchecked(m),
                            *self.lut_coef.get_unchecked(m),
                        )
                    };
                    if coef <= 0. || idx >= self.size as usize {
                        continue;
                    }
                    let y = unsafe { *array.data().get_unchecked(idx) } * coef - error;
                    let t = sum + y;
                    error = t - sum - y;
                    sum = t;
                }
                unsafe { *data.get_unchecked_mut((ij + j) as usize) = sum };
            }
        }
        Ok(Array::with_data(
            self.dim1 as usize,
            self.dim2 as usize,
            data,
        ))
    }

    fn calc_corners(&mut self, spline_x: &[f64], spline_y: &[f64]) -> Vec<f64> {
        let mut corners = Vec::with_capacity(2 * self.size12 as usize);
        let mut delta0 = 0.;
        let mut delta1 = 0.;
        for i in 0..self.dim11 {
            let n = i * self.dim22;
            let m = (i - 1) * self.dim22;
            for j in 0..self.dim22 {
                let l = (n + j) as usize;
                corners.push(spline_y[l]);
                corners.push(spline_x[l]);
                if i > 0 {
                    let delta_y = spline_y[l].ceil() - spline_y[(m + j) as usize].floor();
                    if delta_y > delta0 {
                        delta0 = delta_y
                    }
                }
                if j > 0 {
                    let delta_x = spline_x[l].ceil() - spline_x[l - 1].floor();
                    if delta_x > delta1 {
                        delta1 = delta_x;
                    }
                }
            }
        }
        self.delta0 = delta0 as isize;
        self.delta1 = delta1 as isize;
        corners
    }

    fn calc_pos(&self, corners: Vec<f64>) -> Vec<f64> {
        let size = 4 * 2 * self.size as usize;
        let mut pos = Vec::with_capacity(size);
        unsafe { pos.set_len(size) };
        for i in 0..self.dim11 as isize {
            let m = i * self.dim22 * 2;
            let m1 = i * self.dim2 * 4 * 2;
            let m2 = (i - 1) * self.dim2 * 4 * 2;
            for j in 0..self.dim22 {
                let n = m + j * 2;
                let n1 = j * 4 * 2;
                let n2 = (j - 1) * 4 * 2;
                let k1 = m1 + n1 + 0 * 4;
                let k2 = m1 + n2 + 1 * 4 - 2;
                let k3 = m2 + n2 + 2 * 4 - 4;
                let k4 = m2 + n1 + 3 * 4 - 6;
                for k in 0..2 {
                    let l = (n + k) as usize;
                    if i != self.dim1 && j != self.dim2 {
                        pos[(k1 + k) as usize] = corners[l];
                    }
                    if i != self.dim1 && j != 0 {
                        pos[(k2 + k) as usize] = corners[l];
                    }
                    if i != 0 && j != 0 {
                        pos[(k3 + k) as usize] = corners[l];
                    }
                    if i != 0 && j != self.dim2 {
                        pos[(k4 + k) as usize] = corners[l];
                    }
                }
            }
        }
        pos
    }

    fn calculate_lut_size(&mut self, pos: &Vec<f64>) {
        let mut lut = vec![0; self.size as usize];
        let dd = self.dim2 * 4 * 2;
        for i in 0..self.dim1 {
            let ii = i * dd;
            for j in 0..self.dim2 {
                let jj = (ii + j * 4 * 2) as usize;
                let a0 = pos[jj + 0];
                let a1 = pos[jj + 1];
                let b0 = pos[jj + 2];
                let b1 = pos[jj + 3];
                let c0 = pos[jj + 4];
                let c1 = pos[jj + 5];
                let d0 = pos[jj + 6];
                let d1 = pos[jj + 7];
                let min0 = clip(a0.min(b0).min(c0).min(d0).floor() as isize, 0, self.dim1);
                let min1 = clip(a1.min(b1).min(c1).min(d1).floor() as isize, 0, self.dim2);
                let max0 = clip(a0.max(b0).max(c0).max(d0).ceil() as isize + 1, 0, self.dim1);
                let max1 = clip(a1.max(b1).max(c1).max(d1).ceil() as isize + 1, 0, self.dim2);
                for k in min0..max0 {
                    let kk = k * self.dim2;
                    for l in min1..max1 {
                        let m = (kk + l) as usize;
                        lut[m] += 1;
                        if lut[m] > self.lut_max {
                            self.lut_max = lut[m];
                        }
                    }
                }
            }
        }
    }

    fn calc_lut_table(&mut self, pos: Vec<f64>) {
        let size = (self.size * self.lut_max) as usize;
        self.lut_idx = vec![0; size];
        self.lut_coef = vec![0.; size];
        let mut buffer = self.get_buffer();
        let mut out_max: Vec<isize> = vec![0; self.size as usize];
        let dd = self.dim2 * 4 * 2;
        let ddd = self.dim2 * self.lut_max;
        let mut idx = 0;
        for i in 0..self.dim1 {
            let ii = i * dd;
            for j in 0..self.dim2 {
                let jj = (ii + j * 4 * 2) as usize;
                let mut a0 = pos[jj + 0];
                let mut a1 = pos[jj + 1];
                let mut b0 = pos[jj + 2];
                let mut b1 = pos[jj + 3];
                let mut c0 = pos[jj + 4];
                let mut c1 = pos[jj + 5];
                let mut d0 = pos[jj + 6];
                let mut d1 = pos[jj + 7];
                let o0 = a0.min(b0).min(c0).min(d0).floor();
                let o1 = a1.min(b1).min(c1).min(d1).floor();
                let offset0 = o0 as isize;
                let offset1 = o1 as isize;
                let box_size0 = a0.max(b0).max(c0).max(d0).ceil() as isize - offset0;
                let box_size1 = a1.max(b1).max(c1).max(d1).ceil() as isize - offset1;
                a0 -= o0;
                a1 -= o1;
                b0 -= o0;
                b1 -= o1;
                c0 -= o0;
                c1 -= o1;
                d0 -= o0;
                d1 -= o1;
                let (pab, cab) = p_diff(a0, b0, a1, b1);
                let (pbc, cbc) = p_diff(b0, c0, b1, c1);
                let (pcd, ccd) = p_diff(c0, d0, c1, d1);
                let (pda, cda) = p_diff(d0, a0, d1, a1);
                buffer = self.clear_buffer(box_size0, box_size1, buffer);
                self.integrate(&mut buffer, b0, a0, pab, cab);
                self.integrate(&mut buffer, a0, d0, pda, cda);
                self.integrate(&mut buffer, d0, c0, pcd, ccd);
                self.integrate(&mut buffer, c0, b0, pbc, cbc);
                let area = 0.5 * ((c0 - a0) * (d1 - b1) - (c1 - a1) * (d0 - b0));
                for ms in 0..box_size0 {
                    let ml = ms + offset0;
                    if ml < 0 || ml >= self.dim1 {
                        continue;
                    }
                    let mms = ms * self.delta1;
                    let mml = ml * self.dim2;
                    let mld = ml * ddd;
                    for ns in 0..box_size1 {
                        let nl = ns + offset1;
                        if nl < 0 || nl >= self.dim2 {
                            continue;
                        }
                        let value = buffer[(mms + ns) as usize] / area;
                        if value <= 0. {
                            continue;
                        }
                        let m = (mml + nl) as usize;
                        let k = out_max[m];
                        let n = (mld + nl * self.lut_max + k) as usize;
                        self.lut_idx[n] = idx;
                        self.lut_coef[n] = value;
                        out_max[m] = k + 1;
                    }
                }
                idx += 1;
            }
        }
    }

    fn get_buffer(&self) -> Vec<f64> {
        let buf_size = (self.delta1 * self.delta0) as usize;
        let mut buffer = Vec::with_capacity(buf_size);
        unsafe { buffer.set_len(buf_size) };
        buffer
    }

    fn clear_buffer(&mut self, delta0: isize, delta1: isize, mut buffer: Vec<f64>) -> Vec<f64> {
        if delta0 > self.delta0 || delta1 > self.delta1 {
            self.delta0 = delta0;
            self.delta1 = delta1;
            buffer = self.get_buffer();
        }
        for v in &mut buffer {
            *v = 0.;
        }
        buffer
    }

    //noinspection DuplicatedCode
    fn integrate(&self, buffer: &mut [f64], start: f64, stop: f64, slope: f64, intercept: f64) {
        if start < stop {
            let p = start.ceil();
            let dp = p - start;
            if p > stop {
                let a = calc_area(start, stop, slope, intercept);
                if a != 0. {
                    let mut aa = a.abs();
                    let sign = a / aa;
                    let mut da = stop - start;
                    let mut h = 0;
                    while aa > 0. {
                        if da > aa {
                            da = aa;
                            aa = -1.;
                        }
                        let idx = start.floor() as isize * self.delta1 + h;
                        buffer[idx as usize] += sign * da;
                        aa -= da;
                        h += 1;
                    }
                }
            } else {
                if dp > 0. {
                    let a = calc_area(start, p, slope, intercept);
                    if a != 0. {
                        let mut aa = a.abs();
                        let sign = a / aa;
                        let mut da = dp;
                        let mut h = 0;
                        while aa > 0. {
                            if da > aa {
                                da = aa;
                                aa = -1.;
                            }
                            let idx = (p.floor() as isize - 1) * self.delta1 + h;
                            buffer[idx as usize] += sign * da;
                            aa -= da;
                            h += 1;
                        }
                    }
                }
                for i in p.floor() as isize..stop.floor() as isize {
                    let a = calc_area(i as f64, i as f64 + 1., slope, intercept);
                    if a != 0. {
                        let mut aa = a.abs();
                        let sign = a / aa;
                        let mut da = 1.0;
                        let mut h = 0;
                        while aa > 0. {
                            if da > aa {
                                da = aa;
                                aa = -1.;
                            }
                            let idx = i * self.delta1 + h;
                            buffer[idx as usize] += sign * da;
                            aa -= da;
                            h += 1;
                        }
                    }
                }
                let p = stop.floor();
                let dp = stop - p;
                if dp > 0. {
                    let a = calc_area(p, stop, slope, intercept);
                    if a != 0. {
                        let mut aa = a.abs();
                        let sign = a / aa;
                        let mut da = dp.abs();
                        let mut h = 0;
                        while aa > 0. {
                            if da > aa {
                                da = aa;
                                aa = -1.;
                            }
                            let idx = p.floor() as isize * self.delta1 + h;
                            buffer[idx as usize] += sign * da;
                            aa -= da;
                            h += 1;
                        }
                    }
                }
            }
        } else if start > stop {
            let p = start.floor();
            if stop > p {
                let a = calc_area(start, stop, slope, intercept);
                if a != 0. {
                    let mut aa = a.abs();
                    let sign = a / aa;
                    let mut da = start - stop;
                    let mut h = 0;
                    while aa > 0. {
                        if da > aa {
                            da = aa;
                            aa = -1.;
                        }
                        let idx = start.floor() as isize * self.delta1 + h;
                        buffer[idx as usize] += sign * da;
                        aa -= da;
                        h += 1;
                    }
                }
            } else {
                let dp = p - start;
                if dp < 0. {
                    let a = calc_area(start, p, slope, intercept);
                    if a != 0. {
                        let mut aa = a.abs();
                        let sign = a / aa;
                        let mut da = dp.abs();
                        let mut h = 0;
                        while aa > 0. {
                            if da > aa {
                                da = aa;
                                aa = -1.;
                            }
                            let idx = p.floor() as isize * self.delta1 + h;
                            buffer[idx as usize] += sign * da;
                            aa -= da;
                            h += 1;
                        }
                    }
                }
                let mut i = start as isize;
                let s = stop.ceil() as isize;
                while i > s {
                    let a = calc_area(i as f64, i as f64 - 1., slope, intercept);
                    if a != 0. {
                        let mut aa = a.abs();
                        let sign = a / aa;
                        let mut da = 1.;
                        let mut h = 0;
                        while aa > 0. {
                            if da > aa {
                                da = aa;
                                aa = -1.;
                            }
                            let idx = (i - 1) * self.delta1 + h;
                            buffer[idx as usize] += sign * da;
                            aa -= da;
                            h += 1;
                        }
                    }
                    i -= 1;
                }
                let p = stop.ceil();
                let dp = stop - p;
                if dp < 0. {
                    let a = calc_area(p, stop, slope, intercept);
                    if a != 0. {
                        let mut aa = a.abs();
                        let sign = a / aa;
                        let mut da = dp.abs();
                        let mut h = 0;
                        while aa > 0. {
                            if da > aa {
                                da = aa;
                                aa = -1.;
                            }
                            let idx = stop.floor() as isize * self.delta1 + h;
                            buffer[idx as usize] += sign * da;
                            aa -= da;
                            h += 1;
                        }
                    }
                }
            }
        }
    }
}

fn calc_area(start: f64, stop: f64, slope: f64, intercept: f64) -> f64 {
    0.5 * (stop - start) * (slope * (stop + start) + 2.0 * intercept)
}

fn p_diff(x0: f64, y0: f64, x1: f64, y1: f64) -> (f64, f64) {
    if y0 != x0 {
        let v = (y1 - x1) / (y0 - x0);
        (v, x1 - v * x0)
    } else {
        (0., 0.)
    }
}

fn clip(value: isize, min: isize, max: isize) -> isize {
    if value < min {
        min
    } else if value > max {
        max
    } else {
        value
    }
}

#[cfg(test)]
mod tests {
    use std::fs::File;
    use std::io;
    use std::io::BufRead;

    use cryiorust::frame::{Array, Frame};

    use crate::distortion::Distortion;
    use crate::spline::Spline;
    use cryiorust::cbf::Cbf;

    fn get_corrected() -> Array {
        let mut data = vec![];
        let reader = io::BufReader::new(File::open("testdata/corrected.dat").unwrap());
        for line in reader.lines() {
            let line = line.unwrap();
            data.push(line.parse::<f64>().unwrap());
        }
        Array::with_data(40, 50, data)
    }

    fn get_frame(dim1: usize, dim2: usize) -> Cbf {
        let mut data = vec![];
        for i in 0..dim1 * dim2 {
            data.push(i as f64);
        }
        let a = Array::with_data(dim1, dim2, data);
        let mut cbf = Cbf::new();
        cbf.set_array(a);
        cbf
    }

    #[test]
    fn test_distortion_correction() {
        let (dim1, dim2) = (40, 50);
        let frame = get_frame(dim1, dim2);
        let mut s = Spline::open("testdata/F21newEO.spline").unwrap();
        s.calculate(frame.array());
        let mut d = Distortion::new();
        d.init(frame.array(), &s);
        assert_eq!(d.lut_max, 15);
        let corrected_exp = get_corrected();
        let corrected_act = d.correct(frame.array()).unwrap();
        assert_eq!(corrected_exp.dim1(), corrected_act.dim1());
        assert_eq!(corrected_exp.dim2(), corrected_act.dim2());
        assert_eq!(corrected_exp.data().len(), corrected_act.data().len());
        for (i, (a, e)) in corrected_act
            .data()
            .iter()
            .zip(corrected_exp.data())
            .enumerate()
        {
            let d = (*a - *e).abs();
            assert!(d < 1e-6, "image[{}]: expected = {}; actual = {}", i, *e, *a);
        }
    }
}
