//! An implementation of `RocksDB` database.

/// Backup-related stuff for `RocksDB` database.
pub mod backup {
    pub use rocksdb::backup::{
        BackupEngine as RocksDBBackupEngine, BackupEngineInfo as RocksDBBackupEngineInfo,
        BackupEngineOptions as RocksDBBackupEngineOptions, RestoreOptions as RocksDBRestoreOptions,
    };
}

use crossbeam::sync::{ShardedLock, ShardedLockReadGuard};
use rocksdb::{
    self, checkpoint::Checkpoint, Cache as RocksDBCache, ColumnFamily, DBIterator,
    Options as RocksDBOptions, WriteBatch, WriteOptions as RocksDBWriteOptions,
};
use smallvec::SmallVec;
use std::{fmt, iter, iter::Peekable, mem, path::Path, sync::Arc};

use crate::{
    db::{check_database, Change},
    DBOptions, Database, Iter, Iterator, Patch, ResolvedAddress, Snapshot,
};

/// Size of a byte representation of an index ID, which is used to prefix index keys
/// in a column family.
pub const ID_SIZE: usize = mem::size_of::<u64>();

/// Database implementation on top of [`RocksDB`](https://rocksdb.org)
/// backend.
///
/// `RocksDB` is an embedded database for key-value data, which is optimized for fast storage.
/// This structure is required to potentially adapt the interface to
/// use different databases.
#[derive(Clone)]
pub struct RocksDB {
    db: Arc<ShardedLock<rocksdb::DB>>,
    options: DBOptions,
}

impl From<DBOptions> for RocksDBOptions {
    fn from(opts: DBOptions) -> Self {
        Self::from(&opts)
    }
}

impl From<&DBOptions> for RocksDBOptions {
    fn from(opts: &DBOptions) -> Self {
        let mut defaults = Self::default();
        defaults.create_if_missing(opts.create_if_missing);
        defaults.set_compression_type(opts.compression_type.into());
        defaults.set_max_open_files(opts.max_open_files.unwrap_or(-1));
        defaults.set_max_total_wal_size(opts.max_total_wal_size.unwrap_or(0));
        if let Some(capacity) = opts.max_cache_size {
            defaults.set_row_cache(
                &RocksDBCache::new_lru_cache(capacity)
                    .expect("Failed to instantiate `Cache` for `RocksDB`"),
            );
        }
        defaults
    }
}

/// A snapshot of a `RocksDB`.
pub struct RocksDBSnapshot {
    snapshot: rocksdb::Snapshot<'static>,
    db: Arc<ShardedLock<rocksdb::DB>>,
}

/// An iterator over the entries of a `RocksDB`.
struct RocksDBIterator<'a> {
    iter: Peekable<DBIterator<'a>>,
    key: Option<Box<[u8]>>,
    value: Option<Box<[u8]>>,
    prefix: Option<[u8; ID_SIZE]>,
    ended: bool,
}

impl RocksDB {
    /// Opens a database stored at the specified path with the specified options.
    ///
    /// If the database does not exist at the indicated path and the option
    /// `create_if_missing` is switched on in `DBOptions`, a new database will
    /// be created at the indicated path.
    pub fn open<P: AsRef<Path>>(path: P, options: &DBOptions) -> crate::Result<Self> {
        let inner = {
            if let Ok(names) = rocksdb::DB::list_cf(&RocksDBOptions::default(), &path) {
                let cf_names = names.iter().map(String::as_str).collect::<Vec<_>>();
                rocksdb::DB::open_cf(&options.into(), path, cf_names)?
            } else {
                rocksdb::DB::open(&options.into(), path)?
            }
        };
        let mut db = Self {
            db: Arc::new(ShardedLock::new(inner)),
            options: *options,
        };
        check_database(&mut db)?;
        Ok(db)
    }

    /// Creates checkpoint of this database in the given directory. See [`RocksDB` docs] for
    /// details.
    ///
    /// Successfully created checkpoint can be opened using `RocksDB::open`.
    ///
    /// [`RocksDB` docs]: https://github.com/facebook/rocksdb/wiki/Checkpoints
    pub fn create_checkpoint<T: AsRef<Path>>(&self, path: T) -> crate::Result<()> {
        let guard = self.get_db_lock_guard();
        let checkpoint = Checkpoint::new(&*guard)?;
        checkpoint.create_checkpoint(path)?;
        Ok(())
    }

