const NO_EXCEPTION: u8 = 0;
const THROWING: u8 = 1;
const THROWN: u8 = 2;
const MOVED: u8 = 3;

/// `ExceptionContext` that implements `Send` and `Sync`.
pub struct ExceptionContext<E> {
    status: core::sync::atomic::AtomicU8,
    exception: core::cell::UnsafeCell<core::mem::MaybeUninit<E>>,
}

unsafe impl<E: Send> Send for ExceptionContext<E> {}
unsafe impl<E: Send + Sync> Sync for ExceptionContext<E> {}

impl<E> Drop for ExceptionContext<E> {
    fn drop(&mut self) {
        if *self.status.get_mut() == THROWN {
            // SAFETY: when the status is `THROWN`, `exception` has an unmoved initialized value.
            let e = unsafe { self.exception.get().read().assume_init() };
            drop(e)
        }
    }
}

impl<E> Default for ExceptionContext<E> {
    fn default() -> Self {
        Self {
            status: 0.into(),
            exception: core::cell::UnsafeCell::new(core::mem::MaybeUninit::uninit()),
        }
    }
}

impl<E> ExceptionContext<E> {
    /// Create a new exception context.
    pub fn new() -> Self {
        core::default::Default::default()
    }

    /// Throws an exception. You should always `await` the result.
    ///
    /// Example:
    ///
    /// ```rust
    /// tokio_test::block_on(async {
    ///     let r = tokio::spawn(async {
    ///         asex::sync::ExceptionContext::<String>::new()
    ///             .catching(|ctx| async move {
    ///                 ctx.throw("failed".to_string()).await;
    ///                 panic!("Won't execute.")
    ///             }).await
    ///      }).await.unwrap();
    ///
    ///     assert_eq!(Err("failed".to_string()), r)
    /// })
    /// ```
    pub fn throw(&self, exception: E) -> AlwaysPending {
        if self
            .status
            .compare_exchange_weak(
                NO_EXCEPTION,
                THROWING,
                core::sync::atomic::Ordering::Acquire,
                core::sync::atomic::Ordering::Acquire,
            )
            .is_err()
        {
            panic!("`throw` calls more than once")
        }
        // SAFETY: we compare-exchange from NO_EXCEPTION to THROWING,
        // and the status won't be `NO_EXCEPTION` again.
        // So the compare-exchange will only succeed once, so there is no concurrent write.
        // Also, all reads on `exception` are performed only after status being written `THROWN`.
        // This happens after the exception is written,
        // so there is no concurrent read.
        unsafe { (&mut *self.exception.get()).write(exception) };
        self.status
            .store(THROWN, core::sync::atomic::Ordering::Relaxed);
        AlwaysPending
    }

    /// Executes the function `f` providing the context, then returns a Future that
    /// catches the thrown exception.
    ///
    /// Example:
    ///
    /// ```rust
    /// tokio_test::block_on(async {
    ///     let r = tokio::spawn(async {
    ///         asex::sync::ExceptionContext::<String>::new()
    ///             .catching(|_| async {
    ///                 "success".to_string()
    ///             }).await
    ///     }).await.unwrap();
    ///     assert_eq!(Ok("success".to_string()), r);
    ///
    ///     let r = tokio::spawn(async {
    ///         asex::sync::ExceptionContext::<String>::new()
    ///             .catching(|ctx| async move {
    ///                 ctx.throw("failed".to_string()).await;
    ///                 panic!("Won't execute.")
    ///             }).await
    ///      }).await.unwrap();
    ///
    ///     assert_eq!(Err("failed".to_string()), r)
    /// })
    /// ```
    pub fn catching<'a, Fu: core::future::Future, F: Fn(&'a Self) -> Fu + 'a>(
        &'a self,
        f: F,
    ) -> Catching<'a, E, Fu> {
        Catching {
            ctx: self,
            future: f(self),
        }
    }

    fn try_take_exception(&self) -> Option<E> {
        if self
            .status
            .compare_exchange_weak(
                THROWN,
                MOVED,
                core::sync::atomic::Ordering::Acquire,
                core::sync::atomic::Ordering::Acquire,
            )
            .is_ok()
        {
            // SAFETY: status is changed from THROWN to MOVED,
            // but writes on exception only happens after status changed
            // from NO_EXCEPTION to THROWING, so there is no concurrent write.
            //
            // Because the status was THROWN before this write, the `exception` has an initialized value.
            // We can move it out because after status becomes MOVED, the value won't be dropped by the context itself.
            Some(unsafe { self.exception.get().read().assume_init() })
        } else {
            None
        }
    }
}

#[must_use]
pub struct AlwaysPending;

impl core::future::Future for AlwaysPending {
    type Output = core::convert::Infallible;

    fn poll(
        self: core::pin::Pin<&mut Self>,
        _: &mut core::task::Context<'_>,
    ) -> core::task::Poll<Self::Output> {
        core::task::Poll::Pending
    }
}

pin_project_lite::pin_project! {
    pub struct Catching<'a, E, F> {
        ctx: &'a ExceptionContext<E>,
        #[pin]
        future: F,
    }
}

impl<'a, E, F: core::future::Future> core::future::Future for Catching<'a, E, F> {
    type Output = Result<F::Output, E>;

    fn poll(
        self: core::pin::Pin<&mut Self>,
        cx: &mut core::task::Context<'_>,
    ) -> core::task::Poll<Self::Output> {
        let this = self.project();
        let p = this.future.poll(cx);
        if let Some(exception) = this.ctx.try_take_exception() {
            core::task::Poll::Ready(Err(exception))
        } else {
            p.map(Ok)
        }
    }
}
