//! Commonly used middleware.

mod add_data;
mod catch_panic;
#[cfg(feature = "compression")]
mod compression;
#[cfg(feature = "cookie")]
mod cookie_jar_manager;
mod cors;
#[cfg(feature = "csrf")]
mod csrf;
mod force_https;
mod normalize_path;
#[cfg(feature = "opentelemetry")]
mod opentelemetry_metrics;
#[cfg(feature = "opentelemetry")]
mod opentelemetry_tracing;
mod propagate_header;
mod sensitive_header;
mod set_header;
mod size_limit;
#[cfg(feature = "tokio-metrics")]
mod tokio_metrics_mw;
#[cfg(feature = "tower-compat")]
mod tower_compat;
mod tracing_mw;

#[cfg(feature = "compression")]
pub use self::compression::{Compression, CompressionEndpoint};
#[cfg(feature = "cookie")]
pub use self::cookie_jar_manager::{CookieJarManager, CookieJarManagerEndpoint};
#[cfg(feature = "csrf")]
pub use self::csrf::{Csrf, CsrfEndpoint};
#[cfg(feature = "opentelemetry")]
pub use self::opentelemetry_metrics::{OpenTelemetryMetrics, OpenTelemetryMetricsEndpoint};
#[cfg(feature = "opentelemetry")]
pub use self::opentelemetry_tracing::{OpenTelemetryTracing, OpenTelemetryTracingEndpoint};
#[cfg(feature = "tokio-metrics")]
pub use self::tokio_metrics_mw::{TokioMetrics, TokioMetricsEndpoint};
#[cfg(feature = "tower-compat")]
pub use self::tower_compat::TowerLayerCompatExt;
pub use self::{
    add_data::{AddData, AddDataEndpoint},
    catch_panic::{CatchPanic, CatchPanicEndpoint, PanicHandler},
    cors::{Cors, CorsEndpoint},
    force_https::ForceHttps,
    normalize_path::{NormalizePath, NormalizePathEndpoint, TrailingSlash},
    propagate_header::{PropagateHeader, PropagateHeaderEndpoint},
    sensitive_header::{SensitiveHeader, SensitiveHeaderEndpoint},
    set_header::{SetHeader, SetHeaderEndpoint},
    size_limit::{SizeLimit, SizeLimitEndpoint},
    tracing_mw::{Tracing, TracingEndpoint},
};
use crate::endpoint::Endpoint;

/// Represents a middleware trait.
///
/// # Create you own middleware
///
/// ```
/// use poem::{
///     handler, test::TestClient, web::Data, Endpoint, EndpointExt, Middleware, Request, Result,
/// };
///
/// /// A middleware that extract token from HTTP headers.
/// struct TokenMiddleware;
///
/// impl<E: Endpoint> Middleware<E> for TokenMiddleware {
///     type Output = TokenMiddlewareImpl<E>;
///
///     fn transform(&self, ep: E) -> Self::Output {
///         TokenMiddlewareImpl { ep }
///     }
/// }
///
/// /// The new endpoint type generated by the TokenMiddleware.
/// struct TokenMiddlewareImpl<E> {
///     ep: E,
/// }
///
/// const TOKEN_HEADER: &str = "X-Token";
///
/// /// Token data
/// struct Token(String);
///
/// #[poem::async_trait]
/// impl<E: Endpoint> Endpoint for TokenMiddlewareImpl<E> {
///     type Output = E::Output;
///
///     async fn call(&self, mut req: Request) -> Result<Self::Output> {
///         if let Some(value) = req
///             .headers()
///             .get(TOKEN_HEADER)
///             .and_then(|value| value.to_str().ok())
///         {
///             // Insert token data to extensions of request.
///             let token = value.to_string();
///             req.extensions_mut().insert(Token(token));
///         }
///
///         // call the next endpoint.
///         self.ep.call(req).await
///     }
/// }
///
/// #[handler]
/// async fn index(Data(token): Data<&Token>) -> String {
///     token.0.clone()
/// }
///
/// // Use the `TokenMiddleware` middleware to convert the `index` endpoint.
/// let ep = index.with(TokenMiddleware);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let mut resp = TestClient::new(ep)
///     .get("/")
///     .header(TOKEN_HEADER, "abc")
///     .send()
///     .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("abc").await;
/// # });
/// ```
///
/// # Create middleware with functions
///
/// ```rust
/// use std::sync::Arc;
///
/// use poem::{
///     handler, test::TestClient, web::Data, Endpoint, EndpointExt, IntoResponse, Request, Result,
/// };
/// const TOKEN_HEADER: &str = "X-Token";
///
/// #[handler]
/// async fn index(Data(token): Data<&Token>) -> String {
///     token.0.clone()
/// }
///
/// /// Token data
/// struct Token(String);
///
/// async fn token_middleware<E: Endpoint>(next: E, mut req: Request) -> Result<E::Output> {
///     if let Some(value) = req
///         .headers()
///         .get(TOKEN_HEADER)
///         .and_then(|value| value.to_str().ok())
///     {
///         // Insert token data to extensions of request.
///         let token = value.to_string();
///         req.extensions_mut().insert(Token(token));
///     }
///
///     // call the next endpoint.
///     next.call(req).await
/// }
///
/// let ep = index.around(token_middleware);
/// let cli = TestClient::new(ep);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = cli.get("/").header(TOKEN_HEADER, "abc").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("abc").await;
/// # });
/// ```
pub trait Middleware<E: Endpoint> {
    /// New endpoint type.
    ///
    /// If you don't know what type to use, then you can use
    /// [`BoxEndpoint`](crate::endpoint::BoxEndpoint), which will bring some
    /// performance loss, but it is insignificant.
    type Output: Endpoint;

