use crate::{detail::QuantumSimImpl, parallel_chunk_size, OpBuffer, QuantumSim};

use ndarray::Array2;
use num_complex::Complex64;
use num_traits::{One, Zero};
use rayon::prelude::*;
use rustc_hash::FxHashMap;
use std::mem::MaybeUninit;

pub type FullState = Vec<Complex64>;

/// Full state quantum simulation using a vector of all possible quantum state for the currently allocated qubits.
/// This simulator is very performant with states using smaller amounts of qubits, including dense quantum states,
/// but uses exponentially more memory for each additional qubit, limiting total size of simulation.
pub type FullStateQuantumSim = QuantumSim<FullState>;

/// Utility for performing a parallel, nearly in-place Kronecker product of an identity matrix and the
/// given unitary followed by the dot product with the state vector to produce a new state vector.
/// This uses several assumptions. First, it assumes the targets for the operation have all been swapped
/// into the right-most part of the state vector. Second, the given operation is the only matrix being
/// applied to the vector such that the left part of the Kronecker product is known to be identity.
/// # Panics
///
/// This function will panic if it is unable to allocate enough space to hold the new state vector.
fn krondot(op: &Array2<Complex64>, state: &[Complex64]) -> Vec<Complex64> {
    let mut new_state: Vec<MaybeUninit<Complex64>> = Vec::with_capacity(state.len());
    unsafe {
        new_state.set_len(state.len());
    }
    let chunk_size = parallel_chunk_size(state.len()).max(op.nrows());
    state
        .par_chunks_exact(chunk_size)
        .zip(new_state.par_chunks_exact_mut(chunk_size))
        .for_each(|(state_chunk, new_state_chunk)| {
            state_chunk
                .chunks_exact(op.nrows())
                .zip(new_state_chunk.chunks_exact_mut(op.nrows()))
                .for_each(|(state_op_chunk, new_state_op_chunk)| {
                    op.rows()
                        .into_iter()
                        .zip(new_state_op_chunk.iter_mut())
                        .for_each(|(row, entry)| {
                            *entry = row.iter().zip(state_op_chunk).fold(
                                MaybeUninit::new(Complex64::zero()),
                                |accum, (c1, c2)| {
                                    MaybeUninit::new(unsafe { accum.assume_init() } + c1 * c2)
                                },
                            );
                        });
                });
        });
    unsafe { std::mem::transmute::<_, Vec<Complex64>>(new_state) }
}

impl Default for QuantumSim<FullState> {
    fn default() -> Self {
        Self::new()
    }
}

impl QuantumSim<FullState> {
    /// Creates a new full state quantum simulator object with empty initial state (no qubits allocated, no operations buffered).
    /// Using the defined types, this can be abbreviated to:
    ///
    /// ```
    /// use qqs::*;
    ///
    /// let mut sim = FullStateQuantumSim::new();
    /// ```
    #[must_use]
    pub fn new() -> Self {
        QuantumSim {
            state: vec![Complex64::one()],

            id_map: FxHashMap::default(),

            op_buffer: OpBuffer {
                targets: vec![],
                ops: Array2::default((0, 0)),
            },
        }
    }
}

impl QuantumSimImpl for QuantumSim<FullState> {
    /// Utility that extends the internal state to make room for a newly allocated qubit.
    /// # Panics
    ///
    /// The function will panic if it is unable to allocate enough memory to expand the state vector.
    fn extend_state(&mut self) {
        // Double the size of existing state.
        self.state
            .resize_with(self.state.len() * 2, Complex64::zero);
    }

