use crate::{
    module::{PrimitiveTypeInfo, TypeInfo},
    Result,
};
use std::collections::HashSet;
use std::convert::TryFrom;
use wasm_encoder::{RawSection, SectionId};
use wasmparser::{Chunk, Parser, Payload, SectionReader};

/// Provides module information for future usage during mutation
/// an instance of ModuleInfo could be user to determine which mutation could be applied
#[derive(Default, Clone, Debug)]
pub struct ModuleInfo<'a> {
    // The following fields are offsets inside the `raw_sections` field.
    // The main idea is to maintain the order of the sections in the input Wasm.
    pub exports: Option<usize>,
    pub export_names: HashSet<String>,

    // Indices of various sections within `self.raw_sections`.
    pub types: Option<usize>,
    pub imports: Option<usize>,
    pub tables: Option<usize>,
    pub memories: Option<usize>,
    pub globals: Option<usize>,
    pub elements: Option<usize>,
    pub functions: Option<usize>,
    pub data_count: Option<usize>,
    pub data: Option<usize>,
    pub code: Option<usize>,
    pub start: Option<usize>,

    pub exports_count: u32,
    elements_count: u32,
    data_segments_count: u32,
    start_function: Option<u32>,
    memory_count: u32,
    table_count: u32,
    tag_count: u32,

    imported_functions_count: u32,
    imported_globals_count: u32,
    imported_memories_count: u32,
    imported_tables_count: u32,
    imported_tags_count: u32,

    // types for inner functions
    pub types_map: Vec<TypeInfo>,

    // function idx to type idx
    pub function_map: Vec<u32>,
    pub global_types: Vec<PrimitiveTypeInfo>,

    // raw_sections
    pub raw_sections: Vec<RawSection<'a>>,
    pub input_wasm: &'a [u8],
}

impl<'a> ModuleInfo<'a> {
    /// Parse the given Wasm bytes and fill out a `ModuleInfo` AST for it.
    pub fn new(input_wasm: &[u8]) -> Result<ModuleInfo> {
        let mut parser = Parser::new(0);
        let mut info = ModuleInfo::default();
        let mut wasm = input_wasm;
        info.input_wasm = wasm;

        loop {
            let (payload, consumed) = match parser.parse(wasm, true)? {
                Chunk::NeedMoreData(hint) => {
                    panic!("Invalid Wasm module {:?}", hint);
                }
                Chunk::Parsed { consumed, payload } => (payload, consumed),
            };
            match payload {
                Payload::CodeSectionStart {
                    count: _,
                    range,
                    size: _,
                } => {
                    info.code = Some(info.raw_sections.len());
                    info.section(SectionId::Code.into(), range, input_wasm);
                    parser.skip_section();
                    // update slice, bypass the section
                    wasm = &input_wasm[range.end..];

                    continue;
                }
                Payload::TypeSection(mut reader) => {
                    info.types = Some(info.raw_sections.len());
                    info.section(SectionId::Type.into(), reader.range(), input_wasm);

                    // Save function types
                    for _ in 0..reader.get_count() {
                        reader.read().map(|ty| {
                            let typeinfo = TypeInfo::try_from(ty).unwrap();
                            info.types_map.push(typeinfo);
                        })?;
                    }
                }
                Payload::ImportSection(mut reader) => {
                    info.imports = Some(info.raw_sections.len());
                    info.section(SectionId::Import.into(), reader.range(), input_wasm);

                    for _ in 0..reader.get_count() {
                        let ty = reader.read()?;
                        match ty.ty {
                            wasmparser::ImportSectionEntryType::Function(ty) => {
                                // Save imported functions
                                info.function_map.push(ty);
                                info.imported_functions_count += 1;
                            }
                            wasmparser::ImportSectionEntryType::Global(ty) => {
                                let ty = PrimitiveTypeInfo::try_from(ty.content_type).unwrap();
                                info.global_types.push(ty);
                                info.imported_globals_count += 1;
                            }
                            wasmparser::ImportSectionEntryType::Memory(_ty) => {
                                info.memory_count += 1;
                                info.imported_memories_count += 1;
                            }
                            wasmparser::ImportSectionEntryType::Table(_ty) => {
                                info.table_count += 1;
                                info.imported_tables_count += 1;
                            }
                            wasmparser::ImportSectionEntryType::Tag(_ty) => {
                                info.tag_count += 1;
                                info.imported_tags_count += 1;
                            }
                            _ => {
                                // Do nothing
                            }
                        }
                    }
                }
                Payload::FunctionSection(mut reader) => {
                    info.functions = Some(info.raw_sections.len());
                    info.section(SectionId::Function.into(), reader.range(), input_wasm);

                    for _ in 0..reader.get_count() {
                        reader.read().map(|ty| {
                            info.function_map.push(ty);
                        })?;
                    }
                }
                Payload::TableSection(reader) => {
                    info.tables = Some(info.raw_sections.len());
                    info.table_count += reader.get_count();
                    info.section(SectionId::Table.into(), reader.range(), input_wasm);
                }
                Payload::MemorySection(reader) => {
                    info.memories = Some(info.raw_sections.len());
                    info.memory_count += reader.get_count();
                    info.section(SectionId::Memory.into(), reader.range(), input_wasm);
                }
                Payload::GlobalSection(mut reader) => {
                    info.globals = Some(info.raw_sections.len());
                    info.section(SectionId::Global.into(), reader.range(), input_wasm);

                    for _ in 0..reader.get_count() {
                        let ty = reader.read()?;
                        // We only need the type of the global, not necesarily if is mutable or not
                        let ty = PrimitiveTypeInfo::try_from(ty.ty.content_type).unwrap();
                        info.global_types.push(ty);
                    }
                }
                Payload::ExportSection(mut reader) => {
                    info.exports = Some(info.raw_sections.len());
                    info.exports_count = reader.get_count();

                    for _ in 0..reader.get_count() {
                        let entry = reader.read()?;
                        info.export_names.insert(entry.field.into());
                    }

                    info.section(SectionId::Export.into(), reader.range(), input_wasm);
                }
                Payload::StartSection { func, range } => {
                    info.start = Some(info.raw_sections.len());
                    info.start_function = Some(func);
                    info.section(SectionId::Start.into(), range, input_wasm);
                }
                Payload::ElementSection(reader) => {
                    info.elements = Some(info.raw_sections.len());
                    info.elements_count = reader.get_count();
                    info.section(SectionId::Element.into(), reader.range(), input_wasm);
                }
                Payload::DataSection(reader) => {
                    info.data = Some(info.raw_sections.len());
                    info.data_segments_count = reader.get_count();
                    info.section(SectionId::Data.into(), reader.range(), input_wasm);
                }
                Payload::CustomSection {
                    name: _,
                    data_offset: _,
                    data: _,
                    range,
                } => {
                    info.section(SectionId::Custom.into(), range, input_wasm);
                }
                Payload::AliasSection(reader) => {
                    info.section(SectionId::Alias.into(), reader.range(), input_wasm);
                }
                Payload::UnknownSection {
                    id,
                    contents: _,
                    range,
                } => {
                    info.section(id, range, input_wasm);
                }
                Payload::DataCountSection { count: _, range } => {
                    info.data_count = Some(info.raw_sections.len());
                    info.section(SectionId::DataCount.into(), range, input_wasm);
                }
                Payload::Version { .. } => {}
                Payload::End => {
                    break;
                }
                _ => todo!("{:?} not implemented", payload),
            }
            wasm = &wasm[consumed..];
        }

        Ok(info)
    }

