Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
keccak.cpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: not started, auditors: [], date: YYYY-MM-DD }
3// external_1: { status: not started, auditors: [], date: YYYY-MM-DD }
4// external_2: { status: not started, auditors: [], date: YYYY-MM-DD }
5// =====================
6
7#include "keccak.hpp"
15namespace bb::stdlib {
16
17using namespace bb::plookup;
18
35template <typename Builder>
36template <size_t lane_index>
38{
39 // left_bits = the number of bits that wrap around 11^64 (left_bits)
40 constexpr size_t left_bits = ROTATIONS[lane_index];
41
42 // right_bits = the number of bits that don't wrap
43 constexpr size_t right_bits = 64 - ROTATIONS[lane_index];
44
45 // TODO read from same source as plookup table code
46 constexpr size_t max_bits_per_table = plookup::keccak_tables::Rho<>::MAXIMUM_MULTITABLE_BITS;
47
48 // compute the number of lookups required for our left and right bit slices
49 constexpr size_t num_left_tables = left_bits / max_bits_per_table + (left_bits % max_bits_per_table > 0 ? 1 : 0);
50 constexpr size_t num_right_tables = right_bits / max_bits_per_table + (right_bits % max_bits_per_table > 0 ? 1 : 0);
51
52 // get the numerical value of the left and right bit slices
53 // (lookup table input values derived from left / right)
54 uint256_t input = limb.get_value();
55 constexpr uint256_t slice_divisor = BASE.pow(right_bits);
56 const auto [left, right] = input.divmod(slice_divisor);
57
58 // compute the normalized values for the left and right bit slices
59 // (lookup table output values derived from left_normalised / right_normalized)
60 uint256_t left_normalized = normalize_sparse(left);
61 uint256_t right_normalized = normalize_sparse(right);
62
103
104 // compute plookup witness values for a given slice
105 // (same lambda can be used to compute witnesses for left and right slices)
106 auto compute_lookup_witnesses_for_limb = [&]<size_t limb_bits, size_t num_lookups>(uint256_t& normalized) {
107 // (use a constexpr loop to make some pow and div operations compile-time)
108 bb::constexpr_for<0, num_lookups, 1>([&]<size_t i> {
109 constexpr size_t num_bits_processed = i * max_bits_per_table;
110
111 // How many bits can this slice contain?
112 // We want to implicitly range-constrain `normalized < 11^{limb_bits}`,
113 // which means potentially using a lookup table that is not of size 11^{max_bits_per_table}
114 // for the most-significant slice
115 constexpr size_t bit_slice = (num_bits_processed + max_bits_per_table > limb_bits)
116 ? limb_bits % max_bits_per_table
117 : max_bits_per_table;
118
119 // current column values are tracked via 'input' and 'normalized'
120 lookup[ColumnIdx::C1].push_back(input);
121 lookup[ColumnIdx::C2].push_back(normalized);
122
123 constexpr uint64_t divisor = numeric::pow64(static_cast<uint64_t>(BASE), bit_slice);
124 constexpr uint64_t msb_divisor = divisor / static_cast<uint64_t>(BASE);
125
126 // compute the value of the most significant bit of this slice and store in C3
127 const auto [normalized_quotient, normalized_slice] = normalized.divmod(divisor);
128
129 // 256-bit divisions are expensive! cast to u64s when we don't need the extra bits
130 const uint64_t normalized_msb = (static_cast<uint64_t>(normalized_slice) / msb_divisor);
131 lookup[ColumnIdx::C3].push_back(normalized_msb);
132
133 // We need to provide a key/value object for this lookup in order for the Builder
134 // to compute the plookup sorted list commitment
135 const auto [input_quotient, input_slice] = input.divmod(divisor);
136 lookup.lookup_entries.push_back(
137 { { static_cast<uint64_t>(input_slice), 0 }, { normalized_slice, normalized_msb } });
138
139 // reduce the input and output by 11^{bit_slice}
140 input = input_quotient;
141 normalized = normalized_quotient;
142 });
143 };
144
145 // template lambda syntax is a little funky.
146 // Need to explicitly write `.template operator()` (instead of just `()`).
147 // Otherwise compiler cannot distinguish between `>` symbol referring to closing the template parameter list,
148 // OR `>` being a greater-than operator :/
149 compute_lookup_witnesses_for_limb.template operator()<right_bits, num_right_tables>(right_normalized);
150 compute_lookup_witnesses_for_limb.template operator()<left_bits, num_left_tables>(left_normalized);
151
152 // Call builder method to create plookup constraints.
153 // The MultiTable table index can be derived from `lane_idx`
154 // Each lane_idx has a different rotation amount, which changes sizes of left/right slices
155 // and therefore the selector constants required (i.e. the Q1, Q2, Q3 values in the earlier example)
156 const auto accumulator_witnesses = limb.context->create_gates_from_plookup_accumulators(
157 (plookup::MultiTableId)((size_t)KECCAK_NORMALIZE_AND_ROTATE + lane_index), lookup, limb.get_witness_index());
158
159 // extract the most significant bit of the normalized output from the final lookup entry in column C3
161 accumulator_witnesses[ColumnIdx::C3][num_left_tables + num_right_tables - 1]);
162
163 // Extract the witness that maps to the normalized right slice
164 const field_t<Builder> right_output =
165 field_t<Builder>::from_witness_index(limb.get_context(), accumulator_witnesses[ColumnIdx::C2][0]);
166
167 if (num_left_tables == 0) {
168 // if the left slice size is 0 bits (i.e. no rotation), return `right_output`
169 return right_output;
170 } else {
171 // Extract the normalized left slice
173 limb.get_context(), accumulator_witnesses[ColumnIdx::C2][num_right_tables]);
174
175 // Stitch the right/left slices together to create our rotated output
176 constexpr uint256_t shift = BASE.pow(ROTATIONS[lane_index]);
177 return (left_output + right_output * shift);
178 }
179}
180
197template <typename Builder> void keccak<Builder>::compute_twisted_state(keccak_state& internal)
198{
199 for (size_t i = 0; i < NUM_KECCAK_LANES; ++i) {
200 internal.twisted_state[i] = ((internal.state[i] * 11) + internal.state_msb[i]).normalize();
201 }
202}
203
251template <typename Builder> void keccak<Builder>::theta(keccak_state& internal)
252{
255
256 auto& state = internal.state;
257 const auto& twisted_state = internal.twisted_state;
258 for (size_t i = 0; i < 5; ++i) {
259
268 C[i] = field_ct::accumulate({ twisted_state[i],
269 twisted_state[5 + i],
270 twisted_state[10 + i],
271 twisted_state[15 + i],
272 twisted_state[20 + i] });
273 }
274
279 for (size_t i = 0; i < 5; ++i) {
280 const auto non_shifted_equivalent = (C[(i + 4) % 5]);
281 const auto shifted_equivalent = C[(i + 1) % 5] * BASE;
282 D[i] = (non_shifted_equivalent + shifted_equivalent);
283 }
284
301 static constexpr uint256_t divisor = BASE.pow(64);
302 static constexpr uint256_t multiplicand = BASE.pow(65);
303 for (size_t i = 0; i < 5; ++i) {
304 uint256_t D_native = D[i].get_value();
305 const auto [D_quotient, lo_native] = D_native.divmod(BASE);
306 const uint256_t hi_native = D_quotient / divisor;
307 const uint256_t mid_native = D_quotient - hi_native * divisor;
308
309 field_ct hi(witness_ct(internal.context, hi_native));
310 field_ct mid(witness_ct(internal.context, mid_native));
311 field_ct lo(witness_ct(internal.context, lo_native));
312
313 // assert equal should cost 1 gate (multipliers are all constants)
314 D[i].assert_equal((hi * multiplicand).add_two(mid * 11, lo));
315 internal.context->create_new_range_constraint(hi.get_witness_index(), static_cast<uint64_t>(BASE));
316 internal.context->create_new_range_constraint(lo.get_witness_index(), static_cast<uint64_t>(BASE));
317
318 // If number of bits in KECCAK_THETA_OUTPUT table does NOT cleanly divide 64,
319 // we need an additional range constraint to ensure that mid < 11^64
320 if constexpr (64 % plookup::keccak_tables::Theta::TABLE_BITS == 0) {
321 // N.B. we could optimize out 5 gates per round here but it's very fiddly...
322 // In previous section, D[i] = X + Y (non shifted equiv and shifted equiv)
323 // We also want to validate D[i] == hi' + mid' + lo (where hi', mid' are hi, mid scaled by constants)
324 // We *could* create a big addition gate to validate the previous logic w. following structure:
325 // | w1 | w2 | w3 | w4 |
326 // | -- | --- | -- | -- |
327 // | hi | mid | lo | X |
328 // | P0 | P1 | P2 | Y |
329 // To save a gate, we would need to place the wires for the first KECCAK_THETA_OUTPUT plookup gate
330 // at P0, P1, P2. This is fiddly builder logic that is circuit-width-dependent
331 // (this would save 120 gates per hash block... not worth making the code less readable for that)
333 } else {
335 D[i] = accumulators[ColumnIdx::C2][0];
336
337 // Ensure input to lookup is < 11^64,
338 // by validating most significant input slice is < 11^{64 mod slice_bits}
339 const field_ct most_significant_slice = accumulators[ColumnIdx::C1][accumulators[ColumnIdx::C1].size() - 1];
340
341 // N.B. cheaper to validate (11^{64 mod slice_bits} - slice < 2^14) as this
342 // prevents an extra range table from being created
343 constexpr uint256_t maximum = BASE.pow(64 % plookup::keccak_tables::Theta::TABLE_BITS);
344 const field_ct target = -most_significant_slice + maximum;
345 BB_ASSERT_GT((uint256_t(1) << Builder::DEFAULT_PLOOKUP_RANGE_BITNUM) - 1, maximum);
346 target.create_range_constraint(Builder::DEFAULT_PLOOKUP_RANGE_BITNUM,
347 "input to KECCAK_THETA_OUTPUT too large!");
348 }
349 }
350
351 // compute state[j * 5 + i] XOR D[i] in base-11 representation
352 for (size_t i = 0; i < 5; ++i) {
353 for (size_t j = 0; j < 5; ++j) {
354 state[j * 5 + i] = state[j * 5 + i] + D[i];
355 }
356 }
357}
358
385template <typename Builder> void keccak<Builder>::rho(keccak_state& internal)
386{
387 constexpr_for<0, NUM_KECCAK_LANES, 1>(
388 [&]<size_t i>() { internal.state[i] = normalize_and_rotate<i>(internal.state[i], internal.state_msb[i]); });
389}
390
400template <typename Builder> void keccak<Builder>::pi(keccak_state& internal)
401{
403
404 for (size_t j = 0; j < 5; ++j) {
405 for (size_t i = 0; i < 5; ++i) {
406 B[j * 5 + i] = internal.state[j * 5 + i];
407 }
408 }
409
410 for (size_t y = 0; y < 5; ++y) {
411 for (size_t x = 0; x < 5; ++x) {
412 size_t u = (0 * x + 1 * y) % 5;
413 size_t v = (2 * x + 3 * y) % 5;
414
415 internal.state[v * 5 + u] = B[5 * y + x];
416 }
417 }
418}
419
436template <typename Builder> void keccak<Builder>::chi(keccak_state& internal)
437{
438 // (cost = 12 * 25 = 300?)
439 auto& state = internal.state;
440
441 for (size_t y = 0; y < 5; ++y) {
442 std::array<field_ct, 5> lane_outputs;
443 for (size_t x = 0; x < 5; ++x) {
444 const auto A = state[y * 5 + x];
445 const auto B = state[y * 5 + ((x + 1) % 5)];
446 const auto C = state[y * 5 + ((x + 2) % 5)];
447
448 // vv should cost 1 gate
449 lane_outputs[x] = (A + A + CHI_OFFSET).add_two(-B, C);
450 }
451 for (size_t x = 0; x < 5; ++x) {
452 // Normalize lane outputs and assign to internal.state
453 auto accumulators = plookup_read<Builder>::get_lookup_accumulators(KECCAK_CHI_OUTPUT, lane_outputs[x]);
454 internal.state[y * 5 + x] = accumulators[ColumnIdx::C2][0];
455 internal.state_msb[y * 5 + x] = accumulators[ColumnIdx::C3][accumulators[ColumnIdx::C3].size() - 1];
456 }
457 }
458}
459
469template <typename Builder> void keccak<Builder>::iota(keccak_state& internal, size_t round)
470{
471 const field_ct xor_result = internal.state[0] + SPARSE_RC[round];
472
473 // normalize lane value so that we don't overflow our base11 modulus boundary in the next round
474 internal.state[0] = normalize_and_rotate<0>(xor_result, internal.state_msb[0]);
475
476 // No need to add constraints to compute twisted repr if this is the last round
477 if (round != NUM_KECCAK_ROUNDS - 1) {
478 compute_twisted_state(internal);
479 }
480}
481
482template <typename Builder> void keccak<Builder>::keccakf1600(keccak_state& internal)
483{
484 for (size_t i = 0; i < NUM_KECCAK_ROUNDS; ++i) {
485 theta(internal);
486 rho(internal);
487 pi(internal);
488 chi(internal);
489 iota(internal, i);
490 }
491}
492
493// Returns the keccak f1600 permutation of the input state
494// We first convert the state into 'extended' representation, along with the 'twisted' state
495// and then we call keccakf1600() with this keccak 'internal state'
496// Finally, we convert back the state from the extented representation
497template <typename Builder>
499 std::array<field_t<Builder>, NUM_KECCAK_LANES> state, Builder* ctx)
500{
501 std::vector<field_t<Builder>> converted_buffer(NUM_KECCAK_LANES);
502 std::vector<field_t<Builder>> msb_buffer(NUM_KECCAK_LANES);
503 // populate keccak_state, convert our 64-bit lanes into an extended base-11 representation
504 keccak_state internal;
505 internal.context = ctx;
506 for (size_t i = 0; i < state.size(); ++i) {
507 const auto accumulators = plookup_read<Builder>::get_lookup_accumulators(KECCAK_FORMAT_INPUT, state[i]);
508 internal.state[i] = accumulators[ColumnIdx::C2][0];
509 internal.state_msb[i] = accumulators[ColumnIdx::C3][accumulators[ColumnIdx::C3].size() - 1];
510 }
511 compute_twisted_state(internal);
512 keccakf1600(internal);
513 // we convert back to the normal lanes
514 return extended_2_normal(internal);
515}
516
517// Convert the 'extended' representation of the internal Keccak state into the usual array of 64 bits lanes
518template <typename Builder>
520 keccak_state& internal)
521{
522 std::array<field_t<Builder>, NUM_KECCAK_LANES> conversion;
523
524 // Each hash limb represents a little-endian integer. Need to reverse bytes before we write into the output array
525 for (size_t i = 0; i < internal.state.size(); ++i) {
527 conversion[i] = output_limb;
528 }
529
530 return conversion;
531}
532
534template class keccak<bb::MegaCircuitBuilder>;
535
536} // namespace bb::stdlib
#define BB_ASSERT_GT(left, right,...)
Definition assert.hpp:107
constexpr uint256_t pow(const uint256_t &exponent) const
constexpr std::pair< uint256_t, uint256_t > divmod(const uint256_t &b) const
Container type for lookup table reads.
Definition types.hpp:341
std::vector< BasicTable::LookupEntry > lookup_entries
Definition types.hpp:347
Generate the plookup tables used for the RHO round of the Keccak hash algorithm.
static constexpr size_t TABLE_BITS
static field_t from_witness_index(Builder *ctx, uint32_t witness_index)
Definition field.cpp:62
static field_t accumulate(const std::vector< field_t > &input)
Efficiently compute the sum of vector entries. Using big_add_gate we reduce the number of gates neede...
Definition field.cpp:1167
void create_range_constraint(size_t num_bits, std::string const &msg="field_t::range_constraint") const
Let x = *this.normalize(), constrain x.v < 2^{num_bits}.
Definition field.cpp:909
Builder * context
Definition field.hpp:56
Builder * get_context() const
Definition field.hpp:419
bb::fr get_value() const
Given a := *this, compute its value given by a.v * a.mul + a.add.
Definition field.cpp:828
uint32_t get_witness_index() const
Get the witness index of the current field element.
Definition field.hpp:506
KECCAAAAAAAAAAK.
Definition keccak.hpp:25
static void rho(keccak_state &state)
RHO round.
Definition keccak.cpp:385
static void pi(keccak_state &state)
PI.
Definition keccak.cpp:400
static void theta(keccak_state &state)
THETA round.
Definition keccak.cpp:251
static void compute_twisted_state(keccak_state &internal)
Compute twisted representation of hash lane.
Definition keccak.cpp:197
static void chi(keccak_state &state)
CHI.
Definition keccak.cpp:436
static field_t< Builder > normalize_and_rotate(const field_ct &limb, field_ct &msb)
Normalize a base-11 limb and left-rotate by keccak::ROTATIONS[lane_index] bits. This method also extr...
Definition keccak.cpp:37
static std::array< field_ct, NUM_KECCAK_LANES > permutation_opcode(std::array< field_ct, NUM_KECCAK_LANES > state, Builder *context)
Definition keccak.cpp:498
static std::array< field_ct, NUM_KECCAK_LANES > extended_2_normal(keccak_state &internal)
Definition keccak.cpp:519
static void keccakf1600(keccak_state &state)
Definition keccak.cpp:482
static void iota(keccak_state &state, size_t round)
IOTA.
Definition keccak.cpp:469
bb::avm2::Column C
bn254::witness_ct witness_ct
constexpr uint64_t pow64(const uint64_t input, const uint64_t exponent)
Definition pow.hpp:13
@ KECCAK_FORMAT_INPUT
Definition types.hpp:119
@ KECCAK_FORMAT_OUTPUT
Definition types.hpp:120
@ KECCAK_NORMALIZE_AND_ROTATE
Definition types.hpp:121
@ KECCAK_CHI_OUTPUT
Definition types.hpp:118
@ KECCAK_THETA_OUTPUT
Definition types.hpp:117
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
std::array< field_ct, NUM_KECCAK_LANES > state
Definition keccak.hpp:147
std::array< field_ct, NUM_KECCAK_LANES > twisted_state
Definition keccak.hpp:149
std::array< field_ct, NUM_KECCAK_LANES > state_msb
Definition keccak.hpp:148