use crossbeam_channel::{Receiver, Sender};
use ringbuf::RingBuffer;
use rustfft::{num_complex::Complex, FftPlanner};
use tracing::{debug, warn};

use std::cmp;
use std::env;
use std::f32::consts::PI;
use std::str::FromStr;

/// Hann window for smoothing data before it's sent to the fourier transform.
/// Copied from window.rs in dsp 0.10.1 - dual licensed MIT/Apache2
/// https://docs.rs/dsp/0.10.1/src/dsp/window.rs.html#151-159
fn hann(width: usize, offset: usize, window_length: usize) -> Vec<f32> {
    let mut samples = vec![0.0; window_length];
    let end = cmp::min(offset + width, window_length);
    for i in offset..end {
        let n = (i - offset) as f32;
        samples[i] = (PI * n / (width - 1) as f32).sin().powi(2);
    }
    samples
}

/// Calculates and returns the ordered top N largest values from the input.
/// The result will have the largest value at index 0, and Nth largest value at index N-1.
/// Duplicate values are not deduplicated in the output.
fn topn(vals: &Vec<f32>, n: usize) -> Vec<f32> {
    let mut topn_vec = Vec::with_capacity(n);
    for val in vals {
        if topn_vec.len() >= n {
            // topn_vec is at capacity, check if we should insert at all
            // unwrap: just checked len >= capacity
            if val > topn_vec.last().unwrap() {
                // this is a topn val, drop lowest val from topn_vec to make room
                topn_vec.pop();
            } else {
                // this isn't a topn val, skip
                continue;
            }
        }

        // we have a topn value (or topn_vec wasn't full), insert into the vec
        match topn_vec.binary_search_by(|x| {
            if x == val {
                std::cmp::Ordering::Equal
            } else if x > val {
                std::cmp::Ordering::Less
            } else {
                std::cmp::Ordering::Greater
            }
        }) {
            Ok(idx_found) => {
                // Same value found, insert at next slot
                topn_vec.insert(idx_found + 1, *val);
            }
            Err(idx_insert) => {
                // Same value not found, insert where indicated
                topn_vec.insert(idx_insert, *val);
            }
        }
    }
    topn_vec
}

/// We keep track of two scales:
/// - High scale is for the top few values, things like bass lines will typically get captured here.
/// - Low scale is for everything else.
///
/// In practice this is to help with situations where a bass line taking up only a few buckets
/// is technically 2x-5x the amplitude of everything else. In that scenario we want to compress the
/// bass line so that the rest of the amplitude scale isn't being desensitized too much.
/// So we map the "low" (non-TopN) values to [0.0, 0.8] and the "high" values to [0.8, 1.0].
/// In theory a song with lots of equal amplitudes across will end up with the "high" values getting
/// exaggerated, but in practice at any given time there's only a few buckets with peak amplitudes.
struct ScaleState {
    recent_max_low: f32,
    recent_max_high: f32,
}

/// Number of buckets to be counted as part of the "high" range and compressed into [0.8, 1.0].
/// Higher = more buckets getting forced into the "high" group to be compressed.
const HIGH_BUCKET_COUNT: usize = 10;
/// The dividing point between "high" values and "low" values.
/// So for example a value of 0.8 means the topN values are mapped to [0.8, 1.0] while all other values
/// are mapped to [0.0, 0.8].
/// Higher = more compression of the "high" values into a smaller range.
const HIGH_BUCKET_COMPRESSION: f32 = 0.8;
/// Multipliers to use for decaying recent_max_* for each frame.
/// Needs to be slow enough to avoid excessive fade-in in the voiceprint.
/// High range decay is faster so that things recover faster when a bass line goes away.
const RECENT_MAX_HIGH_MULTIPLIER: f32 = 0.99;
const RECENT_MAX_LOW_MULTIPLIER: f32 = 0.995;

