use anyhow::{bail, Context, Result};
use byteorder::{NetworkEndian, WriteBytesExt};
use std::cmp::max;
use std::io::{Cursor, Read, Write};
use std::net::{IpAddr, SocketAddr, TcpStream, UdpSocket};
use std::time::{Duration, Instant};
use toluol::{DnsFlags, DnsMessage, DnsOpcode, DnsRecord};

#[cfg(feature = "tls")]
use {anyhow::anyhow, std::convert::TryInto, std::net::ToSocketAddrs, std::sync::Arc};

mod args;
use args::ConnectionType;

// TODO
// - see if we can get nicer error messages
// - add +trace support
// - DoH
// - Colors for output (highlight answer section)
// - add tests for parsing (look at cargo fuzz)
// - more input validation when constructing lib data types (e.g. validate names, or even better, a Name type)
// - DNSSEC verification?
// - new trait for JSON output, implemented by two structs (one for OPT records, one for all other records). then
//   display_result() can convert them and keep a Vec<Box<dyn JsonRecord>>. this would make for nicer JSON output
//   of DnsAnswers

fn main() -> Result<()> {
    let bufsize = 4096; // seems reasonable
    let args = args::parse_args();
    let data = prepare_query(&args, bufsize)?;

    let (answer, bytes_recvd, elapsed) = match args.connection_type {
        ConnectionType::Udp => send_query_udp(&args, bufsize, &data),
        ConnectionType::Tcp => send_query_tcp(&args, bufsize, &data),
        #[cfg(feature = "tls")]
        ConnectionType::Tls => send_query_tls(&args, &data),
    }?;

    let res = DnsMessage::parse(&mut Cursor::new(&answer)).context("Could not parse answer.")?;
    display_result(&res, &args, bytes_recvd, &elapsed);

    Ok(())
}

fn prepare_query(args: &args::Args, bufsize: u16) -> Result<Vec<u8>> {
    // see https://tools.ietf.org/html/rfc6840#section-5.9 for why the cd flag is set
    let flags = DnsFlags::new(false, false, true, false, true, true);
    let msg = DnsMessage::new_query(
        &args.url,
        args.qtype,
        DnsOpcode::QUERY,
        flags,
        true,
        args.use_dnssec,
        bufsize,
    )
    .context("Could not create query.")?;
    msg.encode().context("Could not encode query.")
}

fn send_query_udp(
    args: &args::Args,
    bufsize: u16,
    data: &[u8],
) -> Result<(Vec<u8>, u16, Duration)> {
    let socket = create_and_connect_socket_udp(args)?;
    let mut res = vec![0; bufsize as usize]; // the query sets this as max size

    socket
        .set_write_timeout(Some(Duration::new(2, 0)))
        .context("Could not set UDP socket write timeout.")?;
    socket
        .set_read_timeout(Some(Duration::new(10, 0)))
        .context("Could not set UDP socket read timeout.")?;

    let addr = format!("{}:{}", args.nameserver, args.port);
    socket
        .connect(&addr)
        .context(format!("Could not connect to {} via UDP.", addr))?;

    let before = Instant::now();
    socket
        .send(data)
        .context("Could not send data to nameserver")?;

    let bytes_recvd = socket
        .recv(&mut res)
        .context("The nameserver did not reply in time.")?;
    let elapsed = before.elapsed();

    res.resize(bytes_recvd, 0);

    Ok((res, bytes_recvd as u16, elapsed))
}

