use std::net::{Shutdown, SocketAddr, TcpStream, ToSocketAddrs};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};

use anyhow::{anyhow, bail, Context as _, Error, Result};
use async_lock::{Barrier, Mutex};
use async_rustls::{TlsStream, webpki::DNSNameRef};
use http::Uri;
use hyper::{Body, Client};
use smol::{io, prelude::*, Async, Task};
use tokio::io::ReadBuf;
use tracing::{trace, warn};

use crate::resolver;

/// Returns a new Hyper HTTP client:
/// - Using smol for the connection and async runtime
/// - Using the provided Originz Resolver for resolving any hostnames
pub fn client_originz(
    mut resolver: resolver::Resolver,
    http2_only: bool,
    get_ipv6: bool,
    udp_size: u16,
) -> Client<SmolConnector> {
    let (resolver_tx, resolver_rx): (
        async_channel::Sender<ResolverQuery>,
        async_channel::Receiver<ResolverQuery>,
    ) = async_channel::bounded(32);

    // Create a separate task that will perform lookups on behalf of the SmolConnector.
    // This is mainly to get around locking + async issues across the spawned tasks for each query.
    // The only other alternative would be to create a whole new Resolver with each query.
    // Also, we don't worry about tracking the handle for this task because it should expire automatically when resolver_tx is dropped.
    let resolver_task = Arc::new(smol::spawn(async move {
        trace!("Internal resolver waiting for requests");
        // Returns Err when channel is closed and has no more messages
        while let Ok(msg) = resolver_rx.recv().await {
            trace!("Internal resolver: {}", msg.host);
            // If the "host" appears to already be an IP, return it as-is rather than trying to resolve it.
            // This effectively mirrors the behavior of the system resolver via client_system().
            // This is not in the user query path, and should only come up if the admin e.g. gives an upstream as "https://127.0.0.1"
            let endpoint_str = format!("{}:{}", msg.host, msg.port);
            if let Ok(lookup_result) = SocketAddr::from_str(endpoint_str.as_str()) {
                trace!(
                    "Internal resolver IP shortcut: {} = {:?}",
                    endpoint_str,
                    lookup_result
                );
                // Store the result, then notify the barrier
                msg.result.lock().await.replace(Ok(lookup_result));
                msg.result_barrier.wait().await;
            } else {
                // It doesn't look like a socket address, so do the resolve.
                let lookup_result = resolver
                    .resolve_str(&msg.host, msg.port, get_ipv6, udp_size)
                    .await;
                if let Err(e) = &lookup_result {
                    warn!("Internal resolver lookup failed: {:?}", e);
                } else {
                    trace!("Internal resolver: {} = {:?}", endpoint_str, lookup_result);
                }
                // Store the result, then notify the barrier
                msg.result.lock().await.replace(lookup_result);
                msg.result_barrier.wait().await;
            }
        }
        trace!("Internal resolver exiting");
    }));

    Client::builder()
        .executor(SmolExecutor)
        .http2_only(http2_only)
        .build::<_, Body>(SmolConnector {
            _resolver_task: resolver_task,
            resolver_tx,
        })
}

/// Returns a new Hyper HTTP client:
/// - Using smol for the connection and async runtime
/// - Using the system resolver for resolving any hostnames
pub fn client_system(http2_only: bool) -> Client<SmolConnector> {
    let (resolver_tx, resolver_rx): (
        async_channel::Sender<ResolverQuery>,
        async_channel::Receiver<ResolverQuery>,
    ) = async_channel::bounded(32);

    // Create a separate task that will perform lookups on behalf of the SmolConnector.
    // This isn't strictly needed for the system resolver, but keeps things in line with client_originz().
    let resolver_task = Arc::new(smol::spawn(async move {
        trace!("System resolver waiting for requests");
        // Returns Err when channel is closed and has no more messages
        while let Ok(msg) = resolver_rx.recv().await {
            trace!("System resolver: {}", msg.host);
            let host = msg.host.clone();
            let port = msg.port;
            let lookup_result =
                match smol::unblock(move || (host.as_str(), port).to_socket_addrs()).await {
                    Ok(mut socket_addrs) => match socket_addrs.next() {
                        Some(socket_addr) => Ok(socket_addr),
                        None => Err(anyhow!("No results for hostname {}", msg.host)),
                    },
                    Err(e) => {
                        Err(e).with_context(|| format!("Failed to query for hostname {}", msg.host))
                    }
                };

            trace!("System resolver: {} = {:?}", msg.host, lookup_result);
            // Store the result, then notify the barrier
            msg.result.lock().await.replace(lookup_result);
            msg.result_barrier.wait().await;
        }
        trace!("System resolver exiting");
    }));

    Client::builder()
        .executor(SmolExecutor)
        .http2_only(http2_only)
        .build::<_, Body>(SmolConnector {
            _resolver_task: resolver_task,
            resolver_tx,
        })
}

