use std::borrow::Borrow;
use std::cell::UnsafeCell;
use std::sync::Arc;

use wasm_bindgen::JsCast;
use web_sys::WebGl2RenderingContext as Gl;

use crate::runtime::{Connection, RenderingContext, ShaderCompilationError};
use crate::task::{ContextId, GpuTask, Progress};
use crate::util::JsId;
use std::hash::{Hash, Hasher};

/// The programmable stage in the rendering pipeline that handles the processing of individual
/// vertices.
///
/// A vertex shader receives a single vertex from a graphics pipeline's vertex input stream and
/// outputs a (transformed) vertex to the next pipeline stage.
///
/// See [RenderingContext::create_vertex_shader] for details on how a vertex shader is created.
pub struct VertexShader {
    object_id: u64,
    data: Arc<VertexShaderData>,
}

impl VertexShader {
    pub(crate) fn data(&self) -> &Arc<VertexShaderData> {
        &self.data
    }
}

impl PartialEq for VertexShader {
    fn eq(&self, other: &Self) -> bool {
        self.object_id == other.object_id
    }
}

impl Hash for VertexShader {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.object_id.hash(state);
    }
}

/// The programmable stage in the rendering pipeline that handles the processing of individual
/// fragments generated by rasterization into a set of colors and a single depth value.
///
/// The fragment shader is the pipeline stage after a primitive is rasterized. For each sample of
/// the pixels covered by a primitive, a "fragment" is generated. Each fragment notabily has a
/// Window Space position and it contains all of the interpolated per-vertex output values from the
/// last Vertex Processing stage.
///
/// The output of a fragment shader is a depth value, a possible stencil value (unmodified by the
/// fragment shader), and zero or more color values to be potentially written to the buffers in the
/// current framebuffer.
///
/// Fragment shaders take a single fragment as input and produce a single fragment as output.
///
/// See [RenderingContext::create_fragment_shader] for details on how a fragment shader is created.
pub struct FragmentShader {
    object_id: u64,
    data: Arc<FragmentShaderData>,
}

impl FragmentShader {
    pub(crate) fn data(&self) -> &Arc<FragmentShaderData> {
        &self.data
    }
}

impl PartialEq for FragmentShader {
    fn eq(&self, other: &Self) -> bool {
        self.object_id == other.object_id
    }
}

impl Hash for FragmentShader {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.object_id.hash(state);
    }
}

pub(crate) struct VertexShaderData {
    id: UnsafeCell<Option<JsId>>,
    context_id: u64,
    dropper: Box<dyn VertexShaderObjectDropper>,
}

impl VertexShaderData {
    pub(crate) fn id(&self) -> Option<JsId> {
        unsafe { *self.id.get() }
    }

    pub(crate) fn context_id(&self) -> u64 {
        self.context_id
    }
}

pub(crate) struct FragmentShaderData {
    id: UnsafeCell<Option<JsId>>,
    context_id: u64,
    dropper: Box<dyn FragmentShaderObjectDropper>,
}

impl FragmentShaderData {
    pub(crate) fn id(&self) -> Option<JsId> {
        unsafe { *self.id.get() }
    }

    pub(crate) fn context_id(&self) -> u64 {
        self.context_id
    }
}

trait VertexShaderObjectDropper {
    fn drop_shader_object(&self, id: JsId);
}

impl<T> VertexShaderObjectDropper for T
where
    T: RenderingContext,
{
    fn drop_shader_object(&self, id: JsId) {
        self.submit(VertexShaderDropCommand { id });
    }
}

impl Drop for VertexShaderData {
    fn drop(&mut self) {
        if let Some(id) = self.id() {
            self.dropper.drop_shader_object(id);
        }
    }
}

trait FragmentShaderObjectDropper {
    fn drop_shader_object(&self, id: JsId);
}

impl<T> FragmentShaderObjectDropper for T
where
    T: RenderingContext,
{
    fn drop_shader_object(&self, id: JsId) {
        self.submit(FragmentShaderDropCommand { id });
    }
}

impl Drop for FragmentShaderData {
    fn drop(&mut self) {
        if let Some(id) = self.id() {
            self.dropper.drop_shader_object(id);
        }
    }
}

pub(crate) struct VertexShaderAllocateCommand<S> {
    object_id: u64,
    data: Arc<VertexShaderData>,
    source: S,
}

impl<S> VertexShaderAllocateCommand<S>
where
    S: Borrow<str> + 'static,
{
    pub(crate) fn new<Rc>(context: &Rc, object_id: u64, source: S) -> Self
    where
        Rc: RenderingContext + Clone + 'static,
    {
        let data = Arc::new(VertexShaderData {
            id: UnsafeCell::new(None),
            context_id: context.id(),
            dropper: Box::new(context.clone()),
        });

        VertexShaderAllocateCommand {
            object_id,
            data,
            source,
        }
    }
}

