use crossbeam_channel::{Receiver, Sender};
use ringbuf::RingBuffer;
use rustfft::{num_complex::Complex, FftPlanner};
use tracing::{debug, warn};

use std::cmp;
use std::env;
use std::f32::consts::PI;
use std::str::FromStr;

/// Hann window for smoothing data before it's sent to the fourier transform.
/// Copied from window.rs in dsp 0.10.1 - dual licensed MIT/Apache2
/// https://docs.rs/dsp/0.10.1/src/dsp/window.rs.html#151-159
fn hann(width: usize, offset: usize, window_length: usize) -> Vec<f32> {
    let mut samples = vec![0.0; window_length];
    let end = cmp::min(offset + width, window_length);
    for i in offset..end {
        let n = (i - offset) as f32;
        samples[i] = (PI * n / (width - 1) as f32).sin().powi(2);
    }
    samples
}

/// Figures out a reasonable fourier size to use for the given input sample frequency and output bucket count.
/// This is determined by two factors:
/// - Fourier size should always be at least double output_len, because we discard half the fourier result.
/// - At higher sampling rates, the output will have an empty margin at higher frequencies which we can
///   discard by using an even larger fourier calculation.
fn calculate_fourier_size(output_len: usize, input_frequency: Option<i32>) -> usize {
    // Undocumented setting to allow manual testing of different scales
    // Must be >= 2 or bad things will likely happen
    if let Ok(scale) = env::var("FOURIER_SCALE") {
        if let Ok(scalesz) = usize::from_str(&scale) {
            return scalesz * output_len;
        }
    }

    if let Some(f) = input_frequency {
        // At higher sample rates, we end up with a lot of higher frequency output that is
        // effectively unused in source audio. In these cases, we can use a larger fourier transform
        // than necessary, then trim down the output to fit the expected output_len,
        // effectively "zooming in" on the meaty part of the audio for what we display.
        // This also results in more "usable spectrum" in the same output_len, at the cost of:
        // - CPU on the fourier thread
        // - Rate of display updates since we're consuming more samples for the same output
        // The scaling used here doesn't have any mathematical basis and is just from trial and error,
        // but we stick with base-2 multipliers so that the FFT can use the faster Radix4 algorithm.
        if f >= 192000 {
            8 * output_len
        } else if f >= 64000 {
            4 * output_len
        } else {
            // Not much margin to trim here. Stick to the minimum 2x size.
            2 * output_len
        }
    } else {
        // Don't know sample frequency, play it safe and use the minimum 2x size
        2 * output_len
    }
}

