import datetime
import os.path
import re
import string
import sys

from rust import RustHelperBackend
from stone import ir
from stone.backends.python_helpers import fmt_class as fmt_py_class


class Permissions(object):
    @property
    def permissions(self):
        # For generating tests, make sure we include any internal
        # fields/structs if we're using internal specs. If we're not using
        # internal specs, this is a no-op, so just do it all the time.  Note
        # that this only needs to be done for json serialization, the struct
        # definitions will include all fields, all the time.
        return ['internal']


class TestBackend(RustHelperBackend):
    def __init__(self, target_folder_path, args):
        super(TestBackend, self).__init__(target_folder_path, args)

        # Don't import other generators until here, otherwise stone.cli will
        # call them with its own arguments, in addition to the TestBackend.
        from stone.backends.python_types import PythonTypesBackend
        self.target_path = target_folder_path
        self.ref_path = os.path.join(target_folder_path, 'reference')
        self.reference = PythonTypesBackend(self.ref_path, args + ["--package", "reference"])
        self.reference_impls = {}

    def make_test_value(self, typ):
        if ir.is_struct_type(typ):
            if typ.has_enumerated_subtypes():
                return [TestPolymorphicStruct(self, typ, self.reference_impls, variant)
                    for variant in typ.get_enumerated_subtypes()]
            else:
                vals = [TestStruct(self, typ, self.reference_impls)]
                if typ.all_optional_fields:
                    # If any fields are optional, also emit a test struct that lacks all optional fields.
                    # This helps catch backwards-compat issues as well as checking serialization of None.
                    vals += [TestStruct(self, typ, self.reference_impls, no_optional_fields=True)]
                return vals
        elif ir.is_union_type(typ):
            return [TestUnion(self, typ, self.reference_impls, variant)
                for variant in typ.all_fields]
        else:
            raise RuntimeError(u'ERROR: type {} is neither struct nor union'
                                .format(typ))

    def generate(self, api):
        print(u'Generating Python reference code')
        self.reference.generate(api)
        with self.output_to_relative_path('reference/__init__.py'):
            self.emit(u'# this is the Stone-generated reference Python SDK')

        print(u'Loading reference code:')
        sys.path.insert(0, self.target_path)
        sys.path.insert(1, "stone")
        from stone.backends.python_rsrc.stone_serializers import json_encode
        for ns in api.namespaces:
            print('\t' + ns)
            python_ns = ns
            if ns == 'async':
                # hack to work around 'async' being a Python3 keyword
                python_ns = 'async_'
            self.reference_impls[ns] = __import__('reference.'+python_ns).__dict__[python_ns]

        print(u'Generating test code')
        for ns in api.namespaces.values():
            ns_name = self.namespace_name(ns)
            with self.output_to_relative_path(ns_name + '.rs'):
                self._emit_header()
                for typ in ns.data_types:
                    self._emit_tests(ns, typ, json_encode)

                    if self.is_closed_union(typ):
                        self._emit_closed_union_test(ns, typ)

        with self.output_to_relative_path('mod.rs'):
            self._emit_header()
            for ns in api.namespaces:
                self.emit(u'#[cfg(feature = "dbx_{}")]'.format(ns))
                self.emit(u'mod {};'.format(self.namespace_name_raw(ns)))
                self.emit()

    def _emit_header(self):
        self.emit(u'// DO NOT EDIT')
        self.emit(u'// This file was @generated by Stone')
        self.emit()
        self.emit(u'#![allow(bad_style)]')
        self.emit()
        self.emit(u'#![allow(')
        self.emit(u'    clippy::float_cmp,')
        self.emit(u'    clippy::unreadable_literal,')
        self.emit(u'    clippy::cognitive_complexity,')
        self.emit(u'    clippy::collapsible_match,')
        self.emit(u'    clippy::bool_assert_comparison')
        self.emit(u')]')
        self.emit()

    def _emit_tests(self, ns, typ, json_encode):
        ns_name = self.namespace_name(ns)
        type_name = self.struct_name(typ)

        # The general idea here is to instantiate each type using the reference
        # Python code, put some random data in the fields, serialize it to
        # JSON, emit the JSON into the Rust test, have Rust deserialize it, and
        # emit assertions that the fields match. Then have Rust re-serialize to
        # JSON and desereialize it again, then check the fields of the
        # newly-deserialized struct. This verifies Rust's serializer.

        for test_value in self.make_test_value(typ):
            pyname = fmt_py_class(typ.name)

            json = json_encode(
                self.reference_impls[ns.name].__dict__[pyname + '_validator'],
                test_value.value,
                Permissions())

            # "other" is a hardcoded, special-cased tag used by Stone for the
            # catch-all variant of open unions. Let's rewrite it to something
            # else, to test that the unknown variant logic actually works.
            # Unfortunately this requires mega-hax of rewriting the JSON text,
            # because the Python serializer won't let us give an arbitrary
            # variant name.
            json = json.replace(
                    '{".tag": "other"',
                    '{".tag": "dropbox-sdk-rust-bogus-test-variant"')

            with self._test_fn(type_name + test_value.test_suffix()):
                self.emit(u'let json = r#"{}"#;'.format(json))
                self.emit(u'let x = ::serde_json::from_str::<::dropbox_sdk::{}::{}>(json).unwrap();'
                        .format(ns_name,
                                self.struct_name(typ)))
                test_value.emit_asserts(self, 'x')
                self.emit(u'assert_eq!(x, x.clone());')

                if test_value.is_serializable():
                    # now serialize it back to JSON, deserialize it again, and
                    # test it again.
                    self.emit()
                    self.emit(u'let json2 = ::serde_json::to_string(&x).unwrap();')
                    de = u'::serde_json::from_str::<::dropbox_sdk::{}::{}>(&json2).unwrap()' \
                        .format(ns_name,
                                self.struct_name(typ))

                    if typ.all_fields:
                        self.emit(u'let x2 = {};'.format(de))
                        test_value.emit_asserts(self, 'x2')
                        self.emit(u'assert_eq!(x, x2);')
                    else:
                        self.emit(u'{};'.format(de))
                else:
                    # assert that serializing it returns an error
                    self.emit(u'assert!(::serde_json::to_string(&x).is_err());')
            self.emit()

    def _emit_closed_union_test(self, ns, typ):
        ns_name = self.namespace_name(ns)
        type_name = self.struct_name(typ)
        with self._test_fn("ClosedUnion_" + type_name):
            self.emit(u'// This test ensures that an exhaustive match compiles.')
            self.emit(u'let x: Option<::dropbox_sdk::{}::{}> = None;'.format(
                ns_name, self.enum_name(typ)))
            self.emit(u'match x {')
            with self.indent():
                var_exps = []
                for variant in self.get_enum_variants(typ):
                    v_name = self.enum_variant_name(variant)
                    var_exp = u'::dropbox_sdk::{}::{}::{}'.format(
                        ns_name, type_name, v_name)
                    if not ir.is_void_type(variant.data_type):
                        var_exp += u'(_)'
                    var_exps += [var_exp]

                self.generate_multiline_list(
                    [u'None'] + [u'Some({})'.format(exp) for exp in var_exps],
                    sep=' | ',
                    skip_last_sep=True,
                    delim=('', ''),
                    after=' => ()')
            self.emit(u'}')
        self.emit()

    def _test_fn(self, name):
        self.emit(u'#[test]')
        return self.emit_rust_function_def(u'test_' + name)


