use heck::*;
use std::io::{Read, Write};
use std::mem;
use std::path::Path;
use std::process::{Command, Stdio};
use witx::*;

pub fn generate<P: AsRef<Path>>(witx_paths: &[P]) -> String {
    let doc = witx::load(witx_paths).unwrap();

    let mut raw = String::new();
    raw.push_str(
        "\
// This file is automatically generated, DO NOT EDIT
//
// To regenerate this file run the `crates/witx-bindgen` command

use core::mem::MaybeUninit;
use core::fmt;
",
    );
    for ty in doc.typenames() {
        ty.render(&mut raw);
        raw.push('\n');
    }
    for m in doc.modules() {
        m.render(&mut raw);
        raw.push('\n');
    }
    for c in doc.constants() {
        rustdoc(&c.docs, &mut raw);
        raw.push_str(&format!(
            "pub const {}_{}: {} = {};\n",
            c.ty.as_str().to_shouty_snake_case(),
            c.name.as_str().to_shouty_snake_case(),
            c.ty.as_str().to_camel_case(),
            c.value
        ));
    }

    let mut rustfmt = Command::new("rustfmt")
        .stdin(Stdio::piped())
        .stdout(Stdio::piped())
        .spawn()
        .unwrap();
    rustfmt
        .stdin
        .take()
        .unwrap()
        .write_all(raw.as_bytes())
        .unwrap();
    let mut ret = String::new();
    rustfmt
        .stdout
        .take()
        .unwrap()
        .read_to_string(&mut ret)
        .unwrap();
    let status = rustfmt.wait().unwrap();
    assert!(status.success());
    ret
}

trait Render {
    fn render(&self, src: &mut String);
}

impl Render for NamedType {
    fn render(&self, src: &mut String) {
        let name = self.name.as_str();
        match &self.tref {
            TypeRef::Value(ty) => match &**ty {
                Type::Record(s) => render_record(src, name, s),
                Type::Handle(h) => render_handle(src, name, h),
                Type::Variant(h) => render_variant(src, name, h),
                Type::List { .. }
                | Type::Pointer { .. }
                | Type::ConstPointer { .. }
                | Type::Builtin { .. } => render_alias(src, name, &self.tref),
            },
            TypeRef::Name(_nt) => render_alias(src, name, &self.tref),
        }
    }
}

fn render_record(src: &mut String, name: &str, s: &RecordDatatype) {
    if let Some(repr) = s.bitflags_repr() {
        src.push_str(&format!("pub type {} = ", name.to_camel_case()));
        repr.render(src);
        src.push(';');
        for (i, member) in s.members.iter().enumerate() {
            rustdoc(&member.docs, src);
            src.push_str(&format!(
                "pub const {}_{}: {} = 1 << {};\n",
                name.to_shouty_snake_case(),
                member.name.as_str().to_shouty_snake_case(),
                name.to_camel_case(),
                i,
            ));
        }
        return;
    }
    src.push_str("#[repr(C)]\n");
    if record_contains_union(s) {
        // Unions can't automatically derive `Debug`.
        src.push_str("#[derive(Copy, Clone)]\n");
    } else {
        src.push_str("#[derive(Copy, Clone, Debug)]\n");
    }
    src.push_str(&format!("pub struct {} {{\n", name.to_camel_case()));
    for member in s.members.iter() {
        rustdoc(&member.docs, src);
        src.push_str("pub ");
        member.name.render(src);
        src.push_str(": ");
        member.tref.render(src);
        src.push_str(",\n");
    }
    src.push('}');
}

