use core::mem::MaybeUninit;
use core::{mem, slice};

use simd_abstraction::generic::hex as sa_hex;
use simd_abstraction::{Bytes16, Bytes32, Load, SIMD256};

use crate::{fallback, uninit_slice_as_mut_ptr, AsciiCase, Error, ERROR};

macro_rules! specialize_for {
    ($feature:literal, $ty: ty) => {
        use crate::{AsciiCase, Error};
        use core::mem::MaybeUninit;

        #[inline]
        #[target_feature(enable = $feature)]
        pub unsafe fn check(src: &[u8]) -> bool {
            let s = <$ty as simd_abstraction::InstructionSet>::new_unchecked();
            crate::generic::check(s, src)
        }

        #[inline]
        #[target_feature(enable = $feature)]
        pub unsafe fn encode<'s, 'd>(
            src: &'s [u8],
            dst: &'d mut [MaybeUninit<u8>],
            case: AsciiCase,
        ) -> Result<&'d mut [u8], Error> {
            let s = <$ty as simd_abstraction::InstructionSet>::new_unchecked();
            crate::generic::encode(s, src, dst, case)
        }

        #[inline]
        #[target_feature(enable = $feature)]
        pub unsafe fn decode<'s, 'd>(
            src: &'s [u8],
            dst: &'d mut [MaybeUninit<u8>],
        ) -> Result<&'d mut [u8], Error> {
            let s = <$ty as simd_abstraction::InstructionSet>::new_unchecked();
            crate::generic::decode(s, src, dst)
        }
    };
}

#[inline]
pub fn check<S: SIMD256>(s: S, src: &[u8]) -> bool {
    let (prefix, chunks, suffix) = unsafe { src.align_to::<Bytes32>() };
    if !fallback::check(prefix) {
        return false;
    }
    for chunk in chunks {
        if !sa_hex::check_u8x32(s, s.load(chunk)) {
            return false;
        }
    }
    if !fallback::check(suffix) {
        return false;
    }
    true
}

#[inline]
pub fn encode<'s, 'd, S>(
    s: S,
    src: &'s [u8],
    dst: &'d mut [MaybeUninit<u8>],
    case: AsciiCase,
) -> Result<&'d mut [u8], Error>
where
    S: SIMD256,
{
    if dst.len() / 2 < src.len() {
        return Err(ERROR);
    }
    let table = match case {
        AsciiCase::Lower => sa_hex::ENCODE_LOWER_LUT,
        AsciiCase::Upper => sa_hex::ENCODE_UPPER_LUT,
    };
    unsafe {
        let dst = uninit_slice_as_mut_ptr(dst);
        encode_unchecked(s, src, dst, table);
        Ok(slice::from_raw_parts_mut(dst, src.len() * 2))
    }
}

#[inline]
unsafe fn encode_unchecked<S: SIMD256>(s: S, src: &[u8], dst: *mut u8, table: &Bytes32) {
    let mut cur: *mut u8 = dst;
    let (prefix, chunks, suffix) = src.align_to::<Bytes16>();
    if !prefix.is_empty() {
        fallback::encode_unchecked(prefix, cur, mem::transmute(table));
        cur = cur.add(prefix.len() * 2);
    }
    let lut = s.load(table);
    for chunk in chunks {
        let ans = sa_hex::encode_u8x16(s, s.load(chunk), lut);
        s.v256_storeu(cur, ans);
        cur = cur.add(32);
    }
    if !suffix.is_empty() {
        fallback::encode_unchecked(suffix, cur, mem::transmute(table));
    }
}

#[inline]
pub fn decode<'s, 'd, S>(
    s: S,
    src: &'s [u8],
    dst: &'d mut [MaybeUninit<u8>],
) -> Result<&'d mut [u8], Error>
where
    S: SIMD256,
{
    let n = src.len();
    let m = n / 2;
    if !((n & 1) == 0 && dst.len() >= m) {
        return Err(ERROR);
    }

    unsafe {
        let src = src.as_ptr();
        let dst = uninit_slice_as_mut_ptr(dst);
        decode_unchecked(s, m, src, dst)?;
        Ok(slice::from_raw_parts_mut(dst, m))
    }
}

#[inline]
unsafe fn decode_unchecked<S: SIMD256>(
    s: S,
    m: usize,
    mut src: *const u8,
    mut dst: *mut u8,
) -> Result<(), Error> {
    let mut cnt = m;
    while cnt >= 16 {
        let chunk = s.v256_loadu(src);
        let ans = sa_hex::decode_u8x32(s, chunk).map_err(|()| ERROR)?;
        s.v128_storeu(dst, ans);
        src = src.add(32);
        dst = dst.add(16);
        cnt -= 16;
    }
    fallback::decode_raw(cnt, src, dst)?;
    Ok(())
}
