//
// Copyright (c) 2022 Oleg Lelenkov <o.lelenkov@gmail.com>
// Distributed under terms of the BSD 3-Clause license.
//

use std::borrow::Cow;

use serde::de::{self, Visitor, IntoDeserializer};

use crate::serialize::error::{UniNodeSerError, Result};

pub struct MapKeyDeserializer<'de> {
    key: Cow<'de, str>,
}

impl<'de> MapKeyDeserializer<'de> {
    pub fn new(key: Cow<'de, str>) -> Self {
        Self { key }
    }
}

macro_rules! deserialize_integer_key {
    ($method:ident => $visit:ident) => {
        fn $method<V>(self, visitor: V) -> Result<V::Value>
        where
            V: Visitor<'de>,
        {
            match (self.key.parse(), self.key) {
                (Ok(integer), _) => visitor.$visit(integer),
                (Err(_), Cow::Borrowed(s)) => visitor.visit_borrowed_str(s),
                (Err(_), Cow::Owned(s)) => visitor.visit_string(s),
            }
        }
    };
}

impl<'de> de::Deserializer<'de> for MapKeyDeserializer<'de> {
    type Error = UniNodeSerError;

    deserialize_integer_key!(deserialize_i8 => visit_i8);

    deserialize_integer_key!(deserialize_i16 => visit_i16);

    deserialize_integer_key!(deserialize_i32 => visit_i32);

    deserialize_integer_key!(deserialize_i64 => visit_i64);

    deserialize_integer_key!(deserialize_u8 => visit_u8);

    deserialize_integer_key!(deserialize_u16 => visit_u16);

    deserialize_integer_key!(deserialize_u32 => visit_u32);

    deserialize_integer_key!(deserialize_u64 => visit_u64);

    serde::forward_to_deserialize_any! {
        bool f32 f64 char str string bytes byte_buf unit unit_struct seq tuple
        tuple_struct map struct identifier ignored_any
    }

    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self.key {
            Cow::Borrowed(string) => visitor.visit_borrowed_str(string),
            Cow::Owned(string) => visitor.visit_string(string),
        }
    }

    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        visitor.visit_some(self)
    }

    fn deserialize_newtype_struct<V>(
        self, _name: &'static str, visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        visitor.visit_newtype_struct(self)
    }

    fn deserialize_enum<V>(
        self, _name: &'static str, _variants: &'static [&'static str],
        visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        visitor.visit_enum(self.key.into_deserializer())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde::Deserialize;

    #[test]
    fn key_integers() {
        fn assert_deserialize_int<'de, T>(num: &'de str, val: T)
        where
            T: Deserialize<'de> + std::fmt::Debug + std::cmp::PartialEq,
        {
            let de = MapKeyDeserializer::new(Cow::Borrowed(num));
            assert_eq!(T::deserialize(de).unwrap(), val)
        }

        assert_deserialize_int("8", 8i8);
        assert_deserialize_int("16", 16i16);
        assert_deserialize_int("32", 32i32);
        assert_deserialize_int("64", 64i64);
        assert_deserialize_int("8", 8u8);
        assert_deserialize_int("16", 16u16);
        assert_deserialize_int("32", 32u32);
        assert_deserialize_int("64", 64u64);
    }

    #[test]
    fn key_string() {
        let de = MapKeyDeserializer::new(Cow::Borrowed("key"));
        assert_eq!(String::deserialize(de).unwrap(), String::from("key"));
    }

    #[test]
    fn key_option() {
        let de = MapKeyDeserializer::new(Cow::Borrowed("key"));
        assert_eq!(
            Option::<String>::deserialize(de).unwrap(),
            Some(String::from("key"))
        );
    }

    #[test]
    fn key_newtype() {
        #[derive(Debug, Deserialize, PartialEq)]
        struct Key(u8);
        let de = MapKeyDeserializer::new(Cow::Borrowed("42"));
        assert_eq!(Key::deserialize(de).unwrap(), Key(42));
    }

    #[test]
    fn key_enum() {
        #[derive(Debug, Deserialize, PartialEq)]
        enum Key {
            One,
            Two,
        }
        let de = MapKeyDeserializer::new(Cow::Borrowed("One"));
        assert_eq!(Key::deserialize(de).unwrap(), Key::One);
    }
}
