use std::collections::HashSet;
use std::future::Future;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use crate::rejection::Rejection;
use crate::response::{IntoResponse, Response};
use headers::{
    AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders,
    AccessControlRequestHeaders, AccessControlRequestMethod, HeaderMapExt,
};
use http::header::{self, HeaderName};
use http::{HeaderValue, Method, StatusCode, Uri};
use tower_layer::Layer;
use tower_service::Service;

#[derive(Default)]
pub struct CorsLayer {
    opts: CorsOptions,
}

impl<S> Layer<S> for CorsLayer {
    type Service = Cors<S>;

    fn layer(&self, inner: S) -> Self::Service {
        Cors {
            inner,
            headers: Arc::new(PrebuildHeaders {
                allow_headers: self.opts.allowed_headers.iter().cloned().collect(),
                allow_methods: self
                    .opts
                    .allowed_methods
                    .iter()
                    .cloned()
                    .map(Method::from)
                    .collect(),
                expose_headers: if self.opts.exposed_headers.is_empty() {
                    None
                } else {
                    Some(self.opts.exposed_headers.iter().cloned().collect())
                },
            }),
            opts: Arc::new(self.opts.clone()),
        }
    }
}

#[derive(Clone)]
pub struct Cors<S> {
    inner: S,
    opts: Arc<CorsOptions>,
    headers: Arc<PrebuildHeaders>,
}

#[derive(Default, Clone)]
struct CorsOptions {
    allow_credentials: bool,
    allowed_headers: HashSet<HeaderName>,
    exposed_headers: HashSet<HeaderName>,
    max_age: Option<u64>,
    allowed_methods: HashSet<AccessControlRequestMethod>,
    origins: Option<Origins>,
}

struct PrebuildHeaders {
    allow_headers: AccessControlAllowHeaders,
    allow_methods: AccessControlAllowMethods,
    expose_headers: Option<AccessControlExposeHeaders>,
}

#[derive(Clone, PartialEq)]
enum Origins {
    Any,
    Wildcard,
    Restricted(HashSet<HeaderValue>),
}

impl CorsLayer {
    /// # Panics
    ///
    /// Panics if the provided argument is not a valid `Origin`.
    pub fn allow_origin(mut self, origin: &str) -> Self {
        if origin == "*" {
            self.opts.origins = Some(Origins::Wildcard);
            return self;
        }

        let uri = Uri::from_str(origin).expect("invalid origin");
        let scheme = uri.scheme().expect("invalid origin (scheme missing)");
        let authority = uri.authority().expect("invalid origin (host missing)");

        if !matches!(uri.path_and_query().map(|p| p.as_str()), None | Some("/")) {
            panic!("invalid origin (origin must not contain a path/query");
        }

        let origin = &origin[..(scheme.as_str().len() + authority.as_str().len() + 3)];
        let origin = HeaderValue::from_str(origin).expect("invalid origin");

        #[allow(clippy::mutable_key_type)]
        let mut origins = match self.opts.origins.take() {
            Some(Origins::Restricted(origins)) => origins,
            _ => HashSet::new(),
        };
        origins.insert(origin);
        self.opts.origins = Some(Origins::Restricted(origins));

        self
    }

    /// # Warning
    ///
    /// This allows any website from using your endpoint, it is thus highly recommended to
    /// explicitly set a list of allowed origins via [`allow_origin`] instead, or see if your
    /// use-case can be solved with a wildcard origin first (`allow_origin("*")`).
    pub fn allow_any_origin(mut self) -> Self {
        self.opts.origins = Some(Origins::Any);
        self
    }

    pub fn allow_method(mut self, method: Method) -> Self {
        self.opts.allowed_methods.insert(method.into());
        self
    }

    pub fn allow_header(mut self, header: impl Into<HeaderName>) -> Self {
        self.opts.allowed_headers.insert(header.into());
        self
    }

    pub fn expose_header(mut self, header: impl Into<HeaderName>) -> Self {
        self.opts.exposed_headers.insert(header.into());
        self
    }

    pub fn allow_credentials(mut self) -> Self {
        self.opts.allow_credentials = true;
        self
    }

    pub fn max_age(mut self, age: Duration) -> Self {
        self.opts.max_age = Some(age.as_secs());
        self
    }
}

impl CorsOptions {
    fn is_valid_origin(&self, origin: &HeaderValue) -> bool {
        match &self.origins {
            Some(Origins::Any) | Some(Origins::Wildcard) => true,
            Some(Origins::Restricted(origins)) => origins.contains(origin),
            None => false,
        }
    }
}

