use std::net::{SocketAddr, ToSocketAddrs};
use std::time::Duration;

use anyhow::{bail, Context, Result};
use async_io::Timer;
use async_net::UdpSocket;
use async_trait::async_trait;
use bytes::BytesMut;
use futures_lite::FutureExt;
use lazy_static::lazy_static;
use rand::Rng;
use scopeguard;
use tracing::{debug, trace, warn};

use crate::client::DnsClient;
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message};
use crate::specs::message::Message;

pub struct Client {
    dns_server: SocketAddr,
    last_udp_size: u16,
    response_buffer: Vec<u8>,
    timeout_ms: u64,
}

/// DNS Client that queries a server over UDP with fallback to TCP.
/// The timeout logic uses base2 exponential retry (1s, 2s, 4s, ...)
/// If a UDP response comes back truncated, the client will automatically retry the request over TCP.
impl Client {
    /// Constructs a new `Client` that will query the specified `dns_server`.
    /// timeout_ms: Approximate timeout for requests, actual timeout will be the next base2-1 amount.
    ///             For example 10000 -> 15000 (1s + 2s + 4s + 8s+)
    pub fn new(dns_server: SocketAddr, timeout_ms: u64) -> Self {
        Client {
            dns_server,
            last_udp_size: 4096,
            response_buffer: vec![0; 4096],
            timeout_ms,
        }
    }
}

lazy_static! {
    /// Encoder instance, currently doesn't have state
    static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}

#[async_trait]
impl DnsClient for Client {
    async fn query(&mut self, request: &Message, query_buffer: &mut BytesMut) -> Result<Option<Message>> {
        // Ensure the request to our server has the correct UDP size in the request.
        ENCODER.encode(request, Some(self.last_udp_size), query_buffer)?;

        let mut request_id: u16 = 0;

        // Ensure that response_buffer is reset to size=4096 when we're done with it,
        // regardless of success or error. The socket read is based on len, not capacity.
        let mut response_buffer = scopeguard::guard(&mut self.response_buffer, |buf| {
            buf.resize(4096, 0);
        });
        let response_size: usize;
        {
            let mut response_buffer_slice = response_buffer.as_mut();
            response_size = send_recv_exponential_backoff(
                &self.dns_server,
                query_buffer,
                &mut request_id,
                &mut response_buffer_slice,
                self.timeout_ms,
            )
                .await?;
            // Shorten to actual size received
            response_buffer.truncate(response_size);
        }

        debug!(
            "Raw response from {:?} ({}b): {:02X?}",
            self.dns_server,
            response_buffer.len(),
            &response_buffer[..]
        );

        match DNSMessageDecoder::new().decode(&response_buffer) {
            Ok(Some(response)) => {
                debug!("Response from {:?}: {}", self.dns_server, response);

                if response.header.truncated {
                    // Message claims to be truncated
                    return Ok(None);
                }
                if response.header.id != request_id {
                    bail!(
                        "Returned transaction id {:?} doesn't match sent {:?}",
                        response.header.id,
                        request_id
                    );
                }

                // After passing validation, update udp_size for the next request to this server.
                if let Some(opt) = &response.opt {
                    trace!(
                        "Using udp_size={} for server={}",
                        opt.udp_size,
                        self.dns_server
                    );
                    self.last_udp_size = opt.udp_size;
                }

                Ok(Some(response))
            }
            Ok(None) => {
                // Message was likely truncated, upstream can fall back to TCP
                debug!(
                    "Unable to parse response from server={} to request={:02X?}: {:02X?}",
                    self.dns_server,
                    &query_buffer[..],
                    &response_buffer[..],
                );
                Ok(None)
            }
            Err(e) => {
                // Other parse error
                Err(e).context(format!(
                    "Failed to parse response from server={} to request={:02X?}: {:02X?}",
                    self.dns_server,
                    &query_buffer[..],
                    &response_buffer[..],
                ))
            }
        }
    }
}

async fn send_recv_exponential_backoff(
    dest: &SocketAddr,
    query_buffer: &mut BytesMut,
    request_id: &mut u16,
    mut response_buffer: &mut [u8],
    total_timeout_ms: u64,
) -> Result<usize> {
    // Start at 1s, then 2s, then 4s, ...
    let mut remaining_timeout_ms = total_timeout_ms;
    let mut timeout_ms = 1000;
    loop {
        // NOTE: This assumes that port 0 results in a random port each time.
        // In particular we DONT want it to just increment by 1 or something each time.
        // Apparently this is OS-specific but Linux at least should do what we want.
        let client_addr = "0.0.0.0:0".to_socket_addrs()?.next().unwrap();
        let conn = UdpSocket::bind(client_addr).await?;

        // We regenerate the request ID on every retry. We're changing the client port each time,
        // so scrambling the request ID shouldn't result in "old" mismatched responses anyway.
        // This reduces the likelihood of someone trying to poison our cache by sending a request
        // and then flooding us with responses that match that request's message id.
        *request_id = rand::thread_rng().gen::<u16>();
        message::update_message_id(*request_id, query_buffer, 0)?;

        debug!(
            "Raw request to {:?} ({}b): {:02X?}",
            &dest,
            query_buffer.len(),
            &query_buffer[..]
        );
        // (Re)send request.
        let _sendsize = conn.send_to(query_buffer.as_ref(), dest).await?;

        match conn
            .recv_from(&mut response_buffer)
            .or(async {
                Timer::after(Duration::from_millis(timeout_ms)).await;
                return Err(std::io::Error::new(
                    std::io::ErrorKind::TimedOut,
                    "UDP receive timed out",
                ));
            })
            .await
        {
            // Got a response from somewhere
            Ok((recvsize, recvdest)) => {
                // Before returning, check that the response is from who we're waiting for
                if *dest == recvdest {
                    return Ok(recvsize);
                }
                // If it doesn't match, resend and resume waiting, unless this was the last retry
                warn!(
                    "Response origin {:?} doesn't match request target {:?}",
                    recvdest, dest
                );
            }
            Err(e) => {
                if crate::client::is_timeout(e.kind()) {
                    // Timeout occurred, try again (or exit loop)
                    debug!("UDP request to {} timed out after {}ms", dest, timeout_ms);
                } else {
                    // A different error occurred, give up
                    return Err(e).with_context(|| {
                        format!("Failed to receive DNS response from {}", dest)
                    })?;
                }
            }
        }

        timeout_ms *= 2;
        if remaining_timeout_ms == 0 {
            // No retries left, give up
            bail!("Timed out waiting for response from {:?}", dest);
        } else if remaining_timeout_ms <= timeout_ms {
            // Last retry
            remaining_timeout_ms = 0;
        } else {
            // More retries left after this one
            remaining_timeout_ms -= timeout_ms;
        }
    }
}
