#![doc = include_str!("../README.md")]
#[cfg(feature = "fastly")]
pub mod fastly_transport;
mod region;
#[cfg(feature = "reqwest")]
pub mod reqwest_transport;

use chrono::{DateTime, Utc};
use hmac::{Hmac, Mac, NewMac};
use http::{
    header::{HeaderName, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, HOST},
    method::Method,
    Request as HttpRequest, Uri,
};
pub use region::Region;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::{collections::HashMap, error::Error, fmt::Display, iter::FromIterator};

const SHORT_DATE: &str = "%Y%m%d";
const LONG_DATETIME: &str = "%Y%m%dT%H%M%SZ";
const X_AMZ_CONTENT_SHA256: &[u8] = b"X-Amz-Content-Sha256";

/// A type alias for `http::RequestVec<u8>`
pub type Request = HttpRequest<Vec<u8>>;
type HmacSha256 = Hmac<Sha256>;

/// A set of AWS credentials to authenticate requests with
pub struct Credentials {
    aws_access_key_id: String,
    aws_secret_access_key: String,
}

impl Credentials {
    pub fn new(
        aws_access_key_id: impl AsRef<str>,
        aws_secret_access_key: impl AsRef<str>,
    ) -> Self {
        Self {
            aws_access_key_id: aws_access_key_id.as_ref().to_owned(),
            aws_secret_access_key: aws_secret_access_key.as_ref().to_owned(),
        }
    }
}

/// Information about your target AWS DynamoDB table
#[non_exhaustive]
pub struct Table {
    /// The name of your DynamoDB
    pub table_name: String,
    /// The name of the attribute that will store your key
    pub key_name: String,
    /// The name of the attribute that will store your value
    pub value_name: String,
    /// The AWS region the table is hosted in.
    ///
    /// When `endpoint` is defined, the value of this field is is somewhat arbitrary
    pub region: Region,
    /// An Optional, uri to address the DynamoDB api, often times just for dynamodb local
    pub endpoint: Option<String>,
}

impl Table {
    pub fn new(
        table_name: impl AsRef<str>,
        key_name: impl AsRef<str>,
        value_name: impl AsRef<str>,
        region: Region,
        endpoint: impl Into<Option<String>>,
    ) -> Self {
        Self {
            table_name: table_name.as_ref().into(),
            key_name: key_name.as_ref().into(),
            value_name: value_name.as_ref().into(),
            region,
            endpoint: endpoint.into(),
        }
    }
}

/// A trait to implement the behavior for sending requests, often your "IO" layer
pub trait Transport {
    /// Accepts a signed `http::Request<Vec<u8>>` and returns a tuple
    /// representing a response's HTTP status code and body
    fn send(
        &self,
        signed: Request,
    ) -> Result<(u16, String), Box<dyn Error>>;
}

#[derive(Serialize, Deserialize)]
enum Attr {
    S(String),
}

#[derive(Serialize)]
#[serde(rename_all = "PascalCase")]
struct PutItemInput<'a> {
    table_name: &'a str,
    item: HashMap<&'a str, Attr>,
}

#[derive(Serialize)]
#[serde(rename_all = "PascalCase")]
struct GetItemInput<'a> {
    table_name: &'a str,
    key: HashMap<&'a str, Attr>,
    projection_expression: &'a str,
    expression_attribute_names: HashMap<&'a str, &'a str>,
}

#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct GetItemOutput {
    item: HashMap<String, Attr>,
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "PascalCase")]
struct AWSError {
    #[serde(alias = "__type")]
    __type: String,
    message: String,
}

impl Display for AWSError {
    fn fmt(
        &self,
        f: &mut std::fmt::Formatter<'_>,
    ) -> std::fmt::Result {
        f.write_str(self.__type.as_str())?;
        f.write_str(": ")?;
        f.write_str(self.message.as_str())
    }
}

impl Error for AWSError {}

#[derive(Debug)]
struct StrErr(String);

impl Display for StrErr {
    fn fmt(
        &self,
        f: &mut std::fmt::Formatter<'_>,
    ) -> std::fmt::Result {
        f.write_str(self.0.as_str())
    }
}

impl Error for StrErr {}

/// The central client interface applications will work with
///
/// # Example
///
/// ```rust ,no_run
/// # use std::{env, error::Error};
/// # use tiny_dynamo::{reqwest_transport::Reqwest, Credentials, Table, DB};
/// # fn main() -> Result<(), Box<dyn Error>> {
///let db = DB::new(
///    Credentials::new(
///        env::var("AWS_ACCESS_KEY_ID")?,
///        env::var("AWS_SECRET_ACCESS_KEY")?,
///    ),
///    Table::new(
///        "table-name",
///        "key-attr-name",
///        "value-attr-name",
///        "us-east-1".parse()?,
///        None
///    ),
///    Reqwest::new(),
///);
/// # Ok(())
/// # }
/// ```
pub struct DB {
    credentials: Credentials,
    table_info: Table,
    transport: Box<dyn Transport>,
}

