/*
 * Copyright (c) 2021 Frank Fischer <frank-fischer@shadow-soft.de>
 *
 * This program is free software: you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see  <http://www.gnu.org/licenses/>
 */

//! A primal network simplex implementation.

use super::{MinCostFlow, SolutionState};
use crate::traits::{GraphType, IndexDigraph};
use crate::vec::EdgeVec;
//use num_integer::Roots;
use num_traits::{Bounded, FromPrimitive, NumAssign, NumCast, Signed, ToPrimitive, Zero};

type ID = u32;

pub enum Pricing {
    RoundRobin,
    Complete,
    Block,
    MultiplePartial,
}

pub struct NetworkSimplex<'a, G, F> {
    graph: &'a G,

    balances: Vec<F>,
    potentials: Vec<F>,
    subtrees: Vec<ID>,
    parent_edges: Vec<ID>,
    parent_nodes: Vec<ID>,
    first_childs: Vec<ID>,
    prevs: Vec<ID>,
    nexts: Vec<ID>,

    sources: Vec<ID>,
    sinks: Vec<ID>,
    lower: Vec<F>,
    upper: Vec<F>,
    costs: Vec<F>,
    caps: Vec<F>,
    flows: Vec<F>,
    state: Vec<i8>,

    pub pricing: Pricing,
    current_edge: ID,
    block_size: usize,

    niter: usize,
    solution_state: SolutionState,
    need_new_basis: bool,

    /// The artificial cost value.
    ///
    /// Should be larger than the value of any augmenting cycle. If
    /// `None` (the default) the artificial cost is set to
    /// `(max(max(cost), 0) + 1) * n`, which should be large enough.
    pub artificial_cost: Option<F>,
    /// The infinite flow value.
    ///
    /// Capacities greater than or equal to this are considered
    /// unbounded and flows are considered infinite. The default is
    /// `F::max_value()`. For floating-point types `F::infinity()` can
    /// be used as well.
    pub infinite: F,
}

