// stargazer - A Gemini Server
// Copyright (C) 2021 Ben Aaron Goldberg <ben@benaaron.dev>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.

mod cgi;
mod cli;
mod config;
mod dir_list;
mod error;
mod get_file;
mod logger;
mod router;
mod tls;

use crate::error::{bad_req, ErrorConv, GemError, Result};
use async_channel::{Receiver, Sender};
use async_executor::Executor;
use async_io::Timer;
use async_net::{SocketAddrV6, TcpListener, TcpStream};
use cgi::{serve_cgi, serve_scgi};
use futures_lite::*;
use futures_rustls::{server::TlsStream, TlsAcceptor};
use get_file::get_file;
use log::{debug, error};
use once_cell::sync::{Lazy, OnceCell};
use router::{route, Request, Route};
use std::clone::Clone;
use std::convert::TryFrom;
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::panic::catch_unwind;
use std::sync::Arc;
use std::{process::exit, str, thread, time::Duration};

pub static SHUTDOWN: Lazy<(Sender<()>, Receiver<()>)> =
    Lazy::new(|| async_channel::bounded(1));
/// Global executor
pub static EXEC: Lazy<Executor> = Lazy::new(|| {
    // Start executor thread per cpu core
    for n in 1..=CONF.worker_threads {
        thread::Builder::new()
            .name(format!("stargazer-worker-{}", n))
            .spawn(|| loop {
                let _ = catch_unwind(|| {
                    async_io::block_on(EXEC.run(SHUTDOWN.1.recv()))
                });
            })
            .expect("cannot spawn executor thread");
    }

    Executor::new()
});

static CONF_BACKING: OnceCell<config::Config> = OnceCell::new();
pub static CONF: Lazy<config::Config> = Lazy::new(|| {
    CONF_BACKING
        .get()
        .cloned()
        .expect("CONF accessed before it was set")
});

fn main() {
    use error::Context;

    // Parse cli args
    let args = cli::Args::from_args();

    // Set log level based on debug cli flag
    let log_level = match args.debug {
        true => log::LevelFilter::Debug,
        false => log::LevelFilter::Warn,
    };

    // start logger
    logger::init(log_level).expect("Error starting logger");

    // Use dev or regular config based on dev cli flag
    let conf = if args.dev {
        println!("Staring dev server at gemini://localhost/");
        config::dev_config()
    } else {
        debug!("Loading config from: {}", args.config_path.display());
        config::load(&args.config_path)
    };
    let conf = match conf {
        Ok(conf) => conf,
        Err(e) => {
            error!("{:#}", e);
            log::logger().flush();
            exit(1);
        }
    };
    debug!("{:#?}", conf);
    if args.check_config {
        exit(0);
    }
    // Set global config
    CONF_BACKING.set(conf).unwrap();

    if CONF.conn_logging && !args.debug {
        log::set_max_level(log::LevelFilter::Info);
    }

    // Configure tls acceptor for all domain names in config
    let acceptor = TlsAcceptor::from(match tls::load_config(&CONF.routes) {
        Ok(conf) => Arc::new(conf),
        Err(e) => {
            error!("{:#}", e);
            log::logger().flush();
            exit(1);
        }
    });

    // Create a listener task for each interface+port listed in conf
    debug!("Server staring");
    for addr in &CONF.listen {
        let acceptor = acceptor.clone();
        EXEC.spawn(async move {
            // Start tcp listener
            // Program exits with an error if any fail to bind
            let listener = match addr {
                SocketAddr::V4(addr) => match TcpListener::bind(addr).await {
                    Ok(l) => l,
                    Err(e) => {
                        error!("{:#}", e);
                        log::logger().flush();
                        exit(1);
                    }
                },
                SocketAddr::V6(addr) => match create_v6_sock(addr) {
                    Ok(l) => l,
                    Err(e) => {
                        error!("Error creating IPv6 listener: {:#}", e);
                        log::logger().flush();
                        exit(1);
                    }
                },
            };

            let server_port = match listener.local_addr() {
                Ok(addr) => addr.port(),
                Err(e) => {
                    error!("{:#}", e);
                    log::logger().flush();
                    exit(1);
                }
            };

            // Accept connections
            loop {
                let (stream, remote_addr) = match listener
                    .accept()
                    .await
                    .context("Error accepting tcp connection")
                {
                    Ok(ret) => ret,
                    Err(e) => {
                        debug!("{:#}", e);
                        continue;
                    }
                };
                let acceptor = acceptor.clone();

                // Start connection handler task
                // This will time out after 10s
                EXEC.spawn(async move {
                    if CONF.response_timeout == 0 {
                        start_task(stream, acceptor, remote_addr, server_port)
                            .await
                    } else {
                        start_task(stream, acceptor, remote_addr, server_port)
                            .or(async {
                                Timer::after(Duration::from_secs(
                                    CONF.response_timeout,
                                ))
                                .await;
                                debug!(
                                "Request timeout: Total request time exceeded"
                            );
                            })
                            .await
                    }
                })
                .detach();
            }
        })
        .detach();
    }

    async_io::block_on(async {
        if let Err(e) = exit_on_sig().await {
            println!("Error exiting !!!: {:#}", e);
        }
    });
}

