use std::pin::Pin;

use connection_utils::Disconnected;
use cs_trace::{Tracer, create_trace};
use tokio::io::{AsyncRead, AsyncWrite};

mod rpc;

pub mod disconnected;
pub mod connected;

pub struct MultiplexedConnection<TAsyncDuplex: AsyncRead + AsyncWrite + Send + 'static> {
    trace: Box<dyn Tracer>,
    stream: Option<Pin<Box<TAsyncDuplex>>>,
}

impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + 'static> MultiplexedConnection<TAsyncDuplex> {
    pub fn new(
        stream: TAsyncDuplex,
    ) -> Box<dyn Disconnected> {
        let trace = create_trace!("rpc-connection");

        return Box::new(
            MultiplexedConnection {
                trace,
                stream: Some(Box::pin(stream)),
            }
        );
    }
}

#[cfg(test)]
mod tests {
    pub use cs_trace::{create_trace_listener, TraceListenerOptions, Trace, SubscriberInitExt, create_trace, child};
    use cs_utils::futures::wait;
    use connection_utils::test::test_async_stream;
    use cs_utils::random_str;
    use rstest::rstest;
    use tokio::io::duplex;
    use tokio::try_join;
    use cs_utils::random_str_rg;

    use super::MultiplexedConnection;

    #[rstest]
    #[case::size_8_32(8, 32)]
    #[case::size_128_512(128, 512)]
    #[case::size_2048_4096(2048, 4096)]
    #[case::size_4096_8192(4096, 8192)]
    #[case::size_8192_16384(8192, 16384)]
    #[tokio::test]
    async fn sends_data_from_local_channel(
        #[case] str_min_size: usize,
        #[case] str_max_size: usize,
    ) {
        use cs_utils::random_str_rg;

        let (duplex1, duplex2) = duplex(4096);

        let channel_label = format!("channel-label-{}", random_str(4));
        let channel_label1 = channel_label.clone();
        let channel_label2 = channel_label.clone();

        let (channel1, channel2) = try_join!(
            tokio::spawn(async move {
                let mut connection1 = MultiplexedConnection::new(duplex1)
                    .connect().await
                    .expect("Error while listening.");

                wait(50).await;
                
                let channel = connection1
                    .channel(channel_label1.clone()).await
                    .unwrap();
                
                assert_eq!(
                    channel.label(),
                    &channel_label1,
                    "Channel labels must match.",
                );

                channel
            }),
            tokio::spawn(async move {
                let mut connection2 = MultiplexedConnection::new(duplex2)
                    .listen().await
                    .expect("Error while listening.");

                let mut on_remote_channel = connection2.on_remote_channel().unwrap();

                let channel = {
                    loop {
                        if let Ok(channel) = on_remote_channel.try_recv() {
                            break channel;
                        }

                        wait(50).await;
                    }
                };

                assert_eq!(
                    channel.label(),
                    &channel_label2,
                    "Channel labels must match.",
                );

                channel
            }),
        ).unwrap();

        let test_data = vec![
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
        ].join("");

        test_async_stream(
            channel1,
            channel2,
            test_data,
        ).await;
    }

    #[rstest]
    #[case::size_8_32(8, 32)]
    #[case::size_128_512(128, 512)]
    #[case::size_2048_4096(2048, 4096)]
    #[case::size_4096_8192(4096, 8192)]
    #[case::size_8192_16384(8192, 16384)]
    #[tokio::test]
    async fn sends_data_from_remote_channel(
        #[case] str_min_size: usize,
        #[case] str_max_size: usize,
    ) {
        let (duplex1, duplex2) = duplex(4096);

        let channel_label = format!("channel-label-{}", random_str(4));
        let channel_label1 = channel_label.clone();
        let channel_label2 = channel_label.clone();

        let (channel1, channel2) = try_join!(
            tokio::spawn(async move {
                let mut connection1 = MultiplexedConnection::new(duplex1)
                    .connect().await
                    .expect("Error while listening");

                wait(50).await;
                
                let channel = connection1
                    .channel(channel_label1.clone()).await
                    .unwrap();

                assert_eq!(
                    channel.label(),
                    &channel_label1,
                    "Channel labels must match.",
                );

                channel
            }),
            tokio::spawn(async move {
                let mut connection2 = MultiplexedConnection::new(duplex2)
                    .listen().await
                    .expect("Error while listening.");

                let mut on_remote_channel = connection2.on_remote_channel().unwrap();

                let channel = {
                    loop {
                        if let Ok(channel) = on_remote_channel.try_recv() {
                            break channel;
                        }

                        wait(50).await;
                    }
                };

                assert_eq!(
                    channel.label(),
                    &channel_label2,
                    "Channel labels must match.",
                );

                channel
            }),
        ).unwrap();

        let test_data = vec![
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
            random_str_rg(str_min_size..=str_max_size),
        ].join("");

        test_async_stream(
            channel1,
            channel2,
            test_data,
        ).await;
    }
}