    pub fn has_code(&self) -> bool {
        self.code != None
    }

    /// Registers a new raw_section in the ModuleInfo
    pub fn section(&mut self, id: u8, range: wasmparser::Range, full_wasm: &'a [u8]) {
        self.raw_sections.push(RawSection {
            id,
            data: &full_wasm[range.start..range.end],
        });
    }

    pub fn get_code_section(&self) -> RawSection<'a> {
        self.raw_sections[self.code.unwrap()]
    }

    pub fn get_exports_section(&self) -> RawSection<'a> {
        self.raw_sections[self.exports.unwrap()]
    }

    pub fn has_exports(&self) -> bool {
        self.exports != None
    }

    /// Returns the function type based on the index of the function type
    /// `types[functions[idx]]`
    pub fn get_functype_idx(&self, idx: u32) -> &TypeInfo {
        let functpeindex = self.function_map[idx as usize] as usize;
        &self.types_map[functpeindex]
    }

    /// Returns the number of globals used by the Wasm binary including imported
    /// glboals
    pub fn get_global_count(&self) -> usize {
        self.global_types.len()
    }

    /// Returns the global section bytes as a `RawSection` instance
    pub fn get_global_section(&self) -> RawSection {
        self.raw_sections[self.globals.unwrap()]
    }

    pub fn replace_section(
        &self,
        i: usize,
        new_section: &impl wasm_encoder::Section,
    ) -> wasm_encoder::Module {
        let mut module = wasm_encoder::Module::new();
        self.raw_sections.iter().enumerate().for_each(|(j, s)| {
            if i == j {
                module.section(new_section);
            } else {
                module.section(s);
            }
        });
        module
    }

    /// Replaces raw sections in the passed indexes and return a new module
    ///
    /// This method will be helpful to add more than one custom section. For
    /// example, some code mutations might need to add a few globals. This
    /// method can be used to write a new or custom global section before the
    /// code section.
    /// * `section_writer` this callback should write the custom section and
    ///   returns true if it was successful, if false is returned then the
    ///   default section will be written to the module
    pub fn replace_multiple_sections<P>(&self, section_writer: P) -> wasm_encoder::Module
    where
        P: Fn(usize, u8, &mut wasm_encoder::Module) -> bool,
    {
        let mut module = wasm_encoder::Module::new();
        self.raw_sections.iter().enumerate().for_each(|(j, s)| {
            // Write if the section_writer did not write a custom section
            if !section_writer(j, s.id, &mut module) {
                module.section(s);
            }
        });
        module
    }

    pub fn num_functions(&self) -> u32 {
        self.function_map.len() as u32
    }

    pub fn num_local_functions(&self) -> u32 {
        self.num_functions() - self.num_imported_functions()
    }

    pub fn num_imported_functions(&self) -> u32 {
        self.imported_functions_count
    }

    pub fn num_tables(&self) -> u32 {
        self.table_count
    }

    pub fn num_imported_tables(&self) -> u32 {
        self.imported_tables_count
    }

    pub fn num_memories(&self) -> u32 {
        self.memory_count
    }

    pub fn num_imported_memories(&self) -> u32 {
        self.imported_memories_count
    }

    pub fn num_globals(&self) -> u32 {
        self.global_types.len() as u32
    }

    pub fn num_imported_globals(&self) -> u32 {
        self.imported_globals_count
    }

    pub fn num_tags(&self) -> u32 {
        self.tag_count
    }

    pub fn num_imported_tags(&self) -> u32 {
        self.imported_tags_count
    }

    pub fn num_data(&self) -> u32 {
        self.data_segments_count
    }

    pub fn num_elements(&self) -> u32 {
        self.elements_count
    }

    pub fn num_types(&self) -> u32 {
        self.types_map.len() as u32
    }
}
