/*!
A trait for objects which may be used as a language for congruence closure
*/
use crate::UnionFind;
use core::hash::{BuildHasher, Hash, Hasher};
use std::{rc::Rc, sync::Arc};

/// An object which can be used as a language for congruence closure, given a context of type `U`
pub trait Language<C, I> {
    /// Visit the dependencies of this term
    fn visit_deps<E>(
        &self,
        visitor: &mut impl FnMut(&C, I) -> Result<(), E>,
        ctx: &C,
    ) -> Result<(), E>;

    /// Visit the dependencies of this term, potentially optimizing the context during the access operation
    fn visit_deps_mut<E>(
        &self,
        visitor: &mut impl FnMut(&mut C, I) -> Result<(), E>,
        ctx: &mut C,
    ) -> Result<(), E>;

    /// Hash this object modulo the given context + transformer
    fn hash_mod<H: Hasher>(&self, hasher: &mut H, hash_mod: &mut impl FnMut(&mut H, &C, I), ctx: &C);

    /// Hash this object modulo the given context + transformer
    fn hash_mod_mut<H: Hasher>(
        &self,
        hasher: &mut H,
        hash_mod: &mut impl FnMut(&mut H, &mut C, I),
        ctx: &mut C,
    );

    /// Whether this object is equal to another modulo some context + transformer
    fn eq_mod(&self, other: &Self, eq_mod: &mut impl FnMut(&C, I, I) -> bool, ctx: &C) -> bool;

    /// Whether this object is equal to another modulo some context + transformer
    fn eq_mod_mut(
        &self,
        other: &Self,
        map: &mut impl FnMut(&mut C, I, I) -> bool,
        ctx: &mut C,
    ) -> bool;

    /// Hash this object modulo the given context + transformer
    fn hash_one_mod<B: BuildHasher>(
        &self,
        hasher: &B,
        map: &mut impl FnMut(&mut B::Hasher, &C, I),
        ctx: &C,
    ) -> u64 {
        let mut hasher = hasher.build_hasher();
        self.hash_mod(&mut hasher, map, ctx);
        hasher.finish()
    }

    /// Hash this object modulo the given context + transformer
    fn hash_one_mod_mut<B: BuildHasher>(
        &self,
        hasher: &B,
        map: &mut impl FnMut(&mut B::Hasher, &mut C, I),
        ctx: &mut C,
    ) -> u64 {
        let mut hasher = hasher.build_hasher();
        self.hash_mod_mut(&mut hasher, map, ctx);
        hasher.finish()
    }
}

impl<L, I, C> Language<C, I> for &'_ mut L
where
    L: Language<C, I>,
{
    #[inline(always)]
    fn visit_deps<E>(
        &self,
        visitor: &mut impl FnMut(&C, I) -> Result<(), E>,
        ctx: &C,
    ) -> Result<(), E> {
        (**self).visit_deps(visitor, ctx)
    }

    #[inline(always)]
    fn visit_deps_mut<E>(
        &self,
        visitor: &mut impl FnMut(&mut C, I) -> Result<(), E>,
        ctx: &mut C,
    ) -> Result<(), E> {
        (**self).visit_deps_mut(visitor, ctx)
    }

    #[inline(always)]
    fn hash_mod<H: Hasher>(&self, hasher: &mut H, map: &mut impl FnMut(&mut H, &C, I), ctx: &C) {
        (**self).hash_mod(hasher, map, ctx)
    }

    #[inline(always)]
    fn hash_mod_mut<H: Hasher>(
        &self,
        hasher: &mut H,
        map: &mut impl FnMut(&mut H, &mut C, I),
        ctx: &mut C,
    ) {
        (**self).hash_mod_mut(hasher, map, ctx)
    }

    #[inline(always)]
    fn eq_mod(&self, other: &Self, map: &mut impl FnMut(&C, I, I) -> bool, ctx: &C) -> bool {
        (**self).eq_mod(other, map, ctx)
    }

    #[inline(always)]
    fn eq_mod_mut(
        &self,
        other: &Self,
        map: &mut impl FnMut(&mut C, I, I) -> bool,
        ctx: &mut C,
    ) -> bool {
        (**self).eq_mod_mut(other, map, ctx)
    }
}

