use crate::{InstructionSet, SIMD128, SIMD256};

#[cfg(target_arch = "x86")]
use core::arch::x86::*;

#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

#[derive(Clone, Copy)]
pub struct SSE41(());

unsafe impl InstructionSet for SSE41 {
    #[inline(always)]
    fn detect() -> Option<Self> {
        #[cfg(target_feature = "sse4.1")]
        {
            Some(Self(()))
        }
        #[cfg(not(target_feature = "sse4.1"))]
        {
            #[cfg(feature = "std")]
            if std::is_x86_feature_detected!("sse4.1") {
                return Some(Self(()));
            }
            None
        }
    }

    #[inline(always)]
    unsafe fn new_unchecked() -> Self {
        Self(())
    }
}

unsafe impl SIMD128 for SSE41 {
    type V128 = __m128i;

    #[inline(always)]
    unsafe fn v128_load(self, addr: *const u8) -> Self::V128 {
        _mm_load_si128(addr.cast())
    }

    #[inline(always)]
    unsafe fn v128_loadu(self, addr: *const u8) -> Self::V128 {
        _mm_loadu_si128(addr.cast())
    }

    #[inline(always)]
    unsafe fn v128_storeu(self, addr: *mut u8, a: Self::V128) {
        _mm_storeu_si128(addr.cast(), a)
    }

    #[inline(always)]
    fn v128_or(self, a: Self::V128, b: Self::V128) -> Self::V128 {
        unsafe { _mm_or_si128(a, b) }
    }

    #[inline(always)]
    fn v128_and(self, a: Self::V128, b: Self::V128) -> Self::V128 {
        unsafe { _mm_and_si128(a, b) }
    }

    #[inline(always)]
    fn v128_to_bytes(self, a: Self::V128) -> [u8; 16] {
        unsafe { core::mem::transmute(a) }
    }

    #[inline(always)]
    fn u8x16_splat(self, x: u8) -> Self::V128 {
        unsafe { _mm_set1_epi8(x as i8) }
    }

    #[inline(always)]
    fn i8x16_shuffle(self, a: Self::V128, b: Self::V128) -> Self::V128 {
        unsafe { _mm_shuffle_epi8(a, b) }
    }

    #[inline(always)]
    fn i16x8_sll<const IMM8: i32>(self, a: Self::V128) -> Self::V128 {
        unsafe { _mm_slli_epi16::<IMM8>(a) }
    }

    #[inline(always)]
    fn i16x8_srl<const IMM8: i32>(self, a: Self::V128) -> Self::V128 {
        unsafe { _mm_srli_epi16::<IMM8>(a) }
    }

    #[inline(always)]
    fn i16x8_extract<const IMM3: i32>(self, a: Self::V128) -> i16 {
        unsafe { _mm_extract_epi16::<IMM3>(a) as i16 }
    }

    #[inline(always)]
    fn i32x4_extract<const IMM2: i32>(self, a: Self::V128) -> i32 {
        unsafe { _mm_extract_epi32::<IMM2>(a) }
    }
}

unsafe impl SIMD256 for SSE41 {
    type V256 = (__m128i, __m128i);

    #[inline(always)]
    unsafe fn v256_load(self, addr: *const u8) -> Self::V256 {
        (
            _mm_load_si128(addr.cast()),
            _mm_load_si128(addr.add(16).cast()),
        )
    }

    #[inline(always)]
    unsafe fn v256_loadu(self, addr: *const u8) -> Self::V256 {
        (
            _mm_loadu_si128(addr.cast()),
            _mm_loadu_si128(addr.add(16).cast()),
        )
    }

    #[inline(always)]
    unsafe fn v256_storeu(self, addr: *mut u8, a: Self::V256) {
        _mm_storeu_si128(addr.cast(), a.0);
        _mm_storeu_si128(addr.add(16).cast(), a.1);
    }

    #[inline(always)]
    fn v256_or(self, a: Self::V256, b: Self::V256) -> Self::V256 {
        unsafe { (_mm_or_si128(a.0, b.0), _mm_or_si128(a.1, b.1)) }
    }

