// stargazer - A Gemini Server
// Copyright (C) 2021 Ben Aaron Goldberg <ben@benaaron.dev>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.

use crate::router::{
    CGIRoute, RedirectRoute, Route, RoutePath, RouteType, SCGIAddress,
    SCGIRoute, StaticRoute,
};
use anyhow::{bail, Context, Result};
use ini::{Ini, Properties};
use log::warn;
use std::fs;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::path::{Path, PathBuf};
use std::str::FromStr;
use users::get_user_by_name;
use x509_parser::prelude::X509Certificate;

#[derive(Debug, Clone)]
/// Program configuration
pub struct Config {
    /// Interfaces to listen on
    pub listen: Vec<SocketAddr>,
    /// Where to store certs
    pub store: PathBuf,
    /// Routes to server (only domains for now)
    pub routes: Vec<Route>,
    /// Number of worker threads (defaults to num cpus)
    pub worker_threads: usize,
    /// Max time to read a request (0 to disable)
    pub request_timeout: u64,
    /// Max time from start of request to end of response (0 to disable)
    pub response_timeout: u64,
    /// Cert organization name
    pub organization: String,
    /// Logging connections and other info to stdout
    pub conn_logging: bool,
    /// Logging status and meta for CGI and other info to stdout
    pub cgi_resp_logging: bool,
    /// Generate certs?
    pub generate_certs: bool,
    /// Regenerate expired certs?
    pub regen_certs: bool,
    /// How long should certs last in days, 0 means doesn't expire
    pub cert_lifetime: u64,
    /// Should stargazer log full or partial IPs or not at all
    pub ip_log: IpLogAmount,
}

#[derive(Debug, Clone, Copy)]
pub enum IpLogAmount {
    Full,
    Partial,
    None,
}

