extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use syn::{ItemEnum, ItemStruct, ItemType};

fn get_name(item: &proc_macro2::TokenStream) -> Option<Ident> {
    let ident = if let Ok(ItemStruct { ident, .. }) = syn::parse2(item.clone()) {
        ident
    } else if let Ok(ItemEnum { ident, .. }) = syn::parse2(item.clone()) {
        ident
    } else if let Ok(ItemType { ident, .. }) = syn::parse2(item.clone()) {
        ident
    } else {
        return None;
    };

    Some(ident)
}

fn jamsocket_wasm_impl(item: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
    let name =
        get_name(&item).expect("Can only use #[jamsocket_wasm] on a struct, enum, or type alias.");

    quote! {
        #item

        mod _jamsocket_wasm_macro_autogenerated {
            extern crate alloc;

            use super::#name;
            use jamsocket_wasm::prelude::{
                MessageRecipient,
                SimpleJamsocketService,
                JamsocketContext,
                ClientId,
            };

            // Instance-global jamsocket service.
            static mut SERVER_STATE: Option<#name> = None;

            #[no_mangle]
            pub static JAMSOCKET_API_VERSION: i32 = 1;

            #[no_mangle]
            pub static JAMSOCKET_API_PROTOCOL: i32 = 0;

            struct GlobalJamsocketContext;

            impl JamsocketContext for GlobalJamsocketContext {
                fn set_timer(&self, ms_delay: u32) {
                    unsafe {
                        ffi::set_timer(ms_delay);
                    }
                }

                fn send_message(&self, recipient: impl Into<MessageRecipient>, message: &str) {
                    unsafe {
                        ffi::send_message(
                            recipient.into().encode_u32(),
                            &message.as_bytes()[0] as *const u8 as u32,
                            message.len() as u32,
                        );
                    }
                }

                fn send_binary(&self, recipient: impl Into<MessageRecipient>, message: &[u8]) {
                    unsafe {
                        ffi::send_binary(
                            recipient.into().encode_u32(),
                            &message[0] as *const u8 as u32,
                            message.len() as u32,
                        );
                    }
                }
            }

            // Functions implemented by the host.
            mod ffi {
                extern "C" {
                    pub fn send_message(client: u32, message: u32, message_len: u32);

                    pub fn send_binary(client: u32, message: u32, message_len: u32);

                    pub fn set_timer(ms_delay: u32);
                }
            }

            // Functions provided to the host.
            #[no_mangle]
            extern "C" fn initialize(room_id_ptr: *const u8, room_id_len: usize) {
                let room_id = unsafe {
                    String::from_utf8(std::slice::from_raw_parts(room_id_ptr, room_id_len).to_vec()).map_err(|e| format!("Error parsing UTF-8 from host {:?}", e)).unwrap()
                };
                let mut c = #name::new(&room_id, &GlobalJamsocketContext);

                unsafe {
                    SERVER_STATE.replace(c);
                }
            }

            #[no_mangle]
            extern "C" fn connect(client_id: ClientId) {
                match unsafe { SERVER_STATE.as_mut() } {
                    Some(st) => SimpleJamsocketService::connect(st, client_id.into(), &GlobalJamsocketContext),
                    None => ()
                }
            }

            #[no_mangle]
            extern "C" fn disconnect(client_id: ClientId) {
                match unsafe { SERVER_STATE.as_mut() } {
                    Some(st) => SimpleJamsocketService::disconnect(st, client_id.into(), &GlobalJamsocketContext),
                    None => ()
                }
            }

            #[no_mangle]
            extern "C" fn timer() {
                match unsafe { SERVER_STATE.as_mut() } {
                    Some(st) => SimpleJamsocketService::timer(st, &GlobalJamsocketContext),
                    None => ()
                }
            }

            #[no_mangle]
            extern "C" fn message(client_id: ClientId, ptr: *const u8, len: usize) {
                unsafe {
                    let string = String::from_utf8(std::slice::from_raw_parts(ptr, len).to_vec()).map_err(|e| format!("Error parsing UTF-8 from host {:?}", e)).unwrap();

                    match SERVER_STATE.as_mut() {
                        Some(st) => SimpleJamsocketService::message(st, client_id.into(), &string, &GlobalJamsocketContext),
                        None => ()
                    }
                }
            }

            #[no_mangle]
            extern "C" fn binary(client_id: ClientId, ptr: *const u8, len: usize) {
                unsafe {
                    let data = std::slice::from_raw_parts(ptr, len);

                    match SERVER_STATE.as_mut() {
                        Some(st) => SimpleJamsocketService::binary(st, client_id.into(), data, &GlobalJamsocketContext),
                        None => ()
                    }
                }
            }

            #[no_mangle]
            pub unsafe extern "C" fn jam_malloc(size: u32) -> *mut u8 {
                let layout = core::alloc::Layout::from_size_align_unchecked(size as usize, 0);
                alloc::alloc::alloc(layout)
            }

            #[no_mangle]
            pub unsafe extern "C" fn jam_free(ptr: *mut u8, size: u32) {
                let layout = core::alloc::Layout::from_size_align_unchecked(size as usize, 0);
                alloc::alloc::dealloc(ptr, layout);
            }
        }
    }
}

/// Exposes a `jamsocket_wasm::SimpleJamsocketService`-implementing trait as a WebAssembly module.
#[proc_macro_attribute]
pub fn jamsocket_wasm(_attr: TokenStream, item: TokenStream) -> TokenStream {
    jamsocket_wasm_impl(item.into()).into()
}

#[cfg(test)]
mod test {
    use super::get_name;
    use quote::quote;

    #[test]
    fn test_parse_name() {
        assert_eq!(
            "MyStruct",
            get_name(&quote! {
                struct MyStruct {}
            })
            .unwrap()
            .to_string()
        );

        assert_eq!(
            "AnotherStruct",
            get_name(&quote! {
                struct AnotherStruct;
            })
            .unwrap()
            .to_string()
        );

        assert_eq!(
            "ATupleStruct",
            get_name(&quote! {
                struct ATupleStruct(u32, u32, u32);
            })
            .unwrap()
            .to_string()
        );

        assert_eq!(
            "AnEnum",
            get_name(&quote! {
                enum AnEnum {
                    Option1,
                    Option2(u32),
                }
            })
            .unwrap()
            .to_string()
        );

        assert_eq!(
            "ATypeDecl",
            get_name(&quote! {
                type ATypeDecl = u32;
            })
            .unwrap()
            .to_string()
        );

        assert!(get_name(&quote! {
            impl Foo {}
        })
        .is_none());
    }
}