fn render_variant(src: &mut String, name: &str, v: &Variant) {
    if v.cases.iter().all(|c| c.tref.is_none()) {
        return render_enum_like_variant(src, name, v);
    }
    src.push_str("#[repr(C)]\n");
    src.push_str("#[derive(Copy, Clone)]\n");
    src.push_str(&format!("pub union {}U {{\n", name.to_camel_case()));
    for case in v.cases.iter() {
        if let Some(ref tref) = case.tref {
            rustdoc(&case.docs, src);
            src.push_str("pub ");
            case.name.render(src);
            src.push_str(": ");
            tref.render(src);
            src.push_str(",\n");
        }
    }
    src.push_str("}\n");
    src.push_str("#[repr(C)]\n");
    src.push_str("#[derive(Copy, Clone)]\n");
    src.push_str(&format!("pub struct {} {{\n", name.to_camel_case()));
    src.push_str("pub tag: ");
    v.tag_repr.render(src);
    src.push_str(",\n");
    src.push_str(&format!("pub u: {}U,\n", name.to_camel_case()));
    src.push_str("}\n");
}

fn render_enum_like_variant(src: &mut String, name: &str, s: &Variant) {
    src.push_str("#[repr(transparent)]\n");
    src.push_str("#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]\n");
    src.push_str(&format!("pub struct {}(", name.to_camel_case()));
    s.tag_repr.render(src);
    src.push_str(");\n");
    for (i, variant) in s.cases.iter().enumerate() {
        rustdoc(&variant.docs, src);
        src.push_str(&format!(
            "pub const {}_{}: {ty} = {ty}({});\n",
            name.to_shouty_snake_case(),
            variant.name.as_str().to_shouty_snake_case(),
            i,
            ty = name.to_camel_case(),
        ));
    }
    let camel_name = name.to_camel_case();

    src.push_str("impl ");
    src.push_str(&camel_name);
    src.push_str("{\n");

    src.push_str("pub const fn raw(&self) -> ");
    s.tag_repr.render(src);
    src.push_str("{ self.0 }\n\n");

    src.push_str("pub fn name(&self) -> &'static str {\n");
    src.push_str("match self.0 {");
    for (i, variant) in s.cases.iter().enumerate() {
        src.push_str(&i.to_string());
        src.push_str(" => \"");
        src.push_str(&variant.name.as_str().to_shouty_snake_case());
        src.push_str("\",");
    }
    src.push_str("_ => unsafe { core::hint::unreachable_unchecked() },");
    src.push_str("}\n");
    src.push_str("}\n");

    src.push_str("pub fn message(&self) -> &'static str {\n");
    src.push_str("match self.0 {");
    for (i, variant) in s.cases.iter().enumerate() {
        src.push_str(&i.to_string());
        src.push_str(" => \"");
        src.push_str(variant.docs.trim());
        src.push_str("\",");
    }
    src.push_str("_ => unsafe { core::hint::unreachable_unchecked() },");
    src.push_str("}\n");
    src.push_str("}\n");

    src.push_str("}\n");

    src.push_str("impl fmt::Debug for ");
    src.push_str(&camel_name);
    src.push_str("{\nfn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n");
    src.push_str("f.debug_struct(\"");
    src.push_str(&camel_name);
    src.push_str("\")");
    src.push_str(".field(\"code\", &self.0)");
    src.push_str(".field(\"name\", &self.name())");
    src.push_str(".field(\"message\", &self.message())");
    src.push_str(".finish()");
    src.push_str("}\n");
    src.push_str("}\n");

    // Auto-synthesize an implementation of the standard `Error` trait for
    // error-looking types based on their name.
    //
    // TODO: should this perhaps be an attribute in the witx file?
    if name.contains("errno") {
        src.push_str("impl fmt::Display for ");
        src.push_str(&camel_name);
        src.push_str("{\nfn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n");
        src.push_str("write!(f, \"{} (error {})\", self.name(), self.0)");
        src.push_str("}\n");
        src.push_str("}\n");
        src.push('\n');
        src.push_str("#[cfg(feature = \"std\")]\n");
        src.push_str("extern crate std;\n");
        src.push_str("#[cfg(feature = \"std\")]\n");
        src.push_str("impl std::error::Error for ");
        src.push_str(&camel_name);
        src.push_str("{}\n");
    }
}

