use num_bigint::BigInt;
use num_complex::Complex64;
use num_irrational::Quadratic64;
use num_rational::Rational64;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::{ffi, AsPyPointer, IntoPy, PyObject};
use std::cmp::Ordering;
use suan_core::{self as core, Subject};

/// This struct is used for representing [Subject] in Python
#[derive(Clone)]
#[pyclass(name = "Subject", unsendable)]
struct PySubject(Subject);

impl From<Subject> for PySubject {
    fn from(s: Subject) -> Self {
        PySubject(s)
    }
}

#[pymethods]
impl PySubject {
    /// Return the real part of the number. The number itself will be returned if
    /// it's not complex
    #[getter]
    fn re(&self) -> Option<Self> {
        match &self.0 {
            Subject::Complex(v) => Some(Subject::Real(v.re).into()),
            v => Some(v.clone().into()),
        }
    }

    /// Return the image part of the number. `None` will be returned if it's
    /// not complex
    #[getter]
    fn im(&self) -> Option<Self> {
        match self.0 {
            Subject::Complex(v) => Some(Subject::Real(v.im).into()),
            _ => None,
        }
    }

    /// Return the numerator of the rational / quad number. Return itself for integers,
    /// otherwise return `None`
    #[getter]
    fn numer(&self) -> Option<Self> {
        match &self.0 {
            Subject::Int(v) => Some(Subject::Int(*v).into()),
            Subject::BInt(v) => Some(Subject::BInt(v.clone()).into()),
            Subject::Rational(v) => Some(Subject::Int(*v.numer()).into()),
            Subject::BRational(v) => Some(Subject::BInt(v.numer().clone()).into()),
            Subject::Quad(v) => {
                let (a, b, _, r) = v.parts();
                Some(Subject::Quad(Quadratic64::new(*a, *b, 1, *r)).into())
            }
            _ => None,
        }
    }

    /// Return the denominator of the rational / quad number. Otherwise return `None`.
    #[getter]
    fn denom(&self) -> Option<Self> {
        match &self.0 {
            Subject::Rational(v) => Some(Subject::Int(*v.denom()).into()),
            Subject::BRational(v) => Some(Subject::BInt(v.denom().clone()).into()),
            Subject::Quad(v) => Some(Subject::Int(*v.parts().2).into()),
            _ => None,
        }
    }

    /// Return the discriminant for the quadratic number
    #[getter]
    fn discr(&self) -> Option<Self> {
        match self.0 {
            Subject::Quad(v) => {
                // d if d mod 4 == 1, 4d if d mod 4 == 2,3
                let r = *v.parts().3;
                Some(Subject::Int(if r % 4 == 1 { r } else { 4 * r }).into())
            }
            _ => None,
        }
    }

    /// Return the modulus of modular numbers
    #[getter(mod)]
    fn modulus(&self) -> Option<Self> {
        match self.0 {
            _ => None,
        }
    }

    fn __repr__(&self) -> String {
        format!("{:?}", self.0)
    }

    fn __str__(&self) -> String {
        format!("{}", self.0)
    }

    /// Convert Subject to native python types. If no native corresponding type available,
    /// then convert to decomposed raw representations.
    fn unwrap(&self, py: Python) -> PyObject {
        match &self.0 {
            Subject::Int(v) => v.into_py(py),
            Subject::BInt(v) => v.clone().into_py(py),
            Subject::Real(v) => v.into_py(py),
            // Subject::Rational(v) => fractions.Fraction package
            // Subject::BReal(v) => (exp, mantissa)
            // Subject::IntMod(v) => (value, mod)
            Subject::Quad(v) => Into::<(i64, i64, i64, i64)>::into(v.clone()).into_py(py),
            Subject::Complex(v) => v.into_py(py),
            // Subject::BComplex(v) => (re exp, re mantissa, im exp, im mantissa)
            _ => unimplemented!(),
        }
    }

    fn __richcmp__(&self, other: SubjectInput, op: CompareOp) -> PyResult<bool> {
        let cmp = self.0.cmp(&other.0);
        Ok(match op {
            CompareOp::Eq => cmp == Ordering::Equal,
            CompareOp::Ne => cmp != Ordering::Equal,
            CompareOp::Ge => cmp != Ordering::Less,
            CompareOp::Gt => cmp == Ordering::Greater,
            CompareOp::Le => cmp != Ordering::Greater,
            CompareOp::Lt => cmp == Ordering::Less,
        })
    }

