use sqlparser::{
    ast,
    dialect::{self, Dialect as SqlParserDialect},
};
use std::collections::HashSet;
use thiserror::Error;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SqlType {
    CreateView,
    CreateTable,
    InsertInto,
    DeleteFrom,
    Update,
}

#[derive(Debug, PartialEq, Eq)]
pub struct Statement {
    pub target: String,
    pub kind: SqlType,
    pub sources: HashSet<String>,
}

#[derive(Debug, Clone, Copy)]
pub enum Dialect {
    Ansi,
    Clickhouse,
    Generic,
    Hive,
    Mssql,
    Mysql,
    Postgres,
    Sqlite,
    Snowflake,
}

impl Dialect {
    fn sqlparser_dialect(&self) -> Box<dyn SqlParserDialect> {
        match self {
            Self::Ansi => Box::new(dialect::AnsiDialect {}),
            Self::Clickhouse => Box::new(dialect::ClickHouseDialect {}),
            Self::Generic => Box::new(dialect::GenericDialect {}),
            Self::Hive => Box::new(dialect::HiveDialect {}),
            Self::Mssql => Box::new(dialect::MsSqlDialect {}),
            Self::Mysql => Box::new(dialect::MySqlDialect {}),
            Self::Postgres => Box::new(dialect::PostgreSqlDialect {}),
            Self::Sqlite => Box::new(dialect::SQLiteDialect {}),
            Self::Snowflake => Box::new(dialect::SnowflakeDialect {}),
        }
    }
}

#[derive(Debug, Error)]
pub enum MetaQueryError {
    #[error("{0}")]
    SqlParserError(String),
}

pub fn parse_statements(sql: &str, dialect: Dialect) -> Result<Vec<Statement>, MetaQueryError> {
    let parsed = sqlparser::parser::Parser::parse_sql(&*dialect.sqlparser_dialect(), sql);
    match parsed {
        Ok(statements) => Ok(statements.iter().filter_map(from_statement).collect()),
        Err(msg) => Err(MetaQueryError::SqlParserError(format!("{msg}"))),
    }
}

fn from_statement(statement: &ast::Statement) -> Option<Statement> {
    match statement {
        ast::Statement::CreateTable {
            name,
            query: Some(query),
            ..
        } => {
            let sources = from_query(query);
            if sources.is_empty() {
                None
            } else {
                let target = from_object_name(name);
                Some(Statement {
                    target,
                    sources: HashSet::from_iter(sources),
                    kind: SqlType::CreateTable,
                })
            }
        }
        ast::Statement::CreateView { name, query, .. } => {
            let sources = from_query(query);
            if sources.is_empty() {
                None
            } else {
                let target = from_object_name(name);
                Some(Statement {
                    target,
                    sources: HashSet::from_iter(sources),
                    kind: SqlType::CreateView,
                })
            }
        }
        ast::Statement::Insert {
            table_name, source, ..
        } => {
            let sources = from_query(source);
            if sources.is_empty() {
                None
            } else {
                let target = from_object_name(table_name);
                Some(Statement {
                    target,
                    sources: HashSet::from_iter(sources),
                    kind: SqlType::InsertInto,
                })
            }
        }
        ast::Statement::Delete {
            table_name,
            selection: Some(selection),
        } => {
            let sources = from_expr(selection);
            if sources.is_empty() {
                None
            } else {
                let target = from_object_name(table_name);
                Some(Statement {
                    target,
                    sources: HashSet::from_iter(sources),
                    kind: SqlType::DeleteFrom,
                })
            }
        }
        ast::Statement::Update {
            table,
            assignments,
            selection,
            ..
        } => {
            let target = from_table_with_joins(table)[0].clone();
            let mut sources: Vec<_> = assignments
                .iter()
                .flat_map(|ast::Assignment { value, .. }| from_expr(value))
                .collect();
            if let Some(selection) = selection {
                sources.extend(from_expr(selection));
            }
            if sources.is_empty() {
                None
            } else {
                Some(Statement {
                    target,
                    sources: HashSet::from_iter(sources),
                    kind: SqlType::Update,
                })
            }
        }
        _ => None,
    }
}

fn from_object_name(name: &ast::ObjectName) -> String {
    name.0.iter().map(from_ident).collect::<Vec<_>>().join(".")
}

fn from_ident(ident: &ast::Ident) -> String {
    ident.value.to_ascii_lowercase()
}

