/*
Copyright (C) 2021 Kunal Mehta <legoktm@debian.org>

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

use crate::error::ApiError;
use crate::{tokens::TokenStore, Error, ErrorFormat};
use log::debug;
use reqwest::{header, Client as HttpClient, Request, Response};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, Semaphore};

type Result<T> = std::result::Result<T, Error>;

#[derive(Clone, Debug)]
pub struct Client {
    api_url: String,
    http: HttpClient,
    tokens: Arc<RwLock<TokenStore>>,
    semaphore: Arc<Semaphore>,
    oauth2: Option<String>,
    errorformat: ErrorFormat,
}

impl Client {
    pub fn new(api_url: &str) -> Self {
        Self {
            api_url: api_url.to_string(),
            http: HttpClient::builder().gzip(true).build().unwrap(),
            tokens: Arc::new(RwLock::new(TokenStore::default())),
            semaphore: Arc::new(Semaphore::new(1)),
            oauth2: None,
            errorformat: ErrorFormat::default(),
        }
    }

    pub fn set_oauth2_token(&mut self, token: Option<String>) {
        self.oauth2 = token;
    }

    pub fn set_errorformat(&mut self, format: ErrorFormat) {
        self.errorformat = format;
    }

    pub fn set_concurrency(&mut self, concurrency: usize) {
        self.semaphore = Arc::new(Semaphore::new(concurrency));
    }

    fn headers(&self) -> Result<header::HeaderMap> {
        let mut headers = header::HeaderMap::new();
        if let Some(token) = &self.oauth2 {
            headers.insert(
                header::AUTHORIZATION,
                format!("Bearer {}", token).parse()?,
            );
        }

        Ok(headers)
    }

    pub async fn get<P: AsRef<str>>(&self, params: &[(P, P)]) -> Result<Value> {
        let mut params: HashMap<String, String> = params
            .iter()
            .map(|(key, val)| {
                (key.as_ref().to_string(), val.as_ref().to_string())
            })
            .collect();
        params.insert("format".to_string(), "json".to_string());
        params.insert("formatversion".to_string(), "2".to_string());
        params.insert("errorformat".to_string(), self.errorformat.to_string());
        let req = self
            .http
            .get(&self.api_url)
            .headers(self.headers()?)
            .query(&params)
            .build()?;
        let _lock = self.semaphore.acquire().await?;
        log_request(&req);
        let resp = self.http.execute(req).await?;
        log_response(&resp);
        drop(_lock);
        let value: Value = resp.error_for_status()?.json().await?;
        match value.get("errors") {
            Some(errors) => {
                let errors: Vec<ApiError> =
                    serde_json::from_value(errors.clone())?;
                Err(Error::ApiError(errors[0].clone()))
            }
            None => Ok(value),
        }
    }

    pub async fn post_with_token<P: AsRef<str>>(
        &mut self,
        token: &str,
        params: &[(P, P)],
    ) -> Result<Value> {
        let mut params: HashMap<_, _> = params
            .iter()
            .map(|(key, val)| {
                (key.as_ref().to_string(), val.as_ref().to_string())
            })
            .collect();
        let token = match self.tokens.read().await.get(token) {
            Some(token) => token,
            None => self.tokens.write().await.load(token, &self).await?,
        };
        params.insert("token".to_string(), token.to_string());
        let params: Vec<_> = params.iter().collect();
        self.post(params.as_slice()).await
    }

    pub async fn post<P: AsRef<str>>(
        &self,
        params: &[(P, P)],
    ) -> Result<Value> {
        let mut params: HashMap<String, String> = params
            .iter()
            .map(|(key, val)| {
                (key.as_ref().to_string(), val.as_ref().to_string())
            })
            .collect();
        params.insert("format".to_string(), "json".to_string());
        params.insert("formatversion".to_string(), "2".to_string());
        params.insert("errorformat".to_string(), self.errorformat.to_string());
        let req = self
            .http
            .post(&self.api_url)
            .headers(self.headers()?)
            .form(&params)
            .build()?;
        let _lock = self.semaphore.acquire().await?;
        log_request(&req);
        let resp = self.http.execute(req).await?;
        log_response(&resp);
        drop(_lock);
        let value: Value = resp.error_for_status()?.json().await?;
        match value.get("errors") {
            Some(errors) => {
                let errors: Vec<ApiError> =
                    serde_json::from_value(errors.clone())?;
                Err(Error::ApiError(errors[0].clone()))
            }
            None => Ok(value),
        }
    }
}

fn log_request(req: &Request) {
    let method = req.method().to_string();
    let url = req.url().to_string();
    // TODO: form body?
    debug!("Sending: HTTP {}: {}", method, url);
}

fn log_response(resp: &Response) {
    let status = resp.status().as_u16();
    let request_id = match resp.headers().get("x-request-id") {
        // Not worth logging an error if the header is invalid utf-8
        Some(val) => val.to_str().unwrap_or("unknown"),
        None => "unknown",
    };
    let url = resp.url().to_string();
    debug!("Received: {} (req: {}): {}", status, request_id, url);
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_basic_get() {
        let client = Client::new("https://www.mediawiki.org/w/api.php");
        let resp = client
            .get(&[("action", "query"), ("meta", "siteinfo")])
            .await
            .unwrap();
        assert_eq!(
            resp["query"]["general"]["sitename"].as_str().unwrap(),
            "MediaWiki"
        );
    }

    #[tokio::test]
    async fn test_basic_errors() {
        let client = Client::new("https://www.mediawiki.org/w/api.php");
        let error = client.get(&[("action", "nonexistent")]).await.unwrap_err();
        assert_eq!(
            &error.to_string(),
            "API error: (code: badvalue): Unrecognized value for parameter \"action\": nonexistent."
        );
    }
}
