use clap::{Arg, Command, ErrorKind};
use futures::stream::StreamExt;
use indymilter::Listener;
use signal_hook::consts::{SIGINT, SIGTERM};
use signal_hook_tokio::{Handle, Signals};
use spamassassin_milter::{Config, MILTER_NAME, VERSION};
use std::{net::IpAddr, os::unix::fs::FileTypeExt, path::Path, process, str::FromStr};
use tokio::{
    fs,
    net::{TcpListener, UnixListener},
    sync::oneshot,
    task::JoinHandle,
};

const ARG_AUTH_UNTRUSTED: &str = "AUTH_UNTRUSTED";
const ARG_DRY_RUN: &str = "DRY_RUN";
const ARG_MAX_MESSAGE_SIZE: &str = "MAX_MESSAGE_SIZE";
const ARG_PRESERVE_BODY: &str = "PRESERVE_BODY";
const ARG_PRESERVE_HEADERS: &str = "PRESERVE_HEADERS";
const ARG_REJECT_SPAM: &str = "REJECT_SPAM";
const ARG_REPLY_CODE: &str = "REPLY_CODE";
const ARG_REPLY_STATUS_CODE: &str = "REPLY_STATUS_CODE";
const ARG_REPLY_TEXT: &str = "REPLY_TEXT";
const ARG_TRUSTED_NETWORKS: &str = "TRUSTED_NETWORKS";
const ARG_VERBOSE: &str = "VERBOSE";
const ARG_SOCKET: &str = "SOCKET";
const ARG_SPAMC_ARGS: &str = "SPAMC_ARGS";

#[tokio::main]
async fn main() {
    let command = Command::new(MILTER_NAME)
        .version(VERSION)
        .arg(Arg::new(ARG_AUTH_UNTRUSTED)
            .short('a')
            .long("auth-untrusted")
            .help("Treat authenticated senders as untrusted"))
        .arg(Arg::new(ARG_DRY_RUN)
            .short('n')
            .long("dry-run")
            .help("Process messages without applying changes"))
        .arg(Arg::new(ARG_MAX_MESSAGE_SIZE)
            .short('s')
            .long("max-message-size")
            .value_name("BYTES")
            .help("Maximum message size to process"))
        .arg(Arg::new(ARG_PRESERVE_BODY)
            .short('B')
            .long("preserve-body")
            .help("Suppress rewriting of message body"))
        .arg(Arg::new(ARG_PRESERVE_HEADERS)
            .short('H')
            .long("preserve-headers")
            .help("Suppress rewriting of Subject/From/To headers"))
        .arg(Arg::new(ARG_REJECT_SPAM)
            .short('r')
            .long("reject-spam")
            .help("Reject messages flagged as spam"))
        .arg(Arg::new(ARG_REPLY_CODE)
            .short('C')
            .long("reply-code")
            .value_name("CODE")
            .help("Reply code when rejecting messages"))
        .arg(Arg::new(ARG_REPLY_STATUS_CODE)
            .short('S')
            .long("reply-status-code")
            .value_name("CODE")
            .help("Status code when rejecting messages"))
        .arg(Arg::new(ARG_REPLY_TEXT)
            .short('R')
            .long("reply-text")
            .value_name("MSG")
            .help("Reply text when rejecting messages"))
        .arg(Arg::new(ARG_TRUSTED_NETWORKS)
            .short('t')
            .long("trusted-networks")
            .value_name("NETS")
            .use_value_delimiter(true)
            .help("Trust connections from these networks"))
        .arg(Arg::new(ARG_VERBOSE)
            .short('v')
            .long("verbose")
            .help("Enable verbose operation logging"))
        .arg(Arg::new(ARG_SOCKET)
            .required(true)
            .help("Listening socket of the milter"))
        .arg(Arg::new(ARG_SPAMC_ARGS)
            .last(true)
            .multiple_occurrences(true)
            .help("Additional arguments to pass to spamc"));

    let (socket, config) = match build_config(command) {
        Ok(config) => config,
        Err(e) => {
            e.exit();
        }
    };

    let (shutdown_tx, shutdown) = oneshot::channel();

    let signals = Signals::new(&[SIGTERM, SIGINT]).expect("failed to install signal handler");
    let signals_handle = signals.handle();
    let signals_task = spawn_signals_task(signals, shutdown_tx);

    let addr;
    let mut socket_path = None;
    let listener = match socket {
        Socket::Inet(socket) => {
            let listener = match TcpListener::bind(socket).await {
                Ok(listener) => listener,
                Err(e) => {
                    eprintln!("error: could not bind TCP socket: {}", e);
                    process::exit(1);
                }
            };

            Listener::Tcp(listener)
        }
        Socket::Unix(socket) => {
            // Before creating the socket file, try removing any existing socket
            // at the target path. This is to clear out a leftover file from a
            // previous, aborted execution.
            try_remove_socket(&socket).await;

            let listener = match UnixListener::bind(socket) {
                Ok(listener) => listener,
                Err(e) => {
                    eprintln!("error: could not create UNIX domain socket: {}", e);
                    process::exit(1);
                }
            };

            // Remember the socket file path, and delete it on shutdown.
            addr = listener.local_addr().unwrap();
            socket_path = addr.as_pathname();

            Listener::Unix(listener)
        }
    };

    eprintln!("{} {} starting", MILTER_NAME, VERSION);

    let result = spamassassin_milter::run(listener, config, shutdown).await;

    cleanup(signals_handle, signals_task, socket_path).await;

    match result {
        Ok(()) => {
            eprintln!("{} {} shut down", MILTER_NAME, VERSION);
        }
        Err(e) => {
            eprintln!("{} {} terminated with error: {}", MILTER_NAME, VERSION, e);
            process::exit(1);
        }
    }
}

