use super::{sdf, Material, Opcode, Program, SignedDistance};
use macaw::{Quat, Vec3, Vec4};

#[derive(Copy, Clone)]
pub struct Interpreter<SD: SignedDistance> {
    marker: std::marker::PhantomData<SD>,
}

impl<SD: SignedDistance> Default for Interpreter<SD> {
    fn default() -> Self {
        Self {
            marker: Default::default(),
        }
    }
}

pub struct InterpreterContext<'a, SD: SignedDistance> {
    program: &'a Program,
    stack: [SD; 64], // TODO: SmallVec ?
    stack_ptr: usize,
    constant_idx: usize,
    position_stack: [Vec3; 64], // TODO: SmallVec ?
    position_stack_ptr: usize,
}

impl<'a, SD: SignedDistance + Copy + Clone> InterpreterContext<'a, SD> {
    fn new(program: &'a Program) -> Self {
        Self {
            program,
            stack: [SignedDistance::infinity(); 64],
            stack_ptr: 0,
            constant_idx: 0,
            position_stack: [Vec3::ZERO; 64],
            position_stack_ptr: 0,
        }
    }

    fn reset(&mut self) {
        self.stack_ptr = 0;
        self.position_stack_ptr = 0;
        self.constant_idx = 0;
    }

    fn float32(&mut self) -> f32 {
        let ret = self.program.constants[self.constant_idx];
        self.constant_idx += 1;
        ret
    }

    fn vec3(&mut self) -> Vec3 {
        let ret =
            Vec3::from_slice(&self.program.constants[self.constant_idx..self.constant_idx + 3]);
        self.constant_idx += 3;
        ret
    }

    fn vec4(&mut self) -> Vec4 {
        let ret =
            Vec4::from_slice(&self.program.constants[self.constant_idx..self.constant_idx + 4]);
        self.constant_idx += 4;
        ret
    }

    fn quat(&mut self) -> Quat {
        let ret =
            Quat::from_slice(&self.program.constants[self.constant_idx..self.constant_idx + 4]);
        self.constant_idx += 4;
        ret
    }

    fn material(&mut self) -> Material {
        let rgb = self.vec3();
        rgb.into()
    }

    fn push_sd(&mut self, v: SD) {
        //self.stack.push(v);
        self.stack[self.stack_ptr] = v;
        self.stack_ptr += 1;
    }

    fn pop_sd(&mut self) -> Option<SD> {
        self.stack_ptr -= 1;
        self.stack.get(self.stack_ptr).copied()
    }

    fn push_position(&mut self, pos: Vec3) {
        self.position_stack[self.position_stack_ptr] = pos;
        self.position_stack_ptr += 1;
    }

    fn pop_position(&mut self) -> Option<Vec3> {
        self.position_stack_ptr -= 1;
        self.position_stack.get(self.position_stack_ptr).copied()
    }

    // See comment at the end of the `interpret` function below.
    #[allow(dead_code)]
    fn top_is_finite(&self) -> bool {
        if self.stack_ptr > 0 {
            if let Some(value) = self.stack.get(self.stack_ptr - 1) {
                value.is_distance_finite()
            } else {
                // Wacky
                false
            }
        } else {
            false
        }
    }
}

