use crate::{
    frame::{Frame, FrameError},
    macros::Stage,
    reply::{Reply, ReplyKind},
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::{
    convert::{TryFrom, TryInto},
    io::{self, Cursor, ErrorKind},
    time::Duration,
};
use tokio::{
    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufWriter},
    sync::{mpsc, oneshot},
    time,
};

enum Message {
    WriteFrame {
        frame: Frame,
        response: oneshot::Sender<io::Result<()>>,
    },
    ReadFrame {
        response: oneshot::Sender<io::Result<Option<Frame>>>,
    },
}

// A connection can be cloned. This means that cloned connections all hold a
// handle to the message-processing task holding the actual connection. Once the
// last `Connection` is dropped, so is the `mpsc::Sender` handle and the task
// then exits.
#[derive(Clone)]
pub struct Connection {
    conn: mpsc::Sender<Message>,
}

impl Connection {
    pub fn new<S>(stream: S, timeout: Duration) -> Self
    where
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
    {
        let (messages_tx, messages_rx) = mpsc::channel(1);

        let mut stream = BufWriter::new(stream);
        let mut buffer = BytesMut::with_capacity(8192);

        // Creating a `Connection` spawns a task that processes messages for as
        // long as the connection (or a clone) exists. When the last
        // `mpsc::Sender` held by a connection is dropped the call on `recv`
        // will return and the task exits.

        // The stream, ie the actual connection is handed off to this task, and
        // only written to and read from by the task.

        tokio::spawn(async move {
            let mut conn = messages_rx;

            // TODO move into separate "actor" struct?
            while let Some(msg) = conn.recv().await {
                match msg {
                    Message::ReadFrame { response } => {
                        let task = read_frame(&mut stream, &mut buffer);

                        let result = match time::timeout(timeout, task).await {
                            Ok(r) => r,
                            Err(e) => Err(e.into()),
                        };

                        let _ = response.send(result);
                    }
                    Message::WriteFrame { frame, response } => {
                        let task = write_frame(&mut stream, frame);

                        let result = match time::timeout(timeout, task).await {
                            Ok(r) => r,
                            Err(e) => Err(e.into()),
                        };

                        let _ = response.send(result);
                    }
                }
            }

            // TODO correct? necessary?:
            let _ = stream.shutdown().await;
        });

        Self { conn: messages_tx }
    }

    pub async fn read_frame(&self) -> io::Result<Option<Frame>> {
        let (response_tx, response) = oneshot::channel();

        let msg = Message::ReadFrame {
            response: response_tx,
        };

        if self.conn.send(msg).await.is_err() {
            unreachable!();
        }

        response.await.unwrap()
    }

    pub async fn write_frame(&self, frame: Frame) -> io::Result<()> {
        let (response_tx, response) = oneshot::channel();

        let msg = Message::WriteFrame {
            frame,
            response: response_tx,
        };

        if self.conn.send(msg).await.is_err() {
            unreachable!();
        }

        response.await.unwrap()
    }

