/* SPDX-License-Identifier: MIT */

use crate::config::{Config, SpoolConfig};
use anyhow::{anyhow, bail};
use axum::body::StreamBody;
use axum::extract::Form;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Extension;
use serde::Deserialize;
use std::process::Stdio;
use time::format_description::well_known::Rfc3339;
use time::{Duration, OffsetDateTime};
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncReadExt;
use tokio::io::BufReader;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info};

#[derive(Deserialize, Debug)]
#[allow(dead_code)]
pub struct FetchRequest {
    #[serde(rename = "default-timezone-offset")]
    default_timezone_offset: String,
    #[serde(rename = "query-type")]
    query_type: String,
    filter: String,
    #[serde(rename = "start-time")]
    start_time: String,
    duration: String,
    event: Option<String>,
    #[serde(rename = "duration-before")]
    duration_before: String,
    #[serde(rename = "duration-after")]
    duration_after: String,
    spool: String,
}

#[derive(Debug)]
pub struct BadRequest {
    reason: String,
}

impl BadRequest {
    fn new<S: AsRef<str>>(reason: S) -> Self {
        Self {
            reason: reason.as_ref().to_string(),
        }
    }
}

impl IntoResponse for BadRequest {
    fn into_response(self) -> Response {
        (StatusCode::BAD_REQUEST, self.reason).into_response()
    }
}

fn get_spool_config<'a>(config: &'a Config, name: &str) -> Option<&'a SpoolConfig> {
    for spool in &config.spools {
        if spool.name == name {
            return Some(spool);
        }
    }
    None
}

pub async fn fetch(
    Extension(config): Extension<Config>,
    Form(params): Form<FetchRequest>,
) -> Result<impl IntoResponse, BadRequest> {
    let spool_config = get_spool_config(&config, &params.spool).ok_or_else(|| {
        error!("No spool with the name: {}", &params.spool);
        BadRequest::new(&format!("invalid spool name: {}", &params.spool))
    })?;

    dbg!(&params);
    let process_name = std::env::args().next().as_ref().unwrap().to_string();

    let start_time = OffsetDateTime::parse(&params.start_time, &Rfc3339).map_err(|err| {
        error!("Invalid start-time: {} -- {:?}", &params.start_time, err);
        BadRequest::new("invalid start-time")
    })?;

    let duration = parse_duration(&params.duration).map_err(|err| {
        error!("Invalid duration: {} -- {:?}", &params.duration, err);
        BadRequest::new("invalid duration")
    })?;

    let mut command = tokio::process::Command::new(process_name);
    command
        .arg("export")
        .arg("--json")
        .arg("-v")
        .arg("--directory")
        .arg(&spool_config.directory)
        .arg("--filter")
        .arg(&params.filter)
        .arg("--start-time")
        .arg(format!("{}", start_time.unix_timestamp()))
        .arg("--duration")
        .arg(format!("{}", duration.as_seconds_f64() as i64))
        .arg("--output")
        .arg("-");
    if let Some(prefix) = &spool_config.prefix {
        command.arg("--prefix").arg(prefix);
    }
    command.stdout(Stdio::piped());
    command.stderr(Stdio::piped());
    let mut child = command.spawn().unwrap();
    let mut stdout = child.stdout.take().unwrap();
    let stderr = child.stderr.take().unwrap();

    let (tx, rx) =
        tokio::sync::mpsc::unbounded_channel::<std::result::Result<Vec<u8>, std::io::Error>>();
    let rx = UnboundedReceiverStream::new(rx);
    let body = StreamBody::new(rx);
    let mut stderr_reader = BufReader::new(stderr).lines();
    let (wait_tx, wait_rx) = tokio::sync::oneshot::channel::<&str>();

    tokio::spawn(async move {
        let mut bytes = 0;
        let mut wait_tx = Some(wait_tx);
        let mut client_closed = false;
        loop {
            let mut buf = Vec::with_capacity(8192);
            tokio::select! {
                _ = child.wait() => {
                    break;
                }
                _ = tx.closed() => {
                    client_closed = true;
                    break;
                }
                closed = read_stderr(&mut stderr_reader) => {
                    if closed {
                        break;
                    }
                }
                x = stdout.read_buf(&mut buf) => {
                    match x {
                        Ok(n) => {
                            if n == 0 {
                                break;
                            } else {
                                if let Some(wait_tx) = wait_tx.take() {
                                    wait_tx.send("ok").unwrap();
                                }
                                if tx.send(Ok(buf)).is_err() {
                                    error!("Failed to write to body stream, client must have closed connection");
                                    client_closed = true;
                                    break;
                                }
                                bytes += n;
                            }
                        }
                        Err(err) => {
                            error!("Error reading export process stdout: {:?}", err);
                            break;
                        }
                    }
                }
            }
        }

        if client_closed {
            let _ = child.start_kill();
        }
        let status = child.wait().await.unwrap();
        if status.success() {
            info!(
                "Export process exited successfully, bytes written: {}",
                bytes
            );
        } else {
            error!("Export process exited with error code: {:?}", status);
        }
        if let Some(wait_tx) = wait_tx.take() {
            if status.success() {
                wait_tx.send("nopkt").unwrap();
            } else {
                wait_tx.send("err").unwrap();
            }
        }
    });

    Ok(match wait_rx.await {
        Err(err) => {
            error!("Error on wait channel: {:?}", err);
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                "An error occurred. See server logs for details",
            )
                .into_response()
        }
        Ok(status) => {
            if status == "ok" {
                body.into_response()
            } else if status == "nopkt" {
                (StatusCode::NOT_FOUND, "No packets found.").into_response()
            } else {
                (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    "An error occurred. See server logs for details",
                )
                    .into_response()
            }
        }
    })
}

async fn read_stderr(
    reader: &mut tokio::io::Lines<tokio::io::BufReader<tokio::process::ChildStderr>>,
) -> bool {
    match reader.next_line().await {
        Ok(Some(next)) => {
            info!("export process: {}", next);
            false
        }
        Ok(None) => true,
        Err(err) => {
            error!("Failed to export process stderr: {:?}", err);
            true
        }
    }
}

fn parse_duration(s: &str) -> anyhow::Result<Duration> {
    let re = regex::Regex::new(r"^(\d+)(.+)$")?;
    let captures = re
        .captures(s)
        .ok_or_else(|| anyhow!("invalid duration string: {}", s))?;
    let value = captures.get(1).unwrap().as_str().parse::<i64>()?;
    let unit = captures.get(2).unwrap().as_str();

    match unit {
        "s" => Ok(time::Duration::seconds(value)),
        "m" => Ok(time::Duration::minutes(value)),
        "h" => Ok(time::Duration::hours(value)),
        _ => bail!("invalid duration unit: {}", unit),
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_parse_duration() {
        let duration = parse_duration("1s").unwrap();
        assert_eq!(duration, Duration::seconds(1));

        let duration = parse_duration("1m").unwrap();
        assert_eq!(duration, Duration::seconds(60));

        let duration = parse_duration("1h").unwrap();
        assert_eq!(duration, Duration::seconds(3600));

        assert!(parse_duration("1").is_err());
        assert!(parse_duration("1z").is_err());
        assert!(parse_duration("am").is_err());
        assert!(parse_duration("99999999999999999999999999999m").is_err());
    }
}
