use core::convert::Infallible;
use std::prelude::v1::*;

use numpy::{PyArray1, PyReadonlyArray1};
use probability::distribution::Gaussian;
use pyo3::{prelude::*, types::PyTuple};

use crate::{
    stream::{
        model::{DefaultContiguousCategoricalEntropyModel, DefaultLeakyQuantizer},
        Decode, Encode, TryCodingError,
    },
    CoderError, Pos, Seek, UnwrapInfallible,
};

use super::model::{internals::EncoderDecoderModel, Model};

pub fn init_module(_py: Python<'_>, module: &PyModule) -> PyResult<()> {
    module.add_class::<AnsCoder>()?;
    Ok(())
}

/// An entropy coder based on [Asymmetric Numeral Systems (ANS)] [1].
///
/// This is a wrapper around the Rust type [`constriction::stream::stack::DefaultAnsCoder`]
/// with python bindings.
///
/// Note that this entropy coder is a stack (a "last in first out" data
/// structure). You can push symbols on the stack using the methods
/// `encode_leaky_gaussian_symbols_reverse` or `encode_iid_categorical_symbols_reverse`, and then pop
/// them off *in reverse order* using the methods `decode_leaky_gaussian_symbols` or
/// `decode_iid_categorical_symbols`, respectively.
///
/// To copy out the compressed data that is currently on the stack, call
/// `get_compressed`. You would typically want write this to a binary file in some
/// well-documented byte order. After reading it back in at a later time, you can
/// decompress it by constructing an `constriction.AnsCoder` where you pass in the compressed
/// data as an argument to the constructor.
///
/// If you're only interested in the compressed file size, calling `num_bits` will
/// be cheaper as it won't actually copy out the compressed data.
///
/// ## Examples
///
/// ### Compression:
///
/// ```python
/// import sys
/// import constriction
/// import numpy as np
///
/// ans = constriction.stream.stack.AnsCoder()  # No arguments => empty ANS coder
///
/// symbols = np.array([2, -1, 0, 2, 3], dtype = np.int32)
/// min_supported_symbol, max_supported_symbol = -10, 10  # both inclusively
/// means = np.array([2.3, -1.7, 0.1, 2.2, -5.1], dtype = np.float64)
/// stds = np.array([1.1, 5.3, 3.8, 1.4, 3.9], dtype = np.float64)
///
/// ans.encode_leaky_gaussian_symbols_reverse(
///     symbols, min_supported_symbol, max_supported_symbol, means, stds)
///
/// print(f"Compressed size: {ans.num_valid_bits()} bits")
///
/// compressed = ans.get_compressed()
/// if sys.byteorder == "big":
///     # Convert native byte order to a consistent one (here: little endian).
///     compressed.byteswap(inplace=True)
/// compressed.tofile("compressed.bin")
/// ```
///
/// ### Decompression:
///
/// ```python
/// import sys
/// import constriction
/// import numpy as np
///
/// compressed = np.fromfile("compressed.bin", dtype=np.uint32)
/// if sys.byteorder == "big":
///     # Convert little endian byte order to native byte order.
///     compressed.byteswap(inplace=True)
///
/// ans = constriction.stream.stack.AnsCoder(compressed)
///
/// min_supported_symbol, max_supported_symbol = -10, 10  # both inclusively
/// means = np.array([2.3, -1.7, 0.1, 2.2, -5.1], dtype = np.float64)
/// stds = np.array([1.1, 5.3, 3.8, 1.4, 3.9], dtype = np.float64)
///
/// reconstructed = ans.decode_leaky_gaussian_symbols(
///     min_supported_symbol, max_supported_symbol, means, stds)
/// assert ans.is_empty()
/// print(reconstructed)  # Should print [2, -1, 0, 2, 3]
/// ```
///
/// ## Constructor
///
/// AnsCoder(compressed)
///
/// Arguments:
/// compressed (optional) -- initial compressed data, as a numpy array with
///     dtype `uint32`.
///
/// [Asymmetric Numeral Systems (ANS)]: https://en.wikipedia.org/wiki/Asymmetric_numeral_systems
/// [`constriction::stream::ans::DefaultAnsCoder`]: crate::stream::stack::DefaultAnsCoder
///
/// ## References
///
/// [1] Duda, Jarek, et al. "The use of asymmetric numeral systems as an accurate
/// replacement for Huffman coding." 2015 Picture Coding Symposium (PCS). IEEE, 2015.
#[pyclass]
#[pyo3(text_signature = "([compressed], seal=False)")]
#[derive(Debug)]
pub struct AnsCoder {
    inner: crate::stream::stack::DefaultAnsCoder,
}

