#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#include "pns.h"

/* MACROS */

#define CHECK_MEMORY_ERROR(result) if (!result) {\
    fprintf(stderr, "Out of memory");\
    exit(EXIT_FAILURE);\
}

#define RESIZE_ARRAY(count, size, value) if (count > size) {\
    size = count + count / 5;\
    value = realloc(value, size * sizeof(*value));\
    CHECK_MEMORY_ERROR(value);\
}

#define ADD_ELEMENT(name, count, size, value) {\
    name = count++;\
    RESIZE_ARRAY(count, size, value)\
}

/* UNEXPORTED */

/** Node **/

void pnsCreateNode(PnsNode* node) {
    node->prev_count = 0;
    node->prev_size = 0;
    node->prev = NULL;
    node->next_count = 0;
    node->next_size = 0;
    node->next = NULL;
}

void pnsCloneNode(PnsNode* node_clone, const PnsNode* node) {
    node_clone->prev_count = node->prev_count;
    node_clone->prev_size = node->prev_size;
    node_clone->prev = malloc(node->prev_size * sizeof(uint32_t));
    CHECK_MEMORY_ERROR(node_clone->prev);
    memcpy(node_clone->prev, node->prev, node->prev_count * sizeof(uint32_t));
    node_clone->next_count = node->next_count;
    node_clone->next_size = node->next_size;
    node_clone->next = malloc(node->next_size * sizeof(uint32_t));
    CHECK_MEMORY_ERROR(node_clone->next);
    memcpy(node_clone->next, node->next, node->next_count * sizeof(uint32_t));
}

void pnsDeleteNode(PnsNode* node) {
    free(node->prev);
    free(node->next);
    pnsCreateNode(node);
}

/** TransitionList **/

void PnsIndexList_push(PnsIndexList** place, uint32_t index) {
    PnsIndexList* new_list = malloc(sizeof(PnsIndexList));
    CHECK_MEMORY_ERROR(new_list);
    new_list->index = index;
    new_list->next = *place;
    *place = new_list;
}

void pnsCloneTransitionList(PnsIndexList** clone_place, PnsIndexList* list) {
    while (list != NULL) {
        PnsIndexList* new_list = malloc(sizeof(PnsIndexList));
        CHECK_MEMORY_ERROR(new_list);
        new_list->index = list->index;
        new_list->next = NULL;
        *clone_place = new_list;
        clone_place = &new_list->next;
        list = list->next;
    }
}

uint32_t PnsIndexList_pop(PnsIndexList** place) {
    uint32_t tid = (*place)->index;
    PnsIndexList* next = (*place)->next;
    free(*place);
    *place = next;
    return tid;
}

bool PnsIndexList_appendNew(PnsIndexList** place, uint32_t tid) {
start:
    if (*place == NULL) {
        PnsIndexList_push(place, tid);
        return true;
    }
    else if ((*place)->index == tid) return false;
    else place = &(*place)->next;
    goto start;
}

bool PnsIndexList_remove(PnsIndexList** place, uint32_t tid) {
start:
    if (*place == NULL) return false;
    else if ((*place)->index == tid) {
        PnsIndexList* current = *place;
        *place = (*place)->next;
        free(current);
        return true;
    }
    else place = &(*place)->next;
    goto start;
}

void pnsDestroyTransitionList(PnsIndexList* list) {
start:
    if (list == NULL) return;
    PnsIndexList* current = list;
    list = list->next;
    free(current);
    goto start;
}

void pnsClearTransitionList(PnsIndexList** list) {
    pnsDestroyTransitionList(*list);
    *list = NULL;
}

/** Fire Changes **/

void pnsCreateFireChanges(PnsFireChanges* changes) {
    changes->count = 0;
    changes->added_count = 0;
    changes->removed_count = 0;
    changes->active = NULL;
    changes->added = NULL;
    changes->removed = NULL;
}

void pnsDestroyFireChanges(PnsFireChanges* changes) {
    pnsDestroyTransitionList(changes->active);
    pnsDestroyTransitionList(changes->added);
    pnsDestroyTransitionList(changes->removed);
}

