/*
    Nyx, blazing fast astrodynamics
    Copyright (C) 2021 Christopher Rabotin <christopher.rabotin@gmail.com>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published
    by the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
*/

use super::crossbeam::thread;
use super::error_ctrl::{ErrorCtrl, RSSCartesianStep};
use super::{IntegrationDetails, RK, RK89};
use crate::dimensions::allocator::Allocator;
use crate::dimensions::{DefaultAllocator, VectorN};
use crate::dynamics::Dynamics;
use crate::errors::NyxError;
use crate::md::trajectory::Traj;
use crate::md::EventEvaluator;
use crate::time::{Duration, Epoch, TimeUnit};
use crate::{State, TimeTagged};
use std::f64;
use std::sync::mpsc::{channel, Sender};
use std::sync::Arc;

/// A Propagator allows propagating a set of dynamics forward or backward in time.
/// It is an EventTracker, without any event tracking. It includes the options, the integrator
/// details of the previous step, and the set of coefficients used for the monomorphic instance.
#[derive(Clone, Debug)]
pub struct Propagator<'a, D: Dynamics, E: ErrorCtrl>
where
    DefaultAllocator: Allocator<f64, <D::StateType as State>::Size>
        + Allocator<f64, <D::StateType as State>::PropVecSize>,
{
    pub dynamics: Arc<D>, // Stores the dynamics used. *Must* use this to get the latest values
    pub opts: PropOpts<E>, // Stores the integration options (tolerance, min/max step, init step, etc.)
    order: u8,             // Order of the integrator
    stages: usize,         // Number of stages, i.e. how many times the derivatives will be called
    a_coeffs: &'a [f64],
    b_coeffs: &'a [f64],
}

