use crate::{
    macros::Stage,
    message::{
        replies::{Reply, ReplyKind},
        Message,
    },
    stream::{Request, StreamHandler},
};
use bytes::{BufMut, Bytes, BytesMut};
use std::{io, time::Duration};
use tokio::{
    io::{AsyncRead, AsyncWrite},
    sync::{mpsc, oneshot},
};
use tracing::trace;

// A connection can be cloned. This means that cloned connections all hold a
// handle to the message-processing task (actor) holding the actual connection
// stream. Once the last `Connection` is dropped, so is the `mpsc::Sender`
// handle and the task then exits.
//
// Note that in a milter session, a connection is only ever read from or written
// to serially. However, during the eom stage, a clone of the connection is used
// to write replies from the eom callback.
#[derive(Clone)]
pub struct Connection {
    conn: mpsc::Sender<Request>,
}

impl Connection {
    pub fn new<S>(stream: S, timeout: Duration) -> Self
    where
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
    {
        let (messages_tx, messages_rx) = mpsc::channel(1);

        // Creating a `Connection` spawns a task that processes requests for as
        // long as the connection (or a clone) exists. When the last
        // `mpsc::Sender` held by a connection is dropped the call on `recv`
        // will return and the task exits.

        // The stream handler takes exclusive ownership of the stream, and is
        // handed off directly to `tokio::spawn`, becoming an *actor* solely
        // responsible for reading and writing messages to the stream.

        let handler = StreamHandler::new(stream);

        tokio::spawn(handler.handle_requests(messages_rx, timeout));

        Self { conn: messages_tx }
    }

    pub async fn read_message(&self) -> io::Result<Message> {
        let (response_tx, response) = oneshot::channel();

        let request = Request::ReadMsg {
            response: response_tx,
        };

        self.conn.send(request).await.unwrap_or_else(|_| unreachable!());

        let result = response.await.unwrap();

        if let Ok(msg) = &result {
            trace!(?msg, "message read");
        }

        result
    }

    pub async fn write_message(&self, msg: Message) -> io::Result<()> {
        let (response_tx, response) = oneshot::channel();

        trace!(?msg, "writing message");

        let request = Request::WriteMsg {
            msg,
            response: response_tx,
        };

        self.conn.send(request).await.unwrap_or_else(|_| unreachable!());

        response.await.unwrap()
    }

