use crate::websocket::{WebSocketHandler, WebSocketMessage, WebSocketSession};

use std::collections::HashMap;
use tokio::sync::broadcast::{channel, Sender};

#[derive(Debug, Clone)]
pub(crate) enum Dispatch {
    Open(WebSocketSession),
    Message(WebSocketSession, WebSocketMessage),
    Close(WebSocketSession, WebSocketMessage),
}

#[derive(Debug, Clone)]
pub(crate) struct Endpoints {
    channels: HashMap<&'static str, Sender<Dispatch>>,
}

impl Endpoints {
    pub(crate) fn get_paths(&self) -> Vec<&'static str> {
        #[allow(clippy::map_clone)]
        self.channels.keys().clone().map(|k| *k).collect()
    }

    #[tracing::instrument(level = "trace")]
    pub(crate) async fn contains_path(&self, key: &str) -> bool {
        self.channels.contains_key(key)
    }

    pub(crate) fn insert(
        &mut self,
        key: &'static str,
        mut handler: Box<impl WebSocketHandler + 'static>,
    ) {
        let (tx, mut rx) = channel::<Dispatch>(128);
        let key2 = key.to_string();
        self.channels.insert(key, tx);

        let f = async move {
            loop {
                match rx.recv().await {
                    Ok(message) => match message {
                        Dispatch::Open(session) => handler.on_open(&session).await,
                        Dispatch::Message(session, msg) => handler.on_message(&session, msg).await,
                        Dispatch::Close(session, msg) => handler.on_close(&session, msg).await,
                    },
                    Err(e) => tracing::error!("handler endpoint: {}: {:?}", key2, e),
                }
            }
        };

        tokio::spawn(f);
    }

    #[tracing::instrument(level = "trace")]
    pub(crate) async fn on_open(&self, session: &WebSocketSession) {
        self.channels
            .get(session.context().path().as_str())
            .and_then(|tx| match tx.send(Dispatch::Open(session.clone())) {
                Ok(t) => Some(t),
                Err(e) => {
                    tracing::error!("{:?}", e);
                    None
                }
            });
    }

    #[tracing::instrument(level = "trace")]
    pub(crate) async fn on_message(&self, session: &WebSocketSession, msg: WebSocketMessage) {
        self.channels
            .get(session.context().path().as_str())
            .and_then(
                |tx| match tx.send(Dispatch::Message(session.clone(), msg)) {
                    Ok(t) => Some(t),
                    Err(e) => {
                        tracing::error!("{:?}", e);
                        None
                    }
                },
            );
    }

    #[tracing::instrument(level = "trace")]
    pub(crate) async fn on_close(&self, session: &WebSocketSession, msg: WebSocketMessage) {
        self.channels
            .get(session.context().path().as_str())
            .and_then(|tx| match tx.send(Dispatch::Close(session.clone(), msg)) {
                Ok(t) => Some(t),
                Err(e) => {
                    tracing::error!("{:?}", e);
                    None
                }
            });
    }
}

impl Default for Endpoints {
    fn default() -> Self {
        Self {
            channels: HashMap::new(),
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::websocket::{WebSocketHandler, WebSocketMessage, WebSocketSession};

    #[derive(Debug, Default)]
    struct Handler;

    #[crate::async_trait]
    impl WebSocketHandler for Handler {
        async fn on_open(&mut self, _ws: &WebSocketSession) {}
        async fn on_message(&mut self, _ws: &WebSocketSession, _msg: WebSocketMessage) {}
        async fn on_close(&mut self, _ws: &WebSocketSession, _msg: WebSocketMessage) {}
    }

    #[tokio::test]
    async fn endpoint_contains_key() {
        let mut e = crate::Endpoints::default();
        let key = "ws";

        assert_eq!(e.contains_path(key).await, false);

        let h = Handler::default();
        e.insert(key, Box::new(h));

        assert_eq!(e.contains_path(key).await, true);
    }
}
