use super::CnfVariable;
use super::ParsingError;
use crate::VarIdx;
use itertools::Itertools;
use serde::{Deserialize, Serialize};

#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)]
pub struct CnfClause {
    variables: Vec<CnfVariable>,
    xor: bool,
}

impl CnfClause {
    pub fn new(variables: Vec<CnfVariable>) -> Self {
        Self {
            variables,
            xor: false,
        }
    }

    pub fn new_xor(variables: Vec<CnfVariable>) -> Self {
        Self {
            variables,
            xor: true,
        }
    }

    pub fn from_str(clause: &str) -> Result<Self, ParsingError> {
        if !clause.ends_with("0") {
            return Err(ParsingError);
        }
        let mut clause = clause.get(0..clause.len() - 1).unwrap();
        let mut xor = false;
        if clause.starts_with("x") {
            clause = clause.get(1..).ok_or(ParsingError)?;
            xor = true;
        }
        clause
            .split_ascii_whitespace()
            .map(|var| CnfVariable::from_str(var))
            .fold_ok(Vec::new(), |mut acc, var| {
                acc.push(var);
                acc
            })
            .map(|vars| {
                if xor {
                    Self::new_xor(vars)
                } else {
                    Self::new(vars)
                }
            })
    }

    pub fn max_index(&self) -> Option<VarIdx> {
        self.variables.iter().max().map(|variable| variable.index())
    }

    pub fn to_dimacs(&self) -> String {
        let start = if self.xor {
            String::from("x")
        } else {
            String::new()
        };
        let variables = self
            .variables
            .iter()
            .map(|variable| variable.to_dimacs())
            .chain(std::iter::once(String::from("0")))
            .join(" ");
        format!("{}{}", start, variables)
    }

    pub fn pos_variables<'a>(&'a self) -> impl Iterator<Item = VarIdx> + 'a {
        self.variables.iter().filter_map(|var| {
            if var.is_pos() {
                Some(var.index())
            } else {
                None
            }
        })
    }

    pub fn neg_variables<'a>(&'a self) -> impl Iterator<Item = VarIdx> + 'a {
        self.variables.iter().filter_map(|var| {
            if var.is_neg() {
                Some(var.index())
            } else {
                None
            }
        })
    }
}

impl From<CnfClause> for Vec<CnfVariable> {
    fn from(clause: CnfClause) -> Self {
        clause.variables
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use CnfVariable::{Neg, Pos};

    #[test]
    fn cnf() {
        let dimacs = CnfClause::new(vec![Pos(0), Neg(10), Pos(15), Neg(321)]).to_dimacs();
        assert_eq!(dimacs, String::from("1 -11 16 -322 0"));
    }

    #[test]
    fn xor() {
        let dimacs = CnfClause::new_xor(vec![Pos(0), Neg(10), Pos(15), Neg(321)]).to_dimacs();
        assert_eq!(dimacs, String::from("x1 -11 16 -322 0"));
    }

    #[test]
    fn cnf_from_str() {
        let clause = "1 -5 -30 40 0";
        let expected = CnfClause::new(vec![Pos(0), Neg(4), Neg(29), Pos(39)]);
        assert_eq!(CnfClause::from_str(clause).unwrap(), expected);
    }

    #[test]
    fn xor_from_str() {
        let clause = "x1 -5 -30 40 0";
        let expected = CnfClause::new_xor(vec![Pos(0), Neg(4), Neg(29), Pos(39)]);
        assert_eq!(CnfClause::from_str(clause).unwrap(), expected);
    }

    #[test]
    fn invalid_from_str() {
        let clause = "a1 -5 -30 40 0";
        assert!(CnfClause::from_str(clause).is_err());
    }
}