    fn __int__(&self, py: Python) -> PyResult<PyObject> {
        match &self.0 {
            Subject::Int(v) => Ok(v.into_py(py)),
            Subject::BInt(v) => Ok(v.clone().into_py(py)),
            _ => Err(PyTypeError::new_err("this subject is not an integer")),
        }
    }

    fn __float__(&self, py: Python) -> PyResult<PyObject> {
        match &self.0 {
            Subject::Real(v) => Ok(v.into_py(py)),
            // Subject::BReal(v) => ..,
            _ => Err(PyTypeError::new_err("this subject is not a float")),
        }
    }

    fn __complex__(&self, py: Python) -> PyResult<PyObject> {
        match &self.0 {
            Subject::Complex(v) => Ok(v.into_py(py)),
            // Subject::BComplex(v) => ..,
            _ => Err(PyTypeError::new_err("this subject is not a complex")),
        }
    }

    ////////////////////////////// Operators //////////////////////////////

    fn __add__(&self, rhs: SubjectInput) -> PyResult<PySubject> {
        match (&self.0, rhs.0) {
            (l, r) => Ok((l.clone() + r).into()),
        }
    }
    fn __radd__(&self, lhs: SubjectInput) -> PyResult<PySubject> {
        match (lhs.0, &self.0) {
            (l, r) => Ok((l + r.clone()).into()),
        }
    }
    fn __truediv__(&self, rhs: SubjectInput) -> PyResult<PySubject> {
        match (&self.0, rhs.0) {
            (l, r) => Ok((l.clone() / r).into()),
        }
    }
    fn __rtruediv__(&self, lhs: SubjectInput) -> PyResult<PySubject> {
        match (lhs.0, &self.0) {
            (l, r) => Ok((l / r.clone()).into()),
        }
    }
}

/// This struct is used for automatically parsing input as Subject
struct SubjectInput(pub Subject);

// conversion from python object
impl<'source> FromPyObject<'source> for SubjectInput {
    fn extract(ob: &'source PyAny) -> PyResult<SubjectInput> {
        let py = ob.py();

        unsafe {
            let ptr = ob.as_ptr();
            if ffi::PyLong_Check(ptr) > 0 {
                // input is integer
                let mut overflow = 0;
                let v = ffi::PyLong_AsLongLongAndOverflow(ptr, &mut overflow);
                if v == -1 && PyErr::occurred(py) {
                    Err(PyErr::fetch(py))
                } else if overflow != 0 {
                    // some code below is from https://github.com/PyO3/pyo3/blob/main/src/conversions/num_bigint.rs
                    let n_bits = ffi::_PyLong_NumBits(ptr) as usize;
                    let n_bytes = match n_bits {
                        usize::MAX => {
                            return Err(PyErr::fetch(py));
                        }
                        0 => 0,
                        n => (n as usize) / 8 + 1,
                    };
                    let long_ptr = ptr as *mut ffi::PyLongObject;
                    let num_big = if n_bytes <= 64 {
                        let mut buffer = [0; 64];
                        if ffi::_PyLong_AsByteArray(long_ptr, buffer.as_mut_ptr(), n_bytes, 1, 1)
                            == -1
                        {
                            return Err(PyErr::fetch(py));
                        }
                        BigInt::from_signed_bytes_le(&buffer[..n_bytes])
                    } else {
                        let mut buffer = vec![0; n_bytes];
                        if ffi::_PyLong_AsByteArray(long_ptr, buffer.as_mut_ptr(), n_bytes, 1, 1)
                            == -1
                        {
                            return Err(PyErr::fetch(py));
                        }
                        BigInt::from_signed_bytes_le(&buffer)
                    };
                    Ok(SubjectInput(Subject::BInt(num_big)))
                } else {
                    Ok(SubjectInput(Subject::Int(v)))
                }
            } else if ffi::PyFloat_Check(ptr) > 0 {
                // input is float
                let v = ffi::PyFloat_AsDouble(ptr);
                if v == 1. && PyErr::occurred(py) {
                    Err(PyErr::fetch(py))
                } else {
                    Ok(SubjectInput(Subject::Real(v)))
                }
            } else if ffi::PyComplex_Check(ptr) > 0 {
                // input is complex
                let v = ffi::PyComplex_AsCComplex(ptr);
                if v.real == 1. && PyErr::occurred(py) {
                    Err(PyErr::fetch(py))
                } else {
                    Ok(SubjectInput(Subject::Complex(Complex64::new(
                        v.real, v.imag,
                    ))))
                }
            } else {
                // we could support `fractions` and `decimal` package, but there's no stable C API
                // for these types, so it's better to let users use the `new()` function instead

                if let Ok(v) = ob.extract::<PySubject>() {
                    return Ok(SubjectInput(v.0));
                } else {
                    Err(PyValueError::new_err(
                        "Only integer, float, complex objects can be automatically converted to suan subject"
                    ))
                }
            }
        }
    }
}

