use super::{BinaryOp, Coercion, Comparison, Filter, Mir, MirEdge, MirGraph, MirNode};
use crate::viz::{Id, IdGen};
use ::petgraph::graph::NodeIndex;
use ::std::collections::HashMap;

/// A globally unique identifier for every node, which
/// can be calculated from a reference to its containing region
/// and its local node index.
/// These are *not* stable under graph mutation.
#[derive(Debug, Hash, PartialEq, Eq)]
struct GlobalNodeKey {
    local_idx: NodeIndex,
    // A unique key for the containing region, conjured from
    // the deep by unfathomable dark sorcery.
    // (Casting a reference.)
    global_region_key: usize,
}
impl GlobalNodeKey {
    fn new(region: &RegionView<'_>, local_idx: NodeIndex) -> Self {
        let global_region_key = region.graph as *const MirGraph as usize;
        Self {
            local_idx,
            global_region_key,
        }
    }
}

fn push_node(dot: &mut String, id: &Id, label: &str) {
    id.fmt(dot);
    dot.push_str(" [label = \"");
    dot.push_str(label);
    dot.push_str("\"]\n");
}

#[derive(Clone)]
enum RId {
    Simple(Id),
    Structural { id: Id, end_id: Box<RId> },
}
impl RId {
    fn fmt(&self, out: &mut String) {
        match self {
            RId::Simple(id) => id.fmt(out),
            RId::Structural { id, end_id: _ } => RId::region_fmt(id, out),
        }
    }
    fn region_fmt(id: &Id, out: &mut String) {
        out.push_str("cluster_");
        id.fmt(out)
    }
}

#[derive(Clone, Copy)]
struct RegionView<'a> {
    graph: &'a MirGraph,
    #[allow(dead_code)]
    end: NodeIndex,
}

fn dot_inner(
    mir: RegionView,
    gen: &mut IdGen,
    id_map: &mut HashMap<GlobalNodeKey, RId>,
    dot: &mut String,
) {
    for node in mir.graph.node_indices() {
        macro_rules! push_node {
            ($val:expr) => {{
                let id = gen.next();
                push_node(&mut *dot, &id, &*$val);
                id_map.insert(GlobalNodeKey::new(&mir, node), RId::Simple(id));
            }};
        }
        macro_rules! push_region {
            ($name:expr, $region:expr) => {{
                match $region {
                    region => {
                        let region_id = gen.next();
                        // Note that, yes, "cluster" must be at the start of
                        // the subgraph name for GraphViz to group things and
                        // style the subgraphs how we want.
                        // https://graphviz.org/Gallery/directed/cluster.html
                        dot.push_str("subgraph ");
                        RId::region_fmt(&region_id, &mut *dot);
                        dot.push_str(" {\n");
                        // write nodes in here
                        let view = RegionView {
                            graph: &region.graph,
                            end: region.end,
                        };
                        dot_inner(view, &mut *gen, &mut *id_map, &mut *dot);
                        let end_id = id_map[&GlobalNodeKey::new(&view, region.end)].clone();
                        dot.push_str("label = \"");
                        dot.push_str(&*$name);
                        dot.push_str("\"\n");
                        dot.push_str("}\n");
                        id_map.insert(
                            GlobalNodeKey::new(&mir, node),
                            RId::Structural {
                                id: region_id,
                                end_id: Box::new(end_id),
                            },
                        );
                    }
                }
            }};
        }
        match &mir.graph[node] {
            MirNode::Integer(val) => push_node!(format!("{}", val)),
            MirNode::Coerce(Coercion::FromOutputToInt) => push_node!("Coerce(FromOutputToInt)"),
            MirNode::Roll => push_node!("Roll"),
            MirNode::BinOp(BinaryOp::Add) => push_node!("+"),
            MirNode::BinOp(BinaryOp::Subtract) => push_node!("-"),
            MirNode::BinOp(BinaryOp::LogicalAnd) => push_node!("AND"),
            MirNode::Filter(Filter::Simple(filter)) => push_node!(format!("Filter({:?})", filter)),
            MirNode::Filter(Filter::SatisfiesPredicate) => push_node!("Filter(Satisfies)"),
            MirNode::Apply => push_node!("Apply"),
            MirNode::PartialApply => push_node!("PartialApply"),
            MirNode::Compare(Comparison::Equal) => push_node!("Compare(Equal)"),
            MirNode::Compare(Comparison::GreaterThan) => push_node!("Compare(GreaterThan)"),
            MirNode::Count => push_node!("Count"),
            MirNode::Loop(body, _ty) => push_region!("Loop", body),
            MirNode::Decision(_) => todo!("visualizing decision points"),
            MirNode::FunctionDefinition(body) => push_region!("Function", body),
            MirNode::RecursiveEnvironment(body) => push_region!("Recursive Environment", body),
            MirNode::RegionArgument(_) => push_node!("Region Argument"),
            MirNode::End => push_node!("End"),
            MirNode::Fmt(node) => push_node!(format!("Fmt({:?})", node)),
            MirNode::UseFuel(_) => push_node!("UseFuel"),
        }
    }
    for edge in mir.graph.raw_edges().iter() {
        let source = &id_map[&GlobalNodeKey::new(&mir, edge.source())];
        let target = &id_map[&GlobalNodeKey::new(&mir, edge.target())];

        source.fmt(&mut *dot);
        dot.push_str(" -> ");
        match target {
            RId::Simple(_) => {
                target.fmt(&mut *dot);
            }
            RId::Structural { id: _, end_id } => {
                end_id.fmt(&mut *dot);
            }
        };
        match edge.weight {
            MirEdge::IntermediateResultDependency { .. } => dot.push_str(" [color=cornflowerblue]"),
            _ => (),
        }
        dot.push_str("\n");
    }
}

pub fn dot(mir: &Mir) -> String {
    let mut out = String::new();
    let mut gen = IdGen::new();
    let mut id_map: HashMap<GlobalNodeKey, RId> = HashMap::new();
    out.push_str("strict digraph {\n");
    out.push_str("compound=true\n");
    dot_inner(
        RegionView {
            graph: &mir.graph,
            end: mir.top,
        },
        &mut gen,
        &mut id_map,
        &mut out,
    );
    out.push_str("}");
    out
}
