//! [![crates.io version](https://img.shields.io/crates/v/safina-threadpool.svg)](https://crates.io/crates/safina-threadpool)
//! [![license: Apache 2.0](https://gitlab.com/leonhard-llc/safina-rs/-/raw/main/license-apache-2.0.svg)](http://www.apache.org/licenses/LICENSE-2.0)
//! [![unsafe forbidden](https://gitlab.com/leonhard-llc/safina-rs/-/raw/main/unsafe-forbidden-success.svg)](https://github.com/rust-secure-code/safety-dance/)
//! [![pipeline status](https://gitlab.com/leonhard-llc/safina-rs/badges/main/pipeline.svg)](https://gitlab.com/leonhard-llc/safina-rs/-/pipelines)
//!
//! A threadpool.
//!
//! You can use it alone or with [`safina`](https://crates.io/crates/safina),
//! a safe async runtime.
//!
//! # Features
//! - Add closures or `FnOnce` to the pool
//! - One of the pool's threads will execute it
//! - Automatically restarts panicked threads
//! - Handles hitting process thread limit.
//! - `forbid(unsafe_code)`
//! - Depends only on `std`
//! - 100% test coverage
//!
//! # Limitations
//! - Allocates memory
//! - Not optimized
//!
//! # Documentation
//! <https://docs.rs/safina-threadpool>
//!
//! # Examples
//! ```rust
//! # type ProcessResult = ();
//! # fn process_data(data: (), sender: std::sync::mpsc::Sender<ProcessResult>) -> ProcessResult {
//! #    sender.send(()).unwrap();
//! # }
//! # fn f() {
//! # let data_source = vec![(),()];
//! let pool =
//!     safina_threadpool::ThreadPool::new("worker", 2).unwrap();
//! let receiver = {
//!     let (sender, receiver) =
//!         std::sync::mpsc::channel();
//!     for data in data_source {
//!         let sender_clone = sender.clone();
//!         pool.schedule(
//!             move || process_data(data, sender_clone));
//!     }
//!     receiver
//! };
//! let results: Vec<ProcessResult> =
//!     receiver.iter().collect();
//! // ...
//! # }
//! ```
//!
//! # Alternatives
//! - [`blocking`](https://crates.io/crates/blocking)
//!   - Popular
//!   - A little `unsafe` code
//! - [`threadpool`](https://crates.io/crates/threadpool)
//!   - Popular
//!   - Well maintained
//!   - Dependencies have `unsafe` code
//! - [`futures-executor`](https://crates.io/crates/futures-executor)
//!   - Very popular
//!   - Full of `unsafe`
//! - [`scoped_threadpool`](https://crates.io/crates/scoped_threadpool)
//!   - Popular
//!   - Contains `unsafe` code
//! - [`scheduled-thread-pool`](https://crates.io/crates/scheduled-thread-pool)
//!   - Used by a popular connection pool library
//!   - Dependencies have `unsafe` code
//! - [`workerpool`](https://crates.io/crates/workerpool)
//!   - Dependencies have `unsafe` code
//! - [`threads_pool`](https://crates.io/crates/threads_pool)
//!   - Full of `unsafe`
//! - [`thread-pool`](https://crates.io/crates/thread-pool)
//!   - Old
//!   - Dependencies have `unsafe` code
//! - [`tasque`](https://crates.io/crates/tasque)
//!   - Dependencies have `unsafe` code
//! - [`fast-threadpool`](https://crates.io/crates/fast-threadpool)
//!   - Dependencies have `unsafe` code
//! - [`blocking-permit`](https://crates.io/crates/blocking-permit)
//!   - Full of `unsafe`
//! - [`rayon-core`](https://crates.io/crates/rayon-core)
//!   - Full of `unsafe`
//!
//! # Changelog
//! - v0.2.1 - Improve test coverage.
//! - v0.2.0
//!   - `ThreadPool::new` to return `Result`.
//!   - `ThreadPool::try_schedule` to return an error when it fails to restart panicked threads.
//!   - `ThreadPool::schedule` to handle failure starting replacement threads.
//! - v0.1.4 - Stop threads on drop.
//! - v0.1.3 - Support stable Rust!  Needs 1.51+.
//! - v0.1.2 - Add another example
//! - v0.1.1 - Simplified internals and improved documentation.
//! - v0.1.0 - First release
//!
//! # TO DO
//! - Log a warning when all threads panicked.
//! - Update test coverage.
//! - Add a public `respawn_threads` function.
//! - Add a stress test
//! - Add a benchmark.  See benchmarks in <https://crates.io/crates/executors>
//! - Add a way for a job to schedule another job on the same thread, with stealing.
//!
//! # Release Process
//! 1. Edit `Cargo.toml` and bump version number.
//! 1. Run `./release.sh`
#![forbid(unsafe_code)]

