// SPDX-License-Identifier: MPL-2.0

//! **(NOTE: This module is experimental. Applications should not use it yet.)** This module
//! defines the [`Type`] trait, the main building block of the [`prio3`](crate::vdaf::prio3) VDAF.
//! Implementations of this trait for various measurement types can be found in [`types`].
//!
//! This module also implements a fully linear PCP ("Probabilistically Checkable Proof") system
//! based on [[BBCG+19], Theorem 4.3] suitable for implementations of [`Type`]. Implementations of
//! [`Type`] provide the FLP functionality described in [[VDAF], Section 6.1] and used in `prio3`
//! to ensure validity of the recovered output shares. Most applications will not need to use this
//! module directly.
//!
//! # Overview
//!
//! The proof system is comprised of three algorithms. The first, `prove`, is run by the prover in
//! order to generate a proof of a statement's validity. The second and third, `query` and
//! `decide`, are run by the verifier in order to check the proof. The proof asserts that the input
//! is an element of a language recognized by an arithmetic circuit. If an input is _not_ valid,
//! then the verification step will fail with high probability. For example:
//!
//! ```
//! use prio::pcp::types::Count;
//! use prio::pcp::Type;
//! use prio::field::{random_vector, FieldElement, Field64};
//!
//! // The prover chooses a measurement.
//! let count = Count::new();
//! let input: Vec<Field64> = count.encode(&0).unwrap();
//!
//! // The prover and verifier agree on "joint randomness" used to generate and
//! // check the proof. The application needs to ensure that the prover
//! // "commits" to the input before this point. In the `prio3` VDAF, the joint
//! // randomness is derived from additive shares of the input.
//! let joint_rand = random_vector(count.joint_rand_len()).unwrap();
//!
//! // The prover generates the proof.
//! let prove_rand = random_vector(count.prove_rand_len()).unwrap();
//! let proof = count.prove(&input, &prove_rand, &joint_rand).unwrap();
//!
//! // The verifier checks the proof. In the first step, the verifier "queries"
//! // the input and proof, getting the "verifier message" in response. It then
//! // inspects the verifier to decide if the input is valid.
//! let query_rand = random_vector(count.query_rand_len()).unwrap();
//! let verifier = count.query(&input, &proof, &query_rand, &joint_rand, 1).unwrap();
//! assert!(count.decide(&verifier).unwrap());
//! ```
//!
//! The proof system implemented here lifts [[BBCG+19], Theorem 4.3] to a 1.5-round, public-coin,
//! interactive oracle proof system (see [[BBCG+19], Definition 3.11]). The main difference is that
//! the arithmetic circuit may include an additional, random input (called the "joint randomness"
//! above). This allows us to express proof systems like the SIMD circuit of [[BBCG+19]] that
//! trade a modest amount of soundness error for a significantly smaller proof.
//!
//! Another improvement made here is to allow for validity circuits with multiple repeating
//! sub-components (called "gadgets" here, see [`gadgets`]) instead of just one. This idea comes
//! from [[BBCG+19], Remark 4.5].
//!
//! WARNING: This proof system has not yet undergone significant security analysis. As such, this
//! proof system should not be considered suitable for production use.
//!
//! [BBCG+19]: https://ia.cr/2019/188
//! [VDAF]: https://datatracker.ietf.org/doc/html/draft-patton-cfrg-vdaf-00
//! [CGB17]: https://crypto.stanford.edu/prio

use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish, FftError};
use crate::field::{FieldElement, FieldError};
use crate::fp::log2;
use crate::polynomial::poly_eval;
use std::any::Any;
use std::convert::TryFrom;
use std::fmt::Debug;

pub mod gadgets;
pub mod types;

/// Errors propagated by methods in this module.
#[derive(Debug, thiserror::Error)]
pub enum PcpError {
    /// Calling [`Type::prove`] returned an error.
    #[error("prove error: {0}")]
    Prove(String),

    /// Calling [`Type::query`] returned an error.
    #[error("query error: {0}")]
    Query(String),

    /// Calling [`Type::decide`] returned an error.
    #[error("decide error: {0}")]
    Decide(String),

