use super::*;
use crate::{
    eval::query::{Query, Resolver},
    name::{self, Name},
    record::{
        DomainSpec, DualCidrLength, Exists, ExplainString, Explanation, Include, Ip4, Ip6,
        Mechanism, Mx, Ptr, Redirect, A,
    },
    resolve::{LookupError, LookupResult},
    trace::Tracepoint,
};
use std::{net::IpAddr, u128, u32};

impl EvaluateMatch for Mechanism {
    fn evaluate_match(&self, query: &mut Query, resolver: &Resolver) -> EvalResult<MatchResult> {
        trace!(query, Tracepoint::EvaluateMechanism(self.clone()));

        use Mechanism::*;
        let mechanism: &dyn EvaluateMatch = match self {
            All => return Ok(MatchResult::Match),
            Include(include) => include,
            A(a) => a,
            Mx(mx) => mx,
            Ptr(ptr) => ptr,
            Ip4(ip4) => ip4,
            Ip6(ip6) => ip6,
            Exists(exists) => exists,
        };

        mechanism.evaluate_match(query, resolver)
    }
}

impl EvaluateMatch for Include {
    fn evaluate_match(&self, query: &mut Query, resolver: &Resolver) -> EvalResult<MatchResult> {
        increment_lookup_count(query)?;

        let target_name = get_target_name(&self.domain_spec, query, resolver)?;
        trace!(query, Tracepoint::TargetName(target_name.clone()));

        let result = execute_recursive_query(query, resolver, target_name, true);

        // See the table in §5.2.
        use SpfResult::*;
        match result {
            Pass => Ok(MatchResult::Match),
            Fail(_) | Softfail | Neutral => Ok(MatchResult::NoMatch),
            Temperror => Err(EvalError::RecursiveTemperror),
            Permerror => Err(EvalError::RecursivePermerror),
            None => Err(EvalError::IncludeNoRecord),
        }
    }
}

impl EvaluateMatch for A {
    fn evaluate_match(&self, query: &mut Query, resolver: &Resolver) -> EvalResult<MatchResult> {
        increment_lookup_count(query)?;

        let target_name = get_target_name_or_domain(self.domain_spec.as_ref(), query, resolver)?;
        trace!(query, Tracepoint::TargetName(target_name.clone()));

        let ip = query.params.ip();

        // §5.3: ‘An address lookup is done on the <target-name> using the type
        // of lookup (A or AAAA) appropriate for the connection type (IPv4 or
        // IPv6).’
        let addrs = to_eval_result(resolver.lookup_a_or_aaaa(query, &target_name, ip))?;
        increment_void_lookup_count_if_void(query, addrs.len())?;

        for addr in addrs {
            trace!(query, Tracepoint::TryIpAddr(addr));

            // ‘The <ip> is compared to the returned address(es). If any address
            // matches, the mechanism matches.’
            if is_in_network(addr, self.prefix_len, ip) {
                return Ok(MatchResult::Match);
            }
        }

        Ok(MatchResult::NoMatch)
    }
}

impl EvaluateMatch for Mx {
    fn evaluate_match(&self, query: &mut Query, resolver: &Resolver) -> EvalResult<MatchResult> {
        increment_lookup_count(query)?;

        let target_name = get_target_name_or_domain(self.domain_spec.as_ref(), query, resolver)?;
        trace!(query, Tracepoint::TargetName(target_name.clone()));

        let mxs = to_eval_result(resolver.lookup_mx(query, &target_name))?;
        increment_void_lookup_count_if_void(query, mxs.len())?;

        let ip = query.params.ip();

        let mut i = 0;
        for mx in mxs {
            trace!(query, Tracepoint::TryMxName(mx.clone()));

            // §4.6.4: ‘When evaluating the "mx" mechanism, the number of "MX"
            // resource records queried is included in the overall limit of 10
            // mechanisms/modifiers that cause DNS lookups’
            increment_lookup_count(query)?;

            let addrs = to_eval_result(resolver.lookup_a_or_aaaa(query, &mx, ip))?;
            increment_void_lookup_count_if_void(query, addrs.len())?;

            for addr in addrs {
                trace!(query, Tracepoint::TryIpAddr(addr));

                // ‘the evaluation of each "MX" record MUST NOT result in
                // querying more than 10 address records -- either "A" or "AAAA"
                // resource records. If this limit is exceeded, the "mx"
                // mechanism MUST produce a "permerror" result.’ And: ‘These
                // limits are per mechanism […] in the record, and are in
                // addition to the lookup limits specified above.’
                increment_per_mechanism_lookup_count(query, &mut i)?;

                if is_in_network(addr, self.prefix_len, ip) {
                    return Ok(MatchResult::Match);
                }
            }
        }

        Ok(MatchResult::NoMatch)
    }
}