impl<'a, G, F> MinCostFlow<'a> for NetworkSimplex<'a, G, F>
where
    G: IndexDigraph<'a>,
    F: Bounded + NumCast + NumAssign + PartialOrd + Copy + FromPrimitive + Signed,
{
    type Graph = G;

    type Flow = F;

    fn new(g: &'a Self::Graph) -> Self {
        let mut spx = NetworkSimplex {
            graph: g,
            balances: vec![F::zero(); g.num_nodes()],
            potentials: vec![F::zero(); g.num_nodes() + 1],
            subtrees: vec![0; g.num_nodes() + 1],
            parent_edges: vec![0; g.num_nodes() + 1],
            parent_nodes: vec![0; g.num_nodes() + 1],
            first_childs: vec![0; g.num_nodes() + 1],
            prevs: vec![0; g.num_nodes() + 1],
            nexts: vec![0; g.num_nodes() + 1],

            sources: vec![0; g.num_edges() + g.num_nodes()],
            sinks: vec![0; g.num_edges() + g.num_nodes()],
            lower: vec![F::zero(); g.num_edges() + g.num_nodes()],
            upper: vec![F::zero(); g.num_edges() + g.num_nodes()],
            costs: vec![F::zero(); g.num_edges() + g.num_nodes()],
            caps: vec![F::zero(); g.num_edges() + g.num_nodes()],
            flows: vec![F::zero(); g.num_edges() + g.num_nodes()],
            state: vec![0; g.num_edges() + g.num_nodes()],

            pricing: Pricing::Block,
            current_edge: 0,
            block_size: 0,

            niter: 0,
            solution_state: SolutionState::Unknown,
            need_new_basis: true,

            artificial_cost: None,
            infinite: F::max_value(),
        };
        spx.init();
        spx
    }

    fn as_graph(&self) -> &'a Self::Graph {
        self.graph
    }

    fn balance(&self, u: <Self::Graph as GraphType<'a>>::Node) -> Self::Flow {
        self.balances[self.graph.node_id(u)]
    }

    fn set_balance(&mut self, u: <Self::Graph as GraphType<'a>>::Node, balance: Self::Flow) {
        self.need_new_basis = true;
        self.balances[self.graph.node_id(u)] = balance;
    }

    fn lower(&self, e: <Self::Graph as GraphType<'a>>::Edge) -> Self::Flow {
        self.lower[self.graph.edge_id(e)]
    }

    fn set_lower(&mut self, e: <Self::Graph as GraphType<'a>>::Edge, lb: Self::Flow) {
        self.need_new_basis = true;
        self.lower[self.graph.edge_id(e)] = lb;
    }

    fn upper(&self, e: <Self::Graph as GraphType<'a>>::Edge) -> Self::Flow {
        self.upper[self.graph.edge_id(e)]
    }

    fn set_upper(&mut self, e: <Self::Graph as GraphType<'a>>::Edge, ub: Self::Flow) {
        self.need_new_basis = true;
        self.upper[self.graph.edge_id(e)] = ub;
    }

    fn cost(&self, e: <Self::Graph as GraphType<'a>>::Edge) -> Self::Flow {
        self.costs[self.graph.edge_id(e)]
    }

    fn set_cost(&mut self, e: <Self::Graph as GraphType<'a>>::Edge, cost: Self::Flow) {
        self.costs[self.graph.edge_id(e)] = cost;
    }

    /// Return the value of the latest computed flow value.
    fn value(&self) -> Self::Flow {
        let mut v = F::zero();
        for e in self.graph.edges() {
            v += self.flow(e) * self.costs[self.graph.edge_id(e)];
        }
        v
    }

    /// The flow of an Edge.
    fn flow(&self, a: <Self::Graph as GraphType<'a>>::Edge) -> Self::Flow {
        let eid = self.graph.edge_id(a);
        self.flows[eid] + self.lower[eid]
    }

    /// The flow as vector.
    fn flow_vec(&self) -> EdgeVec<'a, &'a Self::Graph, Self::Flow> {
        EdgeVec::new_with(self.as_graph(), |e| self.flow(e))
    }

    /// Solve the maxflow problem.
    ///
    /// The method solves the max flow problem from the source nodes
    /// `src` to the sink node `snk` with the given `upper` bounds on
    /// the edges.
    fn solve(&mut self) -> SolutionState {
        self.niter = 0;
        self.solution_state = SolutionState::Unknown;

        // check trivial cases
        if self.graph.num_nodes().is_zero() {
            self.solution_state = SolutionState::Optimal;
            return self.solution_state;
        }

        if self.graph.num_edges().is_zero() {
            // check if all balances are zero, that's the only way to be feasible
            self.solution_state = if self.balances.iter().all(Zero::is_zero) {
                SolutionState::Optimal
            } else {
                SolutionState::Infeasible
            };
            return self.solution_state;
        }

        self.initialize_pricing();

        if self.need_new_basis {
            if !self.prepare_initial_basis() {
                self.solution_state = SolutionState::Infeasible;
                return self.solution_state;
            }
        }

        self.compute_node_potentials(self.graph.num_nodes().to_u32().unwrap());
        loop {
            self.niter += 1;
            if let Some(eid) = self.find_entering_edge() {
                if !self.augment_cycle(eid) {
                    self.solution_state = SolutionState::Unbounded;
                    return self.solution_state;
                }
            } else {
                break;
            }
        }

        self.solution_state = if self.check_feasiblity() {
            SolutionState::Optimal
        } else {
            SolutionState::Infeasible
        };

        self.solution_state
    }

    fn solution_state(&self) -> SolutionState {
        if self.need_new_basis {
            SolutionState::Unknown
        } else {
            self.solution_state
        }
    }
}