fn from_query(query: &ast::Query) -> Vec<String> {
    let mut body_sources = from_set_expr(&query.body);
    if let Some(with) = &query.with {
        let with_info = WithInfo::new(with);
        body_sources.retain(|source| !with_info.aliases.contains(source));
        body_sources.extend(with_info.sources);
    }
    body_sources
}

fn from_set_expr(set_expr: &ast::SetExpr) -> Vec<String> {
    match set_expr {
        ast::SetExpr::Select(select) => from_select(select),
        ast::SetExpr::Query(query) => from_query(query),
        ast::SetExpr::SetOperation { left, right, .. } => [left, right]
            .iter()
            .flat_map(|&set_expr| from_set_expr(set_expr))
            .collect(),
        _ => vec![],
    }
}

fn from_select(select: &ast::Select) -> Vec<String> {
    let mut from_sources: Vec<_> = select.from.iter().flat_map(from_table_with_joins).collect();
    let from_columns: Vec<_> = select
        .projection
        .iter()
        .flat_map(from_select_item)
        .collect();
    from_sources.extend(from_columns);
    from_sources
}

fn from_table_with_joins(table_with_joins: &ast::TableWithJoins) -> Vec<String> {
    let mut from_relations = from_table_factor(&table_with_joins.relation);
    let from_joins: Vec<_> = table_with_joins
        .joins
        .iter()
        .flat_map(|ast::Join { relation, .. }| from_table_factor(relation))
        .collect();
    from_relations.extend(from_joins);
    from_relations
}

fn from_select_item(select_item: &ast::SelectItem) -> Vec<String> {
    match select_item {
        ast::SelectItem::UnnamedExpr(expr) => from_expr(expr),
        ast::SelectItem::ExprWithAlias { expr, .. } => from_expr(expr),
        _ => vec![],
    }
}

fn from_table_factor(table_factor: &ast::TableFactor) -> Vec<String> {
    match table_factor {
        ast::TableFactor::Table { name, .. } => vec![from_object_name(name)],
        ast::TableFactor::Derived { subquery, .. } => from_query(subquery),
        ast::TableFactor::NestedJoin(join) => from_table_with_joins(join),
        _ => vec![],
    }
}

fn from_expr(expr: &ast::Expr) -> Vec<String> {
    match expr {
        ast::Expr::IsNull(expr) => from_expr(expr),
        ast::Expr::IsNotNull(expr) => from_expr(expr),
        ast::Expr::IsDistinctFrom(left, right) => [left, right]
            .iter()
            .flat_map(|&expr| from_expr(expr))
            .collect(),
        ast::Expr::IsNotDistinctFrom(left, right) => [left, right]
            .iter()
            .flat_map(|&expr| from_expr(expr))
            .collect(),
        ast::Expr::InList { expr, list, .. } => {
            let mut expr_sources = from_expr(expr);
            let list_sources: Vec<_> = list.iter().flat_map(from_expr).collect();
            expr_sources.extend(list_sources);
            expr_sources
        }
        ast::Expr::InSubquery { expr, subquery, .. } => {
            let mut expr_sources = from_expr(expr);
            let subquery_sources = from_query(subquery);
            expr_sources.extend(subquery_sources);
            expr_sources
        }
        ast::Expr::InUnnest {
            expr, array_expr, ..
        } => [expr, array_expr]
            .iter()
            .flat_map(|&expr| from_expr(expr))
            .collect(),
        ast::Expr::Between {
            expr, low, high, ..
        } => [expr, low, high]
            .iter()
            .flat_map(|&expr| from_expr(expr))
            .collect(),
        ast::Expr::BinaryOp { left, right, .. } => [left, right]
            .iter()
            .flat_map(|&expr| from_expr(expr))
            .collect(),
        ast::Expr::UnaryOp { expr, .. } => from_expr(expr),
        ast::Expr::Cast { expr, .. } => from_expr(expr),
        ast::Expr::TryCast { expr, .. } => from_expr(expr),
        ast::Expr::Extract { expr, .. } => from_expr(expr),
        ast::Expr::Substring {
            expr,
            substring_from,
            substring_for,
        } => {
            let mut exprs = vec![expr];
            if let Some(substring_from) = substring_from {
                exprs.push(substring_from);
            }
            if let Some(substring_for) = substring_for {
                exprs.push(substring_for);
            }
            exprs.iter().flat_map(|&expr| from_expr(expr)).collect()
        }
        ast::Expr::Trim { expr, trim_where } => {
            let mut exprs = vec![expr];
            if let Some((_, trim_where_expr)) = trim_where {
                exprs.push(trim_where_expr);
            }
            exprs.iter().flat_map(|&expr| from_expr(expr)).collect()
        }
        ast::Expr::Collate { expr, .. } => from_expr(expr),
        ast::Expr::Nested(expr) => from_expr(expr),
        ast::Expr::MapAccess { column, keys } => {
            let mut exprs = keys.to_owned();
            exprs.push(*column.to_owned());
            exprs.iter().flat_map(from_expr).collect()
        }
        ast::Expr::Case {
            operand,
            conditions,
            results,
            else_result,
        } => {
            let mut exprs = conditions.to_owned();
            exprs.extend(results.to_owned());
            if let Some(operand) = operand {
                exprs.push(*operand.to_owned());
            }
            if let Some(else_result) = else_result {
                exprs.push(*else_result.to_owned());
            }
            exprs.iter().flat_map(from_expr).collect()
        }
        ast::Expr::Exists(query) => from_query(query),
        ast::Expr::Subquery(query) => from_query(query),
        ast::Expr::ListAgg(list_agg) => from_list_agg(list_agg),
        ast::Expr::GroupingSets(exprs) => from_nested_exprs(exprs),
        ast::Expr::Cube(exprs) => from_nested_exprs(exprs),
        ast::Expr::Rollup(exprs) => from_nested_exprs(exprs),
        ast::Expr::Tuple(exprs) => exprs.iter().flat_map(from_expr).collect(),
        ast::Expr::ArrayIndex { obj, indexs, .. } => {
            let mut exprs = indexs.to_owned();
            exprs.push(*obj.to_owned());
            exprs.iter().flat_map(from_expr).collect()
        }
        ast::Expr::Array(ast::Array { elem, .. }) => elem.iter().flat_map(from_expr).collect(),
        _ => vec![],
    }
}

