// -*- coding: utf-8 -*-
//
// Copyright 2021 Michael Büsch <m@bues.ch>
//
// Licensed under the Apache License version 2.0
// or the MIT license, at your option.
// SPDX-License-Identifier: Apache-2.0 OR MIT
//

use crate::util::{
    get_bounds,
    overlaps_any,
};
use std::{
    cell::UnsafeCell,
    collections::HashSet,
    hint::unreachable_unchecked,
    ops::{
        Deref,
        DerefMut,
        Range,
        RangeBounds,
    },
    sync::{
        LockResult,
        Mutex,
        PoisonError,
        TryLockError,
        TryLockResult,
    }
};

/// Multi-thread range lock for `Vec<T>`.
///
/// # Example
///
/// ```
/// use range_lock::VecRangeLock;
/// use std::sync::Arc;
///
/// let lock = Arc::new(VecRangeLock::new(vec![1, 2, 3, 4, 5]));
///
/// let mut guard = lock.try_lock(2..4).expect("Failed to lock 2..4");
/// assert_eq!(guard[0], 3);
/// guard[0] = 100;
/// assert_eq!(guard[0], 100);
/// assert_eq!(guard[1], 4);
/// ```
#[derive(Debug)]
pub struct VecRangeLock<T> {
    ranges: Mutex<HashSet<Range<usize>>>,
    data:   UnsafeCell<Vec<T>>,
}

// SAFETY:
// It is safe to access VecRangeLock and the contained data (via VecRangeLockGuard)
// from multiple threads simultaneously.
// The lock ensures that access to the data is strictly serialized.
// T must be Send-able to other threads.
unsafe impl<T> Sync for VecRangeLock<T>
where
    T: Send
{ }

impl<'a, T> VecRangeLock<T> {
    /// Construct a new VecRangeLock.
    ///
    /// * `data`: The data Vec to protect.
    pub fn new(data: Vec<T>) -> VecRangeLock<T> {
        VecRangeLock {
            ranges: Mutex::new(HashSet::new()),
            data:   UnsafeCell::new(data),
        }
    }

    /// Get the length (in number of elements) of the embedded Vec.
    #[inline]
    pub fn data_len(&self) -> usize {
        // SAFETY: Multithreaded access is safe. len cannot change.
        unsafe { (*self.data.get()).len() }
    }

    /// Unwrap the VecRangeLock into the contained data.
    /// This method consumes self.
    #[inline]
    pub fn into_inner(self) -> Vec<T> {
        debug_assert!(self.ranges.lock().unwrap().is_empty());
        self.data.into_inner()
    }

    /// Try to lock the given data `range`.
    ///
    /// * On success: Returns a `VecRangeLockGuard` that can be used to access the locked region.
    ///               Dereferencing `VecRangeLockGuard` yields a slice of the `data`.
    /// * On failure: Returns TryLockError::WouldBlock, if the range is contended.
    ///               The locking attempt may be retried by the caller upon contention.
    ///               Returns TryLockError::Poisoned, if the lock is poisoned.
    pub fn try_lock(&'a self, range: impl RangeBounds<usize>) -> TryLockResult<VecRangeLockGuard<'a, T>> {
        let data_len = self.data_len();
        let (range_start, range_end) = get_bounds(&range, data_len);
        if range_start >= data_len || range_end > data_len {
            panic!("Range is out of bounds.");
        }
        if range_start > range_end {
            panic!("Invalid range. Start is bigger than end.");
        }
        let range = range_start..range_end;

        if range_start < range_end {
            if let LockResult::Ok(mut ranges) = self.ranges.lock() {
                if overlaps_any(&*ranges, &range) {
                    TryLockResult::Err(TryLockError::WouldBlock)
                } else {
                    ranges.insert(range.clone());
                    TryLockResult::Ok(VecRangeLockGuard::new(self, range))
                }
            } else {
                TryLockResult::Err(TryLockError::Poisoned(
                    PoisonError::new(VecRangeLockGuard::new(self, range))))
            }
        } else {
            // Empty range.
            TryLockResult::Ok(VecRangeLockGuard::new(self, range))
        }
    }

    /// Unlock a range.
    fn unlock(&self, range: &Range<usize>) {
        let mut ranges = self.ranges.lock()
            .expect("VecRangeLock: Failed to take ranges mutex.");
        ranges.remove(range);
    }

    /// Get an immutable slice to the specified range.
    ///
    /// # SAFETY
    ///
    /// See get_mut_slice().
    #[inline]
    unsafe fn get_slice(&self, range: &Range<usize>) -> &[T] {
        // SAFETY: We trust the slicing machinery of Vec to work correctly.
        //         It must return the slice range that we requested.
        //         Otherwise our non-overlap guarantees are gone.
        &(*self.data.get())[range.clone()]
    }

    /// Get a mutable slice to the specified range.
    ///
    /// # SAFETY
    ///
    /// The caller must ensure that:
    /// * No overlapping slices must coexist on multiple threads.
    /// * Immutable slices to overlapping ranges may only coexist on a single thread.
    /// * Immutable and mutable slices must not coexist.
    #[inline]
    #[allow(clippy::mut_from_ref)] // Slices won't overlap. See SAFETY.
    unsafe fn get_mut_slice(&self, range: &Range<usize>) -> &mut [T] {
        let cptr = self.get_slice(range) as *const [T];
        let mut_slice = (cptr as *mut [T]).as_mut();
        // SAFETY: The pointer is never null, because it has been casted from a slice.
        mut_slice.unwrap_or_else(|| unreachable_unchecked())
    }
}

