use anyhow::{bail, Context, Result};

use crate::codec::message;
use crate::specs::message::Message;

/// Decodes a wire DNS message, turning it into a parsed `Message` object.
pub struct DNSMessageDecoder {
    /// The current offset in buf. We do this manually in order to avoid changes to buf's offsets,
    /// as required to support label compression.
    offset: usize,

    /// The expected count of questions, answers, authorities, and additionals to be parsed.
    expected_counts: message::RecordCounts,

    message: Option<Message>,
}

impl<'a> DNSMessageDecoder {
    /// Returns a `DNSMessageDecoder` for decoding DNS messages.
    pub fn new() -> DNSMessageDecoder {
        DNSMessageDecoder {
            offset: 0,
            expected_counts: message::RecordCounts::new(),
            message: None,
        }
    }

    /// Decodes data from the provided `BytesMut`.
    /// - Returns `Ok(Message)` if buf has been processed, but may be truncated.
    /// - Returns `Ok(None)` if more buf data is needed via additional calls to `decode()`.
    pub fn decode(&mut self, buf: &[u8]) -> Result<Option<Message>> {
        // Note that we avoid calling split_* against buf and instead keep track of our read offset locally.
        // This is because buf needs to be 'untainted' by any offset shifts in case label compression is being used.
        match &mut self.message {
            None => {
                // Haven't consumed a header yet, try for that first.
                match message::read_header(buf, &mut self.offset)? {
                    Some((header, record_counts, truncated)) => {
                        self.expected_counts = record_counts;
                        let mut message = Message {
                            header,
                            opt: None,
                            question: Vec::new(),
                            answer: Vec::new(),
                            authority: Vec::new(),
                            additional: Vec::new(),
                        };
                        if truncated {
                            // Header content says message is truncated.
                            // Don't bother waiting for more data because it won't be coming.
                            // Upstream should check for header.truncated in the result.
                            return Ok(Some(message));
                        }
                        // Check if there's more data in buf that we can parse into the message.
                        // This should be the common case.
                        if decode_into_message(buf, &mut self.offset, &mut self.expected_counts, &mut message)? {
                            // Message is done. Reset state and return.
                            self.offset = 0;
                            self.expected_counts = message::RecordCounts::new();
                            return Ok(Some(message));
                        } else {
                            // Message is still incomplete - need more data.
                            // Save initial progress to try again on the next decode call.
                            self.message = Some(message);
                            return Ok(None)
                        }
                    }
                    None => return Ok(None)
                }
            },
            Some(message) => {
                // Pending message from prior decode calls. See if we can complete it
                if !decode_into_message(buf, &mut self.offset, &mut self.expected_counts, message)? {
                    // Message is still incomplete - need more data.
                    return Ok(None);
                }
            },
        }

        // Message is done. Reset state and return.
        // This would ideally be above in the 'Some(message)' case but this is easier on the borrow checker.
        self.offset = 0;
        self.expected_counts = message::RecordCounts::new();
        return Ok(std::mem::take(&mut self.message))
    }
}

fn decode_into_message(buf: &[u8], offset: &mut usize, expected_counts: &mut message::RecordCounts, message: &mut Message) -> Result<bool> {
    while message.question.len() < expected_counts.question as usize {
        // Expecting more questions, try to get them
        match message::read_question(buf, offset)
            .context("Failed to read question resource")?
        {
            Some(question) => {
                message.question.push(question);
            }
            None => return Ok(false),
        }
    }

    while message.answer.len() < expected_counts.answer as usize {
        // Expecting more answer resources, try to get them
        match message::read_resource_non_opt(buf, offset)
            .context("Failed to read answer resource")?
        {
            Some(answer) => {
                message.answer.push(answer);
            }
            None => return Ok(false),
        }
    }

    while message.authority.len() < expected_counts.authority as usize {
        // Expecting more authority resources, try to get them
        match message::read_resource_non_opt(buf, offset)
            .context("Failed to read authority resource")?
        {
            Some(authority) => {
                message.authority.push(authority);
            }
            None => return Ok(false),
        }
    }

    // Ensure any OPT record is counted against the expected number of additional records
    let mut opt_count = match &message.opt {
        Some(_opt) => 1,
        None => 0,
    };
    while message.additional.len() + opt_count < expected_counts.additional as usize {
        // Expecting more additional resources, try to get them
        // Use special handling for OPT resources, when they are encountered
        match message::read_resource_name_type(buf, offset)
            .context("Failed to read additional resource name and type")?
        {
            Some((name, message::OPT_RESOURCE_TYPE)) => {
                if name != "." {
                    bail!("Expected OPT resource to have name '.', but was: {}", name);
                }
                if opt_count != 0 {
                    bail!("Message has multiple OPT resources");
                }
                match message::read_resource_remainder_opt(buf, offset) {
                    Ok(Some(opt)) => {
                        message.opt = Some(opt);
                        opt_count += 1;
                    }
                    Ok(None) => return Ok(false),
                    Err(e) => return Err(e),
                }
            }
            Some((name, resource_type)) => {
                match message::read_resource_remainder(
                    buf,
                    offset,
                    name,
                    resource_type,
                ) {
                    Ok(Some(additional)) => message.additional.push(additional),
                    Ok(None) => return Ok(false),
                    Err(e) => return Err(e),
                }
            }
            None => return Ok(false),
        }
    }

    // We've filled all the expected records, so the message is complete.
    Ok(true)
}