/// Reads audio signal from recv_audio, runs fourier transform, then sends [0.0, 1.0] results to send_processed.
/// Fourier transform will consume fft_size samples, then produce a frequency spectrum of size=fft_size/2.
pub fn process_audio_loop(
    output_len: usize,
    input_frequency: Option<i32>,
    recv_audio: Receiver<Vec<f32>>,
    send_processed: Sender<Vec<f32>>,
) {
    // As mentioned below, we only take some of the resulting output of the fourier transform to display.
    // In addition to that, for higher frequencies we add some additional "padding" to the input,
    // allowing us to trim off some of the higher frequency data which ends up being zero anyway.
    let input_len = calculate_fourier_size(output_len, input_frequency);
    debug!("Using fourier size: {}", input_len);
    let fft = FftPlanner::new().plan_fft_forward(input_len);
    // Have the edges of the FFT input be zeroes
    let window = hann(input_len - 2, 1, input_len);

    // Buffer for accumulating incoming audio, before it is processed by the FFT
    let (mut audio_buf_in, mut audio_buf_out) = RingBuffer::new(input_len).split();
    let mut fft_buf = Vec::with_capacity(input_len);
    fft_buf.resize(input_len, Complex::new(0.0, 0.0));

    let mut recent_max = 0.;
    loop {
        match recv_audio.recv() {
            Ok(audio) => {
                // Special signal/hack: If the buffer is empty, we've just switched to a new device.
                // In this case we should reset recent_max since the new device may have different levels.
                // However, in practice this may not work anyway if there's an signal spike when switching devices,
                // so in practice this hack is just best-effort.
                if audio.is_empty() {
                    recent_max = 0.;
                    continue;
                }

                // Grow circular buffer capacity to match what's needed (shouldn't occur much after init)
                if audio_buf_in.capacity() < audio_buf_in.len() + audio.len() {
                    let (mut new_audio_buf_in, new_audio_buf_out) =
                        RingBuffer::new(audio_buf_in.capacity() + audio.len()).split();
                    new_audio_buf_in.move_from(&mut audio_buf_out, None);
                    audio_buf_in = new_audio_buf_in;
                    audio_buf_out = new_audio_buf_out;
                }

                // Append audio to circular buffer
                audio_buf_in.push_iter(&mut audio.into_iter());

                // Consume audio from circular buffer
                while audio_buf_out.len() >= input_len {
                    // There's enough buffered input data to run a round of FFT
                    // Copy input_len values from audio_buf to fft_buf, applying hann window along the way
                    // There's a little extra complexity here because the ringbuf exposes the content as two parts
                    audio_buf_out.access(|older, newer| {
                        if older.len() >= input_len {
                            // just read from older
                            for i in 0..input_len {
                                fft_buf[i] = Complex::new(window[i] * older[i], 0.0);
                            }
                        } else {
                            // read from older, then newer
                            for idx in 0..older.len() {
                                fft_buf[idx] = Complex::new(window[idx] * older[idx], 0.0);
                            }
                            for i in 0..(input_len - older.len()) {
                                let buf_idx = older.len() + i;
                                fft_buf[buf_idx] = Complex::new(window[buf_idx] * newer[i], 0.0);
                            }
                        }
                    });

                    // Remove the first HALF of the values from audio_buf, shifting it forward.
                    // The second half is left in-place.
                    // This allows us to reuse parts of the data that were deemphasized by the window filter.
                    // To illustrate:
                    // 1: 0 1 2 3 2 1 0
                    // 2:       0 1 2 3 2 1 0
                    // 3:             0 1 2 3 2 1 0
                    audio_buf_out.discard(input_len / 2);

                    // Process the selected data in fft_buf, then forward the result
                    fft.process(&mut fft_buf);
                    let result = Vec::from_iter(
                        fft_buf
                            .iter()
                            // Only take output_len bytes:
                            // - The output is two "mirrored" halves,
                            //   where result[0] is DC and result[input_len-1] is zero
                            // - For higher sample rates, we also remove some of the high freq space
                            .take(output_len)
                            // Convert complex values into floats
                            .map(|c| c.norm()),
                    );

                    // Normalize the values to fall within [0.0, 1.0], using the recent max value for scaling.
                    // Note: We do this in the fourier thread to take some load off of the main display thread.
                    // Get the max value up-front so that we don't change scaling partway through a sample.
                    let max = result
                        .iter()
                        .max_by(|x, y| {
                            if x == y {
                                std::cmp::Ordering::Equal
                            } else if x < y {
                                std::cmp::Ordering::Less
                            } else {
                                std::cmp::Ordering::Greater
                            }
                        })
                        .unwrap()
                        .clone();
                    // Slowly decrease the max across samples,
                    // but not so quickly that any fade-in is too obvious.
                    recent_max *= 0.995;
                    if max > recent_max {
                        recent_max = max;
                    }

                    // Divide the values by recent_max to get the [0.0, 1.0] scale before sending.
                    if let Err(e) = send_processed
                        .send(result.into_iter().map(|val| val / recent_max).collect())
                    {
                        // Output channel has closed, did something break?
                        warn!("exiting audio processing thread, output error: {}", e);
                        return;
                    }
                }
            }
            Err(_e) => {
                // Input channel has closed, expected when shutting down
                return;
            }
        }
    }
}