void pnsFireChanges_add(PnsFireChanges* changes, uint32_t tid) {
    bool added = PnsIndexList_appendNew(&changes->active, tid);
    changes->count += added;
    if (added) {
        if (!PnsIndexList_remove(&changes->removed, tid)) {
            ++changes->added_count;
            PnsIndexList_push(&changes->added, tid);
        }
        else --changes->removed_count;
    }
}
void pnsFireChanges_remove(PnsFireChanges* changes, uint32_t tid) {
    bool removed = PnsIndexList_remove(
        &changes->active,
        tid
    );
    changes->count -= removed;
    if (removed) {
        if (!PnsIndexList_remove(&changes->added, tid)) {
            ++changes->removed_count;
            PnsIndexList_push(&changes->removed, tid);
        }
        else --changes->added_count;
    }
}

/** helpers **/

/*** Array ***/

uint32_t sortIndex(uint32_t id, uint32_t count, uint32_t* ids) {
    for (uint32_t i = 0; i < count; ++i) {
        uint32_t current_id = ids[i];
        if (current_id > id) {
            ids[i] = id;
            id = current_id;
        }
    }
    return id;
}

void removeIndex(uint32_t id, uint32_t* count, uint32_t* ids) {
    if (*count == 0) return;
    --*count;

    uint32_t i = 0;
    if (ids[*count] == id) return;
    while (i < *count) {
        if (ids[i] == id) {
            ids[i] = ids[i + 1];
            break;
        }
        ++i;
    }
    if (i == *count) {
        ++*count;
        return;
    }
    while (++i < *count) ids[i] = ids[i + 1];
}

/*** token calculations ***/

uint32_t token_count(const PnsState* state, const PnsNet* net, uint32_t pid) {
    return state->token_counts[pid] + net->initial_token_counts[pid];
}

void addFireable(PnsState* state, uint32_t tid) {
    pnsFireChanges_add(&state->fire, tid);
}

void removeFireable(PnsState* state, uint32_t tid) {
    pnsFireChanges_remove(&state->fire, tid);
}

void addUnfireable(PnsState* state, uint32_t tid) {
    pnsFireChanges_add(&state->unfire, tid);
}

void removeUnfireable(PnsState* state, uint32_t tid) {
    pnsFireChanges_remove(&state->unfire, tid);
}

void recalculate_transition_forward(PnsState* state, const PnsNet* net, uint32_t tid) {
    PnsNode* transition = &net->transitions[tid];
    bool now_valid = true;
    for (uint32_t pid = 0; pid < transition->prev_count; ++pid) {
        if (token_count(state, net, transition->prev[pid]) == 0) {
            now_valid = false;
            break;
        }
    }
    if (now_valid) addFireable(state, tid);
    else removeFireable(state, tid);
}

void recalculate_transition_backward(PnsState* state, const PnsNet* net, uint32_t tid) {
    PnsNode* transition = &net->transitions[tid];
    bool now_valid = true;
    if(state->call_counts[tid] == 0) now_valid = false;
    else for (uint32_t pid = 0; pid < transition->next_count; ++pid) {
        if (token_count(state, net, transition->next[pid]) == 0) {
            now_valid = false;
            break;
        }
    }
    if (now_valid) addUnfireable(state, tid);
    else removeUnfireable(state, tid);
}

void calculate_transition_list(PnsState* state, const PnsNet* net) {
    for (uint32_t tid = 0; tid < net->transition_count; ++tid) {
        PnsNode* transition = &net->transitions[tid];
        {
            bool now_valid = true;
            for (uint32_t pid = 0; pid < transition->prev_count; ++pid) {
                if (token_count(state, net, transition->prev[pid]) == 0) {
                    now_valid = false;
                    break;
                }
            }
            if (now_valid) {
                state->fire.count += 1;
                PnsIndexList_push(&state->fire.active, tid);
                state->fire.added_count += 1;
                PnsIndexList_push(&state->fire.added, tid);
            }
        }

        {
            bool now_valid = true;
            if(state->call_counts[tid] == 0)
                continue;
            for (uint32_t pid = 0; pid < transition->next_count; ++pid) {
                if (token_count(state, net, transition->next[pid]) == 0) {
                    now_valid = false;
                    break;
                }
            }
            if (now_valid) {
                state->unfire.count += 1;
                PnsIndexList_push(&state->unfire.active, tid);
                state->unfire.added_count += 1;
                PnsIndexList_push(&state->unfire.added, tid);
            }
        }
    }
}

/* HEADER */

