use sqlparser::{
    ast::{ObjectName, Query, SetExpr, Statement, TableFactor, TableWithJoins, With},
    dialect::{self, Dialect},
    tokenizer::{Token, Tokenizer},
};

pub enum SqlDialect {
    Ansi,
    Generic,
    Hive,
    MsSql,
    MySql,
    PostgreSql,
    SQLite,
    Snowflake,
}

impl SqlDialect {
    fn sqlparser_dialect(&self) -> Box<dyn Dialect> {
        match self {
            Self::Ansi => Box::new(dialect::AnsiDialect {}),
            Self::Generic => Box::new(dialect::GenericDialect {}),
            Self::Hive => Box::new(dialect::HiveDialect {}),
            Self::MsSql => Box::new(dialect::MsSqlDialect {}),
            Self::MySql => Box::new(dialect::MySqlDialect {}),
            Self::PostgreSql => Box::new(dialect::PostgreSqlDialect {}),
            Self::SQLite => Box::new(dialect::SQLiteDialect {}),
            Self::Snowflake => Box::new(dialect::SnowflakeDialect {}),
        }
    }

    fn tokenize(&self, query: &str) -> Vec<Token> {
        let dialect = self.sqlparser_dialect();
        Tokenizer::new(&*dialect, query).tokenize().unwrap()
    }
}

pub struct Parser {
    statement: Statement,
}

impl Parser {
    pub fn new(sql_dialect: SqlDialect, query: &str) -> Self {
        let tokens = sql_dialect.tokenize(query);
        let statement = sqlparser::parser::Parser::new(tokens, &*sql_dialect.sqlparser_dialect())
            .parse_statement()
            .unwrap();
        Self { statement }
    }

    pub fn tables(&self) -> Vec<&str> {
        match &self.statement {
            Statement::Query(query) => from_query(query),
            _ => todo!(),
        }
    }
}

fn from_object_name(obj: &ObjectName) -> &str {
    &obj.0.iter().last().unwrap().value
}

fn from_table_factor(table_factor: &TableFactor) -> Vec<&str> {
    match table_factor {
        TableFactor::Table { name, .. } => vec![from_object_name(name)],
        TableFactor::Derived { subquery, .. } => from_query(subquery),
        TableFactor::NestedJoin(table_with_joins) => from_table_with_joins(table_with_joins),
        _ => todo!(),
    }
}

fn from_table_with_joins(table_with_joins: &TableWithJoins) -> Vec<&str> {
    let mut froms = from_table_factor(&table_with_joins.relation);
    let join_froms: Vec<_> = table_with_joins
        .joins
        .iter()
        .flat_map(|join| from_table_factor(&join.relation))
        .collect();
    froms.extend(join_froms);
    froms
}

fn from_set_expr(set_expr: &SetExpr) -> Vec<&str> {
    match set_expr {
        SetExpr::Select(select) => select.from.iter().flat_map(from_table_with_joins).collect(),
        SetExpr::Query(query) => from_query(query),
        SetExpr::SetOperation { left, right, .. } => {
            let mut froms = from_set_expr(left);
            let right_froms = from_set_expr(right);
            froms.extend(right_froms);
            froms
        }
        _ => vec![],
    }
}

fn from_query(query: &Query) -> Vec<&str> {
    let mut froms = from_set_expr(&query.body);
    if let Some(with) = &query.with {
        let aliases = alias_with(with);
        froms.retain(|from| !aliases.contains(from));
        let with_froms: Vec<_> = with
            .cte_tables
            .iter()
            .flat_map(|cte| from_query(&cte.query))
            .collect();
        froms.extend(with_froms);
    }
    froms.sort_unstable();
    froms.dedup();
    froms
}

fn alias_with(with: &With) -> Vec<&str> {
    with.cte_tables
        .iter()
        .map(|cte| cte.alias.name.value.as_str())
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn simple_select() {
        let query = "select * from cool_table";
        assert_eq!(
            Parser::new(SqlDialect::PostgreSql, query).tables(),
            vec!["cool_table"]
        );
    }

    #[test]
    fn simple_union() {
        let query = "select * from cool_table union all select * from other_table";
        let parser = Parser::new(SqlDialect::PostgreSql, query);
        let tables = parser.tables();
        assert_eq!(tables, vec!["cool_table", "other_table"]);
    }

    #[test]
    fn simple_join() {
        let query = "
            select 
                a.col_from_a,
                b.col_from b
            from table_a as a
            inner join table_b as b using (id_col)
        ";
        let parser = Parser::new(SqlDialect::PostgreSql, query);
        let tables = parser.tables();
        assert_eq!(tables, vec!["table_a", "table_b"]);
    }

    #[test]
    fn cte() {
        let query = "
            with my_cte as (
                select * from cool_table
            )
            select * from my_cte
        ";
        let parser = Parser::new(SqlDialect::PostgreSql, query);
        let tables = parser.tables();
        assert_eq!(tables, vec!["cool_table"])
    }

    #[test]
    fn scoped_cte() {
        let query = "
            select *
            from actual_table
            cross join (
                with actual_table as (
                    select * from unique_table
                )
                select * from actual_table
            ) as sub
        ";
        let parser = Parser::new(SqlDialect::PostgreSql, query);
        let tables = parser.tables();
        assert_eq!(tables, vec!["actual_table", "unique_table"]);
    }

    #[test]
    fn repeated_table() {
        let query = "
            select * from cool_table
            union all
            select * from cool_table
        ";
        let parser = Parser::new(SqlDialect::PostgreSql, query);
        let tables = parser.tables();
        assert_eq!(tables, vec!["cool_table"])
    }

    #[test]
    fn kitchen_sink() {
        let query = "
            with my_cte as (
                select * from my_table
            ),
            my_other_cte as (
                select * from my_other_table
            )
            select * from my_main_table
            inner join my_cte using (my_id)
            union all
            select * from (
                with my_main_table as (
                    select * from my_table
                )
                select * from my_other_cte
                where exists (
                    select * from my_main_table
                )
                union all
                select * from my_side_table
            ) as sub
        ";
        let parser = Parser::new(SqlDialect::PostgreSql, query);
        let tables = parser.tables();
        assert_eq!(
            tables,
            vec![
                "my_main_table",
                "my_other_table",
                "my_side_table",
                "my_table"
            ]
        );
    }
}
