use crate::error::Error;
use crate::handler::Action;
use crate::hostname::Hostname;
use crate::packet::{Packet, PacketData};

use kiss_tnc::errors::ReadError;
use kiss_tnc::tnc::Tnc;
use std::collections::VecDeque;
use std::io::{Read, Write};
use std::time::{Duration, Instant};

#[derive(Clone, Debug)]
enum SendState {
	Idle,
	AwaitingDataStartAck(Vec<Vec<u8>>),
	AwaitingDataEndAck(Vec<Vec<u8>>),
}

pub(crate) struct ConnectionState<H> {
	sending: SendState,
	send_timeout_time: Instant,
	send_retries: usize,
	pub receiving: Option<Vec<Option<Vec<u8>>>>,
	pub handler: H,
	max_data_size: usize,
	pub max_retries: usize,
	pub timeout: Duration,
	baud_rate: usize, // Used for calculating timeouts
	actions: VecDeque<Action>,
}

impl<H> ConnectionState<H> {
	pub fn new(
		timeout: Duration,
		max_data_size: usize,
		max_retries: usize,
		baud_rate: usize,
		handler: H,
	) -> Self {
		Self {
			sending: SendState::Idle,
			send_timeout_time: Instant::now(),
			send_retries: 0,
			receiving: None,
			handler,
			max_data_size,
			max_retries,
			timeout,
			baud_rate,
			actions: VecDeque::new(),
		}
	}

	fn send_data<T: Read + Write>(
		&mut self,
		tnc: &mut Tnc<T>,
		from: Hostname,
		to: Hostname,
		data: &[u8],
	) -> Result<(), ReadError> {
		let mut data_chunks: Vec<Vec<u8>> = data
			.chunks(self.max_data_size)
			.map(|x| x.to_owned())
			.collect();

		if data_chunks.is_empty() {
			data_chunks.push(vec![])
		}

		// Construct and send DataStart packet
		let mut buf = vec![];
		Packet::new(
			from,
			to,
			PacketData::DataStart((data_chunks.len() - 1) as u16, data_chunks[0].clone()),
		)
		.to_bytes(&mut buf);
		let data_len = tnc.send_frame(&buf)?;

		// Update our state
		self.refresh_timeout_time(data_len);
		self.send_retries = 0;

		if data_chunks.len() == 1 {
			// Waiting for end ack or retransmission
			self.sending = SendState::AwaitingDataEndAck(data_chunks);
		} else {
			// Waiting for start ack
			self.sending = SendState::AwaitingDataStartAck(data_chunks);
		}

		Ok(())
	}

	pub fn incoming_frame<T: Read + Write>(
		&mut self,
		tnc: &mut Tnc<T>,
		packet: Packet,
	) -> Result<(), ReadError> {
		let them = packet.from;
		let us = packet.to;
		let data = packet.data;
		match (data, self.sending.clone()) {
			(PacketData::DataStartAck, SendState::AwaitingDataStartAck(data)) => {
				// send all of the packets
				let data_len = data
					.clone()
					.into_iter()
					.enumerate()
					.skip(1) // first packet was sent in DataStart
					.map(|(frame_id, frame)| {
						let mut buf = vec![];
						Packet::new(us, them, PacketData::DataMid(frame_id as u16 - 1, frame))
							.to_bytes(&mut buf);
						tnc.send_frame(&buf)
					})
					.fold(Ok(0), |acc, x| match (acc, x) {
						(Ok(acc), Ok(x)) => Ok(acc + x),
						(Ok(_), Err(e)) => Err(e),
						(Err(e), _) => Err(e),
					})?;

				// Reset state
				self.refresh_timeout_time(data_len);
				self.send_retries = 0;
				self.sending = SendState::AwaitingDataEndAck(data);
			}

			(PacketData::DataResend(missing_frame_ids), SendState::AwaitingDataEndAck(data)) => {
				if missing_frame_ids.is_empty() {
					// Transmission finished - success
					self.sending = SendState::Idle;
					return Ok(());
				}

				// Check that they didn't request anything invalid
				let mut max_frame_id = 0;
				for frame_id in &missing_frame_ids {
					if *frame_id as usize >= data.len() {
						// OOB - ignore
						return Ok(());
					}

					if *frame_id > max_frame_id {
						max_frame_id = *frame_id;
					}
				}

				// Send frames they requested
				let data_len = missing_frame_ids
					.iter()
					.map(|frame_id| {
						let mut buf = vec![];
						Packet::new(
							us,
							them,
							PacketData::DataMid(*frame_id, data[*frame_id as usize + 1].clone()),
						)
						.to_bytes(&mut buf);
						tnc.send_frame(&buf)
					})
					.fold(Ok(0), |acc, x| match (acc, x) {
						(Ok(acc), Ok(x)) => Ok(acc + x),
						(Ok(_), Err(e)) => Err(e),
						(Err(e), _) => Err(e),
					})?;

				// If they didn't request the last frame, resend it w/ empty data
				// This is used so the other side can detect that we've finished
				if max_frame_id as usize != data.len() - 2 {
					let mut buf = vec![];
					Packet::new(
						us,
						them,
						PacketData::DataMid((data.len() - 2) as u16, vec![]),
					)
					.to_bytes(&mut buf);
					tnc.send_frame(&buf)?;
				}

				// Reset state
				self.refresh_timeout_time(data_len);
				self.send_retries = 0;
			}
			_ => {} // Do nothing
		}

		Ok(())
	}