/// Returns a copy of 'vals' where the contents have been scaled.
/// For example if recent_max is 10 and the top Nth value is 5:
/// - divide values <= 5 by 5/0.8 = 6.25 to get [0.0,0.8] range
/// - divide values > 5 by ?? to get [0.8,1.0] range (compress down)
/// Meanwhile if the max value is 10 and the 10th largest value is 9:
/// - divide all values by recent_max=10 (no compression)
fn scale(vals: Vec<f32>, s: &mut ScaleState) -> Vec<f32> {
    // Get the top N values up-front so that we don't change scaling partway through a sample.
    // The top N-1 values will be compressed if they exceed the Nth value too much.
    let top_bucket_vals = topn(&vals, HIGH_BUCKET_COUNT);

    // Calculate two maxes, one for the 80th percentile and one for the "simulated" top
    let topn = *top_bucket_vals.last().unwrap();
    s.recent_max_low *= RECENT_MAX_LOW_MULTIPLIER;
    if topn > s.recent_max_low {
        s.recent_max_low = topn;
    }
    let max = *top_bucket_vals.first().unwrap();
    s.recent_max_high *= RECENT_MAX_HIGH_MULTIPLIER;
    if max > s.recent_max_high {
        s.recent_max_high = max;
    }
    // May happen since max_high's decay rate is faster
    if s.recent_max_low > s.recent_max_high {
        s.recent_max_high = s.recent_max_low;
    }

    if s.recent_max_high == 0.0 {
        // Shortcut for all-zero data: avoid NaN values
        return vals;
    }

    // Map/expand low values [0.0, recent_max_low] => [0.0, 0.8]
    let low_val_divisor = s.recent_max_low / HIGH_BUCKET_COMPRESSION;

    // Map/compress high values [recent_max_low, recent_max_high] => [0.8, 1.0]
    // Need to find 'multiplier' and 'adder' where:
    //   recent_max_high * multiplier + adder = 1.0
    //   recent_max_low * multiplier + adder = 0.8
    // Or solving for 'multiplier' and 'adder':
    //   multiplier = (1.0 - 0.8) / (recent_max_high - recent_max_low)
    //   adder = 1.0 - (multiplier * recent_max_high)
    // So for e.g. recent_max_high=10, recent_max_low=5:
    //   multiplier = (1.0 - 0.8) / (10 - 5) = .04
    //   adder = 1.0 - (.04 * 10) = .6
    // Or e.g. recent_max_high=100, recent_max_low=10:
    //   multiplier = (1.0 - 0.8) / (100 - 10) = .2 / 90 = 0.00222...
    //   adder = 1.0 - (0.00222... * 100) = 0.777...
    let high_val_multiplier =
        (1.0 - HIGH_BUCKET_COMPRESSION) / (s.recent_max_high - s.recent_max_low);
    let high_val_adder = 1.0 - high_val_multiplier * s.recent_max_high;

    vals.into_iter()
        .map(|val| {
            if val <= s.recent_max_low {
                // Map/expand "low" range to [0.0, 0.8]
                val / low_val_divisor
            } else {
                // Map/compress "high" range to [0.8, 1.0]
                val * high_val_multiplier + high_val_adder
            }
        })
        .collect()
}

/// Figures out a reasonable fourier size to use for the given input sample frequency and output bucket count.
/// This is determined by two factors:
/// - Fourier size should always be at least double output_len, because we discard half the fourier result.
/// - At higher sampling rates, the output will have an empty margin at higher frequencies which we can
///   discard by using an even larger fourier calculation.
fn calculate_fourier_size(output_len: usize, input_frequency: Option<i32>) -> usize {
    // Undocumented setting to allow manual testing of different scales
    // Must be >= 2 or bad things will likely happen
    if let Ok(scale) = env::var("FOURIER_SCALE") {
        if let Ok(scalesz) = usize::from_str(&scale) {
            return scalesz * output_len;
        }
    }

    if let Some(f) = input_frequency {
        // At higher sample rates, we end up with a lot of higher frequency output that is
        // effectively unused in source audio. In these cases, we can use a larger fourier transform
        // than necessary, then trim down the output to fit the expected output_len,
        // effectively "zooming in" on the meaty part of the audio for what we display.
        // This also results in more "usable spectrum" in the same output_len, at the cost of:
        // - CPU on the fourier thread
        // - Rate of display updates since we're consuming more samples for the same output
        // The scaling used here doesn't have any mathematical basis and is just from trial and error,
        // but we stick with base-2 multipliers so that the FFT can use the faster Radix4 algorithm.
        if f >= 192000 {
            8 * output_len
        } else if f >= 64000 {
            4 * output_len
        } else {
            // Not much margin to trim here. Stick to the minimum 2x size.
            2 * output_len
        }
    } else {
        // Don't know sample frequency, play it safe and use the minimum 2x size
        2 * output_len
    }
}