use core::fmt::{Debug, Display, Formatter};
use core::sync::atomic::{AtomicUsize, Ordering};
use core::time::Duration;
use std::error::Error;
use std::sync::mpsc::{Receiver, RecvTimeoutError, SyncSender, TrySendError};
use std::sync::{Arc, Mutex};

#[cfg(test)]
pub static TEST_MAX_THREADS: AtomicUsize = AtomicUsize::new(100);

struct AtomicCounter {
    next_value: AtomicUsize,
}

impl AtomicCounter {
    pub fn new() -> Self {
        Self {
            next_value: AtomicUsize::new(0),
        }
    }
    pub fn next(&self) -> usize {
        self.next_value.fetch_add(1, Ordering::AcqRel)
    }
}

#[test]
fn atomic_counter() {
    let counter = Arc::new(AtomicCounter::new());
    assert_eq!(0, counter.next());
    assert_eq!(1, counter.next());
    assert_eq!(2, counter.next());
}

#[test]
fn atomic_counter_many_readers() {
    let receiver = {
        let counter = Arc::new(AtomicCounter::new());
        let (sender, receiver) = std::sync::mpsc::channel();
        for _ in 0..10 {
            let counter_clone = counter.clone();
            let sender_clone = sender.clone();
            std::thread::spawn(move || {
                for _ in 0..10 {
                    sender_clone.send(counter_clone.next()).unwrap();
                }
            });
        }
        receiver
    };
    let mut values: Vec<usize> = receiver.iter().collect();
    values.sort_unstable();
    assert_eq!((0_usize..100).collect::<Vec<usize>>(), values);
}

fn err_eq(a: &std::io::Error, b: &std::io::Error) -> bool {
    a.kind() == b.kind() && format!("{}", a) == format!("{}", b)
}

#[derive(Debug)]
pub enum StartThreadsError {
    /// The pool has no threads and `std::thread::Builder::spawn` returned the included error.
    NoThreads(std::io::Error),
    /// The pool has at least one thread and `std::thread::Builder::spawn` returned the included error.
    Respawn(std::io::Error),
}
impl Display for StartThreadsError {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
        match self {
            StartThreadsError::NoThreads(e) => write!(
                f,
                "ThreadPool workers all panicked, failed starting replacement threads: {}",
                e
            ),
            StartThreadsError::Respawn(e) => {
                write!(
                    f,
                    "ThreadPool failed starting threads to replace panicked threads: {}",
                    e
                )
            }
        }
    }
}
impl Error for StartThreadsError {}
impl PartialEq for StartThreadsError {
    fn eq(&self, other: &Self) -> bool {
        match (self, other) {
            (StartThreadsError::NoThreads(a), StartThreadsError::NoThreads(b))
            | (StartThreadsError::Respawn(a), StartThreadsError::Respawn(b)) => err_eq(a, b),
            _ => false,
        }
    }
}
impl Eq for StartThreadsError {}