/** Net **/

/*** default ***/

void pnsCreateNet(PnsNet* net) {
    net->transition_count = 0;
    net->transitions_size = 0;
    net->transitions = NULL;
    net->reusable_transitions = NULL;
    net->place_count = 0;
    net->places_size = 0;
    net->places = NULL;
    net->initial_token_counts = NULL;
    net->reusable_places = NULL;
    net->dirt = NULL;
    net->reverseDirt = NULL;
}

void pnsCloneNet(PnsNet* net_clone, const PnsNet* net) {
    net_clone->transition_count = net->transition_count;
    net_clone->transitions_size = net->transition_count;
    net_clone->transitions = malloc(net->transition_count * sizeof(PnsNode));
    CHECK_MEMORY_ERROR(net_clone->transitions);
    for (uint32_t tid = 0; tid < net->transition_count; ++tid) {
        pnsCloneNode(&net_clone->transitions[tid], &net->transitions[tid]);
    }
    net_clone->place_count = net->place_count;
    net_clone->places_size = net->place_count;
    net_clone->places = malloc(net->place_count * sizeof(PnsNode));
    CHECK_MEMORY_ERROR(net_clone->places);
    for (uint32_t pid = 0; pid < net->place_count; ++pid) {
        pnsCloneNode(&net_clone->places[pid], &net->places[pid]);
    }
    net_clone->initial_token_counts = malloc(net->place_count * sizeof(uint32_t));
    CHECK_MEMORY_ERROR(net_clone->initial_token_counts);
    memcpy(net_clone->initial_token_counts, net->initial_token_counts, net->place_count * sizeof(uint32_t));
    net_clone->reusable_transitions = NULL;
    pnsCloneTransitionList(&net_clone->reusable_transitions, net->reusable_transitions);
    net_clone->reusable_places = NULL;
    pnsCloneTransitionList(&net_clone->reusable_places, net->reusable_places);
    net_clone->dirt = NULL;
    net_clone->reverseDirt = NULL;
}

