use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use std::iter::once;
use syn::{
	parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed,
	FieldsUnnamed,
};

#[proc_macro_derive(Parsable)]
pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
	let DeriveInput {
		ident,
		generics,
		data,
		..
	} = parse_macro_input!(input);

	let expression = generate_expression(&data);

	let output = quote! {
		impl<'a> ::parst::Parsable<'a> for #ident #generics {
			fn read(bytes: &'a [u8]) -> ::parst::PResult<Self> {
				#![allow(non_snake_case)]
				#expression
			}
		}
	};
	proc_macro::TokenStream::from(output)
}

fn gen_field_reads(fields: &Fields, name: TokenStream) -> (TokenStream, TokenStream) {
	match fields {
		Fields::Named(FieldsNamed { named, .. }) => {
			let assignments = named
				.iter()
				.map(|Field { ident, .. }| {
					let name = ident.as_ref().unwrap();
					quote! {
						let (#name, bytes) = ::parst::Parsable::read(bytes)?;
					}
				})
				.collect();

			let names = named
				.iter()
				.map(|Field { ident, .. }| {
					let name = ident.as_ref().unwrap();
					quote! { #name, }
				})
				.collect::<TokenStream>();

			(assignments, quote! { #name { #names } })
		}
		Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
			let assignments = (0..unnamed.len())
				.map(|n| {
					let name = format_ident!("field_{}", n);
					quote! {
						let (#name, bytes) = ::parst::Parsable::read(bytes)?;
					}
				})
				.collect();
			let names = (0..unnamed.len())
				.map(|n| {
					let name = format_ident!("field_{}", n);
					quote! { #name, }
				})
				.collect::<TokenStream>();

			(assignments, quote! { #name ( #names ) })
		}
		Fields::Unit => (quote! {}, quote! { #name }),
	}
}

fn generate_expression(data: &Data) -> TokenStream {
	match data {
		Data::Struct(DataStruct { fields, .. }) => {
			let (assignments, finished) = gen_field_reads(fields, quote! { Self });
			quote! {
				#assignments
				Ok((#finished, bytes))
			}
		}
		Data::Enum(DataEnum { variants, .. }) => variants
			.iter()
			.map(|variant| {
				let name = &variant.ident;
				let (assignments, finished) =
					gen_field_reads(&variant.fields, quote! { Self::#name });
				let fn_name = format_ident!("decode_{}", name);
				quote! {
					let #fn_name = | bytes: &'a [u8] | -> ::parst::PResult<Self> {
						#assignments
						Ok((#finished, bytes))
					};
					if let Ok(x) = #fn_name(bytes) {
						return Ok(x);
					}
				}
			})
			.chain(once(quote! {
				Err(::parst::error::Error::InvalidInput)
			}))
			.collect(),
		Data::Union(_) => unimplemented!(),
	}
}
