use anyhow::{Context, Result};
use argh::FromArgs;
use gawires_diff::DiffParams;
use gawires_patch;
use comde::{Compressor, Decompressor};
use crossbeam_utils::thread;
use log::*;
use size::Size;
use std::{
    fs::{self, File},
    io::{self, BufReader, BufWriter, Read, Seek, Write},
    path::PathBuf,
    str::FromStr,
    time::Instant,
};

/// Generate and apply binary patches
#[derive(FromArgs, PartialEq, Debug)]
struct Cli {
    #[argh(subcommand)]
    cmd: Command,
}

#[derive(FromArgs, PartialEq, Debug)]
#[argh(subcommand)]
enum Command {
    Diff(Diff),
    Patch(Patch),
    Cycle(Cycle),
}

/// Write the diff of two files to a patch file
#[derive(FromArgs, PartialEq, Debug)]
#[argh(subcommand, name = "diff")]
struct Diff {
    #[argh(positional)]
    older: PathBuf,
    #[argh(positional)]
    newer: PathBuf,
    #[argh(positional)]
    patch: PathBuf,
    /// number of partitions
    #[argh(option, default = "1")]
    sort_partitions: usize,
    /// compression method to use
    #[argh(option, default = "Method::Stored")]
    method: Method,
    /// optionally specify a chunk size
    #[argh(option)]
    scan_chunk_size: Option<usize>,
}

/// Apply a patch file generated by this tool
#[derive(FromArgs, PartialEq, Debug)]
#[argh(subcommand, name = "patch")]
struct Patch {
    #[argh(positional)]
    older: PathBuf,
    #[argh(positional)]
    patch: PathBuf,
    #[argh(positional)]
    output: PathBuf,
    /// compression method to use
    #[argh(option, default = "Method::Stored")]
    method: Method,
}

/// Cycle
#[derive(FromArgs, PartialEq, Debug)]
#[argh(subcommand, name = "cycle")]
struct Cycle {
    #[argh(positional)]
    older: PathBuf,
    #[argh(positional)]
    newer: PathBuf,
    /// number of partitions
    #[argh(option, default = "1")]
    sort_partitions: usize,
    /// compression method to use
    #[argh(option, default = "Method::Stored")]
    method: Method,
    /// optionally specify a chunk size
    #[argh(option)]
    scan_chunk_size: Option<usize>,
}

/// Compression method used
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Method {
    Stored,
    Deflate,
    Brotli,
    Snappy,
    Zstd,
}

impl Default for Method {
    fn default() -> Self {
        Self::Stored
    }
}

impl Method {
    fn compress<W: Write + Seek, R: Read>(
        self,
        writer: &mut W,
        reader: &mut R,
    ) -> io::Result<comde::ByteCount> {
        match self {
            Self::Stored => comde::stored::StoredCompressor::new().compress(writer, reader),
            Self::Deflate => comde::deflate::DeflateCompressor::new().compress(writer, reader),
            Self::Brotli => comde::brotli::BrotliCompressor::new().compress(writer, reader),
            Self::Snappy => comde::snappy::SnappyCompressor::new().compress(writer, reader),
            Self::Zstd => comde::zstd::ZstdCompressor::new().compress(writer, reader),
        }
    }

    fn decompress<W: Write, R: Read>(self, reader: R, writer: W) -> io::Result<u64> {
        match self {
            Self::Stored => comde::stored::StoredDecompressor::new().copy(reader, writer),
            Self::Deflate => comde::deflate::DeflateDecompressor::new().copy(reader, writer),
            Self::Brotli => comde::brotli::BrotliDecompressor::new().copy(reader, writer),
            Self::Snappy => comde::snappy::SnappyDecompressor::new().copy(reader, writer),
            Self::Zstd => comde::zstd::ZstdDecompressor::new().copy(reader, writer),
        }
    }
}

impl FromStr for Method {
    type Err = String;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "stored" => Ok(Method::Stored),
            "deflate" => Ok(Method::Deflate),
            "brotli" => Ok(Method::Brotli),
            "snappy" => Ok(Method::Snappy),
            "zstd" => Ok(Method::Zstd),
            _ => Err(format!("Unknown compression method {}", s)),
        }
    }
}

fn main() -> Result<()> {
    #[cfg(debug_assertions)]
    std::env::set_var("RUST_BACKTRACE", "1");

    env_logger::builder().init();

    let Cli { cmd } = argh::from_env();
    match cmd {
        Command::Diff(args) => {
            do_diff(&args)?;
        }
        Command::Patch(args) => {
            do_patch(&args)?;
        }
        Command::Cycle(args) => {
            do_cycle(&args)?;
        }
    }

    Ok(())
}

