extern crate alloc;

use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::rc::Rc;
use alloc::vec;
use alloc::vec::Vec;
use core::cell::RefCell;
use cortex_m::interrupt::free;
use embedded_hal::digital::v2::OutputPin;
use palette::{FromColor, Hsv, Mix, Srgb};
use stm32f1xx_hal_bxcan::pac::{NVIC, SPI2, TIM2};
use stm32f1xx_hal_bxcan::prelude::*;
use stm32f1xx_hal_bxcan::rcc::{Clocks, APB1};
use stm32f1xx_hal_bxcan::spi::{Mode, NoMiso, Phase, Polarity, Spi};
use stm32f1xx_hal_bxcan::stm32::{interrupt, Interrupt};
use stm32f1xx_hal_bxcan::timer::{CountDownTimer, Event, Timer};

use ross_config::config::Config;
use ross_config::peripheral::{BcmPeripheral, Peripheral};
use ross_logger::{log_debug, log_info, log_warning, Logger};
use ross_protocol::convert_packet::ConvertPacket;
use ross_protocol::event::bcm::*;

use crate::helper::type_helper::{BcmClockPin, BcmDataPin, BcmLatchPin, BcmSpi, CanProtocol};
use crate::module::*;

const CHANNEL_COUNT: usize = 24;
const INITIAL_FREQUENCY_HZ: u32 = 65_536;

// Timer for binary code modulation
static mut TIMER: Option<CountDownTimer<TIM2>> = None;
// Current frequency for binary code modulation in Hertz
static mut CURRENT_FREQUENCY_HZ: u32 = INITIAL_FREQUENCY_HZ;
// Current bitmask for binary code modulation
static mut CURRENT_BITMASK: u8 = 0x01;
// Channel brightness values for binary code modulation
static mut CHANNEL_BRIGHTNESS: [u8; CHANNEL_COUNT] = [0x00; CHANNEL_COUNT];
static mut LATCH_PIN: Option<BcmLatchPin> = None;
static mut SPI: Option<BcmSpi> = None;

struct BcmAnimation {
    pub index: u8,
    pub start_time: Option<u32>,
    pub duration: u32,
    pub start_value: Option<BcmValue>,
    pub target_value: BcmValue,
}

pub struct BcmModuleConfig<'a, 'b> {
    pub protocol: Rc<RefCell<CanProtocol<'a>>>,
    pub logger: &'a RefCell<Logger>,
    pub clock_pin: BcmClockPin,
    pub latch_pin: BcmLatchPin,
    pub data_pin: BcmDataPin,
    pub spi2: SPI2,
    pub clocks: Clocks,
    pub apb1: &'b mut APB1,
    pub tim2: TIM2,
    pub config: Rc<RefCell<Config>>,
}

pub struct BcmModule<'a> {
    peripherals: BTreeMap<u32, BcmPeripheral>,
    logger: &'a RefCell<Logger>,
    channel_brightness: &'static mut [u8; CHANNEL_COUNT],
    animations: Vec<BcmAnimation>,
}

impl<'a> BcmModule<'a> {
    fn get_brightness_parts(
        index: u8,
        peripherals: &BTreeMap<u32, BcmPeripheral>,
        channel_brightness: &[u8; CHANNEL_COUNT],
    ) -> Option<BcmValue> {
        if let Some(peripheral) = peripherals
            .get(&(index as u32))
            .map(|peripheral| *peripheral)
        {
            match peripheral {
                BcmPeripheral::Single(channel) => {
                    Some(BcmValue::Single(channel_brightness[channel as usize]))
                }

                BcmPeripheral::Rgb(r_channel, g_channel, b_channel) => Some(BcmValue::Rgb(
                    channel_brightness[r_channel as usize],
                    channel_brightness[g_channel as usize],
                    channel_brightness[b_channel as usize],
                )),

                BcmPeripheral::Rgbw(r_channel, g_channel, b_channel, w_channel) => {
                    Some(BcmValue::Rgbw(
                        channel_brightness[r_channel as usize],
                        channel_brightness[g_channel as usize],
                        channel_brightness[b_channel as usize],
                        channel_brightness[w_channel as usize],
                    ))
                }
            }
        } else {
            None
        }
    }

