// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
// This Source Code Form is "Incompatible With Secondary Licenses", as
// defined by the Mozilla Public License, v. 2.0.

//! Async SCGI Client & Server
//!
//! This library will work with any async runtime that uses the [`futures-io`](https://crates.io/crates/futures-io)
//! library I/O traits.
//!
//! This crate provides two main tools:
//! - The [`ScgiRequest`] type to read & write SCGI requests.
//! - The [`read_request`] function to read an SCGI request from a socket.
//!
//! ## Client Example
//!
//! ```no_run
//! # use std::str::from_utf8;
//! # use futures_lite::{AsyncReadExt, AsyncWriteExt};
//! # use smol::net::TcpStream;
//! use async_scgi::{ScgiHeaders, ScgiRequest};
//!
//! # fn main() -> anyhow::Result<()> {
//! # smol::block_on(async {
//! let mut stream = TcpStream::connect("127.0.0.1:12345").await?;
//! let mut headers = ScgiHeaders::new();
//! headers.insert("PATH_INFO".to_owned(), "/".to_owned());
//! headers.insert("SERVER_NAME".to_owned(), "example.com".to_owned());
//! let body = b"Hello world!";
//! let req = ScgiRequest {
//!     headers,
//!     body: body.to_vec(),
//! };
//! stream.write_all(&req.encode()).await?;
//! let mut resp = vec![];
//! stream.read_to_end(&mut resp).await?;
//! let resp_str = from_utf8(&resp)?;
//! println!("{}", resp_str);
//! # Ok(())
//! # })
//! # }
//! ```
//!
//! ## Server Example
//!
//! ```no_run
//! # use futures_lite::{AsyncWriteExt, StreamExt};
//! # use smol::io::BufReader;
//! # use smol::net::TcpListener;
//! # use std::str::from_utf8;
//! #
//! # fn main() -> anyhow::Result<()> {
//! # smol::block_on(async {
//! let listener = TcpListener::bind("127.0.0.1:12345").await?;
//! let mut incoming = listener.incoming();
//! while let Some(stream) = incoming.next().await {
//!     let mut stream = BufReader::new(stream?);
//!     let req = async_scgi::read_request(&mut stream).await?;
//!     println!("Headers: {:?}", req.headers);
//!     println!("Body: {}", from_utf8(&req.body).unwrap());
//!     stream.write_all(b"Hello Client!").await?;
//! }
//! # Ok(())
//! # })
//! # }
//! ```

#[cfg(test)]
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::str::{from_utf8, Utf8Error};

use futures_lite::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt};
use headers::{encode_headers, ScgiHeaderParseError};
use thiserror::Error;

