use crate::cnf::CnfClause;
use crate::cnf::CnfVariable;
use crate::VarIdx;
use itertools::Itertools;

use Clause::{And, Equal, EvenXor, FixNeg, FixPos, Not, NotAllNeg, NotAllPos, OddXor, Or};
use CnfVariable::{Neg, Pos};

#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub enum Clause {
    // True if a and b = c
    And(VarIdx, VarIdx, VarIdx),
    // True if a or b = c
    Or(VarIdx, VarIdx, VarIdx),
    // True if a = b
    Equal(VarIdx, VarIdx),
    // True if a != b
    Not(VarIdx, VarIdx),
    // True if a = 0
    FixNeg(VarIdx),
    // True if a = 1
    FixPos(VarIdx),
    // True if xor of all var = 0.
    EvenXor(Vec<VarIdx>),
    // True if xor of all var = 1.
    OddXor(Vec<VarIdx>),
    // True if at least one variable is positive.
    NotAllNeg(Vec<VarIdx>),
    // True if at least one variable is negative.
    NotAllPos(Vec<VarIdx>),
}

impl Clause {
    pub fn has_duplicate_indices(&self) -> bool {
        match self {
            And(a, b, c) | Or(a, b, c) => a == b || b == c || c == a,
            Equal(a, b) | Not(a, b) => a == b,
            FixNeg(_) | FixPos(_) => false,
            EvenXor(vars) | OddXor(vars) | NotAllNeg(vars) | NotAllPos(vars) => {
                vars.iter().unique().count() < vars.len()
            }
        }
    }

    pub fn max_index(&self) -> VarIdx {
        use std::cmp::max;
        match self {
            And(a, b, c) | Or(a, b, c) => *max(a, max(b, c)),
            Equal(a, b) | Not(a, b) => *max(a, b),
            FixNeg(a) | FixPos(a) => *a,
            EvenXor(vars) | OddXor(vars) | NotAllNeg(vars) | NotAllPos(vars) => {
                vars.iter().max().cloned().unwrap_or(0)
            }
        }
    }

    pub fn to_cnf(&self) -> Vec<CnfClause> {
        match self {
            And(a, b, c) => {
                let (a, b, c) = (*a, *b, *c);
                vec![
                    CnfClause::new(vec![Pos(a), Pos(b), Neg(c)]),
                    CnfClause::new(vec![Neg(a), Neg(b), Pos(c)]),
                    CnfClause::new(vec![Pos(a), Neg(b), Neg(c)]),
                    CnfClause::new(vec![Neg(a), Pos(b), Neg(c)]),
                ]
            }
            Or(a, b, c) => {
                let (a, b, c) = (*a, *b, *c);
                vec![
                    CnfClause::new(vec![Neg(a), Pos(b), Pos(c)]),
                    CnfClause::new(vec![Pos(a), Neg(b), Pos(c)]),
                    CnfClause::new(vec![Pos(a), Pos(b), Neg(c)]),
                    CnfClause::new(vec![Neg(a), Neg(b), Pos(c)]),
                ]
            }
            Equal(a, b) => {
                let (a, b) = (*a, *b);
                vec![
                    CnfClause::new(vec![Pos(a), Neg(b)]),
                    CnfClause::new(vec![Neg(a), Pos(b)]),
                ]
            }
            Not(a, b) => {
                let (a, b) = (*a, *b);
                vec![
                    CnfClause::new(vec![Pos(a), Pos(b)]),
                    CnfClause::new(vec![Neg(a), Neg(b)]),
                ]
            }
            FixNeg(a) => {
                vec![CnfClause::new(vec![Neg(*a)])]
            }
            FixPos(a) => {
                vec![CnfClause::new(vec![Pos(*a)])]
            }
            NotAllNeg(vars) => {
                vec![CnfClause::new(vars.iter().cloned().map(Pos).collect())]
            }
            NotAllPos(vars) => {
                vec![CnfClause::new(vars.iter().cloned().map(Neg).collect())]
            }
            EvenXor(vars) => {
                let mut vars = vars.iter().cloned().map(Pos).collect_vec();
                if let Some(v) = vars.first_mut() {
                    *v = v.flip();
                }
                vec![CnfClause::new_xor(vars)]
            }
            OddXor(vars) => {
                let vars = vars.iter().cloned().map(Pos).collect_vec();
                vec![CnfClause::new_xor(vars)]
            }
        }
    }
}