    /// Utility function that will swap states of two qubits throughout the state vector.
    /// # Panics
    ///
    /// The function will panic if the state vector is too big to describe with an isize.
    fn swap_qubits(&mut self, qubit1: usize, qubit2: usize) {
        if qubit1 == qubit2 {
            return;
        }
        let (q1, q2) = if qubit1 > qubit2 {
            (qubit2, qubit1)
        } else {
            (qubit1, qubit2)
        };
        let offset1 = 1_isize << q1;
        let offset2 = 1_isize << q2;
        let mask_k = offset1 - 1;
        let mask_j = ((offset2 >> 1) - 1) ^ mask_k;
        let mask_i = !((offset2 >> 1) - 1);

        // Unsafe trick to allow parallel access to the state vector below.
        let state_ptr = self.state.as_mut_ptr() as usize;

        // In parallel, swap entries in the state vector to correspond to the swapping of two qubits'
        // locations.
        let chunk_size = parallel_chunk_size(self.state.len() / 4);
        (0_isize..(self.state.len() / 4).try_into().unwrap())
            .into_par_iter()
            .chunks(chunk_size)
            .for_each(|indices| {
                for index in indices {
                    let k = index & mask_k;
                    let j = (index & mask_j) << 1;
                    let i = (index & mask_i) << 2;
                    let ptr = state_ptr as *mut Complex64;
                    unsafe {
                        std::ptr::swap(
                            ptr.offset(i + j + k + offset1),
                            ptr.offset(i + j + k + offset2),
                        );
                    };
                }
            });
    }

    /// Utility that cleans up state at the last location based on the given boolean value.
    fn cleanup_state(&mut self, res: bool) {
        self.state = if res {
            self.state[self.state.len() / 2..].to_vec()
        } else {
            self.state[0..self.state.len() / 2].to_vec()
        };
    }

    /// Utility function that performs the actual output of state (and optionally map) to stdout. Can
    /// be called internally from other functions to aid in debugging and does not perform any sorting of
    /// the state vector.
    fn dump_impl(&self, print_id_map: bool) {
        if print_id_map {
            println!("MAP: {:?}", self.id_map);
        };
        print!("STATE: [ ");
        self.state.iter().enumerate().for_each(|(index, val)| {
            if !val.is_zero() {
                print!("|{}\u{27e9}: {}, ", index, val);
            }
        });
        println!("]");
    }

    /// Utility that actually performs the application of the buffered unitary to the targets within the
    /// state vector.
    fn apply_impl(&mut self) {
        // Perform the Kronecker product and dot product to produce the new state vector.
        self.state = krondot(&self.op_buffer.ops, &self.state);
    }

    /// Utility to get the sum of all probabilities in the state vector where the given location bit is set.
    /// This corresponds to the probability that the qubit at that location is measured as |1⟩ in the
    /// computational basis.
    fn check_probability(&self, loc: usize) -> f64 {
        let offset = 1_usize << loc;
        let mask = !(offset - 1);
        let chunk_size = parallel_chunk_size(self.state.len() / 2);
        (0..(self.state.len() / 2))
            .into_par_iter()
            .chunks(chunk_size)
            .fold(
                || 0.0_f64,
                |prob, indices| {
                    prob + indices.iter().fold(0.0_f64, |accum, index| {
                        accum
                            + unsafe {
                                self.state
                                    .get_unchecked(
                                        (index % offset) + (offset + ((index & mask) << 1)),
                                    )
                                    .norm_sqr()
                            }
                    })
                },
            )
            .sum()
    }

    /// Utility to get the sum of all probabilies where an odd number of the bits at the given locations
    /// are set. This corresponds to the probability of jointly measuring those qubits in the computational
    /// basis.
    fn check_joint_probability(&self, locs: &[usize]) -> f64 {
        let mask = locs.iter().fold(0_usize, |accum, loc| accum | (1 << loc));
        let chunk_size = parallel_chunk_size(self.state.len());
        (0..self.state.len())
            .into_par_iter()
            .chunks(chunk_size)
            .fold(
                || 0.0_f64,
                |prob, indices| {
                    prob + indices.iter().fold(0.0_f64, |accum, &index| {
                        accum
                            + if (index & mask).count_ones() & 1 > 0 {
                                unsafe { self.state.get_unchecked(index).norm_sqr() }
                            } else {
                                0.0
                            }
                    })
                },
            )
            .sum()
    }

