use crate::*;
use std::{marker::PhantomData, ptr::slice_from_raw_parts, ptr::slice_from_raw_parts_mut};

/// Trait representing a group of fields in a specific struct, specifically intended for allocating memory (Soa/AoSoA).
pub trait FieldGroup<'s, const LENGTH: usize>: private::SealedField + Sized + 'static {
    /// Offset in bytes of the field in an allocation.
    const OFFSETS: [usize; LENGTH];
    const SIZES: [usize; LENGTH];
    const ALIGNMENTS: [usize; LENGTH];

    type FieldTypes;
    type RefType: RefTuple<'s>;
    type MutRefType: MutTuple<'s>;
    type SliceType: SliceTuple + 's;
    type MutSliceType: MutSliceTuple + 's;

    unsafe fn write_to_pointers(tuple: Self::RefType, dest: [*mut u8; LENGTH]);
    unsafe fn read_from_pointers(src: [*const u8; LENGTH]) -> Self::RefType;
    unsafe fn read_from_pointers_mut(src: [*mut u8; LENGTH]) -> Self::MutRefType;

    unsafe fn slice_pointers(shard_ptr: *const u8, length: u16) -> Self::SliceType;
    unsafe fn slice_pointers_mut(shard_ptr: *mut u8, length: u16) -> Self::MutSliceType;

    unsafe fn offset_pointers(shard_ptr: *const u8, offset_in_shard: u16) -> [*const u8; LENGTH];
    unsafe fn offset_pointers_mut(shard_ptr: *mut u8, offset_in_shard: u16) -> [*mut u8; LENGTH];
}

pub struct Field<F: Sized + 'static, const OFFSET_IN_SHARD: usize> {
    _phantom: PhantomData<*const F>,
}

impl<'s, T: Sized + 'static, const OFFSET_IN_SHARD: usize> FieldGroup<'s, 1>
    for Field<T, OFFSET_IN_SHARD>
{
    const OFFSETS: [usize; 1] = [OFFSET_IN_SHARD];
    const SIZES: [usize; 1] = [std::mem::size_of::<T>()];
    const ALIGNMENTS: [usize; 1] = [std::mem::align_of::<T>()];

    type FieldTypes = T;
    type RefType = &'s T;
    type MutRefType = &'s mut T;
    type SliceType = &'s [T];
    type MutSliceType = &'s mut [T];

    #[inline(always)]
    unsafe fn write_to_pointers(tuple: Self::RefType, dest: [*mut u8; 1]) {
        let ptrs = tuple.as_const_pointer_tuple();
        std::ptr::copy_nonoverlapping(ptrs, dest[0] as *mut T, 1);
    }

    #[inline(always)]
    unsafe fn read_from_pointers(src: [*const u8; 1]) -> Self::RefType {
        &*(src[0] as *const T)
    }

    #[inline(always)]
    unsafe fn read_from_pointers_mut(src: [*mut u8; 1]) -> Self::MutRefType {
        &mut *(src[0] as *mut T)
    }

    #[inline(always)]
    unsafe fn slice_pointers(shard_ptr: *const u8, length: u16) -> Self::SliceType {
        &*slice_from_raw_parts(shard_ptr as *const T, length as usize)
    }

    #[inline(always)]
    unsafe fn slice_pointers_mut(shard_ptr: *mut u8, length: u16) -> Self::MutSliceType {
        &mut *slice_from_raw_parts_mut(shard_ptr as *mut T, length as usize)
    }

    #[inline(always)]
    unsafe fn offset_pointers(shard_ptr: *const u8, offset_in_shard: u16) -> [*const u8; 1] {
        [
            (shard_ptr.offset(Self::OFFSETS[0] as isize) as *const T)
                .offset(offset_in_shard as isize) as *const u8,
        ]
    }

    #[inline(always)]
    unsafe fn offset_pointers_mut(shard_ptr: *mut u8, offset_in_shard: u16) -> [*mut u8; 1] {
        [
            (shard_ptr.offset(Self::OFFSETS[0] as isize) as *mut T).offset(offset_in_shard as isize)
                as *mut u8,
        ]
    }
}