fn create_v6_sock(addr: &SocketAddrV6) -> Result<TcpListener> {
    let socket_builder = net2::TcpBuilder::new_v6()?;
    socket_builder.only_v6(true)?;
    socket_builder.bind(addr)?;
    let listener = socket_builder.listen(128)?;
    Ok(TcpListener::try_from(listener)?)
}

#[cfg(unix)]
async fn exit_on_sig() -> Result<()> {
    use signal_hook::consts::{SIGINT, SIGTERM};
    use signal_hook_async_std::Signals;

    let mut signals = Signals::new(&[SIGTERM, SIGINT])?;
    signals.next().await;
    SHUTDOWN.0.send(()).await?;
    log::logger().flush();
    std::process::exit(0);
}

#[cfg(not(unix))]
async fn exit_on_sig() -> Result<()> {
    Ok(future::pending::<()>().await)
}

// This is a separate function for ease of error handling
async fn start_task(
    stream: TcpStream,
    acceptor: TlsAcceptor,
    remote_addr: SocketAddr,
    server_port: u16,
) {
    use anyhow::Context;
    // Accept tsl connection
    match acceptor
        .accept(stream)
        .await
        .context("Error accepting tls connection")
    {
        Ok(mut stream) => {
            // Handle connection
            match handle_connection(&mut stream, remote_addr, server_port).await
            {
                Ok(log_info) => log_info.log(),
                Err(ErrorLogInfo { log_info, error }) => {
                    error.log();
                    if let Some(log_info) = log_info {
                        log_info.log();
                    }
                    let mut error_output = Vec::with_capacity(1040);
                    error_output.extend_from_slice(
                        error.status().to_string().as_bytes(),
                    );
                    error_output.push(b' ');
                    error_output.extend_from_slice(error.meta().as_bytes());
                    error_output.push(b'\r');
                    error_output.push(b'\n');
                    if let Err(e) = stream
                        .write_all(&error_output)
                        .await
                        .context("Error writing response")
                    {
                        debug!("{:#}", e);
                    }
                }
            }
            if let Err(e) = stream.close().await.context("Error closing stream")
            {
                debug!("{:#}", e);
            }
        }
        Err(e) => {
            debug!("{:#}", e);
        }
    };
}

