Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
biggroup_nafs.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: not started, auditors: [], date: YYYY-MM-DD }
3// external_1: { status: not started, auditors: [], date: YYYY-MM-DD }
4// external_2: { status: not started, auditors: [], date: YYYY-MM-DD }
5// =====================
6
7#pragma once
11
13
14template <typename C, class Fq, class Fr, class G>
15template <size_t wnaf_size>
17 const uint64_t stagger,
18 bool is_negative,
19 bool wnaf_skew)
20{
21 // If there is no stagger then there is no need to change anything
22 if (stagger == 0) {
23 return std::make_pair(0, wnaf_skew);
24 }
25
26 // Sanity check input fragment
27 BB_ASSERT_LT(fragment_u64, (1ULL << stagger), "biggroup_nafs: fragment value ≥ 2^{stagger}");
28
29 // Convert the fragment to signed int for easier manipulation
30 int fragment = static_cast<int>(fragment_u64);
31
32 // Inverse the fragment if it's negative
33 if (is_negative) {
34 fragment = -fragment;
35 }
36 // If the value is positive and there is a skew in wnaf, subtract 2^{stagger}.
37 if (!is_negative && wnaf_skew) {
38 fragment -= (1 << stagger);
39 }
40
41 // If the value is negative and there is a skew in wnaf, add 2^{stagger}.
42 if (is_negative && wnaf_skew) {
43 fragment += (1 << stagger);
44 }
45
46 // If the lowest bit is zero, then set final skew to 1 and
47 // (i) add 1 to the absolute value of the fragment if it's positive
48 // (ii) subtract 1 from the absolute value of the fragment if it's negative
49 bool output_skew = (fragment_u64 & 1) == 0;
50 if (!is_negative && output_skew) {
51 fragment += 1;
52 } else if (is_negative && output_skew) {
53 fragment -= 1;
54 }
55
56 // Compute raw wnaf value: w = 2e + 1 => e = (w - 1) / 2 => e = ⌊w / 2⌋
57 const int signed_wnaf_value = (fragment / 2);
58 constexpr int wnaf_window_size = (1ULL << (wnaf_size - 1));
59 uint64_t output_fragment = 0;
60 if (fragment < 0) {
61 output_fragment = static_cast<uint64_t>(wnaf_window_size + signed_wnaf_value - 1);
62 } else {
63 output_fragment = static_cast<uint64_t>(wnaf_window_size + signed_wnaf_value);
64 }
65
66 return std::make_pair(output_fragment, output_skew);
67}
68
69template <typename C, class Fq, class Fr, class G>
70template <size_t wnaf_size>
72 C* builder, const uint64_t* wnaf_values, bool is_negative, size_t rounds, const bool range_constrain_wnaf)
73{
74 constexpr uint64_t wnaf_window_size = (1ULL << (wnaf_size - 1));
75
76 std::vector<field_ct> wnaf_entries;
77 for (size_t i = 0; i < rounds; ++i) {
78 // Predicate == sign of current wnaf value
79 const bool predicate = (wnaf_values[i] >> 31U) & 1U; // sign bit (32nd bit)
80 const uint64_t wnaf_magnitude = (wnaf_values[i] & 0x7fffffffU); // 31-bit magnitude
81
82 // If the signs of current entry and the whole scalar are the same, then add the magnitude of the
83 // wnaf value to the windows size to form an entry. Otherwise, subract the magnitude along with 1.
84 // The extra 1 is needed to get a uniform representation of (2e' + 1) as explained in the README.
85 uint64_t offset_wnaf_entry = 0;
86 if ((!predicate && !is_negative) || (predicate && is_negative)) {
87 offset_wnaf_entry = wnaf_window_size + wnaf_magnitude;
88 } else {
89 offset_wnaf_entry = wnaf_window_size - wnaf_magnitude - 1;
90 }
91 field_ct wnaf_entry(witness_ct(builder, offset_wnaf_entry));
92
93 // In some cases we may want to skip range constraining the wnaf entries. For example when we use these
94 // entries to lookup in a ROM or regular table, it implicitly enforces the range constraint.
95 if (range_constrain_wnaf) {
96 wnaf_entry.create_range_constraint(wnaf_size, "biggroup_nafs: wnaf_entry is not in range");
97 }
98 wnaf_entries.emplace_back(wnaf_entry);
99 }
100 return wnaf_entries;
101}
102
103template <typename C, class Fq, class Fr, class G>
104template <size_t wnaf_size>
106 const std::vector<field_t<Builder>>& wnaf,
107 const bool_ct& positive_skew,
108 const bool_ct& negative_skew,
109 const field_t<Builder>& stagger_fragment,
110 const size_t stagger,
111 const size_t rounds)
112{
113 // Collect positive wnaf entries for accumulation
114 std::vector<field_ct> accumulator;
115 for (size_t i = 0; i < rounds; ++i) {
116 field_ct entry = wnaf[rounds - 1 - i];
117 entry *= field_ct(uint256_t(1) << (i * wnaf_size));
118 accumulator.emplace_back(entry);
119 }
120
121 // Accumulate entries, shift by stagger and add the stagger itself
122 field_ct sum = field_ct::accumulate(accumulator);
123 sum = sum * field_ct(bb::fr(1ULL << stagger));
124 sum += (stagger_fragment);
125 sum = sum.normalize();
126
127 // Convert this value to bigfield element
128 Fr reconstructed_positive_part =
129 Fr(sum, field_ct::from_witness_index(builder, builder->zero_idx()), /*can_overflow*/ false);
130
131 // Double the final value and add the positive skew
132 reconstructed_positive_part =
133 (reconstructed_positive_part + reconstructed_positive_part)
134 .add_to_lower_limb(field_t<Builder>(positive_skew), /*other_maximum_value*/ uint256_t(1));
135
136 // Start reconstructing the negative part: start with wnaf constant 0xff...ff
137 // See the README for explanation of this constant
138 constexpr uint64_t wnaf_window_size = (1ULL << (wnaf_size - 1));
139 uint256_t negative_constant_wnaf_offset(0);
140 for (size_t i = 0; i < rounds; ++i) {
141 negative_constant_wnaf_offset += uint256_t((wnaf_window_size * 2) - 1) * (uint256_t(1) << (i * wnaf_size));
142 }
143
144 // Shift by stagger
145 negative_constant_wnaf_offset = negative_constant_wnaf_offset << stagger;
146
147 // Add for stagger (if any)
148 if (stagger > 0) {
149 negative_constant_wnaf_offset += ((1ULL << wnaf_size) - 1ULL); // from stagger fragment
150 }
151
152 // Add the negative skew to the bigfield constant
153 Fr reconstructed_negative_part =
154 Fr(nullptr, negative_constant_wnaf_offset).add_to_lower_limb(field_t<Builder>(negative_skew), uint256_t(1));
155
156 // output = x_pos - x_neg (x_pos and x_neg are both non-negative)
157 Fr reconstructed = reconstructed_positive_part - reconstructed_negative_part;
158
159 return reconstructed;
160}
161
162template <typename C, class Fq, class Fr, class G>
163template <size_t num_bits, size_t wnaf_size, size_t lo_stagger, size_t hi_stagger>
165 C* builder,
166 const secp256k1::fr& scalar,
167 size_t stagger,
168 bool is_negative,
169 const bool range_constrain_wnaf,
170 bool is_lo)
171{
172 // The number of rounds is the minimal required to cover the whole scalar with wnaf_size windows
173 constexpr size_t num_rounds = ((num_bits + wnaf_size - 1) / wnaf_size);
174
175 // Stagger mask is needed to retrieve the lowest bits that will not be used in montgomery ladder directly
176 const uint64_t stagger_mask = (1ULL << stagger) - 1;
177
178 // Stagger scalar represents the lower "staggered" bits that are not used in the ladder
179 const uint64_t stagger_scalar = scalar.data[0] & stagger_mask;
180
181 std::array<uint64_t, num_rounds> wnaf_values = { 0 };
182 bool skew_without_stagger = false;
183 uint256_t k_u256{ scalar.data[0], scalar.data[1], scalar.data[2], scalar.data[3] };
184 k_u256 = k_u256 >> stagger;
185 if (is_lo) {
186 bb::wnaf::fixed_wnaf<num_bits - lo_stagger, 1, wnaf_size>(
187 &k_u256.data[0], &wnaf_values[0], skew_without_stagger, 0);
188 } else {
189 bb::wnaf::fixed_wnaf<num_bits - hi_stagger, 1, wnaf_size>(
190 &k_u256.data[0], &wnaf_values[0], skew_without_stagger, 0);
191 }
192
193 // Number of rounds that are needed to reconstruct the scalar without staggered bits
194 const size_t num_rounds_excluding_stagger_bits = ((num_bits + wnaf_size - 1 - stagger) / wnaf_size);
195
196 // Compute the stagger-related fragment and the final skew due to the same
197 const auto [first_fragment, skew] =
198 get_staggered_wnaf_fragment_value<wnaf_size>(stagger_scalar, stagger, is_negative, skew_without_stagger);
199
200 // Get wnaf witnesses
201 // Note that we only range constrain the wnaf entries if range_constrain_wnaf is set to true.
202 std::vector<field_ct> wnaf = convert_wnaf_values_to_witnesses<wnaf_size>(
203 builder, &wnaf_values[0], is_negative, num_rounds_excluding_stagger_bits, range_constrain_wnaf);
204
205 // Compute and constrain skews
206 bool_ct negative_skew(witness_ct(builder, is_negative ? 0 : skew), /*use_range_constraint*/ true);
207 bool_ct positive_skew(witness_ct(builder, is_negative ? skew : 0), /*use_range_constraint*/ true);
208
209 // Enforce that both positive_skew, negative_skew are not set at the same time
210 bool_ct both_skews_cannot_be_one = !(positive_skew & negative_skew);
211 both_skews_cannot_be_one.assert_equal(
212 bool_ct(builder, true), "biggroup_nafs: both positive and negative skews cannot be set at the same time");
213
214 // Initialize stagger witness
215 field_ct stagger_fragment = witness_ct(builder, first_fragment);
216
217 // We only range constrain the stagger fragment if range_constrain_wnaf is set. This is because in some cases
218 // we may use the stagger fragment to lookup in a ROM/regular table, which implicitly enforces the range constraint.
219 if (range_constrain_wnaf) {
220 stagger_fragment.create_range_constraint(wnaf_size, "biggroup_nafs: stagger fragment is not in range");
221 }
222
223 // Reconstruct the bigfield scalar from (wnaf + stagger) representation
224 Fr reconstructed = reconstruct_bigfield_from_wnaf<wnaf_size>(
225 builder, wnaf, positive_skew, negative_skew, stagger_fragment, stagger, num_rounds_excluding_stagger_bits);
226
227 secp256k1_wnaf wnaf_out{ .wnaf = wnaf,
228 .positive_skew = positive_skew,
229 .negative_skew = negative_skew,
230 .least_significant_wnaf_fragment = stagger_fragment,
231 .has_wnaf_fragment = (stagger > 0) };
232
233 return std::make_pair(reconstructed, wnaf_out);
234}
235
323template <typename C, class Fq, class Fr, class G>
324template <size_t wnaf_size, size_t lo_stagger, size_t hi_stagger>
326 const Fr& scalar, const bool range_constrain_wnaf)
327{
354 C* builder = scalar.get_context();
355
356 constexpr size_t num_bits = 129;
357
358 // Decomposes the scalar k into two 129-bit scalars klo, khi such that
359 // k = klo + ζ * khi (mod n)
360 // = klo - λ * khi (mod n)
361 // where ζ is the primitive sixth root of unity mod n, and λ is the primitive cube root of unity mod n
362 // (note that ζ = -λ). We know that for any scalar k, such a decomposition exists and klo and khi are 128-bits long.
363 secp256k1::fr k(uint256_t(scalar.get_value() % Fr::modulus_u512));
364 secp256k1::fr klo(0);
365 secp256k1::fr khi(0);
366 bool klo_negative = false;
367 bool khi_negative = false;
369
370 // The low and high scalars must be less than 2^129 in absolute value. In some cases, the khi value
371 // is returned as negative, in which case we negate it and set a flag to indicate this. This is because
372 // we decompose the scalar as:
373 // k = klo + ζ * khi (mod n)
374 // = klo - λ * khi (mod n)
375 // where λ is the cube root of unity. If khi is negative, then -λ * khi is positive, and vice versa.
376 if (khi.uint256_t_no_montgomery_conversion().get_msb() >= 129) {
377 khi_negative = true;
378 khi = -khi;
379 }
380
381 BB_ASSERT_LT(klo.uint256_t_no_montgomery_conversion().get_msb(), 129ULL, "biggroup_nafs: klo > 129 bits");
382 BB_ASSERT_LT(khi.uint256_t_no_montgomery_conversion().get_msb(), 129ULL, "biggroup_nafs: khi > 129 bits");
383
384 const auto [klo_reconstructed, klo_out] =
386 builder, klo, lo_stagger, klo_negative, range_constrain_wnaf, true);
387
388 const auto [khi_reconstructed, khi_out] =
390 builder, khi, hi_stagger, khi_negative, range_constrain_wnaf, false);
391
392 uint256_t minus_lambda_val(-secp256k1::fr::cube_root_of_unity());
393 Fr minus_lambda(bb::fr(minus_lambda_val.slice(0, 136)), bb::fr(minus_lambda_val.slice(136, 256)), false);
394
395 Fr reconstructed_scalar = khi_reconstructed.madd(minus_lambda, { klo_reconstructed });
396
397 // Constant scalars are always reduced mod n by design (scalar < n), however
398 // the reconstructed_scalar may be larger than n as it's a witness. So we need to
399 // reduce the reconstructed_scalar mod n explicitly to match the original scalar.
400 // This is necessary for assert_equal to pass.
401 if (scalar.is_constant()) {
402 reconstructed_scalar.self_reduce();
403 }
404
405 // Validate that the reconstructed scalar matches the original scalar in circuit
406 scalar.assert_equal(reconstructed_scalar, "biggroup_nafs: reconstructed scalar does not match reduced input");
407
408 return { .klo = klo_out, .khi = khi_out };
409}
410
411template <typename C, class Fq, class Fr, class G>
412std::vector<bool_t<C>> element<C, Fq, Fr, G>::compute_naf(const Fr& scalar, const size_t max_num_bits)
413{
414 // Get the circuit builder
415 C* builder = scalar.get_context();
416
417 // To compute the NAF representation, we first reduce the scalar modulo r (the scalar field modulus).
418 uint512_t scalar_multiplier_512 = uint512_t(scalar.get_value()) % uint512_t(Fr::modulus);
419 uint256_t scalar_multiplier = scalar_multiplier_512.lo;
420
421 // Number of rounds is either the max_num_bits provided, or the full size of the scalar field modulus.
422 // If the scalar is zero, we use the full size of the scalar field modulus as we use scalar = r in this case.
423 const size_t num_rounds = (max_num_bits == 0 || scalar_multiplier == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits;
424
425 // NAF can't handle 0 so we set scalar = r in this case.
426 if (scalar_multiplier == 0) {
427 scalar_multiplier = Fr::modulus;
428 }
429
430 // NAF representation consists of num_rounds bits and a skew bit.
431 // Given a scalar k, we compute the NAF representation as follows:
432 //
433 // k = -skew + ₀∑ⁿ⁻¹ (1 - 2 * naf_i) * 2^i
434 //
435 // where naf_i = (1 - k_{i + 1}) ∈ {0, 1} and k_{i + 1} is the (i + 1)-th bit of the scalar k.
436 // If naf_i = 0, then the i-th NAF entry is +1, otherwise it is -1. See the README for more details.
437 //
438 std::vector<bool_ct> naf_entries(num_rounds + 1);
439
440 // If the scalar is even, we set the skew flag to true and add 1 to the scalar.
441 // Sidenote: we apply range constraints to the boolean witnesses instead of full 1-bit range gates.
442 const bool skew_value = !scalar_multiplier.get_bit(0);
443 scalar_multiplier += uint256_t(static_cast<uint64_t>(skew_value));
444 naf_entries[num_rounds] = bool_ct(witness_ct(builder, skew_value), /*use_range_constraint*/ true);
445
446 // We need to manually propagate the origin tag
447 naf_entries[num_rounds].set_origin_tag(scalar.get_origin_tag());
448
449 for (size_t i = 0; i < num_rounds - 1; ++i) {
450 // If the next entry is false, we need to flip the sign of the current entry (naf_entry := (1 - next_bit)).
451 // Apply a basic range constraint per bool, and not a full 1-bit range gate. Results in ~`num_rounds`/4 gates
452 // per scalar.
453 const bool next_entry = scalar_multiplier.get_bit(i + 1);
454 naf_entries[num_rounds - i - 1] = bool_ct(witness_ct(builder, !next_entry), /*use_range_constraint*/ true);
455
456 // We need to manually propagate the origin tag
457 naf_entries[num_rounds - i - 1].set_origin_tag(scalar.get_origin_tag());
458 }
459
460 // The most significant NAF entry is always (+1) as we are working with scalars < 2^{max_num_bits}.
461 // Recall that true represents (-1) and false represents (+1).
462 naf_entries[0] = bool_ct(witness_ct(builder, false), /*use_range_constraint*/ true);
463 naf_entries[0].set_origin_tag(scalar.get_origin_tag());
464
465 // validate correctness of NAF
466 if constexpr (!Fr::is_composite) {
467 std::vector<Fr> accumulators;
468 for (size_t i = 0; i < num_rounds; ++i) {
469 // bit = 1 - 2 * naf
470 Fr entry(naf_entries[num_rounds - i - 1]);
471 entry *= -2;
472 entry += 1;
473 entry *= static_cast<Fr>(uint256_t(1) << (i));
474 accumulators.emplace_back(entry);
475 }
476 accumulators.emplace_back(-Fr(naf_entries[num_rounds])); // -skew
477 Fr accumulator_result = Fr::accumulate(accumulators);
478 scalar.assert_equal(accumulator_result);
479 } else {
480 const auto reconstruct_half_naf = [](bool_ct* nafs, const size_t half_round_length) {
481 field_ct negative_accumulator(0);
482 field_ct positive_accumulator(0);
483 for (size_t i = 0; i < half_round_length; ++i) {
484 negative_accumulator = negative_accumulator + negative_accumulator + field_ct(nafs[i]);
485 positive_accumulator = positive_accumulator + positive_accumulator + field_ct(1) - field_ct(nafs[i]);
486 }
487 return std::make_pair(positive_accumulator, negative_accumulator);
488 };
489
490 std::pair<field_ct, field_ct> hi_accumulators;
491 std::pair<field_ct, field_ct> lo_accumulators;
492
493 if (num_rounds > Fr::NUM_LIMB_BITS * 2) {
494 const size_t midpoint = num_rounds - (Fr::NUM_LIMB_BITS * 2);
495 hi_accumulators = reconstruct_half_naf(&naf_entries[0], midpoint);
496 lo_accumulators = reconstruct_half_naf(&naf_entries[midpoint], num_rounds - midpoint);
497 } else {
498 // If the number of rounds is ≤ (2 * Fr::NUM_LIMB_BITS), the high bits of the resulting Fr element are 0.
499 const field_ct zero = field_ct::from_witness_index(builder, builder->zero_idx());
500 lo_accumulators = reconstruct_half_naf(&naf_entries[0], num_rounds);
501 hi_accumulators = std::make_pair(zero, zero);
502 }
503
504 // Add the skew bit to the low accumulator's negative part
505 lo_accumulators.second = lo_accumulators.second + field_ct(naf_entries[num_rounds]);
506
507 Fr reconstructed_positive = Fr(lo_accumulators.first, hi_accumulators.first);
508 Fr reconstructed_negative = Fr(lo_accumulators.second, hi_accumulators.second);
509 Fr accumulator = reconstructed_positive - reconstructed_negative;
510
511 // Constant scalars are always reduced mod n by design (scalar < n), however
512 // the reconstructed accumulator may be larger than n as its a witness. So we need to
513 // reduce the reconstructed accumulator mod n explicitly to match the original scalar.
514 // This is necessary for assert_equal to pass.
515 if (scalar.is_constant()) {
516 accumulator.self_reduce();
517 }
518
519 // Validate that the reconstructed scalar matches the original scalar in circuit
520 accumulator.assert_equal(scalar);
521 }
522
523 // Propagate tags to naf
524 const auto original_tag = scalar.get_origin_tag();
525 for (auto& naf_entry : naf_entries) {
526 naf_entry.set_origin_tag(original_tag);
527 }
528 return naf_entries;
529}
530} // namespace bb::stdlib::element_default
#define BB_ASSERT_LT(left, right,...)
Definition assert.hpp:137
constexpr bool get_bit(uint64_t bit_index) const
constexpr uint256_t slice(uint64_t start, uint64_t end) const
constexpr uint64_t get_msb() const
Implements boolean logic in-circuit.
Definition bool.hpp:59
void assert_equal(const bool_t &rhs, std::string const &msg="bool_t::assert_equal") const
Implements copy constraint for bool_t elements.
Definition bool.cpp:421
static field_t from_witness_index(Builder *ctx, uint32_t witness_index)
Definition field.cpp:62
static field_t accumulate(const std::vector< field_t > &input)
Efficiently compute the sum of vector entries. Using big_add_gate we reduce the number of gates neede...
Definition field.cpp:1167
void create_range_constraint(size_t num_bits, std::string const &msg="field_t::range_constraint") const
Let x = *this.normalize(), constrain x.v < 2^{num_bits}.
Definition field.cpp:909
AluTraceBuilder builder
Definition alu.test.cpp:124
bn254::witness_ct witness_ct
stdlib::field_t< Builder > field_ct
constexpr T get_msb(const T in)
Definition get_msb.hpp:47
uintx< uint256_t > uint512_t
Definition uintx.hpp:307
void fixed_wnaf(const uint64_t *scalar, uint64_t *wnaf, bool &skew_map, const uint64_t point_index, const uint64_t num_points, const size_t wnaf_bits) noexcept
Performs fixed-window non-adjacent form (WNAF) computation for scalar multiplication.
Definition wnaf.hpp:178
Inner sum(Cont< Inner, Args... > const &in)
Definition container.hpp:70
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
Curve::ScalarField Fr
static constexpr field cube_root_of_unity()
static constexpr uint256_t modulus
static void split_into_endomorphism_scalars(const field &k, field &k1, field &k2)
constexpr uint256_t uint256_t_no_montgomery_conversion() const noexcept
BB_INLINE constexpr field from_montgomery_form() const noexcept