bool pnsLoadNet(PnsNet* net, uint32_t count, const uint32_t* values) {
    uint32_t required_count = 2;

    if (count < required_count) return false;

    uint32_t index = 0;

    net->place_count = values[index++];
    required_count += net->place_count;
    if (count < required_count) return false;
    net->places_size = net->place_count;
    net->places = malloc(net->place_count * sizeof(PnsNode));
    CHECK_MEMORY_ERROR(net->places);
    net->initial_token_counts = malloc(net->place_count * sizeof(uint32_t));
    CHECK_MEMORY_ERROR(net->initial_token_counts);
    for (uint32_t pid = 0; pid < net->place_count; ++pid)
        net->initial_token_counts[pid] = values[index++];

    net->transition_count = values[index++];
    required_count += 2 * net->transition_count;
    if (count < required_count) return false;
    net->transitions_size = net->transition_count;
    net->transitions = malloc(net->transition_count * sizeof(PnsNode));
    CHECK_MEMORY_ERROR(net->transitions);
    net->reusable_transitions = NULL;
    for (uint32_t tid = 0; tid < net->transition_count; ++tid) {
        PnsNode* transition = &net->transitions[tid];

        transition->next_count = values[index++];
        required_count += transition->next_count;
        if (count < required_count) return false;
        transition->next_size = transition->next_count;
        transition->next = malloc(transition->next_count * sizeof(uint32_t));
        CHECK_MEMORY_ERROR(transition->next);
        for (uint32_t i = 0; i < transition->next_count; ++i)
            transition->next[i] = values[index++];

        transition->prev_count = values[index++];
        required_count += transition->prev_count;
        if (count < required_count) return false;
        transition->prev_size = transition->prev_count;
        transition->prev = malloc(transition->prev_count * sizeof(uint32_t));
        CHECK_MEMORY_ERROR(transition->prev);
        for (uint32_t i = 0; i < transition->prev_count; ++i)
            transition->prev[i] = values[index++];

        if (transition->next_count == 0 && transition->prev_count == 0) {
            PnsIndexList_push(&net->reusable_transitions, tid);
        }
    }

    uint32_t* place_next_counts = calloc(net->place_count, sizeof(uint32_t));
    CHECK_MEMORY_ERROR(place_next_counts);
    PnsIndexList** place_next = calloc(net->place_count, sizeof(PnsIndexList*));
    CHECK_MEMORY_ERROR(place_next);
    uint32_t* place_prev_counts = calloc(net->place_count, sizeof(uint32_t));
    CHECK_MEMORY_ERROR(place_prev_counts);
    PnsIndexList** place_prev = calloc(net->place_count, sizeof(PnsIndexList*));
    CHECK_MEMORY_ERROR(place_prev);

    net->reusable_places = NULL;
    for (uint32_t tid = 0; tid < net->transition_count; ++tid) {
        PnsNode* transition = &net->transitions[tid];
        for (uint32_t i = 0; i < transition->next_count; ++i) {
            uint32_t next = transition->next[i];
            ++place_prev_counts[next];
            PnsIndexList_push(&place_prev[next], tid);
        }
        for (uint32_t i = 0; i < transition->prev_count; ++i) {
            uint32_t prev = transition->prev[i];
            ++place_next_counts[prev];
            PnsIndexList_push(&place_next[prev], tid);
        }
    }

    for (uint32_t pid = 0; pid < net->place_count; ++pid) {
        PnsNode* place = &net->places[pid];
        place->next_count = place_next_counts[pid];
        place->next = malloc(place->next_count * sizeof(uint32_t));
        place->next_size = place->next_count;
        CHECK_MEMORY_ERROR(place->next);
        {
            int i = 0;
            PnsIndexList* list = place_next[pid];
            while(list != NULL) {
                place->next[i] = list->index;
                ++i;
                list = list->next;
            }
        }
        place->prev_count = place_prev_counts[pid];
        place->prev_size = place->prev_count;
        place->prev = malloc(place->prev_count * sizeof(uint32_t));
        CHECK_MEMORY_ERROR(place->prev);
        {
            int i = 0;
            PnsIndexList* list = place_prev[pid];
            while(list != NULL) {
                place->prev[i] = list->index;
                ++i;
                list = list->next;
            }
        }
        if (place->next_count == 0 && place->prev_count == 0 && net->initial_token_counts[pid] == 0) {
            PnsIndexList_push(&net->reusable_places, pid);
        }
    }

    for (uint32_t pid = 0; pid < net->place_count; ++pid) {
        pnsDestroyTransitionList(place_next[pid]);
        pnsDestroyTransitionList(place_prev[pid]);
    }
    free(place_next_counts);
    free(place_next);
    free(place_prev_counts);
    free(place_prev);

    net->dirt = NULL;
    net->reverseDirt = NULL;

    return true;
}

void pnsDestroyNet(PnsNet* net) {
    for (uint32_t tid = 0; tid < net->transition_count; ++tid) {
        PnsNode* transition = &net->transitions[tid];
        pnsDeleteNode(transition);
    }
    for (uint32_t pid = 0; pid < net->place_count; ++pid) {
        PnsNode* place = &net->places[pid];
        pnsDeleteNode(place);
    }
    free(net->transitions);
    free(net->places);
    free(net->initial_token_counts);

    pnsDestroyTransitionList(net->reusable_transitions);
    pnsDestroyTransitionList(net->reusable_places);

    pnsDestroyTransitionList(net->dirt);
    pnsDestroyTransitionList(net->reverseDirt);
}

uint32_t pnsNet_serializeSize(const PnsNet* net) {
    uint32_t count = 2 + net->place_count + 2 * net->transition_count;

    for (uint32_t tid = 0; tid < net->transition_count; ++tid) {
        PnsNode* transition = &net->transitions[tid];
        count += transition->next_count + transition->prev_count;
    }

    return count;
}

void pnsNet_serialize(const PnsNet* net, uint32_t* values) {
    uint32_t index = 0;

    values[index++] = net->place_count;
    for (uint32_t pid = 0; pid < net->place_count; ++pid)
        values[index++] = net->initial_token_counts[pid];

    values[index++] = net->transition_count;
    for (uint32_t tid = 0; tid < net->transition_count; ++tid) {
        PnsNode* transition = &net->transitions[tid];
        values[index++] = transition->next_count;
        for (uint32_t i = 0; i < transition->next_count; ++i)
            values[index++] = transition->next[i];
        values[index++] = transition->prev_count;
        for (uint32_t i = 0; i < transition->prev_count; ++i)
            values[index++] = transition->prev[i];
    }
}

