use criterion::{black_box, criterion_group, criterion_main, Bencher, Criterion};
use tensorgraph_sys::{
    device::{cpu::Cpu, cuda::Cuda, DefaultDeviceAllocator},
    vec::{vec_from_host, Vec},
    Share, ShareMut,
};

use tensorgraph_math::{
    blas::{DefaultBLASContext, GEMM},
    tensor::{gemm_ctx, Tensor},
};

/// Performs 1000 matrix mulitplications on a 256x256 matrix
pub fn matmul_1000_256<D: DefaultDeviceAllocator + DefaultBLASContext>(
    init: &[f64],
) -> Vec<f64, D::Alloc>
where
    f64: GEMM<D::Context>,
    D::Alloc: Clone,
    D::Context: Copy,
{
    let a = vec_from_host::<f64, D>(init);
    let b = a.clone();
    let c = b.clone();

    let mut a = Tensor::from_shape([256, 256], a);
    let b = Tensor::from_shape([256, 256], b);
    let mut c = Tensor::from_shape([256, 256], c);

    let ctx = D::Context::default();
    for _ in 0..1000 {
        gemm_ctx(ctx, 1., a.share(), b.share(), 0., c.share_mut());
        std::mem::swap(&mut a, &mut c);
    }

    c.into_inner()
}

pub fn matmul(c: &mut Criterion) {
    let mut group = c.benchmark_group("matmul");

    let mut init = vec![0.0f64; 256 * 256];
    init[1] = 0.001;
    for i in 0..256 {
        let i = i * 256 + i; // diagonals
        init[i] = 1.0;
    }

    let cpu = |b: &mut Bencher| {
        b.iter(|| black_box(matmul_1000_256::<Cpu>(&init)));
    };

    #[cfg(feature = "openblas")]
    group.bench_function("openblas", cpu);

    #[cfg(feature = "blis")]
    group.bench_function("blis", cpu);

    #[cfg(feature = "netlib")]
    group.bench_function("netlib", cpu);

    #[cfg(feature = "matrixmultiply")]
    group.bench_function("matrixmultiply", cpu);

    #[cfg(feature = "accelerate")]
    group.bench_function("accelerate", cpu);

    #[cfg(feature = "cublas")]
    {
        use tensorgraph_math::blas::cublas::CublasContext;
        use tensorgraph_sys::device::cuda::{Context, Stream};

        let ctx = Context::quick_init().unwrap();
        Stream::new(&ctx).unwrap().global_over(|cuda| {
            let ctx = CublasContext::new();
            ctx.with_stream(Some(cuda)).global_over(|_ctx| {
                group.bench_function("cublas", |b| {
                    b.iter(|| {
                        // includes the time to sync data in the benchmark
                        let mut out = vec![0.0f64; 256 * 256];
                        matmul_1000_256::<Cuda>(&init).copy_to_host(&mut out);

                        black_box(out)
                    });
                });
            })
        });
    }

    group.finish();
}

criterion_group!(benches, matmul);
criterion_main!(benches);