pub fn load(config_path: impl AsRef<Path>) -> Result<Config> {
    let res: Result<Config> = (|| {
        let config_path = config_path.as_ref();
        let mut conf = Ini::load_from_file_noescape(&config_path)?;

        let general = conf.general_section_mut();
        let listen = general
            .remove("listen")
            .context("Missing key `listen`")?
            .split(' ')
            .map(parse_interface)
            .collect::<Result<Vec<_>>>()?;
        if listen.is_empty() {
            bail!("At least 1 value must be in the `listen` array");
        }
        let worker_threads = match general.remove("worker-threads") {
            Some(s) => s.parse().context(
                "Value given for `worker-threads` isn't a valid number",
            )?,
            None => num_cpus::get(),
        };
        if worker_threads == 0 {
            bail!("Value of `worker-threads` must be > 0");
        }
        let request_timeout = match general.remove("request-timeout") {
            Some(s) => s.parse().context(
                "Value given for `request-timeout` isn't a valid number",
            )?,
            None => 5,
        };
        let response_timeout = match general.remove("response-timeout") {
            Some(s) => s.parse().context(
                "Value given for `response-timeout` isn't a valid number",
            )?,
            None => 0,
        };
        let conn_logging = general.remove_yn_true("connection-logging");
        // Only turn on if connection-logging is enabled
        let cgi_resp_logging = general.remove_yn_true("cgi-response-logging") && conn_logging;
        let ip_full = general.remove_yn("log-ip");
        let ip_partial = general.remove_yn("log-ip-partial");
        let ip_log = match (ip_full, ip_partial) {
            (false, false) => IpLogAmount::None,
            (true, false) => IpLogAmount::Full,
            (false, true) => IpLogAmount::Partial,
            (true, true) => bail!(
                "log-ip and log-ip-partial can't both be turned on at once"
            ),
        };
        check_section_empty("general", general)?;

        let tls_section = conf
            .section_mut(Some(":tls"))
            .context("Missing key `:tls`")?;
        let store = Path::new(
            &tls_section
                .remove("store")
                .context("Missing key `store` in `:tls`")?,
        )
        .to_owned();
        if !store.exists() {
            fs::create_dir_all(&store).context(
                "`store` directory doesn't exist and couldn't be created",
            )?;
        }
        let organization = tls_section
            .remove("organization")
            .unwrap_or_else(|| "stargazer".to_owned());
        let generate_certs = tls_section.remove_yn_true("gen-certs");
        let regen_certs = tls_section.remove_yn_true("regen-certs");
        let cert_lifetime = match tls_section.remove("cert-lifetime") {
            Some(mut time_str) => {
                let last_char = time_str.pop().ok_or_else(|| anyhow::anyhow!("Invalid time specifier given for cert-lifetime: must end in 'd', 'm', or 'y'"))?;
                let time_val: u64 = time_str.parse().context("Invalid time specifier given for cert-lifetime: all but the last character must be a number")?;
                match last_char {
                    'd' => time_val,
                    'm' => time_val * 30,
                    'y' => time_val * 365,
                    _ => bail!("Invalid time specifier given for cert-lifetime: must end in 'd', 'm', or 'y'"),
                }
            }
            None => 0, // default to never expire
        };
        check_section_empty(":tls", tls_section)?;

        let mut sites = Vec::with_capacity(5);

        // Load router configs
        for (section, props) in conf.iter_mut() {
            let section = match section {
                Some(section) => section,
                None => continue,
            };
            if section == ":tls" {
                continue;
            }
            let (domain, route) =
                if let Some((domain, route)) = split_once(section, ':') {
                    (domain.to_owned(), RoutePath::Prefix(route.to_owned()))
                } else if let Some((domain, route)) = split_once(section, '=') {
                    (domain.to_owned(), RoutePath::Exact(route.to_owned()))
                } else if let Some((domain, route)) = split_once(section, '~') {
                    (
                        domain.to_owned(),
                        RoutePath::Regex(Box::new(
                            regex::Regex::new(route).with_context(|| {
                                format!("Invalid regex `{}`", route)
                            })?,
                        )),
                    )
                } else {
                    (section.to_owned(), RoutePath::All)
                };
            let scgi = props.remove_yn("scgi");
            let cgi = props.remove_yn("cgi");
            let redirect = props.remove("redirect");
            if count_true(&[scgi, cgi, redirect.is_some()]) > 1 {
                bail!("Route {} has more than on of `cgi`, `scgi`, or `redirect` set", section);
            }
            let client_cert = props
                .remove("client-cert")
                .map(|path| load_client_cert(&path, &section))
                .transpose()?;

            let cgi_user = match props.remove("cgi-user") {
                Some(name) => {
                    Some(get_user_by_name(&name).with_context(|| {
                        format!("Invalid user name for `cgi-user`: `{}`", name)
                    })?)
                }
                None => None,
            };
            let cgi_timeout = match props.remove("cgi-timeout") {
                Some(s) => Some(s.parse().with_context(|| {
                    "Value for `cgi-timeout` in invalid, must be an integer"
                })?),
                None => None,
            };

            let cert_path = props.remove("cert-path");
            let key_path = props.remove("key-path");
            let cert_key_path = match (cert_path, key_path) {
                (Some(cert_path), Some(key_path)) => {
                    Some((cert_path.into(), key_path.into()))
                }
                (None, None) => None,
                _ => {
                    bail!("If either `cert-path` or `key-path` are set, both must be set");
                }
            };

            let route_type: RouteType = if scgi {
                let scgi_addr = props.remove("scgi-address").context(
                    "Routes with `scgi` on must include an `scg-address`",
                )?;
                let mut is_sock = false;
                let mut addr = None;
                if let Ok(mut sock_addr) = scgi_addr.to_socket_addrs() {
                    if let Some(sock_addr) = sock_addr.next() {
                        is_sock = true;
                        addr = Some(SCGIAddress::Tcp(sock_addr));
                    }
                }
                if !is_sock {
                    addr = Some(SCGIAddress::Unix(
                        Path::new(&scgi_addr).to_owned(),
                    ));
                }
                SCGIRoute {
                    addr: addr.unwrap(),
                    timeout: cgi_timeout,
                }
                .into()
            } else if redirect.is_some() {
                let permanent = props.remove_yn("permanent");
                RedirectRoute {
                    url: redirect.unwrap(),
                    permanent,
                    redirect_rewrite: props.contains_key("rewrite"),
                }
                .into()
            } else {
                let root =
                    Path::new(&props.remove("root").with_context(|| {
                        format!(
                            "site configs need to have a root. `{}` does not",
                            section
                        )
                    })?)
                    .to_owned();
                if !root.exists() {
                    bail!(
                        "`root` directory for site `{}` doesn't exist",
                        section
                    );
                }
                let index = props
                    .remove("index")
                    .unwrap_or_else(|| "index.gmi".to_owned())
                    .to_owned();
                if index.contains('/') {
                    bail!("`index` in config cannot contain '/'");
                }
                let auto_index_off = props.remove_yn_inv("auto-index");
                if cgi {
                    if auto_index_off {
                        bail!("Routes with `cgi` on cannot use `auto-index`");
                    }
                    CGIRoute {
                        root,
                        index,
                        user: cgi_user,
                        timeout: cgi_timeout,
                    }
                    .into()
                } else {
                    StaticRoute {
                        root,
                        index,
                        auto_index: !auto_index_off,
                    }
                    .into()
                }
            };

            let rt_str = route_type.to_string();
            sites.push(Route {
                domain,
                path: route,
                rewrite: props.remove("rewrite"),
                route_type,
                lang: props.remove("lang"),
                charset: props.remove("charset"),
                cert_key_path,
                client_cert,
            });
            check_route_empty(section, &rt_str, props)?;
        }

        // Reverse order so site is checked first
        sites.reverse();
        if sites.is_empty() {
            bail!(
                "At least one route must be specified in config file. Please \
                  refer to stargazer.ini(5) for details"
            );
        }

        Ok(Config {
            listen,
            store,
            conn_logging,
            cgi_resp_logging,
            routes: sites,
            worker_threads,
            request_timeout,
            response_timeout,
            organization,
            generate_certs,
            regen_certs,
            cert_lifetime,
            ip_log,
        })
    })();
    res.with_context(|| {
        format!(
            "Error loading config file: {}",
            config_path.as_ref().display()
        )
    })
}

