// [[file:../fire.note::*imports][imports:1]]
use crate::common::*;

use vecfx::*;
// imports:1 ends here

// [[file:../fire.note::*input/output][input/output:1]]
/// Evaluated function value and gradient
#[derive(Debug, Clone)]
pub struct Output {
    pub fx: f64,
    pub gx: Vec<f64>,
}

pub type Input<'a> = &'a [f64];

impl Output {
    fn new(n: usize) -> Self {
        use std::f64::NAN;

        Self {
            fx: NAN,
            gx: vec![NAN; n],
        }
    }
}
// input/output:1 ends here

// [[file:../fire.note::*trait][trait:1]]
/// A trait for evaluating value and gradient of objective function
pub trait EvaluateFunction<U> {
    fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<U>;
}

// [2020-11-27 Fri] Will trigger conflicting implementation error
//
// impl<T> EvaluateFunction<()> for T
// where
//     T: FnMut(&[f64], &mut [f64]) -> f64,
// {
//     fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<()> {
//         let fx = (self)(input, &mut output.gx);
//         output.fx = fx;
//         Ok(())
//     }
// }

impl<T, U> EvaluateFunction<U> for T
where
    T: FnMut(Input, &mut Output) -> Result<U>,
{
    fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<U> {
        let user_data = (self)(input, output)?;
        Ok(user_data)
    }
}
// trait:1 ends here

// [[file:../fire.note::*problem][problem:1]]
pub struct Problem<'a, U> {
    // input position
    x: Vec<f64>,
    // callback function for evaluation
    f: Box<dyn EvaluateFunction<U> + 'a>,

    // evaluated function value and gradient
    out: Option<Output>,

    epsilon: f64,
    neval: usize,

    // cache previous position and function value
    x_prev: Option<Vec<f64>>,
    out_prev: Option<Output>,

    // returned data from user defined closure
    pub user_data: Option<U>,
}

impl<'a, U> Problem<'a, U> {
    /// The number of function calls made
    pub fn ncalls(&self) -> usize {
        self.neval
    }

    /// Return function value at current position.
    ///
    /// The function will be evaluated when necessary.
    pub fn value(&mut self) -> f64 {
        // found cached value?
        if self.out.is_none() {
            self.eval().expect("eval error");
        }
        self.out.as_ref().expect("no out").fx
    }

    /// Return function value at previous point
    pub fn value_prev(&self) -> f64 {
        self.out_prev.as_ref().expect("not evaluated yet").fx
    }

    /// Return a reference to function gradient at previous point
    pub fn gradient_prev(&self) -> &[f64] {
        &self.out_prev.as_ref().expect("not evaluated yet").gx
    }

    /// Return a reference to function gradient at current position.
    ///
    /// The function will be evaluated when necessary.
    pub fn gradient(&mut self) -> &[f64] {
        // found cached value?
        if self.out.is_none() {
            self.eval().expect("eval error");
        }
        &self.out.as_ref().expect("no out").gx
    }

    /// Return a reference to current position vector.
    pub fn position(&self) -> &[f64] {
        &self.x
    }

    /// Revert to previous point
    pub fn revert(&mut self) {
        self.x = self.x_prev.clone().expect("not evaluated yet");
        self.out = self.out_prev.clone();
    }
}

/// Core input/output methods
impl<'a, U> Problem<'a, U> {
    /// Construct a CachedProblem
    ///
    /// # Parameters
    ///
    /// * x: initial position
    /// * f: a closure for function evaluation of value and gradient.
    pub fn new(x: Vec<f64>, f: impl EvaluateFunction<U> + 'a) -> Self {
        Self {
            neval: 0,
            epsilon: 1e-8,
            out: None,
            x_prev: x.clone().into(),
            out_prev: None,
            user_data: None,

            f: Box::new(f),
            x,
        }
    }

    /// Update position `x` at a prescribed displacement and step size.
    ///
    /// x += step * displ
    pub fn take_line_step(&mut self, displ: &[f64], step: f64) {
        // position changed
        if step * displ.vec2norm() > self.epsilon {
            // update position vector with displacement
            self.x.vecadd(displ, step);

            // invalidate function output and update cached previous point
            // FIXME: review required
            self.out = None;
            // self.out_prev = self.out.take();
            // self.x_prev = self.out_prev.as_ref().map(|_| self.x.clone());
        }
    }

    /// evaluate function value and gradient at current position
    fn eval(&mut self) -> Result<()> {
        // evaluate function and save returned value from user defined closure
        let n = self.x.len();
        let mut out = self.out.take().unwrap_or(Output::new(n));
        self.user_data = self.f.evaluate(&self.x, &mut out)?.into();

        // FIXME: review required
        // update cached previous point
        self.out_prev = out.clone().into();
        // self.x_prev = self.out_prev.as_ref().map(|_| self.x.clone());
        self.x_prev = self.x.clone().into();

        // update function value and gradient
        self.out = out.into();
        self.neval += 1;

        Ok(())
    }
}
// problem:1 ends here

// [[file:../fire.note::*progress][progress:1]]
/// Progress data produced in each minimization iterations, useful for progress monitor.
#[derive(Debug, Clone)]
pub struct Progress<T> {
    /// Current gradient vector norm.
    pub gnorm: f64,

    /// Current value of the objective function.
    pub fx: f64,

    /// The number of function calls made
    pub ncalls: usize,

    /// The extra data returned from user defined closure for objective function
    /// evaluation
    pub extra: T,
}
// progress:1 ends here
