// [[file:../lbfgs.note::*imports][imports:1]]
use crate::core::*;

use crate::math::LbfgsMath;
use crate::line::*;
// imports:1 ends here

// [[file:../lbfgs.note::*input/output][input/output:1]]
/// Evaluated function value and gradient
#[derive(Debug, Clone)]
pub struct Output {
    pub fx: f64,
    pub gx: Vec<f64>,
}

pub type Input<'a> = &'a [f64];

impl Output {
    pub fn new(n: usize) -> Self {
        use std::f64::NAN;

        Self {
            fx: NAN,
            gx: vec![NAN; n],
        }
    }
}
// input/output:1 ends here

// [[file:../lbfgs.note::*trait][trait:1]]
/// A trait for evaluating value and gradient of objective function
pub trait EvaluateFunction<U> {
    fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<U>;
}

impl<T, U> EvaluateFunction<U> for T
where
    T: FnMut(Input, &mut Output) -> Result<U>,
{
    fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<U> {
        let user_data = (self)(input, output)?;
        Ok(user_data)
    }
}
// trait:1 ends here

// [[file:../lbfgs.note::*problem][problem:1]]
/// Represents an optimization problem.
///
/// `Problem` holds input variables `x`, gradient `gx` arrays, and function value `fx`.
pub(crate) struct Problem<'a, U> {
    /// x is an array of length n. on input it must contain the base point for
    /// the line search.
    pub x: Vec<f64>,

    /// `fx` is a variable. It must contain the value of problem `f` at
    /// x.
    pub fx: f64,

    /// `gx` is an array of length n. It must contain the gradient of `f` at
    /// x.
    pub gx: Vec<f64>,

    // returned data from user defined closure
    pub data: Option<U>,

    /// Cached position vector of previous step.
    pub(crate) xp: Vec<f64>,

    /// Cached gradient vector of previous step.
    pub(crate) gp: Vec<f64>,

    /// Pseudo gradient for OrthantWise Limited-memory Quasi-Newton (owlqn) algorithm.
    pg: Vec<f64>,

    /// Search direction
    d: Vec<f64>,

    /// Store callback function for evaluating objective function.
    // eval_fn: E,
    eval_fn: Box<dyn EvaluateFunction<U> + 'a>,

    /// Orthantwise operations
    owlqn: Option<Orthantwise>,

    /// Evaluated or not
    evaluated: bool,

    /// The number of evaluation.
    neval: usize,
}

impl<'a, U> Problem<'a, U> {
    /// Initialize problem with array length n
    pub fn new(x: Vec<f64>, f: impl EvaluateFunction<U> + 'a, owlqn: Option<Orthantwise>) -> Self {
        let n = x.len();
        Problem {
            eval_fn: Box::new(f),
            fx: 0.0,
            gx: vec![0.0; n],
            xp: vec![0.0; n],
            gp: vec![0.0; n],
            pg: vec![0.0; n],
            d: vec![0.0; n],
            evaluated: false,
            data: None,
            neval: 0,
            x,
            owlqn,
        }
    }

    /// Compute the initial gradient in the search direction.
    pub fn dginit(&self) -> Result<f64> {
        if self.owlqn.is_none() {
            let dginit = self.gx.vecdot(&self.d);
            if dginit > 0.0 {
                warn!(
                    "The current search direction increases the objective function value. dginit = {:-0.4}",
                    dginit
                );
            }

            Ok(dginit)
        } else {
            Ok(self.pg.vecdot(&self.d))
        }
    }

    /// Update search direction using evaluated gradient.
    pub fn update_search_direction(&mut self) {
        if self.owlqn.is_some() {
            self.d.vecncpy(&self.pg);
        } else {
            self.d.vecncpy(&self.gx);
        }
    }

    /// Return a reference to current search direction vector
    pub fn search_direction(&self) -> &[f64] {
        &self.d
    }

    /// Return a mutable reference to current search direction vector
    pub fn search_direction_mut(&mut self) -> &mut [f64] {
        &mut self.d
    }

    /// Compute the gradient in the search direction without sign checking.
    pub fn dg_unchecked(&self) -> f64 {
        self.gx.vecdot(&self.d)
    }

