#[cfg(feature = "miltertest")]
mod miltertest;

#[cfg(feature = "miltertest")]
pub use self::miltertest::*;

use bytes::{Buf, BufMut, Bytes, BytesMut};
use indymilter::{
    message::{
        commands::{
            BodyPayload, Command, CommandKind, ConnInfoPayload, EnvAddrPayload, HeaderPayload,
            HeloPayload, MacroPayload, OptNegPayload,
        },
        replies::{Reply, ReplyKind},
        Message,
    },
    Actions, Callbacks, Config, ProtoOpts,
};
use std::{
    collections::HashMap,
    ffi::{CStr, CString},
    io::{self, ErrorKind},
    net::{Ipv4Addr, SocketAddr},
};
use tokio::{
    io::{AsyncReadExt, AsyncWriteExt, BufStream},
    net::{TcpListener, TcpStream, ToSocketAddrs},
    sync::oneshot,
    task::{self, JoinHandle},
};

pub const LOCALHOST: (Ipv4Addr, u16) = (Ipv4Addr::LOCALHOST, 0);

pub struct Milter {
    milter_handle: JoinHandle<io::Result<()>>,
    shutdown: oneshot::Sender<()>,
    addr: SocketAddr,
}

impl Milter {
    pub async fn spawn<T: Send + 'static>(
        addr: impl ToSocketAddrs,
        callbacks: Callbacks<T>,
        config: Config,
    ) -> io::Result<Self> {
        let listener = TcpListener::bind(addr).await?;

        let addr = listener.local_addr()?;

        let (shutdown_tx, shutdown_rx) = oneshot::channel();

        let milter = tokio::spawn(indymilter::run(listener, callbacks, config, shutdown_rx));

        Ok(Self {
            milter_handle: milter,
            shutdown: shutdown_tx,
            addr,
        })
    }

    pub fn addr(&self) -> SocketAddr {
        self.addr
    }

    pub async fn shutdown(self) -> io::Result<()> {
        // The milter task was spawned with `tokio::spawn`. Sometimes the test
        // function can proceed at each await point, giving the milter task no
        // chance to run to completion before the shutdown signal is received.
        // So we first yield a few times, hoping to give the milter time to
        // proceed a bit further (eg, terminate an open session).
        //
        // Note: This is not intended to impact test correctness, rather to
        // match the behaviour of a real milter (eg logging).

        for _ in 0..10 {
            task::yield_now().await;
        }

        let _ = self.shutdown.send(());

        self.milter_handle.await?
    }
}

pub struct Client {
    stream: BufStream<TcpStream>,
}

impl Client {
    pub async fn connect(addr: impl ToSocketAddrs) -> io::Result<Self> {
        let stream = TcpStream::connect(addr).await?;

        Ok(Self {
            stream: BufStream::new(stream),
        })
    }

    pub async fn write_command(&mut self, cmd: Command) -> io::Result<()> {
        let msg = match cmd {
            Command::Abort => Message::new(CommandKind::Abort, Bytes::new()),
            Command::BodyChunk(BodyPayload { chunk }) => Message::new(CommandKind::BodyChunk, chunk),
            Command::ConnInfo(ConnInfoPayload { hostname, socket_info }) => {
                let mut buf = BytesMut::new();

                buf.put(hostname.to_bytes_with_nul());

                match socket_info {
                    None => {
                        buf.put_u8(b'U');
                    }
                    _ => todo!(),
                }

                Message::new(CommandKind::ConnInfo, buf)
            }
            Command::DefMacros(MacroPayload { stage, macros }) => {
                let mut buf = BytesMut::new();

                buf.put_u8(stage.into());
                for m in macros {
                    buf.put(m.to_bytes_with_nul());
                }

                Message::new(CommandKind::DefMacros, buf)
            }
            Command::BodyEnd(BodyPayload { chunk }) => Message::new(CommandKind::BodyEnd, chunk),
            Command::Helo(HeloPayload { hostname }) => {
                let hostname = hostname.to_bytes_with_nul();

                Message::new(CommandKind::Helo, Bytes::copy_from_slice(hostname))
            }
            Command::QuitNc => Message::new(CommandKind::QuitNc, Bytes::new()),
            Command::Header(HeaderPayload { name, value }) => {
                let mut buf = BytesMut::new();

                buf.put(name.to_bytes_with_nul());
                buf.put(value.to_bytes_with_nul());

                Message::new(CommandKind::Header, buf)
            }
            Command::Mail(EnvAddrPayload { args }) => {
                let mut buf = BytesMut::new();

                for arg in args {
                    buf.put(arg.to_bytes_with_nul());
                }

                Message::new(CommandKind::Mail, buf)
            }
            Command::Eoh => Message::new(CommandKind::Eoh, Bytes::new()),
            Command::OptNeg(OptNegPayload { version, actions, opts }) => {
                let mut buf = BytesMut::new();

                buf.put_u32(version);
                buf.put_u32(actions.bits());
                buf.put_u32(opts.bits());

                Message::new(CommandKind::OptNeg, buf)
            }
            Command::Quit => Message::new(CommandKind::Quit, Bytes::new()),
            Command::Rcpt(EnvAddrPayload { args }) => {
                let mut buf = BytesMut::new();

                for arg in args {
                    buf.put(arg.to_bytes_with_nul());
                }

                Message::new(CommandKind::Rcpt, buf)
            }
            Command::Data => Message::new(CommandKind::Data, Bytes::new()),
            _ => todo!(),
        };

        self.write_message(msg).await?;

        Ok(())
    }

