extern crate alloc;

use alloc::format;
use alloc::boxed::Box;
use cortex_m::interrupt::free;
use core::cell::UnsafeCell;
use embedded_hal::digital::v2::OutputPin;
use stm32f1xx_hal_bxcan::prelude::*;
use stm32f1xx_hal_bxcan::pac::{NVIC, SPI2, TIM2};
use stm32f1xx_hal_bxcan::rcc::{Clocks, APB1};
use stm32f1xx_hal_bxcan::spi::{Mode, NoMiso, Phase, Polarity, Spi};
use stm32f1xx_hal_bxcan::timer::{CountDownTimer, Event, Timer};
use stm32f1xx_hal_bxcan::stm32::{interrupt, Interrupt};

use ross_eeprom::DeviceInfo;
use ross_logger::{log_debug, log_warning, log_error};
use ross_protocol::event::bcm_event::*;
use ross_protocol::event::general_event::AckEvent;
use ross_protocol::convert_packet::ConvertPacket;
use ross_protocol::interface::Interface;

use crate::helper::cell_helper::get_from_cell;
use crate::helper::type_helper::{BcmClockPin, BcmLatchPin, BcmDataPin, UnsafeCanProtocol, BcmSpi, UnsafeLogger, UnsafeBcmModule};

const CHANNEL_COUNT: usize = 24;
const INITIAL_FREQUENCY_HZ: u32 = 65_536;

// Timer for binary code modulation
static mut TIMER: Option<CountDownTimer<TIM2>> = None;
// Current frequency for binary code modulation in Hertz
static mut CURRENT_FREQUENCY_HZ: u32 = INITIAL_FREQUENCY_HZ;
// Current bitmask for binary code modulation
static mut CURRENT_BITMASK: u8 = 0x01;
// Channel brightness values for binary code modulation
static mut CHANNEL_BRIGHTNESS: [u8; CHANNEL_COUNT] = [0x00; CHANNEL_COUNT];
static mut LATCH_PIN: Option<BcmLatchPin> = None;
static mut SPI: Option<BcmSpi> = None;

pub struct BcmModule {
    channel_brightness: &'static mut [u8; CHANNEL_COUNT],
}

impl BcmModule {
    pub fn new() -> UnsafeBcmModule {
        let channel_brightness = unsafe { &mut CHANNEL_BRIGHTNESS };

        UnsafeCell::new(BcmModule {
            channel_brightness,
        })
    }

    pub fn init<'a>(
        module: &'a UnsafeBcmModule,
        protocol: &'a UnsafeCanProtocol<'a>,
        logger: &'a UnsafeLogger,
        device_info: &'a DeviceInfo,
        clock_pin: BcmClockPin,
        latch_pin: BcmLatchPin,
        data_pin: BcmDataPin,
        spi2: SPI2,
        clocks: Clocks,
        apb1: &mut APB1,
        tim2: TIM2,
    ) {
        let spi = {
            let pins = (clock_pin, NoMiso, data_pin);

            let spi_mode = Mode {
                polarity: Polarity::IdleLow,
                phase: Phase::CaptureOnFirstTransition,
            };

            Spi::spi2(
                spi2,
                pins,
                spi_mode,
                8_000_000.hz(),
                clocks,
                apb1,
            )
        };

        unsafe {
            LATCH_PIN = Some(latch_pin);
            SPI = Some(spi);
        }

        let mut timer =
            Timer::tim2(tim2, &clocks, apb1).start_count_down(INITIAL_FREQUENCY_HZ.hz());
            timer.listen(Event::Update);

        unsafe {
            TIMER = Some(timer);

            // For binary code modulation
            NVIC::unmask(Interrupt::TIM2);
        }

        get_from_cell(protocol).add_packet_handler(
            Box::new(move |packet, can| {
                if let Ok(bcm_change_brightness_event) = BcmChangeBrightnessEvent::try_from_packet(&packet) {
                    log_debug!(logger, "Received `bcm_change_brightness_event` ({:?}).", bcm_change_brightness_event);
                    let ack_event = AckEvent {
                        receiver_address: bcm_change_brightness_event.transmitter_address,
                        transmitter_address: device_info.device_address,
                    };

                    get_from_cell(module).change_brightness(
                        bcm_change_brightness_event.channel,
                        bcm_change_brightness_event.brightness,
                        logger,
                    );
                
                    if let Err(err) = can.try_send_packet(&ack_event.to_packet()) {
                        log_error!(
                            logger,
                            "Failed to send `ack_event` ({:?}).",
                            err
                        );
                    } else {
                        log_debug!(logger, "Sent `ack_event` ({:?}).", ack_event);
                    }
                }
            }),
            false,
        ).unwrap();

        log_debug!(logger, "BCM module initialized.");
    }

    pub fn tick(&mut self) {}

    fn change_brightness(&mut self, channel: u8, brightness: u8, logger: &UnsafeLogger) {
        if channel >= CHANNEL_COUNT as u8 {
            log_warning!(logger, "Channel index ({}) exceeds channel count ({}).", channel, CHANNEL_COUNT);
        }

        free(|_cs| {
            (*self.channel_brightness)[channel as usize] = brightness;
        });
    }
}

#[interrupt]
fn TIM2() {
    let latch_pin = unsafe { LATCH_PIN.as_mut().unwrap() };
    let spi = unsafe { SPI.as_mut().unwrap() };

    let timer = unsafe { TIMER.as_mut().unwrap() };
    let current_frequency_hz = unsafe { &mut CURRENT_FREQUENCY_HZ };
    let current_bitmask = unsafe { &mut CURRENT_BITMASK };

    let channel_brightness = unsafe { &CHANNEL_BRIGHTNESS };

    // Decrease the frequency by a factor of 2
    *current_frequency_hz >>= 1;
    // Move onto the next bit
    *current_bitmask <<= 1;

    if *current_bitmask == 0 {
        *current_frequency_hz = INITIAL_FREQUENCY_HZ;
        *current_bitmask = 0x01;
    }

    latch_pin.set_low().unwrap();

    for i in 0..channel_brightness.len() / 8 {
        let mut value = 0x00;

        for j in 0..=7 {
            let brightness = channel_brightness[channel_brightness.len() - (i * 8 + j) - 1];

            value <<= 1;

            if (brightness & (*current_bitmask)) != 0 {
                value |= 1;
            }
        }

        let buffer = [value; 1];
        spi.write(&buffer).unwrap();
    }

    latch_pin.set_high().unwrap();

    timer.start((*current_frequency_hz).hz());
    timer.wait().unwrap();
}
