use crate::{
    error::Error,
    simulations::model::{
        calculations::{add_vecs, mul_vecs, std_deviation, std_max, std_mean, std_min, step_vecs},
        params::GenParams,
    },
    Result,
};
use rand_distr::{Distribution, Normal, Pert};
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
pub enum NodeNames {
    Risk,
    LossEventFrequency,
    ThreatEventFrequency,
    Vulnerability,
    ContactFrequency,
    ProbabilityOfAction,
    ThreatCapability,
    ControlStrength,
    LossMagnitude,
    PrimaryLoss,
    SecondaryLoss,
    SecondaryLossEventFrequency,
    SecondaryLossEventMagnitude,
}

impl NodeNames {
    pub fn from(w: &str) -> Result<NodeNames> {
        match w {
            "Loss Event Frequency" => Ok(NodeNames::LossEventFrequency),
            "Threat Event Frequency" => Ok(NodeNames::ThreatEventFrequency),
            "Vulnerability" => Ok(NodeNames::Vulnerability),
            "Contact Frequency" => Ok(NodeNames::ContactFrequency),
            "Probability Of Action" => Ok(NodeNames::ProbabilityOfAction),
            "Threat Capability" => Ok(NodeNames::ThreatCapability),
            "Control Strength" => Ok(NodeNames::ControlStrength),
            "Loss Magnitude" => Ok(NodeNames::LossMagnitude),
            "Primary Loss" => Ok(NodeNames::PrimaryLoss),
            "Secondary Loss" => Ok(NodeNames::SecondaryLoss),
            "Secondary Loss Event Frequency" => Ok(NodeNames::SecondaryLossEventFrequency),
            "Secondary Loss Event Magnitude" => Ok(NodeNames::SecondaryLossEventMagnitude),
            _ => Err(Error::InvalidInputField { e: w.to_string() }),
        }
    }
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Node {
    pub name: NodeNames,
    pub children: Vec<Node>,
    pub gen_params: Option<GenParams>,
    pub values: Vec<f32>,
    pub mean: Option<f32>,
    pub dev: Option<f32>,
    pub min: Option<f32>,
    pub max: Option<f32>,
}

impl Node {
    pub fn new(name: NodeNames, children: Vec<Node>) -> Node {
        Node {
            name,
            children,
            gen_params: None,
            values: vec![],
            mean: None,
            dev: None,
            min: None,
            max: None,
        }
    }

    pub fn set_param(&mut self, gp: GenParams) -> Result<()> {
        self.gen_params = Some(gp);
        Ok(())
    }

    pub fn search(&mut self, name: &NodeNames) -> Result<&mut Node> {
        if &self.name == name {
            return Ok(self);
        }
        for child in &mut self.children {
            if let Ok(e) = child.search(name) {
                return Ok(e);
            }
        }
        Err(Error::Generic("Node: Search Error".to_string()))
    }

    fn self_gen_pert(&mut self, n: u32, min: f32, mean: f32, max: f32) -> Result<()> {
        self.values.clear();
        let d = Pert::new(min, max, mean).unwrap();
        for _i in 0..n {
            self.values.push(d.sample(&mut rand::thread_rng()));
        }
        self.min = Some(std_min(&self.values).unwrap());
        self.max = Some(std_max(&self.values).unwrap());
        self.mean = Some(std_mean(&self.values).unwrap());
        self.dev = Some(std_deviation(&self.values, &self.mean.unwrap()).unwrap());
        Ok(())
    }

    fn self_gen_const(&mut self, n: u32, val: f32) -> Result<()> {
        self.values.clear();
        for _i in 0..n {
            self.values.push(val);
        }
        self.min = Some(std_min(&self.values).unwrap());
        self.max = Some(std_max(&self.values).unwrap());
        self.mean = Some(std_mean(&self.values).unwrap());
        self.dev = Some(std_deviation(&self.values, &self.mean.unwrap()).unwrap());
        Ok(())
    }

    fn self_gen_normal(&mut self, n: u32, mean: f32, dev: f32) -> Result<()> {
        self.values.clear();
        let d = Normal::new(mean, dev).unwrap();
        for _i in 0..n {
            self.values.push(d.sample(&mut rand::thread_rng()));
        }
        self.min = Some(std_min(&self.values).unwrap());
        self.max = Some(std_max(&self.values).unwrap());
        self.mean = Some(std_mean(&self.values).unwrap());
        self.dev = Some(std_deviation(&self.values, &self.mean.unwrap()).unwrap());
        Ok(())
    }

    fn self_calc(&mut self) -> Result<()> {
        if self.children.len() != 2 {
            return Err(Error::InvalidModel {
                n: self.name.to_owned(),
                c: self.children.len(),
            });
        }

        let res = match self.name {
            NodeNames::ThreatEventFrequency => mul_vecs(
                &mut self.values,
                &self.children[0].values,
                &self.children[1].values,
            ),
            NodeNames::LossEventFrequency => mul_vecs(
                &mut self.values,
                &self.children[0].values,
                &self.children[1].values,
            ),
            NodeNames::SecondaryLoss => mul_vecs(
                &mut self.values,
                &self.children[0].values,
                &self.children[1].values,
            ),
            NodeNames::Risk => mul_vecs(
                &mut self.values,
                &self.children[0].values,
                &self.children[1].values,
            ),
            NodeNames::LossMagnitude => add_vecs(
                &mut self.values,
                &self.children[0].values,
                &self.children[1].values,
            ),
            NodeNames::Vulnerability => step_vecs(
                &mut self.values,
                &self.children[0].values,
                &self.children[1].values,
            ),
            _ => Err(Error::IncorrectModel {
                n: self.name.to_owned(),
            }),
        };
        self.min = Some(std_min(&self.values).unwrap());
        self.max = Some(std_max(&self.values).unwrap());
        self.mean = Some(std_mean(&self.values).unwrap());
        self.dev = Some(std_deviation(&self.values, &self.mean.unwrap()).unwrap());
        res
    }

    pub fn calc(&mut self, n: u32) -> Result<(f32, f32, f32, f32)> {
        let s = match self.gen_params {
            None => {
                for child in &mut self.children {
                    match child.calc(n) {
                        Ok(_) => continue,
                        Err(e) => return Err(e),
                    }
                }
                self.self_calc()
            }
            Some(GenParams::PertParams { min, mean, max }) => self.self_gen_pert(n, min, mean, max),
            Some(GenParams::ConstParams { val }) => self.self_gen_const(n, val),
            Some(GenParams::NormalParams { mean, dev }) => self.self_gen_normal(n, mean, dev),
        };
        match s {
            Err(e) => Err(e),
            Ok(_) => Ok((
                self.min.unwrap(),
                self.max.unwrap(),
                self.mean.unwrap(),
                self.dev.unwrap(),
            )),
        }
    }
}