macro_rules! register_simple_unary {
    ($module:ident, $method:ident) => {
        #[pyfunction]
        fn $method(s: SubjectInput) -> PyResult<PySubject> {
            Ok(core::pow_log::$method(s.0)?.into())
        }
        $module.add_function(wrap_pyfunction!($method, $module)?)?;
    };
}

mod pow_log {
    use super::*;

    pub(super) fn register(_py: Python, m: &PyModule) -> PyResult<()> {
        register_simple_unary!(m, sqrt);
        register_simple_unary!(m, sqrt_);
        register_simple_unary!(m, cbrt);
        register_simple_unary!(m, cbrt_);
        Ok(())
    }
}

mod number {
    use super::*;
    use either::Either;
    use num_prime::{FactorizationConfig, PrimalityTestConfig};
    use pyo3::exceptions::PyKeyError;
    use pyo3::types::{IntoPyDict, PyDict};

    pub(super) fn register(_py: Python, m: &PyModule) -> PyResult<()> {
        m.add_function(wrap_pyfunction!(is_prime, m)?)?;
        m.add_function(wrap_pyfunction!(factors, m)?)?;
        Ok(())
    }

    #[inline]
    fn parse_primality_config(args: Option<&PyDict>) -> PyResult<PrimalityTestConfig> {
        match args {
            None => Ok(PrimalityTestConfig::default()),
            Some(dict) => {
                let mut config = match dict.get_item("preset") {
                    None => PrimalityTestConfig::default(),
                    Some(v) => match v.extract()? {
                        "bpsw" => PrimalityTestConfig::bpsw(),
                        "strict" => PrimalityTestConfig::strict(),
                        _ => {
                            return Err(PyKeyError::new_err(
                                "unrecognized primality test config preset",
                            ))
                        }
                    },
                };
                if let Some(v) = dict.get_item("sprp_trials") {
                    config.sprp_trials = v.extract()?;
                }
                if let Some(v) = dict.get_item("sprp_random_trials") {
                    config.sprp_random_trials = v.extract()?;
                }
                if let Some(v) = dict.get_item("slprp_test") {
                    config.slprp_test = v.extract()?;
                }
                if let Some(v) = dict.get_item("eslprp_test") {
                    config.eslprp_test = v.extract()?;
                }
                Ok(config)
            }
        }
    }

    #[pyfunction(config = "**")]
    fn is_prime(s: SubjectInput, config: Option<&PyDict>) -> PyResult<bool> {
        Ok(suan_core::number::is_prime(s.0, Some(parse_primality_config(config)?))?.probably())
    }

    #[inline]
    fn parse_factor_config(args: Option<&PyDict>) -> PyResult<FactorizationConfig> {
        match args {
            None => Ok(FactorizationConfig::default()),
            Some(dict) => {
                let mut config = match dict.get_item("preset") {
                    None => FactorizationConfig::default(),
                    Some(v) => match v.extract()? {
                        "strict" => FactorizationConfig::strict(),
                        _ => {
                            return Err(PyKeyError::new_err(
                                "unrecognized factorization config preset",
                            ))
                        }
                    },
                };
                if let Some(v) = dict.get_item("primality_config") {
                    config.primality_config = parse_primality_config(Some(v.extract()?))?;
                }
                if let Some(v) = dict.get_item("td_limit") {
                    config.td_limit = Some(v.extract()?);
                }
                if let Some(v) = dict.get_item("rho_trials") {
                    config.rho_trials = v.extract()?;
                }
                Ok(config)
            }
        }
    }

