use crate::error::Error;
use crate::handler::Handler;
use crate::hostname::Hostname;
use crate::packet::{Packet, PacketData};
use crate::state::ConnectionState;

use kiss_tnc::errors::ReadError;
use kiss_tnc::tnc::Tnc;
use std::io::ErrorKind;
use std::io::{Read, Write};
use std::net::TcpStream;
use std::net::ToSocketAddrs;
use std::time::{Duration, Instant};

pub struct Client<T: Read + Write> {
	tnc: Tnc<T>,
	us: Hostname,
	them: Hostname,
	state: ConnectionState<()>,
}

impl Client<TcpStream> {
	pub fn new<A: ToSocketAddrs>(
		addr: A,
		from: Hostname,
		to: Hostname,
		timeout: Duration,
		max_packet_size: usize,
		max_retries: usize,
		baud_rate: usize,
	) -> Result<Self, ReadError> {
		let tnc = Tnc::connect(addr)?;

		tnc.stream
			.set_read_timeout(Some(Duration::from_millis(100)))
			.unwrap();
		Ok(Self {
			tnc,
			us: from,
			them: to,
			state: ConnectionState::new(
				timeout,
				max_packet_size - crate::packet::PACKET_OVERHEAD_BYTES,
				max_retries,
				baud_rate,
				(),
			),
		})
	}
}

impl<T: Read + Write> Client<T> {
	pub fn new_with_stream(
		stream: T,
		from: Hostname,
		to: Hostname,
		timeout: Duration,
		max_packet_size: usize,
		max_retries: usize,
		baud_rate: usize,
	) -> Self {
		let tnc = Tnc::new(stream);

		Self {
			tnc,
			us: from,
			them: to,
			state: ConnectionState::new(
				timeout,
				max_packet_size - crate::packet::PACKET_OVERHEAD_BYTES,
				max_retries,
				baud_rate,
				(),
			),
		}
	}

	pub fn ping(&mut self) -> Result<bool, ReadError> {
		for _ in 0..self.state.max_retries {
			let id = rand::random();

			let mut buf = vec![];
			Packet::new(self.us, self.them, PacketData::Ping(id)).to_bytes(&mut buf);
			self.tnc.send_frame(&buf)?;

			let max_time = Instant::now() + self.state.timeout;
			while Instant::now() < max_time {
				if let Some(packet) = self.read_packet()? {
					if packet.from == self.them
						&& packet.to == self.us && packet.data == PacketData::Pong(id)
					{
						// Ping success
						return Ok(true);
					}
				}
			}
		}

		// No response
		Ok(false)
	}

	/// Connects to the specified server.
	/// Doesn't return until the connection is disconnected/lost.
	/// Use the handler to send and receive data.
	pub fn connect<H: Handler>(&mut self, handler: &mut H) -> Result<(), Error> {
		let mut connected = false;
		for _ in 0..self.state.max_retries {
			if self.try_establish_connection()? {
				connected = true;
				break;
			}
		}

		if !connected {
			return Err(Error::Disconnected(self.them));
		}

		loop {
			if let Some(packet) = self.read_packet()? {
				if packet.to == self.us && packet.from == self.them {
					self.process_packet(packet, handler)?;
				}
			}

			let action = handler.process();
			self.state.enqueue_action(action);
			self.state.process(self.us, self.them, &mut self.tnc)?;
		}
	}

	fn process_packet<H: Handler>(&mut self, packet: Packet, handler: &mut H) -> Result<(), Error> {
		match &packet.data {
			PacketData::Disconnect => {
				self.disconnection()?;
				return Err(Error::Disconnected(self.them));
			}
			PacketData::Ping(id) => self.pong(*id)?,
			PacketData::DataStart(num_frames, data) => {
				self.data_start(*num_frames, data.clone(), handler)?
			}
			PacketData::DataMid(frame_id, data) => {
				if let Some(x) = self.data_mid(*frame_id, data.clone(), handler) {
					x?;
				}
			}
			_ => {}
		}

		self.state.incoming_frame(&mut self.tnc, packet)?;

		Ok(())
	}

	fn disconnection(&mut self) -> Result<(), ReadError> {
		let mut buf = vec![];
		Packet::new(self.us, self.them, PacketData::DisconnectAck).to_bytes(&mut buf);
		self.tnc.send_frame(&buf)?;

		Ok(())
	}