    fn change_brightness(module: Rc<RefCell<Self>>, index: u8, value: &BcmValue) {
        let borrowed_module: &mut Self = &mut module.borrow_mut();
        let logger = borrowed_module.logger;
        let channel_brightness = &mut borrowed_module.channel_brightness;
        let peripherals = &borrowed_module.peripherals;

        Self::change_brightness_parts(index, value, logger, channel_brightness, peripherals);
    }

    fn change_brightness_parts(
        index: u8,
        value: &BcmValue,
        logger: &RefCell<Logger>,
        channel_brightness: &mut [u8; CHANNEL_COUNT],
        peripherals: &BTreeMap<u32, BcmPeripheral>,
    ) {
        if let Some(peripheral) = peripherals
            .get(&(index as u32))
            .map(|peripheral| *peripheral)
        {
            match (peripheral, value) {
                (BcmPeripheral::Single(channel), BcmValue::Single(value)) => {
                    Self::change_channel_brightness_parts(
                        logger,
                        channel_brightness,
                        channel,
                        *value,
                    );
                }
                (
                    BcmPeripheral::Rgb(r_channel, g_channel, b_channel),
                    BcmValue::Rgb(r_value, g_value, b_value),
                ) => {
                    Self::change_channel_brightness_parts(
                        logger,
                        channel_brightness,
                        r_channel,
                        *r_value,
                    );
                    Self::change_channel_brightness_parts(
                        logger,
                        channel_brightness,
                        g_channel,
                        *g_value,
                    );
                    Self::change_channel_brightness_parts(
                        logger,
                        channel_brightness,
                        b_channel,
                        *b_value,
                    );
                }
                (
                    BcmPeripheral::Rgbw(r_channel, g_channel, b_channel, w_channel),
                    BcmValue::Rgbw(r_value, g_value, b_value, w_value),
                ) => {
                    Self::change_channel_brightness_parts(
                        logger,
                        channel_brightness,
                        r_channel,
                        *r_value,
                    );
                    Self::change_channel_brightness_parts(
                        logger,
                        channel_brightness,
                        g_channel,
                        *g_value,
                    );
                    Self::change_channel_brightness_parts(
                        logger,
                        channel_brightness,
                        b_channel,
                        *b_value,
                    );
                    Self::change_channel_brightness_parts(
                        logger,
                        channel_brightness,
                        w_channel,
                        *w_value,
                    );
                }
                (_, _) => {
                    log_warning!(
                        logger,
                        "Received index ({:?}) does not match peripherals.",
                        index
                    );
                }
            }
        } else {
            log_warning!(
                logger,
                "Received index ({:?}) does not match peripherals.",
                index
            );
        }
    }

    fn change_channel_brightness_parts(
        logger: &RefCell<Logger>,
        channel_brightness: &mut [u8; CHANNEL_COUNT],
        channel: u8,
        brightness: u8,
    ) {
        if channel >= CHANNEL_COUNT as u8 {
            log_warning!(
                logger,
                "Channel index ({}) exceeds channel count ({}).",
                channel,
                CHANNEL_COUNT
            );
        }

        free(|_cs| {
            (*channel_brightness)[channel as usize] = brightness;
        });
    }
}

