use clap::{Arg, Command, ErrorKind};
use futures::stream::StreamExt;
use indymilter::Listener;
use signal_hook::consts::{SIGHUP, SIGINT, SIGTERM};
use signal_hook_tokio::{Handle, Signals};
use spf_milter::{CliOptions, Config, Socket, MILTER_NAME, VERSION};
use std::{os::unix::fs::FileTypeExt, path::Path, process};
use tokio::{
    fs,
    net::{TcpListener, UnixListener},
    sync::{mpsc, oneshot},
    task::JoinHandle,
};

const ARG_CONFIG_FILE: &str = "CONFIG_FILE";
const ARG_DRY_RUN: &str = "DRY_RUN";
const ARG_LOG_DESTINATION: &str = "LOG_DESTINATION";
const ARG_LOG_LEVEL: &str = "LOG_LEVEL";
const ARG_SOCKET: &str = "SOCKET";
const ARG_SYSLOG_FACILITY: &str = "SYSLOG_FACILITY";

#[tokio::main]
async fn main() {
    let command = Command::new(MILTER_NAME)
        .version(VERSION)
        .arg(Arg::new(ARG_CONFIG_FILE)
            .short('c')
            .long("config-file")
            .value_name("PATH")
            .help("Path to configuration file"))
        .arg(Arg::new(ARG_DRY_RUN)
            .short('n')
            .long("dry-run")
            .help("Process messages without taking action"))
        .arg(Arg::new(ARG_LOG_DESTINATION)
            .short('l')
            .long("log-destination")
            .value_name("TARGET")
            .help("Destination for log messages"))
        .arg(Arg::new(ARG_LOG_LEVEL)
            .short('L')
            .long("log-level")
            .value_name("LEVEL")
            .help("Minimum severity of messages to log"))
        .arg(Arg::new(ARG_SOCKET)
            .short('p')
            .long("socket")
            .value_name("SOCKET")
            .help("Listening socket of the milter"))
        .arg(Arg::new(ARG_SYSLOG_FACILITY)
            .short('s')
            .long("syslog-facility")
            .value_name("NAME")
            .help("Facility to use for syslog messages"));

    let opts = match build_opts(command) {
        Ok(opts) => opts,
        Err(e) => {
            e.exit();
        }
    };

    let config = match Config::read(opts).await {
        Ok(config) => config,
        Err(e) => {
            eprintln!("error: {}", e);
            process::exit(1);
        }
    };

    let (reload_tx, reload) = mpsc::channel(1);
    let (shutdown_tx, shutdown) = oneshot::channel();

    let signals =
        Signals::new(&[SIGHUP, SIGINT, SIGTERM]).expect("failed to install signal handler");
    let signals_handle = signals.handle();
    let signals_task = spawn_signals_task(signals, reload_tx, shutdown_tx);

    let addr;
    let mut socket_path = None;
    let listener = match config.socket() {
        Socket::Inet(socket) => {
            let listener = match TcpListener::bind(socket).await {
                Ok(listener) => listener,
                Err(e) => {
                    eprintln!("error: could not bind TCP socket: {}", e);
                    process::exit(1);
                }
            };

            Listener::Tcp(listener)
        }
        Socket::Unix(socket) => {
            // Before creating the socket file, try removing any existing socket
            // at the target path. This is to clear out a leftover file from a
            // previous, aborted execution.
            try_remove_socket(&socket).await;

            let listener = match UnixListener::bind(socket) {
                Ok(listener) => listener,
                Err(e) => {
                    eprintln!("error: could not create UNIX domain socket: {}", e);
                    process::exit(1);
                }
            };

            // Remember the socket file path, and delete it on shutdown.
            addr = listener.local_addr().unwrap();
            socket_path = addr.as_pathname();

            Listener::Unix(listener)
        }
    };

    let result = spf_milter::run(listener, config, reload, shutdown).await;

    cleanup(signals_handle, signals_task, socket_path).await;

    if let Err(e) = result {
        eprintln!("error: {}", e);
        process::exit(1);
    }
}

fn build_opts(mut command: Command) -> clap::Result<CliOptions> {
    let matches = command.get_matches_mut();

    let mut opts = CliOptions::builder();

    if let Some(path) = matches.value_of(ARG_CONFIG_FILE) {
        opts = opts.config_file(path);
    }

    if matches.is_present(ARG_DRY_RUN) {
        opts = opts.dry_run(true);
    }

    if let Some(socket) = matches.value_of(ARG_SOCKET) {
        match socket.parse() {
            Ok(socket) => {
                opts = opts.socket(socket);
            }
            Err(_) => {
                return Err(command.error(
                    ErrorKind::InvalidValue,
                    format!("Invalid value for socket: \"{}\"", socket),
                ));
            }
        }
    }

    if let Some(target) = matches.value_of(ARG_LOG_DESTINATION) {
        match target.parse() {
            Ok(target) => {
                opts = opts.log_destination(target);
            }
            Err(_) => {
                return Err(command.error(
                    ErrorKind::InvalidValue,
                    format!("Invalid value for log destination: \"{}\"", target),
                ));
            }
        }
    }

    if let Some(level) = matches.value_of(ARG_LOG_LEVEL) {
        match level.parse() {
            Ok(level) => {
                opts = opts.log_level(level);
            }
            Err(_) => {
                return Err(command.error(
                    ErrorKind::InvalidValue,
                    format!("Invalid value for log level: \"{}\"", level),
                ));
            }
        }
    }

    if let Some(name) = matches.value_of(ARG_SYSLOG_FACILITY) {
        match name.parse() {
            Ok(name) => {
                opts = opts.syslog_facility(name);
            }
            Err(_) => {
                return Err(command.error(
                    ErrorKind::InvalidValue,
                    format!("Invalid value for syslog facility: \"{}\"", name),
                ));
            }
        }
    }

    Ok(opts.build())
}

fn spawn_signals_task(
    mut signals: Signals,
    reload_config: mpsc::Sender<()>,
    shutdown_milter: oneshot::Sender<()>,
) -> JoinHandle<()> {
    tokio::spawn(async move {
        while let Some(signal) = signals.next().await {
            match signal {
                SIGHUP => {
                    let _ = reload_config.send(()).await;
                }
                SIGINT | SIGTERM => {
                    let _ = shutdown_milter.send(());
                    break;
                }
                _ => panic!("unexpected signal"),
            }
        }
    })
}

async fn cleanup(signals_handle: Handle, signals_task: JoinHandle<()>, socket_path: Option<&Path>) {
    signals_handle.close();
    signals_task.await.expect("signal handler task failed");

    if let Some(p) = socket_path {
        try_remove_socket(p).await;
    }
}

async fn try_remove_socket(path: impl AsRef<Path>) {
    if let Ok(metadata) = fs::metadata(&path).await {
        if metadata.file_type().is_socket() {
            let _ = fs::remove_file(path).await;
        }
    }
}
