use carboncopy::{BoxFuture, Sink};
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{stdout, AsyncWriteExt, Stdout};
use tokio::runtime::Runtime;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender as DropTx};
use tokio::sync::watch::{channel as watch_channel, Receiver as WatchRx};
use tokio::sync::Mutex;
use tokio::time::sleep;

/// A sink with memory buffer and periodic flusher built with Tokio facilities so it is
/// suitable for binaries relying on the Tokio executor. It can also be used by async-blind
/// library clients since the Sink trait offers blocking API.
///
/// Its mutable interior is guarded by `Arc<Mutex<...>>`.
pub struct BufSink<T: AsyncWriteExt + Unpin + Send + 'static> {
    rt: Arc<Runtime>,
    interior: Arc<Mutex<Interior<T>>>,
    drop_chan_tx: DropTx<EmptySignal>,
    last_flush_err_chan_rx: WatchRx<Option<Arc<std::io::Error>>>,
}

impl<T: AsyncWriteExt + Unpin + Send + 'static> Sink for BufSink<T> {
    fn sink_blocking(&self, entry: String) -> std::io::Result<()> {
        self.rt.block_on(self.sink(entry))
    }

    fn sink(&self, entry: String) -> BoxFuture<std::io::Result<()>> {
        Box::pin(async move {
            let mut inner = self.interior.lock().await;
            if let Some((buf, _)) = inner.buf.as_mut() {
                let _ = buf.write(entry.as_bytes()).await; // infallible, writing to memory
                Ok(())
            } else {
                // no buffer, write directly to stdout
                inner.output_writer.write(entry.as_bytes()).await?;
                Ok(())
            }
        })
    }
}

impl<T: AsyncWriteExt + Unpin + Send + 'static> Drop for BufSink<T> {
    fn drop(&mut self) {
        let _ = self.drop_chan_tx.send(EmptySignal);
    }
}

impl<T: AsyncWriteExt + Unpin + Send + 'static> BufSink<T> {
    /// At the same time as the instantiation, a flusher task is spawned in the background
    /// whose job is to flush after the buffer overflows or a set timeout has elapsed
    /// (whichever happens first).
    ///
    /// The flusher task will terminate when the instance is dropped.
    pub fn new(opts: SinkOptions<T>) -> Self {
        let interior = Arc::new(Mutex::new(Interior {
            backlogged: false,
            buf: if opts.buffer.is_none() {
                None
            } else {
                let cap = opts.buffer.as_ref().unwrap();
                Some((Vec::with_capacity(cap.0), cap.0))
            },
            output_writer: opts.output_writer,
        }));

        let (drop_tx, mut drop_rx) = unbounded_channel();
        let (err_tx, err_rx) = watch_channel(None);

        let rt = opts.tokio_runtime.clone();
        let interior_clone = interior.clone();
        let timeout_ms = opts.flush_timeout_ms;
        rt.spawn(async move {
            if interior_clone.lock().await.buf.is_some() {
                loop {
                    let overflow = async {
                        loop {
                            {
                                let interior_check = interior_clone.lock().await;
                                if interior_check.buf.as_ref().unwrap().0.len()
                                    >= interior_check.buf.as_ref().unwrap().1
                                {
                                    return;
                                }
                            }
                            if timeout_ms > 1 {
                                sleep(Duration::from_millis(1)).await;
                            }
                        }
                    };
                    let timeout = async move {
                        sleep(Duration::from_millis(timeout_ms)).await;
                    };
                    tokio::select! {
                        _ = overflow => {
                            if let Err(io_err) = interior_clone.lock().await.flush().await {
                                // error will be returned by send() if the sink has been dropped,
                                // at which point, the error no longer matters
                                let _ = err_tx.send(Some(Arc::new(io_err)));
                            } else {
                                let _ = err_tx.send(None);
                            };
                        }
                        _ = timeout => {
                            if let Err(io_err) = interior_clone.lock().await.flush().await {
                                let _ = err_tx.send(Some(Arc::new(io_err)));
                            } else {
                                let _ = err_tx.send(None);
                            };
                        }
                        _ = drop_rx.recv() => {
                            return; // sink instance dropped, terminate loop/task
                        }
                    }
                }
            } else {
                return; // no need for buffer checks if there is no buffer
            }
        });

        Self {
            rt: rt,
            interior: interior,
            drop_chan_tx: drop_tx,
            last_flush_err_chan_rx: err_rx,
        }
    }

    /// Attempts to manually flush the underlying buffer to Stdout.
    pub async fn flush(&self) -> std::io::Result<usize> {
        self.interior.lock().await.flush().await
    }

    /// Checks if buffer flushing is being backlogged (not necessarily by errors).
    pub async fn backlogged(&self) -> bool {
        self.interior.lock().await.backlogged()
    }

    /// Checks if the flusher has just encountered an error. Only use this function to check
    /// for long running errors. A temporary error could already be cleared by retries by the
    /// time you call this function.
    ///
    /// A bufferless sink will always return None.
    pub fn last_flush_err(&self) -> Option<Arc<std::io::Error>> {
        self.last_flush_err_chan_rx.borrow().clone()
    }
}