    // FIXME: improve
    pub fn evaluate(&mut self) -> Result<()> {
        // self.fx = (self.eval_fn)(&self.x, &mut self.gx)?;
        let mut out = Output::new(self.x.len());
        self.data = self.eval_fn.evaluate(&self.x, &mut out)?.into();
        assert!(!out.fx.is_nan() || !out.gx[0].is_nan(), "invalid evaluation: {:?}", out);

        self.fx = out.fx;
        self.gx = out.gx.clone();

        // Compute the L1 norm of the variables and add it to the object value.
        if let Some(owlqn) = self.owlqn {
            self.fx += owlqn.x1norm(&self.x)
        }

        // FIXME: to be better
        // if self.orthantwise {
        // Compute the L1 norm of the variable and add it to the object value.
        // fx += self.owlqn.x1norm(x);
        // self.owlqn.pseudo_gradient(&mut pg, &x, &g);

        self.evaluated = true;
        self.neval += 1;

        Ok(())
    }

    /// Return total number of evaluations.
    pub fn number_of_evaluation(&self) -> usize {
        self.neval
    }

    /// Test if `Problem` has been evaluated or not
    pub fn evaluated(&self) -> bool {
        self.evaluated
    }

    /// Copies all elements from src into self.
    pub fn clone_from(&mut self, src: &Problem<U>) {
        self.x.clone_from_slice(&src.x);
        self.gx.clone_from_slice(&src.gx);
        self.fx = src.fx;
    }

    /// Take a line step along search direction.
    ///
    /// Compute the current value of x: x <- x + (*step) * d.
    ///
    pub fn take_line_step(&mut self, step: f64) {
        self.x.veccpy(&self.xp);
        self.x.vecadd(&self.d, step);

        // Choose the orthant for the new point.
        // The current point is projected onto the orthant.
        if let Some(owlqn) = self.owlqn {
            owlqn.project(&mut self.x, &self.xp, &self.gp);
        }
    }

    /// Return gradient vector norm: ||gx||
    pub fn gnorm(&self) -> f64 {
        if self.owlqn.is_some() {
            self.pg.vec2norm()
        } else {
            self.gx.vec2norm()
        }
    }

    /// Return position vector norm: ||x||
    pub fn xnorm(&self) -> f64 {
        self.x.vec2norm()
    }

    pub fn orthantwise(&self) -> bool {
        self.owlqn.is_some()
    }

    /// Revert to previous step
    pub fn revert(&mut self) {
        self.x.veccpy(&self.xp);
        self.gx.veccpy(&self.gp);
    }

    /// Store the current position and gradient vectors.
    pub fn save_state(&mut self) {
        self.xp.veccpy(&self.x);
        self.gp.veccpy(&self.gx);
    }

    /// Constrain the search direction for orthant-wise updates.
    pub fn constrain_search_direction(&mut self) {
        if let Some(owlqn) = self.owlqn {
            owlqn.constrain(&mut self.d, &self.pg);
        }
    }

    // FIXME
    pub fn update_owlqn_gradient(&mut self) {
        if let Some(owlqn) = self.owlqn {
            owlqn.pseudo_gradient(&mut self.pg, &self.x, &self.gx);
        }
    }
}
// problem:1 ends here

// [[file:../lbfgs.note::*progress][progress:1]]
/// Store optimization progress data, for progress monitor
#[repr(C)]
#[derive(Debug, Clone)]
pub(crate) struct Progress<'a> {
    /// The current values of variables
    pub x: &'a [f64],

    /// The current gradient values of variables.
    pub gx: &'a [f64],

    /// The current value of the objective function.
    pub fx: f64,

    /// The Euclidean norm of the variables
    pub xnorm: f64,

    /// The Euclidean norm of the gradients.
    pub gnorm: f64,

    /// The line-search step used for this iteration.
    pub step: f64,

    /// The iteration count.
    pub niter: usize,

    /// The total number of evaluations.
    pub neval: usize,

    /// The number of function evaluation calls in line search procedure
    pub ncall: usize,
}

impl<'a> Progress<'a> {
    pub(crate) fn new<U>(prb: &'a Problem<U>, niter: usize, ncall: usize, step: f64) -> Self {
        Progress {
            x: &prb.x,
            gx: &prb.gx,
            fx: prb.fx,
            xnorm: prb.xnorm(),
            gnorm: prb.gnorm(),
            neval: prb.number_of_evaluation(),
            ncall,
            step,
            niter,
        }
    }
}
// progress:1 ends here

// [[file:../lbfgs.note::*progress/iter][progress/iter:1]]
/// Progress data produced in each minimization iterations, useful for progress monitor.
#[derive(Debug, Clone)]
pub struct ProgressIter<T> {
    /// Current gradient vector norm.
    pub gnorm: f64,

    /// Current value of the objective function.
    pub fx: f64,

    /// The number of function calls made
    pub ncalls: usize,

