use crossbeam_channel::{Receiver, Sender};
use ringbuf::RingBuffer;
use rustfft::{num_complex::Complex, FftPlanner};
use tracing::{debug, warn};

use std::cmp;
use std::env;
use std::f32::consts::PI;
use std::str::FromStr;

/// Hann window for smoothing data before it's sent to the fourier transform.
/// Copied from window.rs in dsp 0.10.1 - dual licensed MIT/Apache2
/// https://docs.rs/dsp/0.10.1/src/dsp/window.rs.html#151-159
fn hann(width: usize, offset: usize, window_length: usize) -> Vec<f32> {
    // Undocumented setting to see the difference between flattened vs unflattened output
    if let Ok(disabled) = env::var("DEMO_DISABLE_INPUT_WEIGHTS") {
        if !disabled.is_empty() {
            // Flat weights
            return vec![1.0; window_length];
        }
    }

    let mut samples = vec![0.0; window_length];
    let end = cmp::min(offset + width, window_length);
    for i in offset..end {
        let n = (i - offset) as f32;
        samples[i] = (PI * n / (width - 1) as f32).sin().powi(2);
    }
    samples
}

/// Calculates and returns the ordered top N largest values from the input.
/// The result will have the largest value at index 0, and Nth largest value at index N-1.
/// Duplicate values are not deduplicated in the output.
fn topn_vals(vals: &Vec<f32>, n: usize) -> Vec<f32> {
    let mut topn_vec = Vec::with_capacity(n);
    for val in vals {
        if topn_vec.len() >= n {
            // topn_vec is at capacity, check if we should insert at all
            // unwrap: just checked len >= capacity
            if val > topn_vec.last().unwrap() {
                // this is a topn val, drop lowest val from topn_vec to make room
                topn_vec.pop();
            } else {
                // this isn't a topn val, skip
                continue;
            }
        }

        // we have a topn value (or topn_vec wasn't full), insert into the vec
        match topn_vec.binary_search_by(|x| {
            if x == val {
                std::cmp::Ordering::Equal
            } else if x > val {
                std::cmp::Ordering::Less
            } else {
                std::cmp::Ordering::Greater
            }
        }) {
            Ok(idx_found) => {
                // Same value found, insert at next slot
                topn_vec.insert(idx_found + 1, *val);
            }
            Err(idx_insert) => {
                // Same value not found, insert where indicated
                topn_vec.insert(idx_insert, *val);
            }
        }
    }
    topn_vec
}

/// We keep track of two scales:
/// - High scale is for the top few values, things like bass lines will typically get captured here.
/// - Low scale is for everything else.
///
/// In practice this is to help with situations where a bass line taking up only a few buckets
/// is technically 2x-5x the amplitude of everything else. In that scenario we want to compress the
/// bass line so that the rest of the amplitude scale isn't being desensitized too much.
/// So we map the "low" (non-TopN) values to [0.0, 0.8] and the "high" values to [0.8, 1.0].
/// In theory a song with lots of equal amplitudes across will end up with the "high" values getting
/// exaggerated, but in practice at any given time there's only a few buckets with peak amplitudes.
struct CompressState {
    recent_max_low: f32,
    recent_max_high: f32,
    compression_enabled: bool,
}

/// Number of buckets to be counted as part of the "high" range and compressed into [0.8, 1.0].
/// Larger => more buckets getting forced into the "high" group to be compressed.
const HIGH_BUCKET_COUNT: usize = 10;
/// The dividing point between "high" values and "low" values.
/// So for example a value of 0.8 means the topN values are mapped to [0.8, 1.0] while all other values
/// are mapped to [0.0, 0.8].
/// Closer to 1.0 => more compression of the "high" values into a smaller range.
const HIGH_BUCKET_COMPRESSION: f32 = 0.8;
/// Multipliers to use for decaying recent_max_* for each frame.
/// Needs to be slow enough to avoid excessive fade-in in the voiceprint.
/// High range decay is faster so that things recover faster when a loud noise goes away.
const RECENT_MAX_FAST_MULTIPLIER: f32 = 0.99;
const RECENT_MAX_SLOW_MULTIPLIER: f32 = 0.995;