struct Interior<T: AsyncWriteExt + Unpin + Send + 'static> {
    backlogged: bool,
    buf: Option<(Vec<u8>, usize)>,
    output_writer: T,
}

impl<T: AsyncWriteExt + Unpin + Send + 'static> Interior<T> {
    async fn flush(&mut self) -> Result<usize, std::io::Error> {
        if self.buf.is_none() {
            Ok(0)
        } else {
            let vec_len = self.buf.as_ref().unwrap().0.len();
            if vec_len > 0 {
                let mut written: usize = 0;
                while vec_len > 0 {
                    let res = self
                        .output_writer
                        .write(self.buf.as_ref().unwrap().0.as_slice())
                        .await;

                    // clear first N elements of vec according to res,
                    // or empty vec if N == vec_len
                    if let Ok(delta) = res {
                        if delta == 0 {
                            return res;
                        }
                        if delta == vec_len {
                            self.buf.as_mut().unwrap().0 =
                                Vec::with_capacity(self.buf.as_ref().unwrap().1);
                            self.backlogged = false;
                        } else {
                            self.buf.as_mut().unwrap().0.drain(0..delta);
                            self.backlogged = true;
                        }
                        written += delta;
                    } else {
                        self.backlogged = true;
                        return res;
                    }
                }
                Ok(written)
            } else {
                Ok(0)
            }
        }
    }

    fn backlogged(&self) -> bool {
        self.backlogged
    }
}

/// Implements the Default trait.
pub struct SinkOptions<T: AsyncWriteExt + Unpin + Send + 'static> {
    pub buffer: Option<BufferOverflowThreshold>,
    pub flush_timeout_ms: u64,
    pub tokio_runtime: Arc<Runtime>,
    pub output_writer: T,
}

impl Default for SinkOptions<Stdout> {
    fn default() -> Self {
        // unwrap safety: any panic will cause default_options_dont_panic() test to fail
        Self {
            buffer: Some(BufferOverflowThreshold::new(64 * 1024).unwrap()),
            flush_timeout_ms: 100,
            tokio_runtime: Arc::new(Runtime::new().unwrap()),
            output_writer: stdout(),
        }
    }
}

/// A size threshold after which the buffer will be flushed. The size of the buffer itself is
/// unlimited.
#[derive(Debug, PartialEq, Eq, Copy, Clone, Ord, PartialOrd)]
pub struct BufferOverflowThreshold(usize);

impl BufferOverflowThreshold {
    /// Must be greater than 1KB (1024) and less than 1GB (1024 * 1024 * 1024).
    pub fn new(cap: usize) -> Result<Self, ThresholdError> {
        const KB: usize = 1024;
        const GB: usize = 1024 * 1024 * 1024;
        if cap >= 1 * KB && cap <= 1 * GB {
            Ok(Self(cap))
        } else if cap < 1 * KB {
            Err(ThresholdError::LessThan1KB)
        } else {
            Err(ThresholdError::MoreThan1GB)
        }
    }
}

#[derive(Debug, PartialEq, Eq, Copy, Clone, Ord, PartialOrd)]
pub enum ThresholdError {
    LessThan1KB,
    MoreThan1GB,
}

impl fmt::Display for ThresholdError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Self::LessThan1KB => {
                write!(
                    f,
                    "buffer overflow threshold can't be less than 1024 bytes (1KB)"
                )
            }
            Self::MoreThan1GB => {
                write!(
                    f,
                    "buffer overflow threshold can't be greater than 1024 * 1024 * 1024 bytes (1GB)",
                )
            }
        }
    }
}

