//! Managing [GPU shaders](https://en.wikipedia.org/wiki/Shader).

use crate::cx::*;
use std::collections::HashMap;
use zaplib_shader_compiler::ty::Ty;
use zaplib_shader_compiler::{Decl, ShaderAst};

/// Pointer to a [`CxShader`]. Get this pointer by calling [`Cx::get_shader`].
///
/// The [`Shader::location_hash`] is required, and is useful to disambiguate e.g.
/// shader draw calls that should not be bundled together, because otherwise
/// render order becomes unpredictable. This could cause draw calls to be pulled
/// "up" in the render tree, ie. getting called earlier than you would expect
/// them to. For example, say you have a "button" that draws a background using
/// "DrawColor", and a label using "DrawText". Now let's say that you're first
/// drawing a paragraph also using "DrawText", and then you want to draw a
/// button. So you write your calls in this order
/// 1. Paragraph (DrawText)
/// 2. Button (calls DrawColor and then DrawText)
///
/// But the draw calls will look like this:
///
/// 1. DrawText (Paragraph and Button)
/// 2. DrawColor (Button)
///
/// So the button label will be hidden behind its background! You could wrap the
/// button calls in a [`View`], so that you explicitly get a new level in the draw
/// tree:
///
/// 1. DrawText (Paragraph)
/// 2. View (Button)
///      * DrawColor
///      * DrawText
///
/// But this has the downside that if you want to draw 100 buttons, that they
/// will all get their own draw calls, which can get quite slow. Instead, if
/// you tag the draw calls in the button with a [`LocationHash`], then they will
/// still be grouped together if you draw 100 buttons, but separate from the
/// other DrawText call:
///
/// 1. DrawText (Paragraph)
/// 2. DrawColor-Button
/// 3. DrawText-Button
///
/// TODO(JP): Even with [`LocationHash`] draw calls can unexpectedly get grouped.
/// For example, if you have "button" struct that uses [`crate::DrawColor`] for its
/// background, and a "panel" struct that also uses [`crate::DrawColor`] for its
/// background, then if you first draw a button and then a panel that also
/// contains a button, then the draw calls for the buttons will still get
/// grouped, which means that the second button will get overwritten by the
/// panel! In other words, you make your calls in this order:
///
/// 1. Button 1
/// 2. Panel
/// 3. Button 2
///
/// But your calls will get grouped in this order:
///
/// 1. DrawColor-Buttons (1 and 2)
/// 2. DrawColor-Panel
///
/// Which is undesirable. You can work around this by explicitly making a new
/// drawcall for the shader using [`Cx::new_draw_call`], or by wrapping the panel
/// in a [`View`]. But I wonder if there is a clearer way of structuring this..
#[derive(Copy, Clone, Default, PartialEq, Debug)]
pub struct Shader {
    pub shader_id: usize,
    pub location_hash: LocationHash,
}

/// Contains information of a [`CxShader`] of what instances, instances, textures
/// and so on it contains. That information can then be used to modify a [`Shader`
/// or [`DrawCall`].
#[derive(Debug, Default, Clone)]
pub struct CxShaderMapping {
    /// Contains information about the special "rect_pos" and "rect_size" fields.
    /// See [`RectInstanceProps`].
    pub rect_instance_props: RectInstanceProps,
    /// Special structure for user-level uniforms.
    pub user_uniform_props: UniformProps,
    /// Special structure for reading/editing instance properties.
    pub instance_props: InstanceProps,
    /// Special structure for reading/editing geometry properties.
    pub geometry_props: InstanceProps,
    /// Raw definition of all textures.
    pub textures: Vec<PropDef>,
    /// Raw definition of all geometries.
    pub geometries: Vec<PropDef>,
    /// Raw definition of all instances.
    pub instances: Vec<PropDef>,
    /// Raw definition of all user-level uniforms.
    pub user_uniforms: Vec<PropDef>,
    /// Raw definition of all framework-level uniforms that get set per [`DrawCall`].
    pub draw_uniforms: Vec<PropDef>,
    /// Raw definition of all framework-level uniforms that get set per [`View`].
    pub view_uniforms: Vec<PropDef>,
    /// Raw definition of all framework-level uniforms that get set per [`Pass`].
    pub pass_uniforms: Vec<PropDef>,
}