/// Returns a copy of 'vals' where the contents have been compressed according to their relative values
/// as well as recent history.
///
/// The scaling is within two ranges, where the vast majority of low values are assigned to [0.0, 0.8],
/// then the top few values are assigned to [0.8, 1.0]. This avoids situations where a small subset of
/// buckets with high amplitude data (e.g. a bass line) end up desensitizing the whole display.
///
/// To implement this, we keep track of two "recent_max" values, one for the #1 max, and the other
/// for the #10 max, where #1-#9 is compressed into [0.8, 1.0], and everything at/under the #10 value
/// is placed within [0.0, 0.8].
///
/// For example if the max value is 10 and the top Nth value is 5:
/// - divide values <= 5 by 5/0.8 = 6.25 to get [0.0,0.8] range
/// - divide values > 5 by ?? to get [0.8,1.0] range (compress down)
fn compress_amplitudes(vals: Vec<f32>, s: &mut CompressState) -> Vec<f32> {
    if !s.compression_enabled {
        // Ensure the values fall within the required [0.0, 1.0],
        // but don't apply any compression/keep things linear.
        let max = vals
            .iter()
            .max_by(|x, y| {
                if x == y {
                    std::cmp::Ordering::Equal
                } else if x < y {
                    std::cmp::Ordering::Less
                } else {
                    std::cmp::Ordering::Greater
                }
            })
            .unwrap()
            .clone();
        s.recent_max_high *= RECENT_MAX_SLOW_MULTIPLIER;
        if max > s.recent_max_high {
            s.recent_max_high = max;
        }
        return vals.into_iter().map(|x| x / s.recent_max_high).collect();
    }

    // Get the top N values up-front so that we don't change scaling partway through a sample.
    // The top N-1 values will be compressed if they exceed the Nth value too much.
    let top_bucket_vals = topn_vals(&vals, HIGH_BUCKET_COUNT);

    // Calculate two maxes, one for the 80th percentile and one for the "simulated" top
    let topn = *top_bucket_vals.last().unwrap();
    s.recent_max_low *= RECENT_MAX_SLOW_MULTIPLIER;
    if topn > s.recent_max_low {
        s.recent_max_low = topn;
    }
    let max = *top_bucket_vals.first().unwrap();
    s.recent_max_high *= RECENT_MAX_FAST_MULTIPLIER;
    if max > s.recent_max_high {
        s.recent_max_high = max;
    }
    // May happen since max_high's decay rate is faster
    if s.recent_max_low > s.recent_max_high {
        s.recent_max_high = s.recent_max_low;
    }

    if s.recent_max_high == 0.0 {
        // Shortcut for all-zero data: avoid NaN values
        return vals;
    }

    // Map/expand low values [0.0, recent_max_low] => [0.0, 0.8]
    let low_val_divisor = s.recent_max_low / HIGH_BUCKET_COMPRESSION;

    // Map/compress high values [recent_max_low, recent_max_high] => [0.8, 1.0]
    // Need to find 'multiplier' and 'adder' where:
    //   recent_max_high * multiplier + adder = 1.0
    //   recent_max_low * multiplier + adder = 0.8
    // Or solving for 'multiplier' and 'adder':
    //   multiplier = (1.0 - 0.8) / (recent_max_high - recent_max_low)
    //   adder = 1.0 - (multiplier * recent_max_high)
    // So for e.g. recent_max_high=10, recent_max_low=5:
    //   multiplier = (1.0 - 0.8) / (10 - 5) = .04
    //   adder = 1.0 - (.04 * 10) = .6
    // Or e.g. recent_max_high=100, recent_max_low=10:
    //   multiplier = (1.0 - 0.8) / (100 - 10) = .2 / 90 = 0.00222...
    //   adder = 1.0 - (0.00222... * 100) = 0.777...
    let high_val_multiplier =
        (1.0 - HIGH_BUCKET_COMPRESSION) / (s.recent_max_high - s.recent_max_low);
    let high_val_adder = 1.0 - high_val_multiplier * s.recent_max_high;

    vals.into_iter()
        .map(|val| {
            if val <= s.recent_max_low {
                // Map/expand "low" range to [0.0, 0.8]
                val / low_val_divisor
            } else {
                // Map/compress "high" range to [0.8, 1.0]
                val * high_val_multiplier + high_val_adder
            }
        })
        .collect()
}

