// Copyright 2021 System76 <info@system76.com>
// SPDX-License-Identifier: MPL-2.0

#![recursion_limit = "1024"]

#[macro_use]
extern crate derive_new;
#[macro_use]
extern crate derive_setters;
#[macro_use]
extern crate log;
#[macro_use]
extern crate thiserror;

pub mod checksum;
mod range;
mod systems;

pub use self::systems::*;

use std::{
    fmt::Debug,
    future::Future,
    io,
    num::{NonZeroU16, NonZeroU32, NonZeroU64},
    path::Path,
    sync::{
        atomic::{AtomicBool, Ordering},
        Arc,
    },
    time::{Duration, SystemTime, UNIX_EPOCH},
};

use async_fs::{self as fs, File};
use filetime::FileTime;
use futures::{
    channel::mpsc,
    stream::{self, StreamExt},
    AsyncReadExt, AsyncWriteExt,
};
use http_client::native::NativeClient;
use httpdate::HttpDate;
use numtoa::NumToA;
use surf::{Client, Request, Response, StatusCode};

pub type EventSender = mpsc::UnboundedSender<(Arc<Path>, FetchEvent)>;
pub type Output<T> = (Arc<Path>, Result<T, Error>);

/// An error from the asynchronous file fetcher.
#[derive(Debug, Error)]
pub enum Error {
    #[error("task was cancelled")]
    Cancelled,
    #[error("http client error")]
    Client(surf::Error),
    #[error("unable to concatenate fetched parts")]
    Concatenate(#[source] io::Error),
    #[error("unable to create file")]
    FileCreate(#[source] io::Error),
    #[error("unable to set timestamp on {:?}", _0)]
    FileTime(Arc<Path>, #[source] io::Error),
    #[error("content length is an invalid range")]
    InvalidRange(#[source] io::Error),
    #[error("unable to remove file with bad metadata")]
    MetadataRemove(#[source] io::Error),
    #[error("destination has no file name")]
    Nameless,
    #[error("unable to open fetched part")]
    OpenPart(Arc<Path>, #[source] io::Error),
    #[error("destination lacks parent")]
    Parentless,
    #[error("connection timed out")]
    TimedOut,
    #[error("error writing to file")]
    Write(#[source] io::Error),
    #[error("failed to rename partial to destination")]
    Rename(#[source] io::Error),
    #[error("server responded with an error: {}", _0)]
    Status(StatusCode),
}

impl From<surf::Error> for Error {
    fn from(e: surf::Error) -> Self {
        Self::Client(e)
    }
}

/// Information about a source being fetched.
#[derive(Debug, Setters)]
pub struct Source {
    /// URLs whereby the file can be found.
    #[setters(skip)]
    pub urls: Arc<[Box<str>]>,

    /// Where the file shall ultimately be fetched to.
    #[setters(skip)]
    pub dest: Arc<Path>,

    /// Optional location to store the partial file
    #[setters(strip_option)]
    #[setters(into)]
    pub part: Option<Arc<Path>>,
}

impl Source {
    pub fn new(urls: impl Into<Arc<[Box<str>]>>, dest: impl Into<Arc<Path>>) -> Self {
        Self {
            urls: urls.into(),
            dest: dest.into(),
            part: None,
        }
    }
}

/// Events which are submitted by the fetcher.
#[derive(Debug)]
pub enum FetchEvent {
    /// Signals that this file was already fetched.
    AlreadyFetched,
    /// States that we know the length of the file being fetched.
    ContentLength(u64),
    /// Notifies that the file has been fetched.
    Fetched,
    /// Notifies that a file is being fetched.
    Fetching,
    /// Reports the amount of bytes that have been read for a file.
    Progress(usize),
    /// Reports that a part of a file is being fetched.
    PartFetching(u64),
    /// Reports that a part has been fetched.
    PartFetched(u64),
}

/// An asynchronous file fetcher for clients fetching files.
///
/// The futures generated by the fetcher are compatible with single and multi-threaded
/// runtimes, allowing you to choose between the runtime that works best for your
/// application. A single-threaded runtime is generally recommended for fetching files,
/// as your network connection is unlikely to be faster than a single CPU core.
#[derive(new, Setters)]
pub struct Fetcher {
    #[setters(skip)]
    client: Client,

    /// When set, cancels any active operations.
    #[new(default)]
    #[setters(strip_option)]
    cancel: Option<Arc<AtomicBool>>,

    /// The number of concurrent connections to sustain per file being fetched.
    #[new(default)]
    connections_per_file: Option<NonZeroU16>,

    /// The number of attempts to make when a request fails.
    #[new(value = "unsafe { NonZeroU16::new_unchecked(3) } ")]
    retries: NonZeroU16,

    /// The maximum size of a part file when downloading in parts.
    #[new(value = "unsafe { NonZeroU32::new_unchecked(2 * 1024 * 1024) }")]
    max_part_size: NonZeroU32,

    /// The time to wait between chunks before giving up.
    #[new(default)]
    #[setters(strip_option)]
    timeout: Option<Duration>,

    /// Holds a sender for submitting events to.
    #[new(default)]
    #[setters(into)]
    #[setters(strip_option)]
    events: Option<Arc<EventSender>>,
}

impl Default for Fetcher {
    fn default() -> Self {
        Self::new(Client::with_http_client(NativeClient::default()))
    }
}

impl Fetcher {
    /// Request a file from one or more URIs.
    ///
    /// At least one URI must be provided as a source for the file. Each additional URI
    /// serves as a mirror for failover and load-balancing purposes.
    pub async fn request(
        self: Arc<Self>,
        uris: Arc<[Box<str>]>,
        to: Arc<Path>,
    ) -> Result<(), Error> {
        match self.clone().inner_request(uris.clone(), to.clone()).await {
            Ok(()) => Ok(()),
            Err(mut why) => {
                for _ in 1..self.retries.get() {
                    match self.clone().inner_request(uris.clone(), to.clone()).await {
                        Ok(()) => return Ok(()),
                        Err(cause) => why = cause,
                    }
                }

                Err(why)
            }
        }
    }

    async fn inner_request(
        self: Arc<Self>,
        uris: Arc<[Box<str>]>,
        to: Arc<Path>,
    ) -> Result<(), Error> {
        let mut modified = None;
        let mut length = None;
        let mut if_modified_since = None;

        // If the file already exists, validate that it is the same.
        if to.exists() {
            if let Some(response) = head(&self.client, &*uris[0]).await? {
                let content_length = response.content_length();
                modified = response.last_modified();

                if let (Some(content_length), Some(last_modified)) = (content_length, modified) {
                    match fs::metadata(to.as_ref()).await {
                        Ok(metadata) => {
                            let modified = metadata.modified().map_err(Error::Write)?;
                            let ts = modified
                                .duration_since(UNIX_EPOCH)
                                .expect("time went backwards");

                            if metadata.len() == content_length
                                && ts.as_secs() == date_as_timestamp(last_modified)
                            {
                                self.send((to, FetchEvent::AlreadyFetched));
                                return Ok(());
                            }

                            if_modified_since = Some(HttpDate::from(modified));
                            length = Some(content_length);
                        }
                        Err(why) => {
                            error!("failed to fetch metadata of {:?}: {}", to, why);
                            fs::remove_file(to.as_ref())
                                .await
                                .map_err(Error::MetadataRemove)?;
                        }
                    }
                }
            }
        }

        // If set, this will use multiple connections to download a file in parts.
        if let Some(connections) = self.connections_per_file {
            if let Some(response) = head(&self.client, &*uris[0]).await? {
                modified = response.last_modified();
                let length = match length {
                    Some(length) => Some(length),
                    None => response.content_length(),
                };

                if let Some(length) = length {
                    if supports_range(&self.client, &*uris[0], length).await? {
                        self.send((to.clone(), FetchEvent::ContentLength(length)));

                        return self
                            .get_many(length, connections.get(), uris, to, modified)
                            .await;
                    }
                }
            }
        }

        let mut request = self.client.get(&*uris[0]).header("Expect", "").build();
        if let Some(modified_since) = if_modified_since {
            request.set_header(
                "if-modified-since",
                httpdate::fmt_http_date(modified_since.into()),
            );
        }

        let path = match self
            .get(&mut modified, request.clone(), to.clone(), to.clone(), None)
            .await
        {
            Ok(path) => path,
            // Server does not support if-modified-since
            Err(Error::Status(StatusCode::NotImplemented)) => {
                let request = self.client.get(&*uris[0]).header("Expect", "").build();
                self.get(&mut modified, request, to.clone(), to, None)
                    .await?
            }
            Err(why) => return Err(why),
        };

        if let Some(modified) = modified {
            let filetime = FileTime::from_unix_time(date_as_timestamp(modified) as i64, 0);
            filetime::set_file_times(&path, filetime, filetime)
                .map_err(move |why| Error::FileTime(path, why))?;
        }

        Ok(())
    }

    async fn get(
        &self,
        modified: &mut Option<HttpDate>,
        request: Request,
        to: Arc<Path>,
        dest: Arc<Path>,
        length: Option<u64>,
    ) -> Result<Arc<Path>, Error> {
        let mut file = File::create(to.as_ref()).await.map_err(Error::FileCreate)?;

        if let Some(length) = length {
            file.set_len(length).await.map_err(Error::Write)?;
        }

        let response = &mut validate(if let Some(duration) = self.timeout {
            timed(
                duration,
                Box::pin(async { self.client.send(request).await.map_err(Error::from) }),
            )
            .await??
        } else {
            self.client.send(request).await?
        })?;

        if modified.is_none() {
            *modified = response.last_modified();
        }

        if response.status() == StatusCode::NotModified {
            return Ok(to);
        }

        let buffer = &mut [0u8; 8 * 1024];
        let mut read;

        loop {
            if self.cancelled() {
                return Err(Error::Cancelled);
            }

            let reader = async { response.read(buffer).await.map_err(Error::Write) };

            read = match self.timeout {
                Some(duration) => timed(duration, Box::pin(reader)).await??,
                None => reader.await?,
            };

            if read != 0 {
                self.send((dest.clone(), FetchEvent::Progress(read)));

                file.write_all(&buffer[..read])
                    .await
                    .map_err(Error::Write)?;
            } else {
                break;
            }
        }

        Ok(to)
    }

    async fn get_many(
        self: Arc<Self>,
        length: u64,
        concurrent: u16,
        uris: Arc<[Box<str>]>,
        to: Arc<Path>,
        mut modified: Option<HttpDate>,
    ) -> Result<(), Error> {
        let parent = to.parent().ok_or(Error::Parentless)?;
        let filename = to.file_name().ok_or(Error::Nameless)?;

        let mut buf = [0u8; 20];

        // The destination which parts will be concatenated to.
        let concatenated_file = &mut File::create(to.as_ref()).await.map_err(Error::FileCreate)?;

        let max_part_size =
            NonZeroU64::new(self.max_part_size.get() as u64).expect("max part size is 0");

        let to_ = to.clone();
        let parts = stream::iter(range::generate(length, max_part_size).enumerate())
            // Generate a future for fetching each part that a range describes.
            .map(move |(partn, (range_start, range_end))| {
                let uri = uris[partn % uris.len()].clone();

                let part_path = {
                    let mut new_filename = filename.to_os_string();
                    new_filename.push(&[".part", partn.numtoa_str(10, &mut buf)].concat());
                    parent.join(new_filename)
                };

                let fetcher = self.clone();
                let to = to_.clone();

                async move {
                    let range = range::to_string(range_start, range_end);

                    fetcher.send((to.clone(), FetchEvent::PartFetching(partn as u64)));

                    let request = fetcher
                        .client
                        .get(&*uri)
                        .header("range", range.as_str())
                        .header("Expect", "")
                        .build();

                    let result = fetcher
                        .get(
                            &mut modified,
                            request,
                            part_path.into(),
                            to.clone(),
                            Some(range_end - range_start),
                        )
                        .await;

                    fetcher.send((to, FetchEvent::PartFetched(partn as u64)));

                    result
                }
            })
            // Ensure that only this many connections are happenning concurrently at a
            // time
            .buffered(concurrent as usize)
            // This type exploded the stack, and therefore needs to be boxed
            .boxed_local();

        systems::concatenator(concatenated_file, parts).await?;

        if let Some(modified) = modified {
            let filetime = FileTime::from_unix_time(date_as_timestamp(modified) as i64, 0);
            filetime::set_file_times(&to, filetime, filetime)
                .map_err(|why| Error::FileTime(to, why))?;
        }

        Ok(())
    }

    fn cancelled(&self) -> bool {
        self.cancel
            .as_ref()
            .map_or(false, |cancel| cancel.load(Ordering::SeqCst))
    }

    fn send(&self, event: (Arc<Path>, FetchEvent)) {
        if let Some(sender) = self.events.as_ref() {
            let _ = sender.unbounded_send(event);
        }
    }
}

async fn head(client: &Client, uri: &str) -> Result<Option<Response>, Error> {
    match validate(client.head(uri).header("Expect", "").await?).map(Some) {
        result @ Ok(_) => result,
        Err(Error::Status(StatusCode::NotImplemented)) => Ok(None),
        Err(other) => Err(other),
    }
}

async fn supports_range(client: &Client, uri: &str, length: u64) -> Result<bool, Error> {
    let response = client
        .head(uri)
        .header("Expect", "")
        .header("range", range::to_string(0, length).as_str())
        .await?;

    if response.status() == StatusCode::PartialContent {
        Ok(true)
    } else {
        validate(response).map(|_| false)
    }
}

async fn timed<F, T>(duration: Duration, future: F) -> Result<T, Error>
where
    F: Future<Output = T> + Unpin,
{
    let timeout = async move {
        async_io::Timer::after(duration).await;
        Err(Error::TimedOut)
    };

    let result = async move { Ok(future.await) };

    futures::pin_mut!(timeout);
    futures::pin_mut!(result);

    futures::future::select(timeout, result)
        .await
        .factor_first()
        .0
}

fn validate(response: Response) -> Result<Response, Error> {
    let status = response.status();

    if status.is_informational() || status.is_success() {
        Ok(response)
    } else {
        Err(Error::Status(status))
    }
}

trait ResponseExt {
    fn content_length(&self) -> Option<u64>;
    fn last_modified(&self) -> Option<HttpDate>;
}

impl ResponseExt for Response {
    fn content_length(&self) -> Option<u64> {
        let header = self.header("content-length")?.get(0)?;
        header.as_str().parse::<u64>().ok()
    }

    fn last_modified(&self) -> Option<HttpDate> {
        let header = self.header("last-modified")?.get(0)?;
        httpdate::parse_http_date(header.as_str())
            .ok()
            .map(HttpDate::from)
    }
}

fn date_as_timestamp(date: HttpDate) -> u64 {
    SystemTime::from(date)
        .duration_since(UNIX_EPOCH)
        .expect("time backwards")
        .as_secs()
}