macro_rules! impl_abstract_fields {
    ($(($ty:ident, $tc:ident, $nf:expr))*) => {}; //base case
    (($first:ident, $fc:ident, $nf:expr), $(($tail:ident, $tc:ident, $nt:expr)), *) => {
        impl<'s, $first, $($tail),* ,const $fc : usize, $(const $tc : usize),*> FieldGroup<'s, $nf> for (Field<$first, $fc>, $(Field<$tail, $tc>), *)
        where $( $tail : Sized + 'static ),*, $first : Sized + 'static
        {
            const OFFSETS : [usize; $nf] = [$fc, $($tc),*];
            const SIZES : [usize; $nf] = [std::mem::size_of::<$first>(), $(std::mem::size_of::<$tail>()), *];
            const ALIGNMENTS: [usize; $nf] = [std::mem::align_of::<$first>(), $(std::mem::align_of::<$tail>()), *];

            type FieldTypes = ($first, $($tail), *);
            type RefType = (&'s $first, $(&'s $tail, )*);
            type MutRefType = (&'s mut $first, $(&'s mut $tail, )*);

            type SliceType = (&'s [$first], $(&'s [$tail], )*);
            type MutSliceType = (&'s mut [$first], $(&'s mut [$tail], )*);

            #[inline(always)]
            unsafe fn read_from_pointers(src: [*const u8; $nf]) -> Self::RefType {
                (
                    &*(src[$nf - 1] as *const $first),
                    $(&*(src[$nt - 1] as *const $tail)), *
                )
            }

            #[inline(always)]
            unsafe fn read_from_pointers_mut(src: [*mut u8; $nf]) -> Self::MutRefType {
                (
                    &mut *(src[$nf - 1] as *mut $first),
                    $(&mut *(src[$nt - 1] as *mut $tail)), *
                )
            }

            #[inline(always)]
            unsafe fn write_to_pointers(tuple: Self::RefType, dest: [*mut u8; $nf]) {
                #[allow(non_snake_case)]
                let ($first, $($tail), *) = tuple.as_const_pointer_tuple();
                std::ptr::copy_nonoverlapping($first, dest[$nf - 1] as *mut $first, 1);
                $(std::ptr::copy_nonoverlapping($tail, dest[$nt - 1] as *mut $tail, 1); )*
            }

            #[inline(always)]
            unsafe fn offset_pointers(shard_ptr: *const u8, offset_in_shard: u16) -> [*const u8; $nf] {
                [
                    ((shard_ptr.offset(Self::OFFSETS[$nf - 1] as isize)) as *const $first).offset(offset_in_shard as isize) as *const u8,
                    $(((shard_ptr.offset(Self::OFFSETS[$nt - 1] as isize)) as *const $tail).offset(offset_in_shard as isize) as *const u8), *
                ]
            }

            #[inline(always)]
            unsafe fn offset_pointers_mut(shard_ptr: *mut u8, offset_in_shard: u16) -> [*mut u8; $nf] {
                [
                    ((shard_ptr.offset(Self::OFFSETS[$nf - 1] as isize)) as *mut $first).offset(offset_in_shard as isize) as *mut u8,
                    $(((shard_ptr.offset(Self::OFFSETS[$nt - 1] as isize)) as *mut $tail).offset(offset_in_shard as isize) as *mut u8), *
                ]
            }

            #[inline(always)]
            unsafe fn slice_pointers(shard_ptr: *const u8, length: u16) -> Self::SliceType {
                (
                    &*slice_from_raw_parts(((shard_ptr.offset(Self::OFFSETS[$nf - 1] as isize)) as *const $first), length as usize),
                    $(&*slice_from_raw_parts(((shard_ptr.offset(Self::OFFSETS[$nt - 1] as isize)) as *const $tail), length as usize), )*
                )
            }

            #[inline(always)]
            unsafe fn slice_pointers_mut(shard_ptr: *mut u8, length: u16) -> Self::MutSliceType {
                (
                    &mut *slice_from_raw_parts_mut(((shard_ptr.offset(Self::OFFSETS[$nf - 1] as isize)) as *mut $first), length as usize),
                    $(&mut *slice_from_raw_parts_mut(((shard_ptr.offset(Self::OFFSETS[$nt - 1] as isize)) as *mut $tail), length as usize), )*
                )
            }
        }

        impl_abstract_fields!($(($tail, $tc, $nt)),*);
    }
}

impl_abstract_fields!(
    (T1, C1, 8),
    (T2, C2, 7),
    (T3, C3, 6),
    (T4, C4, 5),
    (T5, C5, 4),
    (T6, C6, 3),
    (T7, C7, 2),
    (T8, C8, 1)
);

mod private {
    use super::*;
    pub trait SealedField {}

    impl<F: Sized + 'static, const C: usize> SealedField for Field<F, C> {}

    macro_rules! impl_sealed_fields {
        ($(($ty:ident, $tc:ident, $nf:expr))*) => {}; //base case
        (($first:ident, $fc:ident, $nf:expr), $(($tail:ident, $tc:ident, $tf:expr)), *) => {
            impl<$first, $($tail),*, const $fc : usize, $(const $tc : usize),*> SealedField for (Field<$first, $fc>, $(Field<$tail, $tc>), *)
            where $( $tail : Sized + 'static ),*, $first : Sized + 'static {}
            impl_sealed_fields!($(($tail, $tc, $tf)),*);
        }
    }

    impl_sealed_fields!(
        (T1, C1, 8),
        (T2, C2, 7),
        (T3, C3, 6),
        (T4, C4, 5),
        (T5, C5, 4),
        (T6, C6, 3),
        (T7, C7, 2),
        (T8, C8, 1)
    );
}
