#![cfg_attr(feature = "nightly", feature(const_fn))]

use lock_api::{GetThreadId, RawMutex, RawReentrantMutex};

use std::cell::{Cell, UnsafeCell};
use std::fmt;
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut};

#[derive(Debug, Clone, Default, PartialEq, Eq)]
struct Threadlock {
	locked: Cell<bool>
}

impl Threadlock {

	pub const fn new() -> Threadlock {
		Threadlock { locked: Cell::new(false) }
	}

	#[track_caller]
	pub fn lock(&self) {
		if !self.try_lock() { panic!("LendableMutex already locked on this thread"); }
	}

	pub fn try_lock(&self) -> bool {
		if self.locked.get() {
			false
		} else {
			self.locked.set(true);
			true
		}
	}

	#[track_caller]
	pub fn unlock(&self) {
		debug_assert!(self.locked.get(), "tried to unlock already unlocked threadlock");
		//println!("locked: {}", self.locked.get());
		self.locked.set(false);
	}
}

pub struct LendableMutex<R, G, T: ?Sized> {
	raw: RawReentrantMutex<R, G>,
	threadlock: Threadlock,
	data: UnsafeCell<T>
}

unsafe impl<R, G, T: ?Sized> Send for LendableMutex<R, G, T> {}
unsafe impl<R, G, T: ?Sized> Sync for LendableMutex<R, G, T> {}

impl<R: RawMutex, G: GetThreadId, T: ?Sized> LendableMutex<R, G, T> {
	#[cfg(feature = "nightly")]
	pub const fn new(v: T) -> Self
	where
		T: Sized,
	{
		Self {
			raw: RawReentrantMutex::INIT,
			threadlock: Threadlock::new(),
			data: UnsafeCell::new(v),
		}
	}

	#[cfg(not(feature = "nightly"))]
	pub fn new(v: T) -> Self
	where
		T: Sized,
	{
		Self {
			raw: RawReentrantMutex::INIT,
			threadlock: Threadlock::new(),
			data: UnsafeCell::new(v),
		}
	}

	pub fn into_inner(self) -> T
	where
		T: Sized,
	{
		self.data.into_inner()
	}

	#[track_caller]
	#[inline]
	fn guard(&self) -> LendableMutexGuard<'_, R, G, T> {
		self.threadlock.lock();
		LendableMutexGuard {
			mutex: self,
			marker: PhantomData,
		}
	}

	#[track_caller]
	pub unsafe fn force_unlock(&self) {
		self.threadlock.unlock(); // unlock this first so other threads trying to lock don't get a panic
		self.raw.unlock();
	}

	pub unsafe fn raw(&self) -> &RawReentrantMutex<R, G> {
		&self.raw
	}

	#[track_caller]
	pub fn lock<'a>(&'a self) -> LendableMutexGuard<'a, R, G, T> {
		self.raw.lock();
		self.guard()
	}

	#[track_caller]
	pub fn try_lock<'a>(&'a self) -> Option<LendableMutexGuard<'a, R, G, T>> {
		if self.raw.try_lock() {
			Some(self.guard())
		} else {
			None
		}
	}
}

pub struct LendableMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
	mutex: &'a LendableMutex<R, G, T>,
	marker: PhantomData<(&'a mut T, R::GuardMarker)>,
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Deref for LendableMutexGuard<'a, R, G, T> {
	type Target = T;
	fn deref(&self) -> &T {
		unsafe { &*self.mutex.data.get() }
	}
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for LendableMutexGuard<'a, R, G, T> {
	fn deref_mut(&mut self) -> &mut T {
		unsafe { &mut *self.mutex.data.get() }
	}
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + Debug> Debug for LendableMutexGuard<'a, R, G, T> {
	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
		Debug::fmt(&**self, f)
	}
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + Display> Display
	for LendableMutexGuard<'a, R, G, T>
{
	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
		Display::fmt(&**self, f)
	}
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Drop for LendableMutexGuard<'a, R, G, T> {
	fn drop(&mut self) {
		unsafe {
			self.mutex.force_unlock();
		}
	}
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> LendableMutexGuard<'a, R, G, T> {
	pub fn mutex(s: &Self) -> &'a LendableMutex<R, G, T> {
		s.mutex
	}
	pub fn raw_mutex(s: &Self) -> &'a RawReentrantMutex<R, G> {
		&s.mutex.raw
	}
	pub fn lend(s: &mut Self, f: impl FnOnce()) {
		s.mutex.threadlock.unlock();
		let _defer = defer::defer(|| {
			s.mutex.threadlock.lock();
		});
		f();
	}

	#[inline]
	pub fn map<U: ?Sized, F>(s: Self, f: F) -> MappedLendableMutexGuard<'a, R, G, U>
	where
		F: FnOnce(&mut T) -> &mut U,
	{
		let raw = &s.mutex.raw;
		let threadlock = &s.mutex.threadlock;
		let data = f(unsafe { &mut *s.mutex.data.get() });
		mem::forget(s);
		MappedLendableMutexGuard {
			raw,
			threadlock,
			data,
			marker: PhantomData,
		}
	}

	#[inline]
	pub fn try_map<U: ?Sized, F>(s: Self, f: F) -> Result<MappedLendableMutexGuard<'a, R, G, U>, Self>
	where
		F: FnOnce(&mut T) -> Option<&mut U>,
	{
		let raw = &s.mutex.raw;
		let threadlock = &s.mutex.threadlock;
		let data = f(unsafe { &mut *s.mutex.data.get() });
		if let Some(x) = data {
			mem::forget(s);
			Ok(MappedLendableMutexGuard {
				raw,
				threadlock,
				data: x,
				marker: PhantomData,
			})
		} else {
			Err(s)
		}
	}
}

