//! Types for verifying requests with Actix Web

use crate::{Config, PrepareVerifyError, SignatureVerify};
use actix_web::{
    dev::{MessageBody, Payload, Service, ServiceRequest, ServiceResponse, Transform},
    http::StatusCode,
    Error, FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
};
use log::{debug, warn};
use std::{
    future::{ready, Future, Ready},
    pin::Pin,
    task::{Context, Poll},
};

#[derive(Clone, Debug)]
/// A marker type that can be used to guard routes when the signature middleware is set to
/// 'optional'
pub struct SignatureVerified(String);

impl SignatureVerified {
    /// Return the Key ID used to verify the request
    ///
    /// It might be important for an application to verify that the payload being processed indeed
    /// belongs to the owner of the key used to sign the request.
    pub fn key_id(&self) -> &str {
        &self.0
    }
}

#[derive(Clone, Debug)]
/// The Verify signature middleware
///
/// ```rust,ignore
/// let middleware = VerifySignature::new(MyVerifier::new(), Config::default())
///     .authorization()
///     .optional();
///
/// HttpServer::new(move || {
///     App::new()
///         .wrap(middleware.clone())
///         .route("/protected", web::post().to(|_: SignatureVerified| "Verified Authorization Header"))
///         .route("/unprotected", web::post().to(|| "No verification required"))
/// })
/// ```
pub struct VerifySignature<T>(T, Config, HeaderKind, bool);

#[derive(Clone, Debug)]
#[doc(hidden)]
pub struct VerifyMiddleware<T, S>(S, Config, HeaderKind, bool, T);

#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
enum HeaderKind {
    Authorization,
    Signature,
}

#[derive(Clone, Debug, thiserror::Error)]
#[error("Failed to verify http signature")]
#[doc(hidden)]
pub struct VerifyError;

impl<T> VerifySignature<T>
where
    T: SignatureVerify,
{
    /// Create a new middleware for verifying HTTP Signatures. A type implementing
    /// [`SignatureVerify`] is required, as well as a Config
    ///
    /// By default, this middleware expects to verify Signature headers, and requires the presence
    /// of the header
    pub fn new(verify_signature: T, config: Config) -> Self {
        VerifySignature(verify_signature, config, HeaderKind::Signature, false)
    }

    /// Verify Authorization headers instead of Signature headers
    pub fn authorization(self) -> Self {
        VerifySignature(self.0, self.1, HeaderKind::Authorization, self.3)
    }

    /// Mark the presence of a Signature or Authorization header as optional
    ///
    /// If a header is present, it will be verified, but if there is not one present, the request
    /// is passed through. This can be used to set a global middleware, and then guard each route
    /// handler with the [`SignatureVerified`] type.
    pub fn optional(self) -> Self {
        VerifySignature(self.0, self.1, self.2, true)
    }
}

impl<T, S, B> VerifyMiddleware<T, S>
where
    T: SignatureVerify + Clone + 'static,
    T::Future: 'static,
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
    B: MessageBody + 'static,
{
    fn handle(
        &self,
        req: ServiceRequest,
    ) -> Pin<Box<dyn Future<Output = Result<ServiceResponse<B>, Error>>>> {
        let res = self.1.begin_verify(
            req.method(),
            req.uri().path_and_query(),
            req.headers().clone(),
        );

        let unverified = match res {
            Ok(unverified) => unverified,
            Err(PrepareVerifyError::Expired) => {
                warn!("Header is expired");
                return Box::pin(ready(Err(VerifyError.into())));
            }
            Err(PrepareVerifyError::Missing) => {
                debug!("Header is missing");
                return Box::pin(ready(Err(VerifyError.into())));
            }
            Err(PrepareVerifyError::ParseField(field)) => {
                debug!("Failed to parse field {}", field);
                return Box::pin(ready(Err(VerifyError.into())));
            }
            Err(PrepareVerifyError::Header(e)) => {
                debug!("Failed to parse header {}", e);
                return Box::pin(ready(Err(VerifyError.into())));
            }
            Err(PrepareVerifyError::Required(req)) => {
                debug!("Missing required headers, {:?}", req);
                return Box::pin(ready(Err(VerifyError.into())));
            }
        };

        let algorithm = unverified.algorithm().cloned();
        let key_id = unverified.key_id().to_owned();

        let f1 = unverified.verify(|signature, signing_string| {
            self.4.clone().signature_verify(
                algorithm,
                key_id.clone(),
                signature.to_string(),
                signing_string.to_string(),
            )
        });

        req.extensions_mut().insert(SignatureVerified(key_id));

        let f2 = self.0.call(req);

        Box::pin(async move {
            if f1.await? {
                f2.await
            } else {
                warn!("Signature is invalid");
                Err(VerifyError.into())
            }
        })
    }
}

impl HeaderKind {
    pub fn is_authorization(self) -> bool {
        HeaderKind::Authorization == self
    }

    pub fn is_signature(self) -> bool {
        HeaderKind::Signature == self
    }
}

impl FromRequest for SignatureVerified {
    type Error = VerifyError;
    type Future = Ready<Result<Self, Self::Error>>;
    type Config = ();

    fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
        let res = req.extensions().get::<Self>().cloned().ok_or(VerifyError);

        if res.is_err() {
            debug!("Failed to fetch SignatureVerified from request");
        }

        ready(res)
    }
}

impl<T, S, B> Transform<S, ServiceRequest> for VerifySignature<T>
where
    T: SignatureVerify + Clone + 'static,
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
    S::Error: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type Transform = VerifyMiddleware<T, S>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(VerifyMiddleware(
            service,
            self.1.clone(),
            self.2,
            self.3,
            self.0.clone(),
        )))
    }
}

type FutResult<T, E> = dyn Future<Output = Result<T, E>>;
impl<T, S, B> Service<ServiceRequest> for VerifyMiddleware<T, S>
where
    T: SignatureVerify + Clone + 'static,
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
    S::Error: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type Future = Pin<Box<FutResult<Self::Response, Self::Error>>>;

    fn poll_ready(&self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        self.0.poll_ready(cx)
    }

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let authorization = req.headers().get("Authorization").is_some();
        let signature = req.headers().get("Signature").is_some();

        if authorization || signature {
            if self.2.is_authorization() && authorization {
                return self.handle(req);
            }

            if self.2.is_signature() && signature {
                return self.handle(req);
            }

            debug!("Authorization or Signature headers are missing");
            Box::pin(ready(Err(VerifyError.into())))
        } else if self.3 {
            debug!("Headers are missing but Optional is true, continuing");
            Box::pin(self.0.call(req))
        } else {
            debug!("Authorization or Signature headers are missing");
            Box::pin(ready(Err(VerifyError.into())))
        }
    }
}

impl ResponseError for VerifyError {
    fn status_code(&self) -> StatusCode {
        StatusCode::BAD_REQUEST
    }

    fn error_response(&self) -> HttpResponse {
        HttpResponse::new(self.status_code())
    }
}