class TestField(object):
    def __init__(self, name, python_value, test_value, stone_type, option):
        self.name = name
        self.value = python_value
        self.test_value = test_value
        self.typ = stone_type
        self.option = option

    def emit_assert(self, codegen, expression_path):
        extra = ('.' + self.name) if self.name else ''
        if self.option:
            expression = '(*' + expression_path + extra + '.as_ref().unwrap())'
        else:
            expression = expression_path + extra

        if isinstance(self.test_value, TestValue):
            self.test_value.emit_asserts(codegen, expression)
        elif ir.is_string_type(self.typ):
            codegen.emit(u'assert_eq!({}.as_str(), r#"{}"#);'.format(
                expression, self.value))
        elif ir.is_numeric_type(self.typ):
            codegen.emit(u'assert_eq!({}, {});'.format(
                expression, self.value))
        elif ir.is_boolean_type(self.typ):
            codegen.emit(u'assert_eq!({}, {});'.format(
                expression, 'true' if self.value else 'false'))
        elif ir.is_timestamp_type(self.typ):
            codegen.emit(u'assert_eq!({}.as_str(), "{}");'.format(
                expression, self.value.strftime(self.typ.format)))
        elif ir.is_bytes_type(self.typ):
            codegen.emit(u'assert_eq!(&{}, &[{}]);'.format(
                expression, ",".join(str(x) for x in self.value)))
        else:
            raise RuntimeError(u'Error: assetion unhandled for type {} of field {} with value {}'
                               .format(self.typ, self.name, self.value))