    pub async fn write_message(&mut self, msg: Message) -> io::Result<()> {
        let len = msg.buffer.len();

        if len > Message::MAX_BUFFER_LEN {
            return Err(ErrorKind::InvalidData.into());
        }

        let len = u32::try_from(len).unwrap().checked_add(1).unwrap();

        self.stream.write_u32(len).await?;
        self.stream.write_u8(msg.kind).await?;
        self.stream.write_all(msg.buffer.as_ref()).await?;
        self.stream.flush().await?;

        Ok(())
    }

    pub async fn write_bytes(&mut self, bytes: &[u8]) -> io::Result<()> {
        self.stream.write_all(bytes).await?;
        self.stream.flush().await?;

        Ok(())
    }

    pub async fn read_reply(&mut self) -> io::Result<Reply> {
        let msg = self.read_message().await?;

        match parse_reply(msg) {
            Ok(reply) => Ok(reply),
            Err(_e) => Err(ErrorKind::InvalidData.into()),
        }
    }

    async fn read_message(&mut self) -> io::Result<Message> {
        let len = self.stream.read_u32().await?;
        let len = usize::try_from(len)
            .expect("unsupported pointer size")
            .saturating_sub(1);

        if len > Message::MAX_BUFFER_LEN {
            return Err(ErrorKind::InvalidData.into());
        }

        let kind = self.stream.read_u8().await?;

        let mut buffer = vec![0; len];
        self.stream.read_exact(&mut buffer).await?;

        Ok(Message {
            kind,
            buffer: buffer.into(),
        })
    }

    // Note: consumes and therefore drops this client and connection.
    pub async fn disconnect(mut self) -> io::Result<()> {
        self.stream.shutdown().await
    }
}

// TODO implement properly, no unwrap
use std::error::Error;
fn parse_reply(mut msg: Message) -> Result<Reply, Box<dyn Error + Send + Sync>> {
    let kind = msg.kind.try_into().unwrap();
    match kind {
        ReplyKind::Continue => Ok(Reply::Continue),
        ReplyKind::OptNeg => {
            let mut buf = msg.buffer;

            let version = buf.get_u32();
            let actions = Actions::from_bits(buf.get_u32()).unwrap();
            let opts = ProtoOpts::from_bits(buf.get_u32()).unwrap();
            let macros = {
                // TODO macros
                let mut m = HashMap::new();
                while buf.has_remaining() {
                    let stage = buf.get_i32().try_into().unwrap();
                    let macros = get_c_string(&mut buf).unwrap();
                    m.insert(stage, macros);
                }
                m
            };

            Ok(Reply::OptNeg {
                version,
                actions,
                opts,
                macros,
            })
        }
        ReplyKind::AddHeader => {
            let mut buf = msg.buffer;

            let name = get_c_string(&mut buf).unwrap();
            let value = get_c_string(&mut buf).unwrap();

            Ok(Reply::AddHeader {
                name,
                value,
            })
        }
        ReplyKind::ReplyCode => {
            let reply = get_c_string(&mut msg.buffer).unwrap();

            Ok(Reply::ReplyCode { reply })
        }
        _ => todo!(),
    }
}

// TODO
fn get_c_string(buf: &mut Bytes) -> Result<CString, Box<dyn Error + Send + Sync>> {
    if let Some(i) = buf.iter().position(|&x| x == 0) {
        let b = buf.split_to(i + 1);
        return Ok(CStr::from_bytes_with_nul(b.as_ref()).unwrap().into());
    }
    Err("no c string".into())
}
