// impl_trait - Rust proc macro that significantly reduces boilerplate
// Copyright (C) 2021  Soni L.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.

extern crate proc_macro;
use proc_macro::{TokenStream, TokenTree, Delimiter, Span};
//use syn::parse::{Parse, ParseStream, Result as ParseResult};
use syn::{Generics, GenericParam};
use std::cmp::Ordering;
use quote::ToTokens;

#[proc_macro]
#[allow(unreachable_code)]
pub fn impl_trait(item: TokenStream) -> TokenStream {
    //eprintln!("INPUT: {:#?}", item);
    let mut output: Vec<_> = item.into_iter().collect();
    let attributes: Vec<TokenTree> = {
        let mut pos = 0;
        let mut len = 0;
        let mut in_attr = false;
        while pos != output.len() {
            let tt = &output[pos];
            pos += 1;
            match tt {
                &TokenTree::Punct(ref punct) => {
                    if punct.as_char() == '#' && !in_attr {
                        in_attr = true;
                        continue;
                    }
                }
                &TokenTree::Group(ref group) => {
                    if group.delimiter() == Delimiter::Bracket && in_attr {
                        in_attr = false;
                        len = pos;
                        continue;
                    }
                }
                _ => {}
            }
            break;
        }
        output.drain(0..len).collect()
    };
    //eprintln!("attributes: {:#?}", attributes);
    // check for impl.
    // unsafe impls are only available for traits and are automatically rejected.
    'check_impl: loop { break {
        if let &TokenTree::Ident(ref ident) = &output[0] {
            if format!("{}", ident) == "impl" {
                break 'check_impl;
            }
        }
        panic!("impl_trait! may only be applied to inherent impls");
    } }
    let mut has_where: Option<&TokenTree> = None;
    'check_no_for_before_where: loop { break {
        for tt in &output {
            if let &TokenTree::Ident(ref ident) = tt {
                let formatted = format!("{}", ident);
                if formatted == "where" {
                    has_where = Some(tt);
                    break 'check_no_for_before_where;
                } else if formatted == "for" {
                    panic!("impl_trait! may only be applied to inherent impls");
                }
            }
        }
    } }
    // this is the "where [...]" part, including the "where".
    let mut where_bounds = Vec::new();
    if let Some(where_in) = has_where {
        where_bounds = output.split_last().unwrap().1.into_iter().skip_while(|&tt| {
            !std::ptr::eq(tt, where_in)
        }).cloned().collect();
    }
    let where_bounds = where_bounds;
    drop(has_where);
    let mut count = 0;
    // this is the "<...>" part, immediately after the "impl", and including the "<>".
    let generics = output.split_first().unwrap().1.into_iter().take_while(|&tt| {
        let mut result = count > 0;
        if let &TokenTree::Punct(ref punct) = tt {
            let c = punct.as_char();
            if c == '<' {
                count += 1;
                result = true;
            } else if c == '>' {
                count -= 1;
            }
        }
        result
    }).cloned().collect::<Vec<_>>();
    // so how do you find the target? well...
    // "impl" + [_; generics.len()] + [_; target.len()] + [_; where_bounds.len()] + "{}"
    // we have generics and where_bounds, and the total, so we can easily find target!
    let target_start = 1 + generics.len();
    let target_end = output.len() - 1 - where_bounds.len();
    let target_range = target_start..target_end;
    let target = (&output[target_range]).into_iter().cloned().collect::<Vec<_>>();
    //eprintln!("generics: {:#?}", generics);
    //eprintln!("target: {:#?}", target);
    //eprintln!("where_bounds: {:#?}", where_bounds);
    let items = output.last_mut();
    if let &mut TokenTree::Group(ref mut group) = items.unwrap() {
        // TODO: parse "[unsafe] impl trait" somehow. use syn for it maybe (after swallowing the "trait")
        // luckily for us, there's only one thing that can come after an impl trait: a path
        // (and optional generics).
        // but we can't figure out how to parse the `where`.
        //todo!();
        let span = group.span();
        let mut items = group.stream().into_iter().collect::<Vec<_>>();
        let mut in_unsafe = false;
        let mut in_impl = false;
        let mut in_path = false;
        let mut in_generic = false;
        let mut in_attr = false;
        let mut in_attr_cont = false;
        let mut has_injected_generics = false;
        let mut in_where = false;
        let mut start = 0;
        let mut found: Vec<Vec<TokenTree>> = Vec::new();
        let mut to_remove: Vec<std::ops::Range<usize>> = Vec::new();
        let mut generics_scratchpad = Vec::new();
        let mut count = 0;
        let mut trait_span: Option<Span> = None;
        'main_loop: for (pos, tt) in (&items).into_iter().enumerate() {
            if in_generic {
                // collect the generics
                let mut result = count > 0;
                if let &TokenTree::Punct(ref punct) = tt {
                    let c = punct.as_char();
                    if c == '<' {
                        count += 1;
                        result = true;
                    } else if c == '>' {
                        count -= 1;
                        if count == 0 {
                            in_generic = false;
                            in_path = true;
                        }
                    }
                }
                if result {
                    generics_scratchpad.push(tt.clone());
                    continue;
                }
            }
            if in_path {
                // inject the generics
                if !has_injected_generics {
                    has_injected_generics = true;
                    if generics_scratchpad.is_empty() {
                        found.last_mut().unwrap().extend(generics.clone());
                    } else if generics.is_empty() {
                        found.last_mut().unwrap().extend(generics_scratchpad.clone());
                    } else {
                        // need to *combine* generics. this is not exactly trivial.
                        // thankfully we don't need to worry about defaults on impls.
                        let mut this_generics: Generics = syn::parse(generics_scratchpad.drain(..).collect()).unwrap();
                        let parent_generics: Generics = syn::parse(generics.clone().into_iter().collect()).unwrap();
                        let mut target = parent_generics.params.into_pairs().chain(this_generics.params.clone().into_pairs()).collect::<Vec<_>>();
                        target.sort_by(|a, b| {
                            match (a.value(), b.value()) {
                                (&GenericParam::Lifetime(_), &GenericParam::Const(_)) => Ordering::Less,
                                (&GenericParam::Type(_), &GenericParam::Const(_)) => Ordering::Less,
                                (&GenericParam::Lifetime(_), &GenericParam::Type(_)) => Ordering::Less,
                                (&GenericParam::Lifetime(_), &GenericParam::Lifetime(_)) => Ordering::Equal,
                                (&GenericParam::Type(_), &GenericParam::Type(_)) => Ordering::Equal,
                                (&GenericParam::Const(_), &GenericParam::Const(_)) => Ordering::Equal,
                                (&GenericParam::Type(_), &GenericParam::Lifetime(_)) => Ordering::Greater,
                                (&GenericParam::Const(_), &GenericParam::Type(_)) => Ordering::Greater,
                                (&GenericParam::Const(_), &GenericParam::Lifetime(_)) => Ordering::Greater,
                            }
                        });
                        // just need to fix the one Pair::End in the middle of the thing.
                        for item in &mut target {
                            if matches!(item, syn::punctuated::Pair::End(_)) {
                                let value = item.value().clone();
                                *item = syn::punctuated::Pair::Punctuated(value, syn::token::Comma { spans: [trait_span.unwrap().into()] });
                                break;
                            }
                        }
                        this_generics.params = target.into_iter().collect();
                        let new_generics = TokenStream::from(this_generics.into_token_stream());
                        found.last_mut().unwrap().extend(new_generics);
                    }
                }
                in_generic = false;
                if let &TokenTree::Ident(ref ident) = tt {
                    let formatted = format!("{}", ident);
                    if count == 0 && formatted == "where" {
                        in_path = false;
                        in_where = true;
                        // add "for"
                        found.last_mut().unwrap().push(proc_macro::Ident::new("for", trait_span.unwrap()).into());
                        // add Target
                        found.last_mut().unwrap().extend(target.clone());
                        // *then* add the "where" (from the impl-trait)
                        found.last_mut().unwrap().push(tt.clone());
                        // and the parent bounds (except the "where")
                        found.last_mut().unwrap().extend((&where_bounds).into_iter().skip(1).cloned());
                        // also make sure that there's an ',' at the correct place
                        if let Some(&TokenTree::Punct(ref x)) = where_bounds.last() {
                            if x.as_char() == ',' {
                                continue 'main_loop;
                            }
                        }
                        found.last_mut().unwrap().push(proc_macro::Punct::new(',', proc_macro::Spacing::Alone).into());
                        continue 'main_loop;
                    }
                }
                if let &TokenTree::Punct(ref punct) = tt {
                    let c = punct.as_char();
                    if c == '<' {
                        count += 1;
                    } else if c == '>' {
                        // this is broken so just give up
                        // FIXME better error handling
                        if count == 0 {
                            in_path = false;
                            continue 'main_loop;
                        }
                        count -= 1;
                    }
                }
                if let &TokenTree::Group(ref group) = tt {
                    if group.delimiter() == Delimiter::Brace && count == 0 {
                        to_remove.push(start..pos+1);
                        // add "for"
                        found.last_mut().unwrap().push(proc_macro::Ident::new("for", tt.span()).into());
                        // add Target
                        found.last_mut().unwrap().extend(target.clone());
                        // and the parent bounds (including the "where")
                        found.last_mut().unwrap().extend(where_bounds.clone());
                        in_path = false;
                        in_where = false;
                        // fall through to add the block
                    }
                }
                found.last_mut().unwrap().push(tt.clone());
                continue 'main_loop;
            }
            if in_where {
                // just try to find the block, and add all the stuff.
                if let &TokenTree::Punct(ref punct) = tt {
                    let c = punct.as_char();
                    if c == '<' {
                        count += 1;
                    } else if c == '>' {
                        // this is broken so just give up
                        // FIXME better error handling
                        if count == 0 {
                            in_where = false;
                            continue 'main_loop;
                        }
                        count -= 1;
                    }
                }
                if let &TokenTree::Group(ref group) = tt {
                    if group.delimiter() == Delimiter::Brace && count == 0 {
                        // call it done!
                        to_remove.push(start..pos+1);
                        in_where = false;
                    }
                }
                found.last_mut().unwrap().push(tt.clone());
                continue 'main_loop;
            }
            if found.len() == to_remove.len() {
                found.push(Vec::new());
                in_unsafe = false;
                in_impl = false;
                in_where = false;
                in_path = false;
                in_attr_cont = false;
                in_generic = false;
                has_injected_generics = false;
                count = 0;
            }
            match tt {
                &TokenTree::Ident(ref ident) => {
                    let formatted = format!("{}", ident);
                    if formatted == "unsafe" && !in_impl {
                        found.last_mut().unwrap().push(tt.clone());
                        if !in_attr_cont {
                            start = pos;
                        }
                        in_attr = false;
                        in_unsafe = true;
                        continue;
                    } else if formatted == "impl" && !in_impl {
                        if !in_attr_cont && !in_unsafe {
                            start = pos;
                        }
                        found.last_mut().unwrap().push(tt.clone());
                        in_unsafe = false;
                        in_attr = false;
                        in_impl = true;
                        continue;
                    } else if formatted == "trait" && in_impl {
                        // swallowed. doesn't go into found.
                        trait_span = Some(tt.span());
                        in_generic = true;
                        in_path = true;
                        in_impl = false;
                        has_injected_generics = false;
                        continue;
                    }
                },
                &TokenTree::Punct(ref punct) => {
                    if punct.as_char() == '#' && !in_attr {
                        found.last_mut().unwrap().push(tt.clone());
                        if !in_attr_cont {
                            start = pos;
                        }
                        in_attr = true;
                        continue;
                    }
                }
                &TokenTree::Group(ref group) => {
                    if group.delimiter() == Delimiter::Bracket && in_attr {
                        found.last_mut().unwrap().push(tt.clone());
                        in_attr = false;
                        in_attr_cont = true;
                        continue;
                    }
                }
                _ => {}
            }
            found.truncate(to_remove.len());
            in_unsafe = false;
            in_impl = false;
            in_where = false;
            in_path = false;
            in_attr_cont = false;
            in_generic = false;
            has_injected_generics = false;
            count = 0;
        }
        // must be iterated backwards
        for range in to_remove.into_iter().rev() {
            items.drain(range);
        }
        *group = proc_macro::Group::new(group.delimiter(), items.into_iter().collect());
        group.set_span(span);
        output.extend(found.into_iter().flatten());
    }
    drop(generics);
    drop(target);
    drop(where_bounds);
    //eprintln!("attributes: {:#?}", attributes);
    //eprintln!("OUTPUT: {:#?}", output);
    //eprintln!("OUTPUT: {}", (&output).into_iter().cloned().collect::<TokenStream>());
    attributes.into_iter().chain(output.into_iter()).collect()
}