/// The `Propagator` trait defines the functions of a propagator and of an event tracker.
impl<'a, D: Dynamics, E: ErrorCtrl> Propagator<'a, D, E>
where
    DefaultAllocator: Allocator<f64, <D::StateType as State>::Size>
        + Allocator<f64, <D::StateType as State>::PropVecSize>,
{
    /// Each propagator must be initialized with `new` which stores propagator information.
    pub fn new<T: RK>(dynamics: Arc<D>, opts: PropOpts<E>) -> Self {
        Self {
            dynamics,
            opts,
            stages: T::stages(),
            order: T::order(),
            a_coeffs: T::a_coeffs(),
            b_coeffs: T::b_coeffs(),
        }
    }

    /// Set the tolerance for the propagator
    pub fn set_tolerance(&mut self, tol: f64) {
        self.opts.tolerance = tol;
    }

    /// Set the maximum step size for the propagator
    pub fn set_max_step(&mut self, step: Duration) {
        self.opts.max_step = step;
    }

    /// An RK89 propagator (the default) with custom propagator options.
    pub fn rk89(dynamics: Arc<D>, opts: PropOpts<E>) -> Self {
        Self::new::<RK89>(dynamics, opts)
    }

    pub fn with(&'a self, state: D::StateType) -> PropInstance<'a, D, E> {
        // let init_time = state.epoch();
        // let init_state_vec = dynamics.state_vector();
        // Pre-allocate the k used in the propagator
        let mut k = Vec::with_capacity(self.stages + 1);
        for _ in 0..self.stages {
            k.push(VectorN::<f64, <D::StateType as State>::PropVecSize>::zeros());
        }
        PropInstance {
            state,
            prop: Arc::new(self),
            tx_chan: None,
            prevent_tx: false,
            details: IntegrationDetails {
                step: self.opts.init_step,
                error: 0.0,
                attempts: 1,
            },
            step_size: self.opts.init_step,
            fixed_step: self.opts.fixed_step,
            // init_time,
            k,
        }
    }
}

impl<'a, D: Dynamics> Propagator<'a, D, RSSCartesianStep>
where
    DefaultAllocator: Allocator<f64, <D::StateType as State>::Size>
        + Allocator<f64, <D::StateType as State>::PropVecSize>,
{
    /// Default propagator is an RK89 with the default PropOpts.
    pub fn default(dynamics: Arc<D>) -> Self {
        Self::new::<RK89>(dynamics, PropOpts::default())
    }
}

/// A Propagator allows propagating a set of dynamics forward or backward in time.
/// It is an EventTracker, without any event tracking. It includes the options, the integrator
/// details of the previous step, and the set of coefficients used for the monomorphic instance.
#[derive(Debug)]
pub struct PropInstance<'a, D: Dynamics, E: ErrorCtrl>
where
    DefaultAllocator: Allocator<f64, <D::StateType as State>::Size>
        + Allocator<f64, <D::StateType as State>::PropVecSize>,
{
    /// The state of this propagator instance
    pub state: D::StateType,
    /// The propagator setup (kind, stages, etc.)
    pub prop: Arc<&'a Propagator<'a, D, E>>,
    /// An output channel for all of the states computed by this propagator instance
    pub tx_chan: Option<Sender<D::StateType>>,
    /// Stores the details of the previous integration step
    pub details: IntegrationDetails,
    prevent_tx: bool, // Allows preventing publishing to channel even if channel is set
    step_size: Duration, // Stores the adapted step for the _next_ call
    fixed_step: bool,
    // init_time: Epoch,
    // init_state_vec: VectorN<f64, <D::StateType as State>::Size>,
    // Allows us to do pre-allocation of the ki vectors
    k: Vec<VectorN<f64, <D::StateType as State>::PropVecSize>>,
}

impl<'a, D: Dynamics, E: ErrorCtrl> PropInstance<'a, D, E>
where
    DefaultAllocator: Allocator<f64, <D::StateType as State>::Size>
        + Allocator<f64, <D::StateType as State>::PropVecSize>
        + Allocator<f64, <D::StateType as State>::Size>,
{
    /// Allows setting the step size of the propagator
    pub fn set_step(&mut self, step_size: Duration, fixed: bool) {
        self.step_size = step_size;
        self.fixed_step = fixed;
    }

    /// Set the output channel of the propagator. For example use this to generate an interpolated trajectory.
    pub fn with_tx(mut self, tx: Sender<D::StateType>) -> Self {
        self.tx_chan = Some(tx);
        self
    }

    /// Returns the state of the propagation
    ///
    /// WARNING: Do not use the dynamics to get the state, it will be the initial value!
    pub fn state_vector(&self) -> VectorN<f64, <D::StateType as State>::PropVecSize> {
        self.state.as_vector().unwrap()
    }

    /// This method propagates the provided Dynamics for the provided duration.
    #[allow(clippy::erasing_op)]
    pub fn for_duration(&mut self, duration: Duration) -> Result<D::StateType, NyxError> {
        if duration == 0 * TimeUnit::Second {
            debug!("No propagation necessary");
            return Ok(self.state);
        }
        if duration > 2 * TimeUnit::Minute || duration < -2 * TimeUnit::Minute {
            // Prevent the print spam for EKF orbit determination cases
            info!("Propagating for {}", duration);
        }
        let backprop = duration < TimeUnit::Nanosecond;
        if backprop {
            self.step_size = -self.step_size; // Invert the step size
        }
        let stop_time = self.state.epoch() + duration;
        loop {
            let dt = self.state.epoch();
            if (!backprop && dt + self.step_size > stop_time)
                || (backprop && dt + self.step_size <= stop_time)
            {
                if stop_time == dt {
                    // No propagation necessary
                    return Ok(self.state);
                }
                // Take one final step of exactly the needed duration until the stop time
                let prev_step_size = self.step_size;
                let prev_step_kind = self.fixed_step;
                self.set_step(stop_time - dt, true);
                let (t, state_vec) = self.derive()?;
                self.state.set(self.state.epoch() + t, &state_vec)?;
                self.state = self.prop.dynamics.finally(self.state)?;
                // Restore the step size for subsequent calls
                self.set_step(prev_step_size, prev_step_kind);
                if !self.prevent_tx {
                    if let Some(ref chan) = self.tx_chan {
                        if let Err(e) = chan.send(self.state) {
                            warn!("could not publish to channel: {}", e)
                        }
                    }
                }
                if backprop {
                    self.step_size = -self.step_size; // Restore to a positive step size
                }
                return Ok(self.state);
            } else {
                let (t, state_vec) = self.derive()?;

                self.state.set(self.state.epoch() + t, &state_vec)?;
                self.state = self.prop.dynamics.finally(self.state)?;
                if !self.prevent_tx {
                    if let Some(ref chan) = self.tx_chan {
                        if let Err(e) = chan.send(self.state) {
                            warn!("could not publish to channel: {}", e)
                        }
                    }
                }
            }
        }
    }

    /// Propagates the provided Dynamics until the provided epoch. Returns the end state.
    pub fn until_epoch(&mut self, end_time: Epoch) -> Result<D::StateType, NyxError> {
        let duration: Duration = end_time - self.state.epoch();
        self.for_duration(duration)
    }

    /// Propagates the provided Dynamics for the provided duration and generate the trajectory of these dynamics on its own thread.
    /// Returns the end state and the trajectory.
    /// Known bug #190: Cannot generate a valid trajectory when propagating backward
    pub fn for_duration_with_traj(
        &mut self,
        duration: Duration,
    ) -> Result<(D::StateType, Traj<D::StateType>), NyxError> {
        thread::scope(|s| {
            let (tx, rx) = channel();
            self.tx_chan = Some(tx);
            let start_state = self.state;
            // The trajectory must always be generated on its own thread.
            let traj_thread = s.spawn(move |_| Traj::new(start_state, rx));
            let end_state = self.for_duration(duration)?;

            let traj = traj_thread.join().unwrap_or_else(|_| {
                Err(NyxError::NoInterpolationData(
                    "Could not generate trajectory".to_string(),
                ))
            })?;

            Ok((end_state, traj))
        })
        .unwrap()
    }

    /// Propagates the provided Dynamics until the provided epoch and generate the trajectory of these dynamics on its own thread.
    /// Returns the end state and the trajectory.
    /// Known bug #190: Cannot generate a valid trajectory when propagating backward
    pub fn until_epoch_with_traj(
        &mut self,
        end_time: Epoch,
    ) -> Result<(D::StateType, Traj<D::StateType>), NyxError> {
        let duration: Duration = end_time - self.state.epoch();
        self.for_duration_with_traj(duration)
    }

    /// Propagate until a specific event is found `trigger` times.
    /// Returns the state found and the trajectory until `max_duration`
    pub fn until_event<F: EventEvaluator<D::StateType>>(
        &mut self,
        max_duration: Duration,
        event: &F,
        trigger: usize,
    ) -> Result<(D::StateType, Traj<D::StateType>), NyxError> {
        info!("Searching for {}", event);

        let (_, traj) = self.for_duration_with_traj(max_duration)?;
        // Now, find the requested event
        let events = traj.find_all(event)?;
        match events.get(trigger) {
            Some(event_state) => Ok((*event_state, traj)),
            None => Err(NyxError::UnsufficientTriggers(trigger, events.len())),
        }
    }

    /// This method integrates whichever function is provided as `d_xdt`. Everything passed to this function is in **seconds**.
    ///
    /// This function returns the step sized used (as a Duration) and the new state as y_{n+1} = y_n + \frac{dy_n}{dt}.
    /// To get the integration details, check `self.latest_details`.
    fn derive(
        &mut self,
    ) -> Result<(Duration, VectorN<f64, <D::StateType as State>::PropVecSize>), NyxError> {
        let state = &self.state_vector();
        let ctx = &self.state;
        // Reset the number of attempts used (we don't reset the error because it's set before it's read)
        self.details.attempts = 1;
        // Convert the step size to seconds -- it's mutable because we may change it below
        let mut step_size = self.step_size.in_seconds();
        loop {
            let ki = self.prop.dynamics.eom(0.0, state, ctx)?;
            self.k[0] = ki;
            let mut a_idx: usize = 0;
            for i in 0..(self.prop.stages - 1) {
                // Let's compute the c_i by summing the relevant items from the list of coefficients.
                // \sum_{j=1}^{i-1} a_ij  ∀ i ∈ [2, s]
                let mut ci: f64 = 0.0;
                // The wi stores the a_{s1} * k_1 + a_{s2} * k_2 + ... + a_{s, s-1} * k_{s-1} +
                let mut wi =
                    VectorN::<f64, <D::StateType as State>::PropVecSize>::from_element(0.0);
                for kj in &self.k[0..i + 1] {
                    let a_ij = self.prop.a_coeffs[a_idx];
                    ci += a_ij;
                    wi += a_ij * kj;
                    a_idx += 1;
                }

                let ki = self
                    .prop
                    .dynamics
                    .eom(ci * step_size, &(state + step_size * wi), ctx)?;
                self.k[i + 1] = ki;
            }
            // Compute the next state and the error
            let mut next_state = state.clone();
            // State error estimation from https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods#Adaptive_Runge%E2%80%93Kutta_methods
            // This is consistent with GMAT https://github.com/ChristopherRabotin/GMAT/blob/37201a6290e7f7b941bc98ee973a527a5857104b/src/base/propagator/RungeKutta.cpp#L537
            let mut error_est =
                VectorN::<f64, <D::StateType as State>::PropVecSize>::from_element(0.0);
            for (i, ki) in self.k.iter().enumerate() {
                let b_i = self.prop.b_coeffs[i];
                if !self.fixed_step {
                    let b_i_star = self.prop.b_coeffs[i + self.prop.stages];
                    error_est += step_size * (b_i - b_i_star) * ki;
                }
                next_state += step_size * b_i * ki;
            }

            if self.fixed_step {
                // Using a fixed step, no adaptive step necessary
                self.details.step = self.step_size;
                return Ok(((self.details.step), next_state));
            } else {
                // Compute the error estimate.
                self.details.error = E::estimate(&error_est, &next_state, &state);
                if self.details.error <= self.prop.opts.tolerance
                    || step_size <= self.prop.opts.min_step.in_seconds()
                    || self.details.attempts >= self.prop.opts.attempts
                {
                    if self.details.attempts >= self.prop.opts.attempts {
                        warn!(
                            "Could not further decrease step size: maximum number of attempts reached ({})",
                            self.details.attempts
                        );
                    }

                    self.details.step = step_size * TimeUnit::Second;
                    if self.details.error < self.prop.opts.tolerance {
                        // Let's increase the step size for the next iteration.
                        // Error is less than tolerance, let's attempt to increase the step for the next iteration.
                        let proposed_step = 0.9
                            * step_size
                            * (self.prop.opts.tolerance / self.details.error)
                                .powf(1.0 / f64::from(self.prop.order));
                        step_size = if proposed_step > self.prop.opts.max_step.in_seconds() {
                            self.prop.opts.max_step.in_seconds()
                        } else {
                            proposed_step
                        };
                    }
                    // In all cases, let's update the step size to whatever was the adapted step size
                    self.step_size = step_size * TimeUnit::Second;
                    return Ok((self.details.step, next_state));
                } else {
                    // Error is too high and we aren't using the smallest step, and we haven't hit the max number of attempts.
                    // So let's adapt the step size.
                    self.details.attempts += 1;
                    let proposed_step = 0.9
                        * step_size
                        * (self.prop.opts.tolerance / self.details.error)
                            .powf(1.0 / f64::from(self.prop.order - 1));
                    step_size = if proposed_step < self.prop.opts.min_step.in_seconds() {
                        self.prop.opts.min_step.in_seconds()
                    } else {
                        proposed_step
                    };
                    // Note that we don't set self.step_size, that will be updated right before we return
                }
            }
        }
    }

    /// Borrow the details of the latest integration step.
    pub fn latest_details(&self) -> &IntegrationDetails {
        &self.details
    }
}

/// PropOpts stores the integrator options, including the minimum and maximum step sizes, and the
/// max error size.
///
/// Note that different step sizes and max errors are only used for adaptive
/// methods. To use a fixed step integrator, initialize the options using `with_fixed_step`, and
/// use whichever adaptive step integrator is desired.  For example, initializing an RK45 with
/// fixed step options will lead to an RK4 being used instead of an RK45.
#[derive(Clone, Copy, Debug)]
pub struct PropOpts<E: ErrorCtrl> {
    init_step: Duration,
    min_step: Duration,
    max_step: Duration,
    tolerance: f64,
    attempts: u8,
    fixed_step: bool,
    errctrl: E,
}

impl<E: ErrorCtrl> PropOpts<E> {
    /// `with_adaptive_step` initializes an `PropOpts` such that the integrator is used with an
    ///  adaptive step size. The number of attempts is currently fixed to 50 (as in GMAT).
    pub fn with_adaptive_step(
        min_step: Duration,
        max_step: Duration,
        tolerance: f64,
        errctrl: E,
    ) -> Self {
        PropOpts {
            init_step: max_step,
            min_step,
            max_step,
            tolerance,
            attempts: 50,
            fixed_step: false,
            errctrl,
        }
    }

    pub fn with_adaptive_step_s(min_step: f64, max_step: f64, tolerance: f64, errctrl: E) -> Self {
        Self::with_adaptive_step(
            min_step * TimeUnit::Second,
            max_step * TimeUnit::Second,
            tolerance,
            errctrl,
        )
    }

    /// Returns a string with the information about these options
    pub fn info(&self) -> String {
        format!(
            "[min_step: {:.e}, max_step: {:.e}, tol: {:.e}, attempts: {}]",
            self.min_step, self.max_step, self.tolerance, self.attempts,
        )
    }
}

impl PropOpts<RSSCartesianStep> {
    /// `with_fixed_step` initializes an `PropOpts` such that the integrator is used with a fixed
    ///  step size.
    pub fn with_fixed_step(step: Duration) -> Self {
        PropOpts {
            init_step: step,
            min_step: step,
            max_step: step,
            tolerance: 0.0,
            fixed_step: true,
            attempts: 0,
            errctrl: RSSCartesianStep {},
        }
    }

    pub fn with_fixed_step_s(step: f64) -> Self {
        Self::with_fixed_step(step * TimeUnit::Second)
    }

    /// Returns the default options with a specific tolerance.
    #[allow(clippy::field_reassign_with_default)]
    pub fn with_tolerance(tolerance: f64) -> Self {
        let mut opts = Self::default();
        opts.tolerance = tolerance;
        opts
    }
}

impl Default for PropOpts<RSSCartesianStep> {
    /// `default` returns the same default options as GMAT.
    fn default() -> PropOpts<RSSCartesianStep> {
        PropOpts {
            init_step: 60.0 * TimeUnit::Second,
            min_step: 0.001 * TimeUnit::Second,
            max_step: 2700.0 * TimeUnit::Second,
            tolerance: 1e-12,
            attempts: 50,
            fixed_step: false,
            errctrl: RSSCartesianStep {},
        }
    }
}

#[test]
fn test_options() {
    use super::error_ctrl::RSSStep;

    let opts = PropOpts::with_fixed_step_s(1e-1);
    assert_eq!(opts.min_step, 1e-1 * TimeUnit::Second);
    assert_eq!(opts.max_step, 1e-1 * TimeUnit::Second);
    assert!(opts.tolerance.abs() < std::f64::EPSILON);
    assert_eq!(opts.fixed_step, true);

    let opts = PropOpts::with_adaptive_step_s(1e-2, 10.0, 1e-12, RSSStep {});
    assert_eq!(opts.min_step, 1e-2 * TimeUnit::Second);
    assert_eq!(opts.max_step, 10.0 * TimeUnit::Second);
    assert!((opts.tolerance - 1e-12).abs() < std::f64::EPSILON);
    assert_eq!(opts.fixed_step, false);

    let opts: PropOpts<RSSCartesianStep> = Default::default();
    assert_eq!(opts.init_step, 60.0 * TimeUnit::Second);
    assert_eq!(opts.min_step, 0.001 * TimeUnit::Second);
    assert_eq!(opts.max_step, 2700.0 * TimeUnit::Second);
    assert!((opts.tolerance - 1e-12).abs() < std::f64::EPSILON);
    assert_eq!(opts.attempts, 50);
    assert_eq!(opts.fixed_step, false);
}
