use pyo3::prelude::*;
use pyo3::{exceptions, PyResult};
use pyo3::wrap_pyfunction;
use std::ops::{Deref, DerefMut};

use rust_decimal::prelude::Decimal as RustDecimal;
use rust_decimal::prelude::ToPrimitive;

#[pyclass(name="Decimal")]
#[derive(Clone, Copy)]
pub struct Decimal(RustDecimal);


#[pymethods]
impl Decimal {
    #[new]
    pub fn new(num: i64, scale: u32) -> Decimal {
        Self(RustDecimal::new(num, scale))
    }

    pub const fn scale(&self) -> u32 {
        self.0.scale()
    }

    pub const fn mantissa(&self) -> i128 {
        self.0.mantissa()
    }

    pub const fn is_zero(&self) -> bool {
        self.0.is_zero()
    }

    pub fn set_sign_positive(&mut self, positive: bool) {
        self.0.set_sign_positive(positive)
    }

    //#[inline(always)]
    pub fn set_sign_negative(&mut self, negative: bool) {
        self.0.set_sign_negative(negative)
    }
    
    pub fn set_scale(&mut self, scale: u32) -> PyResult<()> {
        let result = self.0.set_scale(scale);
        match result {
            Ok(v) => {
                Ok(v)
            },
            Err(_) => {
                Err(exceptions::PyRuntimeError::new_err("set_scale Error"))
            }
        }
    }

    pub fn rescale(&mut self, scale: u32) {
        self.0.rescale(scale)
    }

    pub const fn is_sign_negative(&self) -> bool {
        self.0.is_sign_negative()
    }

    pub const fn is_sign_positive(&self) -> bool {
        self.0.is_sign_positive()
    }

    pub fn trunc(&self) -> Decimal {
        self.0.trunc().into()
    }

    pub fn fract(&self) -> Decimal {
        self.0.fract().into()
    }

    pub fn abs(&self) -> Decimal {
        self.0.abs().into()
    }

    pub fn floor(&self) -> Decimal {
        self.0.floor().into()
    }

    pub fn ceil(&self) -> Decimal {
        self.0.ceil().into()
    }

    pub fn max(&self, other: Decimal) -> Decimal {
        self.0.max(other.0).into()
    }

    pub fn min(&self, other: Decimal) -> Decimal {
        self.0.min(other.0).into()
    }

    pub fn normalize(&self) -> Decimal {
        self.0.normalize().into()
    }

    pub fn normalize_assign(&mut self) {
        self.0.normalize_assign()
    }

    pub fn round(&self) -> Decimal {
        self.0.round().into()
    }

    pub fn round_dp(&self, dp: u32) -> Decimal {
        self.0.round_dp(dp).into()
    }

    pub fn round_sf(&self, digits: u32) -> Option<Decimal> {
        let decimal = self.0.round_sf(digits);
        if decimal.is_some() {
            Some(decimal.unwrap().into())
        } else {
            None
        }
    }

    pub fn to_int(&self) -> i64 {
        self.0.to_i64().unwrap()
    }

    pub fn to_float(&self) -> f64 {
        self.0.to_f64().unwrap()
    }

    fn __add__(&self, other: Decimal) -> PyResult<Decimal> {
        Ok((self.0 + other.0).into())
    }

    fn __sub__(&self, other: Decimal) -> PyResult<Decimal> {
        Ok((self.0 - other.0).into())
    }

    fn __mult__(&self, other: Decimal) -> PyResult<Decimal> {
        Ok((self.0 * other.0).into())
    }

    fn __mod__(&self, other: Decimal) -> PyResult<Decimal> {
        Ok((self.0 / other.0).into())
    }

    fn __divmod__(&self, other: Decimal) -> PyResult<Decimal> {
        Ok((self.0 / other.0).into())
    }

    fn __str__(&self) -> PyResult<String> {
        Ok(self.to_string())
    }

    fn __repr__(&self) -> PyResult<String> {
        Ok(format!("Decimal({})", self.to_string()))
    }
}


impl Deref for Decimal {
    type Target = RustDecimal;
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl DerefMut for Decimal {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

impl Into<RustDecimal> for Decimal {
    fn into(self) -> RustDecimal {
        self.0
    }
}

impl Into<Decimal> for RustDecimal {
    fn into(self) -> Decimal {
        Decimal(self)
    }
}

/// Formats the sum of two numbers as string
#[pyfunction]
fn return_string(a: Decimal) -> PyResult<String> {
    Ok(a.to_string())
}

/// This module is a python module implemented in Rust.
#[pymodule]
fn rust_binding(py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<Decimal>()?;
    m.add_wrapped(wrap_pyfunction!(return_string))?;

    Ok(())
}