impl Render for IntRepr {
    fn render(&self, src: &mut String) {
        match self {
            IntRepr::U8 => src.push_str("u8"),
            IntRepr::U16 => src.push_str("u16"),
            IntRepr::U32 => src.push_str("u32"),
            IntRepr::U64 => src.push_str("u64"),
        }
    }
}

fn render_alias(src: &mut String, name: &str, dest: &TypeRef) {
    src.push_str(&format!("pub type {}", name.to_camel_case()));
    if let Type::List(_) = &**dest.type_() {
        src.push_str("<'a>");
    }
    src.push_str(" = ");

    // Give `size` special treatment to translate it to `usize` in Rust instead of `u32`. Makes
    // things a bit nicer for client libraries. We can remove this hack once WASI moves to a
    // snapshot that uses BuiltinType::Size.
    if name == "size" {
        src.push_str("usize");
    } else {
        dest.render(src);
    }
    src.push(';');
}

impl Render for TypeRef {
    fn render(&self, src: &mut String) {
        match self {
            TypeRef::Name(t) => {
                src.push_str(&t.name.as_str().to_camel_case());
                if let Type::List(_) = &**t.type_() {
                    src.push_str("<'_>");
                }
            }
            TypeRef::Value(v) => match &**v {
                Type::Builtin(t) => t.render(src),
                Type::List(t) => match &**t.type_() {
                    Type::Builtin(BuiltinType::Char) => src.push_str("&str"),
                    _ => {
                        src.push_str("&'a [");
                        t.render(src);
                        src.push(']');
                    }
                },
                Type::Pointer(t) => {
                    src.push_str("*mut ");
                    t.render(src);
                }
                Type::ConstPointer(t) => {
                    src.push_str("*const ");
                    t.render(src);
                }
                Type::Variant(v) if v.is_bool() => src.push_str("bool"),
                Type::Variant(v) => match v.as_expected() {
                    Some((ok, err)) => {
                        src.push_str("Result<");
                        match ok {
                            Some(ty) => ty.render(src),
                            None => src.push_str("()"),
                        }
                        src.push(',');
                        match err {
                            Some(ty) => ty.render(src),
                            None => src.push_str("()"),
                        }
                        src.push('>');
                    }
                    None => {
                        panic!("unsupported anonymous variant")
                    }
                },
                Type::Record(r) if r.is_tuple() => {
                    src.push('(');
                    for member in r.members.iter() {
                        member.tref.render(src);
                        src.push(',');
                    }
                    src.push(')');
                }
                t => panic!("reference to anonymous {} not possible!", t.kind()),
            },
        }
    }
}

impl Render for BuiltinType {
    fn render(&self, src: &mut String) {
        match self {
            // A C `char` in Rust we just interpret always as `u8`. It's
            // technically possible to use `std::os::raw::c_char` but that's
            // overkill for the purposes that we'll be using this type for.
            BuiltinType::U8 { lang_c_char: _ } => src.push_str("u8"),
            BuiltinType::U16 => src.push_str("u16"),
            BuiltinType::U32 {
                lang_ptr_size: false,
            } => src.push_str("u32"),
            BuiltinType::U32 {
                lang_ptr_size: true,
            } => src.push_str("usize"),
            BuiltinType::U64 => src.push_str("u64"),
            BuiltinType::S8 => src.push_str("i8"),
            BuiltinType::S16 => src.push_str("i16"),
            BuiltinType::S32 => src.push_str("i32"),
            BuiltinType::S64 => src.push_str("i64"),
            BuiltinType::F32 => src.push_str("f32"),
            BuiltinType::F64 => src.push_str("f64"),
            BuiltinType::Char => src.push_str("char"),
        }
    }
}