/// An error that occurred while reading an SCGI request.
#[derive(Error, Debug)]
pub enum ScgiReadError {
    /// Length can't be decoded to an integer.
    #[error("Length can't be decoded to an integer")]
    BadLength,
    /// The length or the headers are not in UTF-8.
    #[error("The length or the headers are not in UTF-8")]
    Utf8(#[from] Utf8Error),
    /// Netstring sanity checks fail.
    #[error("Netstring sanity checks fail")]
    BadNetstring,
    /// Error parsing SCGI headers.
    #[error("Error parsing SCGI headers")]
    BadHeaders(#[from] ScgiHeaderParseError),
    /// IO Error.
    #[error("IO Error")]
    IO(#[from] std::io::Error),
}

/// An ScgiRequest header map.
#[cfg(not(test))]
pub type ScgiHeaders = HashMap<String, String>;
#[cfg(test)]
pub type ScgiHeaders = BTreeMap<String, String>;

/// An SCGI request.
///
/// The `SCGI` and `CONTENT_LENGTH` length headers are added automatically when
/// [`ScgiRequest::encode`] is called and removed when requests are read.
#[derive(Debug, Default, PartialEq, Eq)]
pub struct ScgiRequest {
    /// The request header name, value pairs.
    pub headers: ScgiHeaders,
    /// The request body.
    pub body: Vec<u8>,
}

impl ScgiRequest {
    /// Create an empty ScgiRequest.
    pub fn new() -> Self {
        Self::default()
    }

    /// Create an ScgiRequest with a set of headers.
    pub fn from_headers(headers: ScgiHeaders) -> Self {
        Self {
            headers,
            body: Vec::new(),
        }
    }

    /// Encode an ScgiRequest to be sent over the wire.
    pub fn encode(&self) -> Vec<u8> {
        let headers = encode_headers(&self.headers, self.body.len());
        let mut buf = Vec::with_capacity(headers.len() + 6);
        buf.extend(headers.len().to_string().as_bytes());
        buf.push(b':');
        buf.extend(headers);
        buf.push(b',');
        buf.extend(&self.body);
        buf
    }
}

/// Read an SCGI request.
pub async fn read_request<S: AsyncBufRead + Unpin>(
    stream: &mut S,
) -> Result<ScgiRequest, ScgiReadError> {
    let mut len_part = Vec::with_capacity(10);
    let read = stream.read_until(b':', &mut len_part).await?;
    if len_part[read - 1] != b':' {
        return Err(ScgiReadError::BadNetstring);
    }
    let length = from_utf8(&len_part[..read - 1])?
        .parse::<usize>()
        .map_err(|_| ScgiReadError::BadLength)?;
    let mut headers = vec![0; length];
    stream.read_exact(&mut headers).await?;
    let mut end_delim = 0;
    stream
        .read_exact(std::slice::from_mut(&mut end_delim))
        .await?;
    if end_delim != b',' {
        return Err(ScgiReadError::BadNetstring);
    }
    let (headers, content_length) = headers::header_string_map(&headers)?;
    let mut body = vec![0; content_length];
    stream.read_exact(&mut body).await?;
    Ok(ScgiRequest { headers, body })
}

/// Functions for working with SCGI headers directly.
pub mod headers {
    use super::*;
    use memchr::memchr;

    /// An error that occurred while parsing SCGI headers.
    #[derive(Error, Debug)]
    pub enum ScgiHeaderParseError {
        /// The length or the headers are not in UTF-8.
        #[error("The length or the headers are not in UTF-8")]
        Utf8(#[from] Utf8Error),
        /// Error parsing the null-terminated headers.
        #[error("Error parsing the null-terminated headers")]
        BadHeaderVals,
        /// CONTENT_LENGTH can't be decoded to an integer.
        #[error("CONTENT_LENGTH can't be decoded to an integer")]
        BadLength,
        /// CONTENT_LENGTH header was missing.
        #[error("CONTENT_LENGTH header was missing")]
        NoLength,
    }

    /// Encode headers to be sent over the wire.
    ///
    /// The `SCGI` and `CONTENT_LENGTH` headers are added automatically.
    pub fn encode_headers(headers: &ScgiHeaders, content_length: usize) -> Vec<u8> {
        let mut buf = Vec::new();

        // Add required SCGI version header
        buf.extend(b"SCGI");
        buf.push(0);
        buf.extend(b"1");
        buf.push(0);

        // Add required CONTENT_LENGTH version header
        buf.extend(b"CONTENT_LENGTH");
        buf.push(0);
        buf.extend(content_length.to_string().as_bytes());
        buf.push(0);

        for (name, value) in headers.iter() {
            buf.extend(name.as_bytes());
            buf.push(0);
            buf.extend(value.as_bytes());
            buf.push(0);
        }
        buf
    }

    /// Parse the headers, invoking the `header` closure for every header parsed.
    pub fn parse_headers<'h>(
        raw_headers: &'h [u8],
        mut headers_fn: impl FnMut(&'h str, &'h str) -> Result<(), ScgiHeaderParseError>,
    ) -> Result<(), ScgiHeaderParseError> {
        let mut pos = 0;
        while pos < raw_headers.len() {
            let null = memchr(0, &raw_headers[pos..]).ok_or(ScgiHeaderParseError::BadHeaderVals)?;
            let header_name = from_utf8(&raw_headers[pos..pos + null])?;
            pos += null + 1;
            let null = memchr(0, &raw_headers[pos..]).ok_or(ScgiHeaderParseError::BadHeaderVals)?;
            let header_value = from_utf8(&raw_headers[pos..pos + null])?;
            headers_fn(header_name, header_value)?;
            pos += null + 1;
        }
        Ok(())
    }

    /// Parse the headers and pack them as strings into a map.
    ///
    /// The value of the `CONTENT_LENGTH` header is returned in adition to the
    /// header map. The `SCGI` and `CONTENT_LENGTH` headers are not included
    /// in the header map.
    pub fn header_string_map(
        raw_headers: &[u8],
    ) -> Result<(ScgiHeaders, usize), ScgiHeaderParseError> {
        let mut headers_map = ScgiHeaders::new();
        let mut content_length = None;
        parse_headers(raw_headers, |name, value| {
            // Ignore SCGI version header
            if name != "SCGI" {
                if name == "CONTENT_LENGTH" {
                    content_length =
                        Some(value.parse().map_err(|_| ScgiHeaderParseError::BadLength)?);
                } else {
                    headers_map.insert(name.to_owned(), value.to_owned());
                }
            }
            Ok(())
        })?;
        Ok((
            headers_map,
            content_length.ok_or(ScgiHeaderParseError::NoLength)?,
        ))
    }

    /// Parse the headers and pack them as slices into a map.
    ///
    /// The value of the `CONTENT_LENGTH` header is returned in adition to the
    /// header map. The `SCGI` and `CONTENT_LENGTH` headers are not included
    /// in the header map.
    pub fn header_str_map<'h>(
        raw_headers: &'h [u8],
    ) -> Result<(HashMap<&'h str, &'h str>, usize), ScgiHeaderParseError> {
        let mut headers_map = HashMap::new();
        let mut content_length = None;
        parse_headers(raw_headers, |name, value| {
            // Ignore SCGI version header
            if name != "SCGI" {
                if name == "CONTENT_LENGTH" {
                    content_length =
                        Some(value.parse().map_err(|_| ScgiHeaderParseError::BadLength)?);
                } else {
                    headers_map.insert(name, value);
                }
            }
            Ok(())
        })?;
        Ok((
            headers_map,
            content_length.ok_or(ScgiHeaderParseError::NoLength)?,
        ))
    }
}

#[cfg(test)]
mod tests {
    use futures_lite::io::BufReader;

    use crate::read_request;

    use super::*;

    const TEST_DATA: &[u8] = include_bytes!("../test_data/dump");

    #[test]
    fn encode_scgi_request() {
        let mut headers = ScgiHeaders::new();
        headers.insert("hello".to_owned(), "world".to_owned());
        headers.insert("foo".to_owned(), "bar".to_owned());
        let req = ScgiRequest {
            headers,
            body: "here is some data".into(),
        };
        let encoded = req.encode();
        assert_eq!(TEST_DATA, &encoded);
    }

    #[test]
    fn read_scgi_request() {
        let mut headers = ScgiHeaders::new();
        headers.insert("hello".to_owned(), "world".to_owned());
        headers.insert("foo".to_owned(), "bar".to_owned());
        let expected = ScgiRequest {
            headers,
            body: "here is some data".into(),
        };
        futures_lite::future::block_on(async {
            let mut stream = BufReader::new(TEST_DATA);
            let req = read_request(&mut stream).await.unwrap();
            assert_eq!(expected, req);
        });
    }
}