impl<'a, G, F> NetworkSimplex<'a, G, F>
where
    G: IndexDigraph<'a>,
    F: NumCast + NumAssign + Signed + PartialOrd + Copy + FromPrimitive,
{
    fn init(&mut self) {
        let m = self.graph.num_edges();
        // Initialize edges
        for eid in 0..m {
            self.sources[eid] = NumCast::from(self.graph.node_id(self.graph.src(self.graph.id2edge(eid)))).unwrap();
            self.sinks[eid] = NumCast::from(self.graph.node_id(self.graph.snk(self.graph.id2edge(eid)))).unwrap();
        }

        // The artificial edges will be initialized when the initial
        // basis is prepared.
    }

    pub fn num_iterations(&self) -> usize {
        self.niter
    }

    fn initialize_pricing(&mut self) {
        match self.pricing {
            Pricing::RoundRobin => self.current_edge = 0,
            Pricing::Complete => (),
            Pricing::Block => {
                self.current_edge = 0;
                // The following code is analogous to my Pascal implementation.
                // We could also use
                //    self.block_size = (self.graph.num_edges().sqrt() / 2).max(10);
                self.block_size = ((self.graph.num_edges() as f64).sqrt() * 0.5)
                    .round()
                    .to_usize()
                    .unwrap()
                    .max(10);
            }
            Pricing::MultiplePartial => todo!(),
        }
    }

    fn prepare_initial_basis(&mut self) -> bool {
        let n = self.graph.num_nodes();
        let m = self.graph.num_edges() * 2;
        // The artificial node is always the root of the basis tree
        let uid = self.graph.num_nodes();

        // modified balances of each node
        let mut balances = self.balances.clone();
        balances.push(F::zero());

        // compute the cost value for the artificial nodes
        let artificial_cost = self.artificial_cost.unwrap_or_else(|| {
            let mut value = F::zero();
            for &c in &self.costs[0..self.graph.num_edges()] {
                if c > value {
                    value = c;
                }
            }
            F::from(n).unwrap() * (F::one() + value)
        });

        self.subtrees[uid] = n as ID + 1;
        self.parent_edges[uid] = ID::max_value();
        self.parent_nodes[uid] = ID::max_value();
        self.first_childs[uid] = 0; // the first node
        self.prevs[uid] = ID::max_value();
        self.nexts[uid] = ID::max_value();

        // Initial flow on all non-artificial edges is at lower or upper bound depending on the cost
        for e in self.graph.edges() {
            let eid = self.graph.edge_id(e);
            // We set the initial flow on edges with non-negative costs at the lower bound ...
            let cap = self.upper[eid] - self.lower[eid];

            // The current edge is always infeasible.
            if cap < F::zero() {
                return false;
            }

            let flw: F;
            if self.costs[eid] >= F::zero() {
                self.state[eid] = 1;
                flw = F::zero();
            } else {
                self.state[eid] = -1;
                flw = cap;
            }

            self.flows[eid] = flw;
            self.caps[eid] = cap;

            // Update artificial balances
            let flw = flw + self.lower[eid];
            balances[self.graph.node_id(self.graph.src(e))] -= flw;
            balances[self.graph.node_id(self.graph.snk(e))] += flw;
        }

        // The initial basis consists of the artificial edges only
        for vid in 0..n {
            self.subtrees[vid] = 1;
            // Set the initial flow on the artificial edges
            let eid = m + vid * 2;
            let fid; // the parent edge, oriented from the artificial node (the root) to v
            let b; // the balance / initial flow on the artificial edge
            if balances[vid] >= F::zero() {
                fid = eid ^ 1;
                b = balances[vid];
                self.costs[eid / 2] = artificial_cost;
                // this edge is oriented from v the artificial node
                self.sources[eid / 2] = ID::from_usize(eid - m).unwrap();
                self.sinks[eid / 2] = ID::from_usize(n).unwrap();
            } else {
                fid = eid;
                b = -balances[vid];
                self.costs[eid / 2] = artificial_cost;
                // this edge is oriented from the artificial node to v
                self.sources[eid / 2] = ID::from_usize(n).unwrap();
                self.sinks[eid / 2] = ID::from_usize(eid - m).unwrap();
            }

            self.caps[eid / 2] = self.infinite;
            self.flows[eid / 2] = b;
            self.state[eid / 2] = 0;

            self.parent_nodes[vid] = uid as ID;
            self.parent_edges[vid] = fid as ID;
            self.first_childs[vid] = ID::max_value();
            self.prevs[vid] = if vid > 0 { vid as ID - 1 } else { ID::max_value() };
            self.nexts[vid] = if vid + 1 < n { vid as ID + 1 } else { ID::max_value() };
        }

        self.need_new_basis = false;

        true
    }

    fn compute_node_potentials(&mut self, rootid: ID) {
        let rootid = if rootid != ID::max_value() {
            rootid as usize
        } else {
            self.graph.num_nodes()
        };
        if rootid != self.graph.num_nodes() {
            let eid = self.parent_edges[rootid] as usize;
            self.potentials[rootid] =
                self.potentials[self.parent_nodes[rootid] as usize] + oriented_flow(eid, self.costs[eid / 2]);
        }

        // traverse tree in pre-order
        let mut uid = self.first_childs[rootid] as usize;
        if uid == ID::max_value() as usize {
            return;
        }
        loop {
            // update the potential of this node
            let eid = self.parent_edges[uid] as usize;
            self.potentials[uid] =
                self.potentials[self.parent_nodes[uid] as usize] + oriented_flow(eid, self.costs[eid as usize / 2]);
            // go to next node in pre-order
            if self.first_childs[uid] != ID::max_value() {
                // first child if there is one
                uid = self.first_childs[uid] as usize;
            } else {
                // go upwards until we find the first node that has a sibling
                // and go to this sibling
                while self.nexts[uid] == ID::max_value() {
                    uid = self.parent_nodes[uid] as usize;
                    if uid == rootid {
                        return;
                    }
                }
                uid = self.nexts[uid] as usize;
            }
        }
    }

    fn augment_cycle(&mut self, eid: ID) -> bool {
        let eid = eid as usize;

        // e = (u,v)
        let (mut uleaving, mut vleaving) = if (eid & 1) == 0 {
            (self.sources[eid / 2] as usize, self.sinks[eid / 2] as usize)
        } else {
            (self.sinks[eid / 2] as usize, self.sources[eid / 2] as usize)
        };

        // Obtain free capacity on non-basis edge.
        let mut d = self.caps[eid / 2];

        // Compute maximal flow augmentation value and determine base-leaving-edge.
        let mut min_nodeid = None;
        let mut min_fwd = true;
        let mut uid = uleaving;
        let mut vid = vleaving;
        while uid != vid {
            if self.subtrees[uid] < self.subtrees[vid] {
                // Edges on the side of u are in forward direction on the cycle
                let f = self.parent_edges[uid] as usize;
                let flw = if (f & 1) == 0 {
                    if self.caps[f / 2] != self.infinite {
                        self.caps[f / 2] - self.flows[f / 2]
                    } else {
                        self.infinite
                    }
                } else {
                    self.flows[f / 2]
                };
                if flw < d {
                    d = flw;
                    min_nodeid = Some(uid);
                    min_fwd = false;
                }
                uid = self.parent_nodes[uid] as usize;
            } else {
                // Edges on the side of v are in backward direction on the cycle
                let f = self.parent_edges[vid] as usize;
                let flw = if (f & 1) == 0 {
                    self.flows[f / 2]
                } else {
                    if self.caps[f / 2] != self.infinite {
                        self.caps[f / 2] - self.flows[f / 2]
                    } else {
                        self.infinite
                    }
                };
                if flw <= d {
                    d = flw;
                    min_nodeid = Some(vid);
                    min_fwd = true;
                }
                vid = self.parent_nodes[vid] as usize;
            };
        }

        if d >= self.infinite {
            return false;
        }

        // vid is the common ancestor, i.e. the "top-most" node on the
        // cycle in the basis tree.
        let ancestorid = vid;

        // Augment the flow one the basis entering edge.
        self.flows[eid / 2] = if self.state[eid / 2] == 1 {
            d
        } else {
            self.caps[eid / 2] - d
        };

        // Check if e stays in non-basis
        let min_nodeid = if let Some(m) = min_nodeid {
            m
        } else {
            // switch bound
            self.state[eid / 2] = -self.state[eid / 2];
            // update flow on cycle
            let mut uid = uleaving;
            let mut vid = vleaving;
            while uid != ancestorid {
                let f = self.parent_edges[uid] as usize;
                self.flows[f / 2] += oriented_flow(f, d);
                uid = self.parent_nodes[uid] as usize;
            }
            while vid != ancestorid {
                let f = self.parent_edges[vid] as usize;
                self.flows[f / 2] -= oriented_flow(f, d);
                vid = self.parent_nodes[vid] as usize;
            }
            // done
            return true;
        };

        // ************************************************************
        // update the basis tree
        // ************************************************************

        self.state[eid / 2] = 0;

        // The basis leaving edge should be on the side of u, so possibly reverse e.
        let mut fparent: usize;
        let mut eid = eid;
        if min_fwd {
            eid = eid ^ 1;
            fparent = self.parent_edges[min_nodeid] as usize;
            self.state[fparent / 2] = if (fparent & 1) == 0 { 1 } else { -1 };
            // swap this edge
            d = -d;
            let u = uleaving;
            uleaving = vleaving;
            vleaving = u;
        } else {
            fparent = self.parent_edges[min_nodeid] as usize;
            self.state[fparent / 2] = if (fparent & 1) == 0 { -1 } else { 1 };
        }

        // First make u a child of v using the new edge ...
        let mut uid = uleaving;
        let mut vid = vleaving;

        // ... save the old parent of u ...
        let mut wid = self.parent_nodes[uid] as usize;
        // ... remove u from its old parent ...
        if self.prevs[uid] != ID::max_value() {
            self.nexts[self.prevs[uid] as usize] = self.nexts[uid];
        } else {
            self.first_childs[wid] = self.nexts[uid];
        }
        if self.nexts[uid] != ID::max_value() {
            self.prevs[self.nexts[uid] as usize] = self.prevs[uid];
        }
        // ... and make it a child of v.
        self.parent_nodes[uid] = vid as ID;
        self.nexts[uid] = self.first_childs[vid];
        if self.nexts[uid] != ID::max_value() {
            self.prevs[self.nexts[uid] as usize] = uid as ID;
        }
        self.prevs[uid] = ID::max_value();
        self.first_childs[vid] = uid as ID;
        fparent = self.parent_edges[uid] as usize; // save to original parent
        self.parent_edges[uid] = eid as ID ^ 1;

        // All subtree sizes from v up to the ancestor are increased.
        // The flow on these edges is reduced
        let subtreediff = self.subtrees[min_nodeid];
        while vid != ancestorid {
            let f = self.parent_edges[vid] as usize;
            self.flows[f / 2] -= oriented_flow(f, d);
            self.subtrees[vid] += subtreediff;
            vid = self.parent_nodes[vid] as usize;
        }

        // Go up from u up to min_node and rotate all edges along the way.
        // At the beginning of each iteration v is the (original) parent
        // of the current node u and f is the (original) incoming edge of u.
        // Then v is detached from its old parent and made a child of u.
        vid = wid;
        let mut childsubtree = 0;
        while uid != min_nodeid {
            // save v's original parent ...
            wid = self.parent_nodes[vid] as usize;
            // ... remove v from its parent ...
            if self.prevs[vid] != ID::max_value() {
                self.nexts[self.prevs[vid] as usize] = self.nexts[vid]
            } else {
                self.first_childs[wid] = self.nexts[vid];
            }
            if self.nexts[vid] != ID::max_value() {
                self.prevs[self.nexts[vid] as usize] = self.prevs[vid];
            }
            // ... and make it a child of u
            self.parent_nodes[vid] = uid as ID;
            let f = self.parent_edges[vid] as usize;
            self.parent_edges[vid] = fparent as ID ^ 1;
            self.flows[fparent / 2] += oriented_flow(fparent, d);
            fparent = f;
            self.nexts[vid] = self.first_childs[uid];
            if self.nexts[vid] != ID::max_value() {
                self.prevs[self.nexts[vid] as usize] = vid as ID;
            }
            self.prevs[vid] = ID::max_value();
            self.first_childs[uid] = vid as ID;

            // What is currently below u stays below u but is not added again
            // Everything else is added
            let usubtree = self.subtrees[uid];
            self.subtrees[uid] = subtreediff - childsubtree;
            childsubtree = usubtree;

            // Go up one edge.
            uid = vid;
            vid = wid;
        }
        // we reached min_node, set its subtree size and update its flow
        self.subtrees[min_nodeid] = subtreediff - childsubtree;
        self.flows[fparent / 2] += oriented_flow(fparent, d);

        // At this point u is min_node and v its (old) parent, hence (v,u) is the
        // basis leaving edge. It remains to reduce the subtree size of all
        // nodes on the path from v up to the ancestor. Each of these nodes
        // loses everything that is (now) in the subtree of u (formerly in
        // the subtree of min_node).
        while vid != ancestorid {
            let f = self.parent_edges[vid] as usize;
            self.flows[f / 2] += oriented_flow(f, d);
            self.subtrees[vid] -= subtreediff;
            vid = self.parent_nodes[vid] as usize;
        }

        // Finally recompute node potentials for changed subtrees
        self.compute_node_potentials(uleaving as ID);

        true
    }

    fn check_feasiblity(&mut self) -> bool {
        self.flows[self.graph.num_edges()..].iter().all(F::is_zero)
    }

    fn find_entering_edge(&mut self) -> Option<ID> {
        match self.pricing {
            Pricing::RoundRobin => self.round_robin_pricing(),
            Pricing::Complete => self.complete_pricing(),
            Pricing::Block => self.block_pricing(),
            Pricing::MultiplePartial => self.multiple_partial_pricing(),
        }
    }

    fn round_robin_pricing(&mut self) -> Option<ID> {
        let mut eid = self.current_edge as usize;
        loop {
            if self.reduced_cost(eid) < F::zero() {
                self.current_edge = eid as ID;
                return Some(self.oriented_edge(eid));
            }
            eid = (eid + 1) % self.graph.num_edges();
            if eid == self.current_edge as usize {
                return None;
            }
        }
    }

    fn complete_pricing(&mut self) -> Option<ID> {
        let mut min_cost = F::zero();
        let mut min_edge = None;
        for eid in 0..self.graph.num_edges() {
            let c = self.reduced_cost(eid);
            if c < min_cost {
                min_cost = c;
                min_edge = Some(eid);
            }
        }

        min_edge.map(|eid| self.oriented_edge(eid))
    }

    fn block_pricing(&mut self) -> Option<ID> {
        let mut end = self.graph.num_edges();
        let mut eid = self.current_edge as usize % end;
        let mut min_edge = None;
        let mut min_cost = F::zero();
        let mut m = (eid + self.block_size).min(end);
        let mut cnt = self.block_size.min(end);

        loop {
            while eid < m {
                let c = self.reduced_cost(eid);
                if c < min_cost {
                    min_cost = c;
                    min_edge = Some(eid);
                }
                cnt -= 1;
                eid += 1;
            }

            if cnt == 0 {
                // reached regular end of the current block, start new block
                m = (eid + self.block_size).min(end);
                cnt = self.block_size.min(end);
            } else if eid != self.current_edge as usize {
                // reached non-regular end of the final block, start
                // from the beginning
                end = self.current_edge as usize;
                eid = 0;
                m = cnt.min(end);
                continue;
            }

            if let Some(enteringid) = min_edge {
                self.current_edge = eid as ID;
                return Some(self.oriented_edge(enteringid));
            }

            if eid == self.current_edge as usize {
                return None;
            }
        }
    }

    fn multiple_partial_pricing(&mut self) -> Option<ID> {
        todo!()
    }

    fn reduced_cost(&self, eid: usize) -> F {
        unsafe {
            F::from(*self.state.get_unchecked(eid)).unwrap()
                * (*self.costs.get_unchecked(eid)
                    - *self.potentials.get_unchecked(*self.sinks.get_unchecked(eid) as usize)
                    + *self.potentials.get_unchecked(*self.sources.get_unchecked(eid) as usize))
        }
    }

    fn oriented_edge(&self, eid: usize) -> ID {
        let eid = if self.state[eid] == 1 { eid * 2 } else { eid * 2 | 1 };
        eid as ID
    }
}

fn oriented_flow<F>(eid: usize, d: F) -> F
where
    F: NumAssign + NumCast,
{
    (F::one() - F::from((eid & 1) * 2).unwrap()) * d
}