pub struct MappedLendableMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
	raw: &'a RawReentrantMutex<R, G>,
	threadlock: &'a Threadlock,
	data: *mut T,
	marker: PhantomData<(&'a mut T, R::GuardMarker)>,
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Deref for MappedLendableMutexGuard<'a, R, G, T> {
	type Target = T;
	fn deref(&self) -> &T {
		unsafe { &*self.data }
	}
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for MappedLendableMutexGuard<'a, R, G, T> {
	fn deref_mut(&mut self) -> &mut T {
		unsafe { &mut *self.data }
	}
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + Debug> Debug for MappedLendableMutexGuard<'a, R, G, T> {
	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
		Debug::fmt(&**self, f)
	}
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + Display> Display
	for MappedLendableMutexGuard<'a, R, G, T>
{
	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
		Display::fmt(&**self, f)
	}
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Drop for MappedLendableMutexGuard<'a, R, G, T> {
	fn drop(&mut self) {
		unsafe {
			self.threadlock.unlock(); // unlock this first, see LendableMutexGuard::force_unlock
			self.raw.unlock();
		}
	}
}

impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> MappedLendableMutexGuard<'a, R, G, T> {
	pub fn raw_mutex(s: &Self) -> &'a RawReentrantMutex<R, G> {
		s.raw
	}
	pub fn lend(s: &mut Self, f: impl FnOnce()) {
		s.threadlock.unlock();
		let _defer = defer::defer(|| {
			s.threadlock.lock();
		});
		f();
	}

	#[inline]
	pub fn map<U: ?Sized, F>(s: Self, f: F) -> MappedLendableMutexGuard<'a, R, G, U>
	where
		F: FnOnce(&mut T) -> &mut U,
	{
		let raw = s.raw;
		let threadlock = s.threadlock;
		let data = f(unsafe { &mut *s.data });
		mem::forget(s);
		MappedLendableMutexGuard {
			raw,
			threadlock,
			data,
			marker: PhantomData,
		}
	}

	#[inline]
	pub fn try_map<U: ?Sized, F>(s: Self, f: F) -> Result<MappedLendableMutexGuard<'a, R, G, U>, Self>
	where
		F: FnOnce(&mut T) -> Option<&mut U>,
	{
		let raw = s.raw;
		let threadlock = s.threadlock;
		let data = f(unsafe { &mut *s.data });
		if let Some(x) = data {
			mem::forget(s);
			Ok(MappedLendableMutexGuard {
				raw,
				threadlock,
				data: x,
				marker: PhantomData,
			})
		} else {
			Err(s)
		}
	}
}

pub type PlLendableMutex<T> = LendableMutex<parking_lot::RawMutex, parking_lot::RawThreadId, T>;
pub type PlLendableMutexGuard<'a, T> =
	LendableMutexGuard<'a, parking_lot::RawMutex, parking_lot::RawThreadId, T>;
pub type PlMappedLendableMutexGuard<'a, T> =
	MappedLendableMutexGuard<'a, parking_lot::RawMutex, parking_lot::RawThreadId, T>;

#[cfg(test)]
mod tests {
	use super::*;
	use std::sync::Arc;
	use std::thread;

	#[test]
	fn basic_mutex() {
		let m = Arc::new(PlLendableMutex::new(0));
		let mut handles = Vec::new();
		for _ in 0..100 {
			let m2 = m.clone();
			handles.push(thread::spawn(move || {
				let mut l = m2.lock();
				*l += 1;
			}));
		}
		for h in handles {
			h.join().unwrap();
		}
		assert_eq!(*m.lock(), 100);
	}

	#[test]
	fn stays_locked() {
		let m = Arc::new(PlLendableMutex::new(0));
		let mut handles = Vec::new();
		for _ in 0..100 {
			let m2 = m.clone();
			handles.push(thread::spawn(move || {
				//println!("[{:?}] locking", thread::current().id());
				let mut l = m2.lock();
				//println!("[{:?}] locked", thread::current().id());
				let old = *l;
				*l += 1;
				//println!("[{:?}] lending", thread::current().id());
				PlLendableMutexGuard::lend(&mut l, || {
					//println!("[{:?}] lent", thread::current().id());
					//#[allow(deprecated)]
					//thread::sleep_ms(100);
					thread::yield_now();
					let mut l2 = m2.lock();
					assert_eq!(*l2, old + 1);
					*l2 += 1;
					//println!("[{:?}] end lend", thread::current().id());
				});
				assert_eq!(*l, old + 2);
				//println!("[{:?}] end of thread", thread::current().id());
			}));
		}
		for h in handles {
			h.join().unwrap();
		}
		assert_eq!(*m.lock(), 200);
	}
}