/// Parse the request, route it, call the response handler
async fn handle_connection(
    stream: &mut TlsStream<TcpStream>,
    remote_addr: SocketAddr,
    server_port: u16,
) -> std::result::Result<FullLogInfo, ErrorLogInfo> {
    use anyhow::Context;

    // Read the request
    let (input_buf, read_len) = if CONF.request_timeout == 0 {
        read_request(stream).await
    } else {
        read_request(stream)
            .or(async {
                Timer::after(Duration::from_secs(CONF.request_timeout)).await;
                Err(GemError::TimeOut)
            })
            .await
    }?;

    let (route, req) = route(
        str::from_utf8(&input_buf[..read_len])
            .context("url is not valid utf8")
            .map_err(GemError::BadReq)?,
    )?;
    let log_ip = match CONF.ip_log {
        config::IpLogAmount::Full => remote_addr.ip(),
        config::IpLogAmount::Partial => anon_ip(remote_addr.ip()),
        config::IpLogAmount::None => Ipv4Addr::new(0, 0, 0, 0).into(),
    };
    match handle_response(stream, remote_addr, server_port, route, req.clone()).await {
        Ok(log_info) => Ok(FullLogInfo {
            log_info,
            req,
            log_ip,
        }),
        Err(error) => {
            let status = error.status();
            let meta = error.meta().as_bytes().to_owned();
            Err(ErrorLogInfo {
                error,
                log_info: Some(FullLogInfo {
                    log_info: LogInfo {
                        size: 0,
                        status,
                        meta,
                    },
                    req,
                    log_ip,
                }),
            })
        }
    }
}

/// Handle responding to the request
///
/// Depending on the path requested this could mean:
/// - Return file contents with correct mime type
/// - Return a gemtext page list directory contents
/// - Run a CGI script
/// - Send a request to an SCGI script
/// - Send a redirect
/// - Return an error
async fn handle_response(
    stream: &mut TlsStream<TcpStream>,
    remote_addr: SocketAddr,
    server_port: u16,
    route: &'static Route,
    req: Request,
) -> Result<LogInfo> {
    use crate::router::RouteType;
    use anyhow::Context;

    // Check client cert if one if configured
    if let Some(config_client_cert) = &route.client_cert {
        let (_, session) = stream.get_ref();
        match session.peer_certificates().and_then(|list| list.get(0)) {
            Some(session_client_cert) => {
                match x509_parser::parse_x509_certificate(
                    session_client_cert.as_ref(),
                ) {
                    Ok((_, cert)) => {
                        if cert != *config_client_cert {
                            return Err(GemError::ClientCertNoAuth);
                        }
                    }
                    Err(e) => {
                        log::debug!("Bad client cert: {:#}", e);
                        return Err(GemError::BadReq(anyhow::anyhow!(
                            "Bad client cert"
                        )));
                    }
                }
            }
            None => return Err(GemError::ClientCertRequired),
        }
    }

    let log_info = match &route.route_type {
        RouteType::Cgi(cgi_route) => {
            let fut = serve_cgi(
                cgi_route,
                &route.domain,
                &req,
                remote_addr.ip(),
                server_port,
                stream,
            );
            let res = match cgi_route.timeout {
                Some(timeout) => {
                    fut.or(async {
                        Timer::after(Duration::from_secs(timeout)).await;
                        Err(GemError::CGIError(anyhow::anyhow!(
                            "CGI Process timed out"
                        )))
                    })
                    .await
                }
                None => fut.await,
            };
            // Convert ServerErrors from serve_cgi to CGIErrors
            match res {
                Ok(log_info) => Ok(log_info),
                Err(GemError::ServerError(e)) => Err(GemError::CGIError(e)),
                Err(e) => Err(e),
            }?
        }
        RouteType::Scgi(scgi_route) => {
            let fut = serve_scgi(
                scgi_route,
                &route.domain,
                &req,
                &route.path,
                remote_addr.ip(),
                server_port,
                stream,
            );
            let res = match scgi_route.timeout {
                Some(timeout) => {
                    fut.or(async {
                        Timer::after(Duration::from_secs(timeout)).await;
                        Err(GemError::CGIError(anyhow::anyhow!(
                            "SCGI Request timed out"
                        )))
                    })
                    .await
                }
                None => fut.await,
            };
            // Convert ServerErrors from serve_cgi to CGIErrors
            match res {
                Ok(log_info) => Ok(log_info),
                Err(GemError::ServerError(e)) => Err(GemError::CGIError(e)),
                Err(e) => Err(e),
            }?
        }
        RouteType::Redirect(redirect_route) => {
            let mut output = Vec::with_capacity(1024);
            if redirect_route.permanent {
                output.extend_from_slice(b"31 ");
            } else {
                output.extend_from_slice(b"30 ");
            }
            let meta = if redirect_route.redirect_rewrite {
                let mut uri = uriparse::URIReference::try_from(
                    redirect_route.url.as_str(),
                )
                .with_context(|| {
                    format!(
                        "Redirect url`{}` is not a valid uri",
                        redirect_route.url
                    )
                })
                .into_server_error()?;
                uri.set_path(req.path.as_str())
                    .with_context(|| format!("Path for redirect with `redirect-rewrite` set has an invalid path: {}", req.path))
                    .into_server_error()?;
                uri.to_string()
            } else {
                redirect_route.url.clone()
            };
            output.extend_from_slice(meta.as_bytes());
            output.extend_from_slice(b"\r\n");
            stream.write_all(&output).await.into_io_error()?;
            LogInfo {
                size: output.len(),
                meta: output,
                status: if redirect_route.permanent { 31 } else { 30 },
            }
        }
        RouteType::Static(static_route) => {
            get_file(
                static_route,
                &req.path,
                &route.lang,
                &route.charset,
                stream,
            )
            .await?
        }
    };

    stream
        .flush()
        .await
        .context("Error flushing socket")
        .into_io_error()?;
    Ok(log_info)
}
/// Read the request
///
/// Make sure to stop and return an error if the request is too long
pub async fn read_request(
    stream: &mut TlsStream<TcpStream>,
) -> Result<([u8; 1027], usize)> {
    use error::Context;
    // buffer size is 1024 max url len + 2 for \r\n + 1 byte to detect too long requests
    let mut input_buf = [0; 1027];
    let mut total_read = 0;
    let mut found_cr = false;

    'outer: loop {
        let read = stream
            .read(&mut input_buf[total_read..])
            .await
            .context("Error reading input from client")
            .into_io_error()?;
        if read + total_read > 1026 {
            return Err(bad_req("Request too long"));
        }
        if read == 0 {
            break;
        }
        for (i, byte) in
            input_buf[total_read..total_read + read].iter().enumerate()
        {
            if *byte == b'\r' {
                found_cr = true;
            } else if found_cr && *byte == b'\n' {
                total_read += i + 1;
                break 'outer;
            } else {
                found_cr = false;
            }
        }
        total_read += read;
    }
    Ok((input_buf, total_read.saturating_sub(2)))
}