impl EvaluateMatch for Ptr {
    fn evaluate_match(&self, query: &mut Query, resolver: &Resolver) -> EvalResult<MatchResult> {
        increment_lookup_count(query)?;

        let target_name = get_target_name_or_domain(self.domain_spec.as_ref(), query, resolver)?;
        trace!(query, Tracepoint::TargetName(target_name.clone()));

        let ip = query.params.ip();

        let ptrs = match to_eval_result(resolver.lookup_ptr(query, ip)) {
            Ok(ptrs) => ptrs,
            // §5.5: ‘If a DNS error occurs while doing the PTR RR lookup, then
            // this mechanism fails to match.’
            Err(e) => {
                trace!(query, Tracepoint::ReverseLookupError(e));
                return Ok(MatchResult::NoMatch);
            }
        };
        increment_void_lookup_count_if_void(query, ptrs.len())?;

        let validated_names = get_validated_domain_names(query, resolver, ip, ptrs)?;

        // ‘Check all validated domain names to see if they either match the
        // <target-name> domain or are a subdomain of the <target-name> domain.
        // If any do, this mechanism matches.’
        for name in &validated_names {
            trace!(query, Tracepoint::TryValidatedName(name.clone()));
            if name == &target_name || name.is_subdomain_of(&target_name) {
                return Ok(MatchResult::Match);
            }
        }

        // ‘If no validated domain name can be found, or if none of the
        // validated domain names match or are a subdomain of the <target-name>,
        // this mechanism fails to match.’
        Ok(MatchResult::NoMatch)
    }
}

pub fn get_validated_domain_names(
    query: &mut Query,
    resolver: &Resolver,
    ip: IpAddr,
    names: Vec<Name>,
) -> EvalResult<Vec<Name>> {
    let mut validated_names = Vec::new();

    let mut i = 0;

    // §5.5: ‘For each record returned, validate the domain name by looking up
    // its IP addresses.’
    'names: for name in names {
        trace!(query, Tracepoint::ValidatePtrName(name.clone()));

        // §4.6.4: ‘When evaluating the "ptr" mechanism or the %{p} macro, the
        // number of "PTR" resource records queried is included in the overall
        // limit of 10 mechanisms/modifiers that cause DNS lookups’
        increment_lookup_count(query)?;

        let addrs = match to_eval_result(resolver.lookup_a_or_aaaa(query, &name, ip)) {
            Ok(addrs) => addrs,
            // §5.5: ‘If a DNS error occurs while doing an A RR lookup, then
            // that domain name is skipped and the search continues.’
            Err(e) => {
                trace!(query, Tracepoint::PtrNameLookupError(e));
                continue;
            }
        };
        increment_void_lookup_count_if_void(query, addrs.len())?;

        for addr in addrs {
            trace!(query, Tracepoint::TryIpAddr(addr));

            // §4.6.4: ‘the evaluation of each "PTR" record MUST NOT result in
            // querying more than 10 address records -- either "A" or "AAAA"
            // resource records. If this limit is exceeded, all records other
            // than the first 10 MUST be ignored.’ And: ‘These limits are per
            // mechanism or macro in the record, and are in addition to the
            // lookup limits specified above.’
            if increment_per_mechanism_lookup_count(query, &mut i).is_err() {
                trace!(query, Tracepoint::PtrAddressLookupLimitExceeded);
                break 'names;
            }

            // §5.5: ‘If <ip> is among the returned IP addresses, then that
            // domain name is validated.’
            if addr == ip {
                trace!(query, Tracepoint::PtrNameValidated);
                validated_names.push(name);
                break;
            }
        }
    }

    Ok(validated_names)
}

impl EvaluateMatch for Ip4 {
    fn evaluate_match(&self, query: &mut Query, _: &Resolver) -> EvalResult<MatchResult> {
        Ok(if is_in_network(self.addr, self.prefix_len, query.params.ip()) {
            MatchResult::Match
        } else {
            MatchResult::NoMatch
        })
    }
}

