use swc_common::{util::take::Take, Spanned, DUMMY_SP};
use swc_ecma_ast::*;
use swc_ecma_utils::{prepend_stmt, undefined, StmtExt, StmtLike};
use swc_ecma_visit::{noop_visit_type, Visit, VisitWith};

use super::Optimizer;
use crate::{compress::util::is_pure_undefined, debug::dump, mode::Mode, util::ExprOptExt};

/// Methods related to the option `if_return`. All methods are noop if
/// `if_return` is false.
impl<M> Optimizer<'_, M>
where
    M: Mode,
{
    pub(super) fn merge_nested_if(&mut self, s: &mut IfStmt) {
        if !self.options.conditionals && !self.options.bools {
            return;
        }

        if s.alt.is_some() {
            return;
        }

        if let Stmt::If(IfStmt {
            test,
            cons,
            alt: None,
            ..
        }) = &mut *s.cons
        {
            self.changed = true;
            report_change!("if_return: Merging nested if statements");

            s.test = Box::new(Expr::Bin(BinExpr {
                span: s.test.span(),
                op: op!("&&"),
                left: s.test.take(),
                right: test.take(),
            }));
            s.cons = cons.take();
        }
    }

    pub(super) fn merge_else_if(&mut self, s: &mut IfStmt) {
        if let Some(Stmt::If(IfStmt {
            span: span_of_alt,
            test: test_of_alt,
            cons: cons_of_alt,
            alt: Some(alt_of_alt),
            ..
        })) = s.alt.as_deref_mut()
        {
            match &**cons_of_alt {
                Stmt::Return(..) | Stmt::Continue(ContinueStmt { label: None, .. }) => {}
                _ => return,
            }

            match &mut **alt_of_alt {
                Stmt::Block(..) => {}
                Stmt::Expr(..) => {
                    *alt_of_alt = Box::new(Stmt::Block(BlockStmt {
                        span: DUMMY_SP,
                        stmts: vec![*alt_of_alt.take()],
                    }));
                }
                _ => {
                    return;
                }
            }

            self.changed = true;
            report_change!("if_return: Merging `else if` into `else`");

            match &mut **alt_of_alt {
                Stmt::Block(alt_of_alt) => {
                    prepend_stmt(
                        &mut alt_of_alt.stmts,
                        Stmt::If(IfStmt {
                            span: *span_of_alt,
                            test: test_of_alt.take(),
                            cons: cons_of_alt.take(),
                            alt: None,
                        }),
                    );
                }

                _ => {
                    unreachable!()
                }
            }

            s.alt = Some(alt_of_alt.take());
        }
    }

    pub(super) fn merge_if_returns(
        &mut self,
        stmts: &mut Vec<Stmt>,
        can_work: bool,
        is_fn_body: bool,
    ) {
        if !self.options.if_return {
            return;
        }

        for stmt in stmts.iter_mut() {
            self.merge_nested_if_returns(stmt, can_work);
        }

        if can_work || is_fn_body {
            self.merge_if_returns_inner(stmts);
        }
    }

    #[allow(clippy::only_used_in_recursion)]
    fn merge_nested_if_returns(&mut self, s: &mut Stmt, can_work: bool) {
        let terminate = can_merge_as_if_return(&*s);

        match s {
            Stmt::Block(s) => {
                self.merge_if_returns(&mut s.stmts, terminate, false);
            }
            Stmt::If(s) => {
                self.merge_nested_if_returns(&mut s.cons, can_work);

                if let Some(alt) = s.alt.as_deref_mut() {
                    self.merge_nested_if_returns(alt, can_work);
                }
            }
            _ => {}
        }
    }

    /// Merge simple return statements in if statements.
    ///
    /// # Example
    ///
    /// ## Input
    ///
    /// ```js
    /// function foo() {
    ///     if (a) return foo();
    ///     return bar()
    /// }
    /// ```
    ///
    /// ## Output
    ///
    /// ```js
    /// function foo() {
    ///     return a ? foo() : bar();
    /// }
    /// ```
    fn merge_if_returns_inner(&mut self, stmts: &mut Vec<Stmt>) {
        if !self.options.if_return {
            return;
        }

        // for stmt in stmts.iter_mut() {
        //     let ctx = Ctx {
        //         is_nested_if_return_merging: true,
        //         ..self.ctx
        //     };
        //     self.with_ctx(ctx).merge_nested_if_returns(stmt, terminate);
        // }

        if stmts.len() <= 1 {
            return;
        }

        let idx_of_not_mergable =
            stmts
                .iter()
                .enumerate()
                .rposition(|(idx, stmt)| match stmt.as_stmt() {
                    Some(v) => !self.can_merge_stmt_as_if_return(v, stmts.len() - 1 == idx),
                    None => true,
                });
        let skip = idx_of_not_mergable.map(|v| v + 1).unwrap_or(0);
        trace_op!("if_return: Skip = {}", skip);
        let mut last_idx = stmts.len() - 1;

        {
            loop {
                let s = stmts.get(last_idx);
                let s = match s {
                    Some(s) => s,
                    _ => break,
                };

                if let Stmt::Decl(Decl::Var(v)) = s {
                    if v.decls.iter().all(|v| v.init.is_none()) {
                        if last_idx == 0 {
                            break;
                        }
                        last_idx -= 1;
                        continue;
                    }
                }

                break;
            }
        }

        if last_idx <= skip {
            log_abort!("if_return: [x] Aborting because of skip");
            return;
        }

        {
            let stmts = &stmts[skip..=last_idx];
            let return_count: usize = stmts.iter().map(count_leaping_returns).sum();

            // There's no return statement so merging requires injecting unnecessary `void
            // 0`
            if return_count == 0 {
                log_abort!("if_return: [x] Aborting because we failed to find return");
                return;
            }

            // If the last statement is a return statement and last - 1 is an if statement
            // is without return, we don't need to fold it as `void 0` is too much for such
            // cases.

            let if_return_count = stmts
                .iter()
                .filter(|s| match s {
                    Stmt::If(IfStmt {
                        cons, alt: None, ..
                    }) => always_terminates_with_return_arg(cons),
                    _ => false,
                })
                .count();

            if stmts.len() >= 2 {
                match (
                    &stmts[stmts.len() - 2].as_stmt(),
                    &stmts[stmts.len() - 1].as_stmt(),
                ) {
                    (_, Some(Stmt::If(IfStmt { alt: None, .. }) | Stmt::Expr(..)))
                        if if_return_count <= 1 =>
                    {
                        log_abort!(
                            "if_return: [x] Aborting because last stmt is a not return stmt"
                        );
                        return;
                    }

                    (
                        Some(Stmt::If(IfStmt {
                            cons, alt: None, ..
                        })),
                        Some(Stmt::Return(..)),
                    ) => match &**cons {
                        Stmt::Return(ReturnStmt { arg: Some(..), .. }) => {}
                        _ => {
                            log_abort!(
                                "if_return: [x] Aborting because stmt before last is an if stmt \
                                 and cons of it is not a return stmt"
                            );
                            return;
                        }
                    },

                    _ => {}
                }
            }
        }

        {
            let stmts = &stmts[..=last_idx];
            let start = stmts
                .iter()
                .enumerate()
                .skip(skip)
                .position(|(idx, stmt)| match stmt.as_stmt() {
                    Some(v) => self.can_merge_stmt_as_if_return(v, stmts.len() - 1 == idx),
                    None => false,
                })
                .unwrap_or(0);

            let ends_with_mergable = stmts
                .last()
                .map(|stmt| match stmt.as_stmt() {
                    Some(Stmt::If(IfStmt { alt: None, .. }))
                        if self.ctx.is_nested_if_return_merging =>
                    {
                        false
                    }
                    Some(s) => self.can_merge_stmt_as_if_return(s, true),
                    _ => false,
                })
                .unwrap();

            if stmts.len() == start + skip + 1 || !ends_with_mergable {
                return;
            }

            let can_merge =
                stmts
                    .iter()
                    .enumerate()
                    .skip(skip)
                    .all(|(idx, stmt)| match stmt.as_stmt() {
                        Some(s) => self.can_merge_stmt_as_if_return(s, stmts.len() - 1 == idx),
                        _ => false,
                    });
            if !can_merge {
                return;
            }
        }

        report_change!("if_return: Merging returns");

        self.changed = true;

        let mut cur: Option<Box<Expr>> = None;
        let mut new = Vec::with_capacity(stmts.len());

        let len = stmts.len();

        for (idx, stmt) in stmts.take().into_iter().enumerate() {
            if let Some(not_mergable) = idx_of_not_mergable {
                if idx < not_mergable {
                    new.push(stmt);
                    continue;
                }
            }
            if idx > last_idx {
                new.push(stmt);
                continue;
            }

            let stmt = if !self.can_merge_stmt_as_if_return(&stmt, len - 1 == idx) {
                debug_assert_eq!(cur, None);
                new.push(stmt);
                continue;
            } else {
                stmt
            };
            let is_nonconditional_return = matches!(stmt, Stmt::Return(..));
            let new_expr = self.merge_if_returns_to(stmt, vec![]);
            match new_expr {
                Expr::Seq(v) => match &mut cur {
                    Some(cur) => match &mut **cur {
                        Expr::Cond(cur) => {
                            let seq = get_rightmost_alt_of_cond(cur).force_seq();
                            seq.exprs.extend(v.exprs);
                        }
                        Expr::Seq(cur) => {
                            cur.exprs.extend(v.exprs);
                        }
                        _ => {
                            unreachable!(
                                "if_return: cur must be one of None, Expr::Seq or Expr::Cond(with \
                                 alt Expr::Seq)"
                            )
                        }
                    },
                    None => cur = Some(Box::new(Expr::Seq(v))),
                },
                Expr::Cond(v) => match &mut cur {
                    Some(cur) => match &mut **cur {
                        Expr::Cond(cur) => {
                            let alt = get_rightmost_alt_of_cond(cur);

                            let (span, exprs) = {
                                let prev_seq = alt.force_seq();
                                prev_seq.exprs.push(v.test);
                                let exprs = prev_seq.exprs.take();

                                (prev_seq.span, exprs)
                            };

                            *alt = Expr::Cond(CondExpr {
                                span: DUMMY_SP,
                                test: Box::new(Expr::Seq(SeqExpr { span, exprs })),
                                cons: v.cons,
                                alt: v.alt,
                            });
                        }
                        Expr::Seq(prev_seq) => {
                            prev_seq.exprs.push(v.test);
                            let exprs = prev_seq.exprs.take();

                            *cur = Box::new(Expr::Cond(CondExpr {
                                span: DUMMY_SP,
                                test: Box::new(Expr::Seq(SeqExpr {
                                    span: prev_seq.span,
                                    exprs,
                                })),
                                cons: v.cons,
                                alt: v.alt,
                            }));
                        }
                        _ => {
                            unreachable!(
                                "if_return: cur must be one of None, Expr::Seq or Expr::Cond(with \
                                 alt Expr::Seq)"
                            )
                        }
                    },
                    None => cur = Some(Box::new(Expr::Cond(v))),
                },
                _ => {
                    unreachable!(
                        "if_return: merge_if_returns_to should return one of None, Expr::Seq or \
                         Expr::Cond"
                    )
                }
            }

            if is_nonconditional_return {
                break;
            }
        }

        if let Some(mut cur) = cur {
            match &*cur {
                Expr::Seq(seq)
                    if seq
                        .exprs
                        .last()
                        .map(|v| is_pure_undefined(&self.expr_ctx, v))
                        .unwrap_or(true) =>
                {
                    let expr = self.ignore_return_value(&mut cur);

                    if let Some(cur) = expr {
                        new.push(Stmt::Expr(ExprStmt {
                            span: DUMMY_SP,
                            expr: Box::new(cur),
                        }))
                    } else {
                        trace_op!("if_return: Ignoring return value");
                    }
                }
                _ => {
                    new.push(Stmt::Return(ReturnStmt {
                        span: DUMMY_SP,
                        arg: Some(cur),
                    }));
                }
            }
        }

        *stmts = new;
    }

    /// This method returns [Expr::Seq] or [Expr::Cond].
    ///
    /// `exprs` is a simple optimization.
    fn merge_if_returns_to(&mut self, stmt: Stmt, mut exprs: Vec<Box<Expr>>) -> Expr {
        //
        match stmt {
            Stmt::Block(s) => {
                assert_eq!(s.stmts.len(), 1);
                self.merge_if_returns_to(s.stmts.into_iter().next().unwrap(), exprs)
            }

            Stmt::If(IfStmt {
                span,
                test,
                cons,
                alt,
                ..
            }) => {
                let cons = Box::new(self.merge_if_returns_to(*cons, vec![]));
                let alt = match alt {
                    Some(alt) => Box::new(self.merge_if_returns_to(*alt, vec![])),
                    None => undefined(DUMMY_SP),
                };

                exprs.push(test);

                Expr::Cond(CondExpr {
                    span,
                    test: Box::new(Expr::Seq(SeqExpr {
                        span: DUMMY_SP,
                        exprs,
                    })),
                    cons,
                    alt,
                })
            }
            Stmt::Expr(stmt) => {
                exprs.push(Box::new(Expr::Unary(UnaryExpr {
                    span: DUMMY_SP,
                    op: op!("void"),
                    arg: stmt.expr,
                })));
                Expr::Seq(SeqExpr {
                    span: DUMMY_SP,
                    exprs,
                })
            }
            Stmt::Return(stmt) => {
                let span = stmt.span;
                exprs.push(stmt.arg.unwrap_or_else(|| undefined(span)));
                Expr::Seq(SeqExpr {
                    span: DUMMY_SP,
                    exprs,
                })
            }
            _ => unreachable!(),
        }
    }

    fn can_merge_stmt_as_if_return(&self, s: &Stmt, _is_last: bool) -> bool {
        let res = match s {
            Stmt::Expr(..) => true,
            Stmt::Return(..) => true,
            Stmt::Block(s) => {
                s.stmts.len() == 1 && self.can_merge_stmt_as_if_return(&s.stmts[0], false)
            }
            Stmt::If(stmt) => {
                matches!(&*stmt.cons, Stmt::Return(..))
                    && matches!(
                        stmt.alt.as_deref(),
                        None | Some(Stmt::Return(..) | Stmt::Expr(..))
                    )
            }
            _ => false,
        };
        // if !res {
        //     trace!("Cannot merge: {}", dump(s));
        // }

        res
    }
}