/// Reads audio signal from recv_audio, runs fourier transform, then sends [0.0, 1.0] results to send_processed.
/// Fourier transform will consume fft_size samples, then produce a frequency spectrum of size=fft_size/2.
pub fn process_audio_loop(
    output_len: usize,
    input_frequency: Option<i32>,
    recv_audio: Receiver<Vec<f32>>,
    send_processed: Sender<Vec<f32>>,
) {
    // As mentioned below, we only take some of the resulting output of the fourier transform to display.
    // In addition to that, for higher frequencies we add some additional "padding" to the input,
    // allowing us to trim off some of the higher frequency data which ends up being zero anyway.
    let input_len = calculate_fourier_size(output_len, input_frequency);
    debug!("Using fourier size: {}", input_len);
    let fft = FftPlanner::new().plan_fft_forward(input_len);
    // Have the edges of the FFT input be zeroes
    let window = hann(input_len - 2, 1, input_len);

    // Buffer for accumulating incoming audio, before it is processed by the FFT
    let (mut audio_buf_in, mut audio_buf_out) = RingBuffer::new(input_len).split();
    let mut fft_buf = Vec::with_capacity(input_len);
    fft_buf.resize(input_len, Complex::new(0.0, 0.0));

    let mut scale_state = ScaleState {
        recent_max_low: 0.,
        recent_max_high: 0.,
    };
    loop {
        match recv_audio.recv() {
            Ok(audio) => {
                // Special signal/hack: If the buffer is empty, we've just switched to a new device.
                // In this case we should reset recent_max since the new device may have different levels.
                // However, in practice this may not work anyway if there's an signal spike when switching devices,
                // so in practice this hack is just best-effort.
                if audio.is_empty() {
                    scale_state.recent_max_low = 0.;
                    scale_state.recent_max_high = 0.;
                    continue;
                }

                // Grow circular buffer capacity to match what's needed (shouldn't occur much after init)
                if audio_buf_in.capacity() < audio_buf_in.len() + audio.len() {
                    let (mut new_audio_buf_in, new_audio_buf_out) =
                        RingBuffer::new(audio_buf_in.capacity() + audio.len()).split();
                    new_audio_buf_in.move_from(&mut audio_buf_out, None);
                    audio_buf_in = new_audio_buf_in;
                    audio_buf_out = new_audio_buf_out;
                }

                // Append audio to circular buffer
                audio_buf_in.push_iter(&mut audio.into_iter());

                // Consume audio from circular buffer
                while audio_buf_out.len() >= input_len {
                    // There's enough buffered input data to run a round of FFT
                    // Copy input_len values from audio_buf to fft_buf, applying hann window along the way
                    // There's a little extra complexity here because the ringbuf exposes the content as two parts
                    audio_buf_out.access(|older, newer| {
                        if older.len() >= input_len {
                            // just read from older
                            for i in 0..input_len {
                                fft_buf[i] = Complex::new(window[i] * older[i], 0.0);
                            }
                        } else {
                            // read from older, then newer
                            for idx in 0..older.len() {
                                fft_buf[idx] = Complex::new(window[idx] * older[idx], 0.0);
                            }
                            for i in 0..(input_len - older.len()) {
                                let buf_idx = older.len() + i;
                                fft_buf[buf_idx] = Complex::new(window[buf_idx] * newer[i], 0.0);
                            }
                        }
                    });

                    // Remove the first HALF of the values from audio_buf, shifting it forward.
                    // The second half is left in-place.
                    // This allows us to reuse parts of the data that were deemphasized by the window filter.
                    // To illustrate:
                    // 1: 0 1 2 3 2 1 0
                    // 2:       0 1 2 3 2 1 0
                    // 3:             0 1 2 3 2 1 0
                    audio_buf_out.discard(input_len / 2);

                    // Process the selected data in fft_buf, then forward the result
                    fft.process(&mut fft_buf);
                    let result = Vec::from_iter(
                        fft_buf
                            .iter()
                            // Only take output_len bytes:
                            // - The output is two "mirrored" halves,
                            //   where result[0] is DC and result[input_len-1] is zero
                            // - For higher sample rates, we also remove some of the high freq space
                            .take(output_len)
                            // Convert complex values into floats
                            .map(|c| c.norm()),
                    );

                    // Scale/compress the values to fall within [0.0, 1.0] before sending.
                    // Note: We do this in the fourier thread to take some load off of the main display thread.
                    let result = scale(result, &mut scale_state);

                    if let Err(e) = send_processed.send(result) {
                        // Output channel has closed, did something break?
                        warn!("exiting audio processing thread, output error: {}", e);
                        return;
                    }
                }
            }
            Err(_e) => {
                // Input channel has closed, expected when shutting down
                return;
            }
        }
    }
}
