// Copyright (C) 2021 Robin Krahl <robin.krahl@ireas.org>
// SPDX-License-Identifier: Apache-2.0 or MIT

use std::convert;

use crate::{
    command::Command,
    error::{DeviceError, Error, RequestError, ResponseError},
    hid::Device,
    transaction::Channel,
};

pub fn data_size_init(device: &impl Device) -> usize {
    device.packet_size() - InitializationPacket::HEADER_SIZE
}

pub fn data_size_cont(device: &impl Device) -> usize {
    device.packet_size() - ContinuationPacket::HEADER_SIZE
}

pub fn send_init(
    device: &impl Device,
    channel: Channel,
    command: Command,
    length: u16,
    data: &[u8],
    buffer: &mut [u8],
) -> Result<(), Error> {
    Packet::Initialization(InitializationPacket {
        channel,
        command,
        length,
        data,
    })
    .send(device, buffer)
}

pub fn send_cont(
    device: &impl Device,
    channel: Channel,
    sequence: u8,
    data: &[u8],
    buffer: &mut [u8],
) -> Result<(), Error> {
    Packet::Continuation(ContinuationPacket {
        channel,
        sequence,
        data,
    })
    .send(device, buffer)
}

pub fn receive_init<'a>(
    device: &impl Device,
    channel: Channel,
    command: Command,
    buffer: &'a mut [u8],
) -> Result<(u16, &'a [u8]), Error> {
    // TODO: handle KeepAlive
    let packet = Packet::receive(device, channel, buffer)?;
    if let Packet::Initialization(packet) = packet {
        if packet.command == command {
            Ok((packet.length, packet.data))
        } else if packet.command == Command::Error {
            if !packet.data.is_empty() {
                Err(Error::from(DeviceError::from(packet.data[0])))
            } else {
                Err(Error::from(ResponseError::MissingErrorCode))
            }
        } else {
            Err(Error::from(ResponseError::UnexpectedCommand {
                expected: command,
                actual: packet.command,
            }))
        }
    } else {
        // TODO: Add actual & expected type
        Err(Error::from(ResponseError::UnexpectedPacketType))
    }
}

pub fn receive_cont<'a>(
    device: &impl Device,
    channel: Channel,
    sequence: u8,
    buffer: &'a mut [u8],
) -> Result<&'a [u8], Error> {
    let packet = Packet::receive(device, channel, buffer)?;
    if let Packet::Continuation(packet) = packet {
        if sequence == packet.sequence {
            Ok(packet.data)
        } else {
            Err(Error::from(ResponseError::UnexpectedSequence {
                expected: sequence,
                actual: packet.sequence,
            }))
        }
    } else {
        // TODO: Add actual & expected type
        Err(Error::from(ResponseError::UnexpectedPacketType))
    }
}

enum Packet<'a> {
    Initialization(InitializationPacket<'a>),
    Continuation(ContinuationPacket<'a>),
}

impl<'a> Packet<'a> {
    pub fn send(&self, device: &impl Device, buffer: &mut [u8]) -> Result<(), Error> {
        buffer[0] = 0;
        let n = self.serialize(&mut buffer[1..])?;
        buffer[n + 1..].fill(0);
        device.send(buffer)?;
        Ok(())
    }

    pub fn receive(
        device: &impl Device,
        channel: Channel,
        buffer: &'a mut [u8],
    ) -> Result<Self, Error> {
        use std::convert::TryFrom as _;
        let data = device.receive(buffer)?;
        let packet = Self::try_from(data)?;
        if packet.channel() == channel {
            Ok(packet)
        } else {
            Err(Error::from(ResponseError::UnexpectedChannel {
                expected: u32::from(channel),
                actual: u32::from(packet.channel()),
            }))
        }
    }

    pub fn channel(&self) -> Channel {
        match self {
            Self::Initialization(packet) => packet.channel,
            Self::Continuation(packet) => packet.channel,
        }
    }

    fn serialize(&self, buffer: &mut [u8]) -> Result<usize, RequestError> {
        match self {
            Self::Initialization(packet) => packet.serialize(buffer),
            Self::Continuation(packet) => packet.serialize(buffer),
        }
    }
}

impl<'a> convert::TryFrom<&'a [u8]> for Packet<'a> {
    type Error = ResponseError;

    fn try_from(data: &'a [u8]) -> Result<Self, Self::Error> {
        if data.len() < 5 {
            return Err(ResponseError::NotEnoughData);
        }
        let channel = Channel::from([data[0], data[1], data[2], data[3]]);
        let command_or_sequence = data[4];
        let data = &data[5..];

        if command_or_sequence.leading_ones() > 0 {
            if data.len() < 2 {
                return Err(ResponseError::NotEnoughData);
            }
            Ok(Packet::Initialization(InitializationPacket {
                channel,
                command: Command::from(command_or_sequence & 0b0111_1111),
                length: u16::from_be_bytes([data[0], data[1]]),
                data: &data[2..],
            }))
        } else {
            Ok(Packet::Continuation(ContinuationPacket {
                channel,
                sequence: command_or_sequence,
                data,
            }))
        }
    }
}

struct InitializationPacket<'a> {
    channel: Channel,
    command: Command,
    length: u16,
    data: &'a [u8],
}

impl<'a> InitializationPacket<'a> {
    const HEADER_SIZE: usize = 7;

    fn serialize(&self, buffer: &mut [u8]) -> Result<usize, RequestError> {
        let n = Self::HEADER_SIZE + self.data.len();
        if buffer.len() < n {
            return Err(RequestError::BufferTooSmall);
        }

        let channel = u32::from(self.channel).to_be_bytes();
        let length = self.length.to_be_bytes();

        buffer[0] = channel[0];
        buffer[1] = channel[1];
        buffer[2] = channel[2];
        buffer[3] = channel[3];
        buffer[4] = u8::from(self.command) | 0b1000_0000;
        buffer[5] = length[0];
        buffer[6] = length[1];

        buffer[7..n].copy_from_slice(self.data);

        Ok(n)
    }
}

struct ContinuationPacket<'a> {
    channel: Channel,
    sequence: u8,
    data: &'a [u8],
}

impl<'a> ContinuationPacket<'a> {
    const HEADER_SIZE: usize = 5;

    fn serialize(&self, buffer: &mut [u8]) -> Result<usize, RequestError> {
        let n = Self::HEADER_SIZE + self.data.len();
        if buffer.len() < n {
            return Err(RequestError::BufferTooSmall);
        }

        let channel = u32::from(self.channel).to_be_bytes();

        buffer[0] = channel[0];
        buffer[1] = channel[1];
        buffer[2] = channel[2];
        buffer[3] = channel[3];
        buffer[4] = self.sequence;

        buffer[5..n].copy_from_slice(self.data);

        Ok(n)
    }
}