    pub async fn write_reply(&self, reply: Reply) -> io::Result<()> {
        let msg = match reply {
            Reply::AddRcpt { rcpt } => {
                let rcpt = rcpt.to_bytes_with_nul();

                Message::new(ReplyKind::AddRcpt, Bytes::copy_from_slice(rcpt))
            }
            Reply::DeleteRcpt { rcpt } => {
                let rcpt = rcpt.to_bytes_with_nul();

                Message::new(ReplyKind::DeleteRcpt, Bytes::copy_from_slice(rcpt))
            }
            Reply::AddRcptExt { rcpt, args } => {
                let rcpt = rcpt.to_bytes_with_nul();

                let mut buf = BytesMut::with_capacity(rcpt.len());

                buf.put(rcpt);
                if let Some(args) = args {
                    buf.put(args.to_bytes_with_nul());
                }

                Message::new(ReplyKind::AddRcptExt, buf)
            }
            Reply::OptNeg { version, actions, opts, macros } => {
                let mut buf = BytesMut::with_capacity(12);

                buf.put_u32(version);
                buf.put_u32(actions.bits());
                buf.put_u32(opts.bits());

                for stage in Stage::all_stages() {
                    if let Some(macros) = macros.get(&stage) {
                        buf.put_i32(stage.into());
                        buf.put(macros.to_bytes_with_nul());
                    }
                }

                Message::new(ReplyKind::OptNeg, buf)
            }
            Reply::Accept => Message::new(ReplyKind::Accept, Bytes::new()),
            Reply::ReplaceBody { chunk } => Message::new(ReplyKind::ReplaceBody, chunk),
            Reply::Continue => Message::new(ReplyKind::Continue, Bytes::new()),
            Reply::Discard => Message::new(ReplyKind::Discard, Bytes::new()),
            Reply::ChangeSender { mail, args } => {
                let mail = mail.to_bytes_with_nul();

                let mut buf = BytesMut::with_capacity(mail.len());

                buf.put(mail);
                if let Some(args) = args {
                    buf.put(args.to_bytes_with_nul());
                }

                Message::new(ReplyKind::ChangeSender, buf)
            }
            Reply::AddHeader { name, value } => {
                let name = name.to_bytes_with_nul();
                let value = value.to_bytes_with_nul();

                let mut buf = BytesMut::with_capacity(name.len() + value.len());

                buf.put(name);
                buf.put(value);

                Message::new(ReplyKind::AddHeader, buf)
            }
            Reply::InsertHeader { index, name, value } => {
                let name = name.to_bytes_with_nul();
                let value = value.to_bytes_with_nul();

                let mut buf = BytesMut::with_capacity(name.len() + value.len() + 4);

                buf.put_i32(index);
                buf.put(name);
                buf.put(value);

                Message::new(ReplyKind::InsertHeader, buf)
            }
            Reply::ChangeHeader { name, index, value } => {
                let name = name.to_bytes_with_nul();
                let value = value.to_bytes_with_nul();

                let mut buf = BytesMut::with_capacity(name.len() + value.len() + 4);

                buf.put_i32(index);
                buf.put(name);
                buf.put(value);

                Message::new(ReplyKind::ChangeHeader, buf)
            }
            Reply::Progress => Message::new(ReplyKind::Progress, Bytes::new()),
            Reply::Quarantine { reason } => {
                let reason = reason.to_bytes_with_nul();

                Message::new(ReplyKind::Quarantine, Bytes::copy_from_slice(reason))
            }
            Reply::Reject => Message::new(ReplyKind::Reject, Bytes::new()),
            Reply::Skip => Message::new(ReplyKind::Skip, Bytes::new()),
            Reply::Tempfail => Message::new(ReplyKind::Tempfail, Bytes::new()),
            Reply::ReplyCode { reply } => {
                let reply = reply.to_bytes_with_nul();

                Message::new(ReplyKind::ReplyCode, Bytes::copy_from_slice(reply))
            }
        };

        self.write_message(msg).await
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::ErrorKind;
    use tokio::{
        io::{AsyncReadExt, AsyncWriteExt},
        join, time,
    };

    #[tokio::test]
    async fn multiple_connections() {
        let (mut client, stream) = tokio::io::duplex(100);

        let conn1 = Connection::new(stream, Duration::from_secs(30));
        let conn2 = conn1.clone();

        // First, read from and write to the first connection.

        client.write_all(b"\0\0\0\x03xyz").await.unwrap();

        let msg = conn1.read_message().await.unwrap();
        assert_eq!(msg, Message::new(b'x', &b"yz"[..]));
        conn1.write_message(msg).await.unwrap();

        let mut buffer = vec![0; 7];
        client.read_exact(&mut buffer).await.unwrap();
        assert_eq!(buffer, b"\0\0\0\x03xyz");

        // Drop the first connection and continue using the second connection.

        drop(conn1);

        let msg = Message::new(b'x', &b"abc"[..]);
        conn2.write_message(msg).await.unwrap();

        let mut buffer = vec![0; 8];
        client.read_exact(&mut buffer).await.unwrap();
        assert_eq!(buffer, b"\0\0\0\x04xabc");

        // Drop the second and last remaining connection, closing the stream.

        drop(conn2);

        let e = client.read_u8().await.unwrap_err();
        assert_eq!(e.kind(), ErrorKind::UnexpectedEof);
    }

    #[tokio::test]
    async fn connection_timeout() {
        let timeout = Duration::from_secs(30);

        let (mut client, stream) = tokio::io::duplex(100);
        let conn = Connection::new(stream, timeout);

        // Both the `Connection` and the client end of the duplex stream are
        // moved into the futures. Connection therefore is dropped and closes
        // after 30 seconds. The client write then fails as it attempts to write
        // a few seconds later.

        time::pause();

        let (stream_result, client_result) = join!(
            async move { conn.read_message().await },
            async move {
                client.write_all(b"\0\0\0\x05").await.unwrap();
                time::sleep(timeout + Duration::from_secs(5)).await;
                client.write_all(b"Xyzabc").await
            },
        );

        time::resume();

        let e = stream_result.unwrap_err();
        assert_eq!(e.kind(), ErrorKind::TimedOut);

        let e = client_result.unwrap_err();
        assert_eq!(e.kind(), ErrorKind::BrokenPipe);
    }
}
