//! MIR based intermediate result formatting.
//!
//! The MIR node types that deal in annotating intermediate results
//! must not be specific to any particular dice operators,
//! as the plan is to build new operators on top of MIR and make the
//! evaluation backend agnostic to how formatting of the structured
//! data it produces ends up being done.
//!
//! Stuff in here doesn't take much care to be fast, since
//! we know it will not be used inside of an especially hot loop.
use ::core::iter::{self, FromIterator, FusedIterator};
use ::std::collections::BTreeMap;

/// A MIR node kind that deals in saving and annotating
/// intermediate results of dice expressions for use in textual
/// output formatting.
#[derive(Debug, Clone)]
pub(super) enum FmtNode {
    /// Creates an empty list.
    MakeList,
    /// Pushes something to a list.
    PushToList,
    /// Records a value for use in output.
    /// Clones instead of consuming.
    Record,
    /// Annotate the output tree with data known prior to
    /// dice program evaluation.
    Annotate(Annotation),
    /// A region argument node kind specifically for carrying
    /// intermediate results.
    /// Making this distinct from regular region arguments allows us to
    /// avoid scanning edges when lowering.
    RegionArgument(u8),
}

/// An annotation is a descriptive piece of data whose contents are known
/// prior to dice program evaluation, that can be attached to nodes in the
/// tree structure that keeps important intermediate results of a running
/// dice program.
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Annotation {
    // TODO: consider storing value with constant annotation
    Constant,
    Roll {
        count: i64,
        sides: i64,
    },
    KeepHigh {
        keep_count: i64,
    },
    KeepLow {
        keep_count: i64,
    },
    Explode {
        count: i64,
        sides: i64,
    },
    Add,
    Subtract,
    UnarySubtract,
    #[allow(dead_code)]
    UnaryAdd,
}

/// A naive representation of tree structured output by dice programs.
#[derive(Debug)]
#[non_exhaustive]
pub enum OutputNode {
    Value(super::stack::Value),
    List(Vec<OutputNode>),
    Annotated(Annotation, Box<OutputNode>),
}

#[derive(Debug)]
#[non_exhaustive]
pub(super) enum FormattingFailure {
    /// Structured dice program output contains something the formatter
    /// doesn't know how to format.
    /// This will only happen when I'm in the process of adding new
    /// user-facing operators or when I've got an old version
    /// deployed in combination with MIR loaded from a database
    /// to which a newer version has saved stuff.
    /// May also eventually happen with a user-facing scripting language.
    UnknownStructure,
}

struct MultiSet<T> {
    counts: BTreeMap<T, usize>,
}
impl<T: Ord> FromIterator<T> for MultiSet<T> {
    fn from_iter<I>(iter: I) -> Self
    where
        I: IntoIterator<Item = T>,
    {
        let mut counts = BTreeMap::new();
        for item in iter {
            *counts.entry(item).or_insert(0) += 1;
        }
        Self { counts }
    }
}
struct MultiSetIter<T>
where
    <BTreeMap<T, usize> as IntoIterator>::IntoIter: FusedIterator,
{
    iter: <BTreeMap<T, usize> as IntoIterator>::IntoIter,
    current: Option<(T, usize)>,
}
impl<T: Copy> Iterator for MultiSetIter<T> {
    type Item = T;
    fn next(&mut self) -> Option<Self::Item> {
        match &mut self.current {
            Some((val, count)) => {
                if *count > 1 {
                    *count -= 1;
                    Some(*val)
                } else {
                    let val = *val;
                    self.current = None;
                    Some(val)
                }
            }
            None => {
                match self.iter.next() {
                    Some((val, count)) => {
                        // It's *probably* an invariant of the MultiSet type that
                        // any element it contains has a multiplicity of at least 1,
                        // but I'm being careful for now.
                        let count = count.checked_sub(1);
                        match count {
                            Some(count) => {
                                if count > 0 {
                                    self.current = Some((val, count));
                                    Some(val)
                                } else {
                                    Some(val)
                                }
                            }
                            None => None,
                        }
                    }
                    None => None,
                }
            }
        }
    }
}
impl<T: Copy> IntoIterator for MultiSet<T> {
    type Item = T;
    type IntoIter = MultiSetIter<T>;
    fn into_iter(self) -> Self::IntoIter {
        Self::IntoIter {
            iter: self.counts.into_iter(),
            current: None,
        }
    }
}
impl<T: Ord + Copy> MultiSet<T> {
    /// This method currently assumes that `other` is a strict subset of `self`.
    fn subtract(&self, other: &MultiSet<T>) -> MultiSet<T> {
        use ::core::cmp::Ordering;
        let mut new: BTreeMap<T, usize> = BTreeMap::new();
        let mut me = self.counts.iter();
        let mut other = other.counts.iter();
        let mut me_item = me.next();
        let mut other_item = other.next();
        loop {
            if let (Some(me_elem), Some(other_elem)) = (me_item, other_item) {
                match me_elem.0.cmp(other_elem.0) {
                    Ordering::Less => {
                        // The current `me_elem` is not found in `other`.
                        new.insert(*me_elem.0, *me_elem.1);
                        me_item = me.next();
                    }
                    Ordering::Equal => {
                        // Saturating to zero subtraction of other_elem's count from `me_elem`'s count.
                        let count = me_elem.1.saturating_sub(*other_elem.1);
                        if count > 0 {
                            new.insert(*me_elem.0, count);
                        }
                        me_item = me.next();
                        other_item = other.next();
                    }
                    // TODO: this actually shouldn't be possible, if `other` is a strict
                    // subset of `self`, since `BTreeMap` keys are sorted in ascending order.
                    Ordering::Greater => panic!("`other` wasn't a strict subset of `self`"),
                }
            } else {
                // We have reached the end of one of the iterators, and so should either include
                // or discard the rest of the elements available, depending on which, if either,
                // is still non-empty.

                // Since we currently require that `other` is a strict subset of `self`,
                // we can assume that either `self` is the non-empty one or both are empty.
                // For correctness' sake, though, let's place an assertion for that.
                assert!(matches!(other_item, None));
                while let Some(me_elem) = me_item {
                    new.insert(*me_elem.0, *me_elem.1);
                    me_item = me.next();
                }
                break;
            }
        }
        Self { counts: new }
    }
}