fn get_rightmost_alt_of_cond(e: &mut CondExpr) -> &mut Expr {
    match &mut *e.alt {
        Expr::Cond(alt) => get_rightmost_alt_of_cond(alt),
        alt => alt,
    }
}

fn count_leaping_returns<N>(n: &N) -> usize
where
    N: VisitWith<ReturnFinder>,
{
    let mut v = ReturnFinder::default();
    n.visit_with(&mut v);
    v.count
}

#[derive(Default)]
pub(super) struct ReturnFinder {
    count: usize,
}

impl Visit for ReturnFinder {
    noop_visit_type!();

    fn visit_return_stmt(&mut self, n: &ReturnStmt) {
        n.visit_children_with(self);
        self.count += 1;
    }

    fn visit_function(&mut self, _: &Function) {}

    fn visit_arrow_expr(&mut self, _: &ArrowExpr) {}
}

fn always_terminates_with_return_arg(s: &Stmt) -> bool {
    match s {
        Stmt::Return(ReturnStmt { arg: Some(..), .. }) => true,
        Stmt::If(IfStmt { cons, alt, .. }) => {
            always_terminates_with_return_arg(cons)
                && alt
                    .as_deref()
                    .map(always_terminates_with_return_arg)
                    .unwrap_or(false)
        }
        Stmt::Block(s) => s.stmts.iter().any(always_terminates_with_return_arg),

        _ => false,
    }
}

fn can_merge_as_if_return(s: &Stmt) -> bool {
    fn cost(s: &Stmt) -> Option<isize> {
        if let Stmt::Block(..) = s {
            if !s.terminates() {
                return None;
            }
        }

        match s {
            Stmt::Return(ReturnStmt { arg: Some(..), .. }) => Some(-1),

            Stmt::Return(ReturnStmt { arg: None, .. }) => Some(0),

            Stmt::Throw(..) | Stmt::Break(..) | Stmt::Continue(..) => Some(0),

            Stmt::If(IfStmt { cons, alt, .. }) => {
                Some(cost(cons)? + alt.as_deref().and_then(cost).unwrap_or(0))
            }
            Stmt::Block(s) => {
                let mut sum = 0;
                let mut found = false;
                for s in s.stmts.iter().rev() {
                    let c = cost(s);
                    if let Some(c) = c {
                        found = true;
                        sum += c;
                    }
                }
                if found {
                    Some(sum)
                } else {
                    None
                }
            }

            _ => None,
        }
    }

    let c = cost(s);

    trace_op!("merging cost of `{}` = {:?}", dump(s, false), c);

    c.unwrap_or(0) < 0
}
