use async_trait::async_trait;
use domain::{
    base::{Dname, Rtype},
    rdata::{Aaaa, Mx, Ptr, Txt, A},
    resolv::{
        stub::conf::{ResolvConf, ResolvOptions},
        StubResolver,
    },
};
use std::{
    error::Error,
    io::{self, ErrorKind},
    net::{IpAddr, Ipv4Addr, Ipv6Addr},
    str::FromStr,
    sync::Arc,
    time::Duration,
};
use viaspf::lookup::{Lookup, LookupError, LookupResult, Name};

pub enum Resolver {
    Live(DomainResolver),
    Mock(Arc<MockResolver>),
}

/// A resolver wrapping a mock `Lookup`.
///
/// For use in testing.
pub struct MockResolver(Box<dyn Lookup>);

impl MockResolver {
    pub fn new(lookup: impl Lookup + 'static) -> Self {
        Self(Box::new(lookup))
    }
}

#[async_trait]
impl Lookup for MockResolver {
    async fn lookup_a(&self, name: &Name) -> LookupResult<Vec<Ipv4Addr>> {
        self.0.lookup_a(name).await
    }

    async fn lookup_aaaa(&self, name: &Name) -> LookupResult<Vec<Ipv6Addr>> {
        self.0.lookup_aaaa(name).await
    }

    async fn lookup_mx(&self, name: &Name) -> LookupResult<Vec<Name>> {
        self.0.lookup_mx(name).await
    }

    async fn lookup_txt(&self, name: &Name) -> LookupResult<Vec<String>> {
        self.0.lookup_txt(name).await
    }

    async fn lookup_ptr(&self, ip: IpAddr) -> LookupResult<Vec<Name>> {
        self.0.lookup_ptr(ip).await
    }
}

/// A resolver doing live DNS queries.
pub struct DomainResolver {
    resolver: StubResolver,
}

impl DomainResolver {
    pub fn new(timeout: Duration) -> Self {
        let options = ResolvOptions {
            timeout,
            attempts: 1,
            ..Default::default()
        };

        let mut conf = ResolvConf::default();
        conf.options = options;
        conf.finalize();

        let resolver = StubResolver::from_conf(conf);

        Self { resolver }
    }
}

#[async_trait]
impl Lookup for DomainResolver {
    async fn lookup_a(&self, name: &Name) -> LookupResult<Vec<Ipv4Addr>> {
        let name = to_dname(name)?;
        self.resolver
            .query((name, Rtype::A))
            .await
            .map_err(to_lookup_error)?
            .answer()
            .map_err(wrap_error)?
            .limit_to::<A>()
            .map(|record| record.map(|r| r.data().addr()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)
    }

    async fn lookup_aaaa(&self, name: &Name) -> LookupResult<Vec<Ipv6Addr>> {
        let name = to_dname(name)?;
        self.resolver
            .query((name, Rtype::Aaaa))
            .await
            .map_err(to_lookup_error)?
            .answer()
            .map_err(wrap_error)?
            .limit_to::<Aaaa>()
            .map(|record| record.map(|r| r.data().addr()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)
    }

    async fn lookup_mx(&self, name: &Name) -> LookupResult<Vec<Name>> {
        let name = to_dname(name)?;
        let answer = self
            .resolver
            .query((name, Rtype::Mx))
            .await
            .map_err(to_lookup_error)?;
        let mut mxs = answer
            .answer()
            .map_err(wrap_error)?
            .limit_to::<Mx<_>>()
            .map(|record| record.map(|r| r.into_data()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)?;
        mxs.sort_by_key(|mx| mx.preference());
        mxs.into_iter()
            .map(|mx| Name::new(&mx.exchange().to_string()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)
    }

    async fn lookup_txt(&self, name: &Name) -> LookupResult<Vec<String>> {
        let name = to_dname(name)?;
        let answer = self
            .resolver
            .query((name, Rtype::Txt))
            .await
            .map_err(to_lookup_error)?;
        let txts = answer
            .answer()
            .map_err(wrap_error)?
            .limit_to::<Txt<_>>()
            .map(|record| record.map(|r| r.into_data()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)?;
        txts.into_iter()
            .map(|txt| {
                txt.text::<Vec<_>>()
                    .map(|bytes| String::from_utf8_lossy(&bytes).into_owned())
            })
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)
    }

    async fn lookup_ptr(&self, ip: IpAddr) -> LookupResult<Vec<Name>> {
        let name = ip_to_dname(ip)?;
        let answer = self
            .resolver
            .query((name, Rtype::Ptr))
            .await
            .map_err(to_lookup_error)?;
        let ptrs = answer
            .answer()
            .map_err(wrap_error)?
            .limit_to::<Ptr<_>>()
            .map(|record| record.map(|r| r.into_data()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)?;
        ptrs.into_iter()
            .map(|ptr| Name::new(&ptr.to_string()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)
    }
}

fn to_dname(name: &Name) -> LookupResult<Dname<Vec<u8>>> {
    Dname::from_str(name.as_str()).map_err(wrap_error)
}

fn ip_to_dname(ip: IpAddr) -> LookupResult<Dname<Vec<u8>>> {
    let s = match ip {
        IpAddr::V4(addr) => {
            let [a, b, c, d] = addr.octets();
            format!("{}.{}.{}.{}.in-addr.arpa.", d, c, b, a)
        }
        IpAddr::V6(addr) => format!(
            "{}.ip6.arpa.",
            addr.octets()
                .iter()
                .rev()
                .map(|o| format!("{:x}.{:x}", o & 0xf, o >> 4))
                .collect::<Vec<_>>()
                .join(".")
        ),
    };
    Dname::from_str(&s).map_err(wrap_error)
}

fn to_lookup_error(error: io::Error) -> LookupError {
    match error.kind() {
        ErrorKind::NotFound => LookupError::NoRecords,
        ErrorKind::TimedOut => LookupError::Timeout,
        _ => wrap_error(error),
    }
}

fn wrap_error(error: impl Error + Send + Sync + 'static) -> LookupError {
    LookupError::Dns(Some(error.into()))
}

#[cfg(test)]
mod tests {
    use super::*;

    // This test is disabled because it depends on live DNS records.
    #[ignore]
    #[tokio::test]
    async fn domain_resolver_lookup_ok() {
        let resolver = DomainResolver::new(Duration::from_secs(30));

        let domain = Name::new("gluet.ch").unwrap();

        let ips = resolver.lookup_a(&domain).await;
        assert!(ips.is_ok());
        let ip = ips.unwrap().into_iter().next().unwrap();
        assert!(resolver.lookup_ptr(ip.into()).await.is_ok());

        let ips = resolver.lookup_aaaa(&domain).await;
        assert!(ips.is_ok());
        let ip = ips.unwrap().into_iter().next().unwrap();
        assert!(resolver.lookup_ptr(ip.into()).await.is_ok());
    }
}
