use crate::internal::*;
use ndarray::*;

use tract_linalg::mmm::{FusedSpec, MatMatMul, MatrixStoreSpec, RoundingPolicy, ScratchSpace};

#[derive(PartialEq, Clone, Hash, Debug)]
pub enum ProtoFusedSpec {
    Min(AttrOrInput),
    Max(AttrOrInput),
    PerRowMul(AttrOrInput),
    PerRowAdd(AttrOrInput),
    PerColMul(AttrOrInput),
    PerColAdd(AttrOrInput),
    AddRowColProducts(AttrOrInput, AttrOrInput),
    ScalarMul(AttrOrInput),
    ScalarAdd(AttrOrInput),
    AddUnicast(AttrOrInput),
    QScale(usize, RoundingPolicy, i32),
}

impl ProtoFusedSpec {
    pub fn resolve<'t>(&'t self, inputs: &'t [Arc<Tensor>]) -> FusedSpec<'t> {
        match self {
            ProtoFusedSpec::Min(v) => FusedSpec::Min(v.tensor(inputs)),
            ProtoFusedSpec::Max(v) => FusedSpec::Max(v.tensor(inputs)),
            ProtoFusedSpec::PerColAdd(v) => FusedSpec::PerColAdd(v.tensor(inputs)),
            ProtoFusedSpec::PerRowAdd(v) => FusedSpec::PerRowAdd(v.tensor(inputs)),
            ProtoFusedSpec::PerColMul(v) => FusedSpec::PerColMul(v.tensor(inputs)),
            ProtoFusedSpec::PerRowMul(v) => FusedSpec::PerRowMul(v.tensor(inputs)),
            ProtoFusedSpec::ScalarMul(v) => FusedSpec::ScalarMul(v.tensor(inputs)),
            ProtoFusedSpec::ScalarAdd(v) => FusedSpec::ScalarAdd(v.tensor(inputs)),
            ProtoFusedSpec::AddRowColProducts(row, col) => {
                FusedSpec::AddRowColProducts(row.tensor(inputs), col.tensor(inputs))
            }
            ProtoFusedSpec::AddUnicast(v) => FusedSpec::AddUnicast(v.tensor(inputs).view()),
            ProtoFusedSpec::QScale(s, rp, m) => FusedSpec::QScale(*s, *rp, *m),
        }
    }
}

#[derive(Clone, Debug, Hash)]
pub struct ConcreteMatMulGeometry {
    pub m: usize,
    pub k: usize,
    pub n: usize,
    pub b_storage: MatrixStoreSpec,
}

#[derive(Clone, Debug, Hash)]
pub struct SymbolicMatMulGeometry {
    pub m: TDim,
    pub k: TDim,
    pub n: TDim,
    pub mmm: Box<dyn MatMatMul>,
    pub b_datum_type: DatumType,
}

impl ResolveTo<ConcreteMatMulGeometry> for SymbolicMatMulGeometry {
    type Param = SymbolValues;
    fn resolve(&self, param: &Self::Param) -> TractResult<ConcreteMatMulGeometry> {
        let m = self.m.eval(param).to_usize()?;
        let k = self.k.eval(param).to_usize()?;
        let n = self.n.eval(param).to_usize()?;
        let b_storage = unsafe { self.mmm.b_packed(self.b_datum_type.size_of(), k) };
        Ok(ConcreteMatMulGeometry { m, k, n, b_storage })
    }
}

pub type MatMulGeometry = GeometryBound<SymbolicMatMulGeometry, ConcreteMatMulGeometry>;

impl MatMulGeometry {
    fn m(&self) -> Cow<TDim> {
        match self {
            Self::Symbolic(it) => Cow::Borrowed(&it.m),
            Self::Concrete(it) => Cow::Owned(it.m.to_dim()),
        }
    }

    fn k(&self) -> Cow<TDim> {
        match self {
            Self::Symbolic(it) => Cow::Borrowed(&it.k),
            Self::Concrete(it) => Cow::Owned(it.k.to_dim()),
        }
    }
}

#[derive(Clone, Educe, Debug)]
#[educe(Hash)]
pub struct LirMatMulUnary {
    pub c_fact: TypedFact,
    pub c_m_axis: usize,
    pub c_n_axis: usize,
    pub micro_ops: ArrayD<(Arc<Tensor>, Vec<ProtoFusedSpec>)>,
    pub c_final_shape: ShapeFact,
    pub geometry: MatMulGeometry,
    #[educe(Hash(method = "hash_mmm"))]
    pub mmm: Box<dyn MatMatMul>,
    pub reshape_post: Vec<AxisOp>,
}

