use sqlparser::{
    ast::{
        Expr, Function, FunctionArg, ListAgg, ListAggOnOverflow, ObjectName, Query, Select,
        SelectItem, SetExpr, TableFactor, TableWithJoins, WindowSpec, With,
    },
    dialect::{self, Dialect},
    tokenizer::{Token, Tokenizer},
};

pub enum SqlDialect {
    Ansi,
    Generic,
    Hive,
    MsSql,
    MySql,
    PostgreSql,
    SQLite,
    Snowflake,
}

impl SqlDialect {
    fn sqlparser_dialect(&self) -> Box<dyn Dialect> {
        match self {
            Self::Ansi => Box::new(dialect::AnsiDialect {}),
            Self::Generic => Box::new(dialect::GenericDialect {}),
            Self::Hive => Box::new(dialect::HiveDialect {}),
            Self::MsSql => Box::new(dialect::MsSqlDialect {}),
            Self::MySql => Box::new(dialect::MySqlDialect {}),
            Self::PostgreSql => Box::new(dialect::PostgreSqlDialect {}),
            Self::SQLite => Box::new(dialect::SQLiteDialect {}),
            Self::Snowflake => Box::new(dialect::SnowflakeDialect {}),
        }
    }

    fn tokenize(&self, query: &str) -> Result<Vec<Token>, &'static str> {
        let dialect = self.sqlparser_dialect();
        Tokenizer::new(&*dialect, query)
            .tokenize()
            .or(Err("Failed to tokenize query"))
    }
}

pub struct Parser {
    query: Query,
}

impl Parser {
    pub fn try_new(sql_dialect: SqlDialect, query: &str) -> Result<Self, &'static str> {
        let tokens = sql_dialect.tokenize(query)?;
        let query =
            sqlparser::parser::Parser::new(tokens, &*sql_dialect.sqlparser_dialect()).parse_query();
        match query {
            Ok(query) => Ok(Self { query }),
            Err(_) => Err("Failed to parse query"),
        }
    }

    pub fn tables(&self) -> Vec<&str> {
        from_query(&self.query)
    }
}

fn from_object_name(obj: &ObjectName) -> &str {
    &obj.0.iter().last().unwrap().value
}

fn from_table_factor(table_factor: &TableFactor) -> Vec<&str> {
    match table_factor {
        TableFactor::Table { name, .. } => vec![from_object_name(name)],
        TableFactor::Derived { subquery, .. } => from_query(subquery),
        TableFactor::NestedJoin(table_with_joins) => from_table_with_joins(table_with_joins),
        _ => todo!(),
    }
}

fn from_table_with_joins(table_with_joins: &TableWithJoins) -> Vec<&str> {
    let mut froms = from_table_factor(&table_with_joins.relation);
    let join_froms: Vec<_> = table_with_joins
        .joins
        .iter()
        .flat_map(|join| from_table_factor(&join.relation))
        .collect();
    froms.extend(join_froms);
    froms
}

fn from_set_expr(set_expr: &SetExpr) -> Vec<&str> {
    match set_expr {
        SetExpr::Select(select) => from_select(select),
        SetExpr::Query(query) => from_query(query),
        SetExpr::SetOperation { left, right, .. } => {
            let mut froms = from_set_expr(left);
            let right_froms = from_set_expr(right);
            froms.extend(right_froms);
            froms
        }
        _ => vec![],
    }
}

