//
// Copyright (c) 2022 Oleg Lelenkov <o.lelenkov@gmail.com>
// Distributed under terms of the BSD 3-Clause license.
//

use std::borrow::Cow;

use serde::de::{self, Visitor, IntoDeserializer};
use serde::de::value::SeqDeserializer;

use crate::value::{UniNode, Object};
use crate::serialize::error::{UniNodeSerError, Result};
use super::mapkey::MapKeyDeserializer;

fn visit_array<'de, T, I, V>(iter: I, visitor: V) -> Result<V::Value>
where
    V: Visitor<'de>,
    I: Iterator<Item = T>,
    T: IntoDeserializer<'de, UniNodeSerError>,
{
    let mut deserializer = SeqDeserializer::new(iter);
    let seq = visitor.visit_seq(&mut deserializer)?;
    deserializer.end()?;
    Ok(seq)
}

fn visit_object<'de, V>(object: Object, visitor: V) -> Result<V::Value>
where
    V: Visitor<'de>,
{
    let len = object.len();
    let mut deserializer = MapDeserializer::new(object);
    let map = visitor.visit_map(&mut deserializer)?;
    let remaining = deserializer.iter.len();
    if remaining == 0 {
        Ok(map)
    } else {
        Err(de::Error::invalid_length(len, &"fewer elements in map"))
    }
}

impl<'de> de::Deserializer<'de> for UniNode {
    type Error = UniNodeSerError;

    serde::forward_to_deserialize_any! {
        i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 bool
    }

    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self {
            UniNode::Null => visitor.visit_unit(),
            UniNode::Boolean(v) => visitor.visit_bool(v),
            UniNode::Integer(v) => visitor.visit_i64(v),
            UniNode::UInteger(v) => visitor.visit_u64(v),
            UniNode::Float(v) => visitor.visit_f64(v),
            UniNode::String(v) => visitor.visit_string(v),
            UniNode::Bytes(v) => visitor.visit_byte_buf(v),
            UniNode::Array(v) => visit_array(v.into_iter(), visitor),
            UniNode::Object(v) => visit_object(v, visitor),
        }
    }

    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_string(visitor)
    }

    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_string(visitor)
    }

    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self {
            UniNode::String(v) => visitor.visit_string(v),
            _ => Err(self.invalid_type(&visitor)),
        }
    }

    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_string(visitor)
    }

    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_byte_buf(visitor)
    }

    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self {
            UniNode::String(v) => visitor.visit_string(v),
            UniNode::Array(v) => visit_array(v.into_iter(), visitor),
            UniNode::Bytes(v) => visitor.visit_byte_buf(v),
            _ => Err(self.invalid_type(&visitor)),
        }
    }

    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self {
            UniNode::Null => visitor.visit_none(),
            _ => visitor.visit_some(self),
        }
    }

    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self {
            UniNode::Null => visitor.visit_unit(),
            _ => Err(self.invalid_type(&visitor)),
        }
    }

    fn deserialize_unit_struct<V>(
        self, _name: &'static str, visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_unit(visitor)
    }

    fn deserialize_newtype_struct<V>(
        self, _name: &'static str, visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        visitor.visit_newtype_struct(self)
    }

    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self {
            UniNode::Array(v) => visit_array(v.into_iter(), visitor),
            UniNode::Bytes(v) => visit_array(v.into_iter(), visitor),
            _ => visit_array(std::iter::once(self), visitor),
        }
    }

    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_seq(visitor)
    }

    fn deserialize_tuple_struct<V>(
        self, _name: &'static str, _len: usize, visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_seq(visitor)
    }

    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self {
            UniNode::Object(v) => visit_object(v, visitor),
            _ => Err(self.invalid_type(&visitor)),
        }
    }

    fn deserialize_struct<V>(
        self, _name: &'static str, _fields: &'static [&'static str], visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self {
            UniNode::Array(v) => visit_array(v.into_iter(), visitor),
            UniNode::Object(v) => visit_object(v, visitor),
            _ => Err(self.invalid_type(&visitor)),
        }
    }

    fn deserialize_enum<V>(
        self, _name: &str, _variants: &'static [&'static str], visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let (variant, value) = match self {
            UniNode::Object(value) => {
                let mut iter = value.into_iter();
                let (variant, value) = match iter.next() {
                    Some(v) => v,
                    None => {
                        return Err(de::Error::invalid_value(
                            de::Unexpected::Map,
                            &"map with a single key",
                        ));
                    },
                };
                if iter.next().is_some() {
                    return Err(de::Error::invalid_value(
                        de::Unexpected::Map,
                        &"map with a single key",
                    ));
                }
                (variant, Some(value))
            },
            UniNode::String(variant) => (variant, None),
            other => return Err(other.invalid_type(&"string or map")),
        };

        visitor.visit_enum(EnumDeserializer { variant, value })
    }

    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        drop(self);
        visitor.visit_unit()
    }
}