    /// Calling a gadget returned an error.
    #[error("gadget error: {0}")]
    Gadget(String),

    /// Calling the validity circuit returned an error.
    #[error("validity circuit error: {0}")]
    Valid(String),

    /// Calling [`Type::encode`] returned an error.
    #[error("value error: {0}")]
    Encode(String),

    /// Calling [`Type::truncate`] returned an error.
    #[error("truncate error: {0}")]
    Truncate(String),

    /// Returned if an FFT operation propagates an error.
    #[error("FFT error: {0}")]
    Fft(#[from] FftError),

    /// Returned if a field operation encountered an error.
    #[error("Field error: {0}")]
    Field(#[from] FieldError),

    /// Unit test error.
    #[cfg(test)]
    #[error("test failed: {0}")]
    Test(String),
}

/// A type. Implementations of this trait specify how a particular kind of measurement is encoded
/// as a vector of field elements and how validity of the encoded measurement is determined.
/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement.
pub trait Type: Sized + Eq + Clone + Debug {
    /// The type of raw measurement to be encoded.
    type Measurement;

    /// The finite field used for this type.
    type Field: FieldElement;

    /// Encodes a measurement as a vector of [`Self::input_len`] field elements.
    fn encode(&self, measurement: &Self::Measurement) -> Result<Vec<Self::Field>, PcpError>;

    /// Returns the sequence of gadgets associated with the validity circuit.
    ///
    /// # Notes
    ///
    /// The construction of [[BBCG+19], Theorem 4.3] uses a single gadget rather than many.  The
    /// idea to generalize the proof system to allow multiple gadgets is discussed briefly in
    /// [[BBCG+19], Remark 4.5], but no construction is given. The construction implemented here
    /// requires security analysis.
    ///
    /// [BBCG+19]: https://ia.cr/2019/188
    fn gadget(&self) -> Vec<Box<dyn Gadget<Self::Field>>>;

    /// Evaluates the validity circuit on an input and returns the output.
    ///
    /// # Parameters
    ///
    /// * `gadgets` is the sequence of gadgets, presumably output by [`Self::gadget`].
    /// * `input` is the input to be validated.
    /// * `joint_rand` is the joint randomness shared by the prover and verifier.
    /// * `num_shares` is the number of input shares.
    ///
    /// # Example usage
    ///
    /// Applications typically do not call this method directly. It is used internally by
    /// [`Self::prove`] and [`Self::query`] to generate and verify the proof respectively.
    ///
    /// ```
    /// use prio::pcp::types::Count;
    /// use prio::pcp::Type;
    /// use prio::field::{random_vector, FieldElement, Field64};
    ///
    /// let count = Count::new();
    /// let input: Vec<Field64> = count.encode(&1).unwrap();
    /// let joint_rand = random_vector(count.joint_rand_len()).unwrap();
    /// let v = count.valid(&mut count.gadget(), &input, &joint_rand, 1).unwrap();
    /// assert_eq!(v, Field64::zero());
    /// ```
    fn valid(
        &self,
        gadgets: &mut Vec<Box<dyn Gadget<Self::Field>>>,
        input: &[Self::Field],
        joint_rand: &[Self::Field],
        num_shares: usize,
    ) -> Result<Self::Field, PcpError>;

    /// Constructs an aggregatable output from an encoded input. Calling this method is only safe
    /// once `input` has been validated.
    fn truncate(&self, input: &[Self::Field]) -> Result<Vec<Self::Field>, PcpError>;

    /// The length in field elements of the encoded input returned by [`Self::encode`].
    fn input_len(&self) -> usize;

    /// The length in field elements of the proof generated for this type.
    fn proof_len(&self) -> usize;

    /// The length in field elements of the verifier message constructed by [`Self::query`].
    fn verifier_len(&self) -> usize;

    /// The length of the truncated output (i.e., the output of [`Type::truncate`]).
    fn output_len(&self) -> usize;

    /// The length of the joint random input.
    fn joint_rand_len(&self) -> usize;

    /// The length in field elements of the random input consumed by the prover to generate a
    /// proof. This is the same as the sum of the arity of each gadget in the validity circuit.
    fn prove_rand_len(&self) -> usize;

    /// The length in field elements of the random input consumed by the verifier to make queries
    /// against inputs and proofs. This is the same as the number of gadgets in the validity
    /// circuit.
    fn query_rand_len(&self) -> usize;

    /// Generate a proof of an input's validity. The return value is a sequence of
    /// [`Self::proof_len`] field elements.
    ///
    /// # Parameters
    ///
    /// * `input` is the input.
    /// * `prove_rand` is the prover' randomness.
    /// * `joint_rand` is the randomness shared by the prover and verifier.
    #[allow(clippy::needless_range_loop)]
    fn prove(
        &self,
        input: &[Self::Field],
        prove_rand: &[Self::Field],
        joint_rand: &[Self::Field],
    ) -> Result<Vec<Self::Field>, PcpError> {
        if input.len() != self.input_len() {
            return Err(PcpError::Prove(format!(
                "unexpected input length: got {}; want {}",
                input.len(),
                self.input_len()
            )));
        }

        if prove_rand.len() != self.prove_rand_len() {
            return Err(PcpError::Prove(format!(
                "unexpected prove randomness length: got {}; want {}",
                prove_rand.len(),
                self.prove_rand_len()
            )));
        }

        if joint_rand.len() != self.joint_rand_len() {
            return Err(PcpError::Prove(format!(
                "unexpected joint randomness length: got {}; want {}",
                joint_rand.len(),
                self.joint_rand_len()
            )));
        }

        let mut prove_rand_len = 0;
        let mut shim = self
            .gadget()
            .into_iter()
            .map(|inner| {
                let inner_arity = inner.arity();
                if prove_rand_len + inner_arity > prove_rand.len() {
                    return Err(PcpError::Prove(format!(
                        "short prove randomness: got {}; want {}",
                        prove_rand.len(),
                        self.prove_rand_len()
                    )));
                }

                let gadget = Box::new(ProveShimGadget::new(
                    inner,
                    &prove_rand[prove_rand_len..prove_rand_len + inner_arity],
                )?) as Box<dyn Gadget<Self::Field>>;
                prove_rand_len += inner_arity;

                Ok(gadget)
            })
            .collect::<Result<Vec<_>, PcpError>>()?;
        assert_eq!(prove_rand_len, self.prove_rand_len());

        // Create a buffer for storing the proof. The buffer is longer than the proof itself; the extra
        // length is to accommodate the computation of each gadget polynomial.
        let data_len = (0..shim.len())
            .map(|idx| {
                shim[idx].arity() + shim[idx].degree() * (1 + shim[idx].calls()).next_power_of_two()
            })
            .sum();
        let mut proof = vec![Self::Field::zero(); data_len];

        // Run the validity circuit with a sequence of "shim" gadgets that record the value of each
        // input wire of each gadget evaluation. These values are used to construct the wire
        // polynomials for each gadget in the next step.
        let _ = self.valid(&mut shim, input, joint_rand, 1)?;

        // Fill the buffer with the proof. `proof_len` keeps track of the amount of data written to the
        // buffer so far.
        let mut proof_len = 0;
        for idx in 0..shim.len() {
            let gadget = shim[idx]
                .as_any()
                .downcast_mut::<ProveShimGadget<Self::Field>>()
                .unwrap();

            // Interpolate the wire polynomials `f[0], ..., f[g_arity-1]` from the input wires of each
            // evaluation of the gadget.
            let m = (1 + gadget.calls()).next_power_of_two();
            let m_inv =
                Self::Field::from(<Self::Field as FieldElement>::Integer::try_from(m).unwrap())
                    .inv();
            let mut f = vec![vec![Self::Field::zero(); m]; gadget.arity()];
            for wire in 0..gadget.arity() {
                discrete_fourier_transform(&mut f[wire], &gadget.f_vals[wire], m)?;
                discrete_fourier_transform_inv_finish(&mut f[wire], m, m_inv);

                // The first point on each wire polynomial is a random value chosen by the prover. This
                // point is stored in the proof so that the verifier can reconstruct the wire
                // polynomials.
                proof[proof_len + wire] = gadget.f_vals[wire][0];
            }

            // Construct the gadget polynomial `G(f[0], ..., f[g_arity-1])` and append it to `proof`.
            gadget.call_poly(&mut proof[proof_len + gadget.arity()..], &f)?;
            proof_len += gadget.arity() + gadget.degree() * (m - 1) + 1;
        }

        // Truncate the buffer to the size of the proof.
        assert_eq!(proof_len, self.proof_len());
        proof.truncate(proof_len);
        Ok(proof)
    }

    /// Query an input and proof and return the verifier message. The return value has length
    /// [`Self::verifier_len`].
    ///
    /// # Parameters
    ///
    /// * `input` is the input or input share.
    /// * `proof` is the proof or proof share.
    /// * `query_rand` is the verifier's randomness.
    /// * `joint_rand` is the randomness shared by the prover and verifier.
    /// * `num_shares` is the total number of input shares.
    fn query(
        &self,
        input: &[Self::Field],
        proof: &[Self::Field],
        query_rand: &[Self::Field],
        joint_rand: &[Self::Field],
        num_shares: usize,
    ) -> Result<Vec<Self::Field>, PcpError> {
        if input.len() != self.input_len() {
            return Err(PcpError::Query(format!(
                "unexpected input length: got {}; want {}",
                input.len(),
                self.input_len()
            )));
        }

        if proof.len() != self.proof_len() {
            return Err(PcpError::Query(format!(
                "unexpected proof length: got {}; want {}",
                proof.len(),
                self.proof_len()
            )));
        }

        if query_rand.len() != self.query_rand_len() {
            return Err(PcpError::Query(format!(
                "unexpected query randomness length: got {}; want {}",
                query_rand.len(),
                self.query_rand_len()
            )));
        }

        if joint_rand.len() != self.joint_rand_len() {
            return Err(PcpError::Query(format!(
                "unexpected joint randomness length: got {}; want {}",
                joint_rand.len(),
                self.joint_rand_len()
            )));
        }

        let mut proof_len = 0;
        let mut shim = self
            .gadget()
            .into_iter()
            .enumerate()
            .map(|(idx, gadget)| {
                let gadget_degree = gadget.degree();
                let gadget_arity = gadget.arity();
                let m = (1 + gadget.calls()).next_power_of_two();
                let r = query_rand[idx];

                // Make sure the query randomness isn't a root of unity. Evaluating the gadget
                // polynomial at any of these points would be a privacy violation, since these points
                // were used by the prover to construct the wire polynomials.
                if r.pow(<Self::Field as FieldElement>::Integer::try_from(m).unwrap())
                    == Self::Field::one()
                {
                    return Err(PcpError::Query(format!(
                        "invalid query randomness: encountered 2^{}-th root of unity",
                        m
                    )));
                }

                // Compute the length of the sub-proof corresponding to the `idx`-th gadget.
                let next_len = gadget_arity + gadget_degree * (m - 1) + 1;
                let proof_data = &proof[proof_len..proof_len + next_len];
                proof_len += next_len;

                Ok(Box::new(QueryShimGadget::new(gadget, r, proof_data)?)
                    as Box<dyn Gadget<Self::Field>>)
            })
            .collect::<Result<Vec<_>, _>>()?;

        // Create a buffer for the verifier data. This includes the output of the validity circuit and,
        // for each gadget `shim[idx].inner`, the wire polynomials evaluated at the query randomness
        // `query_rand[idx]` and the gadget polynomial evaluated at `query_rand[idx]`.
        let data_len = 1
            + (0..shim.len())
                .map(|idx| shim[idx].arity() + 1)
                .sum::<usize>();
        let mut verifier = Vec::with_capacity(data_len);

        // Run the validity circuit with a sequence of "shim" gadgets that record the inputs to each
        // wire for each gadget call. Record the output of the circuit and append it to the verifier
        // message.
        //
        // NOTE The proof of [BBC+19, Theorem 4.3] assumes that the output of the validity circuit is
        // equal to the output of the last gadget evaluation. Here we relax this assumption. This
        // should be OK, since it's possible to transform any circuit into one for which this is true.
        // (Needs security analysis.)
        let validity = self.valid(&mut shim, input, joint_rand, num_shares)?;
        verifier.push(validity);

        // Fill the buffer with the verifier message.
        for idx in 0..shim.len() {
            let r = query_rand[idx];
            let gadget = shim[idx]
                .as_any()
                .downcast_ref::<QueryShimGadget<Self::Field>>()
                .unwrap();

            // Reconstruct the wire polynomials `f[0], ..., f[g_arity-1]` and evaluate each wire
            // polynomial at query randomness `r`.
            let m = (1 + gadget.calls()).next_power_of_two();
            let m_inv =
                Self::Field::from(<Self::Field as FieldElement>::Integer::try_from(m).unwrap())
                    .inv();
            let mut f = vec![Self::Field::zero(); m];
            for wire in 0..gadget.arity() {
                discrete_fourier_transform(&mut f, &gadget.f_vals[wire], m)?;
                discrete_fourier_transform_inv_finish(&mut f, m, m_inv);
                verifier.push(poly_eval(&f, r));
            }

            // Add the value of the gadget polynomial evaluated at `r`.
            verifier.push(gadget.p_at_r);
        }

        assert_eq!(verifier.len(), self.verifier_len());
        Ok(verifier)
    }

    /// Returns true if the verifier message indicates that the input from which it was generated is valid.
    #[allow(clippy::needless_range_loop)]
    fn decide(&self, verifier: &[Self::Field]) -> Result<bool, PcpError> {
        if verifier.len() != self.verifier_len() {
            return Err(PcpError::Decide(format!(
                "unexpected verifier length: got {}; want {}",
                verifier.len(),
                self.verifier_len()
            )));
        }

        // Check if the output of the circuit is 0.
        if verifier[0] != Self::Field::zero() {
            return Ok(false);
        }

        // Check that each of the proof polynomials are well-formed.
        let mut gadgets = self.gadget();
        let mut verifier_len = 1;
        for idx in 0..gadgets.len() {
            let next_len = 1 + gadgets[idx].arity();

            let e = gadgets[idx].call(&verifier[verifier_len..verifier_len + next_len - 1])?;
            if e != verifier[verifier_len + next_len - 1] {
                return Ok(false);
            }

            verifier_len += next_len;
        }

        Ok(true)
    }
}

/// A gadget, a non-affine arithmetic circuit that is called when evaluating a validity circuit.
pub trait Gadget<F: FieldElement> {
    /// Evaluates the gadget on input `inp` and returns the output.
    fn call(&mut self, inp: &[F]) -> Result<F, PcpError>;

    /// Evaluate the gadget on input of a sequence of polynomials. The output is written to `outp`.
    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), PcpError>;

    /// Returns the arity of the gadget. This is the length of `inp` passed to `call` or
    /// `call_poly`.
    fn arity(&self) -> usize;

    /// Returns the circuit's arithmetic degree. This determines the minimum length the `outp`
    /// buffer passed to `call_poly`.
    fn degree(&self) -> usize;

    /// Returns the number of times the gadget is expected to be called.
    fn calls(&self) -> usize;

    /// This call is used to downcast a `Box<dyn Gadget<F>>` to a concrete type.
    fn as_any(&mut self) -> &mut dyn Any;
}

// A "shim" gadget used during proof generation to record the input wires each time a gadget is
// evaluated.
struct ProveShimGadget<F: FieldElement> {
    inner: Box<dyn Gadget<F>>,