    /// Transform the input [`Endpoint`] to another one.
    fn transform(&self, ep: E) -> Self::Output;
}

poem_derive::generate_implement_middlewares!();

/// A middleware implemented by a closure.
pub struct FnMiddleware<T>(T);

impl<T, E, E2> Middleware<E> for FnMiddleware<T>
where
    T: Fn(E) -> E2,
    E: Endpoint,
    E2: Endpoint,
{
    type Output = E2;

    fn transform(&self, ep: E) -> Self::Output {
        (self.0)(ep)
    }
}

/// Make middleware with a closure.
pub fn make<T>(f: T) -> FnMiddleware<T> {
    FnMiddleware(f)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        handler,
        http::{header::HeaderName, HeaderValue},
        test::TestClient,
        web::Data,
        EndpointExt, IntoResponse, Request, Response, Result,
    };

    #[tokio::test]
    async fn test_make() {
        #[handler(internal)]
        fn index() -> &'static str {
            "abc"
        }

        struct AddHeader<E> {
            ep: E,
            header: HeaderName,
            value: HeaderValue,
        }

        #[async_trait::async_trait]
        impl<E: Endpoint> Endpoint for AddHeader<E> {
            type Output = Response;

            async fn call(&self, req: Request) -> Result<Self::Output> {
                let mut resp = self.ep.call(req).await?.into_response();
                resp.headers_mut()
                    .insert(self.header.clone(), self.value.clone());
                Ok(resp)
            }
        }

        let ep = index.with(make(|ep| AddHeader {
            ep,
            header: HeaderName::from_static("hello"),
            value: HeaderValue::from_static("world"),
        }));
        let cli = TestClient::new(ep);

        let resp = cli.get("/").send().await;
        resp.assert_header("hello", "world");
        resp.assert_text("abc").await;
    }

    #[tokio::test]
    async fn test_with_multiple_middlewares() {
        #[handler(internal)]
        fn index(data: Data<&i32>) -> String {
            data.0.to_string()
        }

        let ep = index.with((
            AddData::new(10),
            SetHeader::new().appending("myheader-1", "a"),
            SetHeader::new().appending("myheader-2", "b"),
        ));
        let cli = TestClient::new(ep);

        let resp = cli.get("/").send().await;
        resp.assert_status_is_ok();
        resp.assert_header("myheader-1", "a");
        resp.assert_header("myheader-2", "b");
        resp.assert_text("10").await;
    }
}