impl<L, I, C> Language<C, I> for &'_ L
where
    L: Language<C, I>,
{
    #[inline(always)]
    fn visit_deps<E>(
        &self,
        visitor: &mut impl FnMut(&C, I) -> Result<(), E>,
        ctx: &C,
    ) -> Result<(), E> {
        (**self).visit_deps(visitor, ctx)
    }

    #[inline(always)]
    fn visit_deps_mut<E>(
        &self,
        visitor: &mut impl FnMut(&mut C, I) -> Result<(), E>,
        ctx: &mut C,
    ) -> Result<(), E> {
        (**self).visit_deps_mut(visitor, ctx)
    }

    #[inline(always)]
    fn hash_mod<H: Hasher>(&self, hasher: &mut H, map: &mut impl FnMut(&mut H, &C, I), ctx: &C) {
        (**self).hash_mod(hasher, map, ctx)
    }

    #[inline(always)]
    fn hash_mod_mut<H: Hasher>(
        &self,
        hasher: &mut H,
        map: &mut impl FnMut(&mut H, &mut C, I),
        ctx: &mut C,
    ) {
        (**self).hash_mod_mut(hasher, map, ctx)
    }

    #[inline(always)]
    fn eq_mod(&self, other: &Self, map: &mut impl FnMut(&C, I, I) -> bool, ctx: &C) -> bool {
        (**self).eq_mod(other, map, ctx)
    }

    #[inline(always)]
    fn eq_mod_mut(
        &self,
        other: &Self,
        map: &mut impl FnMut(&mut C, I, I) -> bool,
        ctx: &mut C,
    ) -> bool {
        (**self).eq_mod_mut(other, map, ctx)
    }
}

impl<L, I, C> Language<C, I> for Box<L>
where
    L: Language<C, I>,
{
    #[inline(always)]
    fn visit_deps<E>(
        &self,
        visitor: &mut impl FnMut(&C, I) -> Result<(), E>,
        ctx: &C,
    ) -> Result<(), E> {
        (**self).visit_deps(visitor, ctx)
    }

    #[inline(always)]
    fn visit_deps_mut<E>(
        &self,
        visitor: &mut impl FnMut(&mut C, I) -> Result<(), E>,
        ctx: &mut C,
    ) -> Result<(), E> {
        (**self).visit_deps_mut(visitor, ctx)
    }

    #[inline(always)]
    fn hash_mod<H: Hasher>(&self, hasher: &mut H, map: &mut impl FnMut(&mut H, &C, I), ctx: &C) {
        (**self).hash_mod(hasher, map, ctx)
    }

    #[inline(always)]
    fn hash_mod_mut<H: Hasher>(
        &self,
        hasher: &mut H,
        map: &mut impl FnMut(&mut H, &mut C, I),
        ctx: &mut C,
    ) {
        (**self).hash_mod_mut(hasher, map, ctx)
    }

    #[inline(always)]
    fn eq_mod(&self, other: &Self, map: &mut impl FnMut(&C, I, I) -> bool, ctx: &C) -> bool {
        (**self).eq_mod(other, map, ctx)
    }

    #[inline(always)]
    fn eq_mod_mut(
        &self,
        other: &Self,
        map: &mut impl FnMut(&mut C, I, I) -> bool,
        ctx: &mut C,
    ) -> bool {
        (**self).eq_mod_mut(other, map, ctx)
    }
}

impl<L, I, C> Language<C, I> for Arc<L>
where
    L: Language<C, I>,
{
    #[inline(always)]
    fn visit_deps<E>(
        &self,
        visitor: &mut impl FnMut(&C, I) -> Result<(), E>,
        ctx: &C,
    ) -> Result<(), E> {
        (**self).visit_deps(visitor, ctx)
    }

    #[inline(always)]
    fn visit_deps_mut<E>(
        &self,
        visitor: &mut impl FnMut(&mut C, I) -> Result<(), E>,
        ctx: &mut C,
    ) -> Result<(), E> {
        (**self).visit_deps_mut(visitor, ctx)
    }

    #[inline(always)]
    fn hash_mod<H: Hasher>(&self, hasher: &mut H, map: &mut impl FnMut(&mut H, &C, I), ctx: &C) {
        (**self).hash_mod(hasher, map, ctx)
    }

    #[inline(always)]
    fn hash_mod_mut<H: Hasher>(
        &self,
        hasher: &mut H,
        map: &mut impl FnMut(&mut H, &mut C, I),
        ctx: &mut C,
    ) {
        (**self).hash_mod_mut(hasher, map, ctx)
    }

    #[inline(always)]
    fn eq_mod(&self, other: &Self, map: &mut impl FnMut(&C, I, I) -> bool, ctx: &C) -> bool {
        (**self).eq_mod(other, map, ctx)
    }

    #[inline(always)]
    fn eq_mod_mut(
        &self,
        other: &Self,
        map: &mut impl FnMut(&mut C, I, I) -> bool,
        ctx: &mut C,
    ) -> bool {
        (**self).eq_mod_mut(other, map, ctx)
    }
}

