// Copyright (c) 2016-2021 Frank Fischer <frank-fischer@shadow-soft.de>
//
// This program is free software: you can redistribute it and/or
// modify it under the terms of the GNU General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program.  If not, see  <http://www.gnu.org/licenses/>
//

//! Compute a maximum weight branching.

use crate::builder::{Buildable, Builder};
use crate::linkedlistgraph::LinkedListGraph;
use crate::traits::IndexDigraph;
use crate::vec::{EdgeVec, NodeVec};

use crate::num::traits::NumAssign;

#[allow(clippy::cognitive_complexity)]
pub fn max_weight_branching<'a, G, W>(g: &'a G, weights: &EdgeVec<'a, &'a G, W>) -> Vec<G::Edge>
where
    G: IndexDigraph<'a>,
    W: NumAssign + Ord + Copy,
{
    // find non-cycle-free subset
    let mut inarcs = NodeVec::new(g, None);
    for e in g.edges() {
        let u = g.snk(e);
        let w = weights[e];
        if let Some(max_arc) = inarcs[u] {
            if weights[max_arc] < w {
                inarcs[u] = Some(e)
            }
        } else if w > W::zero() {
            inarcs[u] = Some(e)
        }
    }

    let mut newnodes = NodeVec::new(g, None);
    let mut newg = LinkedListGraph::<usize>::new_builder();

    // find cycles
    let mut label = NodeVec::new(g, 0);
    let mut diffweights = NodeVec::new(g, W::zero());
    for u in g.nodes() {
        if label[u] != 0 {
            continue;
        } // node already seen

        // run along predecessors of unseen nodes
        let mut v = u;
        while label[v] == 0 {
            label[v] = 1;
            if let Some(e) = inarcs[v] {
                v = g.src(e);
            } else {
                break;
            }
        }

        if let Some(e) = inarcs[v] {
            // last node has an incoming arc ...
            if label[v] == 1 {
                // ... and has been seen on *this* path
                // we have found a cycle
                // find the minimal weight
                let mut minweight = weights[e];
                let mut w = g.src(e);
                while w != v {
                    let e = inarcs[w].unwrap();
                    if weights[e] < minweight {
                        minweight = weights[e];
                    }
                    w = g.src(e);
                }

                // contract the cycle and compute the weight difference
                // for each node
                let contracted_node = newg.add_node();
                newnodes[v] = Some(contracted_node);
                diffweights[v] = weights[e] - minweight;
                label[v] = 2;
                let mut w = g.src(e);
                while w != v {
                    newnodes[w] = Some(contracted_node);
                    label[w] = 2;
                    let e = inarcs[w].unwrap();
                    diffweights[w] = weights[e] - minweight;
                    w = g.src(e);
                }
            }
        }

        // add all remaining nodes on the path as single nodes
        let mut v = u;
        while label[v] == 1 {
            newnodes[v] = Some(newg.add_node());
            label[v] = 2;
            if let Some(e) = inarcs[v] {
                v = g.src(e);
            } else {
                break;
            }
        }
    }

    if newg.num_nodes() == g.num_nodes() {
        // nothing contracted => found a branching
        return inarcs.iter().filter_map(|(_, &e)| e).collect();
    }

    // add arcs
    let mut newweights = vec![];
    let mut newarcs = vec![];
    for e in g.edges() {
        let newu = newnodes[g.src(e)].unwrap();
        let newv = newnodes[g.snk(e)].unwrap();
        if newu != newv && weights[e] > W::zero() {
            newg.add_edge(newu, newv);
            newarcs.push(e);
            newweights.push(weights[e] - diffweights[g.snk(e)]);
        }
    }

    let newg = newg.into_graph();

    // recursively determine branching on smaller graph
    let newweights = EdgeVec::new_from_vec(&newg, newweights);
    let newbranching = max_weight_branching(&newg, &newweights);
    let mut branching = vec![];

    // add original arcs
    let newarcs = EdgeVec::new_from_vec(&newg, newarcs);
    for newa in newbranching {
        let e = newarcs[newa];
        branching.push(e);
        let u = g.snk(e);
        label[u] = 3;
        // if sink of arc is a contraction node, add the cycle
        if let Some(inarc) = inarcs[u] {
            if inarc != e {
                let mut v = g.src(inarc);
                while v != u {
                    label[v] = 3;
                    let e = inarcs[v].unwrap();
                    branching.push(e);
                    v = g.src(e);
                }
            }
        }
    }

    // Now find all nodes that are not contained in the branching.
    // These nodes might be contained in a cycle, we add that cycle
    // except for the cheapest arc.
    for u in g.nodes() {
        if label[u] == 2 {
            label[u] = 3;
            if let Some(e) = inarcs[u] {
                let mut minarc = e;
                let mut v = g.src(e);
                while label[v] != 3 {
                    label[v] = 3;
                    if let Some(e) = inarcs[v] {
                        if weights[e] >= weights[minarc] {
                            branching.push(e);
                        } else {
                            branching.push(minarc);
                            minarc = e;
                        }
                        v = g.src(e);
                    } else {
                        break;
                    }
                }
            }
        }
    }

    branching
}

#[cfg(test)]
mod tests {
    use crate::branching::max_weight_branching;
    use crate::vec::EdgeVec;
    use crate::{Buildable, Builder, LinkedListGraph};

    #[test]
    fn test_branching1() {
        let mut g = LinkedListGraph::<usize>::new_builder();
        let mut weights = vec![];
        let nodes = g.add_nodes(9);

        for &(u, v, c) in [
            (1, 4, 17u32),
            (1, 5, 5),
            (1, 3, 18),
            (2, 1, 21),
            (2, 6, 17),
            (2, 7, 12),
            (3, 2, 21),
            (3, 8, 15),
            (4, 9, 12),
            (5, 2, 12),
            (5, 4, 12),
            (6, 5, 4),
            (6, 7, 13),
            (7, 3, 14),
            (7, 8, 12),
            (8, 9, 18),
            (9, 1, 19),
            (9, 3, 15),
        ]
        .iter()
        {
            g.add_edge(nodes[u - 1], nodes[v - 1]);
            weights.push(c);
        }

        let g = g.into_graph();

        let weights = EdgeVec::new_from_vec(&g, weights);
        let branching = max_weight_branching(&g, &weights);
        assert_eq!(branching.iter().fold(0, |acc, &e| acc + weights[e]), 131);
    }

    #[test]
    fn test_branching2() {
        let mut g = LinkedListGraph::<usize>::new_builder();
        let mut weights = vec![];
        let nodes = g.add_nodes(9);

        for &(u, v, c) in [
            (2, 1, 3),
            (1, 3, 4),
            (6, 3, 3),
            (6, 7, 1),
            (7, 4, 3),
            (1, 2, 10),
            (4, 1, 5),
            (3, 4, 5),
            (4, 5, 2),
            (4, 6, 4),
            (5, 6, 2),
        ]
        .iter()
        {
            g.add_edge(nodes[u - 1], nodes[v - 1]);
            weights.push(c);
        }

        let g = g.into_graph();
        let weights = EdgeVec::new_from_vec(&g, weights);
        let branching = max_weight_branching(&g, &weights);
        assert_eq!(branching.iter().fold(0, |acc, &e| acc + weights[e]), 28);
    }
}
