use anyhow::{anyhow, Result};
use clap::Parser;
use petgraph::{algo::toposort, graphmap::DiGraphMap};
use sqlx::{postgres::PgConnection, Connection, Executor};
use std::{collections::HashMap, fs, path::PathBuf};

#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
    #[clap(short, long)]
    conn: String,
}

#[tokio::main]
async fn main() -> Result<()> {
    let args = Args::parse();
    let project = Project::try_from_dir(".")?;
    let statements = project.to_sql()?;
    let num_models = statements.len();
    if num_models > 0 {
        let mut conn = PgConnection::connect(&args.conn).await?;
        for statement in statements {
            conn.execute(&*statement).await?;
        }
        println!("created {} models", num_models);
    }
    Ok(())
}

#[derive(Debug)]
struct Model {
    query: String,
    refs: Vec<String>,
}

impl Model {
    fn new(query: String, refs: Vec<String>) -> Self {
        Self { query, refs }
    }
}

#[derive(Debug)]
struct Project {
    models: HashMap<String, Model>,
}

impl Project {
    fn try_from_dir(dir: &str) -> Result<Self> {
        let sql_files = get_sql_files(dir)?;
        let mut models = HashMap::new();
        for file in sql_files {
            let query = fs::read_to_string(&file)?;
            let refs = get_tables(&query)?;
            let refs = refs
                .into_iter()
                .filter_map(|ref_ids| match ref_ids.len() {
                    1 => Some(ref_ids[0].clone()),
                    _ => None,
                })
                .collect();
            let name = file.file_stem().unwrap().to_str().unwrap().to_owned();
            let model = Model::new(query, refs);
            models.insert(name, model);
        }
        Ok(Self { models })
    }

    fn edges(&self) -> Vec<(&str, &str)> {
        self.models
            .iter()
            .flat_map(|(name, model)| {
                model
                    .refs
                    .iter()
                    .map(|ref_id| (ref_id.as_str(), name.as_str()))
            })
            .collect()
    }

    fn execution_order(&self) -> Result<Vec<&str>> {
        let graph: DiGraphMap<&str, &str> = self.edges().into_iter().collect();
        match toposort(&graph, None) {
            Ok(elems) => Ok(elems),
            Err(cycle) => Err(anyhow!("{:?}", cycle)),
        }
    }

    fn to_sql(&self) -> Result<Vec<String>> {
        Ok(self
            .execution_order()?
            .into_iter()
            .map(|model_name| {
                let model = self.models.get(model_name).unwrap();
                format!(
                    "drop view if exists {};\n\
                    create view {} as\n\
                    {};",
                    &model_name, &model_name, &model.query
                )
            })
            .collect())
    }
}

fn get_sql_files(dir: &str) -> Result<Vec<PathBuf>> {
    Ok(fs::read_dir(dir)?
        .into_iter()
        .filter_map(|entry| match entry {
            Ok(entry) => match entry.path().extension() {
                Some(extension) => {
                    if extension == "sql" {
                        Some(entry.path())
                    } else {
                        None
                    }
                }
                None => None,
            },
            Err(_) => None,
        })
        .collect())
}

fn get_tables(query: &str) -> Result<Vec<Vec<String>>> {
    let parser = metaquery::Parser::try_new(metaquery::SqlDialect::PostgreSql, query);
    match parser {
        Ok(parser) => Ok(parser
            .tables()
            .into_iter()
            .map(|idents| idents.into_iter().map(|id| id.to_owned()).collect())
            .collect()),
        Err(msg) => Err(anyhow!(msg)),
    }
}