	pub fn process<T: Read + Write>(
		&mut self,
		from: Hostname,
		to: Hostname,
		tnc: &mut Tnc<T>,
	) -> Result<(), Error> {
		if let SendState::Idle = self.sending {
			// Process next action, if available
			if let Some(action) = self.actions.pop_front() {
				self.handle_action(from, to, tnc, action)?;
			}
		}

		if self.send_timeout_time > Instant::now() {
			return Ok(());
		}

		if self.send_retries >= self.max_retries {
			self.send_retries = 0;
			self.sending = SendState::Idle;
			return Ok(());
		}

		match &self.sending {
			SendState::Idle => {}
			SendState::AwaitingDataStartAck(data) => {
				// Resend first packet
				let mut buf = vec![];
				Packet::new(
					from,
					to,
					PacketData::DataStart((data.len() - 1) as u16, data[0].clone()),
				)
				.to_bytes(&mut buf);
				let data_len = tnc
					.send_frame(&buf)
					.map_err(|e| Error::ReadError(e.into()))?;

				// Increment tries
				self.send_retries += 1;

				// Reset timer
				self.refresh_timeout_time(data_len);
			}
			SendState::AwaitingDataEndAck(data) => {
				// Resend last packet
				let mut buf = vec![];
				let idx = data.len() - 1;
				let data_len = if idx == 0 {
					Packet::new(from, to, PacketData::DataStart(0, data[idx].clone()))
						.to_bytes(&mut buf);
					tnc.send_frame(&buf)
						.map_err(|e| Error::ReadError(e.into()))?
				} else {
					Packet::new(
						from,
						to,
						PacketData::DataMid((idx - 1) as u16, data[idx].clone()),
					)
					.to_bytes(&mut buf);
					tnc.send_frame(&buf)
						.map_err(|e| Error::ReadError(e.into()))?
				};

				// Increment tries
				self.send_retries += 1;

				// Reset timer
				self.refresh_timeout_time(data_len);
			}
		}

		Ok(())
	}

	pub fn enqueue_action(&mut self, action: Action) {
		if !matches!(action, Action::Nothing) {
			self.actions.push_back(action);
		}
	}

	fn refresh_timeout_time(&mut self, bytes: usize) {
		self.send_timeout_time = Instant::now()
			+ self.timeout
			+ Duration::from_millis((bytes * 8 * 1000 / self.baud_rate) as u64);
	}

	fn handle_action<T: Read + Write>(
		&mut self,
		from: Hostname,
		to: Hostname,
		tnc: &mut Tnc<T>,
		action: Action,
	) -> Result<(), Error> {
		match action {
			Action::Disconnect => {
				let mut buf = vec![];
				Packet::new(from, to, PacketData::Disconnect).to_bytes(&mut buf);
				tnc.send_frame(&buf)
					.map_err(|e| Error::ReadError(e.into()))?;

				return Err(Error::Disconnected(to));
			}
			Action::SendData(data) => Self::compress_and_send_data(self, from, to, tnc, &data)?,
			Action::Nothing => {}
		}

		Ok(())
	}

	fn compress_and_send_data<T: Read + Write>(
		&mut self,
		from: Hostname,
		to: Hostname,
		tnc: &mut Tnc<T>,
		data: &[u8],
	) -> Result<(), ReadError> {
		// Safe to unwrap because read from a vector will never fail
		let compressed = zstd::stream::encode_all(data, 9).unwrap();

		self.send_data(tnc, from, to, &compressed)?;
		Ok(())
	}
}
