#!/usr/bin/env python3

import functools
import os
import os.path
import re
import subprocess
import sys
import tempfile
import threading
import time
from typing import Callable, List

home_path: str
data_path: str
bitbottle: str
unbottle: str
all_tests: List[Callable] = []


def exec(*args: str, **kw) -> subprocess.CompletedProcess:
    print("+", " ".join(args))
    if "delay" in kw:
        delay = kw["delay"]
        del kw["delay"]
        r_stdin, w_stdin = os.pipe()
        input: str = kw.get("input") or ""
        del kw["input"]
        kw["stdin"] = r_stdin
        def other_thread():
            time.sleep(delay)
            os.write(w_stdin, input.encode("utf-8"))
        threading.Thread(target = other_thread).start()
    rv = subprocess.run(args, **kw)
    if rv.returncode != 0:
        print("*** Fail.")
        sys.exit(rv.returncode)
    return rv

def test(f: Callable) -> Callable:
    @functools.wraps(f)
    def new_f():
        with tempfile.TemporaryDirectory() as temp_path:
            os.chdir(temp_path)
            f()
            os.chdir(home_path)
    all_tests.append(new_f)
    return new_f

@test
def test_basic():
    "archive the source folder and ensure it expands identically"
    count = 1
    for compress in [ "--no-compress", "--snappy", "--lzma2" ]:
        exec(bitbottle, compress, "-o", f"./test{count}.bb", "-C", home_path, "src")
        exec(unbottle, "--check", f"./test{count}.bb")
        exec(unbottle, "-d", f"./test{count}", f"./test{count}.bb")
        exec("diff", "-r", f"{home_path}/src", f"./test{count}/src")
        count += 1
    assert(os.stat("./test1.bb").st_size > os.stat("./test2.bb").st_size)
    assert(os.stat("./test2.bb").st_size > os.stat("./test3.bb").st_size)

@test
def test_hashes():
    "other hash functions are fine"
    count = 1
    for hash in [ "--blake2", "--sha256" ]:
        exec(bitbottle, hash, "-o", f"./test{count}.bb", "-C", home_path, "src")
        exec(unbottle, "--check", f"./test{count}.bb")
        exec(unbottle, "-d", f"./test{count}", f"./test{count}.bb")
        exec("diff", "-r", f"{home_path}/src", f"./test{count}/src")
        count += 1

@test
def test_encryption():
    "encrypt the source folder and ensure it expands identically"
    for encrypt in [ "--no-compress", "--aes" ]:
        exec(bitbottle, encrypt, "--password", "-o", f"./test-pass.bb", "-C", home_path, "src", text = True, input = "hello\n")
        exec(unbottle, "--password", "-d", f"./test-pass", f"./test-pass.bb", text = True, input = "hello\n")
        exec("diff", "-r", f"{home_path}/src", f"./test-pass/src")
        exec(bitbottle, encrypt, "--pub", f"{data_path}/test-key.pub", "-o", f"./test-ssh.bb", "-C", home_path, "src")
        exec(unbottle, "--secret", f"{data_path}/test-key", "-d", f"./test-ssh", f"./test-ssh.bb")
        exec("diff", "-r", f"{home_path}/src", f"./test-ssh/src")

@test
def test_ssh_password_encryption():
    "encrypt using an SSH key protected by a password"
    exec(bitbottle, "--pub", f"{data_path}/test-key-pw.pub", "-o", f"./test-ssh.bb", "-C", home_path, "src")
    exec(unbottle, "--secret", f"{data_path}/test-key-pw", "-d", f"./test-ssh", f"./test-ssh.bb", text = True, input = "password\n")
    exec("diff", "-r", f"{home_path}/src", f"./test-ssh/src")

@test
def test_relative_paths():
    "ensure absolute paths are stored as relative"
    exec(bitbottle, "-o", "test.bb", f"{home_path}/src")
    output: List[str] = exec(unbottle, "--info", "test.bb", capture_output = True, text = True).stderr.split("\n")
    assert(any(re.search(r" src/file_atlas.rs", line) for line in output))
    exec(unbottle, "-d", "./test-rel", "test.bb")
    exec("diff", "-r", f"{home_path}/src", f"./test-rel/src")

@test
def test_defang_unbottle():
    "defang bad paths when unbottling"
    exec(unbottle, "-d", "./test-defang", f"{data_path}/message.bb")
    assert(os.path.exists(f"./test-defang/tmp/message.jpg"))

@test
def test_zero_length_file():
    "don't write contents for a zero-length file"
    open("zero.txt", "wb").write(b"")
    open("one.txt", "wb").write(b"x")
    exec(bitbottle, "-o", "test.bb", "zero.txt", "one.txt")
    output: List[str] = exec(unbottle, "--dump", "test.bb", capture_output = True, text = True).stderr.split("\n")
    # only one bottle block: the one.txt file
    assert(len([line for line in output if re.search(r"Bitbottle type 3:", line)]) == 1)
    exec(unbottle, "-d", "test-out", "test.bb")
    exec("test", "-f", "test-out/zero.txt")

@test
def test_bad_symlink():
    "don't archive a bad symlink"
    exec("mkdir", "-p", "inner")
    open("inner/good.txt", "wb").write(b"")
    exec("ln", "-s", "../../foo", "inner/bad")
    exec(bitbottle, "-v", "-o", "test.bb", "inner")
    exec(unbottle, "-d", "test-out", "test.bb")
    assert(os.path.exists("test-out/inner/good.txt"))
    assert(not os.path.exists("test-out/inner/bad"))

@test
def test_good_symlink():
    "archive a good symlink"
    exec("mkdir", "-p", "inner")
    exec("mkdir", "-p", "inner/deep")
    open("inner/good.txt", "wb").write(b"good")
    exec("ln", "-s", "../good.txt", "inner/deep/ok.txt")
    exec(bitbottle, "-v", "-o", "test.bb", "inner")
    exec(unbottle, "--info", "test.bb")
    exec(unbottle, "-d", "test-out", "test.bb")
    assert(os.path.exists("test-out/inner/good.txt"))
    assert(os.path.exists("test-out/inner/deep/ok.txt"))
    assert(open("test-out/inner/good.txt", "rb").read() == b"good")
    assert(open("test-out/inner/deep/ok.txt", "rb").read() == b"good")


if __name__ == "__main__":
    # move to bitbottle home
    home_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
    data_path = os.path.join(os.getcwd(), "tests/data")
    bitbottle = os.path.join(home_path, "target/release/bitbottle")
    unbottle = os.path.join(home_path, "target/release/unbottle")
    for test in all_tests:
        print()
        print(f"\x1b[1m*** {test.__doc__ or '?'}\x1b[0m")
        print()
        test()
    print()
    print("\x1b[1m*** ALL TESTS PASS. :)\x1b[0m")
    print()