impl Render for Module {
    fn render(&self, src: &mut String) {
        // wrapper functions
        for f in self.funcs() {
            render_highlevel(&f, &self.name, src);
            src.push_str("\n\n");
        }

        // raw module
        let rust_name = self.name.as_str().to_snake_case();
        src.push_str("pub mod ");
        src.push_str(&rust_name);
        src.push_str("{\n");
        src.push_str("#[link(wasm_import_module =\"");
        src.push_str(self.name.as_str());
        src.push_str("\")]\n");
        src.push_str("extern \"C\" {\n");
        for f in self.funcs() {
            f.render(src);
            src.push('\n');
        }
        src.push('}');
        src.push('}');
    }
}

fn render_highlevel(func: &InterfaceFunc, module: &Id, src: &mut String) {
    let mut rust_name = String::new();
    func.name.render(&mut rust_name);
    let rust_name = rust_name.to_snake_case();
    rustdoc(&func.docs, src);
    rustdoc_params(&func.params, "Parameters", src);
    rustdoc_params(&func.results, "Return", src);

    // Render the function and its arguments, and note that the arguments here
    // are the exact type name arguments as opposed to the pointer/length pair
    // ones. These functions are unsafe because they work with integer file
    // descriptors, which are effectively forgeable and danglable raw pointers
    // into the file descriptor address space.
    src.push_str("pub unsafe fn ");

    // TODO workout how to handle wasi-ephemeral which introduces multiple
    // WASI modules into the picture. For now, feature-gate it, and if we're
    // compiling ephmeral bindings, prefix wrapper syscall with module name.
    if cfg!(feature = "multi-module") {
        src.push_str(&[module.as_str().to_snake_case().as_str(), &rust_name].join("_"));
    } else {
        src.push_str(to_rust_ident(&rust_name));
    }

    src.push('(');
    for param in func.params.iter() {
        param.name.render(src);
        src.push_str(": ");
        param.tref.render(src);
        src.push(',');
    }
    src.push(')');

    match func.results.len() {
        0 => {}
        1 => {
            src.push_str(" -> ");
            func.results[0].tref.render(src);
        }
        _ => {
            src.push_str(" -> (");
            for result in func.results.iter() {
                result.tref.render(src);
                src.push_str(", ");
            }
            src.push(')');
        }
    }
    src.push('{');

    func.call_wasm(
        module,
        &mut Rust {
            src,
            params: &func.params,
            block_storage: Vec::new(),
            blocks: Vec::new(),
        },
    );

    src.push('}');
}

struct Rust<'a> {
    src: &'a mut String,
    params: &'a [InterfaceFuncParam],
    block_storage: Vec<String>,
    blocks: Vec<String>,
}