enum Socket {
    Inet(String),
    Unix(String),
}

impl FromStr for Socket {
    type Err = ();

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        if let Some(s) = s.strip_prefix("inet:") {
            Ok(Self::Inet(s.into()))
        } else if let Some(s) = s.strip_prefix("unix:") {
            Ok(Self::Unix(s.into()))
        } else {
            Err(())
        }
    }
}

fn build_config(mut command: Command) -> clap::Result<(Socket, Config)> {
    let matches = command.get_matches_mut();

    let socket = matches.value_of(ARG_SOCKET).unwrap();
    let socket = match socket.parse() {
        Ok(socket) => socket,
        Err(()) => {
            return Err(command.error(
                ErrorKind::InvalidValue,
                format!("Invalid value for socket: \"{}\"", socket),
            ));
        }
    };

    let mut config = Config::builder();

    if let Some(bytes) = matches.value_of(ARG_MAX_MESSAGE_SIZE) {
        match bytes.parse() {
            Ok(bytes) => {
                config = config.max_message_size(bytes);
            }
            Err(_) => {
                return Err(command.error(
                    ErrorKind::InvalidValue,
                    format!("Invalid value for max message size: \"{}\"", bytes),
                ));
            }
        }
    }

    if let Some(nets) = matches.values_of(ARG_TRUSTED_NETWORKS) {
        config = config.use_trusted_networks(true);

        for net in nets.filter(|n| !n.is_empty()) {
            // Both `ipnet::IpNet` and `std::net::IpAddr` inputs are supported.
            match net
                .parse()
                .or_else(|_| net.parse::<IpAddr>().map(From::from))
            {
                Ok(net) => {
                    config = config.trusted_network(net);
                }
                Err(_) => {
                    return Err(command.error(
                        ErrorKind::InvalidValue,
                        format!("Invalid value for trusted network address: \"{}\"", net),
                    ));
                }
            }
        }
    }

    let reply_code = matches.value_of(ARG_REPLY_CODE);
    let reply_status_code = matches.value_of(ARG_REPLY_STATUS_CODE);
    validate_reply_codes(&mut command, reply_code, reply_status_code)?;

    if matches.is_present(ARG_AUTH_UNTRUSTED) {
        config = config.auth_untrusted(true);
    }
    if matches.is_present(ARG_DRY_RUN) {
        config = config.dry_run(true);
    }
    if matches.is_present(ARG_PRESERVE_BODY) {
        config = config.preserve_body(true);
    }
    if matches.is_present(ARG_PRESERVE_HEADERS) {
        config = config.preserve_headers(true);
    }
    if matches.is_present(ARG_REJECT_SPAM) {
        config = config.reject_spam(true);
    }
    if matches.is_present(ARG_VERBOSE) {
        config = config.verbose(true);
    }
    if let Some(code) = reply_code {
        config = config.reply_code(code.to_owned());
    }
    if let Some(code) = reply_status_code {
        config = config.reply_status_code(code.to_owned());
    }
    if let Some(msg) = matches.value_of(ARG_REPLY_TEXT) {
        config = config.reply_text(msg.to_owned());
    }
    if let Some(spamc_args) = matches.values_of(ARG_SPAMC_ARGS) {
        config = config.spamc_args(spamc_args);
    };

    Ok((socket, config.build()))
}

fn validate_reply_codes(
    command: &mut Command,
    reply_code: Option<&str>,
    reply_status_code: Option<&str>,
) -> clap::Result<()> {
    match (reply_code, reply_status_code) {
        (Some(c1), Some(c2))
            if !((c1.starts_with('4') || c1.starts_with('5')) && c2.starts_with(&c1[..1])) =>
        {
            Err(command.error(
                ErrorKind::InvalidValue,
                format!(
                    "Invalid or incompatible values for reply code and status code: \"{}\", \"{}\"",
                    c1, c2
                ),
            ))
        }
        (Some(c), None) if !c.starts_with('5') => Err(command.error(
            ErrorKind::InvalidValue,
            format!("Invalid value for reply code (5XX): \"{}\"", c),
        )),
        (None, Some(c)) if !c.starts_with('5') => Err(command.error(
            ErrorKind::InvalidValue,
            format!("Invalid value for reply status code (5.X.X): \"{}\"", c),
        )),
        _ => Ok(()),
    }
}

fn spawn_signals_task(
    mut signals: Signals,
    shutdown_milter: oneshot::Sender<()>,
) -> JoinHandle<()> {
    tokio::spawn(async move {
        while let Some(signal) = signals.next().await {
            match signal {
                SIGINT | SIGTERM => {
                    let _ = shutdown_milter.send(());
                    break;
                }
                _ => panic!("unexpected signal"),
            }
        }
    })
}

async fn cleanup(signals_handle: Handle, signals_task: JoinHandle<()>, socket_path: Option<&Path>) {
    signals_handle.close();
    signals_task.await.expect("signal handler task failed");

    if let Some(p) = socket_path {
        try_remove_socket(p).await;
    }
}

async fn try_remove_socket(path: impl AsRef<Path>) {
    if let Ok(metadata) = fs::metadata(&path).await {
        if metadata.file_type().is_socket() {
            let _ = fs::remove_file(path).await;
        }
    }
}