impl<'a> Module<'a> for BcmModule<'a> {
    fn new<'b>(config: ModuleConfig<'a, 'b>) -> Rc<RefCell<Self>> {
        let config = match config {
            ModuleConfig::BcmModule(config) => config,
            _ => {
                panic!("Wrong config provided for bcm module.");
            }
        };

        let protocol = config.protocol;
        let logger = config.logger;
        let clock_pin = config.clock_pin;
        let latch_pin = config.latch_pin;
        let data_pin = config.data_pin;
        let spi2 = config.spi2;
        let clocks = config.clocks;
        let apb1 = config.apb1;
        let tim2 = config.tim2;

        let peripherals = config
            .config
            .borrow_mut()
            .peripherals
            .drain_filter(|_, peripheral| matches!(peripheral, Peripheral::Bcm(_)))
            .filter_map(|peripheral| {
                if let (index, Peripheral::Bcm(peripheral)) = peripheral {
                    Some((index, peripheral))
                } else {
                    None
                }
            })
            .collect();

        let channel_brightness = unsafe { &mut CHANNEL_BRIGHTNESS };

        let module = Rc::new(RefCell::new(Self {
            peripherals,
            logger,
            channel_brightness,
            animations: vec![],
        }));

        let spi = {
            let pins = (clock_pin, NoMiso, data_pin);

            let spi_mode = Mode {
                polarity: Polarity::IdleLow,
                phase: Phase::CaptureOnFirstTransition,
            };

            Spi::spi2(spi2, pins, spi_mode, 8_000_000.hz(), clocks, apb1)
        };

        unsafe {
            LATCH_PIN = Some(latch_pin);
            SPI = Some(spi);
        }

        let mut timer =
            Timer::tim2(tim2, &clocks, apb1).start_count_down(INITIAL_FREQUENCY_HZ.hz());
        timer.listen(Event::Update);

        unsafe {
            TIMER = Some(timer);

            // For binary code modulation
            NVIC::unmask(Interrupt::TIM2);
        }

        let module_clone = Rc::clone(&module);
        protocol
            .borrow_mut()
            .add_packet_handler(
                Box::new(move |packet, _protocol| {
                    if let Ok(event) = BcmChangeBrightnessEvent::try_from_packet(&packet) {
                        log_debug!(
                            logger,
                            "Received `bcm_change_brightness_event` ({:?}).",
                            event
                        );

                        module_clone.borrow_mut().animations.retain(|animation| {
                            if animation.index == event.index {
                                log_info!(
                                    logger,
                                    "Stopping animation for peripheral with index ({:?}).",
                                    animation.index
                                );
                                false
                            } else {
                                true
                            }
                        });

                        Self::change_brightness(
                            Rc::clone(&module_clone),
                            event.index,
                            &event.value,
                        );
                    }
                }),
                false,
            )
            .unwrap();

        let module_clone = Rc::clone(&module);
        protocol
            .borrow_mut()
            .add_packet_handler(
                Box::new(move |packet, _protocol| {
                    if let Ok(event) = BcmAnimateBrightnessEvent::try_from_packet(&packet) {
                        log_debug!(
                            logger,
                            "Received `bcm_animate_brightness_event` ({:?}).",
                            event
                        );

                        module_clone.borrow_mut().animations.retain(|animation| {
                            if animation.index == event.index {
                                log_info!(
                                    logger,
                                    "Stopping animation for peripheral with index ({:?}).",
                                    animation.index
                                );
                                false
                            } else {
                                true
                            }
                        });

                        module_clone.borrow_mut().animations.push(BcmAnimation {
                            index: event.index,
                            start_time: None,
                            duration: event.duration,
                            start_value: None,
                            target_value: event.target_value,
                        });
                    }
                }),
                false,
            )
            .unwrap();

        log_debug!(logger, "BCM module initialized.");

        return Rc::clone(&module);
    }

    fn tick(module: Rc<RefCell<Self>>, current_time: &mut u32) {
        let borrowed_module: &mut Self = &mut module.borrow_mut();
        let animations = &mut borrowed_module.animations;
        let logger = borrowed_module.logger;
        let channel_brightness = &mut borrowed_module.channel_brightness;
        let peripherals = &borrowed_module.peripherals;
        let mut done_animations = vec![];

        for (animation_index, animation) in animations.iter_mut().enumerate() {
            if let (Some(start_time), Some(start_value)) =
                (animation.start_time, animation.start_value)
            {
                let current_position =
                    (*current_time - start_time) as f32 / animation.duration as f32;

                let current_position = if current_position > 1.0 {
                    1.0
                } else {
                    current_position
                };

                match (start_value, animation.target_value) {
                    (BcmValue::Single(start_value), BcmValue::Single(target_value)) => {
                        let current_value = (start_value as f32
                            + current_position * (target_value as f32 - start_value as f32))
                            as u8;

                        Self::change_brightness_parts(
                            animation.index,
                            &BcmValue::Single(current_value),
                            logger,
                            channel_brightness,
                            peripherals,
                        );
                    }
                    (
                        BcmValue::Rgb(start_r, start_g, start_b),
                        BcmValue::Rgb(target_r, target_g, target_b),
                    ) => {
                        let start_hsv = Hsv::from_color(Srgb::from_components((
                            start_r as f32 / 255.0,
                            start_g as f32 / 255.0,
                            start_b as f32 / 255.0,
                        )));
                        let target_hsv = Hsv::from_color(Srgb::from_components((
                            target_r as f32,
                            target_g as f32,
                            target_b as f32,
                        )));
                        let current_hsv = start_hsv.mix(&target_hsv, current_position);
                        let current_srgb = Srgb::from_color(current_hsv);

                        Self::change_brightness_parts(
                            animation.index,
                            &BcmValue::Rgb(
                                (current_srgb.red * 255.0) as u8,
                                (current_srgb.green * 255.0) as u8,
                                (current_srgb.blue * 255.0) as u8,
                            ),
                            logger,
                            channel_brightness,
                            peripherals,
                        );
                    }
                    (
                        BcmValue::Rgbw(start_r, start_g, start_b, start_w),
                        BcmValue::Rgbw(target_r, target_g, target_b, target_w),
                    ) => {
                        let start_hsv = Hsv::from_color(Srgb::from_components((
                            start_r as f32 / 255.0,
                            start_g as f32 / 255.0,
                            start_b as f32 / 255.0,
                        )));
                        let target_hsv = Hsv::from_color(Srgb::from_components((
                            target_r as f32,
                            target_g as f32,
                            target_b as f32,
                        )));
                        let current_hsv = start_hsv.mix(&target_hsv, current_position);
                        let current_srgb = Srgb::from_color(current_hsv);
                        let current_w =
                            start_w + ((target_w - start_w) as f32 * current_position) as u8;

                        Self::change_brightness_parts(
                            animation.index,
                            &BcmValue::Rgbw(
                                (current_srgb.red * 255.0) as u8,
                                (current_srgb.green * 255.0) as u8,
                                (current_srgb.blue * 255.0) as u8,
                                current_w,
                            ),
                            logger,
                            channel_brightness,
                            peripherals,
                        );
                    }
                    (start_value, target_value) => {
                        log_warning!(
                            logger,
                            "Tried to interpolate between different value types ({:?}, {:?}).",
                            start_value,
                            target_value
                        );
                    }
                }

                if current_position == 1.0 {
                    done_animations.push(animation_index);
                }
            } else {
                log_info!(
                    logger,
                    "Started animation for peripheral with index ({:?}).",
                    animation.index
                );

                animation.start_time = Some(*current_time);
                if let Some(value) =
                    Self::get_brightness_parts(animation.index, peripherals, channel_brightness)
                {
                    animation.start_value = Some(value);
                } else {
                    log_warning!(
                        logger,
                        "Received index ({:?}) does not match peripherals.",
                        animation.index,
                    );
                    done_animations.push(animation_index);
                }
            }
        }

        // Remove last animations first
        done_animations.reverse();

        for done_animation in done_animations.iter() {
            log_info!(
                logger,
                "Done with animation for peripheral with index ({:?}).",
                animations[*done_animation].index
            );
            animations.remove(*done_animation);
        }
    }
}