    /// The extra data returned from user defined closure for objective function
    /// evaluation
    pub extra: T,
}
// progress/iter:1 ends here

// [[file:../lbfgs.note::*orthantwise][orthantwise:1]]
/// Orthant-Wise Limited-memory Quasi-Newton (OWL-QN) algorithm
#[derive(Copy, Clone, Debug)]
pub struct Orthantwise {
    /// Coeefficient for the L1 norm of variables.
    ///
    ///  Setting this parameter to a positive value activates Orthant-Wise
    ///  Limited-memory Quasi-Newton (OWL-QN) method, which minimizes the
    ///  objective function F(x) combined with the L1 norm |x| of the variables,
    ///  {F(x) + C |x|}. This parameter is the coeefficient for the |x|, i.e.,
    ///  C. As the L1 norm |x| is not differentiable at zero, the library
    ///  modifies function and gradient evaluations from a client program
    ///  suitably; a client program thus have only to return the function value
    ///  F(x) and gradients G(x) as usual. The default value is 1.
    pub c: f64,

    /// Start index for computing L1 norm of the variables.
    ///
    /// This parameter is valid only for OWL-QN method (i.e., orthantwise_c !=
    /// 0). This parameter b (0 <= b < N) specifies the index number from which
    /// the library computes the L1 norm of the variables x,
    ///
    /// |x| := |x_{b}| + |x_{b+1}| + ... + |x_{N}| .
    ///
    /// In other words, variables x_1, ..., x_{b-1} are not used for computing
    /// the L1 norm. Setting b (0 < b < N), one can protect variables, x_1, ...,
    /// x_{b-1} (e.g., a bias term of logistic regression) from being
    /// regularized. The default value is zero.
    pub start: i32,

    /// End index for computing L1 norm of the variables.
    ///
    /// This parameter is valid only for OWL-QN method (i.e., \ref orthantwise_c
    /// != 0). This parameter e (0 < e <= N) specifies the index number at which
    /// the library stops computing the L1 norm of the variables x,
    pub end: i32,
}

impl Default for Orthantwise {
    fn default() -> Self {
        Orthantwise {
            c: 1.0,
            start: 0,
            end: -1,
        }
    }
}

impl Orthantwise {
    // FIXME: remove
    // a dirty wrapper for start and end
    fn start_end(&self, x: &[f64]) -> (usize, usize) {
        let start = self.start as usize;
        let end = if self.end < 0 {
            x.len()
        } else {
            self.end as usize
        };

        (start, end)
    }

    /// Compute the L1 norm of the variables.
    pub(crate) fn x1norm(&self, x: &[f64]) -> f64 {
        let (start, end) = self.start_end(x);

        let mut s = 0.0;
        for i in start..end {
            s += self.c * x[i].abs();
        }

        s
    }

    /// Compute the psuedo-gradients.
    pub(crate) fn pseudo_gradient(&self, pg: &mut [f64], x: &[f64], g: &[f64]) {
        let (start, end) = self.start_end(x);
        let c = self.c;

        // Compute the negative of gradients.
        for i in 0..start {
            pg[i] = g[i];
        }

        // Compute the psuedo-gradients.
        for i in start..end {
            if x[i] < 0.0 {
                // Differentiable.
                pg[i] = g[i] - c;
            } else if 0.0 < x[i] {
                pg[i] = g[i] + c;
            } else {
                if g[i] < -c {
                    // Take the right partial derivative.
                    pg[i] = g[i] + c;
                } else if c < g[i] {
                    // Take the left partial derivative.
                    pg[i] = g[i] - c;
                } else {
                    pg[i] = 0.;
                }
            }
        }

        for i in end..g.len() {
            pg[i] = g[i];
        }
    }

    /// Choose the orthant for the new point.
    ///
    /// During the line search, each search point is projected onto the orthant
    /// of the previous point.
    pub(crate) fn project(&self, x: &mut [f64], xp: &[f64], gp: &[f64]) {
        let (start, end) = self.start_end(xp);

        for i in start..end {
            let sign = if xp[i] == 0.0 { -gp[i] } else { xp[i] };
            if x[i] * sign <= 0.0 {
                x[i] = 0.0
            }
        }
    }

    pub(crate) fn constrain(&self, d: &mut [f64], pg: &[f64]) {
        let (start, end) = self.start_end(pg);

        for i in start..end {
            if d[i] * pg[i] >= 0.0 {
                d[i] = 0.0;
            }
        }
    }
}
// orthantwise:1 ends here