impl<SD: SignedDistance + Copy + Clone> Interpreter<SD> {
    pub fn new_context(program: &Program) -> InterpreterContext<'_, SD> {
        InterpreterContext::<SD>::new(program)
    }

    pub fn interpret(ctx: &mut InterpreterContext<'_, SD>, position: Vec3) -> Option<SD> {
        use Opcode::*;

        let mut current_position = position;

        //let mut ctx = InterpreterContext::<SD>::new(program);
        ctx.reset();

        for opcode in &ctx.program.opcodes {
            match opcode {
                Plane => {
                    let sd = sdf::sd_plane(current_position, ctx.vec4());
                    ctx.push_sd(sd);
                }
                Sphere => {
                    let sd = sdf::sd_sphere(current_position, ctx.vec3(), ctx.float32());
                    ctx.push_sd(sd);
                }
                Capsule => {
                    let sd =
                        sdf::sd_capsule(current_position, &[ctx.vec3(), ctx.vec3()], ctx.float32());
                    ctx.push_sd(sd);
                }
                RoundedCylinder => {
                    let sd = sdf::sd_rounded_cylinder(
                        current_position,
                        ctx.float32(),
                        ctx.float32(),
                        ctx.float32(),
                    );
                    ctx.push_sd(sd);
                }
                TaperedCapsule => {
                    let p0 = ctx.vec3();
                    let r0 = ctx.float32();
                    let p1 = ctx.vec3();
                    let r1 = ctx.float32();
                    let sd = sdf::sd_tapered_capsule(current_position, &[p0, p1], [r0, r1]);
                    ctx.push_sd(sd);
                }
                Cone => {
                    let r = ctx.float32();
                    let h = ctx.float32();
                    let sd = sdf::sd_cone(current_position, r, h);
                    ctx.push_sd(sd);
                }
                RoundedBox => {
                    let half_size = ctx.vec3();
                    let radius = ctx.float32();
                    let sd = sdf::sd_rounded_box(current_position, half_size, radius);
                    ctx.push_sd(sd);
                }
                Torus => {
                    let big_r = ctx.float32();
                    let small_r = ctx.float32();
                    ctx.push_sd(sdf::sd_torus(current_position, big_r, small_r));
                }
                TorusSector => {
                    let big_r = ctx.float32();
                    let small_r = ctx.float32();
                    let sin_cos_half_angle = (ctx.float32(), ctx.float32());
                    ctx.push_sd(sdf::sd_torus_sector(
                        current_position,
                        big_r,
                        small_r,
                        sin_cos_half_angle,
                    ));
                }
                BiconvexLens => {
                    let lower_sagitta = ctx.float32();
                    let upper_sagitta = ctx.float32();
                    let chord = ctx.float32();
                    let sd = sdf::sd_biconvex_lens(
                        current_position,
                        lower_sagitta,
                        upper_sagitta,
                        chord,
                    );
                    ctx.push_sd(sd);
                }
                Material => {
                    let sd = ctx.pop_sd().unwrap();
                    let material = ctx.material();
                    ctx.push_sd(sdf::sd_material(sd, material));
                }
                Union => {
                    let sd1 = ctx.pop_sd().unwrap();
                    let sd2 = ctx.pop_sd().unwrap();
                    ctx.push_sd(sdf::sd_op_union(sd1, sd2));
                }
                UnionSmooth => {
                    let sd1 = ctx.pop_sd().unwrap();
                    let sd2 = ctx.pop_sd().unwrap();
                    let width = ctx.float32();
                    ctx.push_sd(sdf::sd_op_union_smooth(sd1, sd2, width));
                }
                Subtract => {
                    let sd1 = ctx.pop_sd().unwrap();
                    let sd2 = ctx.pop_sd().unwrap();
                    ctx.push_sd(sdf::sd_op_subtract(sd1, sd2));
                }
                SubtractSmooth => {
                    let sd1 = ctx.pop_sd().unwrap();
                    let sd2 = ctx.pop_sd().unwrap();
                    let width = ctx.float32();
                    ctx.push_sd(sdf::sd_op_subtract_smooth(sd1, sd2, width));
                }
                Intersect => {
                    let sd1 = ctx.pop_sd().unwrap();
                    let sd2 = ctx.pop_sd().unwrap();
                    ctx.push_sd(sdf::sd_op_intersect(sd1, sd2));
                }
                IntersectSmooth => {
                    let sd1 = ctx.pop_sd().unwrap();
                    let sd2 = ctx.pop_sd().unwrap();
                    let width = ctx.float32();
                    ctx.push_sd(sdf::sd_op_intersect_smooth(sd1, sd2, width));
                }
                PushTranslation => {
                    let translation = ctx.vec3();
                    ctx.push_position(current_position);
                    current_position += translation;
                }
                PopTransform => {
                    current_position = ctx.pop_position().unwrap();
                }
                PushRotation => {
                    let rotation = ctx.quat();
                    ctx.push_position(current_position);
                    current_position = rotation * current_position;
                }
                PushScale => {
                    let scale = ctx.float32();
                    ctx.push_position(current_position);
                    current_position *= scale;
                }
                PopScale => {
                    current_position = ctx.pop_position().unwrap();
                    let inv_scale = ctx.float32();
                    let sd = ctx.pop_sd().unwrap();
                    ctx.push_sd(sd.copy_with_distance(inv_scale * sd.distance()));
                }
                End => {
                    break;
                }
            }

            // NaN check for debugging! Don't want the overhead by default, so disabled.
            // if !ctx.top_is_finite() {
            //    panic!("Hit infinity at {:?}", opcode);
            // }
        }

        ctx.pop_sd()
    }
}