/*** edit ***/

void pnsNet_dirtyPlace(PnsNet* net, uint32_t pid) {
    PnsNode* place = &net->places[pid];
    for (uint32_t i = 0; i < place->next_count; ++i) {
        uint32_t tid = place->next[i];
        PnsIndexList_appendNew(&net->dirt, tid);
    }
    for (uint32_t i = 0; i < place->prev_count; ++i) {
        uint32_t tid = place->prev[i];
        PnsIndexList_appendNew(&net->reverseDirt, tid);
    }
}

uint32_t get_pid(PnsNet* net) {
    uint32_t pid;
start:
    if (net->reusable_places == NULL) {
        pid = net->place_count++;
        if (net->place_count > net->places_size) {
            net->places_size = net->place_count + net->place_count / 5;
            net->places = realloc(net->places, net->places_size * sizeof(PnsNode));
            CHECK_MEMORY_ERROR(net->places);
            net->initial_token_counts = realloc(net->initial_token_counts, net->places_size * sizeof(PnsNode));
            CHECK_MEMORY_ERROR(net->initial_token_counts);
        }
    }
    else {
        pid = PnsIndexList_pop(&net->reusable_places);
        PnsNode* place = &net->places[pid];
        if (!(place->next_count == 0 && place->prev_count == 0 && net->initial_token_counts[pid] == 0)) goto start;
    }
    return pid;
}

uint32_t get_tid(PnsNet* net) {
    uint32_t tid;
start:
    if (net->reusable_transitions == NULL)
        ADD_ELEMENT(tid, net->transition_count, net->transitions_size, net->transitions)
    else {
        tid = PnsIndexList_pop(&net->reusable_transitions);
        PnsNode* transition = &net->transitions[tid];
        if (!(transition->next_count == 0 && transition->prev_count == 0)) goto start;
    }
    return tid;
}

uint32_t pnsNet_addPlace(PnsNet* net) {
    uint32_t pid = get_pid(net);

    PnsNode* node = &net->places[pid];
    pnsCreateNode(node);

    net->initial_token_counts[pid] = 0;

    return pid;
}

uint32_t pnsNet_addTransition(PnsNet* net) {
    uint32_t tid = get_tid(net);

    PnsIndexList_appendNew(&net->dirt, tid);

    PnsNode* node = &net->transitions[tid];
    pnsCreateNode(node);

    return tid;
}

uint32_t pnsNet_addConnectedTransition(PnsNet* net, uint32_t place_count, uint32_t* pids) {
    uint32_t tid = pnsNet_addTransition(net);
    for (uint32_t i = 0; i < place_count; ++i) {
        uint32_t pid = pids[i];
        pnsNet_connectIn_unsafe(net, tid, pid);
    }

    return tid;
}

void removeNode(uint32_t id, PnsNode* nodes, PnsNode* link_nodes) {
    PnsNode* node = &nodes[id];

    for (uint32_t i = 0; i < node->prev_count; ++i) {
        uint32_t link_id = node->prev[i];
        PnsNode* link_node = &link_nodes[link_id];
        removeIndex(id, &link_node->next_count, link_node->next);
    }

    for (uint32_t i = 0; i < node->next_count; ++i) {
        uint32_t link_id = node->next[i];
        PnsNode* link_node = &link_nodes[link_id];
        removeIndex(id, &link_node->prev_count, link_node->prev);
    }

    pnsDeleteNode(node);
}

void pnsNet_removePlace(PnsNet* net, uint32_t pid) {
    removeNode(pid, net->places, net->transitions);
    net->initial_token_counts[pid] = 0;
    PnsIndexList_appendNew(&net->reusable_places, pid);
}

void pnsNet_removeTransition_unsafe(PnsNet* net, uint32_t tid) {
    removeNode(tid, net->transitions, net->places);
    PnsIndexList_appendNew(&net->reusable_transitions, tid);
}

