use core::mem::MaybeUninit;
use core::slice;

use crate::{uninit_slice_as_mut_ptr, AsciiCase, Error, ERROR};

use simd_abstraction::generic::hex::unhex;

#[inline(always)]
fn shl4(x: u8) -> u8 {
    x.wrapping_shl(4)
}

#[inline(always)]
unsafe fn read(base: *const u8, offset: usize) -> u8 {
    base.add(offset).read()
}

#[inline(always)]
unsafe fn write<T>(base: *mut u8, offset: usize, value: T) {
    base.add(offset).cast::<T>().write(value)
}

#[inline]
pub fn check(src: &[u8]) -> bool {
    #[inline]
    unsafe fn check_unroll1(n: usize, src: *const u8) -> bool {
        let mut i = 0;
        let mut ans = 0;
        while i < n {
            ans |= unhex(read(src, i));
            i += 1;
        }
        ans != 0xff
    }
    #[inline]
    unsafe fn check_unroll4(n: usize, src: *const u8) -> bool {
        let mut i = 0;
        while i < n {
            let y1 = unhex(read(src, i));
            let y2 = unhex(read(src, i + 1));
            let y3 = unhex(read(src, i + 2));
            let y4 = unhex(read(src, i + 3));
            if y1 | y2 | y3 | y4 == 0xff {
                return false;
            }
            i += 4;
        }
        true
    }

    let n = src.len();
    let src = src.as_ptr();
    unsafe {
        let n1 = n & 3;
        let n4 = n - n1;
        if n4 > 0 && !check_unroll4(n4, src) {
            return false;
        }
        check_unroll1(n1, src.add(n4))
    }
}

const UPPER_TABLE: &[u8; 16] = b"0123456789ABCDEF";
const LOWER_TABLE: &[u8; 16] = b"0123456789abcdef";

#[inline]
pub fn encode<'s, 'd>(
    src: &'s [u8],
    dst: &'d mut [MaybeUninit<u8>],
    case: AsciiCase,
) -> Result<&'d mut [u8], Error> {
    if dst.len() / 2 < src.len() {
        return Err(ERROR);
    }
    let table = match case {
        AsciiCase::Lower => LOWER_TABLE,
        AsciiCase::Upper => UPPER_TABLE,
    };
    unsafe {
        let dst = uninit_slice_as_mut_ptr(dst);
        encode_unchecked(src, dst, table);
        Ok(slice::from_raw_parts_mut(dst, src.len() * 2))
    }
}

#[inline]
pub(crate) unsafe fn encode_unchecked(src: &[u8], dst: *mut u8, table: &[u8; 16]) {
    let (n, src) = (src.len(), src.as_ptr());
    let table = table.as_ptr();
    let mut i = 0;
    while i < n {
        let x = read(src, i);
        let hi = read(table, (x >> 4) as usize);
        let lo = read(table, (x & 0xf) as usize);
        write(dst, i * 2, hi);
        write(dst, i * 2 + 1, lo);
        i += 1;
    }
}

#[inline]
pub fn decode<'s, 'd>(
    src: &'s [u8],
    dst: &'d mut [MaybeUninit<u8>],
) -> Result<&'d mut [u8], Error> {
    let n = src.len();
    let m = n / 2;
    if !((n & 1) == 0 && dst.len() >= m) {
        return Err(ERROR);
    }
    unsafe {
        let dst = uninit_slice_as_mut_ptr(dst);
        decode_raw(m, src.as_ptr(), dst)?;
        Ok(slice::from_raw_parts_mut(dst, m))
    }
}

#[inline]
pub(crate) unsafe fn decode_raw(m: usize, src: *const u8, dst: *mut u8) -> Result<(), Error> {
    let mut i = 0;
    while i < m {
        let y1 = unhex(read(src, i * 2));
        let y2 = unhex(read(src, i * 2 + 1));
        if y1 | y2 == 0xff {
            return Err(ERROR);
        }
        let z = shl4(y1) | y2;
        write(dst, i, z);
        i += 1;
    }
    Ok(())
}

#[test]
fn test() {
    crate::tests::test(check, decode, encode);
}
