/*-
 * cdns-rs - a simple sync/async DNS query library
 * Copyright (C) 2020  Aleksandr Morozov, RELKOM s.r.o
 * Copyright (C) 2021-2022  Aleksandr Morozov
 * 
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 *  file, You can obtain one at https://mozilla.org/MPL/2.0/.
 */

use std::fmt;
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, Ordering};

use crossbeam_utils::Backoff;

pub use super::poison;

pub struct Mutex<T: ?Sized> 
{
    c_locked: AtomicBool,
    poison: poison::Flag,
    data: UnsafeCell<T>,
}

unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}

pub struct MutexGuard<'a, T: ?Sized + 'a> 
{
    lock: &'a Mutex<T>,
    poison: poison::Guard,
}

//impl<T: ?Sized> !Send for MutexGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}

impl<'mutex, T: ?Sized> MutexGuard<'mutex, T> 
{
    unsafe fn new(lock: &'mutex Mutex<T>) -> poison::LockResult<MutexGuard<'mutex, T>> 
    {
        poison::map_result(lock.poison.borrow(), |guard| MutexGuard { lock, poison: guard })
    }
}

impl<T: ?Sized> Deref for MutexGuard<'_, T> 
{
    type Target = T;

    fn deref(&self) -> &T 
    {
        unsafe { &*self.lock.data.get() }
    }
}

impl<T: ?Sized> DerefMut for MutexGuard<'_, T> 
{
    fn deref_mut(&mut self) -> &mut T 
    {
        unsafe { &mut *self.lock.data.get() }
    }
}

impl<T: ?Sized> Drop for MutexGuard<'_, T> 
{
    #[inline]
    fn drop(&mut self) 
    {
        self.lock.poison.done(&self.poison);

        // unlock 
        self.lock.c_locked.store(false, Ordering::Release);
    }
}

impl<T: ?Sized + fmt::Debug> fmt::Debug for MutexGuard<'_, T> 
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result 
    {
        fmt::Debug::fmt(&**self, f)
    }
}

impl<T: ?Sized + fmt::Display> fmt::Display for MutexGuard<'_, T> 
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result 
    {
        (**self).fmt(f)
    }
}

pub fn guard_poison<'a, T: ?Sized>(guard: &MutexGuard<'a, T>) -> &'a poison::Flag 
{
    &guard.lock.poison
}

impl<T> Mutex<T> 
{
    pub fn new(t: T) -> Mutex<T> 
    {
        Mutex 
        {
            c_locked: AtomicBool::new(false),
            poison: poison::Flag::new(),
            data: UnsafeCell::new(t),
        }
    }

    pub fn unlock(guard: MutexGuard<'_, T>) 
    {
        drop(guard);
    }
}

impl<T: ?Sized> Mutex<T> 
{
    pub fn lock(&self) -> poison::LockResult<MutexGuard<'_, T>> 
    {
        let backoff = Backoff::new();

        // try lock
        while self.c_locked.swap(true, Ordering::Acquire) == true
        {
            backoff.snooze();
        }

        return unsafe{ MutexGuard::new(self) };
    }

    pub fn try_lock(&self) -> poison::TryLockResult<MutexGuard<'_, T>> 
    {
        unsafe 
        {
            if self.c_locked.load(Ordering::Acquire) == true
            {
                Err(poison::TryLockError::WouldBlock)
            } 
            else 
            {
                let backoff = Backoff::new();

                // try lock
                while self.c_locked.swap(true, Ordering::Acquire) == true
                {
                    backoff.snooze();
                }
                
                Ok( MutexGuard::new(self)? )
            }
        }
    }

    pub fn is_poisoned(&self) -> bool 
    {
        self.poison.get()
    }

    pub fn into_inner(self) -> poison::LockResult<T>
    where T: Sized,
    {
        let data = self.data.into_inner();
        poison::map_result(self.poison.borrow(), |_| data)
    }

    pub fn get_mut(&mut self) -> poison::LockResult<&mut T> 
    {
        let data = self.data.get_mut();
        poison::map_result(self.poison.borrow(), |_| data)
    }
}


impl<T> From<T> for Mutex<T> 
{
    fn from(t: T) -> Self 
    {
        Mutex::new(t)
    }
}

impl<T: ?Sized + Default> Default for Mutex<T> 
{
    fn default() -> Mutex<T> 
    {
        Mutex::new(Default::default())
    }
}

impl<T: ?Sized + fmt::Debug> fmt::Debug for Mutex<T> 
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result 
    {
        let mut d = f.debug_struct("Mutex");
        match self.try_lock() 
        {
            Ok(guard) => 
            {
                d.field("data", &&*guard);
            }
            Err(poison::TryLockError::Poisoned(err)) => 
            {
                d.field("data", &&**err.get_ref());
            }
            Err(poison::TryLockError::WouldBlock) => 
            {
                struct LockedPlaceholder;
                impl fmt::Debug for LockedPlaceholder {
                    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
                        f.write_str("<locked>")
                    }
                }
                d.field("data", &LockedPlaceholder);
            }
        }

        d.field("poisoned", &self.poison.get());
        d.finish_non_exhaustive()
    }
}

//todo tests

#[test]
fn test1_rw()
{
    use std::time::Instant;
    
    let start1 = Instant::now();

    let rw: std::sync::Mutex<Vec<u32>> = std::sync::Mutex::new(vec![1]);

    let start = Instant::now();
    let w = rw.lock();
    let el = start.elapsed();
    println!("write: {:?}", el);

    let mut w = w.unwrap();
    w.push(2);

    drop(w);

    let mut w = rw.lock().unwrap();
    w.push(3);

    drop(w);

    let r = rw.lock().unwrap();

    let el = start1.elapsed();
    println!("finished: {:?}", el);

    for i in r.iter()
    {
        println!("{}", i);
    }

    let v = vec![1,2,3];
    assert_eq!(v, *r);
}

#[test]
fn test2()
{
    use std::thread;
    use std::sync::Arc;
    use std::time::Instant;

    let start1 = Instant::now();

    let rw: Arc<Mutex<Vec<u32>>> = Arc::new(Mutex::new(vec![1]));

    let c_rw = rw.clone();

    let c_rw2 = rw.clone();
    
    let r = rw.lock().unwrap();

    let handler = thread::spawn(move || {
        // thread code
        println!("waiting to write 2");
        let mut w = c_rw.lock().unwrap();
        w.push(2);

        println!("wrote 2");
        println!("2:> {:?}", c_rw);
        drop(w);
    });

    std::thread::sleep(std::time::Duration::from_millis(1));
    let handler2 = thread::spawn(move || {
        // thread code
        println!("waiting to write 3");
        let mut w = c_rw2.lock().unwrap();
        w.push(3);
        
        println!("wrote 3");

        println!("3:> {:?}", c_rw2);
        drop(w);
    });
    
    std::thread::sleep(std::time::Duration::from_millis(4));

    for i in r.iter()
    {
        println!("{}", i);
    }

    println!("1:> {:?}", rw);

    println!("releasing read!");


    drop(r);

    handler.join().unwrap();
    handler2.join().unwrap();


    let start = Instant::now();
    let r = rw.lock().unwrap();
    let el = start.elapsed();
    println!("read: {:?}", el);
    
    println!("pre 1:> {:?}", rw);

    for i in r.iter()
    {
        println!("{}", i);
    }
    let v = vec![1,2,3];
    assert_eq!(v, *r);

    drop(r);

    let el = start1.elapsed();
    println!("finished: {:?}", el);

    println!("1:> {:?}", rw);
}


