use super::event::*;
use crate::aggregate::UncommittedEvent;
use crate::error::{Error, Result as TsResult};
use chrono::{DateTime, TimeZone, Utc};
use futures::stream::{BoxStream, StreamExt, TryStreamExt};
use lru::LruCache;
use serde::{Deserialize, Serialize};
use serde_json::Value as Json;
use sqlx::PgPool;
use std::borrow::Cow;
use std::convert::TryFrom;
use std::fmt::Debug;
use uuid::Uuid;

#[derive(Debug, Clone)]
pub enum CommitOrder {
    None,
    First,
    Following(DateTime<Utc>),
}

pub type EventStream<'a, S> = BoxStream<'a, TsResult<Persisted<<S as EventStore>::Event>>>;

/// An Event Store is an append-only, ordered list of
/// [`Event`](super::aggregate::Aggregate::Event)s for a certain "source" --
/// e.g. an [`Aggregate`](super::aggregate::Aggregate).
#[async_trait]
pub trait EventStore {
    type Event: Clone;

    async fn commit(
        &mut self,
        aggregate_id: Uuid,
        order: CommitOrder,
        events: &[UncommittedEvent<Self::Event>],
    ) -> TsResult<DateTime<Utc>>;

    async fn aggregate_stream(&self, aggregate_id: Uuid) -> TsResult<EventStream<Self>>;

    async fn remove(&mut self, aggregate_id: Uuid) -> TsResult<()>;
}

#[derive(Debug)]
struct EventName {
    id: i32,
    name: String,
}

#[derive(Debug)]
pub struct EventStoreBuilder {
    with_migrations: bool,
    pool: PgPool,
}

impl EventStoreBuilder {
    pub fn new(pool: PgPool) -> Self {
        Self {
            pool,
            with_migrations: true,
        }
    }

    pub fn with_migrations(mut self, value: bool) -> Self {
        self.with_migrations = value;
        self
    }

    pub async fn build<Event>(
        self,
        aggregate_type_name: Cow<'static, str>,
    ) -> TsResult<TimescaleStore<Event>>
    where
        Event: Serialize + Send + Sync + Debug + std::marker::Unpin,
        for<'de> Event: Deserialize<'de>,
    {
        let pool = self.pool;

        if self.with_migrations {
            sqlx::migrate!().run(&pool).await?;
        }

        let aggregate_type_id =
            sqlx::query_file_scalar!("queries/aggregate_type/id.sql", &*aggregate_type_name)
                .fetch_one(&pool)
                .await?
                .ok_or_else(|| Error::InvalidData("Unable to get aggregate type id".into()))?;

        let event_name_cache_data = sqlx::query_file_as!(
            EventName,
            "queries/event/get_all_name_events.sql",
            aggregate_type_id,
            100
        )
        .fetch_all(&pool)
        .await?;

        debug!(
            aggregate_type_name = &*aggregate_type_name,
            event_names = ?*event_name_cache_data,
            aggregate_type_id,
            "Aggregate root data received"
        );

        let mut event_name_cache = LruCache::new(100);
        let total = event_name_cache_data.len();

        for EventName { name, id } in event_name_cache_data {
            event_name_cache.put(name, id);
        }

        debug!(total, "Populated event name cache");

        Ok(TimescaleStore {
            event_name_cache,
            payload: std::marker::PhantomData,
            pool,
            aggregate_type_id,
            aggregate_type_name,
        })
    }
}

#[derive(sqlx::Type)]
#[sqlx(type_name = "append_event_data")]
struct AppendEventInput {
    data: Option<Json>,
    name_id: i32,
    time: i64,
}

#[derive(sqlx::Type)]
#[sqlx(type_name = "_append_event_data")]
struct VecAppendEventInput(Vec<AppendEventInput>);

#[derive(Debug)]
pub struct TimescaleStore<Event> {
    event_name_cache: LruCache<String, i32>,
    payload: std::marker::PhantomData<Event>,
    pool: PgPool,
    pub aggregate_type_name: Cow<'static, str>,
    pub aggregate_type_id: i32,
}

impl<Event> TimescaleStore<Event>
where
    Event: Serialize + Send + Sync + TimescaleEventPayload + Debug,
    for<'de> Event: Deserialize<'de>,
{
    #[tracing::instrument]
    async fn get_event_name_id(&mut self, name: String) -> TsResult<i32> {
        match self.event_name_cache.get(&name) {
            Some(id) => Ok(*id),
            None => {
                let name_id = sqlx::query_file_scalar!(
                    "queries/event/upsert_event_name.sql",
                    self.aggregate_type_id,
                    name
                )
                .fetch_one(&self.pool)
                .await?
                // upsert_event_name should always return i32
                .unwrap();

                self.event_name_cache.put(name, name_id);

                Ok(name_id)
            }
        }
    }
}

#[async_trait]
impl<Event> EventStore for TimescaleStore<Event>
where
    Event: Serialize + Send + Sync + TimescaleEventPayload + Debug + Clone,
    for<'de> Event: Deserialize<'de>,
{
    type Event = Event;

    #[tracing::instrument]
    async fn commit(
        &mut self,
        aggregate_id: Uuid,
        order: CommitOrder,
        events: &[UncommittedEvent<Self::Event>],
    ) -> TsResult<DateTime<Utc>> {
        if events.is_empty() {
            return Err(Error::InvalidData(
                "list of events can't be empty".to_string(),
            ));
        }

        let mut event_data: Vec<AppendEventInput> = Vec::with_capacity(events.len());

        for event in events {
            let time = event.utc.timestamp_nanos();
            let event = &event.data;
            let name = event.name().to_string();
            let mut value = serde_json::to_value(&event)?;

            let data = if let Json::Object(value) = &mut value {
                // Only store JSON value as if serde Untagged
                value.remove(&name)
            } else {
                // Don't store any JSON for other JSON types
                None
            };

            event_data.push(AppendEventInput {
                data,
                name_id: self.get_event_name_id(name).await?,
                time,
            });
        }

        let (orderly, last_known_time) = match order {
            CommitOrder::None => (false, None),
            CommitOrder::First => (true, None),
            CommitOrder::Following(utc) => (true, Some(utc.timestamp_nanos())),
        };

        let output = sqlx::query_file_scalar!(
            "queries/event/append.sql",
            self.aggregate_type_id,
            &aggregate_id,
            &VecAppendEventInput(event_data) as _,
            orderly,
            last_known_time
        )
        .fetch_one(&self.pool)
        .await?;

        let next_offset = Utc.timestamp_nanos(output.unwrap_or_default());

        debug!(?next_offset, "Committed events");

        Ok(next_offset)
    }

    #[tracing::instrument]
    async fn aggregate_stream(&self, aggregate_id: Uuid) -> TsResult<EventStream<Self>> {
        Ok(sqlx::query_file_as!(
            EventRow,
            "queries/event/get_aggregate_all.sql",
            self.aggregate_type_id,
            aggregate_id,
        )
        .fetch(&self.pool)
        .map_err(Error::from)
        .map(move |x| x.and_then(|x| Persisted::try_from(x).map_err(Error::from)))
        .boxed())
    }

    #[tracing::instrument]
    async fn remove(&mut self, aggregate_id: Uuid) -> TsResult<()> {
        sqlx::query_file!(
            "queries/event/delete_aggregate.sql",
            self.aggregate_type_id,
            &aggregate_id
        )
        .execute(&self.pool)
        .await
        .map_err(Error::from)
        .map(|_| ())
    }
}