	fn pong(&mut self, id: u8) -> Result<(), ReadError> {
		let mut buf = vec![];
		Packet::new(self.us, self.them, PacketData::Ping(id)).to_bytes(&mut buf);
		self.tnc.send_frame(&buf)?;

		Ok(())
	}

	fn data_start<H: Handler>(
		&mut self,
		num_frames: u16,
		data: Vec<u8>,
		handler: &mut H,
	) -> Result<(), ReadError> {
		if num_frames == 0 {
			// Reset so that we're ready for the next frame
			self.state.receiving = None;

			let data = match zstd::stream::decode_all(&*data) {
				Ok(x) => x,
				Err(_) => return Ok(()),
			};

			let mut buf = vec![];
			Packet::new(self.us, self.them, PacketData::DataResend(vec![])).to_bytes(&mut buf);
			self.tnc.send_frame(&buf)?;

			let action = handler.recv(&data);
			self.state.enqueue_action(action);
		} else {
			let mut entry = vec![None; num_frames as usize + 1];
			entry[0] = Some(data);
			self.state.receiving = Some(entry);

			let mut buf = vec![];
			Packet::new(self.us, self.them, PacketData::DataStartAck).to_bytes(&mut buf);
			self.tnc.send_frame(&buf)?;
		}

		Ok(())
	}

	fn data_mid<H: Handler>(
		&mut self,
		frame_id: u16,
		data: Vec<u8>,
		handler: &mut H,
	) -> Option<Result<(), ReadError>> {
		let chunks = self.state.receiving.as_mut()?;
		// Frame may be out of bounds(misconfigured or malicious sender)
		// So we must use get_mut and handle the none case
		let c = chunks.get_mut(frame_id as usize + 1)?;
		if c.is_none() {
			*c = Some(data);
		}

		if frame_id as usize == chunks.len() - 2 {
			// That was the last frame - attempt to reconstruct message
			let missing_frames: Vec<_> = chunks
				.iter()
				.enumerate()
				.filter(|(_, x)| x.is_none())
				.map(|(i, _)| (i - 1) as u16) // First chunk *must* have been received by DataStart ack so this is safe
				.collect();

			if missing_frames.is_empty() {
				// No missing frames. Reconstruct data and send to handler
				let compressed: Vec<u8> = chunks
					.iter()
					.map(|c| c.as_ref().unwrap())
					.flatten()
					.cloned()
					.collect();

				let data = match zstd::stream::decode_all(&*compressed) {
					Ok(x) => x,
					Err(_) => return None,
				};

				let mut buf = vec![];
				Packet::new(self.us, self.them, PacketData::DataResend(vec![])).to_bytes(&mut buf);
				if let Err(e) = self.tnc.send_frame(&buf) {
					return Some(Err(e.into()));
				}

				let action = handler.recv(&data);
				self.state.enqueue_action(action);

				// Reset so that we're ready for the next frame
				self.state.receiving = None;
			} else {
				return Some(
					self.request_missing_frames(missing_frames)
						.map_err(|e| e.into()),
				);
			}
		}

		Some(Ok(()))
	}

	fn request_missing_frames(
		&mut self,
		missing_frame_ids: Vec<u16>,
	) -> Result<(), std::io::Error> {
		let mut buf = vec![];
		Packet::new(
			self.us,
			self.them,
			PacketData::DataResend(missing_frame_ids),
		)
		.to_bytes(&mut buf);

		self.tnc.send_frame(&buf)?;

		Ok(())
	}

	fn try_establish_connection(&mut self) -> Result<bool, ReadError> {
		// Send connect packet
		let mut buf = vec![];
		Packet::new(self.us, self.them, PacketData::Connect).to_bytes(&mut buf);
		self.tnc.send_frame(&buf)?;

		let max_time = Instant::now() + self.state.timeout;
		while Instant::now() < max_time {
			if let Some(packet) = self.read_packet()? {
				if packet.from == self.them
					&& packet.to == self.us
					&& packet.data == PacketData::ConnectAck
				{
					// Connection established
					return Ok(true);
				}
			}
		}

		Ok(false)
	}

	fn read_packet(&mut self) -> Result<Option<Packet>, ReadError> {
		let (_, data) = match self.tnc.read_frame() {
			Ok(x) => x,
			Err(ReadError::StreamError(e))
				if matches!(e.kind(), ErrorKind::WouldBlock | ErrorKind::TimedOut) =>
			{
				// We didn't receive anything
				return Ok(None);
			}

			Err(e) => return Err(e),
		};

		let packet = Packet::from_bytes(&data);

		Ok(packet)
	}
}
