use std::io::{self, Write};
use std::time::Duration;

use anyhow::{bail, Context, Result};
use async_io::Timer;
use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use futures_lite::FutureExt;
use hyper::header;
use hyper::{Body, Client as HttpClient, Method, Uri};
use lazy_static::lazy_static;
use rand::Rng;
use scopeguard;
use tracing::debug;

use crate::client::DnsClient;
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message};
use crate::fetcher::Fetcher;
use crate::hyper_smol;
use crate::resolver;
use crate::specs::message::Message;

static MAX_HTTP_BYTES: u16 = 65535;

lazy_static! {
    /// Encoder instance, currently doesn't have state
    static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}

pub struct Client {
    server_url: Uri,
    fetcher: Fetcher,
    client: HttpClient<hyper_smol::SmolConnector>,
    timeout_ms: u64,
    response_buffer: BytesMut,
}

/// DNS Client that queries a server over HTTPS (DoH/RFC8484)
impl Client {
    /// Constructs a new `Client` that will query the specified `dns_server`.
    pub fn new(server_url: Uri, resolver: resolver::Resolver, timeout_ms: u64) -> Result<Self> {
        Ok(Client {
            server_url,
            fetcher: Fetcher::new(
                MAX_HTTP_BYTES as usize,
                Some("application/dns-message".to_string()),
            )
            // Note that hyper will reject requests with "request has unsupported HTTP version",
            // unless we ALSO set "http2_only(true)" in the Client builder.
            .use_http_2(),
            client: hyper_smol::client_originz(resolver, true, false, 4096),
            timeout_ms,
            response_buffer: BytesMut::with_capacity(MAX_HTTP_BYTES as usize),
        })
    }
}

#[async_trait]
impl DnsClient for Client {
    async fn query(&mut self, request: &Message, query_buffer: &mut BytesMut) -> Result<Option<Message>> {
        // Just use our max for the "udp size"
        ENCODER.encode(request, Some(MAX_HTTP_BYTES), query_buffer)?;

        let request_id = rand::thread_rng().gen::<u16>();
        message::update_message_id(request_id, query_buffer, 0)?;

        // Ensure that response_buffer size is reset when we're done with it,
        // regardless of success or error. The socket read is based on len, not capacity.
        let mut response_buffer = scopeguard::guard(&mut self.response_buffer, |buf| {
            buf.clear();
        });

        debug!(
            "Raw request to {} ({}b): {:02X?}",
            self.server_url,
            query_buffer.len(),
            &query_buffer[..]
        );

        // Hyper apparently requires a static copy of the body content? No idea why, and don't care to find out.
        // Maybe we can clean this up by switching to something with a better API someday.
        let request_copy = hyper::body::Bytes::from(query_buffer.clone());
        let request = self
            .fetcher
            .request_builder(&Method::POST, &self.server_url)
            .header(header::CONTENT_TYPE, "application/dns-message")
            .header(header::CONTENT_LENGTH, request_copy.len())
            .body(Body::from(request_copy))
            .context("Failed to build DoH request")?;

        // Copy for async call:
        let timeout_ms = self.timeout_ms;
        let mut response = self
            .client
            .request(request)
            .or(async {
                Timer::after(Duration::from_millis(timeout_ms)).await;
                // hyper keeps error types crate-private. Jump through hoops to produce an Ok response with an error.
                let response = hyper::Response::new(hyper::Body::empty());
                let (mut parts, body) = response.into_parts();
                parts.status = http::StatusCode::GATEWAY_TIMEOUT;
                Ok(hyper::Response::from_parts(parts, body))
            })
            .await
            .context("DoH query failed")?;

        if !response.status().is_success() {
            bail!(
                "HTTP POST to {} returned status: {}",
                self.server_url,
                response.status()
            );
        }

        {
            // Write response payload into response_buffer
            let mut writer = BytesWriter::new(&mut response_buffer);
            self.fetcher
                .write_response(&self.server_url.to_string(), &mut writer, &mut response)
                .await?;
        }

        debug!(
            "Raw response from {} ({}b): {:02X?}",
            self.server_url,
            response_buffer.len(),
            &response_buffer[..]
        );

        match DNSMessageDecoder::new().decode(&response_buffer[..]) {
            Ok(Some(response)) => {
                debug!("Response from {}: {}", self.server_url, response);

                if response.header.truncated {
                    // Message claims to be truncated, shouldn't happen but let's bail anyway
                    return Ok(None);
                }
                if response.header.id != request_id {
                    bail!(
                        "Returned transaction id {:?} doesn't match sent {:?}",
                        response.header.id,
                        request_id
                    );
                }

                Ok(Some(response))
            }
            Ok(None) => {
                // Message was likely corrupted somehow, despite us receiving all the data in the payload
                debug!(
                    "Unable to parse response from server={} to request={:02X?}: {:02X?}",
                    self.server_url,
                    &query_buffer[..],
                    &response_buffer[..],
                );
                Ok(None)
            }
            Err(e) => {
                // Other parse error
                Err(e).context(format!(
                    "Failed to parse response from server={} to request={:02X?}: {:02X?}",
                    self.server_url,
                    &query_buffer[..],
                    &response_buffer[..],
                ))
            }
        }
    }
}

/// Pass-through writer that counts the number of bytes that have been written.
/// Used to consistently measure the decompressed size of a download.
struct BytesWriter<'a> {
    inner: &'a mut BytesMut,
}

impl<'a> BytesWriter<'a> {
    fn new(inner: &'a mut BytesMut) -> BytesWriter<'a> {
        BytesWriter { inner }
    }
}

impl<'a> Write for BytesWriter<'a> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        if self.inner.remaining_mut() >= buf.len() {
            self.inner.put_slice(buf);
            Ok(buf.len())
        } else {
            Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                format!(
                    "Unable to write {} bytes into buffer: {}/{} remaining",
                    buf.len(),
                    self.inner.remaining_mut(),
                    self.inner.len()
                ),
            ))
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        Ok(())
    }
}
