// Copyright 2021 MaidSafe.net limited.
//
// This SAFE Network Software is licensed to you under The General Public License (GPL), version 3.
// Unless required by applicable law or agreed to in writing, the SAFE Network Software distributed
// under the GPL Licence is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. Please review the Licences for the specific language governing
// permissions and limitations relating to use of the SAFE Network Software.

mod item;

use self::item::Item;
use itertools::Itertools;
use std::collections::BTreeMap;
use std::hash::Hash;
use std::time::Duration;
use tokio::sync::RwLock;

///
#[derive(Debug)]
pub struct Cache<T, V>
where
    T: Hash + Eq,
{
    items: RwLock<BTreeMap<T, Item<V>>>,
    item_duration: Option<Duration>,
    capacity: usize,
}

#[allow(clippy::len_without_is_empty)]
impl<T, V> Cache<T, V>
where
    T: Ord + Hash,
{
    /// Creating capacity based `Cache`.
    pub fn with_capacity(capacity: usize) -> Self {
        Self {
            items: RwLock::new(BTreeMap::new()),
            item_duration: None,
            capacity,
        }
    }

    /// Creating time based `Cache`.
    pub fn with_expiry_duration(duration: Duration) -> Self {
        Self {
            items: RwLock::new(BTreeMap::new()),
            item_duration: Some(duration),
            capacity: usize::MAX,
        }
    }

    /// Creating dual-feature capacity and time based `Cache`.
    pub fn with_expiry_duration_and_capacity(duration: Duration, capacity: usize) -> Self {
        Self {
            items: RwLock::new(BTreeMap::new()),
            item_duration: Some(duration),
            capacity,
        }
    }

    ///
    pub async fn len(&self) -> usize {
        self.items.read().await.len()
    }

    ///
    pub async fn is_empty(&self) -> bool {
        self.items.read().await.is_empty()
    }

    ///
    pub async fn count<P>(&self, predicate: P) -> usize
    where
        P: FnMut(&(&T, &Item<V>)) -> bool,
    {
        self.items.read().await.iter().filter(predicate).count()
    }

    ///
    pub async fn get(&self, key: &T) -> Option<V>
    where
        T: Eq + Hash,
        V: Clone,
    {
        self.items
            .read()
            .await
            .get(key)
            .filter(|&item| !item.expired())
            .map(|k| k.object.clone())
    }

    ///
    pub async fn set(&self, key: T, value: V, custom_duration: Option<Duration>) -> Option<V>
    where
        T: Eq + Hash + Clone,
    {
        let replaced = self
            .items
            .write()
            .await
            .insert(
                key,
                Item::new(value, custom_duration.or(self.item_duration)),
            )
            .map(|item| item.object);
        self.remove_expired().await;
        self.drop_excess().await;
        replaced
    }

    ///
    pub async fn remove_expired(&self) {
        let read_items = self.items.read().await;
        let expired_keys: Vec<_> = read_items
            .iter()
            .filter(|(_, item)| item.expired())
            .map(|(key, _)| key)
            .collect();

        for key in expired_keys {
            let _ = self.items.write().await.remove(key);
        }
    }

    /// removes keys beyond capacity
    async fn drop_excess(&self) {
        let len = self.len().await;
        if len > self.capacity {
            let excess = len - self.capacity;
            let read_items = self.items.read().await;
            let mut items = read_items.iter().collect_vec();

            // reversed sort
            items.sort_by(|(_, item_a), (_, item_b)| item_b.elapsed().cmp(&item_a.elapsed()));

            // take the excess
            for (key, _) in items.iter().take(excess) {
                let _ = self.items.write().await.remove(key);
            }
        }
    }

    ///
    pub async fn remove(&self, key: &T) -> Option<V>
    where
        T: Eq + Hash,
    {
        self.items.write().await.remove(key).map(|item| item.object)
    }

    ///
    pub async fn clear(&self) {
        self.items.write().await.clear()
    }
}

#[cfg(test)]
mod tests {
    use crate::cache::Cache;
    use std::time::Duration;

    const KEY: i8 = 0;
    const VALUE: &str = "VALUE";

    #[tokio::test]
    async fn set_and_get_value_with_default_duration() {
        let cache = Cache::with_expiry_duration(Duration::from_secs(2));
        let _ = cache.set(KEY, VALUE, None).await;
        let value = cache.get(&KEY).await;
        assert_eq!(value, Some(VALUE), "value was not found in cache");
    }

    #[tokio::test]
    async fn set_and_get_value_without_duration() {
        let cache = Cache::with_capacity(usize::MAX);
        let _ = cache.set(KEY, VALUE, None).await;
        let value = cache.get(&KEY).await;
        assert_eq!(value, Some(VALUE), "value was not found in cache");
    }

    #[tokio::test]
    async fn set_and_get_value_with_custom_duration() {
        let cache = Cache::with_expiry_duration(Duration::from_secs(0));
        let _ = cache.set(KEY, VALUE, Some(Duration::from_secs(2))).await;
        let value = cache.get(&KEY).await;
        assert_eq!(value, Some(VALUE), "value was not found in cache");
    }

    #[tokio::test]
    async fn set_do_not_get_expired_value() {
        let cache = Cache::with_expiry_duration(Duration::from_secs(0));
        let _ = cache.set(KEY, VALUE, None).await;
        let value = cache.get(&KEY).await;
        assert!(value.is_none(), "found expired value in cache");
    }

    #[tokio::test]
    async fn set_replace_existing_value() {
        const NEW_VALUE: &str = "NEW_VALUE";
        let cache = Cache::with_expiry_duration(Duration::from_secs(2));
        let _ = cache.set(KEY, VALUE, None).await;
        let _ = cache.set(KEY, NEW_VALUE, None).await;
        let value = cache.get(&KEY).await;
        assert_eq!(value, Some(NEW_VALUE), "value was not found in cache");
    }

    #[tokio::test]
    async fn remove_expired_item() {
        let cache = Cache::with_expiry_duration(Duration::from_secs(0));
        let _ = cache.set(KEY, VALUE, None).await;
        cache.remove_expired().await;
        assert!(
            cache.items.read().await.get(&KEY).is_none(),
            "found expired value in cache"
        );
    }

    #[tokio::test]
    async fn remove_expired_do_not_remove_not_expired_item() {
        let cache = Cache::with_expiry_duration(Duration::from_secs(2));
        let _ = cache.set(KEY, VALUE, None).await;
        cache.remove_expired().await;
        assert!(
            cache.items.read().await.get(&KEY).is_some(),
            "could not find not expired item in cache"
        );
    }

    #[tokio::test]
    async fn clear_not_expired_item() {
        let cache = Cache::with_expiry_duration(Duration::from_secs(2));
        let _ = cache.set(KEY, VALUE, None).await;
        cache.clear().await;
        assert!(
            cache.items.read().await.get(&KEY).is_none(),
            "found item in cache"
        );
    }

    #[tokio::test]
    async fn remove_remove_expired_item() {
        let cache = Cache::with_expiry_duration(Duration::from_secs(2));
        let _ = cache.set(KEY, VALUE, None).await;
        assert!(
            cache.remove(&KEY).await.is_some(),
            "none returned from removing existing value"
        );
        assert!(
            cache.items.read().await.get(&KEY).is_none(),
            "found not expired item in cache"
        );
    }

    #[tokio::test]
    async fn remove_return_none_if_not_found() {
        let cache: Cache<i8, &str> = Cache::with_expiry_duration(Duration::from_secs(2));
        assert!(
            cache.remove(&KEY).await.is_none(),
            "some value was returned from remove"
        );
    }
}