/// Lock guard variable type.
///
/// The Deref and DerefMut traits are implemented for this struct.
/// See the documentation of `VecRangeLock` for usage examples of `VecRangeLockGuard`.
#[derive(Debug)]
pub struct VecRangeLockGuard<'a, T> {
    lock:   &'a VecRangeLock<T>,
    range:  Range<usize>,
}

impl<'a, T> VecRangeLockGuard<'a, T> {
    #[inline]
    fn new(lock:    &'a VecRangeLock<T>,
           range:   Range<usize>) -> VecRangeLockGuard<'a, T> {
        VecRangeLockGuard {
            lock,
            range,
        }
    }
}

impl<'a, T> Drop for VecRangeLockGuard<'a, T> {
    #[inline]
    fn drop(&mut self) {
        self.lock.unlock(&self.range);
    }
}

impl<'a, T> Deref for VecRangeLockGuard<'a, T> {
    type Target = [T];

    #[inline]
    fn deref(&self) -> &Self::Target {
        // SAFETY: See deref_mut().
        unsafe { self.lock.get_slice(&self.range) }
    }
}

impl<'a, T> DerefMut for VecRangeLockGuard<'a, T> {
    #[inline]
    fn deref_mut(&mut self) -> &mut Self::Target {
        // SAFETY:
        // The lifetime of the slice is bounded by the lifetime of the guard.
        // The lifetime of the guard is bounded by the lifetime of the range lock.
        // The underlying data is owned by the range lock.
        // Therefore the slice cannot outlive the data.
        // The range lock ensures that no overlapping/conflicting guards
        // can be constructed.
        // The compiler ensures that the DerefMut result cannot be used,
        // if there's also an immutable Deref result.
        unsafe { self.lock.get_mut_slice(&self.range) }
    }
}

#[cfg(test)]
mod tests {
    use std::cell::RefCell;
    use std::sync::{Arc, Barrier};
    use std::thread;
    use super::*;

    #[test]
    fn test_base() {
        {
            // Range
            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
            {
                let mut g = a.try_lock(2..4).unwrap();
                assert!(!a.ranges.lock().unwrap().is_empty());
                assert_eq!(g[0..2], [3, 4]);
                g[1] = 10;
                assert_eq!(g[0..2], [3, 10]);
            }
            assert!(a.ranges.lock().unwrap().is_empty());
        }
        {
            // RangeInclusive
            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
            let g = a.try_lock(2..=4).unwrap();
            assert_eq!(g[0..3], [3, 4, 5]);
        }
        {
            // RangeTo
            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
            let g = a.try_lock(..4).unwrap();
            assert_eq!(g[0..4], [1, 2, 3, 4]);
        }
        {
            // RangeToInclusive
            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
            let g = a.try_lock(..=4).unwrap();
            assert_eq!(g[0..5], [1, 2, 3, 4, 5]);
        }
        {
            // RangeFrom
            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
            let g = a.try_lock(2..).unwrap();
            assert_eq!(g[0..4], [3, 4, 5, 6]);
        }
        {
            // RangeFull
            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
            let g = a.try_lock(..).unwrap();
            assert_eq!(g[0..6], [1, 2, 3, 4, 5, 6]);
        }
    }