bool pnsNet_connectIn_unsafe(PnsNet* net, uint32_t tid, uint32_t pid) {
    PnsNode* transition = &net->transitions[tid];
    PnsNode* place = &net->places[pid];

    for (uint32_t i = 0; i < transition->prev_count; ++i) {
        uint32_t current_pid = transition->prev[i];

        if (current_pid < pid) continue;
        else if (current_pid == pid) return false;

        transition->prev[i] = pid;
        pid = current_pid;
    }

    for (uint32_t i = 0; i < place->next_count; ++i) {
        uint32_t current_tid = place->next[i];

        if (current_tid < tid) continue;

        place->next[i] = tid;
        tid = current_tid;
    }

    uint32_t prev_index, next_index;
    ADD_ELEMENT(prev_index, transition->prev_count, transition->prev_size, transition->prev)
    ADD_ELEMENT(next_index, place->next_count, place->next_size, place->next)

    transition->prev[prev_index] = pid;
    place->next[next_index] = tid;

    return true;
}

bool pnsNet_connectOut(PnsNet* net, uint32_t tid, uint32_t pid) {
    PnsIndexList_appendNew(&net->reverseDirt, tid);

    PnsNode* transition = &net->transitions[tid];
    PnsNode* place = &net->places[pid];

    for (uint32_t i = 0; i < transition->next_count; ++i) {
        uint32_t current_pid = transition->next[i];

        if (current_pid < pid) continue;
        else if (current_pid == pid) return false;

        transition->next[i] = pid;
        pid = current_pid;
    }

    for (uint32_t i = 0; i < place->prev_count; ++i) {
        uint32_t current_tid = place->prev[i];

        if (current_tid < tid) continue;

        place->prev[i] = tid;
        tid = current_tid;
    }

    uint32_t prev_index, next_index;
    ADD_ELEMENT(next_index, transition->next_count, transition->next_size, transition->next)
    ADD_ELEMENT(prev_index, place->prev_count, place->prev_size, place->prev)

    transition->next[next_index] = pid;
    place->prev[prev_index] = tid;

    return true;
}

void disconnectNodes(uint32_t first_id, PnsNode* first_nodes, uint32_t second_id, PnsNode* second_nodes) {
    PnsNode* first_node = &first_nodes[first_id];
    PnsNode* second_node = &second_nodes[second_id];

    removeIndex(second_id, &first_node->next_count, first_node->next);
    removeIndex(first_id, &second_node->prev_count, second_node->prev);
}

void pnsNet_disconnectIn(PnsNet* net, uint32_t tid, uint32_t pid) {
    PnsIndexList_appendNew(&net->dirt, tid);
    disconnectNodes(pid, net->places, tid, net->transitions);
}

void pnsNet_disconnectOut_unsafe(PnsNet* net, uint32_t tid, uint32_t pid) {
    disconnectNodes(tid, net->transitions, pid, net->places);
}


uint32_t pnsNet_duplicateTransition(PnsNet* net, uint32_t tid) {
    uint32_t clone_tid = get_tid(net);

    PnsIndexList_appendNew(&net->dirt, clone_tid);
    PnsIndexList_appendNew(&net->reverseDirt, clone_tid);

    PnsNode* node = &net->transitions[tid];
    PnsNode* node_clone = &net->transitions[clone_tid];
    pnsCloneNode(node_clone, node);

    for (uint32_t i = 0; i < node->prev_count; ++i) {
        uint32_t pid = node->prev[i];
        PnsNode* node = &net->places[pid];
        uint32_t new_tid = sortIndex(clone_tid, node->next_count, node->next);
        uint32_t id;
        ADD_ELEMENT(id, node->next_count, node->next_size, node->next)
        node->next[id] = new_tid;
    }

    for (uint32_t i = 0; i < node->next_count; ++i) {
        uint32_t pid = node->next[i];
        PnsNode* node = &net->places[pid];
        uint32_t new_tid = sortIndex(clone_tid, node->prev_count, node->prev);
        uint32_t id;
        ADD_ELEMENT(id, node->prev_count, node->prev_size, node->prev)
        node->prev[id] = new_tid;
    }

    return clone_tid;
}

