use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use crate::rejection::Rejection;
use crate::response::Response;
use tower_layer::Layer;
use tower_service::Service;

pub struct ErrorHandlerLayer<T> {
    error_handler: T,
}

impl<S, T> Layer<S> for ErrorHandlerLayer<T>
where
    T: Clone,
{
    type Service = ErrorHandler<S, T>;

    fn layer(&self, inner: S) -> Self::Service {
        ErrorHandler {
            inner,
            error_handler: self.error_handler.clone(),
        }
    }
}

#[derive(Clone)]
pub struct ErrorHandler<S, T> {
    inner: S,
    error_handler: T,
}

impl<T> ErrorHandlerLayer<T> {
    pub fn new(error_handler: T) -> Self
    where
        T: Clone + Fn(Rejection) -> Result<Response, Rejection>,
    {
        ErrorHandlerLayer { error_handler }
    }
}

impl<S, F, T> Service<http::Request<hyper::Body>> for ErrorHandler<S, T>
where
    S: Service<http::Request<hyper::Body>, Response = Response, Error = Rejection, Future = F>
        + Send
        + Clone
        + 'static,
    F: Future<Output = Result<Response, Rejection>> + Send + 'static,
    T: Clone + Fn(Rejection) -> Result<Response, Rejection>,
{
    type Response = Response;
    type Error = Rejection;
    type Future = ErrorHandlerFuture<F, T>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future {
        ErrorHandlerFuture {
            inner: self.inner.call(req),
            error_handler: self.error_handler.clone(),
        }
    }
}

#[pin_project::pin_project]
pub struct ErrorHandlerFuture<F, T> {
    #[pin]
    inner: F,
    error_handler: T,
}

impl<F, T> Future for ErrorHandlerFuture<F, T>
where
    F: Future<Output = Result<Response, Rejection>>,
    T: Fn(Rejection) -> Result<Response, Rejection>,
{
    type Output = Result<Response, Rejection>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        let result = match this.inner.poll(cx) {
            Poll::Ready(t) => t,
            Poll::Pending => return Poll::Pending,
        };

        match result {
            Ok(res) => Poll::Ready(Ok(res)),
            Err(err) => Poll::Ready((this.error_handler)(err)),
        }
    }
}

#[cfg(feature = "tracing")]
impl Default
    for ErrorHandlerLayer<fn(Rejection) -> Result<http::Response<hyper::Body>, Rejection>>
{
    fn default() -> Self {
        ErrorHandlerLayer::new(default_error_handler)
    }
}

#[cfg(feature = "tracing")]
fn default_error_handler(err: Rejection) -> Result<Response, Rejection> {
    use crate::rejection::Cause;
    use crate::response::IntoResponse;

    match err.into_cause() {
        Cause::Err(err) => {
            let mut sources = Vec::new();
            let mut source = err.source();
            while let Some(err) = source {
                sources.push(err.to_string());
                source = err.source();
            }
            let sources = sources.into_iter().collect::<String>();
            let msg = err.to_string();
            let result = err.into_response_error();
            let is_server_err = match &result {
                Ok(res) if res.status().is_server_error() => true,
                Err(_) => true,
                _ => false,
            };

            if is_server_err {
                tracing::error!(err = %msg, %sources, "rejected with server error: {}", msg);
            } else {
                // TODO: make debug once we are sure all errors are declared correctly
                tracing::error!(err = %msg, %sources, "rejected: {}", msg);
            }

            result
        }
        Cause::Status(status) => status.into_response(),
    }
}