    /// Points at which the wire polynomials are interpolated.
    f_vals: Vec<Vec<F>>,

    /// The number of times the gadget has been called so far.
    ct: usize,
}

impl<F: FieldElement> ProveShimGadget<F> {
    fn new(inner: Box<dyn Gadget<F>>, prove_rand: &[F]) -> Result<Self, PcpError> {
        let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; inner.arity()];

        #[allow(clippy::needless_range_loop)]
        for wire in 0..f_vals.len() {
            // Choose a random field element as the first point on the wire polynomial.
            f_vals[wire][0] = prove_rand[wire];
        }

        Ok(Self {
            inner,
            f_vals,
            ct: 1,
        })
    }
}

impl<F: FieldElement> Gadget<F> for ProveShimGadget<F> {
    fn call(&mut self, inp: &[F]) -> Result<F, PcpError> {
        #[allow(clippy::needless_range_loop)]
        for wire in 0..inp.len() {
            self.f_vals[wire][self.ct] = inp[wire];
        }
        self.ct += 1;
        self.inner.call(inp)
    }

    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), PcpError> {
        self.inner.call_poly(outp, inp)
    }

    fn arity(&self) -> usize {
        self.inner.arity()
    }

    fn degree(&self) -> usize {
        self.inner.degree()
    }

    fn calls(&self) -> usize {
        self.inner.calls()
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

// A "shim" gadget used during proof verification to record the points at which the intermediate
// proof polynomials are evaluated.
struct QueryShimGadget<F: FieldElement> {
    inner: Box<dyn Gadget<F>>,

    /// Points at which intermediate proof polynomials are interpolated.
    f_vals: Vec<Vec<F>>,

    /// Points at which the gadget polynomial is interpolated.
    p_vals: Vec<F>,

    /// The gadget polynomial evaluated on a random input `r`.
    p_at_r: F,

    /// Used to compute an index into `p_val`.
    step: usize,

    /// The number of times the gadget has been called so far.
    ct: usize,
}

impl<F: FieldElement> QueryShimGadget<F> {
    fn new(inner: Box<dyn Gadget<F>>, r: F, proof_data: &[F]) -> Result<Self, PcpError> {
        let gadget_degree = inner.degree();
        let gadget_arity = inner.arity();
        let m = (1 + inner.calls()).next_power_of_two();
        let p = m * gadget_degree;

        // Each call to this gadget records the values at which intermediate proof polynomials were
        // interpolated. The first point was a random value chosen by the prover and transmitted in
        // the proof.
        let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; gadget_arity];
        for wire in 0..gadget_arity {
            f_vals[wire][0] = proof_data[wire];
        }

        // Evaluate the gadget polynomial at roots of unity.
        let size = p.next_power_of_two();
        let mut p_vals = vec![F::zero(); size];
        discrete_fourier_transform(&mut p_vals, &proof_data[gadget_arity..], size)?;

        // The step is used to compute the element of `p_val` that will be returned by a call to
        // the gadget.
        let step = (1 << (log2(p as u128) - log2(m as u128))) as usize;

        // Evaluate the gadget polynomial `p` at query randomness `r`.
        let p_at_r = poly_eval(&proof_data[gadget_arity..], r);

        Ok(Self {
            inner,
            f_vals,
            p_vals,
            p_at_r,
            step,
            ct: 1,
        })
    }
}