pub fn dev_config() -> Result<Config> {
    let root =
        std::env::current_dir().context("Couldn't access current directory")?;
    if !root.exists() {
        bail!("Current directory doesn't exist");
    }
    let conf = Config {
        listen: vec![SocketAddr::from(([127, 0, 0, 1], 1965))],
        store: dirs::data_dir()
            .context("Cannot locate user data directory")?
            .join("stargazer-dev-store"),
        routes: vec![Route {
            domain: "localhost".to_owned(),
            path: RoutePath::All,
            rewrite: None,
            lang: None,
            charset: None,
            route_type: StaticRoute {
                root,
                index: "index.gmi".to_owned(),
                auto_index: true,
            }
            .into(),
            cert_key_path: None,
            client_cert: None,
        }],
        worker_threads: num_cpus::get(),
        request_timeout: 5,
        response_timeout: 10,
        organization: "stargazer".to_owned(),
        conn_logging: true,
        cgi_resp_logging: true,
        generate_certs: true,
        regen_certs: true,
        cert_lifetime: 0,
        ip_log: IpLogAmount::Full,
    };
    fs::create_dir_all(&conf.store).with_context(|| {
        format!(
            "Error creating dev store directory at {}",
            conf.store.display()
        )
    })?;
    Ok(conf)
}

fn parse_interface(s: &str) -> Result<SocketAddr> {
    Ok(if s.contains(':') {
        if s.starts_with('[') && s.ends_with(']') {
            let addr = &s[1..s.len() - 1];
            let ip = IpAddr::from_str(addr).with_context(|| {
                format!("Error parsing IP address `{}`", addr)
            })?;
            SocketAddr::new(ip, 1965)
        } else {
            s.to_socket_addrs()
                .with_context(|| format!("Error parsing socket address {}", s))?
                .next()
                .with_context(|| {
                    format!("Error parsing socket address {}", s)
                })?
        }
    } else {
        let ip = IpAddr::from_str(s)
            .with_context(|| format!("Error parsing IP address `{}`", s))?;
        SocketAddr::new(ip, 1965)
    })
}