/// Anonymize client IP address
fn anon_ip(addr: IpAddr) -> IpAddr {
    match addr {
        IpAddr::V4(addr) => {
            let [a, b, _, _] = addr.octets();
            Ipv4Addr::new(a, b, 0, 0).into()
        }
        IpAddr::V6(addr) => {
            let [a, b, c, _, _, _, _, _] = addr.segments();
            Ipv6Addr::new(a, b, c, 0, 0, 0, 0, 0).into()
        }
    }
}

pub struct LogInfo {
    size: usize,
    status: u8,
    meta: Vec<u8>,
}

struct FullLogInfo {
    log_info: LogInfo,
    log_ip: IpAddr,
    req: Request,
}

impl FullLogInfo {
    fn log(&self) {
        log::info!(
            "{}\t{}\t{}\t{}\t{}\t{}",
            self.log_ip,
            self.req.host,
            self.req.path,
            self.log_info.size,
            self.log_info.status,
            std::str::from_utf8(&self.log_info.meta)
                .unwrap_or("INVALID_UNICODE")
        );
    }
}

struct ErrorLogInfo {
    log_info: Option<FullLogInfo>,
    error: GemError,
}

impl From<GemError> for ErrorLogInfo {
    fn from(error: GemError) -> Self {
        ErrorLogInfo {
            log_info: None,
            error,
        }
    }
}