impl Bindgen for Rust<'_> {
    type Operand = String;

    fn push_block(&mut self) {
        let prev = std::mem::take(self.src);
        self.block_storage.push(prev);
    }

    fn finish_block(&mut self, operand: Option<String>) {
        let to_restore = self.block_storage.pop().unwrap();
        let src = mem::replace(self.src, to_restore);
        match operand {
            None => {
                assert!(src.is_empty());
                self.blocks.push("()".to_string());
            }
            Some(s) => {
                if src.is_empty() {
                    self.blocks.push(s);
                } else {
                    self.blocks.push(format!("{{ {}; {} }}", src, s));
                }
            }
        }
    }

    fn allocate_space(&mut self, n: usize, ty: &witx::NamedType) {
        self.src
            .push_str(&format!("let mut rp{} = MaybeUninit::<", n));
        self.src.push_str(&ty.name.as_str().to_camel_case());
        self.src.push_str(">::uninit();");
    }

    fn emit(
        &mut self,
        inst: &Instruction<'_>,
        operands: &mut Vec<String>,
        results: &mut Vec<String>,
    ) {
        let mut top_as = |cvt: &str| {
            let mut s = operands.pop().unwrap();
            s.push_str(" as ");
            s.push_str(cvt);
            results.push(s);
        };

        match inst {
            Instruction::GetArg { nth } => {
                let mut s = String::new();
                self.params[*nth].name.render(&mut s);
                results.push(s);
            }
            Instruction::AddrOf => {
                results.push(format!("&{} as *const _ as i32", operands[0]));
            }
            Instruction::I64FromBitflags { .. } | Instruction::I64FromU64 => top_as("i64"),
            Instruction::I32FromPointer
            | Instruction::I32FromConstPointer
            | Instruction::I32FromHandle { .. }
            | Instruction::I32FromUsize
            | Instruction::I32FromChar
            | Instruction::I32FromU8
            | Instruction::I32FromS8
            | Instruction::I32FromChar8
            | Instruction::I32FromU16
            | Instruction::I32FromS16
            | Instruction::I32FromU32
            | Instruction::I32FromBitflags { .. } => top_as("i32"),

            Instruction::EnumLower { .. } => {
                results.push(format!("{}.0 as i32", operands[0]));
            }

            Instruction::F32FromIf32
            | Instruction::F64FromIf64
            | Instruction::If32FromF32
            | Instruction::If64FromF64
            | Instruction::I64FromS64
            | Instruction::I32FromS32 => {
                results.push(operands.pop().unwrap());
            }
            Instruction::ListPointerLength => {
                let list = operands.pop().unwrap();
                results.push(format!("{}.as_ptr() as i32", list));
                results.push(format!("{}.len() as i32", list));
            }
            Instruction::S8FromI32 => top_as("i8"),
            Instruction::Char8FromI32 | Instruction::U8FromI32 => top_as("u8"),
            Instruction::S16FromI32 => top_as("i16"),
            Instruction::U16FromI32 => top_as("u16"),
            Instruction::S32FromI32 => {}
            Instruction::U32FromI32 => top_as("u32"),
            Instruction::S64FromI64 => {}
            Instruction::U64FromI64 => top_as("u64"),
            Instruction::UsizeFromI32 => top_as("usize"),
            Instruction::HandleFromI32 { .. } => top_as("u32"),
            Instruction::PointerFromI32 { .. } => top_as("*mut _"),
            Instruction::ConstPointerFromI32 { .. } => top_as("*const _"),
            Instruction::BitflagsFromI32 { .. } => unimplemented!(),
            Instruction::BitflagsFromI64 { .. } => unimplemented!(),

            Instruction::ReturnPointerGet { n } => {
                results.push(format!("rp{}.as_mut_ptr() as i32", n));
            }

            Instruction::Load { ty } => {
                let mut s = format!("core::ptr::read({} as *const ", &operands[0]);
                s.push_str(&ty.name.as_str().to_camel_case());
                s.push(')');
                results.push(s);
            }

            Instruction::ReuseReturn => {
                results.push("ret".to_string());
            }

            Instruction::TupleLift { .. } => {
                let value = format!("({})", operands.join(", "));
                results.push(value);
            }

            Instruction::ResultLift => {
                let err = self.blocks.pop().unwrap();
                let ok = self.blocks.pop().unwrap();
                let mut result = format!("match {} {{", operands[0]);
                result.push_str("0 => Ok(");
                result.push_str(&ok);
                result.push_str("),");
                result.push_str("_ => Err(");
                result.push_str(&err);
                result.push_str("),");
                result.push('}');
                results.push(result);
            }

            Instruction::EnumLift { ty } => {
                let mut result = ty.name.as_str().to_camel_case();
                result.push('(');
                result.push_str(&operands[0]);
                result.push_str(" as ");
                match &**ty.type_() {
                    Type::Variant(v) => v.tag_repr.render(&mut result),
                    _ => unreachable!(),
                }
                result.push(')');
                results.push(result);
            }

            Instruction::CharFromI32 => unimplemented!(),

            Instruction::CallWasm {
                module,
                name,
                params: _,
                results: func_results,
            } => {
                assert!(func_results.len() < 2);
                if !func_results.is_empty() {
                    self.src.push_str("let ret = ");
                    results.push("ret".to_string());
                }
                self.src.push_str(&module.to_snake_case());
                self.src.push_str("::");
                self.src.push_str(to_rust_ident(&name.to_snake_case()));
                self.src.push('(');
                self.src.push_str(&operands.join(", "));
                self.src.push_str(");");
            }

            Instruction::Return { amt: 0 } => {}
            Instruction::Return { amt: 1 } => {
                self.src.push_str(&operands[0]);
            }
            Instruction::Return { .. } => {
                self.src.push('(');
                self.src.push_str(&operands.join(", "));
                self.src.push(')');
            }

            Instruction::Store { .. }
            | Instruction::ListFromPointerLength { .. }
            | Instruction::CallInterface { .. }
            | Instruction::ResultLower { .. }
            | Instruction::TupleLower { .. }
            | Instruction::VariantPayload => unimplemented!(),
        }
    }
}