/// Precalculates per-frequency-bucket amplitude weights to be multiplied against fourier output of output_len.
/// This equalizes the fourier output by increasing higher frequencies and decreasing lower frequencies.
/// The result should be amplitudes that look "flat" relative to perceived loudness of different frequencies.
fn freq_flattening_weights(output_len: usize, output_frequency: i32) -> Vec<f32> {
    let mut weights = Vec::with_capacity(output_len);
    // Undocumented setting to see the difference between flattened vs unflattened output
    if let Ok(disabled) = env::var("DEMO_DISABLE_FREQ_WEIGHTS") {
        if !disabled.is_empty() {
            // Flat weights
            return vec![1.0; output_len];
        }
    }

    let bucket_freq_width = output_frequency as f32 / output_len as f32;

    // Dedicated variable needed for powf to work
    let ten: f32 = 10.;
    // First bucket should be DC
    let mut cur_freq: f32 = 0.;
    for _i in 0..output_len {
        // Must add first if only to avoid error on log2(0)
        cur_freq += bucket_freq_width;

        // gain/dB for frequency f = 3.01 * log2(f / fbase)
        // Let's assume that the base frequency is 1khz
        let dbgain = 3.01 * (cur_freq / 1000.).log2();
        // dB to linear: 10 ^ (db / 20)
        weights.push(ten.powf(dbgain / 20.));
    }
    weights
}

/// Figures out a reasonable fourier size to use for the given input sample frequency and output bucket count.
/// This is determined by two factors:
/// - Fourier size should always be at least 2x output_len, because we discard the imaginaory
///   "mirror half" from the fourier output.
/// - At higher sampling rates, the output will have an empty margin at higher frequencies which we can
///   discard by using an even larger fourier calculation, which is then truncated to match output_len.
/// Returns a tuple containing:
/// 1. The size in buckets of the fourier transform to use (where the output is truncated after output_len)
/// 2. Weights of length output_len which can be used to "flatten" the truncated fourier output.
fn calculate_fourier_size(output_len: usize, input_freq: Option<i32>) -> (usize, Vec<f32>) {
    if let Some(f) = input_freq {
        // Undocumented setting to allow manual testing of different scales
        // Must be >= 2 or bad things will likely happen
        if let Ok(scale) = env::var("DEMO_FOURIER_SCALE") {
            if let Ok(scalesz) = usize::from_str(&scale) {
                return (
                    scalesz * output_len,
                    freq_flattening_weights(output_len, f / (scalesz as i32 / 2)),
                );
            }
        }

        // At higher sample rates, we end up with a lot of higher frequency data that is
        // effectively unused for audio. In these cases, we can use a larger fourier transform
        // than necessary, then trim down the output to fit the expected output_len,
        // effectively "zooming in" on the meaty part of the audio for what we display.
        // This also results in more "usable spectrum" in the same output_len, at the cost of:
        // - CPU on the fourier thread to do the larger FFT then discard some of it
        // - Rate of display updates since we're consuming more samples for the same output
        // The scaling used here doesn't have any mathematical basis and is just from trial and error,
        // but we stick with base-2 multipliers so that the FFT can use the faster Radix4 algorithm.
        if f >= 192000 {
            // Retain bottom 1/4 of the frequency range (or 1/8 of the overall FFT)
            (8 * output_len, freq_flattening_weights(output_len, f / 4))
        } else if f >= 64000 {
            // Retain bottom 1/2 of the frequency range (or 1/4 of the overall FFT)
            (4 * output_len, freq_flattening_weights(output_len, f / 2))
        } else {
            // Keep the full frequency range (or 1/2 of the overall FFT)
            (2 * output_len, freq_flattening_weights(output_len, f))
        }
    } else {
        // Don't know sample frequency, play it safe and use the minimum 2x size for 1/2 of the FFT
        // Base weights on a 48khz input (just a guess)
        (2 * output_len, freq_flattening_weights(output_len, 48000))
    }
}