impl<S> Cors<S> {
    async fn handle_request<F>(
        mut inner: S,
        opts: Arc<CorsOptions>,
        prebuild: Arc<PrebuildHeaders>,
        req: http::Request<hyper::Body>,
    ) -> Result<Response, Rejection>
    where
        S: Service<http::Request<hyper::Body>, Response = Response, Error = Rejection, Future = F>,
        F: Future<Output = Result<Response, Rejection>> + Send,
    {
        // It is only potentially a CORS request if it has an `Origin` header
        // https://fetch.spec.whatwg.org/#http-requests
        let origin = match req.headers().get(http::header::ORIGIN) {
            Some(origin) => origin,
            None => {
                // not a CORS request
                return inner.call(req).await.or_else(|err| err.into_response());
            }
        };

        if !opts.is_valid_origin(origin) {
            return StatusCode::UNAUTHORIZED.into_response();
        }

        let origin = if opts.origins == Some(Origins::Wildcard) {
            HeaderValue::from_static("*")
        } else {
            origin.to_owned()
        };

        let mut res = if req.method() == Method::OPTIONS {
            let headers = req.headers();
            let request_method = headers.typed_get::<AccessControlRequestMethod>();
            if !request_method
                .as_ref()
                .map(|m| opts.allowed_methods.contains(m))
                .unwrap_or(false)
            {
                return StatusCode::FORBIDDEN.into_response();
            }

            match headers.typed_get::<AccessControlRequestHeaders>() {
                Some(headers) => {
                    for name in headers.iter() {
                        if !opts.allowed_headers.contains(&name) {
                            return StatusCode::FORBIDDEN.into_response();
                        }
                    }
                }
                None => return StatusCode::FORBIDDEN.into_response(),
            }

            let mut res = http::Response::<hyper::Body>::default();
            let headers = res.headers_mut();
            headers.typed_insert(prebuild.allow_methods.clone());
            headers.typed_insert(prebuild.allow_headers.clone());

            if let Some(max_age) = opts.max_age {
                headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
            }

            res
        } else {
            inner.call(req).await.or_else(|err| err.into_response())?
        };

        // add CORS header to response
        let headers = res.headers_mut();
        if !headers.contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN) {
            headers.insert(http::header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
        }

        if opts.allow_credentials {
            headers.insert(
                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
                HeaderValue::from_static("true"),
            );
        }
        if let Some(expose_headers) = &prebuild.expose_headers {
            headers.typed_insert(expose_headers.clone())
        }

        Ok(res)
    }
}

type BoxTrySendFuture<R, E> = Pin<Box<dyn Future<Output = Result<R, E>> + Send>>;

impl<S, F> Service<http::Request<hyper::Body>> for Cors<S>
where
    S: Service<http::Request<hyper::Body>, Response = Response, Error = Rejection, Future = F>
        + Send
        + Clone
        + 'static,
    F: Future<Output = Result<Response, Rejection>> + Send + 'static,
{
    type Response = Response;
    type Error = Rejection;
    type Future = BoxTrySendFuture<Self::Response, Self::Error>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future {
        Box::pin(Self::handle_request(
            self.inner.clone(),
            self.opts.clone(),
            self.headers.clone(),
            req,
        ))
    }
}

#[cfg(test)]
mod tests {
    use std::future::{ready, Ready};
    use std::sync::Mutex;

    use super::*;
    use tower::ServiceExt;

    struct MockService {
        res: Arc<Mutex<Option<Response>>>,
    }

    impl Clone for MockService {
        fn clone(&self) -> Self {
            MockService {
                res: Arc::new(Mutex::new(self.res.lock().unwrap().take())),
            }
        }
    }

    impl Service<http::Request<hyper::Body>> for MockService {
        type Response = Response;
        type Error = Rejection;
        type Future = Ready<Result<Self::Response, Self::Error>>;

        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
            Poll::Ready(Ok(()))
        }

        fn call(&mut self, _req: http::Request<hyper::Body>) -> Self::Future {
            ready(Ok(self
                .res
                .lock()
                .unwrap()
                .take()
                .expect("response already taken")))
        }
    }

    async fn call_service(cors: &CorsLayer, origin: &str) -> crate::response::Response {
        let req = http::Request::builder()
            .header("Origin", origin)
            .body(hyper::Body::empty())
            .unwrap();

        let mut cors = cors.layer(MockService {
            res: Arc::new(Mutex::new(Some(
                http::Response::builder()
                    .body(hyper::Body::empty())
                    .unwrap(),
            ))),
        });

        cors.ready().await.unwrap();
        cors.call(req).await.unwrap()
    }

    #[tokio::test]
    async fn test_origin_not_allowed() {
        let c = CorsLayer::default();
        let res = call_service(&c, "https://solarsail.dev").await;
        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);

        let res = call_service(&c, "https://something.else").await;
        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
    }

    #[tokio::test]
    async fn test_origin_allowed() {
        let c = CorsLayer::default()
            .allow_origin("https://something.else")
            .allow_origin("https://solarsail.dev");
        let res = call_service(&c, "https://solarsail.dev").await;
        assert_eq!(
            res.headers()
                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
                .and_then(|v| v.to_str().ok()),
            Some("https://solarsail.dev")
        );
    }

    #[tokio::test]
    async fn test_any_origin() {
        let c = CorsLayer::default().allow_any_origin();

        let res = call_service(&c, "https://solarsail.dev").await;
        assert_eq!(
            res.headers()
                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
                .and_then(|v| v.to_str().ok()),
            Some("https://solarsail.dev")
        );

        let res = call_service(&c, "https://www.rust-lang.org").await;
        assert_eq!(
            res.headers()
                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
                .and_then(|v| v.to_str().ok()),
            Some("https://www.rust-lang.org")
        );
    }

    #[tokio::test]
    async fn test_wildcard_origin() {
        let c = CorsLayer::default().allow_origin("*");

        let res = call_service(&c, "https://solarsail.dev").await;
        assert_eq!(
            res.headers()
                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
                .and_then(|v| v.to_str().ok()),
            Some("*")
        );

        let res = call_service(&c, "https://www.rust-lang.org").await;
        assert_eq!(
            res.headers()
                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
                .and_then(|v| v.to_str().ok()),
            Some("*")
        );
    }
}