// STABLE Replace with std::str::split_once when stable
#[inline]
fn split_once(s: &str, delimiter: char) -> Option<(&str, &str)> {
    let idx = s.find(delimiter)?;
    Some((&s[..idx], &s[(idx + 1)..]))
}

fn check_section_empty(name: &str, section: &Properties) -> Result<()> {
    if !section.is_empty() {
        let extra: String = section
            .iter()
            .map(|(k, v)| format!("{}={}\n", k, v))
            .collect();
        bail!("Section `{}` contains unknown parameters:\n{}", name, extra,);
    }
    Ok(())
}

fn check_route_empty(
    name: &str,
    route_type: &str,
    section: &Properties,
) -> Result<()> {
    if !section.is_empty() {
        let mut error = String::new();
        for (k, v) in section.iter() {
            let intended_type = match k {
                "index" | "auto-index" | "lang" | "charset" => Some("static"),
                "root" => Some("static or cgi"),
                "permanent" => Some("redirect"),
                "cgi-user" => Some("cgi"),
                "scgi-address" => Some("scgi"),
                "cgi-timeout" => Some("cgi or scgi"),
                _ => None,
            };
            match intended_type {
                Some(intented_type) => error.push_str(&format!("\nRoute `{}` is a(n) {} route but it contains the parameter `{}={}` which is only allowed in {} routes", name, route_type, k, v, intented_type)),
                None => error.push_str(&format!("\nRoute `{}` contains an unknown parameter {}={}", name, k, v)),
            }
        }
        bail!(error);
    }
    Ok(())
}

trait RemoveYN {
    fn remove_yn(self, name: &str) -> bool;
    fn remove_yn_true(self, name: &str) -> bool;
    fn remove_yn_inv(self, name: &str) -> bool;
}

impl RemoveYN for &mut Properties {
    fn remove_yn(self, name: &str) -> bool {
        self.remove(name)
            .map(|s| match s.to_lowercase().as_str() {
                "on" | "true" | "yes" => true,
                "off" | "false" | "no" => false,
                s => {
                    warn!("Invalid value for `{}`: '{}'. Turing off", name, s);
                    false
                }
            })
            .unwrap_or(false)
    }
    fn remove_yn_true(self, name: &str) -> bool {
        self.remove(name)
            .map(|s| match s.to_lowercase().as_str() {
                "on" | "true" | "yes" => true,
                "off" | "false" | "no" => false,
                s => {
                    warn!("Invalid value for `{}`: '{}'. Turing on", name, s);
                    true
                }
            })
            .unwrap_or(true)
    }
    fn remove_yn_inv(self, name: &str) -> bool {
        self.remove(name)
            .map(|s| match s.to_lowercase().as_str() {
                "on" | "true" | "yes" => false,
                "off" | "false" | "no" => true,
                s => {
                    warn!("Invalid value for `{}`: '{}'. Turing off", name, s);
                    false
                }
            })
            .unwrap_or(false)
    }
}

fn count_true(bools: &[bool]) -> usize {
    let mut count = 0;
    for b in bools {
        if *b {
            count += 1;
        }
    }
    count
}

fn load_client_cert(
    path: &str,
    section: &str,
) -> Result<X509Certificate<'static>> {
    let raw_cert = std::fs::read(&path).with_context(|| {
        format!(
            "Error loading client cert at `{}` for route `{}`",
            path, section,
        )
    })?;
    let (_, pem) =
        x509_parser::pem::parse_x509_pem(Box::leak(Box::new(raw_cert)))?;
    let pem = Box::leak(Box::new(pem));
    pem.parse_x509().map_err(anyhow::Error::new)
}
