'''
Generic CellMap specific helper class/utility
'''

import numpy as np
import json
import typing

class CellMap:
    data: typing.Dict[str, np.ndarray]
    layers: typing.List[str]
    cell_size: np.ndarray
    cell_bounds: typing.Tuple[np.ndarray, np.ndarray]
    num_cells: np.ndarray
    extents: np.ndarray
    cell_boundary_precision: float
    from_parent: np.ndarray
    to_parent: np.ndarray


    @staticmethod
    def load(path):
        '''
        Loads a CellMap from the given path, expects a JSON file.
        '''

        # Read the data
        with open(path, 'r') as f:
            raw = json.load(f)

        cm = CellMap()

        # Load metadata
        cm.path = path
        cm.layers = raw['layers']
        cm.cell_size = np.array(raw['cell_size'])
        cm.cell_bounds = np.array([raw['cell_bounds']['x'], raw['cell_bounds']['y']])
        cm.num_cells = np.array([
            cm.cell_bounds[1][1] - cm.cell_bounds[1][0],
            cm.cell_bounds[0][1] - cm.cell_bounds[0][0] 
        ])
        cm.cell_boundary_precision = np.array(raw['cell_boundary_precision'])
        cm.from_parent = np.array(raw['from_parent_matrix']).reshape((3, 3))
        cm.to_parent = np.linalg.inv(cm.from_parent)

        # Calculate extents of map
        extents = np.array([
            [cm.cell_bounds[0][0], cm.cell_bounds[1][0]], 
            [cm.cell_bounds[0][1], cm.cell_bounds[1][0]], 
            [cm.cell_bounds[0][0], cm.cell_bounds[1][1]], 
            [cm.cell_bounds[0][1], cm.cell_bounds[1][1]], 
        ])
        cm.extents = cm.transform_to_parent(extents)

        # Load each layer in turn, reshaping as needed
        cm.data = dict()
        for layer, data in zip(cm.layers, raw['data']):
            if data['dim'][0] != cm.num_cells[0] or data['dim'][1] != cm.num_cells[1]:
                raise RuntimeError(f'Data in cell map file is of wrong shape. Expected {cm.num_cells} but got {data["dim"]}')
            cm.data[layer] = np.array(data['data']).reshape(cm.num_cells)

        return cm

    def transform_to_parent(self, points: np.ndarray):
        '''
        Converts the given point(s) from the map frame to the parent frame.

        Points should be an (N, 2) dimension array.
        '''
        n = np.shape(points)[0]
        dehomog = lambda x: x[:-1]/x[-1]
        homog = np.ones((n, 3))
        homog[:,:-1] = points
        homog = homog @ self.to_parent
        return np.array([dehomog(x) for x in homog])