impl<L, I, C> Language<C, I> for Rc<L>
where
    L: Language<C, I>,
{
    #[inline(always)]
    fn visit_deps<E>(
        &self,
        visitor: &mut impl FnMut(&C, I) -> Result<(), E>,
        ctx: &C,
    ) -> Result<(), E> {
        (**self).visit_deps(visitor, ctx)
    }

    #[inline(always)]
    fn visit_deps_mut<E>(
        &self,
        visitor: &mut impl FnMut(&mut C, I) -> Result<(), E>,
        ctx: &mut C,
    ) -> Result<(), E> {
        (**self).visit_deps_mut(visitor, ctx)
    }

    #[inline(always)]
    fn hash_mod<H: Hasher>(&self, hasher: &mut H, map: &mut impl FnMut(&mut H, &C, I), ctx: &C) {
        (**self).hash_mod(hasher, map, ctx)
    }

    #[inline(always)]
    fn hash_mod_mut<H: Hasher>(
        &self,
        hasher: &mut H,
        map: &mut impl FnMut(&mut H, &mut C, I),
        ctx: &mut C,
    ) {
        (**self).hash_mod_mut(hasher, map, ctx)
    }

    #[inline(always)]
    fn eq_mod(&self, other: &Self, map: &mut impl FnMut(&C, I, I) -> bool, ctx: &C) -> bool {
        (**self).eq_mod(other, map, ctx)
    }

    #[inline(always)]
    fn eq_mod_mut(
        &self,
        other: &Self,
        map: &mut impl FnMut(&mut C, I, I) -> bool,
        ctx: &mut C,
    ) -> bool {
        (**self).eq_mod_mut(other, map, ctx)
    }
}

impl<A, B, C, I> Language<C, I> for (A, B)
where
    A: Language<C, I>,
    B: Language<C, I>,
{
    #[inline(always)]
    fn visit_deps<E>(
        &self,
        visitor: &mut impl FnMut(&C, I) -> Result<(), E>,
        ctx: &C,
    ) -> Result<(), E> {
        self.0.visit_deps(visitor, ctx)?;
        self.1.visit_deps(visitor, ctx)
    }

    #[inline]
    fn visit_deps_mut<E>(
        &self,
        visitor: &mut impl FnMut(&mut C, I) -> Result<(), E>,
        ctx: &mut C,
    ) -> Result<(), E> {
        self.0.visit_deps_mut(visitor, ctx)?;
        self.1.visit_deps_mut(visitor, ctx)
    }

    #[inline(always)]
    fn hash_mod<H: Hasher>(&self, hasher: &mut H, map: &mut impl FnMut(&mut H, &C, I), ctx: &C) {
        self.0.hash_mod(hasher, map, ctx);
        self.1.hash_mod(hasher, map, ctx)
    }

    #[inline(always)]
    fn hash_mod_mut<H: Hasher>(
        &self,
        hasher: &mut H,
        map: &mut impl FnMut(&mut H, &mut C, I),
        ctx: &mut C,
    ) {
        self.0.hash_mod_mut(hasher, map, ctx);
        self.1.hash_mod_mut(hasher, map, ctx)
    }

    #[inline(always)]
    fn eq_mod(&self, other: &Self, map: &mut impl FnMut(&C, I, I) -> bool, ctx: &C) -> bool {
        self.0.eq_mod(&other.0, map, ctx) && self.1.eq_mod(&other.1, map, ctx)
    }

    #[inline(always)]
    fn eq_mod_mut(
        &self,
        other: &Self,
        map: &mut impl FnMut(&mut C, I, I) -> bool,
        ctx: &mut C,
    ) -> bool {
        self.0.eq_mod_mut(&other.0, map, ctx) && self.1.eq_mod_mut(&other.1, map, ctx)
    }
}

impl<U: UnionFind<usize>> Language<U, usize> for usize {
    #[inline]
    fn visit_deps_mut<E>(
        &self,
        visitor: &mut impl FnMut(&mut U, usize) -> Result<(), E>,
        ctx: &mut U,
    ) -> Result<(), E> {
        let ix = ctx.find_mut(*self);
        visitor(ctx, ix)
    }

    #[inline]
    fn visit_deps<E>(
        &self,
        visitor: &mut impl FnMut(&U, usize) -> Result<(), E>,
        ctx: &U,
    ) -> Result<(), E> {
        let ix = ctx.find(*self);
        visitor(ctx, ix)
    }

    #[inline]
    fn hash_mod<H: Hasher>(
        &self,
        hasher: &mut H,
        map: &mut impl FnMut(&mut H, &U, usize),
        ctx: &U,
    ) {
        map(hasher, ctx, *self).hash(hasher)
    }

    #[inline]
    fn hash_mod_mut<H: Hasher>(
        &self,
        hasher: &mut H,
        map: &mut impl FnMut(&mut H, &mut U, usize),
        ctx: &mut U,
    ) {
        map(hasher, ctx, *self)
    }

    #[inline]
    fn eq_mod(
        &self,
        other: &Self,
        map: &mut impl FnMut(&U, usize, usize) -> bool,
        ctx: &U,
    ) -> bool {
        map(ctx, *self, *other)
    }

    #[inline]
    fn eq_mod_mut(
        &self,
        other: &Self,
        map: &mut impl FnMut(&mut U, usize, usize) -> bool,
        ctx: &mut U,
    ) -> bool {
        map(ctx, *self, *other)
    }
}