#[derive(Debug)]
pub enum NewThreadPoolError {
    Parameter(String),
    /// `std::thread::Builder::spawn` returned the included error.
    Spawn(std::io::Error),
}
impl Display for NewThreadPoolError {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
        match self {
            NewThreadPoolError::Parameter(s) => write!(f, "{}", s),
            NewThreadPoolError::Spawn(e) => {
                write!(f, "ThreadPool failed starting threads: {}", e)
            }
        }
    }
}
impl Error for NewThreadPoolError {}
impl PartialEq for NewThreadPoolError {
    fn eq(&self, other: &Self) -> bool {
        match (self, other) {
            (NewThreadPoolError::Parameter(a), NewThreadPoolError::Parameter(b)) => a == b,
            (NewThreadPoolError::Spawn(a), NewThreadPoolError::Spawn(b)) => err_eq(a, b),
            _ => false,
        }
    }
}
impl Eq for NewThreadPoolError {}
impl From<StartThreadsError> for NewThreadPoolError {
    fn from(err: StartThreadsError) -> Self {
        match err {
            StartThreadsError::NoThreads(e) | StartThreadsError::Respawn(e) => {
                NewThreadPoolError::Spawn(e)
            }
        }
    }
}

#[derive(Debug)]
pub enum TryScheduleError {
    QueueFull,
    /// The pool has no threads and `std::thread::Builder::spawn` returned the included error.
    NoThreads(std::io::Error),
    /// The pool has at least one thread and `std::thread::Builder::spawn` returned the included error.
    Respawn(std::io::Error),
}
impl Display for TryScheduleError {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
        match self {
            TryScheduleError::QueueFull => write!(f, "ThreadPool queue is full"),
            TryScheduleError::NoThreads(e) => write!(
                f,
                "ThreadPool workers all panicked, failed starting replacement threads: {}",
                e
            ),
            TryScheduleError::Respawn(e) => {
                write!(
                    f,
                    "ThreadPool failed starting threads to replace panicked threads: {}",
                    e
                )
            }
        }
    }
}
impl Error for TryScheduleError {}
impl PartialEq for TryScheduleError {
    fn eq(&self, other: &Self) -> bool {
        match (self, other) {
            (TryScheduleError::QueueFull, TryScheduleError::QueueFull) => true,
            (TryScheduleError::NoThreads(a), TryScheduleError::NoThreads(b))
            | (TryScheduleError::Respawn(a), TryScheduleError::Respawn(b)) => err_eq(a, b),
            _ => false,
        }
    }
}
impl Eq for TryScheduleError {}
impl From<StartThreadsError> for TryScheduleError {
    fn from(err: StartThreadsError) -> Self {
        match err {
            StartThreadsError::NoThreads(e) => TryScheduleError::NoThreads(e),
            StartThreadsError::Respawn(e) => TryScheduleError::Respawn(e),
        }
    }
}

struct Inner {
    name: &'static str,
    next_name_num: AtomicCounter,
    size: usize,
    receiver: Mutex<Receiver<Box<dyn FnOnce() + Send>>>,
}

impl Inner {
    pub fn num_live_threads(self: &Arc<Inner>) -> usize {
        Arc::strong_count(self) - 1
    }

    fn work(self: &Arc<Inner>) {
        loop {
            let recv_result = self
                .receiver
                .lock()
                .unwrap()
                .recv_timeout(Duration::from_millis(500));
            match recv_result {
                Ok(f) => {
                    let _ignored = self.start_threads();
                    f();
                }
                Err(RecvTimeoutError::Timeout) => {}
                // ThreadPool was dropped.
                Err(RecvTimeoutError::Disconnected) => return,
            };
            let _ignored = self.start_threads();
        }
    }

    #[allow(clippy::unused_self)]
    #[allow(unused_variables)]
    fn spawn_thread(
        &self,
        num_live_threads: usize,
        name: String,
        f: impl FnOnce() + Send + 'static,
    ) -> Result<(), std::io::Error> {
        // I found no way to make std::thread fail reliably on both macOS & Linux.
        #[cfg(test)]
        if num_live_threads >= TEST_MAX_THREADS.load(Ordering::Acquire) {
            return Err(std::io::Error::new(
                std::io::ErrorKind::Other,
                "err1".to_string(),
            ));
        }

        std::thread::Builder::new().name(name).spawn(f)?;
        Ok(())
    }