fn hash_mmm<H: std::hash::Hasher>(mmm: &Box<dyn MatMatMul>, state: &mut H) {
    mmm.type_id().hash(state)
}

impl DynHash for LirMatMulUnary {
    fn dyn_hash(&self, hasher: &mut dyn std::hash::Hasher) {
        dyn_hash(&self, hasher)
    }
}

impl Op for LirMatMulUnary {
    fn name(&self) -> Cow<str> {
        "LirMatMulUnary".into()
    }

    fn info(&self) -> TractResult<Vec<String>> {
        let mut infos = vec![format!(
            "c_shape:{:?}, c_m_axis:{} c_n_axis:{} b_storage:{:?}",
            self.c_fact, self.c_m_axis, self.c_n_axis, self.geometry,
        )];
        infos.push(format!("Mult: {}", self.mmm));
        infos.push(format!("Ops: {:?}", self.micro_ops));
        Ok(infos)
    }

    op_core_lir!();
    op_as_typed_op!();
}

#[derive(Clone, Debug)]
struct State;
impl OpState for State {
    fn eval(
        &mut self,
        session: &mut SessionState,
        op: &dyn Op,
        inputs: TVec<Arc<Tensor>>,
    ) -> TractResult<TVec<Arc<Tensor>>> {
        let op = op.downcast_ref::<LirMatMulUnary>().unwrap();
        let shape = op.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
        let final_shape = op.c_final_shape.eval_to_usize(&session.resolved_symbols)?;
        unsafe {
            let geometry = op.geometry.to_concrete(&session.resolved_symbols)?;
            if session
                .cached_mmm_scratch_space
                .as_deref()
                .map(|scratch| op.mmm.can_use_scratch_space(scratch))
                == Some(false)
            {
                session.cached_mmm_scratch_space = None
            }
            let scratch = session
                .cached_mmm_scratch_space
                .get_or_insert_with(|| op.mmm.allocate_scratch_space());
            eval(
                op,
                &geometry,
                scratch.as_mut(),
                &inputs,
                &shape,
                op.c_m_axis,
                op.c_n_axis,
                &*final_shape,
            )
        }
    }
}

impl EvalOp for LirMatMulUnary {
    fn is_stateless(&self) -> bool {
        self.geometry.is_concrete()
    }

    fn state(
        &self,
        _session: &mut SessionState,
        _node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        Ok(Some(Box::new(State)))
    }

    fn eval(&self, inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        let geometry = self.geometry.to_concrete(&SymbolValues::default())?;
        let mut scratch = unsafe { self.mmm.allocate_scratch_space() };
        eval(
            self,
            &geometry,
            scratch.as_mut(),
            &*inputs,
            self.c_fact.shape.as_concrete().unwrap(),
            self.c_m_axis,
            self.c_n_axis,
            self.c_final_shape.as_concrete().unwrap(),
        )
    }
}