impl CxShaderMapping {
    pub fn from_shader_ast(shader_ast: ShaderAst, metal_uniform_packing: bool) -> Self {
        let mut instances = Vec::new();
        let mut geometries = Vec::new();
        let mut user_uniforms = Vec::new();
        let mut draw_uniforms = Vec::new();
        let mut view_uniforms = Vec::new();
        let mut pass_uniforms = Vec::new();
        let mut textures = Vec::new();
        for decl in shader_ast.decls {
            match decl {
                Decl::Geometry(decl) => {
                    let prop_def = PropDef {
                        name: decl.ident.to_string(),
                        name_hash: StringHash::new(&decl.ident.to_string()).hash,
                        ty: decl.ty_expr.ty.borrow().clone().unwrap(),
                    };
                    geometries.push(prop_def);
                }
                Decl::Instance(decl) => {
                    let prop_def = PropDef {
                        name: decl.ident.to_string(),
                        name_hash: StringHash::new(&decl.ident.to_string()).hash,
                        ty: decl.ty_expr.ty.borrow().clone().unwrap(),
                    };
                    instances.push(prop_def);
                }
                Decl::Uniform(decl) => {
                    let prop_def = PropDef {
                        name: decl.ident.to_string(),
                        name_hash: StringHash::new(&decl.ident.to_string()).hash,
                        ty: decl.ty_expr.ty.borrow().clone().unwrap(),
                    };
                    match decl.block_ident {
                        Some(bi) if bi.with(|string| string == "draw") => {
                            draw_uniforms.push(prop_def);
                        }
                        Some(bi) if bi.with(|string| string == "view") => {
                            view_uniforms.push(prop_def);
                        }
                        Some(bi) if bi.with(|string| string == "pass") => {
                            pass_uniforms.push(prop_def);
                        }
                        None => {
                            user_uniforms.push(prop_def);
                        }
                        _ => (),
                    }
                }
                Decl::Texture(decl) => {
                    let prop_def = PropDef {
                        name: decl.ident.to_string(),
                        name_hash: StringHash::new(&decl.ident.to_string()).hash,
                        ty: decl.ty_expr.ty.borrow().clone().unwrap(),
                    };
                    textures.push(prop_def);
                }
                _ => (),
            }
        }

        CxShaderMapping {
            rect_instance_props: RectInstanceProps::construct(&instances),
            user_uniform_props: UniformProps::construct(&user_uniforms, metal_uniform_packing),
            instance_props: InstanceProps::construct(&instances),
            geometry_props: InstanceProps::construct(&geometries),
            textures,
            instances,
            geometries,
            pass_uniforms,
            view_uniforms,
            draw_uniforms,
            user_uniforms,
        }
    }
}

/// The raw definition of an input property to a [`Shader`].
#[derive(Debug, Clone, Hash, PartialEq)]
pub struct PropDef {
    pub name: String,
    pub ty: Ty,
    pub name_hash: u64,
}

/// Contains information about the special "rect_pos" and "rect_size" fields.
/// These fields describe the typical rectangles drawn in [`crate::DrawQuad`]. It is
/// useful to have generalized access to them, so we can e.g. move a whole bunch
/// of rectangles at the same time, e.g. for alignment in [`Turtle`].
/// TODO(JP): We might want to consider instead doing bulk moves using [`DrawCall`-
/// or [`View`]-level uniforms.
#[derive(Debug, Default, Clone)]
pub struct RectInstanceProps {
    pub rect_pos: Option<usize>,
    pub rect_size: Option<usize>,
}
impl RectInstanceProps {
    pub fn construct(instances: &Vec<PropDef>) -> RectInstanceProps {
        let mut rect_pos = None;
        let mut rect_size = None;
        let mut slot = 0;
        for inst in instances {
            match inst.name.as_ref() {
                "rect_pos" => rect_pos = Some(slot),
                "rect_size" => rect_size = Some(slot),
                _ => (),
            }
            slot += inst.ty.size(); //sg.get_type_slots(&inst.ty);
        }
        RectInstanceProps { rect_pos, rect_size }
    }
}

/// Represents an "instance" GPU input in a [`Shader`].
///
/// TODO(JP): Merge this into [`NamedProp`].
#[derive(Debug, Clone)]
pub struct InstanceProp {
    pub name: String,
    pub ty: Ty,
    pub offset: usize,
    pub slots: usize,
}

/// Represents all "instance" GPU inputs in a [`Shader`].
///
/// TODO(JP): Merge this into [`NamedProps`].
#[derive(Debug, Default, Clone)]
pub struct InstanceProps {
    pub prop_map: HashMap<u64, usize>,
    pub props: Vec<InstanceProp>,
    pub total_slots: usize,
}

/// Represents a "uniform" GPU input in a [`Shader`].
///
/// TODO(JP): Merge this into [`NamedProp`].
#[derive(Debug, Clone)]
pub struct UniformProp {
    pub name: String,
    pub ty: Ty,
    pub offset: usize,
    pub slots: usize,
}

/// Represents all "uniform" GPU inputs in a [`Shader`].
///
/// TODO(JP): Merge this into [`NamedProps`].
#[derive(Debug, Default, Clone)]
pub struct UniformProps {
    pub prop_map: HashMap<u64, usize>,
    pub props: Vec<UniformProp>,
    pub total_slots: usize,
}

/// A generic representation of any kind of [`Shader`] input (instance/uniform/geometry).
#[derive(Debug, Clone)]
pub struct NamedProp {
    pub name: String,
    pub ty: Ty,
    pub offset: usize,
    pub slots: usize,
}

/// A generic representation of a list of [`Shader`] inputs (instance/uniform/geometry).
#[derive(Debug, Default, Clone)]
pub struct NamedProps {
    pub props: Vec<NamedProp>,
    pub total_slots: usize,
}

