use crate::hostname::Hostname;

use crc::{Crc, CRC_16_IBM_SDLC};

// 8 bytes for hostname
// 8 bytes for hostname
// 1 byte for packet type
// Up to 2 bytes for packet metadata (frame id or # of frames)
// 2 bytes for checksum
pub(crate) const PACKET_OVERHEAD_BYTES: usize = 8 + 8 + 1 + 2 + 2;

const X25: Crc<u16> = Crc::<u16>::new(&CRC_16_IBM_SDLC);

#[derive(Debug, PartialEq)]
pub(crate) struct Packet {
	pub from: Hostname,
	pub to: Hostname,
	pub data: PacketData,
}

impl Packet {
	pub fn new(from: Hostname, to: Hostname, data: PacketData) -> Self {
		Self { from, to, data }
	}

	pub fn to_bytes(&self, buf: &mut Vec<u8>) {
		let start_idx = buf.len();

		self.from.to_bytes(buf);
		self.to.to_bytes(buf);
		self.data.to_bytes(buf);

		// Add checksum to tail of buffer
		let checksum = X25.checksum(&buf[start_idx..]);
		buf.extend(u16_to_u8s(checksum));
	}

	pub fn from_bytes(buf: &[u8]) -> Option<Self> {
		if buf.len() < 18 {
			return None;
		}

		let checksum_actual = u8s_to_u16(&buf[(buf.len() - 2)..]);
		let buf = &buf[0..(buf.len() - 2)];
		let checksum_expected = X25.checksum(buf);

		if checksum_actual != checksum_expected {
			return None;
		}

		let from = Hostname::from_bytes(&buf[0..8])?;
		let to = Hostname::from_bytes(&buf[8..16])?;
		let data = PacketData::from_bytes(&buf[16..])?;

		Some(Self { from, to, data })
	}
}

const PING: u8 = 0;
const PONG: u8 = 1;
const CONNECT: u8 = 2;
const CONNECT_ACK: u8 = 3;
const DISCONNECT: u8 = 4;
const DISCONNECT_ACK: u8 = 5;
const DATA_START: u8 = 6;
const DATA_START_ACK: u8 = 7;
const DATA_MID: u8 = 8;
const DATA_RESEND: u8 = 9;

#[derive(PartialEq, Eq, Debug)]
pub(crate) enum PacketData {
	Ping(u8),
	Pong(u8),

	Connect,
	ConnectAck,

	Disconnect,
	DisconnectAck,

	// Highest frame ID, data
	DataStart(u16, Vec<u8>),
	DataStartAck,
	// Frame ID, data
	DataMid(u16, Vec<u8>),
	// Frame IDs to resend
	// DataResend w/ empty vec = DataAck
	DataResend(Vec<u16>),
}

impl PacketData {
	pub fn to_bytes(&self, buf: &mut Vec<u8>) {
		match self {
			Self::Ping(id) => {
				buf.push(PING);
				buf.push(*id);
			}

			Self::Pong(id) => {
				buf.push(PONG);
				buf.push(*id);
			}

			Self::Connect => buf.push(CONNECT),

			Self::ConnectAck => buf.push(CONNECT_ACK),

			Self::Disconnect => buf.push(DISCONNECT),
			Self::DisconnectAck => buf.push(DISCONNECT_ACK),

			Self::DataStart(num_frames, data) => {
				buf.push(DATA_START);
				buf.extend(u16_to_u8s(*num_frames)); // Two bytes for total # of frames
				buf.extend(data);
			}

			Self::DataStartAck => buf.push(DATA_START_ACK),

			Self::DataMid(frame_id, data) => {
				buf.push(DATA_MID);
				buf.extend(u16_to_u8s(*frame_id)); // Two bytes for current frame id
				buf.extend(data);
			}

			Self::DataResend(frame_ids) => {
				buf.push(DATA_RESEND);
				buf.extend(
					frame_ids
						.iter()
						.map(|frame_id| u16_to_u8s(*frame_id))
						.flatten()
						.collect::<Vec<_>>(),
				);
			}
		}
	}