    #[test]
    fn test_empty_range() {
        // Empty range doesn't cause conflicts.
        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
        let g0 = a.try_lock(2..2).unwrap();
        assert!(a.ranges.lock().unwrap().is_empty());
        assert_eq!(g0[0..0], []);
        let g1 = a.try_lock(2..2).unwrap();
        assert!(a.ranges.lock().unwrap().is_empty());
        assert_eq!(g1[0..0], []);
    }

    #[test]
    #[should_panic(expected="index out of bounds")]
    fn test_base_oob_read() {
        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
        let g = a.try_lock(2..4).unwrap();
        let _ = g[2];
    }

    #[test]
    #[should_panic(expected="index out of bounds")]
    fn test_base_oob_write() {
        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
        let mut g = a.try_lock(2..4).unwrap();
        g[2] = 10;
    }

    #[test]
    #[should_panic(expected="guard 1 panicked")]
    fn test_overlap0() {
        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
        let _g0 = a.try_lock(2..4).expect("guard 0 panicked");
        let _g1 = a.try_lock(3..5).expect("guard 1 panicked");
    }

    #[test]
    #[should_panic(expected="guard 0 panicked")]
    fn test_overlap1() {
        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
        let _g1 = a.try_lock(3..5).expect("guard 1 panicked");
        let _g0 = a.try_lock(2..4).expect("guard 0 panicked");
    }

    #[test]
    fn test_thread_no_overlap() {
        let a = Arc::new(VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]));
        let b = Arc::clone(&a);
        let c = Arc::clone(&a);
        let ba0 = Arc::new(Barrier::new(2));
        let ba1 = Arc::clone(&ba0);
        let j0 = thread::spawn(move || {
            {
                let mut g = b.try_lock(2..4).unwrap();
                assert!(!b.ranges.lock().unwrap().is_empty());
                assert_eq!(g[0..2], [3, 4]);
                g[1] = 10;
                assert_eq!(g[0..2], [3, 10]);
            }
            ba0.wait();
        });
        let j1 = thread::spawn(move || {
            {
                let g = c.try_lock(4..6).unwrap();
                assert!(!c.ranges.lock().unwrap().is_empty());
                assert_eq!(g[0..2], [5, 6]);
            }
            ba1.wait();
            let g = c.try_lock(3..5).unwrap();
            assert_eq!(g[0..2], [10, 5]);
        });
        j1.join().expect("Thread 1 panicked.");
        j0.join().expect("Thread 0 panicked.");
        assert!(a.ranges.lock().unwrap().is_empty());
    }

    struct NoSyncStruct(RefCell<u32>); // No Sync auto-trait.

    #[test]
    fn test_nosync() {
        let a = Arc::new(VecRangeLock::new(vec![
            NoSyncStruct(RefCell::new(1)),
            NoSyncStruct(RefCell::new(2)),
            NoSyncStruct(RefCell::new(3)),
            NoSyncStruct(RefCell::new(4)),
        ]));
        let b = Arc::clone(&a);
        let c = Arc::clone(&a);
        let ba0 = Arc::new(Barrier::new(2));
        let ba1 = Arc::clone(&ba0);
        let j0 = thread::spawn(move || {
            let _g = b.try_lock(0..1).unwrap();
            assert!(!b.ranges.lock().unwrap().is_empty());
            ba0.wait();
        });
        let j1 = thread::spawn(move || {
            let _g = c.try_lock(1..2).unwrap();
            assert!(!c.ranges.lock().unwrap().is_empty());
            ba1.wait();
        });
        j1.join().expect("Thread 1 panicked.");
        j0.join().expect("Thread 0 panicked.");
        assert!(a.ranges.lock().unwrap().is_empty());
    }
}

// vim: ts=4 sw=4 expandtab
