use quote::{format_ident, quote};

use crate::model::{BytewiseModel, WordwiseModel};
use crate::spec::Spec;
use crate::util::ones;

fn wrap_uint(int: u128) -> proc_macro2::TokenTree {
    proc_macro2::TokenTree::Literal(proc_macro2::Literal::u128_unsuffixed(int))
}

fn wrap_int(int: i128) -> proc_macro2::TokenTree {
    proc_macro2::TokenTree::Literal(proc_macro2::Literal::i128_unsuffixed(int))
}

fn capitalize(s: &str) -> String {
    let mut chars = s.chars();
    let mut r = chars
        .next()
        .map(|c| c.to_ascii_uppercase())
        .map(String::from)
        .unwrap_or_default();

    r.extend(chars);

    r
}

fn get_suffix(spec: &Spec) -> String {
    let name_lower = spec.name.to_ascii_lowercase();
    let suffix = match name_lower.strip_prefix("crc") {
        None => &name_lower,
        Some(rest) => {
            let rest = match rest.strip_prefix('-') {
                None => rest,
                Some(rest) => rest,
            };
            let rest = rest.trim_start_matches(char::is_numeric);
            match rest.strip_prefix('/') {
                None => rest,
                Some(rest) => rest,
            }
        }
    };
    suffix.replace('-', "_")
}

fn get_underscore(suffix: &str) -> &'static str {
    if suffix.chars().next().unwrap().is_numeric() {
        "_"
    } else {
        ""
    }
}

fn get_ty(spec: &Spec) -> (proc_macro2::TokenStream, u16) {
    match spec.width {
        0..=8 => (quote!(u8), 8),
        9..=16 => (quote!(u16), 16),
        17..=32 => (quote!(u32), 32),
        33..=64 => (quote!(u64), 64),
        65..=128 => (quote!(u128), 128),
        _ => panic!("Unable to support more bits! (asked for {})", spec.width),
    }
}
fn get_word_ty(spec: &Spec) -> (proc_macro2::TokenStream, u16) {
    if spec.width > 64 {
        (quote!(u128), 128)
    } else if usize::BITS == 64 {
        (quote!(u64), 64)
    } else {
        (quote!(u32), 32)
    }
}

trait Endianness {
    fn generate_wordwise_body(spec: &Spec) -> proc_macro2::TokenStream;
    fn generate_swapper(
        spec: &Spec,
        crc_type: &proc_macro2::TokenStream,
        word_type: &proc_macro2::TokenStream,
    ) -> proc_macro2::TokenStream;
}

