//! Operations involving log probabilities.

use serde::{Deserialize, Serialize};

/// A [`String`] which is guaranteed to not be empty.
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Serialize)]
pub struct NonEmptyString(String);

impl NonEmptyString {
    /// Creates a new [`NonEmptyString`] from a [`String`].
    pub fn new(s: String) -> Option<Self> {
        if s.is_empty() {
            None
        } else {
            Some(Self(s))
        }
    }

    /// Get a reference to the inner [`String`].
    pub fn inner(&self) -> &str {
        &self.0
    }

    /// Take the inner [`String`].
    pub fn into_inner(self) -> String {
        self.0
    }
}

#[derive(Serialize)]
pub(crate) struct LogProbabilitiesRequest {
    pub(crate) context: String,
    pub(crate) continuation: NonEmptyString,
}

/// This is logarithm of the probability that a continuation is generated after a context. It can be
/// used to answer questions when only a few answers (such as yes/no) are possible. It can also be
/// used to benchmark the models.
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Deserialize)]
pub struct LogProbabilities {
    logprob: f64,
    is_greedy: bool,
    total_tokens: usize,
}

impl LogProbabilities {
    /// Logarithm of the probability of generation of continuation preceded by context. It is
    /// always <= 0.
    pub const fn log_probability(&self) -> f64 {
        self.logprob
    }

    /// `true` if `continuation` would be generated by greedy sampling from `continuation`.
    pub const fn is_greedy(&self) -> bool {
        self.is_greedy
    }

    /// Indicate the total number of tokens. It is useful to estimate the number of compute
    /// resources used by the request.
    pub const fn total_tokens(&self) -> usize {
        self.total_tokens
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::test_utils;

    #[test]
    fn test_non_empty_string_new() {
        let empty = String::new();
        let non_empty = String::from("textsynth");

        assert!(NonEmptyString::new(empty).is_none());
        assert!(NonEmptyString::new(non_empty).is_some());
    }

    #[test]
    fn test_non_empty_string_inner() {
        let s = String::from("textsynth");
        let non_empty = NonEmptyString::new(s).unwrap();

        assert_eq!(non_empty.inner(), "textsynth");
    }

    #[test]
    fn test_non_empty_string_into_inner() {
        let s = String::from("textsynth");
        let non_empty = NonEmptyString::new(s).unwrap();

        assert_eq!(non_empty.into_inner(), "textsynth");
    }

    #[test]
    fn test_log_probabilities_log_probability() {
        let _ = test_utils::cache::log_probabilities().log_probability();
    }

    #[test]
    fn test_log_probabilities_is_greedy() {
        let _ = test_utils::cache::log_probabilities().is_greedy();
    }

    #[test]
    fn test_log_probabilities_total_tokens() {
        let _ = test_utils::cache::log_probabilities().total_tokens();
    }
}