impl EvaluateMatch for Ip6 {
    fn evaluate_match(&self, query: &mut Query, _: &Resolver) -> EvalResult<MatchResult> {
        Ok(if is_in_network(self.addr, self.prefix_len, query.params.ip()) {
            MatchResult::Match
        } else {
            MatchResult::NoMatch
        })
    }
}

fn is_in_network<A, L>(network_addr: A, prefix_len: Option<L>, ip: IpAddr) -> bool
where
    A: Into<IpAddr>,
    L: Into<DualCidrLength>,
{
    match (network_addr.into(), ip) {
        (IpAddr::V4(network_addr), IpAddr::V4(ip)) => {
            match prefix_len.and_then(|l| l.into().ip4()) {
                // §5: ‘If no CIDR prefix length is given in the directive, then
                // <ip> and the IP address are compared for equality.’
                None => network_addr == ip,
                // ‘If a CIDR prefix length is specified, then only the
                // specified number of high-order bits of <ip> and the IP
                // address are compared for equality.’
                Some(len) => {
                    let mask = u32::MAX << (32 - len.get());
                    (u32::from(network_addr) & mask) == (u32::from(ip) & mask)
                }
            }
        }
        (IpAddr::V6(network_addr), IpAddr::V6(ip)) => {
            match prefix_len.and_then(|l| l.into().ip6()) {
                None => network_addr == ip,
                Some(len) => {
                    let mask = u128::MAX << (128 - len.get());
                    (u128::from(network_addr) & mask) == (u128::from(ip) & mask)
                }
            }
        }
        _ => false,
    }
}

impl EvaluateMatch for Exists {
    fn evaluate_match(&self, query: &mut Query, resolver: &Resolver) -> EvalResult<MatchResult> {
        increment_lookup_count(query)?;

        let target_name = get_target_name(&self.domain_spec, query, resolver)?;
        trace!(query, Tracepoint::TargetName(target_name.clone()));

        // §5.7: ‘The resulting domain name is used for a DNS A RR lookup (even
        // when the connection type is IPv6).’
        let addrs = to_eval_result(resolver.lookup_a(query, &target_name))?;
        increment_void_lookup_count_if_void(query, addrs.len())?;

        // ‘If any A record is returned, this mechanism matches.’
        Ok(if addrs.is_empty() {
            MatchResult::NoMatch
        } else {
            MatchResult::Match
        })
    }
}

impl Evaluate for Redirect {
    fn evaluate(&self, query: &mut Query, resolver: &Resolver) -> SpfResult {
        trace!(query, Tracepoint::EvaluateRedirect(self.clone()));

        if let Err(e) = increment_lookup_count(query) {
            trace!(query, Tracepoint::RedirectLookupLimitExceeded);
            query.result_cause = e.to_error_cause().map(From::from);
            return e.to_spf_result();
        }

        // §6.1: ‘if the <target-name> is malformed, the result is a "permerror"
        // rather than "none"’
        let target_name = match get_target_name(&self.domain_spec, query, resolver) {
            Ok(n) => n,
            Err(e) => {
                trace!(query, Tracepoint::InvalidRedirectTargetName);
                query.result_cause = e.to_error_cause().map(From::from);
                return e.to_spf_result();
            }
        };
        trace!(query, Tracepoint::TargetName(target_name.clone()));

        let result = execute_recursive_query(query, resolver, target_name, false);

        // ‘The result of this new evaluation of check_host() is then considered
        // the result of the current evaluation with the exception that if no
        // SPF record is found, […] the result is a "permerror" rather than
        // "none".’
        match result {
            SpfResult::None => {
                trace!(query, Tracepoint::RedirectNoSpfRecord);
                query.result_cause = Some(ErrorCause::NoSpfRecord.into());
                SpfResult::Permerror
            }
            result => result,
        }
    }
}

fn execute_recursive_query(
    query: &mut Query,
    resolver: &Resolver,
    target_name: Name,
    included: bool,
) -> SpfResult {
    // For recursive queries, adjust the target domain and included query flag
    // before execution, and restore them afterwards. Included redirections keep
    // their included flag set.
    let prev_name = query.params.replace_domain(target_name);
    let prev_included = query.state.is_included_query();
    query.state.set_included_query(prev_included || included);

    let result = query.execute(resolver);

    query.params.replace_domain(prev_name);
    query.state.set_included_query(prev_included);

    result
}

