#!/usr/bin/env python

import sys
import time
import numpy as np
import mwa_hyperbeam
# from mwa_pb.primary_beam import MWA_Tile_full_EE


def get_pointings(n):
    za = np.linspace(0.1, 0.9 * np.pi / 2, n)
    az = np.linspace(0, 0.9 * np.pi, n)
    return az, za


def gen_pointings(num_pointings, hyperbeam=True, parallel=False):
    for n in num_pointings:
        az, za = get_pointings(n)
        freq_hz = 167e6
        delays = np.array([0]*16).astype(np.uint32)
        amps = np.array([1.0]*16)
        norm_to_zenith = True
        times = []

        # for _ in range(100):
        for _ in range(10):
            if np.sum(times) > 300:
                break

            if hyperbeam:
                beam = mwa_hyperbeam.FEEBeam()
            if parallel:
                if hyperbeam:
                    start_time = time.time()
                    beam.calc_jones_array(az, za, freq_hz,
                                          delays, amps, norm_to_zenith)
                else:
                    raise RuntimeError("mwa_pb doesn't have parallel code")

            else:
                start_time = time.time()
                if hyperbeam:
                    for i in range(len(az)):
                        beam.calc_jones(az[i], za[i], freq_hz,
                                        delays, amps, norm_to_zenith)
                else:
                    start_time = time.time()
                    MWA_Tile_full_EE(za, az, int(freq_hz), delays=delays,
                                     zenithnorm=norm_to_zenith,
                                     interp=False, power=False, jones=True)

            duration = time.time() - start_time
            times.append(duration)
            if hyperbeam:
                del beam

        s = ""
        if hyperbeam:
            if parallel:
                s += "(hyperbeam-parallel) "
            else:
                s += "(hyperbeam) "
        elif not hyperbeam:
            s += "(mwa_pb) "
        s += "Num pointings: {}, mean time: {:.3}s, median time: {:.3}s".format(n, np.mean(times), np.median(times))
        print(s)


if len(sys.argv) == 1:
    n = [500, 100000, 1000000]
else:
    n = map(int, sys.argv[1:])

gen_pointings(n, hyperbeam=True)
# gen_pointings(n, hyperbeam=False)
# gen_pointings(n, hyperbeam=True, parallel=True)
# gen_pointings(n, hyperbeam=False)
