use crate::error::Error;
use crate::error::{BrokenPipeError, MapBrokenPipe};
use crate::hostname::Hostname;
use crate::packet::{Packet, PacketData};
use crate::stream::Stream;

use std::sync::mpsc::{Sender, TryRecvError};
use std::time::{Duration, Instant};

// increments our(tx) sequence
// needs to be a macro to appease the borrow checker
macro_rules! seq_inc {
	($self:ident) => {{
		$self.seq_tx %= 255;
		$self.seq_tx += 1;
		$self.seq_tx
	}};
}

#[derive(Clone, Debug)]
enum SendState {
	Idle,
	AwaitingDataStartAck(Vec<Vec<u8>>),
	AwaitingDataEndAck(Vec<Vec<u8>>),
}

pub(crate) struct ConnectionState {
	sending: SendState,
	send_timeout_time: Instant,
	send_retries: usize,
	pub receiving: Option<Vec<Option<Vec<u8>>>>,
	max_data_size: usize,
	pub max_retries: usize,
	pub timeout: Duration,
	baud_rate: usize, // Used for calculating timeouts
	pub stream: Stream,
	pub packet_sender: Sender<Vec<u8>>,
	seq_tx: u8,
	seq_rx: u8,
	pub us: Hostname,
	pub them: Hostname,
}

impl ConnectionState {
	#[allow(clippy::too_many_arguments)]
	pub fn new(
		timeout: Duration,
		max_data_size: usize,
		max_retries: usize,
		baud_rate: usize,
		stream: Stream,
		packet_sender: Sender<Vec<u8>>,
		us: Hostname,
		them: Hostname,
	) -> Self {
		Self {
			sending: SendState::Idle,
			send_timeout_time: Instant::now(),
			send_retries: 0,
			receiving: None,
			max_data_size,
			max_retries,
			timeout,
			baud_rate,
			stream,
			packet_sender,
			seq_tx: rand::random(),
			seq_rx: 0,
			us,
			them,
		}
	}

	fn send_data(&mut self, data: &[u8]) -> Result<(), Error> {
		let mut data_chunks: Vec<Vec<u8>> = data
			.chunks(self.max_data_size)
			.map(|x| x.to_owned())
			.collect();

		if data_chunks.is_empty() {
			data_chunks.push(vec![])
		}

		// Construct and send DataStart packet
		let mut buf = vec![];
		let packet = Packet::new(
			self.us,
			self.them,
			self.seq_inc(),
			PacketData::DataStart((data_chunks.len() - 1) as u16, data_chunks[0].clone()),
		);
		packet.to_bytes(&mut buf);
		let data_len = buf.len();
		self.packet_sender.send(buf).map_broken_pipe()?;

		// Update our state
		self.refresh_timeout_time(data_len);
		self.send_retries = 0;

		if data_chunks.len() == 1 {
			// Waiting for end ack or retransmission
			self.sending = SendState::AwaitingDataEndAck(data_chunks);
		} else {
			// Waiting for start ack
			self.sending = SendState::AwaitingDataStartAck(data_chunks);
		}

		Ok(())
	}

	pub fn incoming_packet(&mut self, packet: Packet) -> Result<(), Error> {
		let them = packet.from;
		let us = packet.to;

		if self.us != us || self.them != them {
			// sanity check
			unreachable!();
		}

		if !self.seq_validate(packet.seq_num) {
			// ignore this frame - it's a duplicate
			return Ok(());
		}

		let data = packet.data;
		match (data, self.sending.clone()) {
			(PacketData::DataStartAck, SendState::AwaitingDataStartAck(data)) => {
				self.data_start_ack(data)?;
			}

			(PacketData::DataResend(missing_frame_ids), SendState::AwaitingDataEndAck(data)) => {
				self.data_resend(missing_frame_ids, data)?;
			}

			(PacketData::DataStart(num_frames, data), _) => {
				self.data_start(num_frames, data)?;
			}

			(PacketData::DataMid(frame_id, data), _) => {
				if let Some(x) = self.data_mid(frame_id, data) {
					x?;
				}
			}

			(PacketData::Disconnect, _) => {
				self.disconnection()?;
			}
			_ => {} // Do nothing
		}

		Ok(())
	}

