use crate::cnf::CnfFormula;
use crate::VarIdx;
use indexmap::IndexSet;

mod clause;
pub use clause::Clause;

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct SparseSat {
    number_of_variables: usize,
    clauses: IndexSet<Clause>,
}

impl SparseSat {
    pub fn new() -> Self {
        Self {
            number_of_variables: 0,
            clauses: IndexSet::new(),
        }
    }

    pub fn with_capacity(capacity: usize) -> Self {
        Self {
            number_of_variables: 0,
            clauses: IndexSet::with_capacity(capacity),
        }
    }

    pub fn add_variable(&mut self) -> VarIdx {
        let idx = self.number_of_variables;
        self.number_of_variables += 1;
        idx
    }

    pub fn insert(&mut self, clause: Clause) {
        let max_index = clause.max_index();
        if max_index >= self.number_of_variables {
            self.number_of_variables = max_index + 1;
        }
        if clause.has_duplicate_indices() {
            panic!("variables must all be differents");
        }
        self.clauses.insert(clause);
    }

    pub fn insert_many(&mut self, clauses: impl IntoIterator<Item = Clause>) {
        for clause in clauses {
            self.insert(clause)
        }
    }

    pub fn number_of_variables(&self) -> usize {
        self.number_of_variables
    }

    pub fn number_of_clauses(&self) -> usize {
        self.clauses.len()
    }

    pub fn clauses(&self) -> impl Iterator<Item = &Clause> {
        self.clauses.iter()
    }

    pub fn has_clause(&self, clause: &Clause) -> bool {
        self.clauses.contains(clause)
    }

    pub fn to_cnf(&self) -> CnfFormula {
        self.clauses().flat_map(|clause| clause.to_cnf()).fold(
            CnfFormula::with_variables(self.number_of_variables),
            CnfFormula::insert,
        )
    }

}

#[cfg(test)]
mod test {
    use super::*;
    use crate::cnf::CnfClause;
    use crate::cnf::CnfVariable::{Neg, Pos};
    use Clause::{And, Equal, EvenXor, OddXor, Or};

    #[test]
    fn single_even_xor() {
        let mut sat = SparseSat::new();
        sat.insert(EvenXor(vec![0, 1, 3]));
        let expected_cnf = CnfFormula::with_variables(4)
            .insert(CnfClause::new_xor(vec![Neg(0), Pos(1), Pos(3)]));
        assert_eq!(sat.to_cnf(), expected_cnf);
    }

    #[test]
    fn single_odd_xor() {
        let mut sat = SparseSat::new();
        sat.insert(OddXor(vec![0, 1, 2]));
        let expected_cnf = CnfFormula::with_variables(3)
            .insert(CnfClause::new_xor(vec![Pos(0), Pos(1), Pos(2)]));
        assert_eq!(sat.to_cnf(), expected_cnf);
    }

    #[test]
    fn single_and() {
        let mut sat = SparseSat::new();
        sat.insert(And(0, 1, 2));
        let expected_cnf = CnfFormula::with_variables(3)
            .insert(CnfClause::new(vec![Pos(0), Pos(1), Neg(2)]))
            .insert(CnfClause::new(vec![Neg(0), Neg(1), Pos(2)]))
            .insert(CnfClause::new(vec![Pos(0), Neg(1), Neg(2)]))
            .insert(CnfClause::new(vec![Neg(0), Pos(1), Neg(2)]));
        assert_eq!(sat.to_cnf(), expected_cnf);
    }

    #[test]
    fn single_or() {
        let mut sat = SparseSat::new();
        sat.insert(Or(0, 1, 2));
        let expected_cnf = CnfFormula::with_variables(3)
            .insert(CnfClause::new(vec![Neg(0), Pos(1), Pos(2)]))
            .insert(CnfClause::new(vec![Pos(0), Neg(1), Pos(2)]))
            .insert(CnfClause::new(vec![Pos(0), Pos(1), Neg(2)]))
            .insert(CnfClause::new(vec![Neg(0), Neg(1), Pos(2)]));
        assert_eq!(sat.to_cnf(), expected_cnf);
    }

    #[test]
    fn single_equal() {
        let mut sat = SparseSat::new();
        sat.insert(Equal(0, 1));
        let expected_cnf = CnfFormula::with_variables(2)
            .insert(CnfClause::new(vec![Neg(0), Pos(1)]))
            .insert(CnfClause::new(vec![Pos(0), Neg(1)]));
        assert_eq!(sat.to_cnf(), expected_cnf);
    }

    #[test]
    fn one_of_each() {
        let mut sat = SparseSat::new();
        sat.insert_many([
            EvenXor(vec![0, 1, 5]),
            OddXor(vec![1, 2, 4]),
            And(2, 3, 6),
            Or(4, 6, 7),
            Equal(5, 7),
        ]);
        let expected_cnf = CnfFormula::with_variables(8)
            // Even xor
            .insert(CnfClause::new_xor(vec![Neg(0), Pos(1), Pos(5)]))
            // Odd xor
            .insert(CnfClause::new_xor(vec![Pos(1), Pos(2), Pos(4)]))
            // And
            .insert(CnfClause::new(vec![Pos(2), Pos(3), Neg(6)]))
            .insert(CnfClause::new(vec![Neg(2), Pos(3), Neg(6)]))
            .insert(CnfClause::new(vec![Pos(2), Neg(3), Neg(6)]))
            .insert(CnfClause::new(vec![Neg(2), Neg(3), Pos(6)]))
            // Or
            .insert(CnfClause::new(vec![Neg(4), Pos(6), Pos(7)]))
            .insert(CnfClause::new(vec![Pos(4), Neg(6), Pos(7)]))
            .insert(CnfClause::new(vec![Pos(4), Pos(6), Neg(7)]))
            .insert(CnfClause::new(vec![Neg(4), Neg(6), Pos(7)]))
            // Equal
            .insert(CnfClause::new(vec![Neg(5), Pos(7)]))
            .insert(CnfClause::new(vec![Pos(5), Neg(7)]));
        assert_eq!(sat.to_cnf(), expected_cnf);
    }
}