fn eval(
    op: &LirMatMulUnary,
    geometry: &ConcreteMatMulGeometry,
    scratch: &mut dyn ScratchSpace,
    inputs: &[Arc<Tensor>],
    c_shape: &[usize],
    c_m_axis: usize,
    c_n_axis: usize,
    c_final_shape: &[usize],
) -> TractResult<TVec<Arc<Tensor>>> {
    unsafe {
        let a_dt = op.micro_ops.iter().next().unwrap().0.datum_type();
        let mut c = Tensor::uninitialized_dt(op.c_fact.datum_type, &c_shape)?;
        let c_storage = op.mmm.c_view_with_axis(c_m_axis, c_n_axis);
        if op
            .c_fact
            .shape
            .iter()
            .enumerate()
            .any(|(ix, d)| ix != c_m_axis && ix != c_n_axis && d != 1.to_dim())
        {
            let mut looping_shape: TVec<usize> = c_shape.into();
            looping_shape[c_m_axis] = 1;
            looping_shape[c_n_axis] = 1;
            for prefix in indices(&*looping_shape) {
                let mut ops = op.micro_ops.view();
                let mut b_prefix = tvec!();
                let mut c_view = c.view();
                for (ix, &dim) in prefix.slice().iter().enumerate() {
                    if ix != c_m_axis && ix != c_n_axis {
                        ops.index_axis_inplace(Axis(0), dim.min(ops.shape()[0] - 1));
                        b_prefix.push(dim);
                    }
                    c_view.offset_axis_unchecked(ix, dim as isize);
                }
                let (pa, fused) = ops.iter().next().unwrap();
                let f: Vec<FusedSpec> = fused.iter().map(|f| f.resolve(inputs)).collect::<Vec<_>>();
                op.mmm.run_with_scratch_space(
                    geometry.m,
                    geometry.k,
                    geometry.n,
                    scratch,
                    &op.mmm.a_packed(a_dt.size_of(), geometry.k).wrap(&pa.view()),
                    &geometry
                        .b_storage
                        .wrap(&TensorView::at_prefix_unchecked(&inputs[0], &*b_prefix)),
                    &mut c_storage.wrap(&c_view),
                    &f,
                )?;
            }
        } else {
            let (pa, fused) = op.micro_ops.iter().next().unwrap();
            let f: Vec<FusedSpec> = fused.iter().map(|f| f.resolve(inputs)).collect::<Vec<_>>();
            op.mmm.run_with_scratch_space(
                geometry.m,
                geometry.k,
                geometry.n,
                scratch,
                &op.mmm.a_packed(a_dt.size_of(), geometry.k).wrap(&pa.view()),
                &geometry.b_storage.wrap(&inputs[0].view()),
                &mut c_storage.wrap(&c.view_mut()),
                &f,
            )?;
        }
        c.set_shape_unchecked(c_final_shape);
        Ok(tvec!(c.into_arc_tensor()))
    }
}