    fn start_thread(self: &Arc<Inner>) -> Result<(), StartThreadsError> {
        let self_clone = self.clone();
        let num_live_threads = self.num_live_threads() - 1;
        if num_live_threads < self.size {
            self.spawn_thread(
                num_live_threads,
                format!("{}{}", self.name, self.next_name_num.next()),
                move || self_clone.work(),
            )
            .map_err(|e| {
                if num_live_threads == 0 {
                    StartThreadsError::NoThreads(e)
                } else {
                    StartThreadsError::Respawn(e)
                }
            })?;
        }
        Ok(())
    }

    fn start_threads(self: &Arc<Inner>) -> Result<(), StartThreadsError> {
        while self.num_live_threads() < self.size {
            self.start_thread()?;
        }
        Ok(())
    }
}

/// A collection of threads and a queue for jobs (`FnOnce` structs) they execute.
///
/// Threads stop when they execute a job that panics.
/// If one thread survives, it will recreate all the threads.
/// The next call to [`schedule`](#method.schedule) or [`try_schedule`](#method.try_schedule)
/// also recreates threads.
///
/// If your threadpool load is bursty and you want to automatically recover
/// from an all-threads-panicked state, you could use
/// [`safina_timer`](https://crates.io/crates/safina-timer) to periodically call
/// [`schedule`](#method.schedule) or [`try_schedule`](#method.try_schedule).
///
/// After drop, threads stop as they become idle.
///
/// # Example
/// ```rust
/// # type ProcessResult = ();
/// # fn process_data(data: (), sender: std::sync::mpsc::Sender<ProcessResult>) -> ProcessResult {
/// #    sender.send(()).unwrap();
/// # }
/// # fn f() {
/// # let data_source = vec![(),()];
/// let pool =
///     safina_threadpool::ThreadPool::new("worker", 2).unwrap();
/// let receiver = {
///     let (sender, receiver) =
///         std::sync::mpsc::channel();
///     for data in data_source {
///         let sender_clone = sender.clone();
///         pool.schedule(
///             move || process_data(data, sender_clone));
///     }
///     receiver
/// };
/// let results: Vec<ProcessResult> =
///     receiver.iter().collect();
/// // ...
/// # }
/// ```
///
/// ```rust
/// # use core::time::Duration;
/// # use std::sync::Arc;
/// let pool =
///     Arc::new(
///         safina_threadpool::ThreadPool::new("worker", 2).unwrap());
/// let executor = safina_executor::Executor::default();
/// safina_timer::start_timer_thread();
/// let pool_clone = pool.clone();
/// executor.spawn(async move {
///     loop {
///         safina_timer::sleep_for(Duration::from_millis(500)).await;
///         pool_clone.schedule(|| {});
///     }
/// });
/// # assert_eq!(2, pool.num_live_threads());
/// # for _ in 0..2 {
/// #     pool.schedule(|| {
/// #         std::thread::sleep(Duration::from_millis(100));
/// #         panic!("ignore this panic")
/// #     });
/// # }
/// # std::thread::sleep(Duration::from_millis(200));
/// # assert_eq!(0, pool.num_live_threads());
/// # std::thread::sleep(Duration::from_millis(500));
/// # assert_eq!(2, pool.num_live_threads());
/// ```
pub struct ThreadPool {
    inner: Arc<Inner>,
    sender: SyncSender<Box<dyn FnOnce() + Send>>,
}
impl ThreadPool {
    /// Creates a new thread pool containing `size` threads.
    /// The threads all start immediately.
    ///
    /// Threads are named with `name` with a number.
    /// For example, `ThreadPool::new("worker", 2)`
    /// creates threads named "worker-1" and "worker-2".
    /// If one of those threads panics, the pool creates "worker-3".
    ///
    /// After the `ThreadPool` struct drops, the threads continue processing
    /// jobs and stop when the queue is empty.
    ///
    /// # Errors
    /// Returns an error when `name` is empty, `size` is zero, or it fails to start the threads.
    pub fn new(name: &'static str, size: usize) -> Result<Self, NewThreadPoolError> {
        if name.is_empty() {
            return Err(NewThreadPoolError::Parameter(
                "ThreadPool::new called with empty name".to_string(),
            ));
        }
        if size < 1 {
            return Err(NewThreadPoolError::Parameter(format!(
                "ThreadPool::new called with invalid size value: {:?}",
                size
            )));
        }
        // Use a channel with bounded size.
        // If the channel was unbounded, the process could OOM when throughput goes down.
        let (sender, receiver) = std::sync::mpsc::sync_channel(size * 200);
        let pool = ThreadPool {
            inner: Arc::new(Inner {
                name,
                next_name_num: AtomicCounter::new(),
                size,
                receiver: Mutex::new(receiver),
            }),
            sender,
        };
        pool.inner.start_threads()?;
        Ok(pool)
    }