// Ultimately, this is a *multiset difference*.
// All filtered dice results are sorted, so that we can identify which parts are left out.
fn diff(subset: &[i64], superset: &[i64]) -> (Box<[i64]>, Box<[i64]>) {
    let (subset, superset): (MultiSet<i64>, MultiSet<i64>) = (
        subset.into_iter().copied().collect(),
        superset.into_iter().copied().collect(),
    );
    let leftover = superset.subtract(&subset);
    let subset = subset.into_iter().collect::<Vec<i64>>().into_boxed_slice();
    let leftover = leftover
        .into_iter()
        .collect::<Vec<i64>>()
        .into_boxed_slice();
    (subset, leftover)
}

pub(super) fn fmt_default_impl(
    buf: &mut String,
    output: &OutputNode,
) -> Result<(), FormattingFailure> {
    use super::stack::Value as SV;
    use OutputNode::*;
    match output {
        Annotated(Annotation::Constant, val) => {
            if let Value(SV::Integer(int)) = &**val {
                let mut itoa_buf = itoa::Buffer::new();
                buf.push_str(itoa_buf.format(*int));
                Ok(())
            } else {
                Err(FormattingFailure::UnknownStructure)
            }
        }
        Annotated(Annotation::Roll { count, sides }, val) => {
            if let Value(SV::Set(partial_sums)) = &**val {
                if *count > 0 {
                    buf.push('(');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*count));
                    buf.push('d');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*sides));
                    buf.push_str(" → ");
                    let (first, rest) = (partial_sums[0], &partial_sums[1..]);
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(first));
                    for part in rest {
                        buf.push_str(" + ");
                        let mut itoa_buf = itoa::Buffer::new();
                        buf.push_str(itoa_buf.format(*part));
                    }
                    buf.push(')');
                } else {
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*count));
                    buf.push('d');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*sides));
                }
                Ok(())
            } else {
                Err(FormattingFailure::UnknownStructure)
            }
        }
        Annotated(keep, inner)
            if matches!(
                keep,
                Annotation::KeepHigh { .. } | Annotation::KeepLow { .. }
            ) =>
        {
            let keep_count = match keep {
                Annotation::KeepHigh { keep_count } | Annotation::KeepLow { keep_count } => {
                    keep_count
                }
                _ => unreachable!(),
            };
            let op = match keep {
                Annotation::KeepHigh { .. } => "k",
                Annotation::KeepLow { .. } => "kl",
                _ => unreachable!(),
            };
            if let List(list) = &**inner {
                if let [Annotated(Annotation::Roll { count, sides }, roll), Value(SV::Set(filtered))] =
                    &**list
                {
                    if let Value(SV::Set(unfiltered)) = &**roll {
                        if *count > 0 && *keep_count > 0 {
                            buf.push('(');
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*count));
                            buf.push('d');
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*sides));
                            buf.push_str(op);
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*keep_count));
                            buf.push_str(" → ");
                            let (mut kept, mut unkept) = diff(filtered, unfiltered);
                            kept.reverse();
                            unkept.reverse();
                            let (first, rest) = (kept[0], &kept[1..]);
                            buf.push_str("**");
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(first));
                            for part in rest {
                                buf.push_str(" + ");
                                let mut itoa_buf = itoa::Buffer::new();
                                buf.push_str(itoa_buf.format(*part));
                            }
                            buf.push_str("**");
                            for part in unkept.into_iter() {
                                buf.push_str(" + ");
                                let mut itoa_buf = itoa::Buffer::new();
                                buf.push_str(itoa_buf.format(*part));
                            }
                            buf.push(')');
                        } else {
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*count));
                            buf.push('d');
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*sides));
                            buf.push_str(op);
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*keep_count));
                        }
                        Ok(())
                    } else {
                        Err(FormattingFailure::UnknownStructure)
                    }
                } else {
                    Err(FormattingFailure::UnknownStructure)
                }
            } else {
                Err(FormattingFailure::UnknownStructure)
            }
        }
        Annotated(Annotation::Explode { count, sides }, inner) => {
            if let List(iterations) = &**inner {
                use super::stack::Value;
                let iterations = iterations
                    .into_iter()
                    .map(|iter| match iter {
                        OutputNode::Value(Value::Set(set)) => Ok(&**set),
                        _ => Err(()),
                    })
                    .collect::<Result<Vec<&[i64]>, ()>>()
                    .map_err(|()| FormattingFailure::UnknownStructure)?;
                if iterations.len() >= 1 && iterations[0].len() > 0 {
                    buf.push('(');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*count));
                    buf.push('d');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*sides));
                    buf.push('!');
                    buf.push_str(" → ");
                    if let [&[first, ref first_rest @ ..], rest @ ..] = &*iterations {
                        let mut itoa_buf = itoa::Buffer::new();
                        buf.push_str(itoa_buf.format(first));
                        for (i, &part) in iter::once(first_rest)
                            .chain(rest.into_iter().copied())
                            .enumerate()
                            .map(|(i, iter)| iter::repeat(i).zip(iter.into_iter()))
                            .flatten()
                        {
                            buf.push_str(" + ");
                            (0..i).for_each(|_| buf.push('['));
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(part));
                            (0..i).for_each(|_| buf.push(']'));
                        }
                    }
                    buf.push(')');
                } else {
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*count));
                    buf.push('d');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*sides));
                    buf.push('!');
                }
                Ok(())
            } else {
                Err(FormattingFailure::UnknownStructure)
            }
        }
        Annotated(op @ (Annotation::Add | Annotation::Subtract), inner) => {
            if let List(list) = &**inner {
                if let [left, right] = &**list {
                    fmt_default_impl(&mut *buf, left)?;
                    match op {
                        Annotation::Add => buf.push_str(" + "),
                        Annotation::Subtract => buf.push_str(" - "),
                        _ => (),
                    }
                    fmt_default_impl(&mut *buf, right)?;
                    Ok(())
                } else {
                    Err(FormattingFailure::UnknownStructure)
                }
            } else {
                Err(FormattingFailure::UnknownStructure)
            }
        }
        Annotated(Annotation::UnarySubtract, inner) => {
            buf.push('-');
            fmt_default_impl(&mut *buf, inner)?;
            Ok(())
        }
        _ => Err(FormattingFailure::UnknownStructure),
    }
}