impl TypedOp for LirMatMulUnary {
    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let c_prefix_len = self.c_fact.rank() - 2;
        if self.micro_ops.ndim() != c_prefix_len {
            bail!(
                "Constant A table and c_prefix should have the same len. (resp {} and {})",
                self.micro_ops.ndim(),
                c_prefix_len
            );
        }
        let mut fact = self.c_fact.clone();
        fact.shape = self.c_final_shape.clone();
        Ok(tvec!(fact))
    }

    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
        let sums: TDim = self.c_fact.shape.iter().product();
        Ok(tvec!(
            (Cost::FMA(self.mmm.internal_type()), sums * self.geometry.k().as_ref()),
            (
                Cost::Params(self.micro_ops.as_slice().unwrap()[0].0.datum_type()),
                self.micro_ops.iter().fold(0.to_dim(), |sum, a| sum + a.0.len())
            )
        ))
    }

    fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
        use crate::ops;
        if node.outputs.len() != 1
            || node.outputs[0].successors.len() != 1
            || model.output_outlets()?.iter().any(|outlet| outlet.node == node.id)
        {
            return Ok(None);
        }
        let succ = model.node(node.outputs[0].successors[0].node);
        if let Some(op) = succ.op_as::<ops::AxisOp>() {
            if op.only_shape() {
                let mut reshape_post = self.reshape_post.clone();
                reshape_post.push(op.clone());
                let mut patch = TypedModelPatch::fuse_with_next(
                    model,
                    &node,
                    Self {
                        c_final_shape: succ.outputs[0].fact.shape.clone(),
                        reshape_post,
                        ..self.clone()
                    },
                )?;
                patch.dont_apply_twice = Some(format!("Fuse {} into {}", succ, node));
                return Ok(Some(patch));
            }
        }

        let merge = |fused_micro_op: &ArrayD<Vec<ProtoFusedSpec>>,
                     additional_inputs: &[OutletId]|
         -> TractResult<Option<TypedModelPatch>> {
            let mut new_op = self.clone();
            new_op
                .micro_ops
                .zip_mut_with(fused_micro_op, |lhs, rhs| lhs.1.extend(rhs.iter().cloned()));
            let mut patch = TypedModelPatch::new(format!("fusing {}", succ));
            patch.dont_apply_twice = Some(format!("Fuse {} into {}", succ.name, node.name));
            let inputs = node
                .inputs
                .iter()
                .chain(additional_inputs.iter())
                .map(|i| patch.tap_model(model, *i))
                .collect::<TractResult<TVec<OutletId>>>()?;
            let output = patch.wire_node(&node.name, new_op, &inputs)?;
            patch.shunt_outside(model, succ.id.into(), output[0])?;
            Ok(Some(patch))
        };

        let merge_broadcast = |spec: &[ProtoFusedSpec], additional_inputs: &[OutletId]| {
            let array = arr0(spec.to_vec()).into_dyn();
            merge(&array, additional_inputs)
        };

        if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
            if let Some(cast) = op.downcast_ref::<ops::cast::Cast>().map(|cast| cast.to) {
                if cast == i8::datum_type() && self.c_fact.datum_type == i32::datum_type() {
                    let at = self.micro_ops.iter().nth(0).unwrap().0.datum_type();
                    let bt = model.outlet_fact(node.inputs[0])?.datum_type;
                    let mmm = tract_linalg::ops()
                        .mmm(
                            at,
                            bt,
                            i8::datum_type(),
                            self.c_fact.shape[self.c_m_axis].to_usize().ok(),
                            None,
                            self.c_fact.shape[self.c_n_axis].to_usize().ok(),
                        )
                        .unwrap();

                    let c_fact = TypedFact::dt_shape(i8::datum_type(), self.c_fact.shape.clone());
                    let mut patch = TypedModelPatch::fuse_with_next(
                        model,
                        &node,
                        Self { mmm, c_fact, ..self.clone() },
                    )?;
                    patch.dont_apply_twice = Some(format!("Fuse {} into {}", succ, node));
                    return Ok(Some(patch));
                }
            } else if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
                return merge_broadcast(
                    &[ProtoFusedSpec::QScale(op.shift, op.policy, op.mult)],
                    &[],
                );
            }
        } else if let Some(op) = succ.op_as::<ops::binary::UnaryOp>() {
            if op.a.len() == 1 {
                if op.mini_op.is::<ops::math::Max>() {
                    return merge_broadcast(&[ProtoFusedSpec::Max((&op.a).into())], &[]);
                } else if op.mini_op.is::<ops::math::Min>() {
                    return merge_broadcast(&[ProtoFusedSpec::Min((&op.a).into())], &[]);
                } else if op.mini_op.is::<ops::math::Mul>() {
                    return merge_broadcast(&[ProtoFusedSpec::ScalarMul((&op.a).into())], &[]);
                } else {
                    return Ok(None);
                }
            }
            let mut arg = op.a.clone().into_tensor();
            for axis_change in self.reshape_post.iter().rev() {
                axis_change.recip().change_tensor(&mut arg, true)?;
            }
            if arg.shape()[self.c_n_axis] == 1
                && arg.shape()[self.c_m_axis].to_dim() == self.c_fact.shape[self.c_m_axis]
                && (op.mini_op.is::<ops::math::Mul>() || op.mini_op.is::<ops::math::Add>())
            {
                let prefix = arg
                    .shape()
                    .iter()
                    .enumerate()
                    .filter_map(|(ix, d)| {
                        if ix != self.c_n_axis && ix != self.c_m_axis {
                            Some(*d)
                        } else {
                            None
                        }
                    })
                    .collect::<TVec<_>>();

                assert_eq!(
                    prefix.iter().cloned().product::<usize>()
                        * self.geometry.m().to_usize().unwrap(),
                    arg.len()
                );

                let arg_len = arg.len();
                let arg = arg.into_shape(&[arg_len])?;
                let mut i = 0;
                let arg = ArrayD::from_shape_simple_fn(&*prefix, || {
                    let t = arg
                        .slice(
                            0,
                            i * self.geometry.m().to_usize().unwrap(),
                            (i + 1) * self.geometry.m().to_usize().unwrap(),
                        )
                        .unwrap();
                    i += 1;
                    if op.mini_op.is::<ops::math::Mul>() {
                        vec![ProtoFusedSpec::PerRowMul(t.into())]
                    } else if op.mini_op.is::<ops::math::Add>() {
                        vec![ProtoFusedSpec::PerRowAdd(t.into())]
                    } else {
                        unreachable!()
                    }
                });
                return merge(&arg, &[]);
            }
        } else if let Some(op) = succ.op_as::<ops::binary::MergeOpUnicast>() {
            if self.c_n_axis == self.c_final_shape.rank() - 2
                && self.c_m_axis == self.c_final_shape.rank() - 1
                && self.micro_ops.len() == 1
            {
                let other_slot = 1 - node.outputs[0].successors[0].slot;
                let other_input = succ.inputs[other_slot];

                if op.0.is::<ops::math::Add>() {
                    return merge_broadcast(
                        &[ProtoFusedSpec::AddUnicast(node.inputs.len().into())],
                        &[other_input],
                    );
                }
            }
        };
        Ok(None)
    }

    as_op!();
}