    /// Returns the number of threads in the pool.
    #[must_use]
    pub fn size(&self) -> usize {
        self.inner.size
    }

    /// Returns the number of threads currently alive.
    #[must_use]
    pub fn num_live_threads(&self) -> usize {
        self.inner.num_live_threads()
    }

    /// Adds a job to the queue.  The next idle thread will execute it.
    /// Jobs are started in FIFO order.
    ///
    /// Blocks when the queue is full or no threads are running.
    /// See [`try_schedule`](#method.try_schedule).
    ///
    /// Recreates any threads that panicked.
    /// Retries on failure to start a new thread.
    ///
    /// Puts `f` in a [`Box`](https://doc.rust-lang.org/stable/std/boxed/struct.Box.html) before
    /// adding it to the queue.
    #[allow(clippy::missing_panics_doc)]
    pub fn schedule<F: FnOnce() + Send + 'static>(&self, f: F) {
        let mut opt_box_f: Option<Box<dyn FnOnce() + Send + 'static>> = Some(Box::new(f));
        loop {
            match self.inner.start_threads() {
                Ok(()) | Err(StartThreadsError::Respawn(_)) => {
                    // At least one thread is running.
                }
                Err(StartThreadsError::NoThreads(_)) => {
                    std::thread::sleep(Duration::from_millis(10));
                    continue;
                }
            }
            opt_box_f = match self.sender.try_send(opt_box_f.take().unwrap()) {
                Ok(()) => return,
                Err(TrySendError::Disconnected(_)) => unreachable!(),
                Err(TrySendError::Full(box_f)) => Some(box_f),
            };
            std::thread::sleep(Duration::from_millis(10));
        }
    }

    /// Adds a job to the queue and then starts threads to replace any panicked threads.
    /// The next idle thread will execute the job.
    /// Starts jobs in FIFO order.
    ///
    /// Puts `f` in a [`Box`](https://doc.rust-lang.org/stable/std/boxed/struct.Box.html) before
    /// adding it to the queue.
    ///
    /// # Errors
    /// Returns an error when the queue is full or it fails to start a thread.
    /// If the return value is not `TryScheduleError::QueueFull` then it added the job to the queue.
    #[allow(clippy::missing_panics_doc)]
    pub fn try_schedule(&self, f: impl FnOnce() + Send + 'static) -> Result<(), TryScheduleError> {
        match self.sender.try_send(Box::new(f)) {
            Ok(_) => {}
            Err(TrySendError::Disconnected(_)) => unreachable!(),
            Err(TrySendError::Full(_)) => return Err(TryScheduleError::QueueFull),
        };
        self.inner.start_threads().map_err(std::convert::Into::into)
    }
}
impl Debug for ThreadPool {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
        write!(
            f,
            "ThreadPool{{{:?},size={:?}}}",
            self.inner.name, self.inner.size
        )
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use safe_lock::SafeLock;
    use std::ops::Range;
    use std::sync::atomic::{AtomicBool, Ordering};
    use std::time::Instant;
    static LOCK: SafeLock = SafeLock::new();

