/// Implementations of the `HierBasisFnSpace` Trait (MaxOrthoShapeFn can be added as a Feature on the Nightly Toolchain)
pub mod hierarchical_basis_fns;

use super::domain::mesh::{
    elem::Elem,
    space::{M2D, V2D},
};
use crate::fem_problem::integration::glq::{gauss_quadrature_points, scale_gauss_quad_points};

use std::collections::HashMap;
use std::sync::{Arc, Mutex};

/// A Trait to define the functional space used to compose a [HierCurlBasisFn]
///
/// This Trait requires that the T (tangential) and N (normal) vector spaces be defined over the given set of `points` (which fall within the range [-1, 1]) up to the given expansion order.
///
/// The first (and optionally the second) derivative of the spaces must also be defined over the same set of points.
///
pub trait HierCurlBasisFnSpace: Clone + Sync + Send + std::fmt::Debug {
    /// Define a set of shapes (or functions) over a set of points
    ///
    /// # Arguments
    /// * `max_order` - The Order up to which the Shape Functions must be defined (Corresponds to the `n` argument in the methods below)
    /// * `points` - The set of points to evaluate the shape functions over
    /// * `compute_d2` - Whether to compute the second derivative of the shape functions
    ///
    /// If `compute_d2` is `false`, then calls to `norm_d2` and `tang_d2` should not be expected
    ///
    fn with(max_order: usize, points: &[f64], compute_d2: bool) -> Self;

    /// Sample the `n`th order normally orientated shape function at the given point `p`
    fn norm(&self, n: usize, p: usize) -> f64;
    /// Sample the 1st derivative of the `n`th order normally orientated shape function at the given point `p`
    fn norm_d1(&self, n: usize, p: usize) -> f64;
    /// Sample the 2nd derivative of the `n`th order normally orientated shape function at the given point `p`. This method does not have to return if `compute_d2` was set to `false`.
    fn norm_d2(&self, n: usize, p: usize) -> f64;

    /// Sample the `n`th order tangentially orientated shape function at the given point `p`
    fn tang(&self, n: usize, p: usize) -> f64;
    /// Sample the 1st derivative of the `n`th order tangentially orientated shape function at the given point `p`
    fn tang_d1(&self, n: usize, p: usize) -> f64;
    /// Sample the 2nd derivative of the `n`th order tangentially orientated shape function at the given point `p`. This method does not have to return if `compute_d2` was set to `false`.
    fn tang_d2(&self, n: usize, p: usize) -> f64;
}

/// A utility structure used to generate and cache [HierBasisFn]'s during Integration
///
/// Three important areas of functionality are handled here:
/// * Convert `BasisSpec`s (abstract basis function definitions on an [Elem]) into [HierBasisFn]'s (numerical instantiations of basis functions for integration)
/// * Handle interpolation between parametric spaces to set up inter-layer integration
/// * Cache computed [HierBasisFn]'s to avoid re-computation.
///
pub struct BasisFnSampler<B: HierBasisFn> {
    /// Maximum u-directed expansion order. [HierBasisFn]s Generated by this sampler will be defined up to this order in `i`
    pub i_max: usize,
    /// Maximum v-directed expansion order. [HierBasisFn]s Generated by this sampler will be defined up to this order in `i`              
    pub j_max: usize,
    /// Whether the 2nd derivatives of the [HierBasisFn]s will be computed. If true, the endpoints [-1, +1] are also defined along both axes!                 
    pub compute_d2: bool,
    /// Gauss Legendre Quadrature points evaluated along u-direction. Defined from (-1 to +1)         
    u_points: Vec<f64>,
    /// Gauss Legendre Quadrature points evaluated along v-direction. Defined from (-1 to +1)
    v_points: Vec<f64>,

    computed: Arc<Mutex<HashMap<BSDescription, Arc<B>>>>,
}

impl<B: HierBasisFn> BasisFnSampler<B> {
    /// Construct a Basis Function Sampler with the following parameters:
    ///
    /// * `i_max` : maximum u-directed expansion order
    /// * `j_max` : maximum v-directed expansion order
    /// * `num_u_points` : number of Gauss-Leg-Quad points to define along the u-axis. If `None`: a default value is used
    /// * `num_v_points` : number of Gauss-Leg-Quad points to define along the v-axis. If `None`: a default value is used
    /// * `compute_2nd_derivs`: whether or not to compute the second derivatives of the [HierBasisFn]s.
    pub fn with(
        i_max: usize,
        j_max: usize,
        num_u_points: Option<usize>,
        num_v_points: Option<usize>,
        compute_2nd_derivs: bool,
    ) -> (Self, [Vec<f64>; 2]) {
        let (u_points, u_weights) = gauss_quadrature_points(
            num_u_points.unwrap_or_else(|| default_ngq(i_max)),
            compute_2nd_derivs,
        );
        let (v_points, v_weights) = gauss_quadrature_points(
            num_v_points.unwrap_or_else(|| default_ngq(j_max)),
            compute_2nd_derivs,
        );

        (
            Self {
                i_max,
                j_max,
                compute_d2: compute_2nd_derivs,
                u_points,
                v_points,
                computed: Arc::new(Mutex::new(HashMap::new())),
            },
            [u_weights, v_weights],
        )
    }