impl<F: FieldElement> Gadget<F> for QueryShimGadget<F> {
    fn call(&mut self, inp: &[F]) -> Result<F, PcpError> {
        #[allow(clippy::needless_range_loop)]
        for wire in 0..inp.len() {
            self.f_vals[wire][self.ct] = inp[wire];
        }
        let outp = self.p_vals[self.ct * self.step];
        self.ct += 1;
        Ok(outp)
    }

    fn call_poly(&mut self, _outp: &mut [F], _inp: &[Vec<F>]) -> Result<(), PcpError> {
        panic!("no-op");
    }

    fn arity(&self) -> usize {
        self.inner.arity()
    }

    fn degree(&self) -> usize {
        self.inner.degree()
    }

    fn calls(&self) -> usize {
        self.inner.calls()
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::field::{random_vector, split_vector, Field126};
    use crate::pcp::gadgets::{Mul, PolyEval};
    use crate::polynomial::poly_range_check;

    use std::marker::PhantomData;

    // Simple integration test for the core PCP logic. You'll find more extensive unit tests for
    // each implemented data type in src/types.rs.
    #[test]
    fn test_pcp() {
        const NUM_SHARES: usize = 2;

        let typ: TestType<Field126> = TestType::new();
        let input = typ.encode(&3).unwrap();
        assert_eq!(input.len(), typ.input_len());

        let input_shares: Vec<Vec<Field126>> = split_vector(input.as_slice(), NUM_SHARES)
            .unwrap()
            .into_iter()
            .collect();

        let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
        let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
        let query_rand = random_vector(typ.query_rand_len()).unwrap();

        let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
        assert_eq!(proof.len(), typ.proof_len());

        let proof_shares: Vec<Vec<Field126>> = split_vector(&proof, NUM_SHARES)
            .unwrap()
            .into_iter()
            .collect();

        let verifier: Vec<Field126> = (0..NUM_SHARES)
            .map(|i| {
                typ.query(
                    &input_shares[i],
                    &proof_shares[i],
                    &query_rand,
                    &joint_rand,
                    NUM_SHARES,
                )
                .unwrap()
            })
            .reduce(|mut left, right| {
                for (x, y) in left.iter_mut().zip(right.iter()) {
                    *x += *y;
                }
                left
            })
            .unwrap();
        assert_eq!(verifier.len(), typ.verifier_len());

        assert!(typ.decide(&verifier).unwrap());
    }

    /// A toy type used for testing the functionality in this module. Valid inputs of this type
    /// consist of a pair of field elements `(x, y)` where `2 <= x < 5` and `x^3 == y`.
    #[derive(Clone, Debug, PartialEq, Eq)]
    struct TestType<F>(PhantomData<F>);

    impl<F> TestType<F> {
        fn new() -> Self {
            Self(PhantomData)
        }
    }

    impl<F: FieldElement> Type for TestType<F> {
        type Measurement = F::Integer;
        type Field = F;

        fn valid(
            &self,
            g: &mut Vec<Box<dyn Gadget<F>>>,
            input: &[F],
            joint_rand: &[F],
            _num_shares: usize,
        ) -> Result<F, PcpError> {
            let r = joint_rand[0];
            let mut res = F::zero();

            // Check that `data[0]^3 == data[1]`.
            let mut inp = [input[0], input[0]];
            inp[0] = g[0].call(&inp)?;
            inp[0] = g[0].call(&inp)?;
            let x3_diff = inp[0] - input[1];
            res += r * x3_diff;

            // Check that `data[0]` is in the correct range.
            let x_checked = g[1].call(&[input[0]])?;
            res += (r * r) * x_checked;

            Ok(res)
        }

        fn input_len(&self) -> usize {
            2
        }

        fn proof_len(&self) -> usize {
            // First chunk
            let mul = 2 /* gadget arity */ + 2 /* gadget degree */ * (
                (1 + 2 /* gadget calls */ as usize).next_power_of_two() - 1) + 1;

            // Second chunk
            let poly = 1 /* gadget arity */ + 3 /* gadget degree */ * (
                (1 + 1 /* gadget calls */ as usize).next_power_of_two() - 1) + 1;

            mul + poly
        }

        fn verifier_len(&self) -> usize {
            // First chunk
            let mul = 1 + 2 /* gadget arity */;

            // Second chunk
            let poly = 1 + 1 /* gadget arity */;

            1 + mul + poly
        }

        fn output_len(&self) -> usize {
            self.input_len()
        }

        fn joint_rand_len(&self) -> usize {
            1
        }

        fn prove_rand_len(&self) -> usize {
            3
        }

        fn query_rand_len(&self) -> usize {
            2
        }

        fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
            vec![
                Box::new(Mul::new(2)),
                Box::new(PolyEval::new(poly_range_check(2, 5), 1)),
            ]
        }

        fn encode(&self, measurement: &F::Integer) -> Result<Vec<F>, PcpError> {
            Ok(vec![
                F::from(*measurement),
                F::from(*measurement).pow(F::Integer::try_from(3).unwrap()),
            ])
        }

        fn truncate(&self, input: &[F]) -> Result<Vec<F>, PcpError> {
            Ok(input.to_vec())
        }
    }
}