    fn assert_elapsed(before: Instant, range_ms: Range<u64>) {
        assert!(!range_ms.is_empty(), "invalid range {:?}", range_ms);
        let elapsed = before.elapsed();
        println!("elapsed = {} ms", elapsed.as_millis());
        let duration_range =
            Duration::from_millis(range_ms.start)..Duration::from_millis(range_ms.end);
        assert!(
            duration_range.contains(&elapsed),
            "{:?} elapsed, out of range {:?}",
            elapsed,
            duration_range
        );
    }

    fn set_max_threads(n: usize) {
        println!("set_max_threads({})", n);
        TEST_MAX_THREADS.store(n, Ordering::Release);
    }

    fn sleep_ms(ms: u64) {
        std::thread::sleep(Duration::from_millis(ms));
    }

    fn panic_threads(pool: &ThreadPool, num: usize) {
        let pause = Arc::new(AtomicBool::new(true));
        for _ in 0..num {
            let pause_clone = pause.clone();
            pool.try_schedule(move || {
                println!(
                    "thread {:?} waiting",
                    std::thread::current().name().unwrap_or("")
                );
                while pause_clone.load(Ordering::Acquire) {
                    sleep_ms(10);
                }
                println!(
                    "panicking thread {:?}",
                    std::thread::current().name().unwrap_or("")
                );
                panic!("ignore this panic");
            })
            .unwrap();
        }
        sleep_ms(100);
        pause.store(false, Ordering::Release);
        sleep_ms(100);
    }

    #[test]
    fn new_thread_pool_at_max() {
        let _guard = LOCK.lock().unwrap();
        set_max_threads(2);
        ThreadPool::new("test", 2).unwrap();
    }

    #[test]
    fn new_thread_pool_error_spawn() {
        let _guard = LOCK.lock().unwrap();
        set_max_threads(2);
        let err = ThreadPool::new("test", 3).unwrap_err();
        assert!(matches!(err, NewThreadPoolError::Spawn(_)));
    }

    #[test]
    fn schedule_retries_thread_start() {
        let _guard = LOCK.lock().unwrap();
        set_max_threads(3);
        let pool = ThreadPool::new("test", 3).unwrap();
        panic_threads(&pool, 3);
        set_max_threads(0);
        let before = Instant::now();
        std::thread::spawn(|| {
            sleep_ms(100);
            set_max_threads(1);
        });
        let (sender, receiver) = std::sync::mpsc::channel();
        pool.schedule(move || {
            println!("sending ()");
            sender.send(()).unwrap();
        });
        receiver.recv_timeout(Duration::from_millis(500)).unwrap();
        assert_elapsed(before, 100..200);
    }

    // schedule_retries_when_queue_full is in tests/test.rs .

    // try_schedule_queue_full is in tests/test.rs .

    #[test]
    fn try_schedule_no_threads() {
        let _guard = LOCK.lock().unwrap();
        set_max_threads(2);
        let pool = ThreadPool::new("test", 2).unwrap();
        panic_threads(&pool, 2);
        set_max_threads(0);
        let result = pool.try_schedule(|| {});
        assert!(
            matches!(result, Err(TryScheduleError::NoThreads(_))),
            "{:?}",
            result
        );
    }

    #[test]
    fn try_schedule_respawn() {
        let _guard = LOCK.lock().unwrap();
        set_max_threads(2);
        let pool = ThreadPool::new("test", 2).unwrap();
        panic_threads(&pool, 1);
        set_max_threads(1);
        let result = pool.try_schedule(|| {});
        assert!(
            matches!(result, Err(TryScheduleError::Respawn(_))),
            "{:?}",
            result
        );
    }

    #[test]
    fn threads_stop_after_pool_drops() {
        let _guard = LOCK.lock().unwrap();
        set_max_threads(2);
        let pool = ThreadPool::new("test", 2).unwrap();
        let inner = pool.inner.clone();
        drop(pool);
        std::thread::sleep(Duration::from_millis(100));
        assert_eq!(0, inner.num_live_threads());
    }
}