fn from_expr(expr: &Expr) -> Vec<&str> {
    match expr {
        Expr::IsNull(expr) => from_expr(expr),
        Expr::IsNotNull(expr) => from_expr(expr),
        Expr::InList { expr, list, .. } => {
            let mut expr_res = from_expr(expr);
            let list_res = list.iter().flat_map(from_expr);
            expr_res.extend(list_res);
            expr_res
        }
        Expr::InSubquery { expr, subquery, .. } => {
            let mut expr_res = from_expr(expr);
            let subquery_res = from_query(subquery);
            expr_res.extend(subquery_res);
            expr_res
        }
        Expr::Between {
            expr, low, high, ..
        } => {
            let mut expr_res = from_expr(expr);
            let low_res = from_expr(low);
            let high_res = from_expr(high);
            expr_res.extend(low_res);
            expr_res.extend(high_res);
            expr_res
        }
        Expr::BinaryOp { left, right, .. } => {
            let mut left_res = from_expr(left);
            let right_res = from_expr(right);
            left_res.extend(right_res);
            left_res
        }
        Expr::UnaryOp { expr, .. } => from_expr(expr),
        Expr::Cast { expr, .. } => from_expr(expr),
        Expr::TryCast { expr, .. } => from_expr(expr),
        Expr::Extract { expr, .. } => from_expr(expr),
        Expr::Substring {
            expr,
            substring_from,
            substring_for,
        } => {
            let mut expr_res = from_expr(expr);
            if let Some(expr) = substring_from {
                let substring_from_res = from_expr(expr);
                expr_res.extend(substring_from_res);
            }
            if let Some(expr) = substring_for {
                let substring_for_res = from_expr(expr);
                expr_res.extend(substring_for_res);
            }
            expr_res
        }
        Expr::Trim { expr, trim_where } => {
            let mut expr_res = from_expr(expr);
            if let Some(trim_where) = trim_where {
                let trim_where_res = from_expr(&trim_where.1);
                expr_res.extend(trim_where_res);
            }
            expr_res
        }
        Expr::Collate { expr, .. } => from_expr(expr),
        Expr::Nested(expr) => from_expr(expr),
        Expr::MapAccess { column, .. } => from_expr(column),
        Expr::Function(Function { args, over, .. }) => {
            let mut args_res: Vec<_> = args
                .iter()
                .flat_map(|arg| match arg {
                    FunctionArg::Named { arg, .. } => from_expr(arg),
                    FunctionArg::Unnamed(expr) => from_expr(expr),
                })
                .collect();
            if let Some(WindowSpec {
                partition_by,
                order_by,
                ..
            }) = over
            {
                let partition_by_res = partition_by.iter().flat_map(from_expr);
                let order_by_res = order_by.iter().flat_map(|by| from_expr(&by.expr));
                args_res.extend(partition_by_res);
                args_res.extend(order_by_res);
            }
            args_res
        }
        Expr::Case {
            operand,
            conditions,
            results,
            else_result,
        } => {
            let mut conditions_res: Vec<_> = conditions.iter().flat_map(from_expr).collect();
            let results_res = results.iter().flat_map(from_expr);
            conditions_res.extend(results_res);
            if let Some(operand) = operand {
                let operand_res = from_expr(operand);
                conditions_res.extend(operand_res);
            }
            if let Some(else_result) = else_result {
                let else_result_res = from_expr(else_result);
                conditions_res.extend(else_result_res);
            }
            conditions_res
        }
        Expr::Exists(query) => from_query(query),
        Expr::Subquery(query) => from_query(query),
        Expr::ListAgg(ListAgg {
            expr,
            separator,
            on_overflow,
            within_group,
            ..
        }) => {
            let mut expr_res = from_expr(expr);
            let within_group_res = within_group.iter().flat_map(|by| from_expr(&by.expr));
            expr_res.extend(within_group_res);
            if let Some(separator) = separator {
                let separator_res = from_expr(separator);
                expr_res.extend(separator_res);
            }
            if let Some(ListAggOnOverflow::Truncate {
                filler: Some(filler),
                ..
            }) = on_overflow
            {
                let on_overflow_res = from_expr(filler);
                expr_res.extend(on_overflow_res);
            }
            expr_res
        }
        _ => vec![],
    }
}

fn from_select_item(select_item: &SelectItem) -> Vec<&str> {
    match select_item {
        SelectItem::UnnamedExpr(expr) => from_expr(expr),
        SelectItem::ExprWithAlias { expr, .. } => from_expr(expr),
        _ => vec![],
    }
}

fn from_select(select: &Select) -> Vec<&str> {
    let mut projection_res: Vec<_> = select
        .projection
        .iter()
        .flat_map(from_select_item)
        .collect();

    let from_res = select.from.iter().flat_map(from_table_with_joins);
    projection_res.extend(from_res);

    let lateral_views_res = select
        .lateral_views
        .iter()
        .flat_map(|lat| from_expr(&lat.lateral_view));
    projection_res.extend(lateral_views_res);

    let group_by_res = select.group_by.iter().flat_map(from_expr);
    projection_res.extend(group_by_res);

    let cluster_by_res = select.cluster_by.iter().flat_map(from_expr);
    projection_res.extend(cluster_by_res);

    let distribute_by_res = select.distribute_by.iter().flat_map(from_expr);
    projection_res.extend(distribute_by_res);

    let sort_by_res = select.sort_by.iter().flat_map(from_expr);
    projection_res.extend(sort_by_res);

    if let Some(top) = &select.top {
        if let Some(quantity) = &top.quantity {
            let quantity_res = from_expr(quantity);
            projection_res.extend(quantity_res);
        }
    }

    if let Some(selection) = &select.selection {
        let selection_res = from_expr(selection);
        projection_res.extend(selection_res);
    }

    if let Some(having) = &select.having {
        let having_res = from_expr(having);
        projection_res.extend(having_res);
    }

    projection_res
}

