use super::pipeline::CuboidsPipeline;
use crate::component::*;

use bevy::{
    prelude::*,
    render::{
        render_resource::{
            BindGroup, BindGroupDescriptor, BindGroupEntry, Buffer, BufferInitDescriptor,
            BufferUsages, BufferVec,
        },
        renderer::{RenderDevice, RenderQueue},
    },
};
use bytemuck::{cast_slice, Pod, Zeroable};
use std::collections::HashMap;

pub(crate) fn prepare_cuboids(
    pipeline: Res<CuboidsPipeline>,
    render_device: Res<RenderDevice>,
    render_queue: Res<RenderQueue>,
    mut buffer_cache: ResMut<BufferCache>,
    new_cuboids: Query<(Entity, &Cuboids)>,
) {
    for (entity, cuboids) in new_cuboids.iter() {
        if cuboids.instances.is_empty() {
            continue;
        }

        let mut gpu_instances = BufferVec::<GpuCuboid>::new(BufferUsages::STORAGE);
        for cuboid in cuboids.instances.iter() {
            gpu_instances.push(GpuCuboid::from(cuboid));
        }
        gpu_instances.write_buffer(&*render_device, &*render_queue);
        let instance_buffer = gpu_instances.buffer().unwrap().clone();
        let instance_buffer_bind_group = render_device.create_bind_group(&BindGroupDescriptor {
            label: Some("gpu_cuboids_bind_group"),
            layout: &pipeline.cuboids_layout,
            entries: &[BindGroupEntry {
                binding: 0,
                resource: instance_buffer.as_entire_binding(),
            }],
        });

        let num_instances: u32 = gpu_instances.len().try_into().unwrap();
        let indices = generate_index_buffer_data(num_instances);
        let index_count = indices.len().try_into().unwrap();
        let index_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: Some("gpu_cuboids_index_buffer"),
            contents: cast_slice(&indices),
            usage: BufferUsages::INDEX,
        });

        buffer_cache.insert(
            entity,
            GpuCuboidBuffers {
                index_buffer,
                index_count,
                _instance_buffer: instance_buffer,
                instance_buffer_bind_group,
            },
        );
    }
}

#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
#[repr(C)]
pub(crate) struct GpuCuboid {
    pub min: Vec4,
    pub max: Vec4,
    pub color: [f32; 4],
}

impl From<&Cuboid> for GpuCuboid {
    fn from(cuboid: &Cuboid) -> Self {
        Self {
            min: cuboid.minimum.extend(1.0),
            max: cuboid.maximum.extend(1.0),
            color: cuboid.color_rgba,
        }
    }
}

#[derive(Clone, Component)]
pub struct GpuCuboidBuffers {
    pub(crate) index_buffer: Buffer,
    pub(crate) index_count: u32,
    pub(crate) _instance_buffer: Buffer,
    pub(crate) instance_buffer_bind_group: BindGroup,
}

const NUM_CUBE_INDICES: u32 = 3 * 3 * 2;
const NUM_CUBE_VERTICES: u32 = 8;

fn generate_index_buffer_data(num_cuboids: u32) -> Vec<u32> {
    /// The indices for all triangles in a cuboid mesh (given 8 corner vertices).
    #[rustfmt::skip]
    const CUBE_INDICES: [u32; 36] = [
        0, 2, 1, 2, 3, 1,
        5, 4, 1, 1, 4, 0,
        0, 4, 6, 0, 6, 2,
        6, 5, 7, 6, 4, 5,
        2, 6, 3, 6, 7, 3,
        7, 1, 3, 7, 5, 1,
    ];

    let num_indices = num_cuboids * NUM_CUBE_INDICES;

    (0..num_indices)
        .map(|i| {
            let cuboid = i / NUM_CUBE_INDICES;
            let cuboid_local = i % NUM_CUBE_INDICES;
            cuboid as u32 * NUM_CUBE_VERTICES as u32 + CUBE_INDICES[cuboid_local as usize]
        })
        .collect()
}

#[derive(Default)]
pub(crate) struct BufferCache {
    entries: HashMap<Entity, BufferCacheEntry>,
}

struct BufferCacheEntry {
    buffers: GpuCuboidBuffers,
    keep_alive: bool,
}

impl BufferCache {
    pub fn get_buffers(&self, entity: Entity) -> Option<&GpuCuboidBuffers> {
        self.entries.get(&entity).map(|e| &e.buffers)
    }

    pub fn insert(&mut self, entity: Entity, buffers: GpuCuboidBuffers) {
        self.entries.insert(
            entity,
            BufferCacheEntry {
                buffers,
                keep_alive: false,
            },
        );
    }

    pub fn keep_alive(&mut self, entity: Entity) {
        self.entries.get_mut(&entity).unwrap().keep_alive = true;
    }

    pub fn entities(&self) -> impl Iterator<Item = &Entity> {
        self.entries.keys()
    }

    pub fn cull_entities(&mut self) {
        let mut to_remove = Vec::new();
        for (entity, entry) in self.entries.iter_mut() {
            if !entry.keep_alive {
                to_remove.push(*entity);
            }
            entry.keep_alive = false;
        }
        for entity in to_remove {
            self.entries.remove(&entity);
        }
    }
}