	pub fn from_bytes(buf: &[u8]) -> Option<Self> {
		let id = buf[0];
		let buf = &buf[1..];

		match id {
			PING => {
				if buf.len() == 1 {
					Some(Self::Ping(buf[0]))
				} else {
					None
				}
			}

			PONG => {
				if buf.len() == 1 {
					Some(Self::Pong(buf[0]))
				} else {
					None
				}
			}

			CONNECT => {
				if buf.is_empty() {
					Some(Self::Connect)
				} else {
					None
				}
			}

			CONNECT_ACK => {
				if buf.is_empty() {
					Some(Self::ConnectAck)
				} else {
					None
				}
			}

			DISCONNECT => {
				if buf.is_empty() {
					Some(Self::Disconnect)
				} else {
					None
				}
			}

			DISCONNECT_ACK => {
				if buf.is_empty() {
					Some(Self::DisconnectAck)
				} else {
					None
				}
			}

			DATA_START => {
				if buf.len() >= 2 {
					// Num additional frames - not including this one
					let num_frames = &buf[0..2];

					let data = buf[2..].to_owned();

					Some(Self::DataStart(u8s_to_u16(num_frames), data))
				} else {
					None
				}
			}

			DATA_START_ACK => {
				if buf.is_empty() {
					Some(Self::DataStartAck)
				} else {
					None
				}
			}

			DATA_MID => {
				if buf.len() >= 2 {
					let frame_id = &buf[0..2];
					let data = buf[2..].to_owned();

					Some(Self::DataMid(u8s_to_u16(frame_id), data))
				} else {
					None
				}
			}

			DATA_RESEND => {
				if buf.len() % 2 == 0 {
					let frame_ids = buf.chunks(2).map(u8s_to_u16).collect();

					Some(Self::DataResend(frame_ids))
				} else {
					None
				}
			}

			_ => None,
		}
	}
}

fn u16_to_u8s(i: u16) -> [u8; 2] {
	let lower = (i & 255) as u8;
	let upper = ((i >> 8) & 255) as u8;

	[lower, upper]
}

fn u8s_to_u16(i: &[u8]) -> u16 {
	i[0] as u16 + ((i[1] as u16) << 8)
}

#[cfg(test)]
mod tests {
	use super::*;

	#[test]
	fn encode_and_decode_ping() {
		let packet = Packet::new(
			Hostname::new("N0CALL-0").unwrap(),
			Hostname::new("ABCDEFG-243").unwrap(),
			PacketData::Ping(37),
		);

		let mut buf = vec![];
		packet.to_bytes(&mut buf);

		let new_packet = Packet::from_bytes(&buf).unwrap();
		assert_eq!(packet, new_packet);
	}

	#[test]
	fn encode_and_decode_data_start() {
		let packet = Packet::new(
			Hostname::new("N0CALL-0").unwrap(),
			Hostname::new("ABCDEFG-243").unwrap(),
			PacketData::DataStart(
				5,
				vec![
					1, 2, 5, 7, 8, 43, 54, 7, 46, 56, 25, 6, 4, 7, 98, 34, 7, 54, 85, 35,
				],
			),
		);

		let mut buf = vec![];
		packet.to_bytes(&mut buf);

		let new_packet = Packet::from_bytes(&buf).unwrap();
		assert_eq!(packet, new_packet);
	}

	#[test]
	fn encode_and_decode_data_start_ack() {
		let packet = Packet::new(
			Hostname::new("N0CALL-0").unwrap(),
			Hostname::new("ABCDEFG-243").unwrap(),
			PacketData::DataStartAck,
		);

		let mut buf = vec![];
		packet.to_bytes(&mut buf);

		let new_packet = Packet::from_bytes(&buf).unwrap();
		assert_eq!(packet, new_packet);
	}

	#[test]
	fn test_crc_detects_error_properly_all_bitflips() {
		let packet = Packet::new(
			Hostname::new("N0CALL-0").unwrap(),
			Hostname::new("ABCDEFG-243").unwrap(),
			PacketData::Ping(37),
		);

		let mut buf = vec![];
		packet.to_bytes(&mut buf);

		for i in 0..buf.len() {
			for j in 0..8 {
				// Flip single bit (jth position of ith byte)
				buf[i] ^= 1 << j;

				// Ensure the packet was not decoded

				let new_packet = Packet::from_bytes(&buf);
				assert!(new_packet.is_none());

				// Flip it back and confirm it decodes
				buf[i] ^= 1 << j;

				let new_packet = Packet::from_bytes(&buf).unwrap();
				assert_eq!(packet, new_packet);
			}
		}
	}
}
