extern crate alloc;

use alloc::boxed::Box;
use alloc::rc::Rc;
use core::cell::RefCell;
use cortex_m::interrupt::free;
use embedded_hal::digital::v2::OutputPin;
use stm32f1xx_hal_bxcan::pac::{NVIC, SPI2, TIM2};
use stm32f1xx_hal_bxcan::prelude::*;
use stm32f1xx_hal_bxcan::rcc::{Clocks, APB1};
use stm32f1xx_hal_bxcan::spi::{Mode, NoMiso, Phase, Polarity, Spi};
use stm32f1xx_hal_bxcan::stm32::{interrupt, Interrupt};
use stm32f1xx_hal_bxcan::timer::{CountDownTimer, Event, Timer};

use ross_eeprom::DeviceInfo;
use ross_logger::{log_debug, log_error, log_warning, Logger};
use ross_protocol::convert_packet::ConvertPacket;
use ross_protocol::event::bcm::*;
use ross_protocol::event::general::AckEvent;

use crate::helper::type_helper::{BcmClockPin, BcmDataPin, BcmLatchPin, BcmSpi, CanProtocol};
use crate::module::*;

const CHANNEL_COUNT: usize = 24;
const INITIAL_FREQUENCY_HZ: u32 = 65_536;

// Timer for binary code modulation
static mut TIMER: Option<CountDownTimer<TIM2>> = None;
// Current frequency for binary code modulation in Hertz
static mut CURRENT_FREQUENCY_HZ: u32 = INITIAL_FREQUENCY_HZ;
// Current bitmask for binary code modulation
static mut CURRENT_BITMASK: u8 = 0x01;
// Channel brightness values for binary code modulation
static mut CHANNEL_BRIGHTNESS: [u8; CHANNEL_COUNT] = [0x00; CHANNEL_COUNT];
static mut LATCH_PIN: Option<BcmLatchPin> = None;
static mut SPI: Option<BcmSpi> = None;

pub struct BcmModuleConfig<'a> {
    pub protocol: Rc<RefCell<CanProtocol<'a>>>,
    pub logger: &'a RefCell<Logger>,
    pub device_info: &'a DeviceInfo,
    pub clock_pin: BcmClockPin,
    pub latch_pin: BcmLatchPin,
    pub data_pin: BcmDataPin,
    pub spi2: SPI2,
    pub clocks: Clocks,
    pub apb1: &'a mut APB1,
    pub tim2: TIM2,
}

pub struct BcmModule<'a> {
    channel_brightness: &'static mut [u8; CHANNEL_COUNT],
    logger: &'a RefCell<Logger>,
}

impl<'a> BcmModule<'a> {
    fn change_brightness(module: Rc<RefCell<Self>>, channel: u8, brightness: u8) {
        let logger = module.borrow_mut().logger;
        let channel_brightness = &mut module.borrow_mut().channel_brightness;

        if channel >= CHANNEL_COUNT as u8 {
            log_warning!(
                logger,
                "Channel index ({}) exceeds channel count ({}).",
                channel,
                CHANNEL_COUNT
            );
        }

        free(|_cs| {
            (*channel_brightness)[channel as usize] = brightness;
        });
    }
}

impl<'a> Module<'a> for BcmModule<'a> {
    fn new(config: ModuleConfig<'a>) -> Rc<RefCell<Self>> {
        let config = match config {
            ModuleConfig::BcmModule(config) => config,
            _ => {
                panic!("Wrong config provided for bcm module.");
            }
        };

        let protocol = config.protocol;
        let logger = config.logger;
        let device_info = config.device_info;
        let clock_pin = config.clock_pin;
        let latch_pin = config.latch_pin;
        let data_pin = config.data_pin;
        let spi2 = config.spi2;
        let clocks = config.clocks;
        let apb1 = config.apb1;
        let tim2 = config.tim2;

        let channel_brightness = unsafe { &mut CHANNEL_BRIGHTNESS };

        let module = Rc::new(RefCell::new(Self {
            channel_brightness,
            logger,
        }));

        let spi = {
            let pins = (clock_pin, NoMiso, data_pin);

            let spi_mode = Mode {
                polarity: Polarity::IdleLow,
                phase: Phase::CaptureOnFirstTransition,
            };

            Spi::spi2(spi2, pins, spi_mode, 8_000_000.hz(), clocks, apb1)
        };

        unsafe {
            LATCH_PIN = Some(latch_pin);
            SPI = Some(spi);
        }

        let mut timer =
            Timer::tim2(tim2, &clocks, apb1).start_count_down(INITIAL_FREQUENCY_HZ.hz());
        timer.listen(Event::Update);

        unsafe {
            TIMER = Some(timer);

            // For binary code modulation
            NVIC::unmask(Interrupt::TIM2);
        }

        let module_clone = Rc::clone(&module);
        protocol
            .borrow_mut()
            .add_packet_handler(
                Box::new(move |packet, protocol| {
                    if let Ok(bcm_change_brightness_event) =
                        BcmChangeBrightnessEvent::try_from_packet(&packet)
                    {
                        log_debug!(
                            logger,
                            "Received `bcm_change_brightness_event` ({:?}).",
                            bcm_change_brightness_event
                        );
                        let ack_event = AckEvent {
                            receiver_address: bcm_change_brightness_event.transmitter_address,
                            transmitter_address: device_info.device_address,
                        };

                        Self::change_brightness(
                            Rc::clone(&module_clone),
                            bcm_change_brightness_event.channel,
                            bcm_change_brightness_event.brightness,
                        );

                        if let Err(err) = protocol.send_packet(&ack_event.to_packet()) {
                            log_error!(logger, "Failed to send `ack_event` ({:?}).", err);
                        } else {
                            log_debug!(logger, "Sent `ack_event` ({:?}).", ack_event);
                        }
                    }
                }),
                false,
            )
            .unwrap();

        log_debug!(logger, "BCM module initialized.");

        return Rc::clone(&module);
    }

    fn tick(&mut self) {}
}

#[interrupt]
fn TIM2() {
    let latch_pin = unsafe { LATCH_PIN.as_mut().unwrap() };
    let spi = unsafe { SPI.as_mut().unwrap() };

    let timer = unsafe { TIMER.as_mut().unwrap() };
    let current_frequency_hz = unsafe { &mut CURRENT_FREQUENCY_HZ };
    let current_bitmask = unsafe { &mut CURRENT_BITMASK };

    let channel_brightness = unsafe { &CHANNEL_BRIGHTNESS };

    // Decrease the frequency by a factor of 2
    *current_frequency_hz >>= 1;
    // Move onto the next bit
    *current_bitmask <<= 1;

    if *current_bitmask == 0 {
        *current_frequency_hz = INITIAL_FREQUENCY_HZ;
        *current_bitmask = 0x01;
    }

    latch_pin.set_low().unwrap();

    for i in 0..channel_brightness.len() / 8 {
        let mut value = 0x00;

        for j in 0..=7 {
            let brightness = channel_brightness[channel_brightness.len() - (i * 8 + j) - 1];

            value <<= 1;

            if (brightness & (*current_bitmask)) != 0 {
                value |= 1;
            }
        }

        let buffer = [value; 1];
        spi.write(&buffer).unwrap();
    }

    latch_pin.set_high().unwrap();

    timer.start((*current_frequency_hz).hz());
    timer.wait().unwrap();
}
