use serde_json::json;
use std::{
    fmt,
    fmt::{Display, Formatter},
};

pub struct AI21 {
    pub token: String,

    pub top_k: u8,
    pub top_p: f64,
    pub temperature: f64,

    pub presence_penalty: f64,
    pub count_penalty: f64,
    pub frequency_penalty: f64,

    pub max_tokens: u16,
    pub stop_sequences: Vec<String>,

    client: reqwest::Client,
}

pub struct AI21Builder {
    token: String,

    top_k: u8,
    top_p: f64,
    temperature: f64,

    presence_penalty: f64,
    count_penalty: f64,
    frequency_penalty: f64,

    max_tokens: u16,
    stop_sequences: Vec<String>,
}

#[derive(Debug)]
pub enum AI21Error {
    Reqwest(reqwest::Error),
    Serde(serde_json::Error),
    InvalidType,
}

impl Display for AI21Error {
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
        match self {
            AI21Error::Reqwest(e) => write!(f, "Reqwest error: {}", e),
            AI21Error::Serde(e) => write!(f, "Serde error: {}", e),
            AI21Error::InvalidType => write!(f, "Invalid type"),
        }
    }
}

impl From<reqwest::Error> for AI21Error {
    fn from(err: reqwest::Error) -> Self {
        AI21Error::Reqwest(err)
    }
}

impl From<serde_json::Error> for AI21Error {
    fn from(err: serde_json::Error) -> Self {
        AI21Error::Serde(err)
    }
}

impl AI21 {
    pub fn new(token: &str) -> AI21Builder {
        AI21Builder {
            token: token.to_string(),

            top_k: 0,
            top_p: 0.9,
            temperature: 1.0,

            presence_penalty: 0.0,
            count_penalty: 0.0,
            frequency_penalty: 0.0,

            max_tokens: 16,
            stop_sequences: vec![],
        }
    }

    pub async fn complete(&self, prompt: &str) -> Result<String, AI21Error> {
        let parameters = json!({
            "prompt": prompt,

            "topKReturn": self.top_k,
            "topP": self.top_p,
            "temperature": self.temperature,

            "presencePenalty": {"scale": self.presence_penalty},
            "countPenalty": {"scale": self.count_penalty},
            "frequencyPenalty": {"scale": self.frequency_penalty},

            "maxTokens": self.max_tokens,
            "stopSequences": self.stop_sequences
        })
        .to_string();

        let response = self
            .client
            .post("https://api.ai21.com/studio/v1/j1-jumbo/complete")
            .header("Authorization", format!("Bearer {}", self.token))
            .header("Content-Type", "application/json")
            .body(parameters)
            .send()
            .await?
            .text()
            .await?;

        let json: serde_json::Value = serde_json::from_str(&response)?;

        if let serde_json::Value::String(output) = &json["completions"][0]["data"]["text"] {
            Ok(output.to_string())
        } else {
            Err(AI21Error::InvalidType)
        }
    }
}

impl AI21Builder {
    pub fn top_k(mut self, top_k: u8) -> Self {
        self.top_k = top_k;
        self
    }

    pub fn top_p(mut self, top_p: f64) -> Self {
        self.top_p = top_p;
        self
    }

    pub fn temperature(mut self, temperature: f64) -> Self {
        self.temperature = temperature;
        self
    }

    pub fn presence_penalty(mut self, presence_penalty: f64) -> Self {
        self.presence_penalty = presence_penalty;
        self
    }

    pub fn count_penalty(mut self, count_penalty: f64) -> Self {
        self.count_penalty = count_penalty;
        self
    }

    pub fn frequency_penalty(mut self, frequency_penalty: f64) -> Self {
        self.frequency_penalty = frequency_penalty;
        self
    }

    pub fn max_tokens(mut self, max_tokens: u16) -> Self {
        self.max_tokens = max_tokens;
        self
    }

    pub fn stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
        self.stop_sequences = stop_sequences;
        self
    }

    pub fn build(self) -> AI21 {
        AI21 {
            token: self.token,

            top_k: self.top_k,
            top_p: self.top_p,
            temperature: self.temperature,

            presence_penalty: self.presence_penalty,
            count_penalty: self.count_penalty,
            frequency_penalty: self.frequency_penalty,

            max_tokens: self.max_tokens,
            stop_sequences: self.stop_sequences,

            client: reqwest::Client::new(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test() {
        use std::env;

        let token = env::var("AI21_TOKEN").unwrap();
        let ai21 = AI21::new(&token).temperature(0.).build();
        let output = ai21.complete("lol").await.unwrap();
        assert_eq!(output, ", u r right, i shudnt have posted that.");
    }
}
