#[cfg(feature = "miltertest-tests")]
mod miltertest;

#[cfg(feature = "miltertest-tests")]
pub use self::miltertest::*;

use bytes::{Buf, BufMut, Bytes, BytesMut};
use indymilter::{
    message::{
        self,
        command::{
            Command, CommandKind, ConnInfoPayload, EnvAddrPayload, HeaderPayload, HeloPayload,
            MacroPayload, OptNegPayload, UnknownPayload,
        },
        reply::{Reply, ReplyKind},
        Message,
    },
    Actions, Callbacks, Config, ProtoOpts, SocketInfo,
};
use std::{
    collections::HashMap,
    ffi::{CStr, CString},
    io,
    net::{Ipv4Addr, SocketAddr},
    time::Duration,
};
use tokio::{
    io::{AsyncWriteExt, BufStream},
    net::{TcpListener, TcpStream, ToSocketAddrs},
    sync::oneshot,
    task::{self, JoinHandle},
};

pub fn init_tracing_subscriber() {
    let _ = tracing_subscriber::fmt::try_init();
}

pub const LOCALHOST: (Ipv4Addr, u16) = (Ipv4Addr::LOCALHOST, 0);

pub fn default_config() -> Config {
    // Override very long default connection timeout for tests.
    Config {
        connection_timeout: Duration::from_secs(30),
        ..Default::default()
    }
}

pub struct Milter {
    milter_handle: JoinHandle<io::Result<()>>,
    shutdown: oneshot::Sender<()>,
    addr: SocketAddr,
}

impl Milter {
    pub async fn spawn<T: Send + 'static>(
        addr: impl ToSocketAddrs,
        callbacks: Callbacks<T>,
        config: Config,
    ) -> io::Result<Self> {
        let listener = TcpListener::bind(addr).await?;

        let addr = listener.local_addr()?;

        let (shutdown, shutdown_rx) = oneshot::channel();

        let milter = tokio::spawn(indymilter::run(listener, callbacks, config, shutdown_rx));

        Ok(Self {
            milter_handle: milter,
            shutdown,
            addr,
        })
    }

    pub fn addr(&self) -> SocketAddr {
        self.addr
    }

    pub async fn shutdown(self) -> io::Result<()> {
        // The milter task was spawned with `tokio::spawn`. Sometimes the test
        // function can proceed at each await point, giving the milter task no
        // chance to run to completion before the shutdown signal is received.
        // So we first yield a few times, hoping to give the milter time to
        // proceed a bit further (eg, terminate an open session).
        //
        // Note: This is not intended to impact test correctness, rather to
        // match the behaviour of a real milter (eg in terms of logging).

        for _ in 0..10 {
            task::yield_now().await;
        }

        let _ = self.shutdown.send(());

        self.milter_handle.await?
    }
}

pub struct Client {
    stream: BufStream<TcpStream>,
}

impl Client {
    pub async fn connect(addr: impl ToSocketAddrs) -> io::Result<Self> {
        let stream = TcpStream::connect(addr).await?;

        Ok(Self {
            stream: BufStream::new(stream),
        })
    }

    pub async fn write_command(&mut self, cmd: Command) -> io::Result<()> {
        let msg = command_into_message(cmd);

        message::write(&mut self.stream, msg).await?;

        Ok(())
    }

    pub async fn write_bytes(&mut self, bytes: &[u8]) -> io::Result<()> {
        self.stream.write_all(bytes).await?;
        self.stream.flush().await?;

        Ok(())
    }

    pub async fn read_reply(&mut self) -> io::Result<Reply> {
        let msg = message::read(&mut self.stream).await?;

        let reply = parse_reply(msg);

        Ok(reply)
    }

    // Note: consumes and therefore drops this client and connection.
    pub async fn disconnect(mut self) -> io::Result<()> {
        self.stream.shutdown().await
    }
}