    /// Generate or retrieve a [HierBasisFn] defined over an [Elem]. Can be defined over a subset of the `Elem`.
    pub fn sample_basis_fn(&mut self, elem: &Elem, over_desc_elem: Option<&Elem>) -> Arc<B> {
        let desc = BSDescription::new(elem, over_desc_elem);

        match self.computed.lock() {
            Ok(mut comp_guard) => {
                if let Some(computed_bs) = comp_guard.get(&desc) {
                    computed_bs.clone()
                } else {
                    let bs = B::defined_over(
                        elem,
                        over_desc_elem,
                        [&self.u_points, &self.v_points],
                        [self.i_max, self.j_max],
                        self.compute_d2,
                    );
                    comp_guard.insert(desc.clone(), Arc::new(bs));
                    comp_guard.get(&desc).unwrap().clone()
                }
            }
            // fallback on computing directly, if MutexGuard is not available.
            Err(_) => Arc::new(B::defined_over(
                elem,
                over_desc_elem,
                [&self.u_points, &self.v_points],
                [self.i_max, self.j_max],
                self.compute_d2,
            )),
        }
    }
}

impl<B: HierBasisFn> Clone for BasisFnSampler<B> {
    fn clone(&self) -> Self {
        Self {
            i_max: self.i_max,
            j_max: self.j_max,
            compute_d2: self.compute_d2,
            u_points: self.u_points.clone(),
            v_points: self.v_points.clone(),
            computed: self.computed.clone(),
        }
    }
}

#[derive(Hash, PartialEq, Eq, Clone, Debug)]
// unique description of a basis sample.
// TODO: update this struct to be more robust to curvilinear Elements
struct BSDescription {
    space: [usize; 2],
    sample: Option<[usize; 2]>,
    base_id: usize,
    desc_id: Option<usize>,
}

impl BSDescription {
    pub fn new(elem: &Elem, sampled_over: Option<&Elem>) -> Self {
        Self {
            space: [elem.nodes[0], elem.nodes[3]],
            sample: sampled_over.map(|so_elem| [so_elem.nodes[0], so_elem.nodes[3]]),
            base_id: elem.id,
            desc_id: sampled_over.map(|so_elem| so_elem.id),
        }
    }
}

// 4 * the maximum order (rounded up to the nearest power of 2)
fn default_ngq(max_order: usize) -> usize {
    let conv = (max_order * 4) as f32;
    let conv_p2 = conv.log2().ceil() as i32;

    (2.0_f32).powi(conv_p2).round() as usize
}

/// A Trait for any Hierarchical Basis Function Defined over some [Elem] (and optionally sub-sampled over one of its descendant [Elem]s)
pub trait HierBasisFn {
    /// The Shared Basis Function Constructor
    ///
    /// Basis functions must be defined over individual [Elem]s in the `Domain`, and due to the RBS structure of the `Mesh`, we must be able to sample sub-regions of Basis Functions over descendant [Elem]s
    ///
    /// Basis functions are defined over a region in parametric space given by `uv_points`. This function expects that both slices in `uv_points` are defined on the region (-1, 1) or [-1, 1]. The points will be scaled as necessary if the Basis Function is to be sampled over a descendant [Elem]
    ///
    /// A set of maximum expansion orders must be given, along with an option for computing the 2nd derivative.
    ///
    fn defined_over(
        elem: &Elem,
        desc_elem: Option<&Elem>,
        uv_points: [&[f64]; 2],
        ij_orders: [usize; 2],
        compute_d2: bool,
    ) -> Self;
}

/// A Hierarchical-Type Curl-Conforming Vectorial Basis Function
///
/// This basis function has the following vectorial components in the u, v and w directions:
/// * `F_u(u, v, i, j) = N_i(u) * T_j(v) * J^-1_u(u, v)`
/// * `F_v(u, v, i, j) = T_i(u) * N_j(v) * J^-1_v(u, v)`
/// * `F_w(u, v, i, j) = T_i(u) * T_j(v) * J_z(u, v)`   (Not Yet Implemented)
///
/// Where the Functions N, and T are the Normal and Tangentially directed Function spaces defined by the [HierCurlBasisFnSpace]. This structure is Generic over any [HierCurlBasisFnSpace].
///
/// The Jacobian is defined by the [Elem]'s mapping to real space (and the mapping between the [Elem]s and its descendant, in the case of sub-sampling)
///
#[derive(Clone, Debug)]
pub struct HierCurlBasisFn<BSpace: HierCurlBasisFnSpace> {
    /// Transformation matrices (or Jacobians) at each sample point. Describes transformation from real space to sampled parametric space
    pub jac: Vec<Vec<M2D>>,
    // Inverse of transformation matrices at each sample point
    pub jac_inv: Vec<Vec<M2D>>,
    /// Determinants of the "Sampling Jacobian" at each point
    pub det_jac: Vec<Vec<f64>>,
    /// Parametric scaling factors (used to scale derivatives in parametric space as necessary)
    pub para_scale: V2D,
    u_shapes: BSpace,
    v_shapes: BSpace,
}

