import hashlib
import logging
from dataclasses import dataclass
from operator import add

import click
import colorama
import psycopg2
from tqdm import tqdm
from upend import UpEnd


class LogFormatter(logging.Formatter):
    format_str = "[%(asctime)s] %(levelname)s - %(message)s"

    FORMATS = {
        logging.DEBUG: colorama.Fore.LIGHTBLACK_EX + format_str + colorama.Fore.RESET,
        logging.INFO: format_str,
        logging.WARNING: colorama.Fore.YELLOW + format_str + colorama.Fore.RESET,
        logging.ERROR: colorama.Fore.RED + format_str + colorama.Fore.RESET,
        logging.CRITICAL: colorama.Fore.RED
        + colorama.Style.BRIGHT
        + format_str
        + colorama.Style.RESET_ALL
        + colorama.Fore.RESET,
    }

    def format(self, record):
        log_fmt = self.FORMATS.get(record.levelno)
        formatter = logging.Formatter(log_fmt)
        return formatter.format(record)


@dataclass
class KSXTrackFile:
    file: str
    sha256sum: str
    energy: int
    seriousness: int
    tint: int
    materials: int


@click.command()
@click.option("--db-name", required=True)
@click.option("--db-user", required=True)
@click.option("--db-password", required=True)
@click.option("--db-host", default="localhost")
@click.option("--db-port", default=5432, type=int)
def main(db_name, db_user, db_password, db_host, db_port):
    """Load KSX database dump into UpEnd."""

    logger = logging.getLogger("ksx2upend")
    logger.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(LogFormatter())
    logger.addHandler(ch)

    logger.debug("Connecting to PostgreSQL...")
    connection = psycopg2.connect(
        database=db_name,
        user=db_user,
        password=db_password,
        host=db_host,
        port=db_port,
    )
    cur = connection.cursor()

    logger.debug("Connecting to UpEnd...")
    upend = UpEnd()

    cur.execute(
        "SELECT file, sha256sum, energy, seriousness, tint, materials "
        "FROM ksx_radio_trackfile "
        "INNER JOIN ksx_radio_moodsregular ON ksx_radio_trackfile.track_id = ksx_radio_moodsregular.track_id"
    )
    trackfiles = [KSXTrackFile(*row) for row in cur.fetchall()]
    logger.info(f"Got {len(trackfiles)} (annotated) trackfiles from database...")

    # TODO: get_invariant() or somesuch?
    blob_addr = list(upend.query((None, "TYPE", 'J"BLOB"')).values())[0]["entity"]

    all_files = upend.query((None, "IS", f"O{blob_addr}")).values()
    hashed_files = upend.query((None, "SHA256", None)).values()

    logger.info(
        f"Got {len(all_files)} files from UpEnd ({len(hashed_files)} of which are hashed)..."
    )

    if len(hashed_files) < len(all_files):
        logger.info("Computing SHA256 hashes for UpEnd files...")
        hashed_entries = [entry["entity"] for entry in hashed_files]
        unhashed_files = [
            file for file in all_files if file["entity"] not in hashed_entries
        ]
        for entry in tqdm(unhashed_files):
            sha256_hash = hashlib.sha256()
            for chunk in upend.get_raw(entry["entity"]):
                sha256_hash.update(chunk)
            upend.insert((entry["entity"], "SHA256", sha256_hash.hexdigest()))
        hashed_files = upend.query((None, "SHA256", None)).values()

    sha256_trackfiles = {tf.sha256sum: tf for tf in trackfiles}
    sha256_entities = {entry["value"]["c"]: entry["entity"] for entry in hashed_files}

    tf_and_ue = [sum for sum in sha256_trackfiles.keys() if sum in sha256_entities]

    logger.info(
        f"Out of {len(trackfiles)} trackfiles, and out of {len(hashed_files)} files in UpEnd, {len(tf_and_ue)} are present in both."
    )

    logger.info("Inserting types...")
    ksx_type_result = upend.insert((None, "TYPE", "KSX_TRACK_MOODS"))
    ksx_type_addr = list(ksx_type_result.values())[0]["entity"]
    upend.insert((ksx_type_addr, "TYPE_REQUIRES", "KSX_ENERGY"))
    upend.insert((ksx_type_addr, "TYPE_REQUIRES", "KSX_SERIOUSNESS"))
    upend.insert((ksx_type_addr, "TYPE_REQUIRES", "KSX_TINT"))
    upend.insert((ksx_type_addr, "TYPE_REQUIRES", "KSX_MATERIALS"))

    logger.info("Inserting mood data...")
    for sum in tqdm(tf_and_ue):
        tf = sha256_trackfiles[sum]
        address = sha256_entities[sum]

        upend.insert((address, "IS", ksx_type_addr), value_type="Address")
        upend.insert((address, "KSX_ENERGY", tf.energy))
        upend.insert((address, "KSX_SERIOUSNESS", tf.seriousness))
        upend.insert((address, "KSX_TINT", tf.tint))
        upend.insert((address, "KSX_MATERIALS", tf.materials))


if __name__ == "__main__":
    main()