    #[inline(always)]
    fn v256_and(self, a: Self::V256, b: Self::V256) -> Self::V256 {
        unsafe { (_mm_and_si128(a.0, b.0), _mm_and_si128(a.1, b.1)) }
    }

    #[inline(always)]
    fn v256_to_bytes(self, a: Self::V256) -> [u8; 32] {
        unsafe { core::mem::transmute([a.0, a.1]) }
    }

    #[inline(always)]
    fn v256_zero(self) -> Self::V256 {
        unsafe { (_mm_setzero_si128(), _mm_setzero_si128()) }
    }

    #[inline(always)]
    fn v128_from_low_v256(self, a: Self::V256) -> Self::V128 {
        a.0
    }

    #[inline(always)]
    fn v128_from_high_v256(self, a: Self::V256) -> Self::V128 {
        a.1
    }

    #[inline(always)]
    fn u8x32_splat(self, x: u8) -> Self::V256 {
        unsafe { (_mm_set1_epi8(x as i8), _mm_set1_epi8(x as i8)) }
    }

    #[inline(always)]
    fn u8x32_add(self, a: Self::V256, b: Self::V256) -> Self::V256 {
        unsafe { (_mm_add_epi8(a.0, b.0), _mm_add_epi8(a.1, b.1)) }
    }

    #[inline(always)]
    fn u8x32_sub(self, a: Self::V256, b: Self::V256) -> Self::V256 {
        unsafe { (_mm_sub_epi8(a.0, b.0), _mm_sub_epi8(a.1, b.1)) }
    }

    #[inline(always)]
    fn u8x32_shuffle(self, a: Self::V256, b: Self::V256) -> Self::V256 {
        unsafe { (_mm_shuffle_epi8(a.0, b.0), _mm_shuffle_epi8(a.1, b.1)) }
    }

    #[inline(always)]
    fn u8x32_any_zero(self, a: Self::V256) -> bool {
        unsafe {
            let zero = _mm_setzero_si128();
            let cmp0 = _mm_movemask_epi8(_mm_cmpeq_epi8(a.0, zero));
            let cmp1 = _mm_movemask_epi8(_mm_cmpeq_epi8(a.1, zero));
            (cmp0 | cmp1) != 0
        }
    }

    #[inline(always)]
    fn u8x32_highest_bits(self, a: Self::V256) -> u32 {
        unsafe {
            let hi = _mm_movemask_epi8(a.0) as u32;
            let lo = _mm_movemask_epi8(a.1) as u32;
            (hi << 16) | lo
        }
    }

    #[inline(always)]
    fn i8x32_splat(self, x: i8) -> Self::V256 {
        unsafe { (_mm_set1_epi8(x), _mm_set1_epi8(x)) }
    }

    #[inline(always)]
    fn i8x32_cmplt(self, a: Self::V256, b: Self::V256) -> Self::V256 {
        unsafe { (_mm_cmplt_epi8(a.0, b.0), _mm_cmplt_epi8(a.1, b.1)) }
    }

    #[inline(always)]
    fn u16x16_sll<const IMM8: i32>(self, a: Self::V256) -> Self::V256 {
        unsafe { (_mm_slli_epi16::<IMM8>(a.0), _mm_slli_epi16::<IMM8>(a.1)) }
    }

    #[inline(always)]
    fn u16x16_srl<const IMM8: i32>(self, a: Self::V256) -> Self::V256 {
        unsafe { (_mm_srli_epi16::<IMM8>(a.0), _mm_srli_epi16::<IMM8>(a.1)) }
    }

    #[inline(always)]
    fn u64x2_from_low_u128x2(self, a: Self::V256) -> Self::V128 {
        unsafe { _mm_unpacklo_epi64(a.0, a.1) }
    }

    #[inline(always)]
    fn i16x16_from_u8x16(self, a: Self::V128) -> Self::V256 {
        unsafe {
            let zero = _mm_setzero_si128();
            (_mm_unpacklo_epi8(a, zero), _mm_unpackhi_epi8(a, zero))
        }
    }
}
