use crate::utils;
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
    parse::{Error, Parse, ParseStream, Result},
    parse_quote,
    punctuated::Punctuated,
    spanned::Spanned,
    Attribute, FnArg, ImplItem, ItemImpl, Pat, Receiver, Signature, Token, Type,
};

impl ToTokens for Item {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        self.0.to_tokens(tokens);
    }
}

struct ExternType {
    /// The identifier of the original type
    ident: syn::Ident,
    /// The identifier to be used as a FFI function argument
    /// Extracted from #[repr(primitive_type)]
    ffi_ty: Type,
}

pub fn expand(input: &mut Item) -> Result<()> {
    // Expand types in items.
    expand_types(&mut input.0.items);

    // Collect all extern types.
    let mut extern_enums = Vec::new();
    for inner in &input.0.items {
        if let ImplItem::Type(item_type) = inner {
            if item_type.ident.to_string().ends_with("_Repr") {
                extern_enums.push(ExternType {
                    ident: item_type.ident.clone(),
                    ffi_ty: item_type.ty.clone(),
                });
            }
        }
    }

    let deprecated_infallible_attr: Attribute = parse_quote!(#[deprecated_infallible]);

    // Generate an export function for each module method.
    let mut export_methods = Vec::new();
    for inner in &mut input.0.items {
        if let ImplItem::Method(method) = inner {
            // Parse custom attributes, and remove them if seen.
            let mut deprecated_infallible = false;
            method.attrs.retain(|attr| {
                if utils::has_custom_attribute(attr, &deprecated_infallible_attr) {
                    deprecated_infallible = true;
                    false
                } else {
                    true
                }
            });

            // Generate the export function if it's a shim function.
            let sig = &mut method.sig;
            if let Some(exp) = get_export_function(sig, deprecated_infallible, &extern_enums)? {
                export_methods.push(exp);
            }
        }
    }

    // Append all of the new export functions to the original impl block.
    let import_method = expand_imports(&export_methods);
    input.0.items.push(import_method);

    input.0.items.extend(
        export_methods
            .into_iter()
            .map(|func| syn::ImplItem::Method(func.item)),
    );

    Ok(())
}

pub struct Args;

impl Parse for Args {
    fn parse(input: ParseStream<'_>) -> Result<Self> {
        if input.is_empty() {
            Ok(Args)
        } else {
            Err(Error::new(Span::call_site(), "expected #[host_exports]"))
        }
    }
}

pub struct Nothing;

impl Parse for Nothing {
    fn parse(_input: ParseStream<'_>) -> Result<Self> {
        Ok(Nothing)
    }
}

pub struct Item(ItemImpl);

impl Parse for Item {
    fn parse(input: ParseStream<'_>) -> Result<Self> {
        let attrs = input.call(Attribute::parse_outer)?;
        let lookahead = input.lookahead1();
        if lookahead.peek(Token![impl]) {
            let mut item: ItemImpl = input.parse()?;
            item.attrs = attrs;
            Ok(Self(item))
        } else {
            Err(lookahead.error())
        }
    }
}

enum FallibleMode {
    /// shim returns `Result<T, ApiError>`, FFI function returns u32.
    Fallible,
    /// shim returns `Result<T, WasmTrap>`, FFI function returns u32.
    DeprecatedInfallible,
    /// shim returns `Result<T, WasmTrap>`, FFI function returns void.
    Infallible,
}

struct ExportFunction {
    /// Implementation of the export function.
    item: syn::ImplItemMethod,

    /// Base name of the function, without the `_export` suffix.
    base_name: String,

    fallible_mode: FallibleMode,
}

/// Creates the associated export function based on the associated shim function.
///
/// Input:
///
/// ```ignore
/// fn ident_shim(&mut self, args...) -> Result<T, Self::Err | Self::WasmTrap>;
/// ```
///
/// Output:
///
/// ```ignore
/// fn ident_export(
///      memory: &mut Self::Memory,
///      host_context: &mut Self::Context,
///      args... // Returned `T` comes here as a `&mut T` or a combination of lower-level arguments.
/// ) -> Result<u32 | (), Self::Err | Self::WasmTrap>;
/// ```
fn get_export_function(
    sig: &Signature,
    deprecated_infallible: bool,
    extern_enums: &[ExternType],
) -> Result<Option<ExportFunction>> {
    let shim_ident = &sig.ident;
    let shim_ident_str = format!("{}", shim_ident);
    if !shim_ident_str.ends_with("_shim") {
        return Ok(None);
    }
    let base_name = shim_ident_str.trim_end_matches("_shim");

    // Check the return type of the shim, it needs to be a Result<Type, _>.
    let (res_type, is_unit, infallible) = match &sig.output {
        syn::ReturnType::Default => {
            return Err(Error::new(
                sig.output.span(),
                "unexpected return type: only Result is allowed!",
            ));
        }

        syn::ReturnType::Type(_, retty) => match retty.as_ref() {
            syn::Type::Path(tp) => {
                if utils::type_path_ends_with(tp, "Result") {
                    let err_type = utils::extract_generic_type(1, None, tp)?.unwrap();

                    // XXX very approximate: wasmtime::Trap or Self::WasmTrap
                    let infallible = utils::type_path_ends_with(&err_type, "Trap")
                        || utils::type_path_ends_with(&err_type, "WasmTrap");

                    match utils::extract_first_generic_type(tp)? {
                        Some(tp) => (Some(tp), false, infallible),
                        None => (Some(tp.clone()), true, infallible),
                    }
                } else {
                    return Err(Error::new(
                        retty.span(),
                        "unexpected return type: only Result is allowed!",
                    ));
                }
            }
            _ => return Err(Error::new(retty.span(), "unexpected return type")),
        },
    };

    let fallible_mode = if deprecated_infallible {
        if !infallible {
            return Err(Error::new(sig.output.span(), "functions marked as #[deprecated_infallible] must return Result<T, wasmtime::Trap>"));
        }
        FallibleMode::DeprecatedInfallible
    } else if infallible {
        FallibleMode::Infallible
    } else {
        FallibleMode::Fallible
    };

    // Whether we have a &self, and whether it is mut
    let mut is_self = None;

    let mut args = Vec::new();
    let mut optional_args = Vec::new();
    for arg in &sig.inputs {
        match arg {
            FnArg::Receiver(Receiver {
                reference: Some(_),
                mutability,
                ..
            }) => {
                is_self = Some(mutability.is_some());
            }
            FnArg::Receiver(arg) => {
                return Err(syn::Error::new(arg.span(), "must take self by reference"));
            }
            FnArg::Typed(arg) => {
                if let Pat::Ident(pat) = &*arg.pat {
                    assert!(
                        pat.ident != "host_context",
                        "module context shouldn't be passed to shim function"
                    );

                    // The wasm memory argument shouldn't be converted.
                    if pat.ident == "memory" {
                        optional_args.push(&pat.ident);
                        continue;
                    }
                    args.push((
                        &pat.ident,
                        convert_arg(&pat.ident, arg.ty.as_ref(), extern_enums)?,
                    ));
                } else {
                    return Err(syn::Error::new(
                        arg.span(),
                        "argument does not have an identifier",
                    ));
                }
            }
        }
    }

    let export_name = syn::Ident::new(&format!("{}_export", base_name), shim_ident.span());

    // Create the function signature
    // 1. memory is _always_ passed as first argument
    // 2. host_context is _always_ passed as second argument
    // 3. We _always_ return a u32 error code
    // 4. Add all of the args translated from the original shim function
    let mut export_sig = {
        let err_type = if matches!(fallible_mode, FallibleMode::Fallible) {
            quote!(Self::Err)
        } else {
            quote!(Self::WasmTrap)
        };
        let export_sig: ImplItem = parse_quote!(
            fn #export_name(memory: &mut Self::Memory, host_context: &mut Self::Context) -> Result<(), #err_type> {}
        );

        if let ImplItem::Method(mut method) = export_sig {
            for arg in args.iter().flat_map(|(_, inp)| inp.args.iter()) {
                method.sig.inputs.push(arg.clone());
            }
            method.sig
        } else {
            unreachable!()
        }
    };

    // Generate the parameters that we'll be passing to the shim function
    let params = {
        let mut params: syn::punctuated::Punctuated<syn::Expr, syn::token::Comma> =
            syn::punctuated::Punctuated::new();

        // If present, push the optional arguments first.
        for arg in optional_args {
            params.push(parse_quote!(#arg));
        }

        for (ident, arg) in &args {
            params.push(match &arg.from_wasm {
                Some(block) => parse_quote!(#block),
                None => parse_quote!(#ident),
            });
        }

        params
    };

    let call = if is_self.is_some() {
        quote!(Self::get(host_context)?.#shim_ident(#params))
    } else {
        quote!(Self::#shim_ident(#params))
    };

    let call = match res_type {
        Some(tp) => {
            if is_unit {
                quote!(#call)
            } else {
                // Add the output pointer to the params
                export_sig
                    .inputs
                    .push(parse_quote!(__ark_ffi_output_ptr: u32));

                let mut wrapped_call = None;

                if let Some(last) = tp.path.segments.last() {
                    if last.ident == "Vec" || last.ident == "String" {
                        let ensure_bytes: Option<syn::Stmt> = if last.ident == "Vec" {
                            let elem_tp = utils::extract_single_generic_type(&tp)?;
                            if elem_tp.map_or(false, |tp| !utils::type_path_ends_with(&tp, "u8")) {
                                return Err(Error::new(
                                    tp.span(),
                                    "only Vec of u8 is allowed in return type position",
                                ));
                            }
                            None
                        } else {
                            Some(parse_quote!(let res: Vec<u8> = res.into();))
                        };

                        wrapped_call = Some(quote!(#call.and_then(|res| {
                            #ensure_bytes
                            let output = crate::wasm_util::get_value_mut(memory, __ark_ffi_output_ptr)?;
                            // Return the length of the produced vector through the out-parameter.
                            *output = res.len() as u32;
                            // Buffer the produced vector on the host, where it can be retrieved
                            // through `take_host_return_vec`
                            host_context.core.set_host_return_vec(res);
                            Ok(())
                        })));
                    }
                }

                match wrapped_call {
                    Some(w) => w,
                    None => {
                        // Copy the result in the out-parameter, for scalar types.
                        quote!(#call.and_then(|res| {
                            let output = crate::wasm_util::get_value_mut(memory, __ark_ffi_output_ptr)?;
                            *output = res;
                            Ok(())
                        }))
                    }
                }
            }
        }
        None => quote!(Ok(#call)),
    };

    Ok(Some(ExportFunction {
        item: parse_quote!(
            #export_sig {
                #call
            }
        ),
        fallible_mode,
        base_name: base_name.to_string(),
    }))
}

struct ExportArg {
    /// A single arg can in fact be two, notably, a pointer and a length
    args: Vec<syn::FnArg>,
    /// A block that converts the arguments passed from the wasm side, into
    /// the _actual_ host side type used by the API
    from_wasm: Option<syn::Block>,
}

fn is_str(ty: &syn::Type) -> bool {
    if let Type::Path(tp) = ty {
        match tp.path.get_ident() {
            None => false,
            Some(id) => {
                let idents = id.to_string();
                idents == "str"
            }
        }
    } else {
        false
    }
}

fn convert_arg(
    ident: &syn::Ident,
    ty: &syn::Type,
    extern_enums: &[ExternType],
) -> Result<ExportArg> {
    let export_arg = match ty {
        syn::Type::Path(tp) => {
            let mut param: syn::FnArg = parse_quote!(#ident: #ty);
            let mut from_wasm = None;

            let extern_enum = extern_enums.iter().find(|ee| {
                if let Some(enum_str) = ee.ident.to_string().strip_suffix("_Repr") {
                    tp.path.segments.last().unwrap().ident == enum_str
                } else {
                    false
                }
            });

            if let Some(ee) = extern_enum {
                let ffi_type = &ee.ffi_ty;

                param = parse_quote!(#ident: #ffi_type);
                from_wasm = Some(parse_quote!({
                    std::convert::TryFrom::try_from(#ident).map_err(|_e| ApiError::invalid_arguments(""))? // TODO: include error in chain?
                }));
            }

            ExportArg {
                args: vec![param],
                from_wasm,
            }
        }
        syn::Type::Reference(tr) => {
            let is_mut = tr.mutability.is_some();

            // We only support 3 reference types, scalar types, slices of
            // scalar types, str, and cstr
            if is_str(tr.elem.as_ref()) {
                // We don't allow &mut str
                if is_mut {
                    return Err(syn::Error::new(ty.span(), "&mut str is not allowed!"));
                }

                let ident_ptr = syn::Ident::new(&format!("{}_ptr", ident), ident.span());
                let ident_len = syn::Ident::new(&format!("{}_len", ident), ident.span());

                ExportArg {
                    args: vec![
                        // The pointer
                        parse_quote!(#ident_ptr: u32),
                        // The length
                        parse_quote!(#ident_len: u32),
                    ],
                    from_wasm: Some(parse_quote!({
                        crate::wasm_util::memory_string(memory, #ident_ptr, #ident_len)?
                    })),
                }
            } else if let syn::Type::Slice(inner) = tr.elem.as_ref() {
                if let syn::Type::Path(_tp) = inner.elem.as_ref() {
                    let ident_ptr = syn::Ident::new(&format!("{}_ptr", ident), ident.span());
                    let ident_len = syn::Ident::new(&format!("{}_len", ident), ident.span());

                    ExportArg {
                        args: vec![
                            // The pointer
                            parse_quote!(#ident_ptr: u32),
                            // The length
                            parse_quote!(#ident_len: u32),
                        ],
                        from_wasm: Some(if is_mut {
                            parse_quote!({
                                crate::wasm_util::read_slice_mut(memory, #ident_ptr, #ident_len)?
                            })
                        } else {
                            parse_quote!({
                                crate::wasm_util::read_slice(memory, #ident_ptr, #ident_len)?
                            })
                        }),
                    }
                } else {
                    return Err(Error::new(tr.elem.span(), "not a simple type path"));
                }
            } else if let syn::Type::Path(_tp) = tr.elem.as_ref() {
                let ident_ptr = syn::Ident::new(&format!("{}_ptr", ident), ident.span());

                ExportArg {
                    args: vec![
                        // The pointer, since it's only 1 element we don't need a length
                        parse_quote!(#ident_ptr: u32),
                    ],
                    from_wasm: Some(if is_mut {
                        parse_quote!({
                            crate::wasm_util::get_value_mut(memory, #ident_ptr)?
                        })
                    } else {
                        parse_quote!({
                            crate::wasm_util::get_value(memory, #ident_ptr)?
                        })
                    }),
                }
            } else {
                return Err(Error::new(tr.span(), "this type is not supported"));
            }
        }
        _ => return Err(Error::new(ty.span(), "this type is not supported")),
    };

    Ok(export_arg)
}

fn expand_imports(export_functions: &[ExportFunction]) -> syn::ImplItem {
    let mut import_method: syn::ImplItemMethod = parse_quote!(
        fn imports(it: Self::ImportTable) -> Result<(), Self::ImportError> {}
    );

    let mut block: syn::Block = parse_quote!({
        let (namespace, prefix) = Self::namespace();
        let wasmtime_linker = unsafe { &mut *it };
    });

    // Add an import name bound to each export function
    for export_func in export_functions {
        let method = &export_func.item;
        let export_ident = &method.sig.ident;

        // Closure inputs are the same as the export function signature except for the first two arguments.
        let first_input: Punctuated<FnArg, Token![,]> =
            parse_quote!(mut caller: wasmtime::Caller<'_, ModuleContext>,);
        let closure_inputs = first_input
            .pairs()
            .chain(method.sig.inputs.pairs().skip(2)) // Skip 'memory' and 'host_context'
            .collect::<Punctuated<&FnArg, &Token![,]>>();

        let mut actual_params: Punctuated<&syn::Ident, syn::token::Comma> = Punctuated::new();
        method.sig.inputs.iter().for_each(|arg| {
            if let FnArg::Typed(arg) = arg {
                if let Pat::Ident(pat) = &*arg.pat {
                    actual_params.push(&pat.ident);
                }
            }
        });

        let name = &export_func.base_name;
        let log_call = match export_func.fallible_mode {
            FallibleMode::Fallible => quote!(Self::log_call(#name, result)),
            FallibleMode::DeprecatedInfallible => {
                quote!(Self::log_deprecated_infallible(#name, result))
            }
            FallibleMode::Infallible => quote!(Self::log_infallible_call(#name, result)),
        };

        let ffi_ok_type = if matches!(export_func.fallible_mode, FallibleMode::Infallible) {
            quote!(())
        } else {
            quote!(u32)
        };

        block.stmts.push(parse_quote!(
            let _ = wasmtime_linker.func_wrap(
                namespace,
                format!("{}__{}", prefix, #name).as_str(),
                move |#closure_inputs| -> Result<#ffi_ok_type, wasmtime::Trap> {
                    let (mut memory, host_context) = crate::wasm_util::get_host_context_from_caller(&mut caller)
                        .map_err(|err| wasmtime::Trap::new(err.display().to_string()))?;
                    let memory = &mut memory;
                    let result = Self::#export_ident(#actual_params);
                    #log_call
                },
            )?;
        ));
    }

    block.stmts.push(parse_quote!(return Ok(());));

    import_method.block = block;
    syn::ImplItem::Method(import_method)
}

fn expand_types(input: &mut Vec<ImplItem>) {
    // Add all type implementations needed by the macro
    // except for 'Err' as that could be API specific.
    let memory_type: syn::ImplItemType = parse_quote!(
        type Memory = crate::wasm_util::WasmMemoryHandle<'t>;
    );
    let context_type: syn::ImplItemType = parse_quote!(
        type Context = ModuleContext;
    );
    let import_type: syn::ImplItemType = parse_quote!(
        type ImportTable = *mut crate::host_api::WasmLinker;
    );
    let import_error_type: syn::ImplItemType = parse_quote!(
        type ImportError = anyhow::Error;
    );
    let wasm_trap_type: syn::ImplItemType = parse_quote!(
        type WasmTrap = wasmtime::Trap;
    );

    input.push(ImplItem::Type(memory_type));
    input.push(ImplItem::Type(context_type));
    input.push(ImplItem::Type(import_type));
    input.push(ImplItem::Type(import_error_type));
    input.push(ImplItem::Type(wasm_trap_type));
}