impl Render for InterfaceFunc {
    fn render(&self, src: &mut String) {
        rustdoc(&self.docs, src);
        if self.name.as_str() != self.name.as_str().to_snake_case() {
            src.push_str("#[link_name = \"");
            src.push_str(self.name.as_str());
            src.push_str("\"]\n");
        }
        src.push_str("pub fn ");
        let mut name = String::new();
        self.name.render(&mut name);
        src.push_str(to_rust_ident(&name.to_snake_case()));

        let (params, results) = self.wasm_signature();
        assert!(results.len() <= 1);
        src.push('(');
        for (i, param) in params.iter().enumerate() {
            src.push_str(&format!("arg{}: ", i));
            param.render(src);
            src.push(',');
        }
        src.push(')');

        if self.noreturn {
            src.push_str(" -> !");
        } else if let Some(result) = results.get(0) {
            src.push_str(" -> ");
            result.render(src);
        }
        src.push(';');
    }
}

fn to_rust_ident(name: &str) -> &str {
    match name {
        "in" => "in_",
        "type" => "type_",
        "yield" => "yield_",
        s => s,
    }
}

impl Render for Id {
    fn render(&self, src: &mut String) {
        src.push_str(to_rust_ident(self.as_str()))
    }
}

impl Render for WasmType {
    fn render(&self, src: &mut String) {
        match self {
            WasmType::I32 => src.push_str("i32"),
            WasmType::I64 => src.push_str("i64"),
            WasmType::F32 => src.push_str("f32"),
            WasmType::F64 => src.push_str("f64"),
        }
    }
}

fn render_handle(src: &mut String, name: &str, _h: &HandleDatatype) {
    src.push_str(&format!("pub type {} = u32;", name.to_camel_case()));
}

fn rustdoc(docs: &str, dst: &mut String) {
    if docs.trim().is_empty() {
        return;
    }
    for line in docs.lines() {
        dst.push_str("/// ");
        dst.push_str(line);
        dst.push('\n');
    }
}

fn rustdoc_params(docs: &[InterfaceFuncParam], header: &str, dst: &mut String) {
    let docs = docs
        .iter()
        .filter(|param| !param.docs.trim().is_empty())
        .collect::<Vec<_>>();
    if docs.is_empty() {
        return;
    }

    dst.push_str("///\n");
    dst.push_str("/// ## ");
    dst.push_str(header);
    dst.push('\n');
    dst.push_str("///\n");

    for param in docs {
        for (i, line) in param.docs.lines().enumerate() {
            dst.push_str("/// ");
            // Currently wasi only has at most one return value, so there's no
            // need to indent it or name it.
            if header != "Return" {
                if i == 0 {
                    dst.push_str("* `");
                    param.name.render(dst);
                    dst.push_str("` - ");
                } else {
                    dst.push_str("  ");
                }
            }
            dst.push_str(line);
            dst.push('\n');
        }
    }
}

fn record_contains_union(s: &RecordDatatype) -> bool {
    s.members
        .iter()
        .any(|member| type_contains_union(member.tref.type_()))
}

fn type_contains_union(ty: &Type) -> bool {
    match ty {
        Type::Variant(c) => c.cases.iter().any(|c| c.tref.is_some()),
        Type::List(tref) => type_contains_union(tref.type_()),
        Type::Record(st) => record_contains_union(st),
        _ => false,
    }
}
