use std::convert::TryFrom;
use std::net::SocketAddr;
use std::time::Duration;

use anyhow::{anyhow, bail, Context, Result};
use async_io::Timer;
use async_net::TcpStream;
use async_trait::async_trait;
use byteorder::{BigEndian, ByteOrder};
use bytes::{BufMut, BytesMut};
use futures_lite::{AsyncReadExt, AsyncWriteExt, FutureExt};
use lazy_static::lazy_static;
use rand::Rng;
use tracing::debug;

use crate::client::DnsClient;
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message};
use crate::specs::message::Message;

/// TCP size header is 16 bits, so max theoretical size is 64k
static MAX_TCP_BYTES: u16 = 65535;

lazy_static! {
    /// Encoder instance, currently doesn't have state
    static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}

pub struct Client {
    dns_server: SocketAddr,
    conn: Option<TcpStream>,
    response_buffer: BytesMut,
    timeout_ms: u64,
}

/// DNS Client that queries a server over TCP
impl Client {
    /// Constructs a new `Client` that will query the specified `dns_server`.
    pub fn new(dns_server: SocketAddr, timeout_ms: u64) -> Self {
        Client {
            dns_server,
            conn: None,
            response_buffer: BytesMut::with_capacity(MAX_TCP_BYTES as usize),
            timeout_ms,
        }
    }

    async fn connect(&mut self) -> Result<()> {
        let conn = TcpStream::connect(self.dns_server)
            .or(async {
                Timer::after(Duration::from_millis(self.timeout_ms)).await;
                return Err(std::io::Error::new(
                    std::io::ErrorKind::TimedOut,
                    "TCP connect timed out",
                ));
            })
            .await
            .map_err(|ioerr| {
                anyhow!(
                    "Failed to connect to TCP upstream {:?}: {}",
                    self.dns_server,
                    ioerr
                )
            })?;

        self.conn = Some(conn);
        Ok(())
    }
}

#[async_trait]
impl DnsClient for Client {
    async fn query(&mut self, request: &Message, query_buffer: &mut BytesMut) -> Result<Option<Message>> {
        // Reserve 2 bytes for the TCP-specific length prefix
        query_buffer.reserve(2);
        query_buffer.put_u16(0);

        // Just use our max for the "udp size"
        ENCODER.encode(request, Some(MAX_TCP_BYTES), query_buffer)?;

        // Insert the resulting encoded size of the message into those leading two bytes that we'd reserved
        let message_len = u16::try_from(query_buffer.len() - 2).with_context(|| {
            format!(
                "Encoded request size {} exceeds {} limit: {}",
                query_buffer.len() - 2,
                MAX_TCP_BYTES,
                request
            )
        })?;
        query_buffer[0] = ((message_len & 0xFF00) >> 8) as u8;
        query_buffer[1] = (message_len & 0xFF) as u8;

        // Query is constructed, now let's do the request.
        if self.conn.is_none() {
            self.connect().await?;
        }

        let timeout_ms = self.timeout_ms;
        let request_id = rand::thread_rng().gen::<u16>();
        // For TCP, the size header means that the message actually starts at byte 2
        message::update_message_id(request_id, query_buffer, 2)?;

        debug!(
            "Raw request to {:?} ({}b): {:02X?}",
            self.dns_server,
            query_buffer.len(),
            &query_buffer[..]
        );

        // NOTE: async is useful here, since it allows us to ensure that the entire write completes within the timeout.
        // If we used sync APIs, we would risk a malicious upstream slowly allowing one byte at a time. (sync write_all loops over writes internally)
        if let Err(ioerr) = self
            .conn
            .as_mut()
            .expect("missing connection")
            .write_all(query_buffer.as_ref())
            .or(async {
                Timer::after(Duration::from_millis(timeout_ms)).await;
                return Err(std::io::Error::new(
                    std::io::ErrorKind::TimedOut,
                    "TCP write timed out",
                ));
            })
            .await
        {
            if ioerr.kind() != std::io::ErrorKind::TimedOut {
                // Mark connection as dead, reconnect again on next query
                self.conn = None;
            }
            bail!(
                "Failed to write to TCP upstream {:?}: {}",
                self.dns_server,
                ioerr
            )
        };

        // Read first two bytes to get expected response size
        let mut response_size_bytes: [u8; 2] = [0, 0];
        if let Err(ioerr) = self
            .conn
            .as_mut()
            .expect("missing connection")
            .read_exact(&mut response_size_bytes)
            .or(async {
                Timer::after(Duration::from_millis(timeout_ms)).await;
                return Err(std::io::Error::new(
                    std::io::ErrorKind::TimedOut,
                    "TCP header read timed out",
                ));
            })
            .await
        {
            if ioerr.kind() != std::io::ErrorKind::TimedOut {
                // Mark connection as dead, reconnect again on next query
                self.conn = None;
            }
            bail!(
                "Failed to read header from TCP upstream {:?}: {}",
                self.dns_server,
                ioerr
            )
        };
        // big endian
        let response_size = BigEndian::read_u16(&response_size_bytes);

        // Read remaining bytes to get response
        self.response_buffer.resize(
            usize::try_from(response_size).with_context(|| "couldn't convert u16 to usize")?,
            0,
        );
        // NOTE: async is useful here, since it allows us to ensure that the entire read completes within the timeout.
        // If we used sync APIs, we would risk a malicious upstream slowly allowing one byte at a time. (sync read_all loops over writes internally)
        if let Err(ioerr) = self
            .conn
            .as_mut()
            .expect("missing connection")
            .read_exact(&mut self.response_buffer)
            .or(async {
                Timer::after(Duration::from_millis(timeout_ms)).await;
                return Err(std::io::Error::new(
                    std::io::ErrorKind::TimedOut,
                    "TCP payload read timed out",
                ));
            })
            .await
        {
            if ioerr.kind() != std::io::ErrorKind::TimedOut {
                // Mark connection as dead, reconnect again on next query
                self.conn = None;
            }
            bail!(
                "Failed to read payload from TCP upstream {:?}: {}",
                self.dns_server,
                ioerr
            )
        };

        debug!(
            "Raw response from {:?} ({}b): {:02X?}",
            self.dns_server,
            self.response_buffer.len(),
            &self.response_buffer[..]
        );

        match DNSMessageDecoder::new().decode(&self.response_buffer[..]) {
            Ok(Some(response)) => {
                debug!("Response from {:?}: {}", self.dns_server, response);

                if response.header.truncated {
                    // Message claims to be truncated, shouldn't happen but let's bail anyway
                    return Ok(None);
                }
                if response.header.id != request_id {
                    bail!(
                        "Returned transaction id {:?} doesn't match sent {:?}",
                        response.header.id,
                        request_id
                    );
                }

                Ok(Some(response))
            }
            Ok(None) => {
                // Message was likely truncated, despite us receiving all the data in the payload
                debug!(
                    "Unable to parse response from server={} to request={:02X?}: {:02X?}",
                    self.dns_server,
                    &query_buffer[..],
                    &self.response_buffer[..],
                );
                Ok(None)
            }
            Err(e) => {
                // Other parse error
                Err(e).context(format!(
                    "Failed to parse response from server={} to request={:02X?}: {:02X?}",
                    self.dns_server,
                    &query_buffer[..],
                    &self.response_buffer[..],
                ))
            }
        }
    }
}