fn create_and_connect_socket_udp(args: &args::Args) -> Result<UdpSocket> {
    if let Ok(ns_addr) = args.nameserver.parse::<IpAddr>() {
        // must match the IP protocol version of ns_addr, else we get an "Address family not
        // supported by protocol" error on connect()
        let bind_addr = if ns_addr.is_ipv6() { "::" } else { "0.0.0.0" };
        let connect_addr: SocketAddr = (ns_addr, args.port).into();
        let socket = UdpSocket::bind((bind_addr, 0)).context("Could not create UDP socket.")?;
        socket
            .connect(connect_addr)
            .context(format!("Could not connect to {} via UDP.", connect_addr))?;
        Ok(socket)
    } else {
        // treat args.nameserver as a hostname. first try to connect via IPv6; if that fails, try
        // IPv4; if that also fails, bail.

        let connect_addr = format!("{}:{}", args.nameserver, args.port);
        // tried in reverse, i.e. "::" is tried first
        let mut bind_addrs = vec!["0.0.0.0", "::"];
        let mut socket = None;
        while socket.is_none() && !bind_addrs.is_empty() {
            let bind_addr = bind_addrs.pop().unwrap();
            let sock = UdpSocket::bind((bind_addr, 0)).context("Could not create UDP socket.")?;
            if sock.connect(&connect_addr).is_ok() {
                socket = Some(sock);
            }
        }

        if let Some(sock) = socket {
            Ok(sock)
        } else {
            Err(anyhow!("Could not connect to {} via UDP.", connect_addr))
        }
    }
}

fn send_query_tcp(
    args: &args::Args,
    bufsize: u16,
    data: &[u8],
) -> Result<(Vec<u8>, u16, Duration)> {
    let addr = format!("{}:{}", args.nameserver, args.port);
    let mut socket = TcpStream::connect(&addr).context(format!(
        "Could not connect to {} via TCP, is the server running?",
        addr
    ))?;

    socket
        .set_write_timeout(Some(Duration::new(2, 0)))
        .context("Could not set TCP stream write timeout.")?;
    socket
        .set_read_timeout(Some(Duration::new(10, 0)))
        .context("Could not set TCP stream read timeout.")?;

    let mut msg = Vec::with_capacity(data.len() + 2);
    msg.write_u16::<NetworkEndian>(data.len() as u16)?;
    msg.extend_from_slice(data);

    let before = Instant::now();
    socket
        .write_all(&msg)
        .context("Could not write data to TCP stream.")?;

    // we can't use socket.read_to_end() because we would have to wait for the read timout to elapse
    // before getting an EOF from the socket. therefore we roll our own implementation which stops reading
    // from the socket as soon as the received number of bytes is equal to the message length given by
    // the first two bytes of the message (plus two, because the message length does not count the two
    // bytes at the start; see RFC 1035, Section 4.2.2)
    let mut offset = 0;
    // the query sets this as max size
    let mut res = vec![0; bufsize as usize];
    while (offset < 2) || (offset - 2 < u16::from_be_bytes([res[0], res[1]]) as usize) {
        offset += socket
            .read(&mut res[offset..])
            .context("Could not read from TCP stream.")?;
    }

    let elapsed = before.elapsed();
    socket.shutdown(std::net::Shutdown::Both)?;

    let bytes_recvd = u16::from_be_bytes([res[0], res[1]]);
    res = res.into_iter().skip(2).collect();
    if bytes_recvd as usize != offset - 2 {
        bail!(
            "Received {} bytes, but TCP message says {} bytes were sent!",
            offset - 2,
            bytes_recvd
        );
    }
    // this will always shrink res
    res.resize(bytes_recvd as usize, 0);

    Ok((res, bytes_recvd, elapsed))
}