impl<BSpace: HierCurlBasisFnSpace> HierCurlBasisFn<BSpace> {
    /// Evaluate the u-directed basis function at some point (m, n)
    pub fn f_u(&self, [i, j]: [usize; 2], [m, n]: [usize; 2]) -> V2D {
        self.jac_inv[m][n].u * self.u_shapes.norm(i, m) * self.v_shapes.tang(j, n)
    }

    /// Evaluate the v-directed basis function at some point (m, n)
    pub fn f_v(&self, [i, j]: [usize; 2], [m, n]: [usize; 2]) -> V2D {
        self.jac_inv[m][n].v * self.u_shapes.tang(i, m) * self.v_shapes.norm(j, n)
    }

    /// Evaluate the first derivative of the u-directed basis with respect to another `Elem`'s parametric space
    pub fn f_u_d1(&self, [i, j]: [usize; 2], [m, n]: [usize; 2], para_scale: &V2D) -> V2D {
        self.jac_inv[m][n].u
            * V2D::from([
                self.u_shapes.norm(i, m) * self.v_shapes.tang_d1(j, n),
                self.u_shapes.norm_d1(i, m) * self.v_shapes.tang(j, n),
            ])
            * para_scale
    }

    /// Evaluate the first derivative of the v-directed basis with respect to another `Elem`'s parametric space
    pub fn f_v_d1(&self, [i, j]: [usize; 2], [m, n]: [usize; 2], para_scale: &V2D) -> V2D {
        self.jac_inv[m][n].v
            * V2D::from([
                self.u_shapes.tang(i, m) * self.v_shapes.norm_d1(j, n),
                self.u_shapes.tang_d1(i, m) * self.v_shapes.norm(j, n),
            ])
            * para_scale
    }

    /// Evaluate the second derivative of the u-directed basis with respect to another `Elem`'s parametric space
    pub fn f_u_d2(&self, [i, j]: [usize; 2], [m, n]: [usize; 2], para_scale: &V2D) -> V2D {
        self.jac_inv[m][n].u
            * V2D::from([
                self.u_shapes.norm(i, m) * self.v_shapes.tang_d2(j, n),
                self.u_shapes.norm_d2(i, m) * self.v_shapes.tang(j, n),
            ])
            * para_scale
            * para_scale
    }

    /// Evaluate the second derivative of the v-directed basis with respect to another `Elem`'s parametric space
    pub fn f_v_d2(&self, [i, j]: [usize; 2], [m, n]: [usize; 2], para_scale: &V2D) -> V2D {
        self.jac_inv[m][n].v
            * V2D::from([
                self.u_shapes.tang(i, m) * self.v_shapes.norm_d2(j, n),
                self.u_shapes.tang_d2(i, m) * self.v_shapes.norm(j, n),
            ])
            * para_scale
            * para_scale
    }

    /// Evaluate the gradient of the u-directed basis with respect to another `Elem`'s parametric space
    pub fn f_u_dd(&self, [i, j]: [usize; 2], [m, n]: [usize; 2], para_scale: &V2D) -> V2D {
        self.jac_inv[m][n].u
            * self.u_shapes.norm_d1(i, m)
            * self.v_shapes.tang_d1(j, n)
            * para_scale[0]
            * para_scale[1]
    }

    /// Evaluate the gradient of the v-directed basis with respect to another `Elem`'s parametric space
    pub fn f_v_dd(&self, [i, j]: [usize; 2], [m, n]: [usize; 2], para_scale: &V2D) -> V2D {
        self.jac_inv[m][n].v
            * self.u_shapes.tang_d1(i, m)
            * self.v_shapes.norm_d1(j, n)
            * para_scale[0]
            * para_scale[1]
    }

    #[inline]
    /// The size of the parametric area relative to the unit-parametric area
    pub fn glq_scale(&self) -> f64 {
        self.para_scale[0] * self.para_scale[1]
    }

