Skip to content

Commit

Permalink
Use SplitEqPolynomial in Spartan cubic sumcheck (#500)
Browse files Browse the repository at this point in the history
* Use SplitEqPolynomial in Spartan cubic sumcheck

* (hopefully) fix weird behavior where layer_output sometimes takes a really long time?
  • Loading branch information
moodlezoup authored Nov 8, 2024
1 parent d5f9b87 commit 9f0b9e6
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 110 deletions.
42 changes: 30 additions & 12 deletions jolt-core/src/poly/sparse_interleaved_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,19 +221,37 @@ impl<F: JoltField> SparseInterleavedPolynomial<F> {
.coeffs
.par_iter()
.map(|segment| {
segment
.chunk_by(|x, y| x.index / 2 == y.index / 2)
.map(|sparse_block| {
let mut dense_block = [F::one(); 2];
for coeff in sparse_block {
dense_block[coeff.index % 2] = coeff.value;
let mut output_segment: Vec<SparseCoefficient<F>> =
Vec::with_capacity(segment.len());
let mut next_index_to_process = 0usize;
for (j, coeff) in segment.iter().enumerate() {
if coeff.index < next_index_to_process {
// Node was already multiplied with its sibling in a previous iteration
continue;
}
if coeff.index % 2 == 0 {
// Left node; try to find correspoding right node
let right = segment
.get(j + 1)
.cloned()
.unwrap_or((coeff.index + 1, F::one()).into());
if right.index == coeff.index + 1 {
// Corresponding right node was found; multiply them together
output_segment
.push((coeff.index / 2, right.value * coeff.value).into());
} else {
// Corresponding right node not found, so it must be 1
output_segment.push((coeff.index / 2, coeff.value).into());
}

let output_index = sparse_block[0].index / 2;
let output_value = dense_block[0].mul_1_optimized(dense_block[1]);
(output_index, output_value).into()
})
.collect()
next_index_to_process = coeff.index + 2;
} else {
// Right node; corresponding left node was not encountered in
// previous iteration, so it must have value 1
output_segment.push((coeff.index / 2, coeff.value).into());
next_index_to_process = coeff.index + 1;
}
}
output_segment
})
.collect();

Expand Down
36 changes: 9 additions & 27 deletions jolt-core/src/r1cs/spartan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::jolt::vm::JoltPolynomials;
use crate::poly::commitment::commitment_scheme::CommitmentScheme;
use crate::poly::opening_proof::ProverOpeningAccumulator;
use crate::poly::opening_proof::VerifierOpeningAccumulator;
use crate::poly::split_eq_poly::SplitEqPolynomial;
use crate::r1cs::key::UniformSpartanKey;
use crate::utils::math::Math;
use crate::utils::thread::drop_in_background_thread;
Expand Down Expand Up @@ -121,51 +122,32 @@ where
let tau = (0..num_rounds_x)
.map(|_i| transcript.challenge_scalar())
.collect::<Vec<F>>();
let mut poly_tau = DensePolynomial::new(EqPolynomial::evals(&tau));
let mut eq_tau = SplitEqPolynomial::new(&tau);

let (mut az, mut bz, mut cz) =
constraint_builder.compute_spartan_Az_Bz_Cz::<PCS, ProofTranscript>(&flattened_polys);

let comb_func_outer = |eq: &F, az: &F, bz: &F, cz: &F| -> F {
// Below is an optimized form of: eq * (Az * Bz - Cz)
if az.is_zero() || bz.is_zero() {
if cz.is_zero() {
F::zero()
} else {
*eq * (-(*cz))
}
} else {
let inner = *az * *bz - *cz;
if inner.is_zero() {
F::zero()
} else {
*eq * inner
}
}
};

let (outer_sumcheck_proof, outer_sumcheck_r, outer_sumcheck_claims) =
SumcheckInstanceProof::prove_spartan_cubic::<_>(
SumcheckInstanceProof::prove_spartan_cubic(
&F::zero(), // claim is zero
num_rounds_x,
&mut poly_tau,
&mut eq_tau,
&mut az,
&mut bz,
&mut cz,
comb_func_outer,
transcript,
);
let outer_sumcheck_r: Vec<F> = outer_sumcheck_r.into_iter().rev().collect();
drop_in_background_thread((az, bz, cz, poly_tau));
drop_in_background_thread((az, bz, cz, eq_tau));

ProofTranscript::append_scalars(transcript, &outer_sumcheck_claims);
// claims from the end of sum-check
// claim_Az is the (scalar) value v_A = \sum_y A(r_x, y) * z(r_x) where r_x is the sumcheck randomness
let (claim_Az, claim_Bz, claim_Cz): (F, F, F) = (
outer_sumcheck_claims[0],
outer_sumcheck_claims[1],
outer_sumcheck_claims[2],
outer_sumcheck_claims[3],
);
ProofTranscript::append_scalars(transcript, [claim_Az, claim_Bz, claim_Cz].as_slice());

// inner sum-check
let r_inner_sumcheck_RLC: F = transcript.challenge_scalar();
Expand Down Expand Up @@ -211,11 +193,11 @@ where
transcript,
);

// Outer sumcheck claims: [eq(r_x), A(r_x), B(r_x), C(r_x)]
// Outer sumcheck claims: [A(r_x), B(r_x), C(r_x)]
let outer_sumcheck_claims = (
outer_sumcheck_claims[0],
outer_sumcheck_claims[1],
outer_sumcheck_claims[2],
outer_sumcheck_claims[3],
);
Ok(UniformSpartanProof {
_inputs: PhantomData,
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/subprotocols/sparse_grand_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl<F: JoltField> BatchedGrandProductToggleLayer<F> {
.enumerate()
.map(|(batch_index, fingerprints)| {
let flag_indices = &self.flag_indices[batch_index / 2];
let mut sparse_coeffs = vec![];
let mut sparse_coeffs = Vec::with_capacity(self.layer_len);
for i in flag_indices {
sparse_coeffs
.push((batch_index * self.layer_len / 2 + i, fingerprints[*i]).into());
Expand Down
Loading

0 comments on commit 9f0b9e6

Please sign in to comment.