fn do_cycle(
    Cycle {
        older,
        newer,
        method,
        sort_partitions,
        scan_chunk_size,
    }: &Cycle,
) -> Result<()> {
    info!("Reading older and newer in memory...");
    let (older, newer) = (fs::read(older)?, fs::read(newer)?);

    info!(
        "Before {}, After {}",
        Size::Bytes(older.len()),
        Size::Bytes(newer.len()),
    );

    let mut compatch = Vec::new();
    let before_diff = Instant::now();

    {
        let mut compatch_w = io::Cursor::new(&mut compatch);

        let (mut patch_r, mut patch_w) = pipe::pipe();
        thread::scope(|s| {
            s.spawn(|_| {
                gawires_diff::simple_diff_with_params(
                    &older[..],
                    &newer[..],
                    &mut patch_w,
                    &DiffParams::new(*sort_partitions, *scan_chunk_size).unwrap(),
                )
                .context("simple diff with params")
                .unwrap();
                // this is important for `.compress()` to finish.
                // since we're using scoped threads, it's never dropped
                // otherwise.
                drop(patch_w);
            });
            method
                .compress(&mut compatch_w, &mut patch_r)
                .context("compress")
                .unwrap();
        })
        .unwrap();
    }

    let diff_duration = before_diff.elapsed();

    let ratio = (compatch.len() as f64) / (newer.len() as f64);

    let mut fresh = Vec::new();
    let before_patch = Instant::now();
    {
        let mut older = io::Cursor::new(&older[..]);

        let (patch_r, patch_w) = pipe::pipe();

        thread::scope(|s| {
            s.spawn(|_| {
                method
                    .decompress(&compatch[..], patch_w)
                    .context("decompress")
                    .unwrap();
            });

            let mut r = gawires_patch::Reader::new(patch_r, &mut older)
                .context("read patch")
                .unwrap();
            let fresh_size = io::copy(&mut r, &mut fresh).unwrap();

            assert_eq!(fresh_size as usize, newer.len());
        })
        .unwrap();
    }
    let patch_duration = before_patch.elapsed();

    let newer_hash = hmac_sha256::Hash::hash(&newer[..]);
    let fresh_hash = hmac_sha256::Hash::hash(&fresh[..]);

    anyhow::ensure!(newer_hash == fresh_hash, "Hash mismatch!");

    let cm = format!("{:?}", method);
    let cp = format!("patch {}", Size::Bytes(compatch.len()));
    let cr = format!("{:03.3}% of {}", ratio * 100.0, Size::Bytes(newer.len()));
    let cdd = format!("dtime {:?}", diff_duration);
    let cpd = format!("ptime {:?}", patch_duration);
    println!("{:12} {:20} {:27} {:20} {:20}", cm, cp, cr, cdd, cpd);

    Ok(())
}

fn do_patch(
    Patch {
        older,
        patch,
        output,
        method,
    }: &Patch,
) -> Result<()> {
    println!("Using method {:?}", method);
    let start = Instant::now();

    let compatch_r = BufReader::new(File::open(patch).context("open patch file")?);
    let (patch_r, patch_w) = pipe::pipe();
    let method = *method;

    std::thread::spawn(move || {
        method
            .decompress(compatch_r, patch_w)
            .context("decompress")
            .unwrap();
    });

    let older_r = File::open(older)?;
    let mut fresh_r = gawires_patch::Reader::new(patch_r, older_r).context("read patch")?;
    let mut output_w = BufWriter::new(File::create(output).context("create patch file")?);
    io::copy(&mut fresh_r, &mut output_w).context("write output file")?;

    info!("Completed in {:?}", start.elapsed());

    Ok(())
}

fn do_diff(
    Diff {
        older,
        newer,
        patch,
        method,
        sort_partitions,
        scan_chunk_size,
    }: &Diff,
) -> Result<()> {
    println!("Using method {:?}", method);
    let start = Instant::now();

    let older_contents = fs::read(older).context("read old file")?;
    let newer_contents = fs::read(newer).context("read new file")?;

    let (mut patch_r, mut patch_w) = pipe::pipe();
    let diff_params = DiffParams::new(*sort_partitions, *scan_chunk_size).unwrap();
    std::thread::spawn(move || {
        gawires_diff::simple_diff_with_params(
            &older_contents[..],
            &newer_contents[..],
            &mut patch_w,
            &diff_params,
        )
        .context("simple diff with params")
        .unwrap();
    });

    let mut compatch_w = BufWriter::new(File::create(patch).context("create patch file")?);
    method
        .compress(&mut compatch_w, &mut patch_r)
        .context("write output file")?;
    compatch_w.flush().context("finish writing output file")?;

    info!("Completed in {:?}", start.elapsed());

    Ok(())
}
