use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;

use anyhow::{bail, Context, Result};
use async_lock::{Barrier, Mutex};
use bytes::BytesMut;
use lazy_static::lazy_static;
use tracing::{debug, warn};

use crate::cache;
use crate::client::DnsClient;
use crate::codec::{encoder::DNSMessageEncoder, message::RequestInfo};
use crate::specs::enums_generated::{OpCode, ResourceClass, ResourceType, ResponseCode};
use crate::specs::message::*;

lazy_static! {
    /// Encoder instance, currently doesn't have state
    static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}

/// A DNS resolver that queries from a set of upstream DNS sources.
pub struct Resolver {
    cache_tx: async_channel::Sender<cache::task::CacheMsg>,
    clients: Vec<Box<dyn DnsClient + Send>>,
    client_buffer: BytesMut,
}

impl Resolver {
    pub fn new(
        cache_tx: async_channel::Sender<cache::task::CacheMsg>,
        clients: Vec<Box<dyn DnsClient + Send>>,
    ) -> Self {
        Resolver {
            cache_tx,
            clients,
            client_buffer: BytesMut::with_capacity(4096),
        }
    }

    /// A high-level query for getting an A/AAAA record for a given hostname.
    pub async fn resolve_str(
        &mut self,
        host: &String,
        port: u16,
        get_ipv6: bool,
        udp_size: u16,
    ) -> Result<SocketAddr> {
        let resource_type = match get_ipv6 {
            true => ResourceType::AAAA,
            false => ResourceType::A,
        };

        // Ensure the host in the DNS message has the period:
        let mut host_with_period = host.clone();
        host_with_period.push('.');

        let request_info = RequestInfo {
            // The host used for cache/filter lookup meanwhile should not have the trailing period:
            name: host.clone(),
            resource_type,
            received_request_id: 0,
            requested_udp_size: udp_size,
        };

        let request = build_request(resource_type, host_with_period, udp_size);
        let response = self.resolve(&request, &request_info)
            .await
            .with_context(|| format!("Failed to resolve host {:?}", host))?;

        let results = extract_addresses(response, resource_type);
        debug!(
            "Resolved host={:?} type={:?}: {:?}",
            host, resource_type, results
        );
        Ok(SocketAddr::new(
            // If there are multiple IPs, just return the first one
            *results.get(0).with_context(|| {
                format!("No {:?} results for host: {:?}", resource_type, host)
            })?,
            port,
        ))
    }

    /// A low-level query that accepts and returns raw DNS query payloads.
    pub async fn resolve_raw(
        &mut self,
        request: &Message,
        request_info: &RequestInfo,
        response_buffer: &mut BytesMut,
    ) -> Result<()> {
        let response = self.resolve(request, request_info).await?;
        debug!("Response to client: {}", response);

        // Reencode the response to be returned, overriding udp_size to match the original request
        ENCODER.encode(
            &response,
            Some(request_info.requested_udp_size),
            response_buffer,
        )
    }

    async fn resolve(&mut self, request: &Message, request_info: &RequestInfo) -> Result<Message> {
        // Check if cache has cached result: Send request and wait for response via barrier+arc
        let result_barrier = Arc::new(Barrier::new(2));
        let result = Arc::new(Mutex::new(None));
        self.cache_tx.send(cache::task::CacheMsg::Fetch(cache::task::CacheFetch {
            request_info: (*request_info).clone(),
            result_barrier: result_barrier.clone(),
            result: result.clone(),
        })).await.context("Failed to send cache fetch query")?;

        // Wait on the barrier to complete
        result_barrier.wait().await;
        // Barrier has completed, get the stored result.
        // Do a swap to get the result out without yet another copy.
        match result.lock().await.replace(Ok(None)).expect("Missing fetch result following barrier") {
            Ok(Some(cache_result)) => {
                return Ok(cache_result)
            },
            Ok(None) => {
                // cache miss - continue with upstream queries below
            },
            Err(e) => {
                // cache fail - complain but continue with upstream queries
                warn!("Cache lookup failed for request {:?}: {}", request_info, e)
            },
        }

        // Cache didn't have anything, so query upstream clients.
        for client in &mut self.clients {
            // Mark the client buffer as empty so that we don't append on top of a prior request
            self.client_buffer.clear();
            if let Some(mut response) = client.query(request, &mut self.client_buffer).await? {
                // Store fetched result to cache (no response needed)
                self.cache_tx.send(cache::task::CacheMsg::Store(cache::task::CacheStore{
                    request_info: (*request_info).clone(),
                    response: response.clone(),
                })).await.context("Failed to send cache store query")?;

                // Set the message ID for the response so that it matches the original request.
                // Keeping the message IDs independent reduces the likelihood of cache poisoning.
                response.header.id = request_info.received_request_id;

                return Ok(response);
            }
        }
        bail!("All upstreams failed to return a response");
    }
}

fn build_request(resource_type: ResourceType, domain: String, udp_size: u16) -> Message {
    let mut question = Vec::new();
    question.push(Question {
        name: domain,
        resource_type: IntEnum::Enum(resource_type),
        resource_class: IntEnum::Enum(ResourceClass::INTERNET),
    });
    Message {
        header: Header {
            id: 0,
            is_response: false,
            op_code: IntEnum::Enum(OpCode::QUERY),
            authoritative: false,
            truncated: false,
            recursion_desired: true,
            recursion_available: false,
            reserved_9: false,
            authentic_data: true,
            checking_disabled: false,
            response_code: IntEnum::Enum(ResponseCode::NOERROR),
        },
        opt: Some(OPT {
            option: Vec::new(),
            udp_size,
            response_code: 0,
            version: 0,
            dnssec_ok: true,
        }),
        question,
        answer: Vec::new(),
        authority: Vec::new(),
        additional: Vec::new(),
    }
}

fn extract_addresses(message: Message, resource_type: ResourceType) -> Vec<IpAddr> {
    let mut results = Vec::with_capacity(message.answer.len());
    for answer in &message.answer {
        if answer.resource_type == IntEnum::Enum(resource_type) {
            match extract_address(answer, resource_type) {
                Some(addr) => results.push(addr),
                None => continue,
            }
        }
    }
    results
}

fn extract_address(answer: &Resource, resource_type: ResourceType) -> Option<IpAddr> {
    if resource_type == ResourceType::A {
        if let ResourceData::A(a) = &answer.rdata {
            Some(
                Ipv4Addr::new(a.address1, a.address2, a.address3, a.address4)
                    .into(),
            )
        } else {
            None
        }
    } else if resource_type == ResourceType::AAAA {
        if let ResourceData::AAAA(aaaa) = &answer.rdata {
            Some(
                Ipv6Addr::new(
                    aaaa.address1,
                    aaaa.address2,
                    aaaa.address3,
                    aaaa.address4,
                    aaaa.address5,
                    aaaa.address6,
                    aaaa.address7,
                    aaaa.address8,
                )
                    .into()
            )
        } else {
            None
        }
    } else {
        panic!(
            "Unsupported resource type for address extraction: {:?}",
            answer.resource_type
        );
    }
}