#[interrupt]
fn TIM2() {
    let latch_pin = unsafe { LATCH_PIN.as_mut().unwrap() };
    let spi = unsafe { SPI.as_mut().unwrap() };

    let timer = unsafe { TIMER.as_mut().unwrap() };
    let current_frequency_hz = unsafe { &mut CURRENT_FREQUENCY_HZ };
    let current_bitmask = unsafe { &mut CURRENT_BITMASK };

    let channel_brightness = unsafe { &CHANNEL_BRIGHTNESS };

    // Decrease the frequency by a factor of 2
    *current_frequency_hz >>= 1;
    // Move onto the next bit
    *current_bitmask <<= 1;

    if *current_bitmask == 0 {
        *current_frequency_hz = INITIAL_FREQUENCY_HZ;
        *current_bitmask = 0x01;
    }

    latch_pin.set_low().unwrap();

    for i in 0..channel_brightness.len() / 8 {
        let mut value = 0x00;

        for j in 0..=7 {
            let brightness = channel_brightness[channel_brightness.len() - (i * 8 + j) - 1];

            value <<= 1;

            if (brightness & (*current_bitmask)) != 0 {
                value |= 1;
            }
        }

        let buffer = [value; 1];
        spi.write(&buffer).unwrap();
    }

    latch_pin.set_high().unwrap();

    timer.start((*current_frequency_hz).hz());
    timer.wait().unwrap();
}