uint32_t pnsNet_duplicatePlace(PnsNet* net, uint32_t pid) {
    uint32_t clone_pid = get_pid(net);

    PnsNode* node = &net->places[pid];
    PnsNode* node_clone = &net->places[clone_pid];

    pnsCloneNode(node_clone, node);

    net->initial_token_counts[clone_pid] = net->initial_token_counts[pid];

    for (uint32_t i = 0; i < node->prev_count; ++i) {
        uint32_t tid = node->prev[i];
        PnsNode* node = &net->transitions[tid];
        uint32_t new_pid = sortIndex(clone_pid, node->next_count, node->next);
        uint32_t id;
        ADD_ELEMENT(id, node->next_count, node->next_size, node->next)
        node->next[id] = new_pid;
    }

    for (uint32_t i = 0; i < node->next_count; ++i) {
        uint32_t tid = node->next[i];
        PnsNode* node = &net->transitions[tid];
        uint32_t new_pid = sortIndex(clone_pid, node->prev_count, node->prev);
        uint32_t id;
        ADD_ELEMENT(id, node->prev_count, node->prev_size, node->prev)
        node->prev[id] = new_pid;
    }

    return clone_pid;
}

uint32_t pnsNet_start(PnsNet* net, uint32_t pid, uint32_t count) {
    void pnsNet_dirtyPlace(PnsNet* net, uint32_t pid);
    return net->initial_token_counts[pid] += count;
}

void pnsNet_clearEdits(PnsNet* net) {
    pnsClearTransitionList(&net->dirt);
    pnsClearTransitionList(&net->reverseDirt);
}

/** State **/

/*** helpers ***/

void pnsFinalizeState(PnsState* state, const PnsNet* net) {
    pnsCreateFireChanges(&state->fire);
    pnsCreateFireChanges(&state->unfire);

    calculate_transition_list(state, net);
}

/*** default ***/

void pnsCreateState(PnsState* state, const PnsNet* net) {
    state->token_counts = calloc(net->place_count, sizeof(uint32_t));
    CHECK_MEMORY_ERROR(state->token_counts);

    state->call_counts = calloc(net->transition_count, sizeof(uint32_t));
    CHECK_MEMORY_ERROR(state->call_counts);

    pnsFinalizeState(state, net);
}

void pnsCloneState(PnsState* state_clone, const PnsState* state, const PnsNet* net) {
    state_clone->token_counts = malloc(net->place_count * sizeof(uint32_t));
    CHECK_MEMORY_ERROR(state_clone->token_counts);
    memcpy(state_clone->token_counts, state->token_counts, net->place_count * sizeof(uint32_t));

    state_clone->call_counts = malloc(net->transition_count * sizeof(uint32_t));
    CHECK_MEMORY_ERROR(state_clone->call_counts);
    memcpy(state_clone->call_counts, state->call_counts, net->transition_count * sizeof(uint32_t));

    pnsFinalizeState(state_clone, net);
}

bool pnsLoadState(PnsState* state, const PnsNet* net, uint32_t count, const uint32_t* values) {
    if (count != net->transition_count) return false;

    state->token_counts = calloc(net->place_count, sizeof(uint32_t));
    CHECK_MEMORY_ERROR(state->token_counts);

    state->call_counts = malloc(net->transition_count * sizeof(uint32_t));
    CHECK_MEMORY_ERROR(state->call_counts);

    for (uint32_t tid = 0; tid < net->transition_count; ++tid) {
        state->call_counts[tid] = values[tid];

        PnsNode* transition = &net->transitions[tid];
        for (uint32_t i = 0; i < transition->prev_count; ++i) {
            uint32_t pid = transition->prev[i];
            state->token_counts[pid] -= state->call_counts[tid];
        }
        for (uint32_t i = 0; i < transition->next_count; ++i) {
            uint32_t pid = transition->next[i];
            state->token_counts[pid] += state->call_counts[tid];
        }
    }

    pnsFinalizeState(state, net);

    return true;
}

void pnsDestroyState(PnsState* state) {
    free(state->token_counts);
    free(state->call_counts);
    pnsDestroyFireChanges(&state->fire);
    pnsDestroyFireChanges(&state->unfire);
}

/*** simulate ***/

void getTransitions(uint32_t current_count, PnsIndexList* list, uint32_t* count, uint32_t* transitions) {
    if (!transitions) {
        *count = current_count;
        return;
    }
    for (uint32_t i = 0; i < *count && list != NULL; ++i) {
        transitions[i] = list->index;
        list = list->next;
    }
    if (list == NULL) *count = current_count;
}

void pnsState_transitions(PnsState* state, uint32_t* count, uint32_t* transitions) {
    getTransitions(state->fire.count, state->fire.active, count, transitions);
}

void pnsState_transitions_backwards(PnsState* state, uint32_t* count, uint32_t* transitions) {
    getTransitions(state->unfire.count, state->unfire.active, count, transitions);
}