	pub fn process(&mut self) -> Result<(), Error> {
		if let SendState::Idle = self.sending {
			// Process next action, if available
			match self.stream.try_recv() {
				Ok(data) => self.compress_and_send_data(&data)?,
				Err(TryRecvError::Disconnected) => {
					self.disconnect(self.us, self.them)?;

					return Err(Error::Disconnected);
				}
				Err(TryRecvError::Empty) => {}
			}
		}

		if self.send_timeout_time > Instant::now() {
			return Ok(());
		}

		if self.send_retries >= self.max_retries {
			self.send_retries = 0;
			self.sending = SendState::Idle;
			return Ok(());
		}

		match &self.sending {
			SendState::Idle => {}
			SendState::AwaitingDataStartAck(data) => {
				// Resend first packet
				let mut buf = vec![];
				Packet::new(
					self.us,
					self.them,
					seq_inc!(self),
					PacketData::DataStart((data.len() - 1) as u16, data[0].clone()),
				)
				.to_bytes(&mut buf);
				let data_len = buf.len();
				self.packet_sender.send(buf).map_broken_pipe()?;

				// Increment tries
				self.send_retries += 1;

				// Reset timer
				self.refresh_timeout_time(data_len);
			}
			SendState::AwaitingDataEndAck(data) => {
				// Resend last packet
				let mut buf = vec![];
				let idx = data.len() - 1;
				let data_len = if idx == 0 {
					Packet::new(
						self.us,
						self.them,
						seq_inc!(self),
						PacketData::DataStart(0, data[idx].clone()),
					)
					.to_bytes(&mut buf);
					let len = buf.len();
					self.packet_sender.send(buf).map_broken_pipe()?;

					len
				} else {
					Packet::new(
						self.us,
						self.them,
						seq_inc!(self),
						PacketData::DataMid((idx - 1) as u16, data[idx].clone()),
					)
					.to_bytes(&mut buf);
					let len = buf.len();
					self.packet_sender.send(buf).map_broken_pipe()?;

					len
				};

				// Increment tries
				self.send_retries += 1;

				// Reset timer
				self.refresh_timeout_time(data_len);
			}
		}

		Ok(())
	}

	pub fn disconnect(&mut self, from: Hostname, to: Hostname) -> Result<(), BrokenPipeError> {
		let mut buf = vec![];
		Packet::new(from, to, self.seq_inc(), PacketData::Disconnect).to_bytes(&mut buf);
		self.packet_sender.send(buf).map_broken_pipe()
	}

	// wrapper around our macro
	pub fn seq_inc(&mut self) -> u8 {
		seq_inc!(self)
	}

	// for their(rx) sequence
	fn seq_validate(&mut self, new: u8) -> bool {
		if crate::packet::is_valid_seq_num(self.seq_rx, new) {
			self.seq_rx = new;

			true
		} else {
			false
		}
	}

	fn data_start(&mut self, num_frames: u16, data: Vec<u8>) -> Result<(), Error> {
		if num_frames == 0 {
			// Reset so that we're ready for the next frame
			self.receiving = None;

			// Only 1 frame means we don't need to do an end ack or resend
			// Since if this frame is corrupted, the client will auto-resend
			let data = match zstd::stream::decode_all(&*data) {
				Ok(x) => x,
				Err(_) => return Ok(()),
			};

			let mut buf = vec![];
			Packet::new(
				self.us,
				self.them,
				self.seq_inc(),
				PacketData::DataResend(vec![]),
			)
			.to_bytes(&mut buf);
			self.packet_sender.send(buf).map_broken_pipe()?;

			if self.stream.write(data).is_err() {
				// disconnect and drop state
				self.disconnect(self.us, self.them)?;

				return Err(Error::Disconnected);
			}
		} else {
			let mut entry = vec![None; num_frames as usize + 1];
			entry[0] = Some(data);
			self.receiving = Some(entry);

			let mut buf = vec![];
			Packet::new(self.us, self.them, self.seq_inc(), PacketData::DataStartAck)
				.to_bytes(&mut buf);
			self.packet_sender.send(buf).map_broken_pipe()?;
		}

		Ok(())
	}

	fn data_mid(&mut self, frame_id: u16, data: Vec<u8>) -> Option<Result<(), Error>> {
		let chunks = self.receiving.as_mut()?;
		// Frame may be out of bounds(misconfigured or malicious sender)
		// So we must use get_mut and handle the none case
		let c = chunks.get_mut(frame_id as usize + 1)?;
		if c.is_none() {
			*c = Some(data);
		}

		if frame_id as usize == chunks.len() - 2 {
			// That was the last frame - attempt to reconstruct message
			let missing_frames: Vec<_> = chunks
				.iter()
				.enumerate()
				.filter(|(_, x)| x.is_none())
				.map(|(i, _)| (i - 1) as u16) // First chunk *must* have been received by DataStart ack so this is safe
				.collect();

			if missing_frames.is_empty() {
				// No missing frames. Reconstruct data and send to handler
				let compressed: Vec<u8> = chunks
					.iter()
					.map(|c| c.as_ref().unwrap())
					.flatten()
					.cloned()
					.collect();

				let data = match zstd::stream::decode_all(&*compressed) {
					Ok(x) => x,
					Err(_) => return None,
				};

				let mut buf = vec![];
				Packet::new(
					self.us,
					self.them,
					self.seq_inc(),
					PacketData::DataResend(vec![]),
				)
				.to_bytes(&mut buf);
				if let Err(e) = self.packet_sender.send(buf).map_broken_pipe() {
					return Some(Err(e.into()));
				}

				if self.stream.write(data).is_err() {
					// Stream is broken
					// Disconnect and drop state
					if let Err(BrokenPipeError) = self.disconnect(self.us, self.them) {
						return Some(Err(Error::BrokenPipe));
					}

					return Some(Err(Error::Disconnected));
				}

				// Reset so that we're ready for the next frame
				self.receiving = None;
			} else {
				return Some(self.request_missing_frames(missing_frames));
			}
		}

		Some(Ok(()))
	}