#[cfg(feature = "tls")]
fn send_query_tls(args: &args::Args, data: &[u8]) -> Result<(Vec<u8>, u16, Duration)> {
    let mut root_store = rustls::RootCertStore::empty();
    root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
        rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
            ta.subject,
            ta.spki,
            ta.name_constraints,
        )
    }));
    let config = rustls::ClientConfig::builder()
        .with_safe_defaults()
        .with_root_certificates(root_store)
        .with_no_client_auth();

    let nameserver = args
        .nameserver
        .as_str()
        .try_into()
        .map_err(|_| anyhow!("Invalid nameserver hostname"))?;
    let mut session = rustls::ClientConnection::new(Arc::new(config), nameserver)
        .context("Could not create TLS connection.")?;

    let addr = format!("{}:{}", args.nameserver, args.port)
        .as_str()
        .to_socket_addrs()
        .context("Invalid nameserver address")?
        .next()
        .ok_or_else(|| anyhow!("Invalid nameserver address"))?;
    let mut socket = TcpStream::connect_timeout(&addr, Duration::new(10, 0)).context(format!(
        "Failed to connect to {}, is the server configured to use DNS over TLS?",
        addr
    ))?;
    let mut plaintext = Vec::new();

    socket
        .set_write_timeout(Some(Duration::new(2, 0)))
        .context("Could not set TLS/TCP stream write timeout.")?;
    socket
        .set_read_timeout(Some(Duration::new(10, 0)))
        .context("Could not set TLS/TCP stream read timeout.")?;

    let mut msg = Vec::with_capacity(data.len() + 2);
    msg.write_u16::<NetworkEndian>(data.len() as u16)?;
    msg.extend_from_slice(data);

    let before = Instant::now();
    session
        .writer()
        .write_all(&msg)
        .context("Could not write to TLS socket.")?;

    while (plaintext.len() < 2)
        || plaintext.len() - 2 < u16::from_be_bytes([plaintext[0], plaintext[1]]) as usize
    {
        if session.wants_write() {
            session
                .write_tls(&mut socket)
                .context("Could not write TLS packets to TCP stream.")?;
        }

        if session.wants_read() {
            session
                .read_tls(&mut socket)
                .context("Could not read TLS packets from TCP stream.")?;
            session
                .process_new_packets()
                .context("Could not process new TLS packets.")?;
            // Ignore WouldBlock errors
            match session.reader().read_to_end(&mut plaintext) {
                Ok(_) => (),
                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => (),
                Err(e) => Err(e).context("Could not read from TLS socket.")?,
            }
        }
    }
    let elapsed = before.elapsed();

    session.send_close_notify();

    // remove first two bytes (see RFC 1035, Section 4.2.2)
    let bytes_recvd = u16::from_be_bytes([plaintext[0], plaintext[1]]);
    plaintext = plaintext.into_iter().skip(2).collect();
    if bytes_recvd != plaintext.len() as u16 {
        bail!(
            "Received {} bytes, but TCP message says {} were sent.",
            bytes_recvd,
            plaintext.len()
        )
    }

    Ok((plaintext, bytes_recvd, elapsed))
}

fn display_result(res: &DnsMessage, args: &args::Args, bytes_recvd: u16, elapsed: &Duration) {
    if !args.short {
        #[cfg(feature = "json")]
        if args.json {
            println!("{}", serde_json::to_string_pretty(&res).unwrap());
            return;
        }
        println!("{}", res);
        println!();
        println!("Query metadata:");
        println!("\tTime:        {} ms", elapsed.as_millis());
        println!("\tReply size:  {} bytes", bytes_recvd);
        println!("\tServer:      {}#{}", args.nameserver, args.port);
        return;
    }

    let all_answers: Vec<_> = res
        .answers
        .iter()
        .chain(res.authoritative_answers.iter())
        .chain(
            res.additional_answers
                .iter()
                // OPT records don't have an `as_padded_string()` implementation
                .filter(|record| matches!(record, DnsRecord::NONOPT { .. })),
        )
        .collect();

    #[cfg(feature = "json")]
    if args.json {
        println!("{}", serde_json::to_string_pretty(&all_answers).unwrap());
        return;
    }

    let (mut max_owner_len, mut max_type_len) = (0, 0);
    for answer in &all_answers {
        if let DnsRecord::NONOPT { name, atype, .. } = answer {
            max_owner_len = max(max_owner_len, name.len());
            max_type_len = max(max_type_len, atype.to_string().len());
        }
    }
    for answer in &all_answers {
        println!("{}", answer.as_padded_string(max_owner_len, max_type_len));
    }
}