fn command_into_message(cmd: Command) -> Message {
    match cmd {
        Command::Abort => Message::new(CommandKind::Abort, Bytes::new()),
        Command::BodyChunk(chunk) => Message::new(CommandKind::BodyChunk, chunk),
        Command::ConnInfo(ConnInfoPayload { hostname, socket_info }) => {
            let mut buf = BytesMut::new();

            buf.put(hostname.to_bytes_with_nul());

            match socket_info {
                SocketInfo::Unknown => buf.put_u8(b'U'),
                SocketInfo::Inet(addr) => {
                    buf.put_u8(match addr {
                        SocketAddr::V4(_) => b'4',
                        SocketAddr::V6(_) => b'6',
                    });

                    buf.put_u16(addr.port());

                    let ip = CString::new(addr.ip().to_string()).unwrap();
                    buf.put(ip.to_bytes_with_nul());
                }
                SocketInfo::Unix(path) => {
                    buf.put_u8(b'L');
                    buf.put_u16(u16::MAX);
                    buf.put(path.to_bytes_with_nul());
                }
            }

            Message::new(CommandKind::ConnInfo, buf)
        }
        Command::DefMacros(MacroPayload { stage, macros }) => {
            let mut buf = BytesMut::new();

            buf.put_u8(stage.into());
            for m in macros {
                buf.put(m.to_bytes_with_nul());
            }

            Message::new(CommandKind::DefMacros, buf)
        }
        Command::BodyEnd(chunk) => Message::new(CommandKind::BodyEnd, chunk),
        Command::Helo(HeloPayload { hostname }) => {
            let hostname = hostname.to_bytes_with_nul();

            Message::new(CommandKind::Helo, Bytes::copy_from_slice(hostname))
        }
        Command::QuitNc => Message::new(CommandKind::QuitNc, Bytes::new()),
        Command::Header(HeaderPayload { name, value }) => {
            let mut buf = BytesMut::new();

            buf.put(name.to_bytes_with_nul());
            buf.put(value.to_bytes_with_nul());

            Message::new(CommandKind::Header, buf)
        }
        Command::Mail(EnvAddrPayload { args }) => {
            let mut buf = BytesMut::new();

            for arg in args {
                buf.put(arg.to_bytes_with_nul());
            }

            Message::new(CommandKind::Mail, buf)
        }
        Command::Eoh => Message::new(CommandKind::Eoh, Bytes::new()),
        Command::OptNeg(OptNegPayload { version, actions, opts }) => {
            let mut buf = BytesMut::new();

            buf.put_u32(version);
            buf.put_u32(actions.bits());
            buf.put_u32(opts.bits());

            Message::new(CommandKind::OptNeg, buf)
        }
        Command::Quit => Message::new(CommandKind::Quit, Bytes::new()),
        Command::Rcpt(EnvAddrPayload { args }) => {
            let mut buf = BytesMut::new();

            for arg in args {
                buf.put(arg.to_bytes_with_nul());
            }

            Message::new(CommandKind::Rcpt, buf)
        }
        Command::Data => Message::new(CommandKind::Data, Bytes::new()),
        Command::Unknown(UnknownPayload { arg }) => {
            let arg = arg.to_bytes_with_nul();

            Message::new(CommandKind::Unknown, Bytes::copy_from_slice(arg))
        }
    }
}

// Assuming that indymilter only answers with well-formed replies, we can assume
// that parsing the reply does not fail; hence the return type `Reply` instead
// of `Result`.
fn parse_reply(msg: Message) -> Reply {
    let kind = msg.kind.try_into().unwrap();

    let mut buf = msg.buffer;

    match kind {
        ReplyKind::AddRcpt => Reply::AddRcpt {
            rcpt: get_c_string(&mut buf),
        },
        ReplyKind::DeleteRcpt => Reply::DeleteRcpt {
            rcpt: get_c_string(&mut buf),
        },
        ReplyKind::AddRcptExt => {
            let rcpt = get_c_string(&mut buf);
            let args = if buf.has_remaining() {
                Some(get_c_string(&mut buf))
            } else {
                None
            };

            Reply::AddRcptExt { rcpt, args }
        }
        ReplyKind::OptNeg => {
            let version = buf.get_u32();
            let actions = Actions::from_bits(buf.get_u32()).unwrap();
            let opts = ProtoOpts::from_bits(buf.get_u32()).unwrap();
            let macros = {
                let mut m = HashMap::new();
                while buf.has_remaining() {
                    let stage = buf.get_i32().try_into().unwrap();
                    let macros = get_c_string(&mut buf);
                    m.insert(stage, macros);
                }
                m
            };

            Reply::OptNeg { version, actions, opts, macros }
        }
        ReplyKind::Accept => Reply::Accept,
        ReplyKind::ReplaceBody => Reply::ReplaceBody { chunk: buf },
        ReplyKind::Continue => Reply::Continue,
        ReplyKind::Discard => Reply::Discard,
        ReplyKind::ChangeSender => {
            let mail = get_c_string(&mut buf);
            let args = if buf.has_remaining() {
                Some(get_c_string(&mut buf))
            } else {
                None
            };

            Reply::ChangeSender { mail, args }
        }
        ReplyKind::AddHeader => {
            let name = get_c_string(&mut buf);
            let value = get_c_string(&mut buf);

            Reply::AddHeader { name, value }
        }
        ReplyKind::InsertHeader => {
            let index = buf.get_i32();
            let name = get_c_string(&mut buf);
            let value = get_c_string(&mut buf);

            Reply::InsertHeader { index, name, value }
        }
        ReplyKind::ChangeHeader => {
            let index = buf.get_i32();
            let name = get_c_string(&mut buf);
            let value = get_c_string(&mut buf);

            Reply::ChangeHeader { name, index, value }
        }
        ReplyKind::Progress => Reply::Progress,
        ReplyKind::Quarantine => Reply::Quarantine {
            reason: get_c_string(&mut buf),
        },
        ReplyKind::Reject => Reply::Reject,
        ReplyKind::Skip => Reply::Skip,
        ReplyKind::Tempfail => Reply::Tempfail,
        ReplyKind::ReplyCode => Reply::ReplyCode {
            reply: get_c_string(&mut buf),
        },
    }
}

fn get_c_string(buf: &mut Bytes) -> CString {
    let i = buf.iter().position(|&x| x == 0).expect("no C string");
    let buf = buf.split_to(i + 1);
    CStr::from_bytes_with_nul(buf.as_ref()).unwrap().into()
}