impl EvaluateToString for Explanation {
    fn evaluate_to_string(&self, query: &mut Query, resolver: &Resolver) -> EvalResult<String> {
        trace!(query, Tracepoint::EvaluateExplanation(self.clone()));

        let target_name = get_target_name(&self.domain_spec, query, resolver)?;
        trace!(query, Tracepoint::TargetName(target_name.clone()));

        // §6.2: ‘The fetched TXT record's strings are concatenated with no
        // spaces, and then treated as an explain-string, which is
        // macro-expanded.’
        let mut explain_string = match lookup_explain_string(resolver, query, &target_name) {
            Ok(e) => e,
            Err(e) => {
                // ‘If there are any DNS processing errors (any RCODE other than
                // 0), or if no records are returned, or if more than one record
                // is returned, or if there are syntax errors in the explanation
                // string, then proceed as if no "exp" modifier was given.’
                use ExplainStringLookupError::*;
                trace!(
                    query,
                    match e {
                        DnsLookup(e) => Tracepoint::ExplainStringLookupError(e),
                        NoExplainString => Tracepoint::NoExplainString,
                        MultipleExplainStrings(s) => Tracepoint::MultipleExplainStrings(s),
                        Syntax(s) => Tracepoint::InvalidExplainStringSyntax(s),
                    }
                );

                // After the tracing above, may now conflate the error causes:
                return Err(EvalError::Dns(None));
            }
        };

        if let Some(f) = query.config.modify_exp_fn() {
            trace!(query, Tracepoint::ModifyExplainString(explain_string.clone()));
            f(&mut explain_string);
        }

        explain_string.evaluate_to_string(query, resolver)
    }
}

#[derive(Debug)]
pub enum ExplainStringLookupError {
    DnsLookup(LookupError),
    NoExplainString,
    MultipleExplainStrings(Vec<String>),
    Syntax(String),
}

impl Error for ExplainStringLookupError {}

impl Display for ExplainStringLookupError {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        write!(f, "failed to obtain explain string")
    }
}

impl From<LookupError> for ExplainStringLookupError {
    fn from(error: LookupError) -> Self {
        match error {
            LookupError::NoRecords => Self::NoExplainString,
            _ => Self::DnsLookup(error),
        }
    }
}

fn lookup_explain_string(
    resolver: &Resolver,
    query: &mut Query,
    name: &Name,
) -> Result<ExplainString, ExplainStringLookupError> {
    let mut exps = resolver.lookup_txt(query, name)?.into_iter();

    use ExplainStringLookupError::*;
    match exps.next() {
        None => Err(NoExplainString),
        Some(exp) => {
            let mut rest = exps.collect::<Vec<_>>();
            match *rest {
                [] => exp.parse().map_err(|_| Syntax(exp)),
                [..] => {
                    rest.insert(0, exp);
                    Err(MultipleExplainStrings(rest))
                }
            }
        }
    }
}

fn get_target_name_or_domain(
    domain_spec: Option<&DomainSpec>,
    query: &mut Query,
    resolver: &Resolver,
) -> EvalResult<Name> {
    // §4.8: ‘For several mechanisms, the <domain-spec> is optional. If it is
    // not provided, the <domain> from the check_host() arguments is used as the
    // <target-name>.’
    match domain_spec {
        None => Ok(query.params.domain().clone()),
        Some(domain_spec) => get_target_name(domain_spec, query, resolver),
    }
}

fn get_target_name(
    domain_spec: &DomainSpec,
    query: &mut Query,
    resolver: &Resolver,
) -> EvalResult<Name> {
    // §4.8: ‘The <domain-spec> string is subject to macro expansion […]. The
    // resulting string is the common presentation form of a fully qualified DNS
    // name’
    let mut name = domain_spec.evaluate_to_string(query, resolver)?;
    truncate_target_name_string(&mut name, name::MAX_DOMAIN_LENGTH);
    Name::new(&name).map_err(|_| EvalError::InvalidName(name))
}