impl DB {
    /// Returns a new instance of a DB
    pub fn new(
        credentials: Credentials,
        table_info: Table,
        transport: impl Transport + 'static,
    ) -> Self {
        Self {
            credentials,
            table_info,
            transport: Box::new(transport),
        }
    }

    /// Gets a value by its key
    pub fn get(
        &self,
        key: impl AsRef<str>,
    ) -> Result<Option<String>, Box<dyn Error>> {
        let Table { value_name, .. } = &self.table_info;
        match self.transport.send(self.get_item_req(key)?)? {
            (200, body) if body.as_str() == "{}" => Ok(None), // not found
            (200, body) => Ok(serde_json::from_str::<GetItemOutput>(&body)?
                .item
                .get(value_name)
                .iter()
                .find_map(|attr| match attr {
                    Attr::S(v) => Some(v.clone()),
                })),
            (_, body) => Err(Box::new(serde_json::from_str::<AWSError>(&body)?)),
        }
    }

    /// Sets a value for a given key
    pub fn set(
        &self,
        key: impl AsRef<str>,
        value: impl AsRef<str>,
    ) -> Result<(), Box<dyn Error>> {
        match self.transport.send(self.put_item_req(key, value)?)? {
            (200, _) => Ok(()),
            (_, body) => Err(Box::new(serde_json::from_str::<AWSError>(&body)?)),
        }
    }

    #[doc(hidden)]
    pub fn put_item_req(
        &self,
        key: impl AsRef<str>,
        value: impl AsRef<str>,
    ) -> Result<Request, Box<dyn Error>> {
        // https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_PutItem.html
        let req = http::Request::builder();
        let Table {
            table_name,
            key_name,
            value_name,
            region,
            endpoint,
            ..
        } = &self.table_info;
        let uri: Uri = endpoint
            .as_deref()
            .unwrap_or_else(|| region.endpoint())
            .parse()?;
        self.sign(
            req.method(Method::POST)
                .uri(&uri)
                .header(HOST, uri.authority().expect("expected host").as_str())
                .header(CONTENT_TYPE, "application/x-amz-json-1.0")
                .header("X-Amz-Target", "DynamoDB_20120810.PutItem")
                .body(serde_json::to_vec(&PutItemInput {
                    table_name,
                    item: HashMap::from_iter([
                        (key_name.as_str(), Attr::S(key.as_ref().to_owned())),
                        (value_name.as_ref(), Attr::S(value.as_ref().to_owned())),
                    ]),
                })?)?,
        )
    }

    #[doc(hidden)]
    pub fn get_item_req(
        &self,
        key: impl AsRef<str>,
    ) -> Result<Request, Box<dyn Error>> {
        // https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_GetItem.html
        let req = http::Request::builder();
        let Table {
            table_name,
            key_name,
            value_name,
            region,
            endpoint,
            ..
        } = &self.table_info;
        let uri: Uri = endpoint
            .as_deref()
            .unwrap_or_else(|| region.endpoint())
            .parse()?;
        self.sign(
            req.method(Method::POST)
                .uri(&uri)
                .header(HOST, uri.authority().expect("expected host").as_str())
                .header(CONTENT_TYPE, "application/x-amz-json-1.0")
                .header("X-Amz-Target", "DynamoDB_20120810.GetItem")
                .body(serde_json::to_vec(&GetItemInput {
                    table_name,
                    key: HashMap::from_iter([(
                        key_name.as_str(),
                        Attr::S(key.as_ref().to_owned()),
                    )]),
                    // we use #v because https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/ReservedWords.html
                    projection_expression: "#v",
                    expression_attribute_names: HashMap::from_iter([("#v", value_name.as_ref())]),
                })?)?,
        )
    }