	fn disconnection(&mut self) -> Result<(), Error> {
		let mut buf = vec![];
		Packet::new(
			self.us,
			self.them,
			self.seq_inc(),
			PacketData::DisconnectAck,
		)
		.to_bytes(&mut buf);
		self.packet_sender.send(buf).map_broken_pipe()?;

		Err(Error::Disconnected)
	}

	fn data_start_ack(&mut self, data: Vec<Vec<u8>>) -> Result<(), Error> {
		// send all of the packets
		let data_len = data
			.clone()
			.into_iter()
			.enumerate()
			.skip(1) // first packet was sent in DataStart
			.map(|(frame_id, frame)| {
				let mut buf = vec![];
				Packet::new(
					self.us,
					self.them,
					self.seq_inc(),
					PacketData::DataMid(frame_id as u16 - 1, frame),
				)
				.to_bytes(&mut buf);
				let len = buf.len();
				self.packet_sender.send(buf).map_broken_pipe()?;

				Ok(len)
			})
			.fold(Ok(0), |acc: Result<usize, Error>, x| match (acc, x) {
				(Ok(acc), Ok(x)) => Ok(acc + x),
				(Ok(_), Err(e)) => Err(e),
				(Err(e), _) => Err(e),
			})?;

		// Reset state
		self.refresh_timeout_time(data_len);
		self.send_retries = 0;
		self.sending = SendState::AwaitingDataEndAck(data);

		Ok(())
	}

	fn data_resend(
		&mut self,
		missing_frame_ids: Vec<u16>,
		data: Vec<Vec<u8>>,
	) -> Result<(), Error> {
		if missing_frame_ids.is_empty() {
			// Transmission finished - success
			self.sending = SendState::Idle;
			return Ok(());
		}

		// Check that they didn't request anything invalid
		let mut max_frame_id = 0;
		for frame_id in &missing_frame_ids {
			if *frame_id as usize >= data.len() {
				// OOB - ignore
				return Ok(());
			}

			if *frame_id > max_frame_id {
				max_frame_id = *frame_id;
			}
		}

		// Send frames they requested
		let data_len = missing_frame_ids
			.iter()
			.map(|frame_id| {
				let mut buf = vec![];
				Packet::new(
					self.us,
					self.them,
					self.seq_inc(),
					PacketData::DataMid(*frame_id, data[*frame_id as usize + 1].clone()),
				)
				.to_bytes(&mut buf);
				let len = buf.len();
				self.packet_sender.send(buf).map_broken_pipe()?;

				Ok(len)
			})
			.fold(Ok(0), |acc: Result<usize, Error>, x| match (acc, x) {
				(Ok(acc), Ok(x)) => Ok(acc + x),
				(Ok(_), Err(e)) => Err(e),
				(Err(e), _) => Err(e),
			})?;

		// If they didn't request the last frame, resend it w/ empty data
		// This is used so the other side can detect that we've finished
		if max_frame_id as usize != data.len() - 2 {
			let mut buf = vec![];
			Packet::new(
				self.us,
				self.them,
				self.seq_inc(),
				PacketData::DataMid((data.len() - 2) as u16, vec![]),
			)
			.to_bytes(&mut buf);
			self.packet_sender.send(buf).map_broken_pipe()?;
		}

		// Reset state
		self.refresh_timeout_time(data_len);
		self.send_retries = 0;

		Ok(())
	}

	fn refresh_timeout_time(&mut self, bytes: usize) {
		self.send_timeout_time = Instant::now()
			+ self.timeout
			+ Duration::from_millis((bytes * 8 * 1000 / self.baud_rate) as u64);
	}

	fn compress_and_send_data(&mut self, data: &[u8]) -> Result<(), Error> {
		// Safe to unwrap because read from a vector will never fail
		let compressed = zstd::stream::encode_all(data, 9).unwrap();

		self.send_data(&compressed)
	}

	fn request_missing_frames(&mut self, missing_frame_ids: Vec<u16>) -> Result<(), Error> {
		let mut buf = vec![];
		Packet::new(
			self.us,
			self.them,
			self.seq_inc(),
			PacketData::DataResend(missing_frame_ids),
		)
		.to_bytes(&mut buf);

		self.packet_sender.send(buf).map_broken_pipe()?;

		Ok(())
	}
}