// §7.3: ‘When the result of macro expansion is used in a domain name query, if
// the expanded domain name exceeds 253 characters (the maximum length of a
// domain name in this format), the left side is truncated to fit, by removing
// successive domain labels (and their following dots) until the total length
// does not exceed 253 characters.’
fn truncate_target_name_string(s: &mut String, max: usize) {
    if s.ends_with('.') {
        s.pop();
    }
    let len = s.len();
    if len > max {
        if let Some((i, _)) = s
            .rmatch_indices('.')
            .take_while(|(i, _)| len - i - 1 <= max)
            .last()
        {
            s.drain(..=i);
        }
    }
}

// §4.6.4: ‘The following terms cause DNS queries: the "include", "a", "mx",
// "ptr", and "exists" mechanisms, and the "redirect" modifier. SPF
// implementations MUST limit the total number of those terms to 10 during SPF
// evaluation’
fn increment_lookup_count(query: &mut Query) -> EvalResult<()> {
    trace!(query, Tracepoint::IncrementLookupCount);
    query.state.increment_lookup_count()
}

// §4.6.4: ‘there may be cases where it is useful to limit the number of "terms"
// for which DNS queries return either a positive answer (RCODE 0) with an
// answer count of 0, or a "Name Error" (RCODE 3) answer. These are sometimes
// collectively referred to as "void lookups".’
pub fn increment_void_lookup_count_if_void(query: &mut Query, count: usize) -> EvalResult<()> {
    if count == 0 {
        trace!(query, Tracepoint::IncrementVoidLookupCount);
        query.state.increment_void_lookup_count(query.config.max_void_lookups())
    } else {
        Ok(())
    }
}

const MAX_LOOKUPS_PER_MECHANISM: usize = 10;

fn increment_per_mechanism_lookup_count(query: &mut Query, i: &mut usize) -> EvalResult<()> {
    trace!(query, Tracepoint::IncrementPerMechanismLookupCount);
    if *i < MAX_LOOKUPS_PER_MECHANISM {
        *i += 1;
        Ok(())
    } else {
        Err(EvalError::PerMechanismLookupLimitExceeded)
    }
}

pub fn to_eval_result<T>(result: LookupResult<Vec<T>>) -> EvalResult<Vec<T>> {
    match result {
        Ok(r) => Ok(r),
        Err(e) => {
            use LookupError::*;
            match e {
                Timeout => Err(EvalError::Timeout),
                // §5: ‘If the server returns "Name Error" (RCODE 3), then
                // evaluation of the mechanism continues as if the server
                // returned no error (RCODE 0) and zero answer records.’
                NoRecords => Ok(Vec::new()),
                Dns(e) => Err(EvalError::Dns(e)),
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::record::Ip4CidrLength;

    #[test]
    fn is_in_network_ok() {
        assert!(is_in_network(
            IpAddr::from([123, 12, 12, 12]),
            Some(Ip4CidrLength::new(24).unwrap()),
            IpAddr::from([123, 12, 12, 98]),
        ));
    }

    #[test]
    fn truncate_target_name_string_ok() {
        fn truncate<S: Into<String>>(s: S, max: usize) -> String {
            let mut s = s.into();
            truncate_target_name_string(&mut s, max);
            s
        }

        // Pathological case where final label longer than limit → no-op:
        assert_eq!(truncate("ab.cd.ef", 1), "ab.cd.ef");
        assert_eq!(truncate("ab.cd.ef.", 1), "ab.cd.ef");

        // Truncating:
        assert_eq!(truncate("ab.cd.ef", 2), "ef");
        assert_eq!(truncate("ab.cd.ef.", 2), "ef");
        assert_eq!(truncate("ab.cd.ef", 3), "ef");
        assert_eq!(truncate("ab.cd.ef", 4), "ef");
        assert_eq!(truncate("ab.cd.ef", 5), "cd.ef");
        assert_eq!(truncate("ab.cd.ef", 6), "cd.ef");
        assert_eq!(truncate("ab.cd.ef", 7), "cd.ef");
        assert_eq!(truncate("ab.cd.ef.", 7), "cd.ef");

        // Not longer than limit → no-op:
        assert_eq!(truncate("ab.cd.ef", 8), "ab.cd.ef");
        assert_eq!(truncate("ab.cd.ef.", 8), "ab.cd.ef");
        assert_eq!(truncate("ab.cd.ef", 9), "ab.cd.ef");
    }
}
