use crate::{Error, ResponseContent};
use async_trait::async_trait;
use futures_util::Stream;
use sse_agent::SseBody;

use std::{pin::Pin, task::Poll};

impl From<sse_agent::Error<hyper::Error>> for Error {
    fn from(err: sse_agent::Error<hyper::Error>) -> Self {
        match err.kind() {
            sse_agent::ErrorKind::<hyper::Error>::Inner(hyper_err) => Self::Net(hyper_err),
            sse_agent::ErrorKind::<hyper::Error>::Sse(parse_err) => {
                Self::InvalidBody(parse_err.to_string())
            }
        }
    }
}

#[derive(Clone, Copy)]
pub struct Sse<T>(std::marker::PhantomData<T>);

#[async_trait]
impl<T> ResponseContent for Sse<T>
where
    T: serde::de::DeserializeOwned + Unpin + Send + 'static,
{
    type Data = JsonStream<T>;

    async fn convert_response(
        response: hyper::Response<hyper::Body>,
    ) -> Result<http::Response<Self::Data>, Error> {
        let (parts, body) = response.into_parts();

        if !parts.status.is_success() {
            return Err(Error::non_2xx(parts.status, &[]));
        }

        let body = JsonStream {
            inner: body.into_sse(),
            _marker: std::marker::PhantomData,
            last_event_id: None,
        };

        Ok(http::Response::from_parts(parts, body))
    }
}

pub struct Event<T> {
    pub event: String,
    pub data: T,
}

pub struct JsonStream<T> {
    inner: sse_agent::Body<hyper::Body>,
    _marker: std::marker::PhantomData<T>,
    last_event_id: Option<String>,
}

impl<T> Stream for JsonStream<T>
where
    T: serde::de::DeserializeOwned + Unpin,
{
    type Item = Result<Event<T>, crate::Error>;

    fn poll_next(
        mut self: std::pin::Pin<&mut Self>,
        ctx: &mut std::task::Context<'_>,
    ) -> Poll<Option<Self::Item>> {
        match Pin::new(&mut self.inner).poll_next(ctx) {
            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(Error::from(err)))),
            Poll::Ready(None) => Poll::Ready(None),
            Poll::Pending => Poll::Pending,
            Poll::Ready(Some(Ok(sse_agent::Event {
                event,
                data,
                last_event_id,
            }))) => {
                self.last_event_id = last_event_id;

                let res = serde_json::from_str::<T>(&data)
                    .map(|data| Event { event, data })
                    .map_err(|err| Error::deserialization(err, data.as_bytes()));

                Poll::Ready(Some(res))
            }
        }
    }
}