    /// Utility to perform the normalize of the state.
    fn normalize(&mut self) {
        let chunk_size = parallel_chunk_size(self.state.len());
        let scale = 1.0
            / self
                .state
                .par_chunks_exact(chunk_size)
                .fold(
                    || 0.0_f64,
                    |sum, chunk| {
                        sum + chunk
                            .iter()
                            .fold(0.0_f64, |accum, val| accum + val.norm_sqr())
                    },
                )
                .sum::<f64>()
                .sqrt();
        self.state
            .par_chunks_exact_mut(chunk_size)
            .for_each(|chunk| chunk.iter_mut().for_each(|c| *c *= scale));
    }

    /// Utility to collapse the probability at the given location based on the boolean value. This means
    /// that if the given value is 'true' then all indices in the state vector where the given location
    /// has a zero bit will be reduced to zero. Then the state vector is normalized.
    fn collapse(&mut self, loc: usize, val: bool) {
        let mask = 1_usize << loc;
        let state_val = if val { mask } else { 0 };
        let chunk_size = parallel_chunk_size(self.state.len());

        // Unsafe trick to allow parallel access to the state vector below.
        let state_ptr = self.state.as_mut_ptr() as usize;

        // Set all entries that don't match the mask to zero.
        (0..self.state.len())
            .into_par_iter()
            .chunks(chunk_size)
            .for_each(|indices| {
                for index in indices {
                    if index & mask != state_val {
                        let ptr = state_ptr as *mut Complex64;
                        unsafe {
                            *ptr.offset(index.try_into().unwrap()) = Complex64::zero();
                        };
                    };
                }
            });

        // Re-normalize the state vector.
        self.normalize();
    }

    /// Utility to collapse the joint probability of a particular set of locations in the state vector.
    /// The entries that do not correspond to the given boolean value are set to zero, and then the whole
    /// state is normalized.
    fn joint_collapse(&mut self, locs: &[usize], val: bool) {
        let mask = locs.iter().fold(0_usize, |accum, loc| accum | (1 << loc));
        let chunk_size = parallel_chunk_size(self.state.len());

        // Unsafe trick to allow parallel access to the state vector below.
        let state_ptr = self.state.as_mut_ptr() as usize;

        (0..self.state.len())
            .into_par_iter()
            .chunks(chunk_size)
            .for_each(|indices| {
                for &index in &indices {
                    if ((index & mask).count_ones() & 1 > 0) != val {
                        let ptr = state_ptr as *mut Complex64;
                        unsafe {
                            *ptr.offset(index.try_into().unwrap()) = Complex64::zero();
                        };
                    }
                }
            });

        // Re-normalize the state vector.
        self.normalize();
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::common_matrices::{h, x};
    use ndarray::array;
    use num_traits::One;

    /// Compare the behavior of the `krondot` to manually invoking the `ndarray::linalg::kron` and
    /// `dot` functions.
    #[test]
    fn test_krondot() {
        assert_eq!(
            ndarray::linalg::kron(&Array2::eye(2_usize.pow(0)), &x())
                .dot(&(array![Complex64::one(), Complex64::zero()]).view())
                .to_vec(),
            krondot(&x(), &[Complex64::one(), Complex64::zero()])
        );

        assert_eq!(
            ndarray::linalg::kron(&Array2::eye(2_usize.pow(1)), &h())
                .dot(
                    &(array![
                        Complex64::one(),
                        Complex64::zero(),
                        Complex64::zero(),
                        Complex64::zero()
                    ])
                    .view()
                )
                .to_vec(),
            krondot(
                &h(),
                &[
                    Complex64::one(),
                    Complex64::zero(),
                    Complex64::zero(),
                    Complex64::zero()
                ]
            )
        );
    }
}