class TestValue(object):
    def __init__(self, rust_generator):
        self.rust_generator = rust_generator
        self.fields = []
        self.value = None

    def emit_asserts(self, codegen, expression_path):
        raise NotImplementedError('you\'re supposed to implement TestValue.emit_asserts')

    def is_serializable(self):
        # Not all types can round-trip back from Rust to JSON.
        return True

    def test_suffix(self):
        return ""


class TestStruct(TestValue):
    def __init__(self, rust_generator: TestBackend, stone_type: ir.Struct, reference_impls, no_optional_fields=False):
        super(TestStruct, self).__init__(rust_generator)

        if stone_type.has_enumerated_subtypes():
            stone_type = stone_type.get_enumerated_subtypes()[0].data_type

        self._stone_type = stone_type
        self._reference_impls = reference_impls
        self._no_optional_fields = no_optional_fields

        py_name = fmt_py_class(stone_type.name)
        try:
            self.value = reference_impls[stone_type.namespace.name].__dict__[py_name]()
        except Exception as e:
            raise RuntimeError(u'Error instantiating value for {}: {}'.format(stone_type.name, e))

        for field in (stone_type.all_required_fields if no_optional_fields else stone_type.all_fields):
            field_value = make_test_field(
                    field.name, field.data_type, rust_generator, reference_impls)
            if field_value is None:
                raise RuntimeError(u'Error: incomplete type generated: {}'.format(stone_type.name))
            self.fields.append(field_value)
            try:
                setattr(self.value, field.name, field_value.value)
            except Exception as e:
                raise RuntimeError(u'Error generating value for {}.{}: {}'
                                   .format(stone_type.name, field.name, e))

    def emit_asserts(self, codegen, expression_path):
        for field in self.fields:
            field.emit_assert(codegen, expression_path)

    def test_suffix(self):
        if self._no_optional_fields:
            return "_OnlyRequiredFields"
        else:
            return ""

