Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
sumcheck.test.cpp
Go to the documentation of this file.
1#include "sumcheck.hpp"
4
7#include <gtest/gtest.h>
8
9using namespace bb;
10
11namespace {
12
26template <typename Flavor> typename Flavor::ProverPolynomials create_satisfiable_trace(size_t circuit_size)
27{
28 using FF = typename Flavor::FF;
31
32 ProverPolynomials full_polynomials;
33
34 // Initialize precomputed polynomials (selectors)
35 for (auto& poly : full_polynomials.get_precomputed()) {
36 poly = Polynomial(circuit_size);
37 }
38
39 // Initialize witness polynomials as shiftable (start_index = 1) to allow shifting
40 for (auto& poly : full_polynomials.get_witness()) {
41 poly = Polynomial::shiftable(circuit_size);
42 }
43
44 // Initialize shifted polynomials (will be populated by set_shifted())
45 for (auto& poly : full_polynomials.get_shifted()) {
46 poly = Polynomial(circuit_size);
47 }
48
49 // Create a simple arithmetic circuit with a few gates
50 // Row 1: Addition gate: w_l + w_r = w_o (1 + 1 = 2)
51 if (circuit_size > 1) {
52 full_polynomials.w_l.at(1) = FF(1);
53 full_polynomials.w_r.at(1) = FF(1);
54 full_polynomials.w_o.at(1) = FF(2);
55 full_polynomials.q_l.at(1) = FF(1);
56 full_polynomials.q_r.at(1) = FF(1);
57 full_polynomials.q_o.at(1) = FF(-1);
58 full_polynomials.q_arith.at(1) = FF(1);
59 }
60
61 // Row 2: Multiplication gate: w_l * w_r = w_o (2 * 2 = 4)
62 if (circuit_size > 2) {
63 full_polynomials.w_l.at(2) = FF(2);
64 full_polynomials.w_r.at(2) = FF(2);
65 full_polynomials.w_o.at(2) = FF(4);
66 full_polynomials.q_m.at(2) = FF(1);
67 full_polynomials.q_o.at(2) = FF(-1);
68 full_polynomials.q_arith.at(2) = FF(1);
69 }
70
71 // For ZK flavors: add randomness to the last rows (which will be masked by row-disabling polynomial)
72 // These rows don't need to satisfy the relation because they're disabled
73 if constexpr (Flavor::HasZK) {
74 constexpr size_t NUM_DISABLED_ROWS = 3; // Matches the number of disabled rows in ZK sumcheck
75 if (circuit_size > NUM_DISABLED_ROWS) {
76 for (size_t i = circuit_size - NUM_DISABLED_ROWS; i < circuit_size; ++i) {
77 full_polynomials.w_l.at(i) = FF::random_element();
78 full_polynomials.w_r.at(i) = FF::random_element();
79 full_polynomials.w_o.at(i) = FF::random_element();
80 full_polynomials.w_4.at(i) = FF::random_element();
81 full_polynomials.w_test_1.at(i) = FF::random_element();
82 full_polynomials.w_test_2.at(i) = FF::random_element();
83 }
84 }
85 }
86
87 // Compute shifted polynomials using the set_shifted() method
88 full_polynomials.set_shifted();
89
90 return full_polynomials;
91}
92
93template <typename Flavor> class SumcheckTests : public ::testing::Test {
94 public:
95 using FF = typename Flavor::FF;
97 using ZKData = ZKSumcheckData<Flavor>;
98
99 const size_t NUM_POLYNOMIALS = Flavor::NUM_ALL_ENTITIES;
100 static void SetUpTestSuite() { bb::srs::init_file_crs_factory(bb::srs::bb_crs_path()); }
101
102 Polynomial<FF> random_poly(size_t size)
103 {
104 auto poly = bb::Polynomial<FF>(size);
105 for (auto& coeff : poly.coeffs()) {
106 coeff = FF::random_element();
107 }
108 return poly;
109 }
110
111 ProverPolynomials construct_ultra_full_polynomials(auto& input_polynomials)
112 {
113 ProverPolynomials full_polynomials;
114 for (auto [full_poly, input_poly] : zip_view(full_polynomials.get_all(), input_polynomials)) {
115 full_poly = input_poly.share();
116 }
117 return full_polynomials;
118 }
119
120 void test_polynomial_normalization()
121 {
122 // TODO(#225)(Cody): We should not use real constants like this in the tests, at least not in so many of them.
123 const size_t multivariate_d(3);
124 const size_t multivariate_n(1 << multivariate_d);
125
126 // Randomly construct the prover polynomials that are input to Sumcheck.
127 // Note: ProverPolynomials are defined as spans so the polynomials they point to need to exist in memory.
128 std::vector<bb::Polynomial<FF>> random_polynomials(NUM_POLYNOMIALS);
129 for (auto& poly : random_polynomials) {
130 poly = random_poly(multivariate_n);
131 }
132 auto full_polynomials = construct_ultra_full_polynomials(random_polynomials);
133
134 auto transcript = Flavor::Transcript::prover_init_empty();
135
136 FF alpha = transcript->template get_challenge<FF>("Sumcheck:alpha");
137
138 std::vector<FF> gate_challenges(multivariate_d);
139 for (size_t idx = 0; idx < multivariate_d; idx++) {
140 gate_challenges[idx] =
141 transcript->template get_challenge<FF>("Sumcheck:gate_challenge_" + std::to_string(idx));
142 }
143
144 SumcheckProver<Flavor> sumcheck(
145 multivariate_n, full_polynomials, transcript, alpha, gate_challenges, {}, multivariate_d);
146
147 auto output = sumcheck.prove();
148
149 FF u_0 = output.challenge[0];
150 FF u_1 = output.challenge[1];
151 FF u_2 = output.challenge[2];
152
153 /* sumcheck.prove() terminates with sumcheck.multivariates.folded_polynoimals as an array such that
154 * sumcheck.multivariates.folded_polynoimals[i][0] is the evaluatioin of the i'th multivariate at the vector of
155 challenges u_i. What does this mean?
156
157 Here we show that if the multivariate is F(X0, X1, X2) defined as above, then what we get is F(u0, u1, u2) and
158 not, say F(u2, u1, u0). This is in accordance with Adrian's thesis (cf page 9).
159 */
160
161 // Get the values of the Lagrange basis polys L_i defined
162 // by: L_i(v) = 1 if i = v, 0 otherwise, for v from 0 to 7.
163 FF one{ 1 };
164 // clang-format off
165 FF l_0 = (one - u_0) * (one - u_1) * (one - u_2);
166 FF l_1 = (u_0) * (one - u_1) * (one - u_2);
167 FF l_2 = (one - u_0) * (u_1) * (one - u_2);
168 FF l_3 = (u_0) * (u_1) * (one - u_2);
169 FF l_4 = (one - u_0) * (one - u_1) * (u_2);
170 FF l_5 = (u_0) * (one - u_1) * (u_2);
171 FF l_6 = (one - u_0) * (u_1) * (u_2);
172 FF l_7 = (u_0) * (u_1) * (u_2);
173 // clang-format on
174 FF hand_computed_value;
175 for (auto [full_poly, partial_eval_poly] :
176 zip_view(full_polynomials.get_all(), sumcheck.partially_evaluated_polynomials.get_all())) {
177 // full_polynomials[0][0] = w_l[0], full_polynomials[1][1] = w_r[1], and so on.
178 hand_computed_value = l_0 * full_poly[0] + l_1 * full_poly[1] + l_2 * full_poly[2] + l_3 * full_poly[3] +
179 l_4 * full_poly[4] + l_5 * full_poly[5] + l_6 * full_poly[6] + l_7 * full_poly[7];
180 EXPECT_EQ(hand_computed_value, partial_eval_poly[0]);
181 }
182
183 // We can also check the correctness of the multilinear evaluations produced by Sumcheck by directly evaluating
184 // the full polynomials at challenge u via the evaluate_mle() function
185 std::vector<FF> u_challenge = { u_0, u_1, u_2 };
186 for (auto [full_poly, claimed_eval] :
187 zip_view(full_polynomials.get_all(), output.claimed_evaluations.get_all())) {
188 Polynomial<FF> poly(full_poly);
189 auto v_expected = poly.evaluate_mle(u_challenge);
190 EXPECT_EQ(v_expected, claimed_eval);
191 }
192 }
193
194 void test_prover()
195 {
196 const size_t multivariate_d(2);
197 const size_t multivariate_n(1 << multivariate_d);
198
199 // Randomly construct the prover polynomials that are input to Sumcheck.
200 // Note: ProverPolynomials are defined as spans so the polynomials they point to need to exist in memory.
201 std::vector<Polynomial<FF>> random_polynomials(NUM_POLYNOMIALS);
202 for (auto& poly : random_polynomials) {
203 poly = random_poly(multivariate_n);
204 }
205 auto full_polynomials = construct_ultra_full_polynomials(random_polynomials);
206
207 auto transcript = Flavor::Transcript::prover_init_empty();
208
209 FF alpha = transcript->template get_challenge<FF>("Sumcheck:alpha");
210
211 std::vector<FF> gate_challenges(multivariate_d);
212 for (size_t idx = 0; idx < gate_challenges.size(); idx++) {
213 gate_challenges[idx] =
214 transcript->template get_challenge<FF>("Sumcheck:gate_challenge_" + std::to_string(idx));
215 }
216
217 SumcheckProver<Flavor> sumcheck(
218 multivariate_n, full_polynomials, transcript, alpha, gate_challenges, {}, CONST_PROOF_SIZE_LOG_N);
219
221
222 if constexpr (Flavor::HasZK) {
223 ZKData zk_sumcheck_data = ZKData(multivariate_d, transcript);
224 output = sumcheck.prove(zk_sumcheck_data);
225 } else {
226 output = sumcheck.prove();
227 }
228 FF u_0 = output.challenge[0];
229 FF u_1 = output.challenge[1];
230 std::vector<FF> expected_values;
231 for (auto& polynomial_ptr : full_polynomials.get_all()) {
232 auto& polynomial = polynomial_ptr;
233 // using knowledge of inputs here to derive the evaluation
234 FF expected_lo = polynomial[0] * (FF(1) - u_0) + polynomial[1] * u_0;
235 expected_lo *= (FF(1) - u_1);
236 FF expected_hi = polynomial[2] * (FF(1) - u_0) + polynomial[3] * u_0;
237 expected_hi *= u_1;
238 expected_values.emplace_back(expected_lo + expected_hi);
239 }
240
241 for (auto [eval, expected] : zip_view(output.claimed_evaluations.get_all(), expected_values)) {
242 eval = expected;
243 }
244 }
245
246 // TODO(#225): make the inputs to this test more interesting, e.g. non-trivial permutations
247 void test_prover_verifier_flow()
248 {
249 const size_t multivariate_d(3);
250 const size_t multivariate_n(1 << multivariate_d);
251
252 const size_t virtual_log_n = 6;
253
254 auto full_polynomials = create_satisfiable_trace<Flavor>(multivariate_n);
255
256 // SumcheckTestFlavor doesn't need complex relation parameters (no permutation, lookup, etc.)
257 RelationParameters<FF> relation_parameters{};
258 auto prover_transcript = Flavor::Transcript::prover_init_empty();
259 FF prover_alpha = prover_transcript->template get_challenge<FF>("Sumcheck:alpha");
260
261 std::vector<FF> prover_gate_challenges(virtual_log_n);
262 prover_gate_challenges =
263 prover_transcript->template get_dyadic_powers_of_challenge<FF>("Sumcheck:gate_challenge", virtual_log_n);
264
265 SumcheckProver<Flavor> sumcheck_prover(multivariate_n,
266 full_polynomials,
267 prover_transcript,
268 prover_alpha,
269 prover_gate_challenges,
270 relation_parameters,
271 virtual_log_n);
272
274 if constexpr (Flavor::HasZK) {
275 ZKData zk_sumcheck_data = ZKData(multivariate_d, prover_transcript);
276 output = sumcheck_prover.prove(zk_sumcheck_data);
277 } else {
278 output = sumcheck_prover.prove();
279 }
280
281 auto verifier_transcript = Flavor::Transcript::verifier_init_empty(prover_transcript);
282
283 FF verifier_alpha = verifier_transcript->template get_challenge<FF>("Sumcheck:alpha");
284
285 auto sumcheck_verifier = SumcheckVerifier<Flavor>(verifier_transcript, verifier_alpha, virtual_log_n);
286
287 std::vector<FF> verifier_gate_challenges(virtual_log_n);
288 verifier_gate_challenges =
289 verifier_transcript->template get_dyadic_powers_of_challenge<FF>("Sumcheck:gate_challenge", virtual_log_n);
290
291 std::vector<FF> padding_indicator_array(virtual_log_n, 1);
292 if constexpr (Flavor::HasZK) {
293 for (size_t idx = 0; idx < virtual_log_n; idx++) {
294 padding_indicator_array[idx] = (idx < multivariate_d) ? FF{ 1 } : FF{ 0 };
295 }
296 }
297
298 auto verifier_output =
299 sumcheck_verifier.verify(relation_parameters, verifier_gate_challenges, padding_indicator_array);
300
301 auto verified = verifier_output.verified;
302
303 EXPECT_EQ(verified, true);
304 };
305
306 void test_failure_prover_verifier_flow()
307 {
308 // Since the last 4 rows in ZK Flavors are disabled, we extend an invalid circuit of size 4 to size 8 by padding
309 // with 0.
310 const size_t multivariate_d(3);
311 const size_t multivariate_n(1 << multivariate_d);
312
313 // Start with a satisfiable trace, then break it
314 auto full_polynomials = create_satisfiable_trace<Flavor>(multivariate_n);
315
316 // Break the circuit by changing w_l[1] from 1 to 0
317 // This makes the arithmetic relation unsatisfied:
318 // q_arith[1] * (q_l[1] * w_l[1] + q_r[1] * w_r[1] + q_o[1] * w_o[1]) = 1 * (1 * 0 + 1 * 1 + (-1) * 2) = -1 ≠
319 // 0
320 full_polynomials.w_l.at(1) = FF(0);
321
322 // SumcheckTestFlavor doesn't need complex relation parameters
323 RelationParameters<FF> relation_parameters{};
324 auto prover_transcript = Flavor::Transcript::prover_init_empty();
325 FF prover_alpha = prover_transcript->template get_challenge<FF>("Sumcheck:alpha");
326
327 auto prover_gate_challenges =
328 prover_transcript->template get_dyadic_powers_of_challenge<FF>("Sumcheck:gate_challenge", multivariate_d);
329
330 SumcheckProver<Flavor> sumcheck_prover(multivariate_n,
331 full_polynomials,
332 prover_transcript,
333 prover_alpha,
334 prover_gate_challenges,
335 relation_parameters,
336 multivariate_d);
337
339 if constexpr (Flavor::HasZK) {
340 // construct libra masking polynomials and compute auxiliary data
341 ZKData zk_sumcheck_data = ZKData(multivariate_d, prover_transcript);
342 output = sumcheck_prover.prove(zk_sumcheck_data);
343 } else {
344 output = sumcheck_prover.prove();
345 }
346
347 auto verifier_transcript = Flavor::Transcript::verifier_init_empty(prover_transcript);
348
349 FF verifier_alpha = verifier_transcript->template get_challenge<FF>("Sumcheck:alpha");
350
351 SumcheckVerifier<Flavor> sumcheck_verifier(verifier_transcript, verifier_alpha, multivariate_d);
352
353 std::vector<FF> verifier_gate_challenges(multivariate_d);
354 for (size_t idx = 0; idx < multivariate_d; idx++) {
355 verifier_gate_challenges[idx] =
356 verifier_transcript->template get_challenge<FF>("Sumcheck:gate_challenge_" + std::to_string(idx));
357 }
358
359 std::vector<FF> padding_indicator_array(multivariate_d);
360 std::ranges::fill(padding_indicator_array, FF{ 1 });
361 auto verifier_output =
362 sumcheck_verifier.verify(relation_parameters, verifier_gate_challenges, padding_indicator_array);
363
364 auto verified = verifier_output.verified;
365
366 EXPECT_EQ(verified, false);
367 };
368};
369
370// Define the FlavorTypes using SumcheckTestFlavor variants
371// Note: Only testing short monomials since full barycentric adds complexity without testing sumcheck-specific logic
372// Note: Grumpkin sumcheck requires ZK mode for commitment-based protocol (used in ECCVM/IVC)
373using FlavorTypes = testing::Types<SumcheckTestFlavor, // BN254, non-ZK, short monomials
374 SumcheckTestFlavorZK, // BN254, ZK, short monomials
375 SumcheckTestFlavorGrumpkinZK>; // Grumpkin, ZK, short monomials
376
377TYPED_TEST_SUITE(SumcheckTests, FlavorTypes);
378
379TYPED_TEST(SumcheckTests, PolynomialNormalization)
380{
381 if constexpr (!TypeParam::HasZK) {
382 this->test_polynomial_normalization();
383 } else {
384 GTEST_SKIP() << "Skipping test for ZK-enabled flavors";
385 }
386}
387// Test the prover
388TYPED_TEST(SumcheckTests, Prover)
389{
390 this->test_prover();
391}
392// Tests the prover-verifier flow
393TYPED_TEST(SumcheckTests, ProverAndVerifierSimple)
394{
395 this->test_prover_verifier_flow();
396}
397// This tests is fed an invalid circuit and checks that the verifier would output false.
398TYPED_TEST(SumcheckTests, ProverAndVerifierSimpleFailure)
399{
400 this->test_failure_prover_verifier_flow();
401}
402
403} // namespace
A container for the prover polynomials.
static constexpr bool HasZK
typename Curve::ScalarField FF
static constexpr size_t NUM_ALL_ENTITIES
Structured polynomial class that represents the coefficients 'a' of a_0 + a_1 x .....
static Polynomial shiftable(size_t virtual_size)
Utility to create a shiftable polynomial of given virtual size.
The implementation of the sumcheck Prover for statements of the form for multilinear polynomials .
Definition sumcheck.hpp:289
SumcheckOutput< Flavor > prove()
Non-ZK version: Compute round univariate, place it in transcript, compute challenge,...
Definition sumcheck.hpp:398
A flexible, minimal test flavor for sumcheck testing.
Implementation of the sumcheck Verifier for statements of the form for multilinear polynomials .
Definition sumcheck.hpp:786
typename ECCVMFlavor::ProverPolynomials ProverPolynomials
testing::Types< MegaFlavor, UltraFlavor, UltraZKFlavor, UltraRollupFlavor > FlavorTypes
std::filesystem::path bb_crs_path()
void init_file_crs_factory(const std::filesystem::path &path)
Entry point for Barretenberg command-line interface.
Definition api.hpp:5
TYPED_TEST_SUITE(ShpleminiTest, TestSettings)
SumcheckTestFlavor_< curve::BN254, true, true > SumcheckTestFlavorZK
Zero-knowledge variant.
SumcheckTestFlavor_< curve::BN254, false, true > SumcheckTestFlavor
Base test flavor (BN254, non-ZK, short monomials)
TYPED_TEST(ShpleminiTest, CorrectnessOfMultivariateClaimBatching)
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
std::string to_string(bb::avm2::ValueTag tag)
Container for parameters used by the grand product (permutation, lookup) Honk relations.
Contains the evaluations of multilinear polynomials at the challenge point . These are computed by S...
This structure is created to contain various polynomials and constants required by ZK Sumcheck.
static field random_element(numeric::RNG *engine=nullptr) noexcept
Minimal test flavors for sumcheck testing without UltraFlavor dependencies.