// stargazer - A Gemini Server
// Copyright (C) 2021 Ben Aaron Goldberg <ben@benaaron.dev>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.

use crate::router::Route;
use crate::CONF;
use anyhow::{anyhow, bail, Context, Result};
use futures_rustls::rustls::client::HandshakeSignatureValid;
use futures_rustls::rustls::internal::msgs::handshake::DigitallySignedStruct;
use futures_rustls::rustls::server::{
    ClientCertVerified, ClientCertVerifier, ClientHello, ResolvesServerCert,
};
use futures_rustls::rustls::{
    sign::{self, CertifiedKey},
    Certificate, DistinguishedNames, Error, PrivateKey, ServerConfig,
};
use log::debug;
use rcgen::{
    Certificate as GenCert, CertificateParams, DistinguishedName, DnType,
};
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::collections::HashMap;
use std::time::{Duration, SystemTime};
use std::{
    fs::{self, File},
    io::BufReader,
    path::{Path, PathBuf},
    sync::Arc,
};
use time::ext::NumericalDuration;
use time::OffsetDateTime;
use x509_parser::parse_x509_certificate;

/// Allow but don't require client auth accepting any client cert.
struct GeminiClientAuth;

impl ClientCertVerifier for GeminiClientAuth {
    fn offer_client_auth(&self) -> bool {
        true
    }

    fn client_auth_mandatory(&self) -> Option<bool> {
        Some(false)
    }

    fn client_auth_root_subjects(&self) -> Option<DistinguishedNames> {
        Some(Vec::new())
    }