pub(super) fn fmt_short_impl(
    buf: &mut String,
    output: &OutputNode,
) -> Result<(), FormattingFailure> {
    use super::stack::Value as SV;
    use OutputNode::*;
    match output {
        Annotated(Annotation::Constant, val) => {
            if let Value(SV::Integer(int)) = &**val {
                let mut itoa_buf = itoa::Buffer::new();
                buf.push_str(itoa_buf.format(*int));
                Ok(())
            } else {
                Err(FormattingFailure::UnknownStructure)
            }
        }
        Annotated(Annotation::Roll { count, sides }, val) => {
            if let Value(SV::Set(partial_sums)) = &**val {
                if *count > 0 {
                    buf.push('(');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*count));
                    buf.push('d');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*sides));
                    buf.push_str(" → ");
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(partial_sums.into_iter().sum::<i64>()));
                    buf.push(')');
                } else {
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*count));
                    buf.push('d');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*sides));
                }
                Ok(())
            } else {
                Err(FormattingFailure::UnknownStructure)
            }
        }
        Annotated(keep, inner)
            if matches!(
                keep,
                Annotation::KeepHigh { .. } | Annotation::KeepLow { .. }
            ) =>
        {
            let keep_count = match keep {
                Annotation::KeepHigh { keep_count } | Annotation::KeepLow { keep_count } => {
                    keep_count
                }
                _ => unreachable!(),
            };
            let op = match keep {
                Annotation::KeepHigh { .. } => "k",
                Annotation::KeepLow { .. } => "kl",
                _ => unreachable!(),
            };
            if let List(list) = &**inner {
                if let [Annotated(Annotation::Roll { count, sides }, roll), Value(SV::Set(filtered))] =
                    &**list
                {
                    if let Value(SV::Set(_unfiltered)) = &**roll {
                        if *count > 0 && *keep_count > 0 {
                            buf.push('(');
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*count));
                            buf.push('d');
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*sides));
                            buf.push_str(op);
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*keep_count));
                            buf.push_str(" → ");
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(filtered.into_iter().sum::<i64>()));
                            buf.push(')');
                        } else {
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*count));
                            buf.push('d');
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*sides));
                            buf.push_str(op);
                            let mut itoa_buf = itoa::Buffer::new();
                            buf.push_str(itoa_buf.format(*keep_count));
                        }
                        Ok(())
                    } else {
                        Err(FormattingFailure::UnknownStructure)
                    }
                } else {
                    Err(FormattingFailure::UnknownStructure)
                }
            } else {
                Err(FormattingFailure::UnknownStructure)
            }
        }
        Annotated(Annotation::Explode { count, sides }, inner) => {
            if let List(iterations) = &**inner {
                use super::stack::Value;
                let iterations = iterations
                    .into_iter()
                    .map(|iter| match iter {
                        OutputNode::Value(Value::Set(set)) => Ok(&**set),
                        _ => Err(()),
                    })
                    .collect::<Result<Vec<&[i64]>, ()>>()
                    .map_err(|()| FormattingFailure::UnknownStructure)?;
                if iterations.len() >= 1 && iterations[0].len() > 0 {
                    buf.push('(');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*count));
                    buf.push('d');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*sides));
                    buf.push('!');
                    buf.push_str(" → ");
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(iterations.into_iter().flatten().sum::<i64>()));
                    buf.push(')');
                } else {
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*count));
                    buf.push('d');
                    let mut itoa_buf = itoa::Buffer::new();
                    buf.push_str(itoa_buf.format(*sides));
                    buf.push('!');
                }
                Ok(())
            } else {
                Err(FormattingFailure::UnknownStructure)
            }
        }
        Annotated(op @ (Annotation::Add | Annotation::Subtract), inner) => {
            if let List(list) = &**inner {
                if let [left, right] = &**list {
                    fmt_short_impl(&mut *buf, left)?;
                    match op {
                        Annotation::Add => buf.push_str(" + "),
                        Annotation::Subtract => buf.push_str(" - "),
                        _ => (),
                    }
                    fmt_short_impl(&mut *buf, right)?;
                    Ok(())
                } else {
                    Err(FormattingFailure::UnknownStructure)
                }
            } else {
                Err(FormattingFailure::UnknownStructure)
            }
        }
        Annotated(Annotation::UnarySubtract, inner) => {
            buf.push('-');
            fmt_short_impl(&mut *buf, inner)?;
            Ok(())
        }
        _ => Err(FormattingFailure::UnknownStructure),
    }
}

/// Format the output of a dice expression as usual in `mbot`.
pub fn mbot_format_default(output: &OutputNode, total: i64) -> String {
    let mut buf = String::with_capacity(2000);
    // Since the total has already been computed, this cannot overflow,
    // as overflow would've already occurred in computing the total.
    fmt_default_impl(&mut buf, output).unwrap();
    buf.push_str(" = ");
    let mut itoa_buf = itoa::Buffer::new();
    buf.push_str(itoa_buf.format(total));
    buf
}

/// Like [`mbot_format_default`], but the shorter format.
pub fn mbot_format_short(output: &OutputNode, total: i64) -> String {
    let mut buf = String::with_capacity(2000);
    // Since the total has already been computed, this cannot overflow,
    // as overflow would've already occurred in computing the total.
    fmt_short_impl(&mut buf, output).unwrap();
    buf.push_str(" = ");
    let mut itoa_buf = itoa::Buffer::new();
    buf.push_str(itoa_buf.format(total));
    buf
}