    fn sign(
        &self,
        mut unsigned: Request,
    ) -> Result<Request, Box<dyn Error>> {
        fn hmac(
            key: &[u8],
            data: &[u8],
        ) -> Result<Vec<u8>, Box<dyn Error>> {
            let mut mac = HmacSha256::new_from_slice(key).map_err(|e| StrErr(e.to_string()))?;
            mac.update(data);
            Ok(mac.finalize().into_bytes().to_vec())
        }

        let body_digest = {
            let mut sha = Sha256::default();
            sha.update(unsigned.body());
            hex::encode(sha.finalize().as_slice())
        };

        let now = Utc::now();
        unsigned
            .headers_mut()
            .append("X-Amz-Date", now.format(LONG_DATETIME).to_string().parse()?);

        fn signed_header_string(headers: &http::HeaderMap) -> String {
            let mut keys = headers
                .keys()
                .map(|key| key.as_str().to_lowercase())
                .collect::<Vec<_>>();
            keys.sort();
            keys.join(";")
        }

        fn string_to_sign(
            datetime: &DateTime<Utc>,
            region: &str,
            canonical_req: &str,
        ) -> String {
            let mut hasher = Sha256::default();
            hasher.update(canonical_req.as_bytes());
            format!(
                "AWS4-HMAC-SHA256\n{timestamp}\n{scope}\n{canonical_req_hash}",
                timestamp = datetime.format(LONG_DATETIME),
                scope = scope_string(datetime, region),
                canonical_req_hash = hex::encode(hasher.finalize().as_slice())
            )
        }

        fn signing_key(
            datetime: &DateTime<Utc>,
            secret_key: &str,
            region: &str,
        ) -> Result<Vec<u8>, Box<dyn Error>> {
            [region.as_bytes(), b"dynamodb", b"aws4_request"]
                .iter()
                .try_fold::<_, _, Result<_, Box<dyn Error>>>(
                    hmac(
                        &[b"AWS4", secret_key.as_bytes()].concat(),
                        datetime.format(SHORT_DATE).to_string().as_bytes(),
                    )?,
                    |res, next| hmac(&res, next),
                )
        }

        fn scope_string(
            datetime: &DateTime<Utc>,
            region: &str,
        ) -> String {
            format!(
                "{date}/{region}/dynamodb/aws4_request",
                date = datetime.format(SHORT_DATE),
                region = region
            )
        }

        fn canonical_header_string(headers: &http::HeaderMap) -> String {
            let mut keyvalues = headers
                .iter()
                .map(|(key, value)| {
                    // Values that are not strings are silently dropped (AWS wouldn't
                    // accept them anyway)
                    key.as_str().to_lowercase() + ":" + value.to_str().unwrap().trim()
                })
                .collect::<Vec<_>>();
            keyvalues.sort();
            keyvalues.join("\n")
        }

        fn canonical_request(
            method: &str,
            headers: &http::HeaderMap,
            body_digest: &str,
        ) -> String {
            // note: all dynamodb uris are requests to / with no query string so theres no need
            // to derive those from the request
            format!(
                "{method}\n/\n\n{headers}\n\n{signed_headers}\n{body_digest}",
                method = method,
                headers = canonical_header_string(headers),
                signed_headers = signed_header_string(headers),
                body_digest = body_digest
            )
        }

        let canonical_request = canonical_request(
            unsigned.method().as_str(),
            unsigned.headers(),
            body_digest.as_str(),
        );

        fn authorization_header(
            access_key: &str,
            datetime: &DateTime<Utc>,
            region: &str,
            signed_headers: &str,
            signature: &str,
        ) -> String {
            format!(
                "AWS4-HMAC-SHA256 Credential={access_key}/{scope}, SignedHeaders={signed_headers}, Signature={signature}",
                access_key = access_key,
                scope = scope_string(datetime, region),
                signed_headers = signed_headers,
                signature = signature
            )
        }

        let string_to_sign = string_to_sign(&now, self.table_info.region.id(), &canonical_request);
        let signature = hex::encode(hmac(
            &signing_key(
                &now,
                &self.credentials.aws_secret_access_key,
                self.table_info.region.id(),
            )?,
            string_to_sign.as_bytes(),
        )?);
        let headers_string = signed_header_string(unsigned.headers());
        let content_length = unsigned.body().len();
        unsigned.headers_mut().extend([
            (
                AUTHORIZATION,
                authorization_header(
                    &self.credentials.aws_access_key_id,
                    &Utc::now(),
                    self.table_info.region.id(),
                    &headers_string,
                    &signature,
                )
                .parse()?,
            ),
            (CONTENT_LENGTH, content_length.to_string().parse()?),
            (
                HeaderName::from_bytes(X_AMZ_CONTENT_SHA256)?,
                body_digest.parse()?,
            ),
        ]);

        Ok(unsigned)
    }
}

/// Provides a `Transport` implementation for a constantized response.
pub struct Const(pub u16, pub String);

impl Transport for Const {
    fn send(
        &self,
        _: Request,
    ) -> Result<(u16, String), Box<dyn Error>> {
        let Const(status, body) = self;
        Ok((*status, body.clone()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn get_item_input_serilizes_as_expected() -> Result<(), Box<dyn Error>> {
        assert_eq!(
            serde_json::to_string(&GetItemInput {
                table_name: "test-table",
                key: HashMap::from_iter([("key-name", Attr::S("key-value".into()))]),
                projection_expression: "#v",
                expression_attribute_names: HashMap::from_iter([("#v", "value-name")]),
            })?,
            r##"{"TableName":"test-table","Key":{"key-name":{"S":"key-value"}},"ProjectionExpression":"#v","ExpressionAttributeNames":{"#v":"value-name"}}"##
        );
        Ok(())
    }

    #[test]
    fn put_item_input_serilizes_as_expected() -> Result<(), Box<dyn Error>> {
        // assert_eq!(
        //     serde_json::to_string(&PutItemInput {
        //         table_name: "test-table",
        //         item: HashMap::from_iter([
        //             ("key-name", Attr::S("key-value".into())),
        //             ("value-name", Attr::S("value".into())),
        //         ]),
        //     })?,
        //     r##"{"TableName":"test-table","Item":{"key-name":{"S":"key-value"},"value-name":{"S":"value"}}}"##
        // );
        Ok(())
    }
}