class TestUnion(TestValue):
    def __init__(self, rust_generator, stone_type, reference_impls, variant):
        super(TestUnion, self).__init__(rust_generator)
        self._stone_type = stone_type
        self._reference_impls = reference_impls
        self._rust_name = rust_generator.enum_name(stone_type)
        self._rust_variant_name = rust_generator.enum_variant_name_raw(variant.name)
        self._rust_namespace_name = rust_generator.namespace_name(stone_type.namespace)
        self._variant = variant

        # We can't serialize the catch-all variant.
        self._is_serializable = not variant.catch_all

        self._inner_value = make_test_field(
            None, self._variant.data_type, rust_generator, reference_impls)

        if self._inner_value is None:
            raise RuntimeError(u'Error generating union variant value for {}.{}'
                               .format(stone_type.name, variant.name))

        self.value = self.get_from_inner_value(variant.name, self._inner_value)

    def get_from_inner_value(self, variant_name, generated_field):
        pyname = fmt_py_class(self._stone_type.name)
        try:
            return self._reference_impls[self._stone_type.namespace.name] \
                    .__dict__[pyname](variant_name, generated_field.value)
        except Exception as e:
            raise RuntimeError(u'Error generating value for {}.{}: {}'
                               .format(self._stone_type.name, variant_name, e))

    def has_other_variants(self):
        return len(self._stone_type.all_fields) > 1 or not self._stone_type.closed

    def emit_asserts(self, codegen, expression_path):
        if expression_path[0] == '(' and expression_path[-1] == ')':
                expression_path = expression_path[1:-1]  # strip off superfluous parens

        with codegen.block(u'match {}'.format(expression_path)):
            if ir.is_void_type(self._variant.data_type):
                codegen.emit(u'::dropbox_sdk::{}::{}::{} => (),'.format(
                    self._rust_namespace_name,
                    self._rust_name,
                    self._rust_variant_name))
            elif codegen.is_nullary_struct(self._variant.data_type):
                codegen.emit(u'::dropbox_sdk::{}::{}::{}(..) => (), // nullary struct'.format(
                    self._rust_namespace_name,
                    self._rust_name,
                    self._rust_variant_name))
            else:
                with codegen.block(u'::dropbox_sdk::{}::{}::{}(ref v) =>'.format(
                        self._rust_namespace_name,
                        self._rust_name,
                        self._rust_variant_name)):
                    self._inner_value.emit_assert(codegen, '(*v)')

            if self.has_other_variants():
                codegen.emit(u'_ => panic!("wrong variant")')

    def is_serializable(self):
        return not self._variant.catch_all

    def test_suffix(self):
        return "_" + self._rust_variant_name


class TestPolymorphicStruct(TestUnion):
    def get_from_inner_value(self, variant_name, generated_field):
        return generated_field.value

    def has_other_variants(self):
        return len(self._stone_type.get_enumerated_subtypes()) > 1 \
                or self._stone_type.is_catch_all()


class TestList(TestValue):
    def __init__(self, rust_generator, stone_type, reference_impls):
        super(TestList, self).__init__(rust_generator)
        self._stone_type = stone_type
        self._reference_impls = reference_impls

        self._inner_value = make_test_field(None, stone_type, rust_generator, reference_impls)
        if self._inner_value is None:
            raise RuntimeError(u'Error generating value for list of {}'.format(stone_type.name))

        self.value = self._inner_value.value

    def emit_asserts(self, codegen, expression_path):
        self._inner_value.emit_assert(codegen, expression_path + '[0]')


class TestMap(TestValue):
    def __init__(self, rust_generator, stone_type, reference_impls):
        super(TestMap, self).__init__(rust_generator)
        self._stone_type = stone_type
        self._reference_impls = reference_impls
        self._key_value = make_test_field(None, stone_type.key_data_type, rust_generator,
                                          reference_impls)
        self._val_value = make_test_field(None, stone_type.value_data_type, rust_generator,
                                          reference_impls)
        self.value = {self._key_value.value: self._val_value.value}

    def emit_asserts(self, codegen, expression_path):
        key_str = u'["{}"]'.format(self._key_value.value)
        self._val_value.emit_assert(codegen, expression_path + key_str)