unsafe impl<S> GpuTask<Connection> for VertexShaderAllocateCommand<S>
where
    S: Borrow<str>,
{
    type Output = Result<VertexShader, ShaderCompilationError>;

    fn context_id(&self) -> ContextId {
        ContextId::Any
    }

    fn progress(&mut self, connection: &mut Connection) -> Progress<Self::Output> {
        let (gl, _) = unsafe { connection.unpack_mut() };
        let data = &self.data;

        let shader_object = gl.create_shader(Gl::VERTEX_SHADER).unwrap();

        gl.shader_source(&shader_object, self.source.borrow());
        gl.compile_shader(&shader_object);

        if !gl
            .get_shader_parameter(&shader_object, Gl::COMPILE_STATUS)
            .as_bool()
            .unwrap()
        {
            let error = gl.get_shader_info_log(&shader_object).unwrap();

            Progress::Finished(Err(ShaderCompilationError(error)))
        } else {
            unsafe {
                *data.id.get() = Some(JsId::from_value(shader_object.into()));
            }

            Progress::Finished(Ok(VertexShader {
                object_id: self.object_id,
                data: self.data.clone(),
            }))
        }
    }
}

pub(crate) struct FragmentShaderAllocateCommand<S> {
    object_id: u64,
    data: Arc<FragmentShaderData>,
    source: S,
}

impl<S> FragmentShaderAllocateCommand<S>
where
    S: Borrow<str> + 'static,
{
    pub(crate) fn new<Rc>(context: &Rc, object_id: u64, source: S) -> Self
    where
        Rc: RenderingContext + Clone + 'static,
    {
        let data = Arc::new(FragmentShaderData {
            id: UnsafeCell::new(None),
            context_id: context.id(),
            dropper: Box::new(context.clone()),
        });

        FragmentShaderAllocateCommand {
            object_id,
            data,
            source,
        }
    }
}

unsafe impl<S> GpuTask<Connection> for FragmentShaderAllocateCommand<S>
where
    S: Borrow<str>,
{
    type Output = Result<FragmentShader, ShaderCompilationError>;

    fn context_id(&self) -> ContextId {
        ContextId::Any
    }

    fn progress(&mut self, connection: &mut Connection) -> Progress<Self::Output> {
        let (gl, _) = unsafe { connection.unpack_mut() };
        let data = &self.data;

        let shader_object = gl.create_shader(Gl::FRAGMENT_SHADER).unwrap();

        gl.shader_source(&shader_object, self.source.borrow());
        gl.compile_shader(&shader_object);

        if !gl
            .get_shader_parameter(&shader_object, Gl::COMPILE_STATUS)
            .as_bool()
            .unwrap()
        {
            let error = gl.get_shader_info_log(&shader_object).unwrap();

            Progress::Finished(Err(ShaderCompilationError(error)))
        } else {
            unsafe {
                *data.id.get() = Some(JsId::from_value(shader_object.into()));
            }

            Progress::Finished(Ok(FragmentShader {
                object_id: self.object_id,
                data: self.data.clone(),
            }))
        }
    }
}

struct VertexShaderDropCommand {
    id: JsId,
}

unsafe impl GpuTask<Connection> for VertexShaderDropCommand {
    type Output = ();

    fn context_id(&self) -> ContextId {
        ContextId::Any
    }

    fn progress(&mut self, connection: &mut Connection) -> Progress<Self::Output> {
        let (gl, state) = unsafe { connection.unpack_mut() };
        let value = unsafe { JsId::into_value(self.id) };

        state
            .program_cache_mut()
            .remove_vertex_shader_dependent(self.id);
        gl.delete_shader(Some(&value.unchecked_into()));

        Progress::Finished(())
    }
}

struct FragmentShaderDropCommand {
    id: JsId,
}

unsafe impl GpuTask<Connection> for FragmentShaderDropCommand {
    type Output = ();

    fn context_id(&self) -> ContextId {
        ContextId::Any
    }

    fn progress(&mut self, connection: &mut Connection) -> Progress<Self::Output> {
        let (gl, state) = unsafe { connection.unpack_mut() };
        let value = unsafe { JsId::into_value(self.id) };

        state
            .program_cache_mut()
            .remove_fragment_shader_dependent(self.id);
        gl.delete_shader(Some(&value.unchecked_into()));

        Progress::Finished(())
    }
}