    /// Retrieves read lock guard containing underlying `rocksdb::DB`.
    pub fn get_db_lock_guard(&self) -> ShardedLockReadGuard<'_, rocksdb::DB> {
        self.db.read().expect("Failed to get read lock to DB")
    }

    fn cf_exists(&self, cf_name: &str) -> bool {
        self.get_db_lock_guard().cf_handle(cf_name).is_some()
    }

    fn create_cf(&self, cf_name: &str) -> crate::Result<()> {
        self.db
            .write()
            .expect("Failed to get write lock to DB")
            .create_cf(cf_name, &self.options.into())
            .map_err(Into::into)
    }

    /// Clears the column family completely, removing all keys from it.
    pub(super) fn clear_column_family(&self, batch: &mut WriteBatch, cf: &ColumnFamily) {
        /// Some lexicographically large key.
        const LARGER_KEY: &[u8] = &[u8::max_value(); 1_024];

        let db_reader = self.get_db_lock_guard();
        let mut iter = db_reader.raw_iterator_cf(cf);
        iter.seek_to_last();
        if iter.valid() {
            if let Some(key) = iter.key() {
                // For some reason, removing a range to a very large key is
                // significantly faster than removing the exact range.
                // This is specific to the debug mode, but since `TemporaryDB`
                // is mostly used for testing, this optimization leads to practical
                // performance improvement.
                if key.len() < LARGER_KEY.len() {
                    batch.delete_range_cf(cf, &[][..], LARGER_KEY);
                } else {
                    batch.delete_range_cf(cf, &[][..], key);
                    batch.delete_cf(cf, &key);
                }
            }
        }
    }

    fn do_merge(&self, patch: Patch, w_opts: &RocksDBWriteOptions) -> crate::Result<()> {
        let mut batch = WriteBatch::default();
        for (resolved, changes) in patch.into_changes() {
            if !self.cf_exists(&resolved.name) {
                self.create_cf(&resolved.name)?;
            }

            let db_reader = self.get_db_lock_guard();
            let cf = db_reader.cf_handle(&resolved.name).unwrap();

            if changes.is_cleared() {
                self.clear_prefix(&mut batch, cf, &resolved);
            }

            if let Some(id_bytes) = resolved.id_to_bytes() {
                // Write changes to the column family with each key prefixed by the ID of the
                // resolved address.

                // We assume that typical key sizes are less than `1_024 - ID_SIZE = 1_016` bytes,
                // so that they fit into stack.
                let mut buffer: SmallVec<[u8; 1_024]> = SmallVec::new();
                buffer.extend_from_slice(&id_bytes);

                for (key, change) in changes.into_data() {
                    buffer.truncate(ID_SIZE);
                    buffer.extend_from_slice(&key);
                    match change {
                        Change::Put(ref value) => batch.put_cf(cf, &buffer, value),
                        Change::Delete => batch.delete_cf(cf, &buffer),
                    }
                }
            } else {
                // Write changes to the column family as-is.
                for (key, change) in changes.into_data() {
                    match change {
                        Change::Put(ref value) => batch.put_cf(cf, &key, value),
                        Change::Delete => batch.delete_cf(cf, &key),
                    }
                }
            }
        }

        self.get_db_lock_guard()
            .write_opt(batch, w_opts)
            .map_err(Into::into)
    }

    /// Removes all keys with the specified prefix from a column family.
    fn clear_prefix(&self, batch: &mut WriteBatch, cf: &ColumnFamily, resolved: &ResolvedAddress) {
        if let Some(id_bytes) = resolved.id_to_bytes() {
            let next_bytes = next_id_bytes(id_bytes);
            batch.delete_range_cf(cf, id_bytes, next_bytes);
        } else {
            self.clear_column_family(batch, cf);
        }
    }

    #[allow(unsafe_code)]
    #[allow(clippy::useless_transmute)]
    pub(super) fn rocksdb_snapshot(&self) -> RocksDBSnapshot {
        RocksDBSnapshot {
            // SAFETY:
            // The snapshot carries an `Arc` to the database to make sure that database
            // is not dropped before the snapshot. Additionally, the pointer to `rocksdb::DB`
            // is stable within `Arc<ShardedLock<rocksdb::DB>>` and its part used in dropping
            // the snapshot (`*mut ffi::rocksdb_t`) is never changed, i.e., not affected
            // by potential incoherence if the `ShardedLock` is being concurrently written to.
            // FIXME: Investigate changing `rocksdb::Snapshot` / `DB` to remove `unsafe` (ECR-4273).
            snapshot: unsafe { mem::transmute(self.get_db_lock_guard().snapshot()) },
            db: Arc::clone(&self.db),
        }
    }
}

impl RocksDBSnapshot {
    fn get_lock_guard(&self) -> ShardedLockReadGuard<'_, rocksdb::DB> {
        self.db.read().expect("Failed to get read lock to DB")
    }

    fn rocksdb_iter(&self, name: &ResolvedAddress, from: &[u8]) -> RocksDBIterator<'_> {
        use rocksdb::{Direction, IteratorMode};

        let from = name.keyed(from);
        let iter = match self.get_lock_guard().cf_handle(&name.name) {
            Some(cf) => self
                .snapshot
                .iterator_cf(cf, IteratorMode::From(from.as_ref(), Direction::Forward)),
            None => self.snapshot.iterator(IteratorMode::Start),
        };
        RocksDBIterator {
            iter: iter.peekable(),
            prefix: name.id_to_bytes(),
            key: None,
            value: None,
            ended: false,
        }
    }
}

