use crate::jpeg::byteorder::ReadBytesExt;
use crate::error::{Error, Result};
use crate::jpeg::marker::Marker;
use crate::jpeg::marker::Marker::*;
use std::io::{self, Read};
use std::ops::RangeInclusive;

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Dimensions {
    pub width: u16,
    pub height: u16,
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub enum CodingProcess {
    DctSequential,
    DctProgressive,
    Lossless,
}

#[derive(Debug, Clone)]
pub struct FrameInfo {
    pub is_baseline: bool,
    pub is_differential: bool,
    pub coding_process: CodingProcess,
    pub precision: u8,

    pub image_size: Dimensions,
    pub mcu_size: Dimensions,
    pub components: Vec<Component>,
}

#[derive(Debug)]
pub struct ScanInfo {
    pub component_indices: Vec<usize>,
    pub dc_table_indices: Vec<usize>,
    pub ac_table_indices: Vec<usize>,

    pub spectral_selection: RangeInclusive<u8>,
    pub successive_approximation_high: u8,
    pub successive_approximation_low: u8,
}

#[derive(Clone, Debug)]
pub struct Component {
    pub identifier: u8,

    pub horizontal_sampling_factor: u8,
    pub vertical_sampling_factor: u8,

    pub quantization_table_index: usize,

    pub size: Dimensions,
    pub block_size: Dimensions,
}

fn read_length<R: Read>(reader: &mut R, marker: Marker) -> Result<usize> {
    if !marker.has_length() {
        return Err(Error::Format("unexpected empty marker"));
    }

    // length is including itself.
    let length = reader.read_u16_be()? as usize;

    if length < 2 {
        return Err(Error::Format("encountered invalid length"));
    }

    Ok(length - 2)
}

fn skip_bytes<R: Read>(reader: &mut R, length: usize) -> Result<()> {
    let length = length as u64;
    let to_skip = &mut reader.by_ref().take(length);
    let copied = io::copy(to_skip, &mut io::sink())?;
    if copied < length {
        Err(Error::Io(io::ErrorKind::UnexpectedEof.into()))
    } else {
        Ok(())
    }
}

// Section B.2.2
pub fn skip_marker<R: Read>(reader: &mut R, marker: Marker) -> Result<()> {
    if marker.has_length() {
        let length = read_length(reader, marker)?;
        skip_bytes(reader, length)?;
    }
    Ok(())
}

// Section B.2.2
pub fn parse_sof<R: Read>(reader: &mut R, marker: Marker) -> Result<FrameInfo> {
    let length = read_length(reader, marker)?;

    if length <= 6 {
        return Err(Error::Format("invalid length in SOF"));
    }

    let is_baseline = marker == SOF(0);
    let is_differential = match marker {
        SOF(0..=3) | SOF(9..=11) => false,
        SOF(5..=7) | SOF(13..=15) => true,
        _ => panic!(),
    };
    let coding_process = match marker {
        SOF(0) | SOF(1) | SOF(5) | SOF(9) | SOF(13) => CodingProcess::DctSequential,
        SOF(2) | SOF(6) | SOF(10) | SOF(14) => CodingProcess::DctProgressive,
        SOF(3) | SOF(7) | SOF(11) | SOF(15) => CodingProcess::Lossless,
        _ => panic!(),
    };

    let precision = reader.read_u8()?;

    match precision {
        8 => {},
        12 => {
            if is_baseline {
                return Err(Error::Format("12 bit sample precision is not allowed in baseline"));
            }
        },
        _ => {
            if coding_process != CodingProcess::Lossless {
                return Err(Error::Format("invalid precision in frame header"))
            }
        },
    }

    let height = reader.read_u16_be()?;
    let width = reader.read_u16_be()?;

    if width == 0 || height == 0 {
        return Err(Error::Format("zero size in frame header"));
    }

    let component_count = reader.read_u8()?;

    if component_count == 0 {
        return Err(Error::Format("zero component count in frame header"));
    }
    if coding_process == CodingProcess::DctProgressive && component_count > 4 {
        return Err(Error::Format("progressive frame with more than 4 components"));
    }

    if length != 6 + 3 * component_count as usize {
        return Err(Error::Format("invalid length in SOF"));
    }

    let mut components: Vec<Component> = Vec::with_capacity(component_count as usize);

    for _ in 0..component_count {
        let identifier = reader.read_u8()?;

        // Each component's identifier must be unique.
        if components.iter().any(|c| c.identifier == identifier) {
            return Err(Error::Format("duplicate frame component identifier"));
        }

        let byte = reader.read_u8()?;
        let horizontal_sampling_factor = byte >> 4;
        let vertical_sampling_factor = byte & 0x0f;

        if horizontal_sampling_factor == 0 || horizontal_sampling_factor > 4 {
            return Err(Error::Format("invalid horizontal sampling factor"));
        }
        if vertical_sampling_factor == 0 || vertical_sampling_factor > 4 {
            return Err(Error::Format("invalid vertical sampling factor"));
        }

        let quantization_table_index = reader.read_u8()?;

        if quantization_table_index > 3 || (coding_process == CodingProcess::Lossless && quantization_table_index != 0) {
            return Err(Error::Format("invalid quantization table index"));
        }

        components.push(Component {
            identifier,
            horizontal_sampling_factor,
            vertical_sampling_factor,
            quantization_table_index: quantization_table_index as usize,
            size: Dimensions {width: 0, height: 0},
            block_size: Dimensions {width: 0, height: 0},
        });
    }

    let h_max = components.iter().map(|c| c.horizontal_sampling_factor).max().unwrap();
    let v_max = components.iter().map(|c| c.vertical_sampling_factor).max().unwrap();
    let mcu_size = Dimensions {
        width: (width as f32 / (h_max as f32 * 8.0)).ceil() as u16,
        height: (height as f32 / (v_max as f32 * 8.0)).ceil() as u16,
    };

    for component in &mut components {
        component.size.width = (width as f32 * (component.horizontal_sampling_factor as f32 / h_max as f32)).ceil() as u16;
        component.size.height = (height as f32 * (component.vertical_sampling_factor as f32 / v_max as f32)).ceil() as u16;

        component.block_size.width = mcu_size.width * component.horizontal_sampling_factor as u16;
        component.block_size.height = mcu_size.height * component.vertical_sampling_factor as u16;
    }

    Ok(FrameInfo {
        is_baseline,
        is_differential,
        coding_process,
        precision,
        image_size: Dimensions {width, height},
        mcu_size,
        components,
    })
}

// Section B.2.3
pub fn parse_sos<R: Read>(reader: &mut R, frame: &FrameInfo) -> Result<ScanInfo> {
    let length = read_length(reader, SOS)?;
    if 0 == length {
        return Err(Error::Format("zero length in SOS"));
    }

    let component_count = reader.read_u8()?;

    if component_count == 0 || component_count > 4 {
        return Err(Error::Format("invalid component count in scan header"));
    }

    if length != 4 + 2 * component_count as usize {
        return Err(Error::Format("invalid length in SOS"));
    }

    let mut component_indices = Vec::with_capacity(component_count as usize);
    let mut dc_table_indices = Vec::with_capacity(component_count as usize);
    let mut ac_table_indices = Vec::with_capacity(component_count as usize);

    for _ in 0..component_count {
        let identifier = reader.read_u8()?;

        let component_index = match frame.components.iter().position(|c| c.identifier == identifier) {
            Some(value) => value,
            None => return Err(Error::Format("scan component identifier does not match any of the component identifiers defined in the frame")),
        };

        // Each of the scan's components must be unique.
        if component_indices.contains(&component_index) {
            return Err(Error::Format("duplicate scan component identifier"));
        }

        // "... the ordering in the scan header shall follow the ordering in the frame header."
        if component_index < *component_indices.iter().max().unwrap_or(&0) {
            return Err(Error::Format("the scan component order does not follow the order in the frame header"));
        }

        let byte = reader.read_u8()?;
        let dc_table_index = byte >> 4;
        let ac_table_index = byte & 0x0f;

        if dc_table_index > 3 || (frame.is_baseline && dc_table_index > 1) {
            return Err(Error::Format("invalid dc table index"));
        }
        if ac_table_index > 3 || (frame.is_baseline && ac_table_index > 1) {
            return Err(Error::Format("invalid ac table index"));
        }

        component_indices.push(component_index);
        dc_table_indices.push(dc_table_index as usize);
        ac_table_indices.push(ac_table_index as usize);
    }

    let blocks_per_mcu = component_indices.iter().map(|&i| {
        frame.components[i].horizontal_sampling_factor as u32 * frame.components[i].vertical_sampling_factor as u32
    }).fold(0, ::std::ops::Add::add);

    if component_count > 1 && blocks_per_mcu > 10 {
        return Err(Error::Format("scan with more than one component and more than 10 blocks per MCU"));
    }

    let spectral_selection_start = reader.read_u8()?;
    let spectral_selection_end = reader.read_u8()?;

    let byte = reader.read_u8()?;
    let successive_approximation_high = byte >> 4;
    let successive_approximation_low = byte & 0x0f;

    Ok(ScanInfo {
        component_indices,
        dc_table_indices,
        ac_table_indices,
        spectral_selection: spectral_selection_start..=spectral_selection_end,
        successive_approximation_high,
        successive_approximation_low,
    })
}