fn from_nested_exprs(exprs: &[Vec<ast::Expr>]) -> Vec<String> {
    exprs
        .iter()
        .flat_map(|exprs| exprs.iter().flat_map(from_expr))
        .collect()
}

fn from_list_agg(list_agg: &ast::ListAgg) -> Vec<String> {
    let mut exprs = vec![*list_agg.expr.to_owned()];
    if let Some(separator) = &list_agg.separator {
        exprs.push(*separator.to_owned());
    }
    if let Some(ast::ListAggOnOverflow::Truncate {
        filler: Some(filler),
        ..
    }) = &list_agg.on_overflow
    {
        exprs.push(*filler.to_owned());
    }
    exprs.extend(
        list_agg
            .within_group
            .iter()
            .map(|ast::OrderByExpr { expr, .. }| expr.to_owned()),
    );
    exprs.iter().flat_map(from_expr).collect()
}

struct WithInfo {
    aliases: Vec<String>,
    sources: Vec<String>,
}

impl WithInfo {
    fn new(with: &ast::With) -> Self {
        let aliases: Vec<_> = with
            .cte_tables
            .iter()
            .map(|cte| from_ident(&cte.alias.name))
            .collect();
        let sources: Vec<_> = with
            .cte_tables
            .iter()
            .flat_map(|cte| from_query(&cte.query))
            .filter(|source| !aliases.contains(source))
            .collect();
        Self { aliases, sources }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn basic_statements() {
        let sql = "
            create view a.foo as
            select * from a.bar;
            create table b.foo as
            select * from b.bar;
            insert into c.foo
            select * from c.bar;
            delete from d.foo
            where exists (
                select 1 from d.bar
                where bar.foo_id = foo.foo_id
            );
            update e.foo
            set foo_id = 1
            where exists (
                select 1 from e.bar
                where bar.foo_id = foo.foo_id
            );
        ";
        let statements = parse_statements(sql, Dialect::Postgres).unwrap();
        assert_eq!(
            statements,
            vec![
                Statement {
                    target: "a.foo".into(),
                    kind: SqlType::CreateView,
                    sources: HashSet::from_iter(vec!["a.bar".into()]),
                },
                Statement {
                    target: "b.foo".into(),
                    kind: SqlType::CreateTable,
                    sources: HashSet::from_iter(vec!["b.bar".into()]),
                },
                Statement {
                    target: "c.foo".into(),
                    kind: SqlType::InsertInto,
                    sources: HashSet::from_iter(vec!["c.bar".into()]),
                },
                Statement {
                    target: "d.foo".into(),
                    kind: SqlType::DeleteFrom,
                    sources: HashSet::from_iter(vec!["d.bar".into()]),
                },
                Statement {
                    target: "e.foo".into(),
                    kind: SqlType::Update,
                    sources: HashSet::from_iter(vec!["e.bar".into()]),
                },
            ]
        );
    }

    #[test]
    fn joins() {
        let sql = "
            create view foo as
            select *
            from bar
            inner join baz using (foo_id);
        ";
        let statements = parse_statements(sql, Dialect::Postgres).unwrap();
        assert_eq!(
            statements,
            vec![Statement {
                target: "foo".into(),
                kind: SqlType::CreateView,
                sources: HashSet::from_iter(vec!["bar".into(), "baz".into()]),
            }],
        );
    }

    #[test]
    fn ctes() {
        let sql = "
            insert into foo
            with cte1 as (
                select * from bar
            ),
            cte2 as (
                select * from baz
            )
            select * from cte1
            union all
            select * from cte2;
        ";
        let statements = parse_statements(sql, Dialect::Postgres).unwrap();
        assert_eq!(
            statements,
            vec![Statement {
                target: "foo".into(),
                kind: SqlType::InsertInto,
                sources: HashSet::from_iter(vec!["bar".into(), "baz".into()]),
            }],
        );
    }

    #[test]
    fn consecutive_ctes() {
        let sql = "
            insert into foo
            with cte1 as (
                select * from bar
            ),
            cte2 as (
                select * from cte1
            )
            select * from cte2;
        ";
        let statements = parse_statements(sql, Dialect::Postgres).unwrap();
        assert_eq!(
            statements,
            vec![Statement {
                target: "foo".into(),
                kind: SqlType::InsertInto,
                sources: HashSet::from_iter(vec!["bar".into()]),
            }],
        );
    }

    #[test]
    fn subquery() {
        let sql = "
            create table foo as
            select * from (
                select * from bar
            ) base;
        ";
        let statements = parse_statements(sql, Dialect::Postgres).unwrap();
        assert_eq!(
            statements,
            vec![Statement {
                target: "foo".into(),
                kind: SqlType::CreateTable,
                sources: HashSet::from_iter(vec!["bar".into()]),
            }],
        );
    }

    #[test]
    fn column_expression() {
        let sql = "
            create view a.foo as
            select *,
                   (select baz_col from c.baz
                    where baz.foo_id = bar.foo_id) as baz_col
            from b.bar;
        ";
        let statements = parse_statements(sql, Dialect::Postgres).unwrap();
        assert_eq!(
            statements,
            vec![Statement {
                target: "a.foo".into(),
                kind: SqlType::CreateView,
                sources: HashSet::from_iter(vec!["b.bar".into(), "c.baz".into()]),
            }],
        );
    }

    #[test]
    fn ignored_statements() {
        let sql = "
            drop table foo;
            create table foo as
            select * from bar;
            delete from foo
            where foo_id = 1;
        ";
        let statements = parse_statements(sql, Dialect::Postgres).unwrap();
        assert_eq!(
            statements,
            vec![Statement {
                target: "foo".into(),
                kind: SqlType::CreateTable,
                sources: HashSet::from_iter(vec!["bar".into()]),
            }],
        );
    }

    #[test]
    fn kitchen_sink() {
        let sql = "
            drop table foo;

            create table foo as
            with cte1 as (
                select *
                from (
                    select * from orig.foo
                ) as cte1_base
            ),
            cte2 as (select 1)
            select * from cte1
            cross join cte2;

            drop view bar cascade;

            create view bar as
            select * from foo
            union all
            select * from orig.bar;

            delete from baz where delete_flag = 1;

            insert into baz
            select bar_alias.*,
                  (select qux_col from qux where qux.bar_id = bar_alias.bar_id) as qux_col
            from bar as bar_alias;

            delete from foo
            where foo_id = (
                select bar_alias.foo_id
                from bar as bar_alias
                where bar_alias.alt_foo_id = foo.alt_foo_id
            );
        ";
        let statements = parse_statements(sql, Dialect::Postgres).unwrap();
        assert_eq!(
            statements,
            vec![
                Statement {
                    target: "foo".into(),
                    kind: SqlType::CreateTable,
                    sources: HashSet::from_iter(vec!["orig.foo".into()]),
                },
                Statement {
                    target: "bar".into(),
                    kind: SqlType::CreateView,
                    sources: HashSet::from_iter(vec!["foo".into(), "orig.bar".into()]),
                },
                Statement {
                    target: "baz".into(),
                    kind: SqlType::InsertInto,
                    sources: HashSet::from_iter(vec!["bar".into(), "qux".into()]),
                },
                Statement {
                    target: "foo".into(),
                    kind: SqlType::DeleteFrom,
                    sources: HashSet::from_iter(vec!["bar".into()]),
                },
            ],
        );
    }
}
