use crate::error::Error;
use crate::handler::Handler;
use crate::hostname::Hostname;
use crate::packet::{Packet, PacketData};
use crate::state::ConnectionState;

use kiss_tnc::errors::ReadError;
use kiss_tnc::tnc::Tnc;
use std::collections::HashMap;
use std::io::ErrorKind;
use std::io::{Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::time::Duration;

pub struct Server<T: Read + Write, H, C> {
	tnc: Tnc<T>,
	hostname: Hostname,
	timeout: Duration,
	max_packet_size: usize,
	max_retries: usize,
	baud_rate: usize,
	connections: HashMap<Hostname, ConnectionState<H>>,
	constructor: C,
}

impl<H: Handler, C: Fn(Hostname) -> H> Server<TcpStream, H, C> {
	pub fn new<A: ToSocketAddrs>(
		addr: A,
		hostname: Hostname,
		timeout: Duration,
		max_packet_size: usize,
		max_retries: usize,
		baud_rate: usize,
		constructor: C,
	) -> Result<Self, ReadError> {
		let tnc = Tnc::connect(addr)?;

		tnc.stream
			.set_read_timeout(Some(Duration::from_millis(100)))
			.unwrap();

		Ok(Self {
			tnc,
			hostname,
			timeout,
			max_packet_size,
			max_retries,
			baud_rate,
			connections: HashMap::new(),
			constructor,
		})
	}
}

impl<T: Read + Write, H: Handler, C: Fn(Hostname) -> H> Server<T, H, C> {
	pub fn new_with_stream(
		stream: T,
		hostname: Hostname,
		timeout: Duration,
		max_packet_size: usize,
		max_retries: usize,
		baud_rate: usize,
		constructor: C,
	) -> Self {
		let tnc = Tnc::new(stream);

		Self {
			tnc,
			hostname,
			timeout,
			max_packet_size,
			max_retries,
			baud_rate,
			connections: HashMap::new(),
			constructor,
		}
	}

	pub fn hostname(&self) -> Hostname {
		self.hostname
	}

	pub fn serve_forever(&mut self) -> Result<(), ReadError> {
		loop {
			let (_, data) = match self.tnc.read_frame() {
				Ok(x) => x,
				Err(ReadError::StreamError(e))
					if matches!(e.kind(), ErrorKind::WouldBlock | ErrorKind::TimedOut) =>
				{
					// We didn't receive anything
					(0, vec![])
				}

				Err(e) => return Err(e),
			};

			if let Some(packet) = Packet::from_bytes(&data) {
				if packet.to == self.hostname {
					self.process_packet(packet)?;
				}
			}

			let mut disconnected = vec![];
			let mut connections = std::mem::take(&mut self.connections);
			for (client, state) in connections.iter_mut() {
				let action = state.handler.process();
				state.enqueue_action(action);
				match state.process(self.hostname, *client, &mut self.tnc) {
					Ok(_) => {}
					Err(Error::Disconnected(_)) => {
						disconnected.push(*client);
					}
					Err(Error::ReadError(e)) => return Err(e),
				};
			}

			for d in disconnected {
				connections.remove(&d);
			}

			self.connections = connections;
		}
	}

	fn process_packet(&mut self, packet: Packet) -> Result<(), ReadError> {
		match &packet.data {
			PacketData::Connect => self.connection_from(packet.from)?,
			PacketData::Disconnect => self.disconnection_from(packet.from)?,
			PacketData::Ping(id) => self.pong(packet.from, *id)?,
			PacketData::DataStart(num_frames, data) => {
				self.data_start(packet.from, *num_frames, data.clone())?
			}

			PacketData::DataMid(frame_id, data) => {
				if let Some(x) = self.data_mid(packet.from, *frame_id, data.clone()) {
					x?;
				}
			}
			_ => {}
		}

		if let Some(conn) = self.connections.get_mut(&packet.from) {
			conn.incoming_frame(&mut self.tnc, packet)?;
		}

		Ok(())
	}

	fn pong(&mut self, to: Hostname, id: u8) -> Result<(), std::io::Error> {
		let mut buf = vec![];
		Packet::new(self.hostname, to, PacketData::Pong(id)).to_bytes(&mut buf);
		self.tnc.send_frame(&buf)?;

		Ok(())
	}

	fn connection_from(&mut self, from: Hostname) -> Result<(), std::io::Error> {
		// Send ACK
		let mut buf = vec![];
		Packet::new(self.hostname, from, PacketData::ConnectAck).to_bytes(&mut buf);
		self.tnc.send_frame(&buf)?;

		// If we're already connected, reset the connection

		self.connections.insert(
			from,
			ConnectionState::new(
				self.timeout,
				self.max_packet_size - crate::packet::PACKET_OVERHEAD_BYTES,
				self.max_retries,
				self.baud_rate,
				(self.constructor)(from),
			),
		);

		Ok(())
	}

	fn disconnection_from(&mut self, from: Hostname) -> Result<(), std::io::Error> {
		// Send ACK
		let mut buf = vec![];
		Packet::new(self.hostname, from, PacketData::DisconnectAck).to_bytes(&mut buf);
		self.tnc.send_frame(&buf)?;

		self.connections.remove(&from);

		Ok(())
	}

	fn data_start(
		&mut self,
		from: Hostname,
		num_frames: u16,
		data: Vec<u8>,
	) -> Result<(), ReadError> {
		if let Some(state) = self.connections.get_mut(&from) {
			if num_frames == 0 {
				// Reset so that we're ready for the next frame
				state.receiving = None;

				// Only 1 frame means we don't need to do an end ack or resend
				// Since if this frame is corrupted, the client will auto-resend
				let data = match zstd::stream::decode_all(&*data) {
					Ok(x) => x,
					Err(_) => return Ok(()),
				};

				let mut buf = vec![];
				Packet::new(self.hostname, from, PacketData::DataResend(vec![])).to_bytes(&mut buf);
				self.tnc.send_frame(&buf)?;

				let action = state.handler.recv(&data);
				state.enqueue_action(action);
			} else {
				let mut entry = vec![None; num_frames as usize + 1];
				entry[0] = Some(data);
				state.receiving = Some(entry);

				let mut buf = vec![];
				Packet::new(self.hostname, from, PacketData::DataStartAck).to_bytes(&mut buf);
				self.tnc.send_frame(&buf)?;
			}
		}

		Ok(())
	}

	fn data_mid(
		&mut self,
		from: Hostname,
		frame_id: u16,
		data: Vec<u8>,
	) -> Option<Result<(), ReadError>> {
		let state = self.connections.get_mut(&from)?;
		let chunks = state.receiving.as_mut()?;
		// Frame may be out of bounds(misconfigured or malicious sender)
		// So we must use get_mut and handle the none case
		let c = chunks.get_mut(frame_id as usize + 1)?;
		if c.is_none() {
			*c = Some(data);
		}

		if frame_id as usize == chunks.len() - 2 {
			// That was the last frame - attempt to reconstruct message
			let missing_frames: Vec<_> = chunks
				.iter()
				.enumerate()
				.filter(|(_, x)| x.is_none())
				.map(|(i, _)| (i - 1) as u16) // First chunk *must* have been received by DataStart ack so this is safe
				.collect();

			if missing_frames.is_empty() {
				// No missing frames. Reconstruct data and send to handler
				let compressed: Vec<u8> = chunks
					.iter()
					.map(|c| c.as_ref().unwrap())
					.flatten()
					.cloned()
					.collect();

				let data = match zstd::stream::decode_all(&*compressed) {
					Ok(x) => x,
					Err(_) => return None,
				};

				let mut buf = vec![];
				Packet::new(self.hostname, from, PacketData::DataResend(vec![])).to_bytes(&mut buf);
				if let Err(e) = self.tnc.send_frame(&buf) {
					return Some(Err(e.into()));
				}

				let action = state.handler.recv(&data);
				state.enqueue_action(action);

				// Reset so that we're ready for the next frame
				state.receiving = None;
			} else {
				return Some(
					self.request_missing_frames(from, missing_frames)
						.map_err(|e| e.into()),
				);
			}
		}

		Some(Ok(()))
	}

	fn request_missing_frames(
		&mut self,
		to: Hostname,
		missing_frame_ids: Vec<u16>,
	) -> Result<(), std::io::Error> {
		let mut buf = vec![];
		Packet::new(self.hostname, to, PacketData::DataResend(missing_frame_ids))
			.to_bytes(&mut buf);

		self.tnc.send_frame(&buf)?;

		Ok(())
	}
}