impl<'de> IntoDeserializer<'de, UniNodeSerError> for UniNode {
    type Deserializer = UniNode;

    fn into_deserializer(self) -> Self::Deserializer {
        self
    }
}

struct MapDeserializer {
    iter: <Object as IntoIterator>::IntoIter,
    value: Option<UniNode>,
}

impl MapDeserializer {
    fn new(map: Object) -> Self {
        Self {
            iter: map.into_iter(),
            value: None,
        }
    }
}

impl<'de> de::Deserializer<'de> for MapDeserializer {
    type Error = UniNodeSerError;

    serde::forward_to_deserialize_any! {
        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
        bytes byte_buf option unit unit_struct newtype_struct seq tuple
        tuple_struct map struct enum identifier ignored_any
    }

    #[inline]
    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        visitor.visit_map(self)
    }
}

impl<'de> de::MapAccess<'de> for MapDeserializer {
    type Error = UniNodeSerError;

    fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
    where
        T: de::DeserializeSeed<'de>,
    {
        match self.iter.next() {
            Some((key, value)) => {
                self.value = Some(value);
                let key_de = MapKeyDeserializer::new(Cow::Owned(key));
                seed.deserialize(key_de).map(Some)
            },
            None => Ok(None),
        }
    }

    fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
    where
        T: de::DeserializeSeed<'de>,
    {
        match self.value.take() {
            Some(value) => seed.deserialize(value),
            None => Err(de::Error::custom("value is missing")),
        }
    }

    fn size_hint(&self) -> Option<usize> {
        match self.iter.size_hint() {
            (lower, Some(upper)) if lower == upper => Some(upper),
            _ => None,
        }
    }
}

struct EnumDeserializer {
    variant: String,
    value: Option<UniNode>,
}

impl<'de> de::EnumAccess<'de> for EnumDeserializer {
    type Error = UniNodeSerError;
    type Variant = VariantDeserializer;

    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, VariantDeserializer)>
    where
        V: de::DeserializeSeed<'de>,
    {
        let variant = self.variant.into_deserializer();
        let visitor = VariantDeserializer { value: self.value };
        seed.deserialize(variant).map(|v| (v, visitor))
    }
}

struct VariantDeserializer {
    value: Option<UniNode>,
}

impl<'de> de::VariantAccess<'de> for VariantDeserializer {
    type Error = UniNodeSerError;

    fn unit_variant(self) -> Result<()> {
        match self.value {
            Some(value) => de::Deserialize::deserialize(value),
            None => Ok(()),
        }
    }

    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
    where
        T: de::DeserializeSeed<'de>,
    {
        match self.value {
            Some(value) => seed.deserialize(value),
            None => Err(de::Error::invalid_type(
                de::Unexpected::UnitVariant,
                &"newtype variant",
            )),
        }
    }

    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self.value {
            Some(UniNode::Array(v)) => serde::Deserializer::deserialize_any(
                SeqDeserializer::new(v.into_iter()),
                visitor,
            ),
            Some(other) => Err(other.invalid_type(&"tuple variant")),
            None => Err(de::Error::invalid_type(
                de::Unexpected::UnitVariant,
                &"tuple variant",
            )),
        }
    }

    fn struct_variant<V>(
        self, _fields: &'static [&'static str], visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        match self.value {
            Some(UniNode::Object(v)) => serde::Deserializer::deserialize_any(
                MapDeserializer::new(v),
                visitor,
            ),
            Some(other) => Err(other.invalid_type(&"struct variant")),
            None => Err(de::Error::invalid_type(
                de::Unexpected::UnitVariant,
                &"struct variant",
            )),
        }
    }
}
