#![feature(once_cell)]

pub mod proto {
    use anyhow::{ensure, Context, Result};
    use serde::{Deserialize, Serialize};
    use std::env;
    use tokio::io::{AsyncBufRead, AsyncBufReadExt};

    #[derive(Serialize, Deserialize)]
    pub struct Init {
        pub proto: String,
        pub version: String,
        pub cwd: String,
        pub args: Vec<String>,
    }

    impl Init {
        pub fn from_env() -> Init {
            Init {
                proto: env!("CARGO_PKG_NAME").to_owned(),
                version: env!("CARGO_PKG_VERSION").to_owned(),
                cwd: env::current_dir()
                    .expect("cannot access current directory")
                    .to_str()
                    .expect("current directory is not valid UTF-8")
                    .to_owned(),
                args: env::args().skip(1).collect(),
            }
        }

        pub async fn from_reader<R: AsyncBufRead + Unpin>(
            buffer: &mut Vec<u8>,
            mut reader: R,
        ) -> Result<Self> {
            buffer.clear();
            reader
                .read_until(b'\0', &mut *buffer)
                .await
                .context("read proto init")?;
            buffer.pop(); // remove trailing '\0'

            let proto_init: Self = serde_json::from_slice(buffer).context("invalid proto init")?;
            ensure!(proto_init.check_version(), "invalid protocol version");
            ensure!(
                proto_init.args.is_empty(),
                "unexpected args passed to client"
            );

            Ok(proto_init)
        }

        /// returns true if the version matches
        pub fn check_version(&self) -> bool {
            self.proto == env!("CARGO_PKG_NAME") && self.version == env!("CARGO_PKG_VERSION")
        }
    }
}

pub mod config {
    use anyhow::Context;
    use directories::ProjectDirs;
    use serde::de::{Error, Unexpected};
    use serde::{Deserialize, Deserializer};
    use std::fs;
    use std::lazy::SyncOnceCell;
    use std::net::{IpAddr, Ipv4Addr};

    mod default {
        use super::*;

        pub fn instance_timeout() -> Option<u32> {
            // 5 minutes
            Some(5 * 60)
        }

        pub fn gc_interval() -> u32 {
            // 10 seconds
            10
        }

        pub fn listen() -> IpAddr {
            // localhost
            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
        }

        pub fn port() -> u16 {
            // some random unprivileged port
            27_631
        }

        pub fn log_filters() -> String {
            "info".to_owned()
        }
    }

    /// parse either bool(false) or u32
    fn de_instance_timeout<'de, D>(deserializer: D) -> Result<Option<u32>, D::Error>
    where
        D: Deserializer<'de>,
    {
        #[derive(Deserialize)]
        #[serde(untagged)]
        enum OneOf {
            Bool(bool),
            U32(u32),
        }

        match OneOf::deserialize(deserializer) {
            Ok(OneOf::U32(value)) => Ok(Some(value)),
            Ok(OneOf::Bool(false)) => Ok(None),
            Ok(OneOf::Bool(true)) => Err(Error::invalid_value(
                Unexpected::Bool(true),
                &"a non-negative integer or false",
            )),
            Err(_) => Err(Error::custom(
                "invalid type: expected a non-negative integer or false",
            )),
        }
    }

    /// make sure the value is greater than 0 to giver users feedback on invalid configuration
    fn de_gc_interval<'de, D>(deserializer: D) -> Result<u32, D::Error>
    where
        D: Deserializer<'de>,
    {
        match u32::deserialize(deserializer)? {
            0 => Err(Error::invalid_value(
                Unexpected::Unsigned(0),
                &"an integer 1 or greater",
            )),
            value => Ok(value),
        }
    }

    #[derive(Deserialize)]
    #[serde(deny_unknown_fields)]
    pub struct Config {
        #[serde(default = "default::instance_timeout")]
        #[serde(deserialize_with = "de_instance_timeout")]
        pub instance_timeout: Option<u32>,

        #[serde(default = "default::gc_interval")]
        #[serde(deserialize_with = "de_gc_interval")]
        pub gc_interval: u32,

        #[serde(default = "default::listen")]
        pub listen: IpAddr,

        #[serde(default = "default::port")]
        pub port: u16,

        #[serde(default = "default::log_filters")]
        pub log_filters: String,
    }

    impl Config {
        fn default_values() -> Self {
            Config {
                instance_timeout: default::instance_timeout(),
                gc_interval: default::gc_interval(),
                listen: default::listen(),
                port: default::port(),
                log_filters: default::log_filters(),
            }
        }

        fn try_load() -> anyhow::Result<Self> {
            let pkg_name = env!("CARGO_PKG_NAME");
            let config_path = ProjectDirs::from("", "", pkg_name)
                .context("project config directory not found")?
                .config_dir()
                .join("config.toml");
            let path = config_path.display();
            let config_data = fs::read(&config_path)
                .with_context(|| format!("cannot read config file `{path}`"))?;
            toml::from_slice(&config_data)
                .with_context(|| format!("cannot parse config file `{path}`"))
        }

        /// panics if called multiple times
        fn init_logger(&self) {
            env_logger::Builder::from_env(
                env_logger::Env::new().default_filter_or(&self.log_filters),
            )
            .format_timestamp(None)
            .format_module_path(false)
            .format_target(false)
            .init();
        }

        /// tries to load configuration file from the standard location, if it fails it constructs
        /// a configuration with default values
        ///
        /// initializes a global logger based on the configuration
        pub fn load_or_default() -> &'static Self {
            static GLOBAL: SyncOnceCell<&'static Config> = SyncOnceCell::new();
            GLOBAL.get_or_init(|| {
                let (config, load_err) = match Self::try_load() {
                    Ok(config) => (config, None),
                    Err(err) => (Self::default_values(), Some(err)),
                };
                let global_config = Box::leak(Box::new(config));
                global_config.init_logger();
                if let Some(load_err) = load_err {
                    // log only after the logger has been initialized
                    log::warn!("cannot load config: {load_err:?}");
                }
                global_config
            })
        }
    }
}