#[pymethods]
impl AnsCoder {
    /// The constructor has the call signature `AnsCoder([compressed, [seal=False]])`.
    ///
    /// - If you want to encode a message, call the constructor with no arguments.
    /// - If you want to decode a message that was previously encoded with an `AnsCoder`, call the
    ///   constructor with a single argument `compressed`, which must be a rank-1 numpy array with
    ///   `dtype=np.uint32` (as returned by the method
    ///   [`get_compressed`](#constriction.stream.stack.AnsCoder.get_compressed) when invoked with
    ///   no arguments).
    /// - For bits-back related compression techniques, it can sometimes be useful to decode symbols
    ///   from some arbitrary bit string that was *not* generated by ANS. To do so, call the
    ///   constructor with the additional argument `seal=True` (if you don't set `seal` to `True`
    ///   then the `AnsCoder` will truncate any trailing zero words from `compressed`). Once you've
    ///   decoded and re-encoded some symbols, you can get back the original `compressed` data by
    ///   calling `.get_compressed(unseal=True)`.
    #[new]
    pub fn new(
        compressed: Option<PyReadonlyArray1<'_, u32>>,
        seal: Option<bool>,
    ) -> PyResult<Self> {
        if compressed.is_none() && seal.is_some() {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "Need compressed data to seal.",
            ));
        }
        let inner = if let Some(compressed) = compressed {
            let compressed = compressed.to_vec()?;
            if seal == Some(true) {
                crate::stream::stack::AnsCoder::from_binary(compressed).unwrap_infallible()
            } else {
                crate::stream::stack::AnsCoder::from_compressed(compressed).map_err(|_| {
                    pyo3::exceptions::PyValueError::new_err(
                        "Invalid compressed data: ANS compressed data never ends in a zero word.",
                    )
                })?
            }
        } else {
            crate::stream::stack::AnsCoder::new()
        };

        Ok(Self { inner })
    }

    /// Records a checkpoint to which you can jump during decoding using
    /// [`seek`](#constriction.stream.stack.AnsCoder.seek).
    ///
    /// Returns a tuple `(position, state)` where `position` is an integer that specifies how many
    /// 32-bit words of compressed data have been produced so far, and `state` is an integer that
    /// defines the `RangeEncoder`'s internal state (so that it can be restored upon
    /// [`seek`ing](#constriction.stream.stack.AnsCoder.seek).
    ///
    /// **Note:** Don't call `pos` if you just want to find out how much compressed data has been
    /// produced so far. Call [`num_words`](#constriction.stream.stack.AnsCoder.num_words)
    /// instead.
    ///
    /// ## Example
    ///
    /// See [`seek`](#constriction.stream.stack.AnsCoder.seek).
    #[pyo3(text_signature = "()")]
    pub fn pos(&mut self) -> (usize, u64) {
        self.inner.pos()
    }

    /// Jumps to a checkpoint recorded with method
    /// [`pos`](#constriction.stream.stack.AnsCoder.pos) during encoding.
    ///
    /// This allows random-access decoding. The arguments `position` and `state` are the two values
    /// returned by the method [`pos`](#constriction.stream.stack
    ///
    /// **Note:** in an ANS coder, both decoding and seeking *consume* compressed data. The Python
    /// API of `constriction`'s ANS coder currently supports only seeking forward but not backward
    /// (seeking backward is supported for Range Coding, and for both ANS and Range Coding in
    /// `constriction`'s Rust API).
    ///
    /// ## Example
    ///
    /// ```python
    /// probabilities = np.array([0.2, 0.4, 0.1, 0.3], dtype=np.float64)
    /// model         = constriction.stream.model.Categorical(probabilities)
    /// message_part1 = np.array([1, 2, 0, 3, 2, 3, 0], dtype=np.int32)
    /// message_part2 = np.array([2, 2, 0, 1, 3], dtype=np.int32)
    ///
    /// # Encode both parts of the message (in reverse order, because ANS
    /// # operates as a stack) and record a checkpoint in-between:
    /// coder = constriction.stream.stack.AnsCoder()
    /// coder.encode_reverse(message_part2, model)
    /// (position, state) = coder.pos() # Records a checkpoint.
    /// coder.encode_reverse(message_part1, model)
    ///
    /// # We could now call `coder.get_compressed()` but we'll just decode
    /// # directly from the original `coder` for simplicity.
    ///
    /// # Decode first symbol:
    /// print(coder.decode(model)) # (prints: 1)
    ///
    /// # Jump to part 2 and decode it:
    /// coder.seek(position, state)
    /// decoded_part2 = coder.decode(model, 5)
    /// assert np.all(decoded_part2 == message_part2)
    /// ```
    #[pyo3(text_signature = "(position, state)")]
    pub fn seek(&mut self, position: usize, state: u64) -> PyResult<()> {
        self.inner.seek((position, state)).map_err(|()| {
            pyo3::exceptions::PyAttributeError::new_err(
                "Tried to seek past end of stream. Note: in an ANS coder,\n\
                both decoding and seeking *consume* compressed data. The Python API of\n\
                `constriction`'s ANS coder currently does not support seeking backward.",
            )
        })
    }

    /// Resets the encoder to an empty state.
    ///
    /// This removes any existing compressed data on the encoder. It is equivalent to replacing the
    /// encoder with a new one but slightly more efficient.
    #[pyo3(text_signature = "()")]
    pub fn clear(&mut self) {
        self.inner.clear();
    }

    /// Returns the current size of the encapsulated compressed data, in `np.uint32` words.
    ///
    /// Thus, the number returned by this method is the length of the array that you would get if
    /// you called [`get_compressed`](#constriction.stream.queue.RangeEncoder.get_compressed)
    /// without arguments.
    #[pyo3(text_signature = "()")]
    pub fn num_words(&self) -> usize {
        self.inner.num_words()
    }

    /// Returns the current size of the compressed data, in bits, rounded up to full words.
    ///
    /// This is 32 times the result of what [`num_words`](#constriction.stream.queue.RangeEncoder.num_words)
    /// would return.
    #[pyo3(text_signature = "()")]
    pub fn num_bits(&self) -> usize {
        self.inner.num_bits()
    }

    /// The current size of the compressed data, in bits, not rounded up to full words.
    ///
    /// This can be at most 32 smaller than `.num_bits()`.
    #[pyo3(text_signature = "()")]
    pub fn num_valid_bits(&self) -> usize {
        self.inner.num_valid_bits()
    }

    /// Returns `True` iff the coder is in its default initial state.
    ///
    /// The default initial state is the state returned by the constructor when
    /// called without arguments, or the state to which the coder is set when
    /// calling `clear`.
    #[pyo3(text_signature = "()")]
    pub fn is_empty(&self) -> bool {
        self.inner.is_empty()
    }

    /// Returns a copy of the compressed data.
    ///
    /// You'll almost always want to call this method without arguments (which will default to
    /// `unseal=False`). See below for an explanation of the advanced use case with argument
    /// `unseal=True`.
    ///
    /// You will typically only want to call this method at the very end of your encoding task,
    /// i.e., once you've encoded the *entire* message. There is usually no need to call this method
    /// after encoding each symbol or other portion of your message. The encoders in `constriction`
    /// *accumulate* compressed data in an internal buffer, and encoding (semantically) *appends* to
    /// this buffer.
    ///
    /// That said, calling `get_compressed` has no side effects, so you *can* call `get_compressed`,
    /// then continue to encode more symbols, and then call `get_compressed` again. The first call
    /// of `get_compressed` will have no effect on the return value of the second call of
    /// `get_compressed`.
    ///
    /// The return value is a rank-1 numpy array of `dtype=np.uint32`. You can write it to a file by
    /// calling `to_file` on it, but we recommend to convert it into an architecture-independent
    /// byte order first:
    ///
    /// ```python
    /// import sys
    ///
    /// encoder = constriction.stream.stack.AnsCoder()
    /// # ... encode some message (skipped here) ...
    /// compressed = encoder.get_compressed() # returns a numpy array.
    /// if sys.byteorder != 'little':
    ///     # Let's save data in little-endian byte order by convention.
    ///     compressed.byteswap(inplace=True)
    /// compressed.tofile('compressed-file.bin')
    ///
    /// # At a later point, you might want to read and decode the file:
    /// compressed = np.fromfile('compressed-file.bin', dtype=np.uint32)
    /// if sys.byteorder != 'little':
    ///     # Restore native byte order before passing it to `constriction`.
    ///     compressed.byteswap(inplace=True)
    /// decoder = constriction.stream.stack.AnsCoder(compressed)
    /// # ... decode the message (skipped here) ...
    /// ```    
    ///
    /// ## Explanation of the optional argument `unseal`
    ///
    /// The optional argument `unseal` of this method is the counterpart to the optional argument
    /// `seal` of the constructor. Calling `.get_compressed(unseal=True)` tells the ANS coder that
    /// you expect it to be in a "sealed" state and instructs it to reverse the "sealing" operation.
    /// An ANS coder is in a sealed state if its encapsulated compressed data ends in a single "1"
    /// word. Calling the constructor of `AnsCoder` with argument `seal=True` constructs a coder
    /// that is guaranteed to be in a sealed state because the constructor will append a single "1"
    /// word to the provided `compressed` data. This sealing/unsealing operation makes sure that any
    /// trailing zero words are conserved since an `AnsCoder` would otherwise truncate them.
    ///
    /// Note that calling `.get_compressed(unseal=True)` fails if the coder is not in a "sealed"
    /// state.
    #[pyo3(text_signature = "(unseal=False)")]
    pub fn get_compressed<'p>(
        &mut self,
        py: Python<'p>,
        unseal: Option<bool>,
    ) -> PyResult<&'p PyArray1<u32>> {
        if unseal == Some(true) {
            let binary = self.inner.get_binary().map_err(|_|
                pyo3::exceptions::PyAssertionError::new_err(
                    "Cannot unseal compressed data because it doesn't fit into integer number of words. Did you create the encoder with `seal=True` and restore its original state?",
                ))?;
            Ok(PyArray1::from_slice(py, &*binary))
        } else {
            Ok(PyArray1::from_slice(
                py,
                &*self.inner.get_compressed().unwrap_infallible(),
            ))
        }
    }

    /// .. deprecated:: 0.2.0
    ///    Superseded by `.get_compressed(unseal=True)`.
    #[pyo3(text_signature = "(DEPRECATED)")]
    pub fn get_binary<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyArray1<u32>> {
        self.get_compressed(py, Some(true))
    }

    /// Encodes one or more symbols, appending them to the encapsulated compressed data.
    ///
    /// This method can be called in 3 different ways:
    ///
    /// ## Option 1: encode_reverse(symbol, model)
    ///
    /// Encodes a *single* symbol with a concrete (i.e., fully parameterized) entropy model; the
    /// suffix "_reverse" of the method name has no significance when called this way.
    ///
    /// For optimal computational efficiency, don't use this option in a loop if you can instead
    /// use one of the two alternative options below.
    ///
    /// For example:
    ///
    /// ```python
    /// # Define a concrete categorical entropy model over the (implied)
    /// # alphabet {0, 1, 2}:
    /// probabilities = np.array([0.1, 0.6, 0.3], dtype=np.float64)
    /// model = constriction.stream.model.Categorical(probabilities)
    ///
    /// # Encode a single symbol with this entropy model:
    /// coder = constriction.stream.stack.AnsCoder()
    /// coder.encode_reverse(2, model) # Encodes the symbol `2`.
    /// # ... then encode some more symbols ...
    /// ```
    ///
    /// ## Option 2: encode_reverse(symbols, model)
    ///
    /// Encodes multiple i.i.d. symbols, i.e., all symbols in the rank-1 array `symbols` will be
    /// encoded with the same concrete (i.e., fully parameterized) entropy model. The symbols are
    /// encoded in *reverse* order so that subsequent decoding will retrieve them in forward order
    /// (see [module-level example](#example)).
    ///
    /// For example:
    ///
    /// ```python
    /// # Use the same concrete entropy model as in the previous example:
    /// probabilities = np.array([0.1, 0.6, 0.3], dtype=np.float64)
    /// model = constriction.stream.model.Categorical(probabilities)
    ///
    /// # Encode an example message using the above `model` for all symbols:
    /// symbols = np.array([0, 2, 1, 2, 0, 2, 0, 2, 1], dtype=np.int32)
    /// coder = constriction.stream.stack.AnsCoder()
    /// coder.encode_reverse(symbols, model)
    /// print(coder.get_compressed()) # (prints: [1276728145, 172])
    /// ```
    ///
    /// ## Option 3: encode_reverse(symbols, model_family, params1, params2, ...)
    ///
    /// Encodes multiple symbols, using the same *family* of entropy models (e.g., categorical or
    /// quantized Gaussian) for all symbols, but with different model parameters for each symbol;
    /// here, each `paramsX` argument is an array of the same length as `symbols`. The number of
    /// required `paramsX` arguments and their shapes and `dtype`s depend on the model family. The
    /// symbols are encoded in *reverse* order so that subsequent decoding will retrieve them in
    /// forward order (see [module-level example](#example)). But the mapping between symbols and
    /// model parameters is as you'd expect it to be (i.e., `symbols[i]` gets encoded with model
    /// parameters `params1[i]`, `params2[i]`, and so on, where `i` counts backwards).
    ///
    /// For example, the
    /// [`QuantizedGaussian`](model.html#constriction.stream.model.QuantizedGaussian) model family
    /// expects two rank-1 model parameters of dtype `np.float64`, which specify the mean and
    /// standard deviation for each entropy model:
    ///
    /// ```python
    /// # Define a generic quantized Gaussian distribution for all integers
    /// # in the range from -100 to 100 (both ends inclusive):
    /// model_family = constriction.stream.model.QuantizedGaussian(-100, 100)
    ///    
    /// # Specify the model parameters for each symbol:
    /// means = np.array([10.3, -4.7, 20.5], dtype=np.float64)
    /// stds  = np.array([ 5.2, 24.2,  3.1], dtype=np.float64)
    ///    
    /// # Encode an example message:
    /// # (needs `len(symbols) == len(means) == len(stds)`)
    /// symbols = np.array([12, -13, 25], dtype=np.int32)
    /// coder = constriction.stream.stack.AnsCoder()
    /// coder.encode_reverse(symbols, model_family, means, stds)
    /// print(coder.get_compressed()) # (prints: [597775281, 3])
    /// ```
    ///
    /// By contrast, the [`Categorical`](model.html#constriction.stream.model.Categorical) model
    /// family expects a single rank-2 model parameter where the i'th row lists the
    /// probabilities for each possible value of the i'th symbol:
    ///
    /// ```python
    /// # Define 2 categorical models over the alphabet {0, 1, 2, 3, 4}:
    /// probabilities = np.array(
    ///     [[0.1, 0.2, 0.3, 0.1, 0.3],  # (for symbols[0])
    ///      [0.3, 0.2, 0.2, 0.2, 0.1]], # (for symbols[1])
    ///     dtype=np.float64)
    /// model_family = constriction.stream.model.Categorical()
    ///
    /// # Encode 2 symbols (needs `len(symbols) == probabilities.shape[0]`):
    /// symbols = np.array([3, 1], dtype=np.int32)
    /// coder = constriction.stream.stack.AnsCoder()
    /// coder.encode_reverse(symbols, model_family, probabilities)
    /// print(coder.get_compressed()) # (prints: [45298483])
    /// ```
    #[pyo3(text_signature = "(symbols, model, optional_model_params)")]
    #[args(symbols, model, params = "*")]
    pub fn encode_reverse(
        &mut self,
        py: Python<'_>,
        symbols: &PyAny,
        model: &Model,
        params: &PyTuple,
    ) -> PyResult<()> {
        if let Ok(symbol) = symbols.extract::<i32>() {
            if !params.is_empty() {
                return Err(pyo3::exceptions::PyAttributeError::new_err(
                    "To encode a single symbol, use a concrete model, i.e., pass the\n\
                    model parameters directly to the constructor of the model and not to the\n\
                    `encode` method of the entropy coder. Delaying the specification of model\n\
                    parameters until calling `encode_reverse` is only useful if you want to encode
                    several symbols in a row with individual model parameters for each symbol. If\n\
                    this is what you're trying to do then the `symbols` argument should be a numpy\n\
                    array, not a scalar.",
                ));
            }
            return model.0.as_parameterized(py, &mut |model| {
                self.inner
                    .encode_symbol(symbol, EncoderDecoderModel(model))?;
                Ok(())
            });
        }

        // Don't use an `else` branch here because, if the following `extract` fails, the returned
        // error message is actually pretty user friendly.
        let symbols = symbols.extract::<PyReadonlyArray1<'_, i32>>()?;
        let symbols = symbols.as_slice()?;

        if params.is_empty() {
            model.0.as_parameterized(py, &mut |model| {
                self.inner
                    .encode_iid_symbols_reverse(symbols, EncoderDecoderModel(model))?;
                Ok(())
            })?;
        } else {
            if symbols.len() != model.0.len(&params[0])? {
                return Err(pyo3::exceptions::PyAttributeError::new_err(
                    "`symbols` argument has wrong length.",
                ));
            }
            let mut symbol_iter = symbols.iter().rev();
            model.0.parameterize(py, params, true, &mut |model| {
                let symbol = symbol_iter.next().expect("TODO");
                self.inner
                    .encode_symbol(*symbol, EncoderDecoderModel(model))?;
                Ok(())
            })?;
        }

        Ok(())
    }

    /// .. deprecated:: 0.2.0
    ///    This method has been superseded by the new and more powerful generic
    ///    [`encode_reverse`](#constriction.stream.stack.AnsCoder.encode_reverse) method in conjunction with the
    ///    [`QuantizedGaussian`](model.html#constriction.stream.model.QuantizedGaussian) model.
    ///
    ///    To encode an array of symbols with an individual quantized Gaussian distribution for each
    ///    symbol, do the following now:
    ///
    ///    ```python
    ///    # Define a generic quantized Gaussian distribution for all integers
    ///    # in the range from -100 to 100 (both ends inclusive):
    ///    model_family = constriction.stream.model.QuantizedGaussian(-100, 100)
    ///
    ///    # Specify the model parameters for each symbol:
    ///    means = np.array([10.3, -4.7, 20.5], dtype=np.float64)
    ///    stds  = np.array([ 5.2, 24.2,  3.1], dtype=np.float64)
    ///
    ///    # Encode an example message:
    ///    # (needs `len(symbols) == len(means) == len(stds)`)
    ///    symbols = np.array([12, -13, 25], dtype=np.int32)
    ///    coder = constriction.stream.stack.AnsCoder()
    ///    coder.encode_reverse(symbols, model_family, means, stds)
    ///    print(coder.get_compressed()) # (prints: [597775281, 3])
    ///    ```
    ///
    ///    If all symbols have the same entropy model (i.e., the same mean and standard deviation),
    ///    then you can use the following shortcut, which is also more computationally efficient:
    ///
    ///    ```python
    ///    # Define a *concrete* quantized Gaussian distribution for all
    ///    # integers in the range from -100 to 100 (both ends inclusive),
    ///    # with a fixed mean of 16.7 and a fixed standard deviation of 9.3:
    ///    model = constriction.stream.model.QuantizedGaussian(
    ///        -100, 100, 16.7, 9.3)
    ///
    ///    # Encode an example message using the above `model` for all symbols:
    ///    symbols = np.array([18, 43, 25, 20, 8, 11], dtype=np.int32)
    ///    coder = constriction.stream.stack.AnsCoder()
    ///    coder.encode_reverse(symbols, model)
    ///    print(coder.get_compressed()) # (prints: [4119848034, 921135])
    ///    ```
    ///
    ///    For more information, see [`QuantizedGaussian`](model.html#constriction.stream.model.QuantizedGaussian).
    #[pyo3(text_signature = "(DEPRECATED)")]
    pub fn encode_leaky_gaussian_symbols_reverse(
        &mut self,
        py: Python<'_>,
        symbols: PyReadonlyArray1<'_, i32>,
        min_supported_symbol: i32,
        max_supported_symbol: i32,
        means: PyReadonlyArray1<'_, f64>,
        stds: PyReadonlyArray1<'_, f64>,
    ) -> PyResult<()> {
        let _ = py.run(
            "print('WARNING: the method `encode_leaky_gaussian_symbols_reverse` is deprecated. Use method\\n\
            \x20        `encode_reverse` instead. For transition instructions with code examples, see:\\n\
            https://bamler-lab.github.io/constriction/apidoc/python/stream/model.html#examples')",
            None,
            None
        );

        let (symbols, means, stds) = (symbols.as_slice()?, means.as_slice()?, stds.as_slice()?);
        if symbols.len() != means.len() || symbols.len() != stds.len() {
            return Err(pyo3::exceptions::PyAttributeError::new_err(
                "`symbols`, `means`, and `stds` must all have the same length.",
            ));
        }

        let quantizer = DefaultLeakyQuantizer::new(min_supported_symbol..=max_supported_symbol);
        self.inner.try_encode_symbols_reverse(
            symbols
                .iter()
                .zip(means.iter())
                .zip(stds.iter())
                .map(|((&symbol, &mean), &std)| {
                    if std > 0.0 && std.is_finite() && mean.is_finite() {
                        Ok((symbol, quantizer.quantize(Gaussian::new(mean, std))))
                    } else {
                        Err(())
                    }
                }),
        )?;

        Ok(())
    }

    /// .. deprecated:: 0.2.0
    ///    This method has been superseded by the new and more powerful generic
    ///    [`encode_reverse`](#constriction.stream.stack.AnsCoder.encode_reverse) method in conjunction with the
    ///    [`Categorical`](model.html#constriction.stream.model.Categorical) model.
    ///
    ///    To encode an array of i.i.d. symbols with a fixed categorical entropy model, do the
    ///    following now:
    ///
    ///    ```python
    ///    # Define a categorical model over the (implied) alphabet {0, 1, 2}:
    ///    probabilities = np.array([0.1, 0.6, 0.3], dtype=np.float64)
    ///    model = constriction.stream.model.Categorical(probabilities)
    ///
    ///    # Encode an example message using the above `model` for all symbols:
    ///    symbols = np.array([0, 2, 1, 2, 0, 2, 0, 2, 1], dtype=np.int32)
    ///    coder = constriction.stream.stack.AnsCoder()
    ///    coder.encode_reverse(symbols, model)
    ///    print(coder.get_compressed()) # (prints: [1276728145, 172])
    ///    ```
    ///
    ///    This new API also allows you to use an *individual* entropy model for each encoded symbol
    ///    (although this is less computationally efficient):
    ///
    ///    ```python
    ///    # Define 2 categorical models over the alphabet {0, 1, 2, 3, 4}:
    ///    probabilities = np.array(
    ///        [[0.1, 0.2, 0.3, 0.1, 0.3],  # (for symbols[0])
    ///         [0.3, 0.2, 0.2, 0.2, 0.1]], # (for symbols[1])
    ///        dtype=np.float64)
    ///    model_family = constriction.stream.model.Categorical()
    ///
    ///    # Encode 2 symbols (needs `len(symbols) == probabilities.shape[0]`):
    ///    symbols = np.array([3, 1], dtype=np.int32)
    ///    coder = constriction.stream.stack.AnsCoder()
    ///    coder.encode_reverse(symbols, model_family, probabilities)
    ///    print(coder.get_compressed()) # (prints: [45298483])
    ///    ```
    #[pyo3(text_signature = "(DEPRECATED)")]
    pub fn encode_iid_categorical_symbols_reverse(
        &mut self,
        py: Python<'_>,
        symbols: PyReadonlyArray1<'_, i32>,
        min_supported_symbol: i32,
        probabilities: PyReadonlyArray1<'_, f64>,
    ) -> PyResult<()> {
        let _ = py.run(
            "print('WARNING: the method `encode_iid_categorical_symbols_reverse` is deprecated. Use method\\n\
            \x20        `encode_reverse` instead. For transition instructions with code examples, see:\\n\
            https://bamler-lab.github.io/constriction/apidoc/python/stream/model.html#constriction.stream.model.Categorical')",
            None,
            None
        );

        let model = DefaultContiguousCategoricalEntropyModel::from_floating_point_probabilities(
            probabilities.as_slice()?,
        )
        .map_err(|()| {
            pyo3::exceptions::PyValueError::new_err(
                "Probability model is either degenerate or not normalizable.",
            )
        })?;

        self.inner.encode_iid_symbols_reverse(
            symbols
                .as_slice()?
                .iter()
                .map(|s| s.wrapping_sub(min_supported_symbol) as usize),
            &model,
        )?;

        Ok(())
    }

    /// .. deprecated:: 0.2.0
    ///    This method has been superseded by the new and more powerful generic
    ///    [`encode_reverse`](#constriction.stream.stack.AnsCoder.encode_reverse) method in conjunction with the
    ///    [`CustomModel`](model.html#constriction.stream.model.CustomModel) or
    ///    [`ScipyModel`](model.html#constriction.stream.model.ScipyModel) model class.
    ///
    ///    To encode an array of symbols with a custom entropy model, do the following now:
    ///
    ///    ```python
    ///    # Define the cumulative distribution function (CDF) and (approximate)
    ///    # inverse of it (sometimes called the percent point function or PPF):
    ///    def cdf(x, model_param1, model_param2):
    ///        # TODO (note: you may also leave out the `model_param`s)
    ///    def ppf(xi, model_param1, model_param2):
    ///        # TODO
    ///
    ///    # Wrap them in a `CustomModel`:
    ///    model = constriction.stream.model.CustomModel(cdf, ppf, -100, 100)
    ///
    ///    # Encode an example message using the above `model` for all symbols:
    ///    message      = np.array([... TODO ...], dtype=np.int32)
    ///    model_prams1 = np.array([... TODO ...], dtype=np.float64)
    ///    model_prams2 = np.array([... TODO ...], dtype=np.float64)
    ///    coder = constriction.stream.stack.AnsCoder()
    ///    coder.encode_reverse(message, model, model_params1, model_params2)
    ///    ```
    ///
    ///    **Hint:** the `scipy` python package provides a number of predefined models, and
    ///    `constriction` offers a convenient wrapper around `scipy` models:
    ///
    ///    ```python
    ///    import scipy.stats
    ///
    ///    coder = constriction.stream.stack.AnsCoder()
    ///
    ///    # Encode an example message with an i.i.d. entropy model from scipy:
    ///    scipy_model = scipy.stats.cauchy(10.2, 16.8)
    ///    constriction_model = constriction.stream.model.ScipyModel(
    ///        scipy_model, -100, 100)
    ///    message_part1 = np.array([-4, 41, 30, 23, -15], dtype=np.int32)
    ///    coder.encode_reverse(message_part1, constriction_model)
    ///
    ///    # Append some more symbols with per-symbol model parameters:
    ///    scipy_model_family = scipy.stats.cauchy
    ///    model_family = constriction.stream.model.ScipyModel(
    ///        scipy_model_family, -100, 100)
    ///    message_part2 = np.array([11,    2,   -18  ], dtype=np.int32)
    ///    means         = np.array([13.2, -5.7, -21.2], dtype=np.float64)
    ///    scales        = np.array([ 4.6, 13.4,   5.7], dtype=np.float64)
    ///    coder.encode_reverse(message_part2, model_family, means, scales)
    ///
    ///    print(coder.get_compressed()) # (prints: [609762275, 3776398430])
    ///    ```
    #[pyo3(text_signature = "(DEPRECATED)")]
    pub fn encode_iid_custom_model_reverse<'py>(
        &mut self,
        py: Python<'py>,
        symbols: PyReadonlyArray1<'_, i32>,
        model: &Model,
    ) -> PyResult<()> {
        let _ = py.run(
            "print('WARNING: the method `encode_iid_custom_model_reverse` is deprecated. Use method\\n\
            \x20        `encode_reverse` instead. For transition instructions with code examples, see:\\n\
            https://bamler-lab.github.io/constriction/apidoc/python/stream/model.html#constriction.stream.model.CustomModel')",
            None,
            None
        );

        self.encode_reverse(py, &symbols, model, PyTuple::empty(py))
    }

    /// Decodes one or more symbols, consuming them from the encapsulated compressed data.
    ///
    /// This method can be called in 3 different ways:
    ///
    /// ## Option 1: decode(model)
    ///
    /// Decodes a *single* symbol with a concrete (i.e., fully parameterized) entropy model and
    /// returns the decoded symbol; (for optimal computational efficiency, don't use this option in
    /// a loop if you can instead use one of the two alternative options below.)
    ///
    /// For example:
    ///
    /// ```python
    /// # Define a concrete categorical entropy model over the (implied)
    /// # alphabet {0, 1, 2}:
    /// probabilities = np.array([0.1, 0.6, 0.3], dtype=np.float64)
    /// model = constriction.stream.model.Categorical(probabilities)
    ///
    /// # Decode a single symbol from some example compressed data:
    /// compressed = np.array([636697421, 6848946], dtype=np.uint32)
    /// coder = constriction.stream.stack.AnsCoder(compressed)
    /// symbol = coder.decode(model)
    /// print(symbol) # (prints: 2)
    /// # ... then decode some more symbols ...
    /// ```
    ///
    /// ## Option 2: decode(model, amt) [where `amt` is an integer]
    ///
    /// Decodes `amt` i.i.d. symbols using the same concrete (i.e., fully parametrized) entropy
    /// model for each symbol, and returns the decoded symbols as a rank-1 numpy array with
    /// `dtype=np.int32` and length `amt`;
    ///
    /// For example:
    ///
    /// ```python
    /// # Use the same concrete entropy model as in the previous example:
    /// probabilities = np.array([0.1, 0.6, 0.3], dtype=np.float64)
    /// model = constriction.stream.model.Categorical(probabilities)
    ///
    /// # Decode 9 symbols from some example compressed data, using the
    /// # same (fixed) entropy model defined above for all symbols:
    /// compressed = np.array([636697421, 6848946], dtype=np.uint32)
    /// coder = constriction.stream.stack.AnsCoder(compressed)
    /// symbols = coder.decode(model, 9)
    /// print(symbols) # (prints: [2, 0, 0, 1, 2, 2, 1, 2, 2])
    /// ```
    ///
    /// ## Option 3: decode(model_family, params1, params2, ...)
    ///
    /// Decodes multiple symbols, using the same *family* of entropy models (e.g., categorical or
    /// quantized Gaussian) for all symbols, but with different model parameters for each symbol,
    /// and returns the decoded symbols as a rank-1 numpy array with `dtype=np.int32`; here, all
    /// `paramsX` arguments are arrays of equal length (the number of symbols to be decoded). The
    /// number of required `paramsX` arguments and their shapes and `dtype`s depend on the model
    /// family.
    ///
    /// For example, the
    /// [`QuantizedGaussian`](model.html#constriction.stream.model.QuantizedGaussian) model family
    /// expects two rank-1 model parameters of dtype `np.float64`, which specify the mean and
    /// standard deviation for each entropy model:
    ///
    /// ```python
    /// # Define a generic quantized Gaussian distribution for all integers
    /// # in the range from -100 to 100 (both ends inclusive):
    /// model_family = constriction.stream.model.QuantizedGaussian(-100, 100)
    ///
    /// # Specify the model parameters for each symbol:
    /// means = np.array([10.3, -4.7, 20.5], dtype=np.float64)
    /// stds  = np.array([ 5.2, 24.2,  3.1], dtype=np.float64)
    ///
    /// # Decode a message from some example compressed data:
    /// compressed = np.array([597775281, 3], dtype=np.uint32)
    /// coder = constriction.stream.stack.AnsCoder(compressed)
    /// symbols = coder.decode(model_family, means, stds)
    /// print(symbols) # (prints: [12, -13, 25])
    /// ```
    ///
    /// By contrast, the [`Categorical`](model.html#constriction.stream.model.Categorical) model
    /// family expects a single rank-2 model parameter where the i'th row lists the
    /// probabilities for each possible value of the i'th symbol:
    ///
    /// ```python
    /// # Define 2 categorical models over the alphabet {0, 1, 2, 3, 4}:
    /// probabilities = np.array(
    ///     [[0.1, 0.2, 0.3, 0.1, 0.3],  # (for first decoded symbol)
    ///      [0.3, 0.2, 0.2, 0.2, 0.1]], # (for second decoded symbol)
    ///     dtype=np.float64)
    /// model_family = constriction.stream.model.Categorical()
    ///
    /// # Decode 2 symbols:
    /// compressed = np.array([2142112014, 31], dtype=np.uint32)
    /// coder = constriction.stream.stack.AnsCoder(compressed)
    /// symbols = coder.decode(model_family, probabilities)
    /// print(symbols) # (prints: [3, 1])
    /// ```
    #[pyo3(text_signature = "(model, optional_amt_or_model_params)")]
    #[args(symbols, model, params = "*")]
    pub fn decode<'py>(
        &mut self,
        py: Python<'py>,
        model: &Model,
        params: &PyTuple,
    ) -> PyResult<PyObject> {
        match params.len() {
            0 => {
                let mut symbol = 0;
                model.0.as_parameterized(py, &mut |model| {
                    symbol = self
                        .inner
                        .decode_symbol(EncoderDecoderModel(model))
                        .unwrap_infallible();
                    Ok(())
                })?;
                return Ok(symbol.to_object(py));
            }
            1 => {
                if let Ok(amt) = usize::extract(params.as_slice()[0]) {
                    let mut symbols = Vec::with_capacity(amt);
                    model.0.as_parameterized(py, &mut |model| {
                        for symbol in self
                            .inner
                            .decode_iid_symbols(amt, EncoderDecoderModel(model))
                        {
                            symbols.push(symbol.unwrap_infallible());
                        }
                        Ok(())
                    })?;
                    return Ok(PyArray1::from_iter(py, symbols).to_object(py));
                }
            }
            _ => {} // Fall through to code below.
        };

        let mut symbols = Vec::with_capacity(model.0.len(&params[0])?);
        model.0.parameterize(py, params, false, &mut |model| {
            let symbol = self
                .inner
                .decode_symbol(EncoderDecoderModel(model))
                .unwrap_infallible();
            symbols.push(symbol);
            Ok(())
        })?;

        Ok(PyArray1::from_vec(py, symbols).to_object(py))
    }

    /// .. deprecated:: 0.2.0
    ///    This method has been superseded by the new and more powerful generic
    ///    [`decode`](#constriction.stream.stack.AnsCoder.decode) method in conjunction with the
    ///    [`QuantizedGaussian`](model.html#constriction.stream.model.QuantizedGaussian) model.
    ///
    ///    To decode an array of symbols with an individual quantized Gaussian distribution for each
    ///    symbol, do the following now:
    ///
    ///    ```python
    ///    # Define a generic quantized Gaussian distribution for all integers
    ///    # in the range from -100 to 100 (both ends inclusive):
    ///    model_family = constriction.stream.model.QuantizedGaussian(-100, 100)
    ///
    ///    # Specify the model parameters for each symbol:
    ///    means = np.array([10.3, -4.7, 20.5], dtype=np.float64)
    ///    stds  = np.array([ 5.2, 24.2,  3.1], dtype=np.float64)
    ///
    ///    # Decode a message from some example compressed data:
    ///    compressed = np.array([597775281, 3], dtype=np.uint32)
    ///    coder = constriction.stream.stack.AnsCoder(compressed)
    ///    symbols = coder.decode(model_family, means, stds)
    ///    print(symbols) # (prints: [12, -13, 25])
    ///    ```
    ///
    ///    If all symbols have the same entropy model (i.e., the same mean and standard deviation),
    ///    then you can use the following shortcut, which is also more computationally efficient:
    ///
    ///    ```python
    ///    # Define a *concrete* quantized Gaussian distribution for all
    ///    # integers in the range from -100 to 100 (both ends inclusive),
    ///    # with a fixed mean of 16.7 and a fixed standard deviation of 9.3:
    ///    model = constriction.stream.model.QuantizedGaussian(
    ///        -100, 100, 16.7, 9.3)
    ///
    ///    # Decode a message from some example compressed data:
    ///    compressed = np.array([4119848034, 921135], dtype=np.uint32)
    ///    coder = constriction.stream.stack.AnsCoder(compressed)
    ///    symbols = coder.decode(model, 6) # (decodes 6 symbols)
    ///    print(symbols) # (prints: [18, 43, 25, 20, 8, 11])
    ///    ```
    ///
    ///    For more information, see [`QuantizedGaussian`](model.html#constriction.stream.model.QuantizedGaussian).
    #[pyo3(text_signature = "(DEPRECATED)")]
    pub fn decode_leaky_gaussian_symbols<'p>(
        &mut self,
        min_supported_symbol: i32,
        max_supported_symbol: i32,
        means: PyReadonlyArray1<'_, f64>,
        stds: PyReadonlyArray1<'_, f64>,
        py: Python<'p>,
    ) -> PyResult<&'p PyArray1<i32>> {
        let _ = py.run(
            "print('WARNING: the method `decode_leaky_gaussian_symbols` is deprecated. Use method\\n\
            \x20        `decode` instead. For transition instructions with code examples, see:\\n\
            https://bamler-lab.github.io/constriction/apidoc/python/stream/model.html#examples')",
            None,
            None
        );

        if means.len() != stds.len() {
            return Err(pyo3::exceptions::PyAttributeError::new_err(
                "`means`, and `stds` must have the same length.",
            ));
        }

        let quantizer = DefaultLeakyQuantizer::new(min_supported_symbol..=max_supported_symbol);
        let symbols = self
            .inner
            .try_decode_symbols(means.iter()?.zip(stds.iter()?).map(|(&mean, &std)| {
                if std > 0.0 && std.is_finite() && mean.is_finite() {
                    Ok(quantizer.quantize(Gaussian::new(mean, std)))
                } else {
                    Err(())
                }
            }))
            .collect::<std::result::Result<Vec<_>, _>>()
            .map_err(|_err:TryCodingError<CoderError<Infallible, Infallible>, ()>| {
                pyo3::exceptions::PyValueError::new_err(
                    "Invalid model parameters (`std` must be strictly positive and both `std` and `mean` must be finite.).",
                )
            })?;

        Ok(PyArray1::from_vec(py, symbols))
    }

    /// .. deprecated:: 0.2.0
    ///    This method has been superseded by the new and more powerful generic
    ///    [`decode`](#constriction.stream.stack.AnsCoder.decode) method in conjunction with the
    ///    [`Categorical`](model.html#constriction.stream.model.Categorical) model.
    ///
    ///    To decode an array of i.i.d. symbols with a fixed categorical entropy model, do the
    ///    following now:
    ///
    ///    ```python
    ///    # Define a categorical model over the (implied) alphabet {0, 1, 2}:
    ///    probabilities = np.array([0.1, 0.6, 0.3], dtype=np.float64)
    ///    model = constriction.stream.model.Categorical(probabilities)
    ///
    ///    # Decode 9 symbols from some example compressed data, using the
    ///    # same (fixed) entropy model defined above for all symbols:
    ///    compressed = np.array([1276728145, 172], dtype=np.uint32)
    ///    coder = constriction.stream.stack.AnsCoder(compressed)
    ///    symbols = coder.decode(model, 9) # (decodes 9 symbols)
    ///    print(symbols) # (prints: [0, 2, 1, 2, 0, 2, 0, 2, 1])
    ///    ```
    ///
    ///    This new API also allows you to use an *individual* entropy model for each decoded symbol
    ///    (although this is less computationally efficient):
    ///
    ///    ```python
    ///    # Define 2 categorical models over the alphabet {0, 1, 2, 3, 4}:
    ///    probabilities = np.array(
    ///        [[0.1, 0.2, 0.3, 0.1, 0.3],  # (for first decoded symbol)
    ///         [0.3, 0.2, 0.2, 0.2, 0.1]], # (for second decoded symbol)
    ///        dtype=np.float64)
    ///    model_family = constriction.stream.model.Categorical()
    ///
    ///    # Decode 2 symbols:
    ///    compressed = np.array([45298483], dtype=np.uint32)
    ///    coder = constriction.stream.stack.AnsCoder(compressed)
    ///    symbols = coder.decode(model_family, probabilities)
    ///    print(symbols) # (prints: [3, 1])
    ///    ```
    #[pyo3(text_signature = "(DEPRECATED)")]
    pub fn decode_iid_categorical_symbols<'py>(
        &mut self,
        amt: usize,
        min_supported_symbol: i32,
        probabilities: PyReadonlyArray1<'_, f64>,
        py: Python<'py>,
    ) -> PyResult<&'py PyArray1<i32>> {
        let _ = py.run(
            "print('WARNING: the method `decode_iid_categorical_symbols` is deprecated. Use method\\n\
            \x20        `decode` instead. For transition instructions with code examples, see:\\n\
            https://bamler-lab.github.io/constriction/apidoc/python/stream/model.html#constriction.stream.model.Categorical')",
            None,
            None
        );

        let model = DefaultContiguousCategoricalEntropyModel::from_floating_point_probabilities(
            probabilities.as_slice()?,
        )
        .map_err(|()| {
            pyo3::exceptions::PyValueError::new_err(
                "Probability distribution is either degenerate or not normalizable.",
            )
        })?;

        Ok(PyArray1::from_iter(
            py,
            self.inner.decode_iid_symbols(amt, &model).map(|symbol| {
                (symbol.unwrap_infallible() as i32).wrapping_add(min_supported_symbol)
            }),
        ))
    }

    /// .. deprecated:: 0.2.0
    ///    This method has been superseded by the new and more powerful generic
    ///    [`decode`](#constriction.stream.stack.AnsCoder.decode) method in conjunction with the
    ///    [`CustomModel`](model.html#constriction.stream.model.CustomModel) or
    ///    [`ScipyModel`](model.html#constriction.stream.model.ScipyModel) model class.
    ///
    ///    Note that the new API expects the parameters in opposite order. So, to transition,
    ///    replace
    ///
    ///    ```python
    ///    coder.decode_iid_custom_model(amt, model) # DEPRECATED
    ///    ```
    ///
    ///    with
    ///
    ///    ```python
    ///    coder.decode(model, amt) # new API
    ///    ```
    ///
    ///    The new API also allows you to provide additional per-symbol model parameters to the
    ///    `decode` method (instead of an `amt` parameter):
    ///
    ///    ```python
    ///    coder.decode(model, params1, params2, ...) # new API
    ///    ```
    ///
    ///    Here, the `paramsX` arguments must be rank-1 numpy arrays with `dtype=np.float64`. The
    ///    parameters will be passed to your custom model's CDF and PPF as individual additional
    ///    scalar arguments. (This is a breaking change to the pre-1.0 method `decode_custom_model`,
    ///    which served the same purpose but passed additional model parameters to the CDF and PPF
    ///    as a single numpy array, which turned out to be cumbersome to deal with.)
    ///
    ///    For more information and code examples, see
    ///    [`CustomModel`](model.html#constriction.stream.model.CustomModel) and
    ///    [`ScipyModel`](model.html#constriction.stream.model.ScipyModel).
    #[pyo3(text_signature = "(DEPRECATED)")]
    pub fn decode_iid_custom_model<'py>(
        &mut self,
        py: Python<'py>,
        amt: usize,
        model: &Model,
    ) -> PyResult<PyObject> {
        let _ = py.run(
            "print('WARNING: the method `decode_iid_custom_model` is deprecated. Use method\\n\
            \x20        `encode_reverse` instead. For transition instructions with code examples, see:\\n\
            https://bamler-lab.github.io/constriction/apidoc/python/stream/model.html#constriction.stream.model.CustomModel')",
            None,
            None
        );

        self.decode(py, model, PyTuple::new(py, [amt]))
    }
}