impl Database for RocksDB {
    fn snapshot(&self) -> Box<dyn Snapshot> {
        Box::new(self.rocksdb_snapshot())
    }

    fn merge(&self, patch: Patch) -> crate::Result<()> {
        let w_opts = RocksDBWriteOptions::default();
        self.do_merge(patch, &w_opts)
    }

    fn merge_sync(&self, patch: Patch) -> crate::Result<()> {
        let mut w_opts = RocksDBWriteOptions::default();
        w_opts.set_sync(true);
        self.do_merge(patch, &w_opts)
    }
}

impl Snapshot for RocksDBSnapshot {
    fn get(&self, resolved_addr: &ResolvedAddress, key: &[u8]) -> Option<Vec<u8>> {
        let lock = self.get_lock_guard();
        let cf = lock.cf_handle(&resolved_addr.name)?;
        self.snapshot
            .get_cf(cf, resolved_addr.keyed(key))
            .unwrap_or_else(|e| panic!("{}", e))
    }

    fn multi_get<'a>(
        &self,
        resolved_addr: &ResolvedAddress,
        keys: &'a mut dyn iter::Iterator<Item = &'a [u8]>,
    ) -> Vec<Option<Vec<u8>>> {
        let lock = self.get_lock_guard();
        let cf = if let Some(cf) = lock.cf_handle(&resolved_addr.name) {
            cf
        } else {
            return vec![None; keys.count()];
        };

        self.snapshot
            .multi_get_cf(keys.map(|key| (cf, resolved_addr.keyed(key))))
            .into_iter()
            .collect::<Result<_, _>>()
            .unwrap_or_else(|e| panic!("{}", e))
    }

    fn iter(&self, name: &ResolvedAddress, from: &[u8]) -> Iter<'_> {
        Box::new(self.rocksdb_iter(name, from))
    }
}

impl<'a> Iterator for RocksDBIterator<'a> {
    fn next(&mut self) -> Option<(&[u8], &[u8])> {
        if self.ended {
            return None;
        }

        let (key, value) = self.iter.next()?;
        if let Some(ref prefix) = self.prefix {
            if &key[..ID_SIZE] != prefix {
                self.ended = true;
                return None;
            }
        }

        self.key = Some(key);
        let key = if self.prefix.is_some() {
            &self.key.as_ref()?[ID_SIZE..]
        } else {
            &self.key.as_ref()?[..]
        };
        self.value = Some(value);
        Some((key, self.value.as_ref()?))
    }

    fn peek(&mut self) -> Option<(&[u8], &[u8])> {
        if self.ended {
            return None;
        }

        let (key, value) = self.iter.peek()?;
        let key = if let Some(prefix) = self.prefix {
            if key[..ID_SIZE] != prefix {
                self.ended = true;
                return None;
            }
            &key[ID_SIZE..]
        } else {
            &key[..]
        };
        Some((key, &value[..]))
    }
}

impl From<RocksDB> for Arc<dyn Database> {
    fn from(db: RocksDB) -> Self {
        Self::from(Box::new(db) as Box<dyn Database>)
    }
}

impl fmt::Debug for RocksDB {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("RocksDB").finish()
    }
}

impl fmt::Debug for RocksDBSnapshot {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("RocksDBSnapshot").finish()
    }
}

/// Generates the sequence of bytes lexicographically following the provided one. Assumes that
/// the provided sequence is less than `[u8::max_value(); ID_SIZE]`.
pub fn next_id_bytes(id_bytes: [u8; ID_SIZE]) -> [u8; ID_SIZE] {
    let mut next_id_bytes = id_bytes;
    for byte in next_id_bytes.iter_mut().rev() {
        if *byte == u8::max_value() {
            *byte = 0;
        } else {
            *byte += 1;
            break;
        }
    }
    next_id_bytes
}

#[test]
fn test_next_id_bytes() {
    assert_eq!(
        next_id_bytes([1, 0, 0, 0, 0, 0, 0, 0]),
        [1, 0, 0, 0, 0, 0, 0, 1]
    );
    assert_eq!(
        next_id_bytes([1, 2, 3, 4, 5, 6, 7, 8]),
        [1, 2, 3, 4, 5, 6, 7, 9]
    );
    assert_eq!(
        next_id_bytes([1, 0, 0, 0, 0, 0, 0, 254]),
        [1, 0, 0, 0, 0, 0, 0, 255]
    );
    assert_eq!(
        next_id_bytes([1, 0, 0, 0, 0, 0, 41, 255]),
        [1, 0, 0, 0, 0, 0, 42, 0]
    );
    assert_eq!(
        next_id_bytes([1, 2, 3, 4, 5, 255, 255, 255]),
        [1, 2, 3, 4, 6, 0, 0, 0]
    );
}