    #[inline]
    /// The size of the parametric space relative to the unit-parametric space (along a single edge)
    pub fn edge_glq_scale(&self, edge_idx: usize) -> f64 {
        match edge_idx {
            0 | 1 => self.para_scale[0],
            2 | 3 => self.para_scale[1],
            _ => panic!("edge_idx must not exceed 3; cannot get glq scaling factor!"),
        }
    }

    #[inline]
    /// The scale of the u-axis relative to the unit parametric space
    pub fn u_glq_scale(&self) -> f64 {
        self.para_scale[0]
    }

    #[inline]
    /// The scale of the v-axis relative to the unit parametric space
    pub fn v_glq_scale(&self) -> f64 {
        self.para_scale[1]
    }

    #[inline]
    /// The scale of both axes relative to the unit parametric space
    pub fn deriv_scale(&self) -> &V2D {
        &self.para_scale
    }

    #[inline]
    /// The determinant of the Jacobian at some point (m, n)
    pub fn sample_scale(&self, [m, n]: [usize; 2]) -> f64 {
        self.det_jac[m][n]
    }

    /// Maximum of `uv_ratio` and `vu_ratio`
    pub fn max_uv_ratio(&self, [m, n]: [usize; 2]) -> f64 {
        let r0 = self.jac[m][n].u[0] / self.jac[m][n].v[1];
        let r1 = self.jac[m][n].u[0] / self.jac[m][n].v[1];
        std::cmp::max_by(r0, r1, |a, b| a.partial_cmp(b).unwrap())
    }
    /// The ratio of the du/dx to dv/dy at some point (m, n)
    pub fn uv_ratio(&self, [m, n]: [usize; 2]) -> f64 {
        self.jac[m][n].u[0] / self.jac[m][n].v[1]
    }

    /// The ratio of the dv/dy to du/dx at some point (m, n)
    pub fn vu_ratio(&self, [m, n]: [usize; 2]) -> f64 {
        self.jac[m][n].v[1] / self.jac[m][n].u[0]
    }
}

impl<BSpace: HierCurlBasisFnSpace> HierBasisFn for HierCurlBasisFn<BSpace> {
    /// Create a Basis Function instance defined over some `Elem` (and optionally mapped over some descendant `Elem`)
    ///
    /// # Arguments
    /// * `elem` : the element to sample over
    /// * `desc_elem` : the descendant element to sample over. If `None`, the BasisFn is defined over the entire `Elem`
    ///* `u_points` : the glq points defined over (-1.0, 1.0) for the u-axis
    /// * `v_points` : the glq points defined over (-1.0, 1.0) for the v-axis
    /// * `i_max` : maximum u-directed expansion order
    /// * `j_max` : maximum v-directed expansion order
    /// * `compute_2d` : whether or not to compute the second derivatives of the Basis Functions
    ///
    /// If a descendant `Elem` is provided, the `raw_points` are mapped (according to GLQ rules) to match the parametric bounds of the descendant Elem
    ///
    fn defined_over(
        elem: &Elem,
        desc_elem: Option<&Elem>,
        [u_points, v_points]: [&[f64]; 2],
        [i_max, j_max]: [usize; 2],
        compute_d2: bool,
    ) -> Self {
        let [(u_glq_scale, u_points_scaled), (v_glq_scale, v_points_scaled)] = match desc_elem {
            Some(desc_elem_ref) => {
                if desc_elem_ref.id == elem.id {
                    [(1.0, u_points.to_vec()), (1.0, v_points.to_vec())]
                } else {
                    let child_parametric_range = desc_elem_ref.relative_parametric_range(elem.id);
                    [
                        scale_gauss_quad_points(
                            u_points,
                            child_parametric_range[0][0],
                            child_parametric_range[0][1],
                        ),
                        scale_gauss_quad_points(
                            v_points,
                            child_parametric_range[1][0],
                            child_parametric_range[1][1],
                        ),
                    ]
                }
            }
            None => [(1.0, u_points.to_vec()), (1.0, v_points.to_vec())],
        };

        let t: Vec<Vec<M2D>> = u_points_scaled
            .iter()
            .map(|u| {
                v_points_scaled
                    .iter()
                    .map(|v| elem.parametric_mapping(V2D::from([*u, *v]), elem.parametric_range()))
                    .collect()
            })
            .collect();

        let ti: Vec<Vec<M2D>> = t
            .iter()
            .map(|row| row.iter().map(|v| v.inverse()).collect())
            .collect();

        let dt: Vec<Vec<f64>> = t
            .iter()
            .map(|row| row.iter().map(|v| v.det()).collect())
            .collect();

        Self {
            jac: t,
            jac_inv: ti,
            det_jac: dt,
            para_scale: V2D::from([u_glq_scale, v_glq_scale]),
            u_shapes: BSpace::with(i_max, &u_points_scaled, compute_d2),
            v_shapes: BSpace::with(j_max, &v_points_scaled, compute_d2),
        }
    }
}