    fn verify_client_cert(
        &self,
        _end_entity: &Certificate,
        _intermediates: &[Certificate],
        _now: SystemTime,
    ) -> Result<ClientCertVerified, Error> {
        Ok(ClientCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        _message: &[u8],
        _cert: &Certificate,
        _dss: &DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn verify_tls13_signature(
        &self,
        _message: &[u8],
        _cert: &Certificate,
        _dss: &DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, Error> {
        Ok(HandshakeSignatureValid::assertion())
    }
}

struct CustomSni {
    by_name: HashMap<String, Arc<sign::CertifiedKey>>,
    default: Arc<sign::CertifiedKey>,
}

impl CustomSni {
    fn new(default: sign::CertifiedKey) -> Self {
        Self {
            by_name: HashMap::new(),
            default: Arc::new(default),
        }
    }

    fn add(&mut self, name: &str, ck: sign::CertifiedKey) {
        self.by_name.insert(name.into(), Arc::new(ck));
    }

    fn is_loaded(&self, name: &str) -> bool {
        self.by_name.contains_key(name)
    }
}

impl ResolvesServerCert for CustomSni {
    fn resolve(
        &self,
        client_hello: ClientHello,
    ) -> Option<Arc<sign::CertifiedKey>> {
        client_hello.server_name().map(|name| {
            self.by_name
                .get(name)
                .map(Arc::clone)
                .unwrap_or_else(|| self.default.clone())
        })
    }
}

/// Load the passed certificates file
fn load_cert(path: &Path) -> Result<Vec<Certificate>> {
    match certs(&mut BufReader::new(
        File::open(path).context("error reading certificate file")?,
    ))
    .map_err(|_| anyhow!("invalid certificate"))
    {
        Ok(certs) => {
            if !certs.is_empty() {
                Ok(certs.into_iter().map(Certificate).collect())
            } else {
                Err(anyhow!(
                    "cert file {} doesn't contain a cert",
                    path.display()
                ))
            }
        }
        Err(e) => Err(e),
    }
}

/// Load the passed keys file
fn load_key(path: &Path) -> Result<PrivateKey> {
    match pkcs8_private_keys(&mut BufReader::new(
        File::open(path).context("error reading key file")?,
    ))
    .map_err(|_| anyhow!("invalid key"))
    {
        Ok(keys) => {
            if keys.len() == 1 {
                Ok(PrivateKey(keys[0].clone()))
            } else {
                Err(anyhow!(
                    "key file {} doesn't contain exactly 1 key",
                    path.display()
                ))
            }
        }
        Err(e) => Err(e),
    }
}

/// Configure the server using rustls
/// See [`futures_rustls::rustls::ServerConfig`] for details
///
/// A TLS server needs a certificate and a fitting private key
pub fn load_config(routes: &[Route]) -> Result<ServerConfig> {
    let (default_cert, default_key) = get_default_cert_and_key()?;
    let signing_key = sign::any_supported_type(&default_key)
        .map_err(|_| anyhow!("Invalid default key"))?;
    let default = CertifiedKey::new(default_cert, signing_key);
    let mut resolver = CustomSni::new(default);
    for route in routes {
        if !resolver.is_loaded(&route.domain) {
            let (cert_chain, key) = get_cert_and_key(route)?;

            let signing_key = sign::any_supported_type(&key).map_err(|_| {
                anyhow!("Invalid key for domain {}", route.domain)
            })?;

            log::debug!("Loaded cert+key for domain {}", route.domain);
            resolver
                .add(&route.domain, CertifiedKey::new(cert_chain, signing_key));
        }
    }

    // we don't use client authentication
    // let mut config = ServerConfig::new(Arc::new(GeminiClientAuth));
    // config.cert_resolver = Arc::new(resolver);
    // config.versions = vec![ProtocolVersion::TLSv1_3, ProtocolVersion::TLSv1_2];

    Ok(ServerConfig::builder()
        .with_safe_defaults()
        .with_client_cert_verifier(Arc::new(GeminiClientAuth))
        .with_cert_resolver(Arc::new(resolver)))
}

fn get_cert_and_key(route: &Route) -> Result<(Vec<Certificate>, PrivateKey)> {
    let (cert_path, key_path, default_paths) = match route.cert_key_path.clone()
    {
        Some((cert_path, key_path)) => (cert_path, key_path, false),
        None => (
            CONF.store.join(format!("{}.crt", route.domain)),
            CONF.store.join(format!("{}.key", route.domain)),
            true,
        ),
    };
    if default_paths && (!cert_path.exists() || !key_path.exists()) {
        let (cert_path, key_path) = gen_cert_and_key(&route.domain)?;
        return Ok((load_cert(&cert_path)?, load_key(&key_path)?));
    }
    let cert_chain = load_cert(&cert_path)?;
    for cert in &cert_chain {
        let (_, cert) =
            parse_x509_certificate(cert.as_ref()).with_context(|| {
                format!("Error parsing cert for {}", route.domain)
            })?;
        if CONF.regen_certs && !cert.validity().is_valid() {
            let (cert_path, key_path) = gen_cert_and_key(&route.domain)?;
            return Ok((load_cert(&cert_path)?, load_key(&key_path)?));
        }
    }
    Ok((cert_chain, load_key(&key_path)?))
}

fn get_default_cert_and_key() -> Result<(Vec<Certificate>, PrivateKey)> {
    let cert_path = CONF.store.join("default.crt");
    let key_path = CONF.store.join("default.key");
    if !cert_path.exists() || !key_path.exists() {
        let (cert_path, key_path) = gen_default_pair()?;
        return Ok((load_cert(&cert_path)?, load_key(&key_path)?));
    }
    let cert_chain = load_cert(&cert_path)?;
    for cert in &cert_chain {
        let (_, cert) = parse_x509_certificate(cert.as_ref())
            .context("Error parsing default cert")?;
        if CONF.regen_certs && !cert.validity().is_valid() {
            let (cert_path, key_path) = gen_default_pair()?;
            return Ok((load_cert(&cert_path)?, load_key(&key_path)?));
        }
    }
    Ok((cert_chain, load_key(&key_path)?))
}

fn gen_cert_and_key(domain: &str) -> Result<(PathBuf, PathBuf)> {
    if !CONF.generate_certs {
        bail!(
            "Cert not found for domain {} and cert generation is disabled",
            domain
        );
    }
    debug!("Generating cert+key for {}", domain);
    let mut params = CertificateParams::new(vec![domain.to_owned()]);
    let mut distinguished_name = DistinguishedName::new();
    distinguished_name.push(DnType::CommonName, domain);
    distinguished_name.push(DnType::OrganizationName, &CONF.organization);
    params.distinguished_name = distinguished_name;
    params.not_before = OffsetDateTime::now_local()
        .unwrap_or(OffsetDateTime::now_utc() - 12.hours());
    // Plus 5 years
    if CONF.cert_lifetime > 0 {
        params.not_after = params.not_before
            + Duration::from_secs(CONF.cert_lifetime * 24 * 60 * 60);
    }
    let cert = GenCert::from_params(params).with_context(|| {
        format!("Error generating cert for domain: `{}`", domain)
    })?;
    let crt = cert.serialize_pem().with_context(|| {
        format!("Error serializing cert for domain: `{}`", domain)
    })?;
    let key = cert.serialize_private_key_pem();

    let crt_path = CONF.store.join(format!("{}.crt", domain));
    let key_path = CONF.store.join(format!("{}.key", domain));

    fs::write(&crt_path, crt).with_context(|| {
        format!("Error writing cert for domain: `{}`", domain)
    })?;
    fs::write(&key_path, key)
        .with_context(|| format!("Error write key for domain: `{}`", domain))?;
    Ok((crt_path, key_path))
}

fn gen_default_pair() -> Result<(PathBuf, PathBuf)> {
    let params = CertificateParams::default();
    let cert = GenCert::from_params(params)
        .context("Error generating default cert")?;
    let crt = cert
        .serialize_pem()
        .context("Error serializing default cert")?;
    let key = cert.serialize_private_key_pem();

    let crt_path = CONF.store.join("default.crt");
    let key_path = CONF.store.join("default.key");

    fs::write(&crt_path, crt).context("Error writing default cert")?;
    fs::write(&key_path, key).context("Error writing default key")?;
    Ok((crt_path, key_path))
}