struct EmptySignal;

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn overflow_threshold() {
        assert_eq!(
            BufferOverflowThreshold::new(1000).err().unwrap(),
            ThresholdError::LessThan1KB
        );
        assert_eq!(
            BufferOverflowThreshold::new(1024 * 1024 * 1024 + 1)
                .err()
                .unwrap(),
            ThresholdError::MoreThan1GB
        );
    }

    #[test]
    fn default_options_dont_panic() {
        assert_eq!(100, SinkOptions::default().flush_timeout_ms); // no is panic good enough
    }

    #[test]
    fn no_buffer() {
        // setup
        let rt = Arc::new(Runtime::new().unwrap());
        let opts = SinkOptions {
            buffer: None,
            flush_timeout_ms: 30,
            tokio_runtime: rt.clone(),
            output_writer: Vec::new(),
        };
        let mem_sink = Arc::new(BufSink::new(opts));
        // end setup

        for i in 0..5 {
            assert!(rt
                .block_on(async {
                    mem_sink
                        .clone()
                        .sink(String::from(format!("hello world {}\n", i)))
                        .await
                })
                .is_ok());
        }

        let ref_output =
            "hello world 0\nhello world 1\nhello world 2\nhello world 3\nhello world 4\n";

        let output =
            rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });

        assert_eq!(ref_output, std::str::from_utf8(output.as_ref()).unwrap());
    }

    #[test]
    fn timeout_flush() {
        // setup
        let rt = Arc::new(Runtime::new().unwrap());
        let opts = SinkOptions {
            buffer: Some(BufferOverflowThreshold::new(64 * 1024).unwrap()),
            flush_timeout_ms: 30,
            tokio_runtime: rt.clone(),
            output_writer: Vec::new(),
        };
        let mem_sink = Arc::new(BufSink::new(opts));
        // end setup

        for i in 0..5 {
            assert!(rt
                .block_on(async {
                    mem_sink
                        .clone()
                        .sink(String::from(format!("hello world {}\n", i)))
                        .await
                })
                .is_ok());
        }

        let ref_output =
            "hello world 0\nhello world 1\nhello world 2\nhello world 3\nhello world 4\n";

        let output_before_flush_timeout =
            rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });

        assert_ne!(
            ref_output,
            std::str::from_utf8(output_before_flush_timeout.as_ref()).unwrap()
        );

        // simulate timeout with sleep
        rt.block_on(async {
            sleep(Duration::from_millis(40)).await;
        });

        let output_after_flush_timeout =
            rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });

        assert_eq!(
            ref_output,
            std::str::from_utf8(output_after_flush_timeout.as_ref()).unwrap()
        );
    }

    #[test]
    fn overflow_flush() {
        // setup
        let rt = Arc::new(Runtime::new().unwrap());
        let opts = SinkOptions {
            buffer: Some(BufferOverflowThreshold::new(1 * 1024).unwrap()),
            flush_timeout_ms: 30,
            tokio_runtime: rt.clone(),
            output_writer: Vec::new(),
        };
        let mem_sink = Arc::new(BufSink::new(opts));
        // end setup

        for _ in 0..1024 {
            assert!(rt
                .block_on(async { mem_sink.clone().sink(String::from("X")).await })
                .is_ok());
        }

        let mut ref_output: String = vec!['X'; 1024].into_iter().collect();

        let output_before_buf_overflow =
            rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });

        assert_ne!(
            ref_output,
            std::str::from_utf8(output_before_buf_overflow.as_ref()).unwrap()
        );

        // trigger overflow
        assert!(rt
            .block_on(async { mem_sink.clone().sink(String::from("X")).await })
            .is_ok());
        // 1 ms sleep between overflow checks, plus margin
        rt.block_on(async {
            sleep(Duration::from_millis(1 + 9)).await;
        });
        // ref_output got additional 'X' from overflow trigger
        ref_output.push('X');

        let output_after_buf_overflow =
            rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });

        assert_eq!(
            ref_output,
            std::str::from_utf8(output_after_buf_overflow.as_ref()).unwrap()
        );
    }

    #[test]
    fn flush_err() {
        // setup
        use core::task::{Context, Poll};
        use std::io::{Error, ErrorKind};
        use std::pin::Pin;
        use tokio::io::AsyncWrite;

        struct ProblematicWriter;
        impl AsyncWrite for ProblematicWriter {
            fn poll_write(
                self: Pin<&mut Self>,
                _: &mut Context<'_>,
                _: &[u8],
            ) -> Poll<Result<usize, Error>> {
                Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
            }
            fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
                Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
            }
            fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
                Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
            }
        }

        let rt = Arc::new(Runtime::new().unwrap());
        let opts = SinkOptions {
            buffer: Some(BufferOverflowThreshold::new(1 * 1024).unwrap()),
            flush_timeout_ms: 20,
            tokio_runtime: rt.clone(),
            output_writer: ProblematicWriter,
        };
        let mem_sink = Arc::new(BufSink::new(opts));
        // end setup

        assert!(rt
            .block_on(async { mem_sink.clone().sink(String::from("hello world\n")).await })
            .is_ok());

        assert!(mem_sink.last_flush_err().is_none());

        // wait for flush timeout
        rt.block_on(async {
            sleep(Duration::from_millis(20 + 5)).await;
        });

        assert!(mem_sink.last_flush_err().is_some());
        assert_eq!(ErrorKind::Other, mem_sink.last_flush_err().unwrap().kind());
        assert_eq!("kaboom!", format!("{}", mem_sink.last_flush_err().unwrap()));
    }
}