impl Endianness for byteorder::LittleEndian {
    fn generate_wordwise_body(spec: &Spec) -> proc_macro2::TokenStream {
        let top = if spec.width > 8 {
            0u16.wrapping_sub(spec.width) & 7
        } else {
            0
        };
        let top_tok = wrap_uint(top as _);

        let (word_type, word_bits) = get_word_ty(spec);
        let word_bytes = word_bits >> 3;
        let word_bytes_tok = wrap_int(word_bytes as _);
        let word_bytes_less_one_tok = wrap_uint((word_bytes - 1) as _);

        let mut body = quote!();

        if !spec.reflect() {
            if top != 0 {
                body = quote!(#body crc <<= #top_tok;);
            }
            if spec.width > 8 {
                body = quote!(#body crc = swapper(crc););
            }
        }

        let segments = (1..(word_bytes - 1))
            .map(|k| {
                let byte_index = wrap_uint((word_bytes - k - 1) as _);
                let shift = wrap_uint((k << 3) as _);
                quote! {
                    super::TABLE_WORD[#byte_index][((word >> #shift) as u8) as usize]
                }
            })
            .collect::<Vec<_>>();

        let last_shift_tok = wrap_uint(((word_bytes - 1) << 3) as _);

        body = quote! {
            #body

            while data.len() >= #word_bytes_tok {
                let (int_bytes, rest) = data.split_at(std::mem::size_of::<#word_type>());

                let word = (crc as #word_type) ^ #word_type::from_le_bytes(int_bytes.try_into().unwrap());

                crc = super::TABLE_WORD[#word_bytes_less_one_tok][(word as u8) as usize]
                    ^ #(#segments ^)*
                    super::TABLE_WORD[0][(word >> #last_shift_tok) as usize];

                data = rest;
            }
        };

        if !spec.reflect() {
            if spec.width > 8 {
                body = quote!(#body crc = swapper(crc););
            }
            if top != 0 {
                body = quote!(#body crc >>= #top_tok;);
            }
        }

        body
    }

    fn generate_swapper(
        spec: &Spec,
        crc_type: &proc_macro2::TokenStream,
        _word_type: &proc_macro2::TokenStream,
    ) -> proc_macro2::TokenStream {
        if spec.reflect() || spec.width <= 8 {
            return quote!();
        }

        generate_swapper_impl(crc_type, (spec.width - 1) as u128 & !7)
    }
}

fn generate_swapper_impl(ty: &proc_macro2::TokenStream, mid: u128) -> proc_macro2::TokenStream {
    let mut pick = 0xff_u128;
    let mut mid = mid as i128;
    let last = -mid;

    let mut body = quote!();

    loop {
        let pick_tok = wrap_uint(pick);
        let mid_tok = wrap_int(mid);

        body = quote!(#body ((crc & #pick_tok) << #mid_tok) +);

        mid -= 16;
        pick <<= 8;

        if mid <= 0 {
            break;
        }
    }

    if mid == 0 {
        let pick_tok = wrap_uint(pick);
        body = quote!(#body (crc & #pick_tok) +);

        mid -= 16;
        pick <<= 8;
    }

    while mid > last {
        let pick_tok = wrap_uint(pick);
        let mid_neg_tok = wrap_int(-mid);
        body = quote!(#body ((crc & #pick_tok) >> #mid_neg_tok) +);

        mid -= 16;
        pick <<= 8;
    }

    {
        let pick_tok = wrap_uint(pick);
        let mid_neg_tok = wrap_int(-mid);
        body = quote!(#body ((crc & #pick_tok) >> #mid_neg_tok));
    }

    quote! {
        fn swapper(crc: #ty) -> #ty {
            #body
        }
    }
}

pub fn gen_mod(spec: &Spec) -> proc_macro2::TokenStream {
    let suffix = get_suffix(spec);
    let maybe_underscore = get_underscore(&suffix);
    let mod_name = format_ident!("crc{}{}{}", spec.width, maybe_underscore, suffix);

    let bitwise = gen_bitwise(spec);
    let bytewise_table = gen_bytewise_table(spec);
    let bytewise = gen_bytewise(spec);
    let wordwise_table = gen_wordwise_table(spec);
    let wordwise = gen_wordwise(spec);

    let (ty, _) = get_ty(spec);

    let reverser = format_ident!("revlow{}", spec.width);

    let revgen = if !spec.reverse() {
        quote!()
    } else {
        let (down, bits) = match spec.width {
            31 => (1, 32),
            47 | 55 | 59 | 61 | 62 | 63 => (64 - spec.width, 64),
            _ => (0, spec.width),
        };

        let all = (1 << (bits - 1) << 1) - 1;

        let mut kept = 0;
        let mut mask = all;
        let mut body = quote!();
        let mut bits = bits;

        loop {
            let mid = bits & 1;
            bits >>= 1;

            if mid != 0 {
                let keep = (mask >> bits) ^ (mask >> (bits + 1));
                let keep_tok = wrap_uint(keep);

                if kept != 0 {
                    body = quote! {
                        #body
                        mid |= val & #keep_tok;
                    };
                } else {
                    body = quote! {
                        #body
                        let mid = val & #keep_tok;
                    };
                }

                kept |= keep;
            }

            let shift = bits + mid;

            mask ^= mask >> shift;

            let left = mask & !kept;
            let left_tok = wrap_uint(left);
            let right_tok = wrap_uint(all ^ kept ^ left);
            let shift_tok = wrap_uint(shift as _);

            body = quote! {
                #body
                val = ((val >> #shift_tok) & #right_tok)
                    | ((val << #shift_tok) & #left_tok);
            };

            if bits <= 1 {
                break;
            }
        }

        let return_value = if down != 0 {
            let down_tok = wrap_uint(down as _);
            quote! {
                val >> #down_tok
            }
        } else if kept != 0 {
            quote! {
                val | mid
            }
        } else {
            quote! {
                val
            }
        };

        quote! {
            fn #reverser(mut val: #ty) -> #ty {
                #body
                #return_value
            }
        }
    };

    quote! {
        pub mod #mod_name {
            #bytewise_table
            #wordwise_table
            #revgen
            #bitwise
            #bytewise
            #wordwise
        }
    }
}

pub fn gen_bitwise(spec: &Spec) -> proc_macro2::TokenStream {
    let suffix = get_suffix(spec);
    let suffix_capitalized = capitalize(&suffix);
    let maybe_underscore = get_underscore(&suffix);
    let name = format_ident!(
        "Crc{}{}{}",
        spec.width,
        maybe_underscore,
        suffix_capitalized
    );
    let simple_test_name = format_ident!("crc{}{}{}_simple", spec.width, maybe_underscore, suffix);
    let batch_test_name = format_ident!("crc{}{}{}_batch", spec.width, maybe_underscore, suffix);
    let empty_test_name = format_ident!("crc{}{}{}_empty", spec.width, maybe_underscore, suffix);

    let (ty, bits) = get_ty(spec);

    let init = wrap_uint(spec.init());
    let poly = wrap_uint(spec.poly());
    let check = wrap_uint(spec.check);
    let mask = wrap_uint(ones(spec.width));
    let reverser = format_ident!("revlow{}", spec.width);
    let cast = if bits != 8 {
        quote!(*byte as #ty)
    } else {
        quote!(*byte)
    };

    let xorin = if spec.xorout == 0 {
        quote!()
    } else if spec.xorout == ones(spec.width) {
        quote!(crc = !crc;)
    } else {
        let value = wrap_uint(spec.xorout);
        quote!(crc ^= #value;)
    };

    let revin = if !spec.reverse() {
        quote!()
    } else {
        quote!(crc = super::#reverser(crc);)
    };

    let revout = if !spec.reverse() {
        quote!()
    } else {
        quote!(crc = super::#reverser(crc);)
    };

    let body = if spec.reflect() {
        let prefix = if spec.width == bits || spec.reverse() {
            quote!()
        } else {
            quote!(crc &= #mask;)
        };
        let body = quote! {
            for byte in data {
                crc ^= #cast;
                for _ in 0..8 {
                    crc = if crc & 1 != 0 { (crc >> 1) ^ #poly } else { crc >> 1 };
                }
            }
        };
        let xorout = if spec.xorout == 0 {
            quote!()
        } else if spec.xorout == ones(spec.width) && bits == spec.width {
            quote!(crc = !crc;)
        } else {
            let value = wrap_uint(spec.xorout);
            quote!(crc ^= #value;)
        };

        quote! {
            #prefix
            #body
            #revout
            #xorout
        }
    } else if spec.width <= 8 {
        let poly = wrap_uint(spec.poly() << (8 - spec.width));
        let shift = wrap_uint((8 - spec.width) as _);
        let shift_left = if spec.width == 8 {
            quote!()
        } else {
            quote!(crc <<= #shift;)
        };
        let shift_right = if spec.width == 8 {
            quote!()
        } else {
            quote!(crc >>= #shift;)
        };
        let body = quote! {
            for byte in data {
                crc ^= #cast;
                for _ in 0..8 {
                    crc = if crc & 0x80 != 0 { (crc << 1) ^ #poly } else { crc << 1 };
                }
            }
        };
        let xorout = if spec.xorout == 0 {
            quote!()
        } else if spec.xorout == ones(spec.width) && !spec.reverse() {
            quote!(crc = !crc;)
        } else {
            let value = wrap_uint(spec.xorout << (8 - spec.width));
            quote!(crc ^= #value;)
        };

        quote! {
            #shift_left
            #body
            #xorout
            #shift_right
            #revout
        }
    } else {
        let shift = wrap_uint((spec.width - 8) as _);
        let top = wrap_uint(1 << (spec.width - 1));

        let body = quote! {
            for byte in data {
                crc ^= (#cast) << #shift;
                for _ in 0..8 {
                    crc = if crc & #top != 0 { (crc << 1) ^ #poly } else { crc << 1 };
                }
            }
        };
        let xorout = if spec.xorout == 0 {
            quote!()
        } else if spec.xorout == ones(spec.width) && !spec.reverse() {
            quote!(crc = !crc;)
        } else {
            let value = wrap_uint(spec.xorout);
            quote!(crc ^= #value;)
        };

        let mask = if spec.width == bits || spec.reverse() {
            quote!()
        } else {
            quote!(crc &= #mask;)
        };

        quote! {
            #body
            #revout
            #xorout
            #mask
        }
    };

    quote! {
        pub mod bitwise {
            pub struct #name;

            impl #name {
                pub fn new() -> ::crcany::crc::Computer<Self> {
                    ::crcany::crc::Computer::new(#name)
                }
            }

            impl ::crcany::crc::Crc for #name {
                type Int = #ty;

                fn init(&self) -> Self::Int {
                    #init
                }

                fn add_bytes(&self, mut crc: Self::Int, mut data: &[u8]) -> Self::Int {
                    if data.is_empty() {
                        return crc;
                    }

                    #xorin
                    #revin
                    #body

                    crc
                }
            }

            #[cfg(test)]
            mod test {
                use super::*;

                #[test]
                fn #simple_test_name() {
                    let mut crc = #name::new();
                    crc.add_bytes(b"123456789");
                    assert_eq!(#check, crc.into_inner());
                }

                #[test]
                fn #batch_test_name() {
                    let mut crc = #name::new();
                    crc.add_bytes(b"123");
                    crc.add_bytes(b"4");
                    crc.add_bytes(b"");
                    crc.add_bytes(b"567");
                    crc.add_bytes(b"89");
                    assert_eq!(#check, crc.into_inner());
                }

                #[test]
                fn #empty_test_name() {
                    let mut crc = #name::new();
                    crc.add_bytes(b"");
                    assert_eq!(#init, crc.into_inner());
                }
            }
        }
    }
}

pub fn gen_bytewise_table(spec: &Spec) -> proc_macro2::TokenStream {
    let bytewise = BytewiseModel::from_spec(spec.clone());

    let (ty, _) = get_ty(spec);

    let bytes = bytewise.table.iter().copied().map(wrap_uint);
    quote! {
        const TABLE_BYTE: [#ty; 256] = [
            #( #bytes ),*
        ];
    }
}

pub fn gen_bytewise(spec: &Spec) -> proc_macro2::TokenStream {
    let suffix = get_suffix(spec);
    let suffix_capitalized = capitalize(&suffix);
    let maybe_underscore = get_underscore(&suffix);
    let name = format_ident!(
        "Crc{}{}{}",
        spec.width,
        maybe_underscore,
        suffix_capitalized
    );
    let simple_test_name = format_ident!("crc{}{}{}_simple", spec.width, maybe_underscore, suffix);
    let batch_test_name = format_ident!("crc{}{}{}_batch", spec.width, maybe_underscore, suffix);
    let empty_test_name = format_ident!("crc{}{}{}_empty", spec.width, maybe_underscore, suffix);

    let (ty, bits) = get_ty(spec);

    let init = wrap_uint(spec.init());
    let check = wrap_uint(spec.check);
    let reverser = format_ident!("revlow{}", spec.width);

    let rev = if !spec.reverse() {
        quote!()
    } else {
        quote!(crc = super::#reverser(crc);)
    };

    let body = if spec.reflect() {
        let mask = if spec.width != bits && !spec.reverse() {
            let mask = wrap_uint(ones(spec.width));
            quote!(crc &= #mask;)
        } else {
            quote!()
        };
        let steps = if spec.width > 8 {
            quote! {
                for byte in data {
                    crc = (crc >> 8) ^ super::TABLE_BYTE[((crc as u8) ^ *byte) as usize];
                }
            }
        } else {
            quote! {
                for byte in data {
                    crc = super::TABLE_BYTE[((crc as u8) ^ *byte) as usize];
                }
            }
        };
        quote!(#mask #steps)
    } else if spec.width <= 8 {
        let shift = 8 - spec.width;
        let shift_tok = wrap_uint(shift as _);

        let shift_left = if spec.width < 8 {
            quote!(crc <<= #shift_tok;)
        } else {
            quote!()
        };
        let shift_right = if spec.width < 8 {
            quote!(crc >>= #shift_tok;)
        } else {
            quote!()
        };

        quote! {
            #shift_left
            for byte in data {
                crc = super::TABLE_BYTE[(crc ^ byte) as usize];
            }
            #shift_right
        }
    } else {
        let shift = spec.width - 8;
        let shift_tok = wrap_uint(shift as _);

        let mask = if spec.width != bits && !spec.reverse() {
            let mask = wrap_uint(ones(spec.width));
            quote!(crc &= #mask;)
        } else {
            quote!()
        };

        quote! {
            for byte in data {
                crc = (crc << 8) ^ super::TABLE_BYTE[(((crc >> #shift_tok) as u8) ^ *byte) as usize];
            }
            #mask
        }
    };

    quote! {
        pub mod bytewise {
            pub struct #name;

            impl #name {
                pub fn new() -> ::crcany::crc::Computer<Self> {
                    ::crcany::crc::Computer::new(#name)
                }
            }

            impl ::crcany::crc::Crc for #name {
                type Int = #ty;

                fn init(&self) -> Self::Int {
                    #init
                }

                fn add_bytes(&self, mut crc: Self::Int, mut data: &[u8]) -> Self::Int {
                    if data.is_empty() {
                        return crc;
                    }

                    #rev
                    #body
                    #rev

                    crc
                }
            }

            #[cfg(test)]
            mod test {
                use super::*;

                #[test]
                fn #simple_test_name() {
                    let mut crc = #name::new();
                    crc.add_bytes(b"123456789");
                    assert_eq!(#check, crc.into_inner());
                }

                #[test]
                fn #batch_test_name() {
                    let mut crc = #name::new();
                    crc.add_bytes(b"123");
                    crc.add_bytes(b"4");
                    crc.add_bytes(b"");
                    crc.add_bytes(b"567");
                    crc.add_bytes(b"89");
                    assert_eq!(#check, crc.into_inner());
                }

                #[test]
                fn #empty_test_name() {
                    let mut crc = #name::new();
                    crc.add_bytes(b"");
                    assert_eq!(#init, crc.into_inner());
                }
            }
        }
    }
}

pub fn gen_wordwise_table(spec: &Spec) -> proc_macro2::TokenStream {
    let wordwise =
        WordwiseModel::<byteorder::NativeEndian, { usize::BITS as _ }>::from_spec(spec.clone());

    let (ty, _) = get_ty(spec);

    let tables = wordwise.table.iter().map(|bytes| {
        let bytes = bytes.iter().copied().map(wrap_uint);
        quote! {
            [ #( #bytes ),* ]
        }
    });

    quote! {
        const TABLE_WORD: [[#ty; 256]; 8] = [
            #( #tables ),*
        ];
    }
}

pub fn gen_wordwise(spec: &Spec) -> proc_macro2::TokenStream {
    let suffix = get_suffix(spec);
    let suffix_capitalized = capitalize(&suffix);
    let maybe_underscore = get_underscore(&suffix);
    let name = format_ident!(
        "Crc{}{}{}",
        spec.width,
        maybe_underscore,
        suffix_capitalized
    );
    let simple_test_name = format_ident!("crc{}{}{}_simple", spec.width, maybe_underscore, suffix);
    let batch_test_name = format_ident!("crc{}{}{}_batch", spec.width, maybe_underscore, suffix);
    let empty_test_name = format_ident!("crc{}{}{}_empty", spec.width, maybe_underscore, suffix);

    let init = wrap_uint(spec.init());
    let check = wrap_uint(spec.check);
    let reverser = format_ident!("revlow{}", spec.width);

    let maybe_rev = if !spec.reverse() {
        quote!()
    } else {
        quote!(crc = super::#reverser(crc);)
    };

    let (ty, bits) = get_ty(spec);
    let (word_type, word_bits) = get_word_ty(spec);
    let word_bytes = word_bits >> 3;

    let shift = if spec.width <= 8 {
        8 - spec.width
    } else {
        spec.width - 8
    };
    let shift_tok = wrap_uint(shift as _);

    let low_mask_tok = wrap_uint((word_bytes - 1) as _);

    let initial_bytes = if spec.reflect() {
        let maybe_mask = if spec.width != bits && !spec.reverse() {
            let mask_tok = wrap_uint(ones(spec.width));
            quote!(crc &= #mask_tok;)
        } else {
            quote!()
        };

        let loop_body = if spec.width > 8 {
            quote! {
                crc = (crc >> 8) ^ super::TABLE_BYTE[((crc as u8) ^ data[0]) as usize];
                data = &data[1..];
            }
        } else {
            quote! {
                crc = super::TABLE_BYTE[((crc as u8) ^ data[0]) as usize];
                data = &data[1..];
            }
        };

        quote! {
            #maybe_mask
            while !data.is_empty() && (data.as_ptr() as usize) & #low_mask_tok != 0 {
                #loop_body
            }
        }
    } else if spec.width <= 8 {
        let maybe_shift = if spec.width < 8 {
            quote!(crc <<= #shift_tok;)
        } else {
            quote!()
        };

        quote! {
            #maybe_shift
            while !data.is_empty() && (data.as_ptr() as usize) & #low_mask_tok != 0 {
                crc = super::TABLE_BYTE[((crc as u8) ^ data[0]) as usize];
                data = &data[1..];
            }
        }
    } else {
        quote! {
            while !data.is_empty() && (data.as_ptr() as usize) & #low_mask_tok != 0 {
                crc = (crc << 8) ^ super::TABLE_BYTE[(((crc >> #shift_tok) as u8) ^ data[0]) as usize];
                data = &data[1..];
            }
        }
    };

    let word_body = byteorder::NativeEndian::generate_wordwise_body(spec);

    let final_bytes = if spec.reflect() {
        if spec.width > 8 {
            quote! {
                for byte in data {
                    crc = (crc >> 8) ^ super::TABLE_BYTE[((crc as u8) ^ *byte) as usize];
                }
            }
        } else {
            quote! {
                for byte in data {
                    crc = super::TABLE_BYTE[((crc as u8) ^ *byte) as usize];
                }
            }
        }
    } else if spec.width <= 8 {
        let maybe_shift = if spec.width < 8 {
            quote!(crc >>= #shift_tok;)
        } else {
            quote!()
        };

        quote! {
            for byte in data {
                crc = super::TABLE_BYTE[((crc as u8) ^ *byte) as usize];
            }
            #maybe_shift
        }
    } else {
        let maybe_mask = if spec.width != bits && !spec.reverse() {
            let mask_tok = wrap_uint(ones(spec.width));
            quote!(crc &= #mask_tok;)
        } else {
            quote!()
        };

        quote! {
            for byte in data {
                crc = (crc << 8) ^ super::TABLE_BYTE[(((crc >> #shift_tok) as u8) ^ *byte) as usize];
            }
            #maybe_mask
        }
    };

    let swapper = byteorder::NativeEndian::generate_swapper(spec, &ty, &word_type);

    quote! {
        pub mod wordwise {
            pub struct #name;

            impl #name {
                pub fn new() -> ::crcany::crc::Computer<Self> {
                    ::crcany::crc::Computer::new(#name)
                }
            }

            impl ::crcany::crc::Crc for #name {
                type Int = #ty;

                fn init(&self) -> Self::Int {
                    #init
                }

                fn add_bytes(&self, mut crc: Self::Int, mut data: &[u8]) -> Self::Int {
                    if data.is_empty() {
                        return crc;
                    }

                    #maybe_rev
                    #initial_bytes
                    #word_body
                    #final_bytes
                    #maybe_rev

                    crc
                }
            }

            #swapper

            #[cfg(test)]
            mod test {
                use super::*;

                const INPUT: &[u8] = b"123456789123456789123456789123456789123456789123456789123456789123456789123456789";

                #[test]
                fn #simple_test_name() {
                    for i in 0..9 {
                        let start = 9 * i;
                        let end = start + 9;
                        let mut crc = #name::new();
                        crc.add_bytes(&INPUT[start..end]);
                        let res = crc.into_inner();
                        assert_eq!(#check, res, "failure on permutation {} - 0b{:b} vs 0b{:b}", i, res, #check);
                    }
                }

                #[test]
                fn #batch_test_name() {
                    let mut crc = #name::new();
                    crc.add_bytes(b"123");
                    crc.add_bytes(b"4");
                    crc.add_bytes(b"");
                    crc.add_bytes(b"567");
                    crc.add_bytes(b"89");
                    assert_eq!(#check, crc.into_inner());
                }

                #[test]
                fn #empty_test_name() {
                    let mut crc = #name::new();
                    crc.add_bytes(b"");
                    assert_eq!(#init, crc.into_inner());
                }
            }
        }
    }
}