    pub async fn write_reply(&self, reply: Reply) -> io::Result<()> {
        // TODO use BytesMut::with_capacity
        let frame = match reply {
            Reply::OptNeg { version, actions, popts, requested_macros } => {
                let mut buf = BytesMut::new();

                buf.put_u32(version);
                buf.put_u32(actions.bits());
                buf.put_u32(popts.bits());

                for stage in Stage::all_stages() {
                    if let Some(macros) = requested_macros.get(&stage) {
                        buf.put_i32(stage.into());
                        buf.put(macros.to_bytes_with_nul());
                    }
                }

                Frame::new(ReplyKind::OptNeg, buf)
            }
            Reply::Accept => Frame::new(ReplyKind::Accept, Bytes::new()),
            Reply::Continue => Frame::new(ReplyKind::Continue, Bytes::new()),
            Reply::Discard => Frame::new(ReplyKind::Discard, Bytes::new()),
            Reply::Skip => Frame::new(ReplyKind::Skip, Bytes::new()),
            Reply::Reject => Frame::new(ReplyKind::Reject, Bytes::new()),
            Reply::Tempfail => Frame::new(ReplyKind::Tempfail, Bytes::new()),
            Reply::ReplyCode { reply } => {
                let bytes = reply.to_bytes_with_nul();

                Frame::new(ReplyKind::ReplyCode, Bytes::copy_from_slice(bytes))
            }
            Reply::AddHeader { name, value } => {
                let mut buf = BytesMut::new();

                buf.put(name.to_bytes_with_nul());
                buf.put(value.to_bytes_with_nul());

                Frame::new(ReplyKind::AddHeader, buf)
            }
            Reply::InsertHeader { index, name, value } => {
                let mut buf = BytesMut::new();

                buf.put_i32(index);
                buf.put(name.to_bytes_with_nul());
                buf.put(value.to_bytes_with_nul());

                Frame::new(ReplyKind::InsertHeader, buf)
            }
            Reply::ChangeHeader { name, index, value } => {
                let mut buf = BytesMut::new();

                buf.put_i32(index);
                buf.put(name.to_bytes_with_nul());
                buf.put(value.to_bytes_with_nul());

                Frame::new(ReplyKind::ChangeHeader, buf)
            }
            Reply::ReplaceBody { chunk } => Frame::new(ReplyKind::ReplaceBody, Bytes::from(chunk)),
            Reply::ChangeSender { mail, args } => {
                let mut buf = BytesMut::new();

                buf.put(mail.to_bytes_with_nul());
                if let Some(args) = args {
                    buf.put(args.to_bytes_with_nul());
                }

                Frame::new(ReplyKind::ChangeSender, buf)
            }
            Reply::AddRcpt { rcpt } => {
                let bytes = rcpt.to_bytes_with_nul();

                Frame::new(ReplyKind::AddRcpt, Bytes::copy_from_slice(bytes))
            }
            Reply::AddRcptExt { rcpt, args } => {
                let mut buf = BytesMut::new();

                buf.put(rcpt.to_bytes_with_nul());
                if let Some(args) = args {
                    buf.put(args.to_bytes_with_nul());
                }

                Frame::new(ReplyKind::AddRcptExt, buf)
            }
            Reply::DeleteRcpt { rcpt } => {
                let bytes = rcpt.to_bytes_with_nul();

                Frame::new(ReplyKind::DeleteRcpt, Bytes::copy_from_slice(bytes))
            }
            Reply::Progress => Frame::new(ReplyKind::Progress, Bytes::new()),
            Reply::Quarantine { reason } => {
                let bytes = reason.to_bytes_with_nul();

                Frame::new(ReplyKind::Quarantine, Bytes::copy_from_slice(bytes))
            }
        };

        self.write_frame(frame).await
    }
}

// TODO hm, adjustable? should be higher
const MILTER_MAX_DATA_SIZE: usize = 65535;

async fn read_frame<R: AsyncRead + Unpin>(
    stream: &mut R,
    buffer: &mut BytesMut,
) -> io::Result<Option<Frame>> {
    loop {
        if let Some(cmd) = parse_frame(buffer)? {
            return Ok(Some(cmd));
        }

        if stream.read_buf(buffer).await? == 0 {
            return if buffer.is_empty() {
                Ok(None)
            } else {
                Err(ErrorKind::ConnectionReset.into())
            };
        }
    }
}

fn parse_frame(buffer: &mut BytesMut) -> io::Result<Option<Frame>> {
    let mut buf = Cursor::new(&buffer[..]);

    match Frame::parse(&mut buf) {
        Ok(frame) => {
            let len = buf.position().try_into().unwrap();

            buffer.advance(len);

            Ok(Some(frame))
        }
        Err(FrameError::Incomplete) => Ok(None),
        Err(FrameError::TooLarge) => Err(ErrorKind::InvalidData.into()),
    }
}

async fn write_frame<W: AsyncWrite + Unpin>(stream: &mut W, frame: Frame) -> io::Result<()> {
    // TODO limit frame size like in Frame::parse
    assert!(frame.buffer.len() <= MILTER_MAX_DATA_SIZE);

    let buflen = u32::try_from(frame.buffer.len()).unwrap();
    let buflen = buflen.checked_add(1).unwrap();

    stream.write_u32(buflen).await?;
    stream.write_u8(frame.kind).await?;
    stream.write_all(frame.buffer.as_ref()).await?;

    stream.flush().await?; // ???

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn read_frame_ok() {
        let mut buf = BytesMut::new();

        let mut v = Cursor::new(b"\0\0\0\x02xy");
        let frame = read_frame(&mut v, &mut buf).await.unwrap();
        assert_eq!(frame, Some(Frame::new(b'x', &b"y"[..])));
        let frame = read_frame(&mut v, &mut buf).await.unwrap();
        assert_eq!(frame, None);

        buf.put_i32(7);
        let mut v = Cursor::new(b"");
        let frame = read_frame(&mut v, &mut buf).await;
        assert!(frame.is_err());
    }

    #[tokio::test]
    async fn write_frame_ok() {
        let mut v = Vec::new();
        write_frame(&mut v, Frame::new(b'x', &b"y"[..])).await.unwrap();
        assert_eq!(v, b"\0\0\0\x02xy");
    }
}
