//! This module implements `solc` and Truffle bytecode output parsing and
//! linking. `Bytecode` is represented as a hex string with special placeholders
//! for libraries that require linking.

use crate::errors::{BytecodeError, LinkError};
use serde::de::{Error as DeError, Visitor};
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::HashSet;
use std::fmt::{Formatter, Result as FmtResult};
use std::mem;
use web3::types::{Address, Bytes};

/// The string representation of the byte code. Note that this must be a
/// `String` since `solc` linking requires string manipulation of the
/// bytecode string representation.
#[derive(Clone, Debug, Default, Serialize)]
pub struct Bytecode(String);

impl Bytecode {
    /// Reads hex bytecode representation from a string slice.
    pub fn from_hex_str(s: &str) -> Result<Self, BytecodeError> {
        if s.is_empty() {
            // special case where we have an empty string byte code.
            return Ok(Bytecode::default());
        }

        // Verify that the length is even
        if s.len() % 2 != 0 {
            return Err(BytecodeError::InvalidLength);
        }

        // account for optional 0x prefix
        let s = s.strip_prefix("0x").unwrap_or(s);

        // verify that each code block is valid hex
        for block in CodeIter(s) {
            let block = block?;

            if let Some(pos) = block
                .bytes()
                .position(|b| !matches!(b, b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F'))
            {
                return Err(BytecodeError::InvalidHexDigit(
                    block.chars().nth(pos).expect("valid pos"),
                ));
            }
        }

        Ok(Bytecode(s.to_string()))
    }

    /// Links a library into the current bytecode.
    ///
    /// # Panics
    ///
    /// Panics if an invalid library name is used (for example if it is more
    /// than 38 characters long).
    pub fn link<S>(&mut self, name: S, address: Address) -> Result<(), LinkError>
    where
        S: AsRef<str>,
    {
        let name = name.as_ref();
        if name.len() > 38 {
            panic!("invalid library name for linking");
        }

        // NOTE(nlordell): solc linking works by string search and replace of
        //   '__$name__..__' with the library address; see generated bytecode for
        //   `LinkedContract` contract for and example of how it looks like
        let placeholder = format!("__{:_<38}", name);
        let address = to_fixed_hex(&address);
        if !self.0.contains(&placeholder) {
            return Err(LinkError::NotFound(name.to_string()));
        }
        self.0 = self.0.replace(&placeholder, &address);

        Ok(())
    }

    /// Converts a bytecode into its byte representation.
    pub fn to_bytes(&self) -> Result<Bytes, LinkError> {
        match self.undefined_libraries().next() {
            Some(library) => Err(LinkError::UndefinedLibrary(library.to_string())),
            None => Ok(Bytes(hex::decode(&self.0).expect("valid hex"))),
        }
    }

    /// Returns an iterator over all libraries remaining in the bytecode.
    pub fn undefined_libraries(&self) -> LibIter<'_> {
        LibIter {
            cursor: &self.0,
            seen: HashSet::new(),
        }
    }

    /// Returns true if bytecode requires linking.
    pub fn requires_linking(&self) -> bool {
        self.undefined_libraries().next().is_some()
    }

    /// Returns true if the bytecode is an empty bytecode.
    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }
}

/// Internal type for iterating though a bytecode's string code blocks skipping
/// the `solc` linker placeholders.
struct CodeIter<'a>(&'a str);

impl<'a> Iterator for CodeIter<'a> {
    type Item = Result<&'a str, BytecodeError>;

    fn next(&mut self) -> Option<Self::Item> {
        if self.0.is_empty() {
            return None;
        }

        match self.0.find("__") {
            Some(pos) => {
                let (block, tail) = self.0.split_at(pos);
                if tail.len() < 40 {
                    Some(Err(BytecodeError::PlaceholderTooShort))
                } else {
                    self.0 = &tail[40..];
                    Some(Ok(block))
                }
            }
            None => Some(Ok(mem::replace(&mut self.0, ""))),
        }
    }
}

/// An iterator over link placeholders in the bytecode.
pub struct LibIter<'a> {
    cursor: &'a str,
    seen: HashSet<&'a str>,
}

impl<'a> Iterator for LibIter<'a> {
    type Item = &'a str;

    fn next(&mut self) -> Option<Self::Item> {
        while let Some(pos) = self.cursor.find("__") {
            // NOTE(nlordell): this won't panic since we only construct this iterator
            //   on valid Bytecode instances where this has been verified
            let (placeholder, tail) = self.cursor[pos..].split_at(40);
            let lib = placeholder.trim_matches('_');

            self.cursor = tail;
            if self.seen.insert(lib) {
                return Some(lib);
            }
        }
        None
    }
}

impl<'de> Deserialize<'de> for Bytecode {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        deserializer.deserialize_str(BytecodeVisitor)
    }
}

/// A serde visitor for deserializing bytecode.
struct BytecodeVisitor;

impl<'de> Visitor<'de> for BytecodeVisitor {
    type Value = Bytecode;

    fn expecting(&self, f: &mut Formatter) -> FmtResult {
        write!(f, "valid EVM bytecode string representation")
    }

    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: DeError,
    {
        Bytecode::from_hex_str(v).map_err(E::custom)
    }
}

fn to_fixed_hex(address: &Address) -> String {
    format!("{:040x}", address)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn default_bytecode_is_empty() {
        assert!(Bytecode::default().is_empty());
    }

    #[test]
    fn empty_hex_bytecode_is_empty() {
        assert!(Bytecode::from_hex_str("0x").unwrap().is_empty());
    }

    #[test]
    fn unprefixed_hex_bytecode_is_not_empty() {
        assert!(!Bytecode::from_hex_str("feedface").unwrap().is_empty());
    }

    #[test]
    fn to_fixed_hex_() {
        for (value, expected) in &[
            (
                "0x0000000000000000000000000000000000000000",
                "0000000000000000000000000000000000000000",
            ),
            (
                "0x0102030405060708091020304050607080900001",
                "0102030405060708091020304050607080900001",
            ),
            (
                "0x9fac3b52be975567103c4695a2835bba40076da1",
                "9fac3b52be975567103c4695a2835bba40076da1",
            ),
        ] {
            let value: Address = value[2..].parse().unwrap();
            assert_eq!(to_fixed_hex(&value), *expected);
        }
    }

    #[test]
    fn bytecode_link_success() {
        let address = Address::zero();
        let address_encoded = [0u8; 20];
        let name = "name";
        let placeholder = format!("__{:_<38}", name);
        let mut bytecode = Bytecode::from_hex_str(&format!(
            "0x61{}{}61{}",
            placeholder, placeholder, placeholder
        ))
        .unwrap();
        bytecode.link(name, address).unwrap();
        let bytes = bytecode.to_bytes().unwrap();
        let mut expected = Vec::<u8>::new();
        expected.extend(&[0x61]);
        expected.extend(&address_encoded);
        expected.extend(&address_encoded);
        expected.extend(&[0x61]);
        expected.extend(&address_encoded);
        assert_eq!(bytes.0, expected);
    }

    #[test]
    fn bytecode_link_fail() {
        let address = Address::zero();
        let placeholder = format!("__{:_<38}", "name0");
        let mut bytecode = Bytecode::from_hex_str(&format!(
            "0x61{}{}61{}",
            placeholder, placeholder, placeholder
        ))
        .unwrap();
        // name does not match
        match bytecode.link("name1", address) {
            Err(LinkError::NotFound(_)) => (),
            _ => panic!("should fail with not found error"),
        }
    }
}