fn from_query(query: &Query) -> Vec<&str> {
    let mut froms = from_set_expr(&query.body);
    let order_by_res = query
        .order_by
        .iter()
        .flat_map(|order_by| from_expr(&order_by.expr));
    froms.extend(order_by_res);
    if let Some(limit) = &query.limit {
        let limit_res = from_expr(limit);
        froms.extend(limit_res);
    }
    if let Some(offset) = &query.offset {
        let offset_res = from_expr(&offset.value);
        froms.extend(offset_res);
    }
    if let Some(fetch) = &query.fetch {
        if let Some(quantity) = &fetch.quantity {
            let quantity_res = from_expr(quantity);
            froms.extend(quantity_res);
        }
    }
    if let Some(with) = &query.with {
        let aliases = alias_with(with);
        froms.retain(|from| !aliases.contains(from));
        let with_froms: Vec<_> = with
            .cte_tables
            .iter()
            .flat_map(|cte| from_query(&cte.query))
            .collect();
        froms.extend(with_froms);
    }
    froms.sort_unstable();
    froms.dedup();
    froms
}

fn alias_with(with: &With) -> Vec<&str> {
    with.cte_tables
        .iter()
        .map(|cte| cte.alias.name.value.as_str())
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn simple_select() {
        let query = "select * from cool_table";
        assert_eq!(
            Parser::try_new(SqlDialect::PostgreSql, query)
                .unwrap()
                .tables(),
            vec!["cool_table"]
        );
    }

    #[test]
    fn simple_union() {
        let query = "select * from cool_table union all select * from other_table";
        let parser = Parser::try_new(SqlDialect::PostgreSql, query).unwrap();
        let tables = parser.tables();
        assert_eq!(tables, vec!["cool_table", "other_table"]);
    }

    #[test]
    fn simple_join() {
        let query = "
            select 
                a.col_from_a,
                b.col_from b
            from table_a as a
            inner join table_b as b using (id_col)
        ";
        let parser = Parser::try_new(SqlDialect::PostgreSql, query).unwrap();
        let tables = parser.tables();
        assert_eq!(tables, vec!["table_a", "table_b"]);
    }

    #[test]
    fn cte() {
        let query = "
            with my_cte as (
                select * from cool_table
            )
            select * from my_cte
        ";
        let parser = Parser::try_new(SqlDialect::PostgreSql, query).unwrap();
        let tables = parser.tables();
        assert_eq!(tables, vec!["cool_table"])
    }

    #[test]
    fn scoped_cte() {
        let query = "
            select *
            from actual_table
            cross join (
                with actual_table as (
                    select * from unique_table
                )
                select * from actual_table
            ) as sub
        ";
        let parser = Parser::try_new(SqlDialect::PostgreSql, query).unwrap();
        let tables = parser.tables();
        assert_eq!(tables, vec!["actual_table", "unique_table"]);
    }

    #[test]
    fn repeated_table() {
        let query = "
            select * from cool_table
            union all
            select * from cool_table
        ";
        let parser = Parser::try_new(SqlDialect::PostgreSql, query).unwrap();
        let tables = parser.tables();
        assert_eq!(tables, vec!["cool_table"])
    }

    #[test]
    fn correlated_subqueries() {
        let query = "
            select
                *,
                (
                    select sum(amount)
                    from sales
                    where sales.customer_id = customers.customer_id
                ) as sales
            from customers
            where exists (
                select * from active_customers
                where active_customers.customer_id = customers.customer_id
            )
        ";
        let parser = Parser::try_new(SqlDialect::PostgreSql, query).unwrap();
        let tables = parser.tables();
        assert_eq!(tables, vec!["active_customers", "customers", "sales"]);
    }

    #[test]
    fn kitchen_sink() {
        let query = "
            with my_cte as (
                select * from my_table
            ),
            my_other_cte as (
                select * from my_other_table
            )
            select
                *,
                (
                    select max(my_amount)
                    from my_amount_table
                    where my_amount_table.my_amount_id = my_main_table.my_amount_id
                ) as amount
            from my_main_table
            inner join my_cte using (my_id)
            union all
            select * from (
                with my_main_table as (
                    select * from my_table
                )
                select * from my_other_cte
                where exists (
                    select * from my_main_table
                )
                union all
                select * from my_side_table
            ) as sub
            where exists (
                select * from my_lookup_table
                where my_lookup_table.my_lookup_id = sub.my_lookup_id
            )
        ";
        let parser = Parser::try_new(SqlDialect::PostgreSql, query).unwrap();
        let tables = parser.tables();
        assert_eq!(
            tables,
            vec![
                "my_amount_table",
                "my_lookup_table",
                "my_main_table",
                "my_other_table",
                "my_side_table",
                "my_table"
            ]
        );
    }

    #[test]
    fn invalid_query() {
        let query = "selext column_a, column_b from my_table";
        let parser = Parser::try_new(SqlDialect::PostgreSql, query);
        assert!(parser.is_err());
    }
}