    #[pyfunction(config = "**")]
    fn factors(py: Python, s: SubjectInput, config: Option<&PyDict>) -> PyResult<PyObject> {
        match suan_core::number::factors(s.0, Some(parse_factor_config(config)?))? {
            Either::Left(dict) => {
                let iter = dict.into_iter().map(|(k, v)| {
                    (
                        Into::<PySubject>::into(k).into_py(py),
                        Into::<PySubject>::into(v).into_py(py),
                    )
                });
                Ok(IntoPyDict::into_py_dict(iter, py).into())
            }
            Either::Right(vec) => {
                let vec: Vec<PySubject> = vec.into_iter().map(|v| v.into()).collect();
                Ok(vec.into_py(py))
            }
        }
    }
}

/// Create a new `Subject` instance from Python types.
#[pyfunction]
fn new(s: SubjectInput) -> PyResult<PySubject> {
    Ok(s.0.into())
    // TODO: also support converting from python fractions.Fraction and decimal.Decimal object
}

// available types to be specified in `parse()`
const TVAR_INFINITY: u8 = 0;
const TVAR_INT: u8 = 1;
const TVAR_REAL: u8 = 2;
const TVAR_RATIONAL: u8 = 3;
const TVAR_COMPLEX: u8 = 4;
const TVAR_QUAD: u8 = 5;

/// Create a new `Subject` instance from string. If argument `t` is specified,
/// then the string will be treated as the representation for the specific type.
#[pyfunction]
fn parse(s: &str, ty: Option<u8>) -> PyResult<PySubject> {
    if let Some(t) = ty {
        match t {
            _ => Err(PyValueError::new_err("invalid type argument")),
        }
    } else {
        Err(PyValueError::new_err("failed to parse the input string"))
    }
}

/// Create a `0` in given number type
#[pyfunction]
fn zero(ty: Option<u8>) -> PyResult<PySubject> {
    match ty.unwrap_or(TVAR_INT) {
        TVAR_INFINITY => Err(PyValueError::new_err("there is no 'zero' for infinity")),
        TVAR_INT => Ok(Subject::Int(0).into()),
        TVAR_RATIONAL => Ok(Subject::Rational(Rational64::from(0)).into()),
        TVAR_REAL => Ok(Subject::Real(0.).into()),
        TVAR_QUAD => Ok(Subject::Quad(Quadratic64::from(0)).into()),
        _ => Err(PyValueError::new_err("invalid type argument")),
    }
}

/// Create a `1` in given number type
#[pyfunction]
fn one(ty: Option<u8>) -> PyResult<PySubject> {
    match ty.unwrap_or(TVAR_INT) {
        TVAR_INFINITY => Err(PyValueError::new_err("there is no 'one' for infinity")),
        TVAR_INT => Ok(Subject::Int(1).into()),
        TVAR_RATIONAL => Ok(Subject::Rational(Rational64::from(1)).into()),
        TVAR_REAL => Ok(Subject::Real(1.).into()),
        TVAR_QUAD => Ok(Subject::Quad(Quadratic64::from(1)).into()),
        _ => Err(PyValueError::new_err("invalid type argument")),
    }
}

/// Calculator for advanced math in rust
#[pymodule]
fn suan(py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<PySubject>()?;

    // number creation functions
    m.add("INFINITY", TVAR_INFINITY)?;
    m.add("INT", TVAR_INT)?;
    m.add("REAL", TVAR_REAL)?;
    m.add("RATIONAL", TVAR_RATIONAL)?;
    m.add("COMPLEX", TVAR_COMPLEX)?;
    m.add("QUAD", TVAR_QUAD)?;
    m.add_function(wrap_pyfunction!(new, m)?)?;
    m.add_function(wrap_pyfunction!(parse, m)?)?;
    m.add_function(wrap_pyfunction!(zero, m)?)?;
    m.add_function(wrap_pyfunction!(one, m)?)?;

    // modules registration
    pow_log::register(py, m)?;
    number::register(py, m)?;

    Ok(())
}
