use crate::message::Message;
use std::{
    io::{self, ErrorKind},
    time::Duration,
};
use tokio::{
    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream},
    sync::{mpsc, oneshot},
    time,
};

pub enum Request {
    WriteMsg {
        msg: Message,
        response: oneshot::Sender<io::Result<()>>,
    },
    ReadMsg {
        response: oneshot::Sender<io::Result<Message>>,
    },
}

pub struct StreamHandler<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    stream: BufStream<S>,
}

impl<S> StreamHandler<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    pub fn new(stream: S) -> Self {
        Self {
            stream: BufStream::new(stream),
        }
    }

    // Consumes this stream handler and handles requests for as long as requests
    // can be received on the given `Receiver`. Use with `tokio::spawn` to turn
    // this stream handler into an *actor*.
    pub async fn handle_requests(mut self, mut conn: mpsc::Receiver<Request>, timeout: Duration) {
        // The actor’s task is fail safe. It does not exit – and therefore does
        // not unexpectedly drop the `mpsc::Receiver` (or similarly the
        // `oneshot::Sender`) – until the `Connection` is dropped.

        while let Some(req) = conn.recv().await {
            match req {
                Request::ReadMsg { response } => {
                    let f = self.read_message();

                    let result = match time::timeout(timeout, f).await {
                        Ok(r) => r,
                        Err(e) => Err(e.into()),
                    };

                    let _ = response.send(result);
                }
                Request::WriteMsg { msg, response } => {
                    let f = self.write_message(msg);

                    let result = match time::timeout(timeout, f).await {
                        Ok(r) => r,
                        Err(e) => Err(e.into()),
                    };

                    let _ = response.send(result);
                }
            }
        }

        // When this actor exits it also shuts down and drops the wrapped
        // stream. An error result is no longer of interest.

        let _ = self.shutdown().await;
    }

    async fn read_message(&mut self) -> io::Result<Message> {
        let len = self.stream.read_u32().await?;
        let kind = self.stream.read_u8().await?;

        let len = usize::try_from(len)
            .expect("unsupported pointer size")
            .saturating_sub(1);

        if len > Message::MAX_BUFFER_LEN {
            return Err(ErrorKind::InvalidData.into());
        }

        let mut buffer = vec![0; len];
        self.stream.read_exact(&mut buffer).await?;

        Ok(Message::new(kind, buffer))
    }

    async fn write_message(&mut self, msg: Message) -> io::Result<()> {
        let len = msg.buffer.len();

        if len > Message::MAX_BUFFER_LEN {
            return Err(ErrorKind::InvalidData.into());
        }

        let len = u32::try_from(len).unwrap().checked_add(1).unwrap();

        self.stream.write_u32(len).await?;
        self.stream.write_u8(msg.kind).await?;
        self.stream.write_all(msg.buffer.as_ref()).await?;
        self.stream.flush().await?;

        Ok(())
    }

    // Consume this stream handler, shutting down the stream and dropping it.
    async fn shutdown(mut self) -> io::Result<()> {
        self.stream.shutdown().await
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn read_message() {
        let mut handler = {
            let (mut client, stream) = tokio::io::duplex(100);

            client.write_all(b"\0\0\0\x02xyz").await.unwrap();

            StreamHandler::new(stream)
        };

        let msg = handler.read_message().await.unwrap();
        assert_eq!(msg, Message::new(b'x', &b"y"[..]));

        let error = handler.read_message().await.unwrap_err();
        assert_eq!(error.kind(), ErrorKind::UnexpectedEof);
    }

    #[tokio::test]
    async fn write_message() {
        let (mut client, stream) = tokio::io::duplex(100);

        let mut handler = StreamHandler::new(stream);

        let msg = Message::new(b'x', &b"abc"[..]);
        handler.write_message(msg).await.unwrap();
        handler.stream.write_u8(1).await.unwrap();

        handler.shutdown().await.unwrap();

        let mut buffer = vec![0; 8];
        client.read_exact(&mut buffer).await.unwrap();
        assert_eq!(buffer, b"\0\0\0\x04xabc");

        let byte = client.read_u8().await.unwrap();
        assert_eq!(byte, 1);

        let error = client.read_u8().await.unwrap_err();
        assert_eq!(error.kind(), ErrorKind::UnexpectedEof);
    }
}