/// Spawns futures.
#[derive(Clone)]
struct SmolExecutor;

impl<F: Future + Send + 'static> hyper::rt::Executor<F> for SmolExecutor {
    fn execute(&self, fut: F) {
        smol::spawn(async { drop(fut.await) }).detach();
    }
}

/// The request for a host to be resolved, along with an output for returning the response.
#[derive(Debug)]
struct ResolverQuery {
    /// The hostname to look up
    host: String,
    /// The port to include in the resolved result
    port: u16,
    /// Barrier to wait for the result to appear. The requestor should wait on this before accessing result.
    result_barrier: Arc<Barrier>,
    /// Where the result should go. Should be an error if the hostname could not be resolved (e.g. not found).
    result: Arc<Mutex<Option<Result<SocketAddr>>>>,
}

/// Connects to URLs.
#[derive(Clone)]
pub struct SmolConnector {
    /// Handle to keep the resolver task from dying prematurely
    _resolver_task: Arc<Task<()>>,
    /// Channel for sending requests to the resolver task
    resolver_tx: async_channel::Sender<ResolverQuery>,
}

impl hyper::service::Service<Uri> for SmolConnector {
    type Response = SmolStream;
    type Error = Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, uri: Uri) -> Self::Future {
        // Get copy for async move:
        let resolver_tx_copy = self.resolver_tx.clone();

        Box::pin(async move {
            let host = uri
                .host()
                .with_context(|| format!("Cannot parse host: {:?}", uri))?;
            // Release when both the requestor and requestee have called wait
            let result_barrier = Arc::new(Barrier::new(2));
            // Where the requestee will store the result before calling result_barrier.wait
            let result = Arc::new(Mutex::new(None));
            match uri.scheme_str() {
                Some("http") => {
                    // Send lookup, with place to give us back the result:
                    trace!("HTTP lookup: {}", host);
                    let query = ResolverQuery {
                        host: host.to_string(),
                        port: uri.port_u16().unwrap_or(80),
                        result_barrier: result_barrier.clone(),
                        result: result.clone(),
                    };
                    resolver_tx_copy
                        .send(query)
                        .await
                        .context("Failed to send HTTP resolver query")?;
                    trace!("HTTP lookup sent");
                    // Wait on the barrier to complete
                    result_barrier.wait().await;
                    // Barrier has completed, get the stored result.
                    // Jump through weird reference hoops to get the SocketAddr value out of the mutex.
                    match result
                        .lock()
                        .await
                        .as_ref()
                        .expect("Missing resolve result following barrier")
                    {
                        Ok(socket_addr) => {
                            let stream = Async::<TcpStream>::connect(socket_addr.clone()).await?;
                            Ok(SmolStream::Plain(stream))
                        }
                        // e is a reference, and anyhow doesn't like using with_context with it. So just give up and create a new error.
                        Err(e) => Err(anyhow!("Failed to resolve host {:?}: {}", uri, e)),
                    }
                }
                Some("https") => {
                    // Send lookup, with place to give us back the result:
                    trace!("HTTPS lookup: {}", host);
                    let query = ResolverQuery {
                        host: host.to_string(),
                        port: uri.port_u16().unwrap_or(443),
                        result_barrier: result_barrier.clone(),
                        result: result.clone(),
                    };
                    resolver_tx_copy
                        .send(query)
                        .await
                        .context("Failed to send HTTPS resolver query")?;
                    trace!("HTTPS lookup sent");
                    // Wait on the barrier to complete
                    result_barrier.wait().await;
                    // Barrier has completed, get the stored result.
                    // Jump through weird reference hoops to get the SocketAddr value out of the mutex.
                    match result
                        .lock()
                        .await
                        .as_ref()
                        .expect("Missing resolve result following barrier")
                    {
                        Ok(socket_addr) => {
                            let stream = Async::<TcpStream>::connect(socket_addr.clone()).await?;
                            let mut client_config = async_rustls::rustls::ClientConfig::new();
                            let mut root_certs = async_rustls::rustls::RootCertStore::empty();
                            for cert in rustls_native_certs::load_native_certs().expect("Failed to load native TLS certs") {
                                root_certs
                                    .add(&async_rustls::rustls::Certificate(cert.0))
                                    .expect("Failed to add native TLS cert");
                            }
                            // Required for http2/DoH, otherwise we get 'http2 error: protocol error: frame with invalid size'
                            // In particular, we're stuck with rustls for now because rust-native-tls doesn't support configuring this.
                            client_config.alpn_protocols = vec![b"h2".to_vec()];
                            // Disabled for now: Was previously using ct-logs package, but bizarre API mismatches broke it.
                            client_config.ct_logs = None;
                            let connector =
                                async_rustls::TlsConnector::from(Arc::new(client_config));
                            if let Ok(dns_name) = webpki::DnsNameRef::try_from_ascii_str(host) {
                                // Convert new webpki::DnsNameRef (webpki 0.22) to old webpki::DNSNameRef (webpki 0.21) used by async-rustls
                                // See also https://github.com/smol-rs/async-rustls/blob/master/Cargo.toml
                                let dns_name_old: DNSNameRef = DNSNameRef::try_from_ascii(dns_name.as_ref())
                                    .expect("Converting new DnsNameRef to old DNSNameRef failed");
                                let stream = connector.connect(dns_name_old, stream).await?;
                                Ok(SmolStream::Tls(async_rustls::TlsStream::Client(stream)))
                            } else {
                                // Uh-oh, looks like we're trying to connect to an IP. Explain the issue with the underlying library and how to work around it.
                                bail!(
                                    "Unable to parse TLS endpoint: {}
rustls/webpki still don't support IP endpoints. See also: https://github.com/briansmith/webpki/issues/54 and https://github.com/ctz/rustls/issues/184
Try using a hostname instead, e.g. 9.9.9.9 => dns.quad9.net, 8.8.8.8 => dns.google, or 1.1.1.1 => cloudflare-dns.com
If this hostname is a DoH or DoT upstream, you will also need to include at least one IP-based 'regular' UDP/TCP upstream as a fallback so that the DoH or DoT hostname can itself be resolved.",
                                    host
                                );
                            }
                        }
                        // e is a reference, and anyhow doesn't like using with_context with it. So just give up and create a new error.
                        Err(e) => Err(anyhow!("Failed to resolve host {:?}: {}", uri, e)),
                    }
                }
                scheme => bail!("Unsupported scheme: {:?}", scheme),
            }
        })
    }
}