/// Reads audio signal from recv_audio, runs fourier transform, then sends [0.0, 1.0] results to send_processed.
/// Fourier transform will consume fft_size samples, then produce a frequency spectrum of size=fft_size/2.
pub fn process_audio_loop(
    output_len: usize,
    input_frequency: Option<i32>,
    recv_audio: Receiver<Vec<f32>>,
    send_processed: Sender<Vec<f32>>,
) {
    // As mentioned below, we only take some of the resulting output of the fourier transform to display.
    // In addition to that, for higher frequencies we add some additional "padding" to the input,
    // allowing us to trim off some of the higher frequency data which ends up being zero anyway.
    let (input_len, output_freq_weights) = calculate_fourier_size(output_len, input_frequency);
    debug!("Using fourier size: {}", input_len);
    let fft = FftPlanner::new().plan_fft_forward(input_len);
    // Have the edges of the FFT input be zeroes
    let input_window_scales = hann(input_len - 2, 1, input_len);

    // Buffer for accumulating incoming audio, before it is processed by the FFT
    let (mut audio_buf_in, mut audio_buf_out) = RingBuffer::new(input_len).split();
    let mut fft_buf = Vec::with_capacity(input_len);
    fft_buf.resize(input_len, Complex::new(0.0, 0.0));

    let mut compress_state = CompressState {
        recent_max_low: 0.,
        recent_max_high: 0.,
        compression_enabled: match env::var("DEMO_DISABLE_COMPRESSION") {
            Ok(_) => false,
            Err(_) => true,
        },
    };
    loop {
        match recv_audio.recv() {
            Ok(audio) => {
                // Special signal/hack: If the buffer is empty, we've just switched to a new device.
                // In this case we should reset recent_max since the new device may have different levels.
                // However, in practice this may not work anyway if there's an signal spike when switching devices,
                // so in practice this hack is just best-effort.
                if audio.is_empty() {
                    compress_state.recent_max_low = 0.;
                    compress_state.recent_max_high = 0.;
                    continue;
                }

                // Grow circular buffer capacity to match what's needed (shouldn't occur much after init)
                if audio_buf_in.capacity() < audio_buf_in.len() + audio.len() {
                    let (mut new_audio_buf_in, new_audio_buf_out) =
                        RingBuffer::new(audio_buf_in.capacity() + audio.len()).split();
                    new_audio_buf_in.move_from(&mut audio_buf_out, None);
                    audio_buf_in = new_audio_buf_in;
                    audio_buf_out = new_audio_buf_out;
                }

                // Append audio to circular buffer
                audio_buf_in.push_iter(&mut audio.into_iter());

                // Consume audio from circular buffer
                while audio_buf_out.len() >= input_len {
                    // There's enough buffered input data to run a round of FFT
                    // Copy input_len values from audio_buf to fft_buf, applying hann window along the way
                    // There's a little extra complexity here because the ringbuf exposes the content as two parts
                    audio_buf_out.access(|older, newer| {
                        if older.len() >= input_len {
                            // just read from older
                            for i in 0..input_len {
                                fft_buf[i] = Complex::new(input_window_scales[i] * older[i], 0.0);
                            }
                        } else {
                            // read from older, then newer
                            for idx in 0..older.len() {
                                fft_buf[idx] =
                                    Complex::new(input_window_scales[idx] * older[idx], 0.0);
                            }
                            for i in 0..(input_len - older.len()) {
                                let buf_idx = older.len() + i;
                                fft_buf[buf_idx] =
                                    Complex::new(input_window_scales[buf_idx] * newer[i], 0.0);
                            }
                        }
                    });

                    // Remove the first HALF of the values from audio_buf, shifting it forward.
                    // The second half is left in-place.
                    // This allows us to reuse parts of the data that were deemphasized by the window filter.
                    // To illustrate:
                    // 1: 0 1 2 3 2 1 0
                    // 2:       0 1 2 3 2 1 0
                    // 3:             0 1 2 3 2 1 0
                    audio_buf_out.discard(input_len / 2);

                    // Run the FFT on the input data
                    fft.process(&mut fft_buf);

                    // Before sending the data to the rendering thread, preprocess the output.
                    // This reduces load on the rendering thread.

                    // Pre-process the fourier output:
                    // - Truncate result to output_len (implicitly via zip() against output_freq_weights)
                    // - Convert complex value to normalized value
                    // - Apply per-freq "flattening" weights against outputs,
                    //   decreasing low freqs/increasing higher freqs to match human perception
                    let flattened_result = Vec::from_iter(
                        fft_buf
                            .iter()
                            .zip(output_freq_weights.iter())
                            .map(|(val, weight)| val.norm() * weight),
                    );

                    // Compress the amplitudes to fall within [0.0, 1.0] before sending.
                    let compressed_result =
                        compress_amplitudes(flattened_result, &mut compress_state);

                    if let Err(e) = send_processed.send(compressed_result) {
                        // Output channel has closed, did something break?
                        warn!("exiting audio processing thread, output error: {}", e);
                        return;
                    }
                }
            }
            Err(_e) => {
                // Input channel has closed, expected when shutting down
                return;
            }
        }
    }
}