void pnsState_cleanChanges(PnsState* state) {
    state->fire.added_count = 0;
    pnsClearTransitionList(&state->fire.added);
    state->fire.removed_count = 0;
    pnsClearTransitionList(&state->fire.removed);
}

void pnsState_cleanChanges_backwards(PnsState* state) {
    state->unfire.added_count = 0;
    pnsClearTransitionList(&state->unfire.added);
    state->unfire.removed_count = 0;
    pnsClearTransitionList(&state->unfire.removed);
}

void popTransitions(uint32_t* current_count, PnsIndexList** list, uint32_t* count, uint32_t* transitions) {
    if (!transitions) {
        *count = *current_count;
        return;
    }
    for (uint32_t i = 0; i < *count && list != NULL; ++i) transitions[i] = PnsIndexList_pop(list);
    if (list == NULL) *count = *current_count;
    *current_count -= *count;
}

void pnsState_addedTransitions(PnsState* state, uint32_t* count, uint32_t* transitions) {
    popTransitions(&state->fire.added_count, &state->fire.added, count, transitions);
}

void pnsState_addedTransitions_backwards(PnsState* state, uint32_t* count, uint32_t* transitions) {
    popTransitions(&state->unfire.added_count, &state->unfire.added, count, transitions);
}

void pnsState_removedTransitions(PnsState* state, uint32_t* count, uint32_t* transitions) {
    popTransitions(&state->fire.removed_count, &state->fire.removed, count, transitions);
}

void pnsState_removedTransitions_backwards(PnsState* state, uint32_t* count, uint32_t* transitions) {
    popTransitions(&state->unfire.removed_count, &state->unfire.removed, count, transitions);
}

void pnsState_fireTimes(PnsState* state, const PnsNet* net, uint32_t tid, uint32_t times) {
    state->call_counts[tid] += times;

    PnsNode* transition = &net->transitions[tid];

    for (uint32_t i = 0; i < transition->prev_count; ++i) {
        state->token_counts[transition->prev[i]] -= times;
    }

    for (uint32_t i = 0; i < transition->next_count; ++i) {
        state->token_counts[transition->next[i]] += times;
    }

    recalculate_transition_backward(state, net, tid);

    for (uint32_t i = 0; i < transition->prev_count; ++i) {
        uint32_t pid = transition->prev[i];
        PnsNode* place = &net->places[pid];
        for (uint32_t j = 0; j < place->next_count; ++j) {
            recalculate_transition_forward(state, net, place->next[j]);
        }
        for (uint32_t j = 0; j < place->prev_count; ++j) {
            recalculate_transition_backward(state, net, place->prev[j]);
        }
    }

    for (uint32_t i = 0; i < transition->next_count; ++i) {
        uint32_t pid = transition->next[i];
        PnsNode* place = &net->places[pid];
        for (uint32_t j = 0; j < place->next_count; ++j) {
            recalculate_transition_forward(state, net, place->next[j]);
        }
        for (uint32_t j = 0; j < place->prev_count; ++j) {
            recalculate_transition_backward(state, net, place->prev[j]);
        }
    }
}

void pnsState_fire(PnsState* state, const PnsNet* net, uint32_t tid) {
    pnsState_fireTimes(state, net, tid, 1);
}

void pnsState_fire_backwards(PnsState* state, const PnsNet* net, uint32_t tid) {
    pnsState_fireTimes(state, net, tid, -1);
}

void pnsState_updateEdits(PnsState* state, const PnsNet* net) {
    if (state->transitions_size < net->transition_count)
        RESIZE_ARRAY(net->transition_count, state->transitions_size, state->call_counts)
    if (state->places_size < net->place_count)
        RESIZE_ARRAY(net->place_count, state->places_size, state->token_counts)
    PnsIndexList* dirt = net->dirt;
    while (dirt != NULL) {
        uint32_t tid = dirt->index;
        recalculate_transition_forward(state, net, tid);
        dirt = dirt->next;
    }
    PnsIndexList* reverseDirt = net->reverseDirt;
    while (reverseDirt != NULL) {
        uint32_t tid = reverseDirt->index;
        recalculate_transition_backward(state, net, tid);
        reverseDirt = reverseDirt->next;
    }
}