/// A TCP or TCP+TLS connection.
pub enum SmolStream {
    /// A plain TCP connection.
    Plain(Async<TcpStream>),

    /// A TCP connection secured by TLS.
    Tls(TlsStream<Async<TcpStream>>),
}

impl hyper::client::connect::Connection for SmolStream {
    fn connected(&self) -> hyper::client::connect::Connected {
        hyper::client::connect::Connected::new()
    }
}

impl tokio::io::AsyncRead for SmolStream {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        match &mut *self {
            SmolStream::Plain(s) => {
                Pin::new(s)
                    .poll_read(cx, buf.initialize_unfilled())
                    .map_ok(|size| {
                        buf.advance(size);
                        ()
                    })
            }
            SmolStream::Tls(s) => {
                Pin::new(s)
                    .poll_read(cx, buf.initialize_unfilled())
                    .map_ok(|size| {
                        buf.advance(size);
                        ()
                    })
            }
        }
    }
}

impl tokio::io::AsyncWrite for SmolStream {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        match &mut *self {
            SmolStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
            SmolStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
        }
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match &mut *self {
            SmolStream::Plain(s) => Pin::new(s).poll_flush(cx),
            SmolStream::Tls(s) => Pin::new(s).poll_flush(cx),
        }
    }

    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match &mut *self {
            SmolStream::Plain(s) => {
                s.get_ref().shutdown(Shutdown::Write)?;
                Poll::Ready(Ok(()))
            }
            SmolStream::Tls(s) => Pin::new(s).poll_close(cx),
        }
    }
}