impl NamedProps {
    pub fn construct(in_props: &Vec<PropDef>) -> NamedProps {
        let mut offset = 0;
        let mut out_props = Vec::new();
        for prop in in_props {
            let slots = prop.ty.size();
            out_props.push(NamedProp { ty: prop.ty.clone(), name: prop.name.clone(), offset, slots });
            offset += slots
        }
        NamedProps { props: out_props, total_slots: offset }
    }
}

impl InstanceProps {
    pub fn construct(in_props: &Vec<PropDef>) -> InstanceProps {
        let mut offset = 0;
        let mut out_props = Vec::new();
        let mut prop_map = HashMap::new();
        for prop in in_props {
            let slots = prop.ty.size();
            prop_map.insert(prop.name_hash, out_props.len());
            out_props.push(InstanceProp { ty: prop.ty.clone(), name: prop.name.clone(), offset, slots });
            offset += slots
        }
        InstanceProps { prop_map, props: out_props, total_slots: offset }
    }
}

impl UniformProps {
    pub fn construct(in_props: &Vec<PropDef>, metal_uniform_packing: bool) -> UniformProps {
        let mut out_props = Vec::new();
        let mut prop_map = HashMap::new();
        let mut offset = 0;

        for prop in in_props {
            let slots = prop.ty.size();

            // metal+webgl
            let aligned_slots = if metal_uniform_packing && slots == 3 { 4 } else { slots };
            if (offset & 3) + aligned_slots > 4 {
                // goes over the boundary
                offset += 4 - (offset & 3); // make jump to new slot
            }

            prop_map.insert(prop.name_hash, out_props.len());
            out_props.push(UniformProp { ty: prop.ty.clone(), name: prop.name.clone(), offset, slots });
            offset += aligned_slots
        }
        if offset & 3 > 0 {
            offset += 4 - (offset & 3);
        }
        UniformProps { prop_map, props: out_props, total_slots: offset }
    }

    pub fn find_zbias_uniform_prop(&self) -> Option<usize> {
        for prop in &self.props {
            if prop.name == "zbias" {
                return Some(prop.offset);
            }
        }
        None
    }
}

/// The actual shader information, which gets stored on [`Cx`]. Once compiled the
/// [`ShaderAst`] will be removed, and the [`CxPlatformShader`] (platform-specific
/// part of the compiled shader) gets set.
#[derive(Default, Clone)]
pub struct CxShader {
    pub name: String,
    pub default_geometry: Option<Geometry>,
    pub(crate) platform: Option<CxPlatformShader>,
    pub mapping: CxShaderMapping,
    pub shader_ast: Option<ShaderAst>,
}

impl Cx {
    /// Get an individual [`Shader`], by looking it up using its name. For more
    /// information on what [`LocationHash`] is used for here, see [`Shader`].
    pub fn get_shader(&self, name_hash: StringHash, location_hash: LocationHash) -> Shader {
        if let Some(shader_id) = self.shader_id_by_hashed_name.get(&name_hash.hash) {
            Shader { shader_id: *shader_id, location_hash }
        } else {
            panic!("Shader not found: {}", name_hash.string);
        }
    }

    /// Register a new shader. Has to be done during program initialization (in
    /// "style" call), since all shaders are compiled once after that.
    ///
    /// Pass in a [`Geometry`] which gets used for instancing (e.g. a quad or a
    /// cube).
    ///
    /// The different [`CodeFragment`]s are appended together (but preserving their
    /// filename/line/column information for error messages). They are split out
    /// into `base_code_fragments` and `main_code_fragment` purely for
    /// convenience. (We could instead have used a single [`slice`] but they are
    /// annoying to get through concatenation..)
    ///
    /// TODO(JP): Would be good to instead compile shaders beforehand, ie. during
    /// compile time. Should look into that at some point.
    pub fn register_shader(
        &mut self,
        shader_name: &str,
        default_geometry: Option<Geometry>,
        base_code_fragments: &[CodeFragment],
        main_code_fragment: &CodeFragment,
    ) {
        let code_fragments: Vec<&CodeFragment> = base_code_fragments.iter().chain([main_code_fragment]).collect();
        let shader_ast = self.shader_ast_generator.generate_shader_ast(&shader_name, code_fragments);

        let hashed_shader_name = StringHash::new(&shader_name).hash;
        if self.shader_id_by_hashed_name.contains_key(&hashed_shader_name) {
            panic!("Duplicate shader: {}", shader_name);
        }

        let shader_id = self.shaders.len();
        self.shader_id_by_hashed_name.insert(hashed_shader_name, shader_id);

        self.shaders.push(CxShader {
            name: shader_name.to_string(),
            default_geometry,
            mapping: CxShaderMapping::from_shader_ast(shader_ast.clone(), Self::use_metal_uniform_packing()),
            platform: None,
            shader_ast: Some(shader_ast),
        });
    }
}