def make_test_field(field_name, stone_type, rust_generator, reference_impls):
    rust_name = rust_generator.field_name_raw(field_name) if field_name is not None else None
    typ, option = ir.unwrap_nullable(stone_type)

    inner = None
    value = None
    if ir.is_struct_type(typ):
        if typ.has_enumerated_subtypes():
            variant = typ.get_enumerated_subtypes()[0]
            inner = TestPolymorphicStruct(rust_generator, typ, reference_impls, variant)
        else:
            inner = TestStruct(rust_generator, typ, reference_impls)
        value = inner.value
    elif ir.is_union_type(typ):
        # Pick the first tag.
        # We could generate tests for them all, but it would lead to a huge explosion of tests, and
        # the types themselves are tested elsewhere.
        if len(typ.fields) == 0:
            # there must be a parent type; go for it
            variant = typ.all_fields[0]
        else:
            variant = typ.fields[0]
        inner = TestUnion(rust_generator, typ, reference_impls, variant)
        value = inner.value
    elif ir.is_list_type(typ):
        inner = TestList(rust_generator, typ.data_type, reference_impls)
        value = [inner.value]
    elif ir.is_map_type(typ):
        inner = TestMap(rust_generator, typ, reference_impls)
        value = inner.value
    elif ir.is_string_type(typ):
        if typ.pattern:
            value = Unregex(typ.pattern, typ.min_length).generate()
        elif typ.min_length:
            value = 'a' * typ.min_length
        else:
            value = 'something'
    elif ir.is_numeric_type(typ):
        value = typ.max_value or typ.maximum or 1e307
    elif ir.is_boolean_type(typ):
        value = True
    elif ir.is_timestamp_type(typ):
        value = datetime.datetime.utcfromtimestamp(2**33 - 1)
    elif ir.is_bytes_type(typ):
        value = bytes([0,1,2,3,4,5])
    elif not ir.is_void_type(typ):
        raise RuntimeError(u'Error: unhandled field type of {}: {}'.format(field_name, typ))
    return TestField(rust_name, value, inner, typ, option)


class Unregex(object):
    """
    Generate a minimal string that passes a regex and optionally is of a given
    minimum length.
    """
    def __init__(self, regex_string, min_len=None):
        self._min_len = min_len
        self._group_refs = {}
        self._tokens = re.sre_parse.parse(regex_string)

    def generate(self):
        return self._generate(self._tokens)

    def _generate(self, tokens):
        result = ''
        for (opcode, argument) in tokens:
            opcode = str(opcode).lower()
            if opcode == 'literal':
                result += chr(argument)
            elif opcode == 'at':
                pass  # start or end anchor; nothing to add
            elif opcode == 'in':
                if argument[0][0] == 'negate':
                    rejects = []
                    for opcode, reject in argument[1:]:
                        if opcode == 'literal':
                            rejects.append(chr(reject))
                        elif opcode == 'range':
                            for i in range(reject[0], reject[1]):
                                rejects.append(chr(i))
                    choices = list(set(string.printable)
                                   .difference(string.whitespace)
                                   .difference(rejects))
                    result += choices[0]
                else:
                    result += self._generate([argument[0]])
            elif opcode == 'any':
                result += '*'
            elif opcode == 'range':
                result += chr(argument[0])
            elif opcode == 'branch':
                result += self._generate(argument[1][0])
            elif opcode == 'subpattern':
                group_number, add_flags, del_flags, sub_tokens = argument
                sub_result = self._generate(sub_tokens)
                self._group_refs[group_number] = sub_result
                result += sub_result
            elif opcode == 'groupref':
                result += self._group_refs[argument]
            elif opcode == 'min_repeat' or opcode == 'max_repeat':
                min_repeat, max_repeat, sub_tokens = argument
                if self._min_len:
                    n = max(min_repeat, min(self._min_len, max_repeat))
                else:
                    n = min_repeat
                sub_result = self._generate(sub_tokens) if n != 0 else ''
                result += str(sub_result) * n
            elif opcode == 'category':
                if argument == 'category_digit':
                    result += '0'
                elif argument == 'category_not_space':
                    result += '!'
                else:
                    raise NotImplementedError('category {}'.format(argument))
            elif opcode == 'assert_not':
                # let's just hope for the best...
                pass
            elif opcode == 'assert' or opcode == 'negate':
                # note: 'negate' is handled in the 'in' opcode
                raise NotImplementedError('regex opcode {} not implemented'.format(opcode))
            else:
                raise NotImplementedError('unknown regex opcode: {}'.format(opcode))
        return result
