1use std::{array, iter};
8
9use anyhow::Result;
10use array_util::ArrayExt;
11use binius_core::oracle::ShiftVariant;
12use binius_field::{
13 AESTowerField8b, ExtensionField, PackedExtension, PackedField, PackedFieldIndexable,
14 PackedSubfield, TowerField, ext_basis,
15 linear_transformation::{
16 FieldLinearTransformation, PackedTransformationFactory, Transformation,
17 },
18 packed::{get_packed_slice, len_packed_slice, set_packed_slice},
19};
20
21use crate::builder::{B1, B8, B128, Col, Expr, TableBuilder, TableWitnessSegment, upcast_col};
22
23const MIX_BYTES_VEC: [u8; 8] = [0x02, 0x02, 0x03, 0x04, 0x05, 0x03, 0x05, 0x07];
25
26const S_BOX_TOWER_MATRIX: FieldLinearTransformation<B8> =
29 FieldLinearTransformation::new_const(&S_BOX_TOWER_MATRIX_COLS);
30
31const S_BOX_TOWER_MATRIX_COLS: [B8; 8] = [
32 B8::new(0x62),
33 B8::new(0xd2),
34 B8::new(0x79),
35 B8::new(0x41),
36 B8::new(0xf4),
37 B8::new(0xd5),
38 B8::new(0x81),
39 B8::new(0x4e),
40];
41
42const S_BOX_TOWER_OFFSET: B8 = B8::new(0x14);
45
46#[derive(Debug, Clone)]
55pub struct Permutation {
56 rounds: [PermutationRound; 10],
57}
58
59impl Permutation {
60 pub fn new(
61 table: &mut TableBuilder,
62 pq: PermutationVariant,
63 mut state_in: [Col<B8, 8>; 8],
64 ) -> Self {
65 let rounds = array::from_fn(|i| {
66 let round = PermutationRound::new(
67 &mut table.with_namespace(format!("round[{i}]")),
68 pq,
69 state_in,
70 i,
71 );
72 state_in = round.state_out;
73 round
74 });
75 Self { rounds }
76 }
77
78 pub fn state_in(&self) -> [Col<B8, 8>; 8] {
80 self.rounds[0].state_in
81 }
82
83 pub fn state_out(&self) -> [Col<B8, 8>; 8] {
85 self.rounds[9].state_out
86 }
87
88 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<()>
89 where
90 P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
91 PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
92 {
93 for round in &self.rounds {
94 round.populate(index)?;
95 }
96 Ok(())
97 }
98
99 pub fn populate_state_in<'a, P>(
101 &self,
102 index: &mut TableWitnessSegment<P>,
103 states: impl IntoIterator<Item = &'a [B8; 64]>,
104 ) -> Result<()>
105 where
106 P: PackedExtension<B8>,
107 P::Scalar: TowerField,
108 {
109 let mut state_in = self
110 .state_in()
111 .try_map_ext(|state_in_i| index.get_mut(state_in_i))?;
112 for (k, state_k) in states.into_iter().enumerate() {
113 for (i, state_in_i) in state_in.iter_mut().enumerate() {
114 for j in 0..8 {
115 set_packed_slice(state_in_i, k * 8 + j, state_k[j * 8 + i]);
116 }
117 }
118 }
119 Ok(())
120 }
121
122 pub fn read_state_outs<'a, P>(
126 &'a self,
127 index: &'a mut TableWitnessSegment<'a, P>,
128 ) -> Result<impl Iterator<Item = [B8; 64]> + 'a>
129 where
130 P: PackedExtension<B8>,
131 P::Scalar: TowerField,
132 {
133 let state_out = self
134 .state_out()
135 .try_map_ext(|state_out_i| index.get(state_out_i))?;
136 let iter = (0..index.log_size()).map(move |k| {
137 array::from_fn(|ij| {
138 let i = ij % 8;
139 let j = ij / 8;
140 get_packed_slice(&state_out[i], k * 8 + j)
141 })
142 });
143 Ok(iter)
144 }
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::Display)]
148pub enum PermutationVariant {
149 P,
150 Q,
151}
152
153impl PermutationVariant {
154 fn shift_bytes_offset(self, i: usize) -> usize {
159 const P_SHIFTS: [usize; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
160 const Q_SHIFTS: [usize; 8] = [1, 3, 5, 7, 0, 2, 4, 6];
161 let right_shift = match self {
162 PermutationVariant::P => P_SHIFTS[i],
163 PermutationVariant::Q => Q_SHIFTS[i],
164 };
165 (8 - right_shift) % 8
167 }
168}
169
170fn round_consts(round: usize) -> [B8; 8] {
171 array::from_fn(|i| {
172 let val = (i * 0x10) ^ round;
173 B8::from(AESTowerField8b::new(val as u8))
174 })
175}
176
177#[derive(Debug, Clone)]
179struct PermutationRound {
180 pq: PermutationVariant,
181 round: usize,
182 pub state_in: [Col<B8, 8>; 8],
184 round_const: Col<B8, 8>,
186 sbox: [SBox<8>; 8],
187 shift: [Col<B8, 8>; 8],
188 pub state_out: [Col<B8, 8>; 8],
190}
191
192impl PermutationRound {
193 pub fn new(
194 table: &mut TableBuilder,
195 pq: PermutationVariant,
196 state_in: [Col<B8, 8>; 8],
197 round: usize,
198 ) -> Self {
199 let round_const = table.add_constant("RoundConstant", round_consts(round));
200
201 let sbox = array::from_fn(|i| {
203 let sbox_in = match (i, pq) {
204 (0, PermutationVariant::P) => state_in[0] + round_const,
205 (_, PermutationVariant::P) => state_in[i].into(),
206 (7, PermutationVariant::Q) => {
207 state_in[7] + round_const + B8::from(AESTowerField8b::new(0xFF))
208 }
209 (_, PermutationVariant::Q) => state_in[i] + B8::from(AESTowerField8b::new(0xFF)),
210 };
211 SBox::new(&mut table.with_namespace(format!("SubBytes[{i}]")), sbox_in)
212 });
213
214 let shift = array::from_fn(|i| {
216 let offset = pq.shift_bytes_offset(i);
217 if offset == 0 {
218 sbox[i].output
219 } else {
220 table.add_shifted(
221 format!("ShiftBytes[{i}]"),
222 sbox[i].output,
223 3,
224 offset,
225 ShiftVariant::CircularLeft,
226 )
227 }
228 });
229
230 let mix_bytes_scalars = MIX_BYTES_VEC.map(|byte| B8::from(AESTowerField8b::new(byte)));
232 let state_out = array::from_fn(|j| {
233 let mix_bytes: [_; 8] =
234 array::from_fn(|i| shift[i] * mix_bytes_scalars[(8 + i - j) % 8]);
235 table.add_computed(
236 format!("MixBytes[{j}]"),
237 mix_bytes
238 .into_iter()
239 .reduce(|a, b| a + b)
240 .expect("mix_bytes has length 8"),
241 )
242 });
243
244 Self {
245 pq,
246 round,
247 state_in,
248 round_const,
249 sbox,
250 shift,
251 state_out,
252 }
253 }
254
255 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<()>
256 where
257 P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
258 PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
259 {
260 {
261 let mut round_const = index.get_mut(self.round_const)?;
262 let round_consts = round_consts(self.round);
263 for k in 0..len_packed_slice(&round_const) {
264 set_packed_slice(&mut round_const, k, round_consts[k % 8]);
265 }
266 }
267
268 for sbox in &self.sbox {
270 sbox.populate(index)?;
271 }
272
273 for (i, (sbox, shift)) in iter::zip(&self.sbox, self.shift).enumerate() {
275 if sbox.output == shift {
276 continue;
277 }
278
279 let sbox_out = index.get_as::<u64, _, 8>(sbox.output)?;
280 let mut shift = index.get_mut_as::<u64, _, 8>(shift)?;
281
282 let offset = self.pq.shift_bytes_offset(i);
285 for (sbox_out_j, shift_j) in iter::zip(&*sbox_out, &mut *shift) {
286 *shift_j = sbox_out_j.rotate_left((offset * 8) as u32);
287 }
288 }
289
290 let mix_bytes_scalars = MIX_BYTES_VEC.map(|byte| B8::from(AESTowerField8b::new(byte)));
294 let shift: [_; 8] = array_util::try_from_fn(|i| index.get(self.shift[i]))?;
295 for j in 0..8 {
296 let mut mix_bytes_out = index.get_mut(self.state_out[j])?;
297 for (k, mix_bytes_out_k) in mix_bytes_out.iter_mut().enumerate() {
298 *mix_bytes_out_k = (0..8)
299 .map(|i| shift[i][k] * mix_bytes_scalars[(8 + i - j) % 8])
300 .sum();
301 }
302 }
303
304 Ok(())
305 }
306}
307
308#[derive(Debug, Clone)]
318struct SBox<const V: usize> {
319 input: Expr<B8, V>,
320 inv_bits: [Col<B1, V>; 8],
322 inv: Col<B8, V>,
323 pub output: Col<B8, V>,
324}
325
326impl<const V: usize> SBox<V> {
327 pub fn new(table: &mut TableBuilder, input: Expr<B8, V>) -> Self {
328 let inv_bits = array::from_fn(|i| table.add_committed(format!("inv_bits[{i}]")));
329 let inv = table.add_computed("inv", pack_b8(inv_bits));
330
331 table.assert_zero("inv_valid_or_inv_zero", input.clone() * Expr::from(inv).pow(2) - inv);
333 table.assert_zero("inv_valid_or_input_zero", input.clone().pow(2) * inv - input.clone());
335
336 let linear_transform_expr = iter::zip(inv_bits, S_BOX_TOWER_MATRIX_COLS)
338 .map(|(inv_bit_i, scalar)| upcast_col(inv_bit_i) * scalar)
339 .reduce(|a, b| a + b)
340 .expect("inv_bits and S_BOX_TOWER_MATRIX_COLS have length 8");
341 let output =
342 table.add_computed("output", linear_transform_expr.clone() + S_BOX_TOWER_OFFSET);
343
344 Self {
345 input,
346 inv_bits,
347 inv,
348 output,
349 }
350 }
351
352 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<()>
353 where
354 P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
355 PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
356 {
357 let mut inv = index.get_mut(self.inv)?;
358
359 for (inv_i, val_i) in iter::zip(&mut *inv, index.eval_expr(&self.input)?) {
361 *inv_i = val_i.invert_or_zero();
362 }
363
364 let mut inv_bits = self
366 .inv_bits
367 .try_map_ext(|inv_bits_i| index.get_mut(inv_bits_i))?;
368 for i in 0..index.size() * V {
369 let inv_val = get_packed_slice(&inv, i);
370 for (j, inv_bit_j) in ExtensionField::<B1>::iter_bases(&inv_val).enumerate() {
371 set_packed_slice(&mut inv_bits[j], i, inv_bit_j);
372 }
373 }
374
375 let mut output = index.get_mut(self.output)?;
377
378 let transform_matrix =
379 <PackedSubfield<P, B8>>::make_packed_transformation(S_BOX_TOWER_MATRIX);
380 let transform_offset = <PackedSubfield<P, B8>>::broadcast(S_BOX_TOWER_OFFSET);
381 for (out_i, inv_i) in iter::zip(&mut *output, &*inv) {
382 *out_i = transform_offset + transform_matrix.transform(inv_i);
383 }
384
385 Ok(())
386 }
387}
388
389fn pack_b8<const V: usize>(bits: [Col<B1, V>; 8]) -> Expr<B8, V> {
390 let b8_basis: [_; 8] = array::from_fn(ext_basis::<B8, B1>);
391 bits.into_iter()
392 .enumerate()
393 .map(|(i, bit)| upcast_col(bit) * b8_basis[i])
394 .reduce(|a, b| a + b)
395 .expect("bits has length 8")
396}
397
398#[cfg(test)]
399mod tests {
400 use std::iter::repeat_with;
401
402 use binius_compute::cpu::alloc::CpuComputeAllocator;
403 use binius_field::{
404 arch::OptimalUnderlier128b, arithmetic_traits::InvertOrZero, as_packed_field::PackedType,
405 };
406 use binius_hash::groestl::{GroestlShortImpl, GroestlShortInternal};
407 use rand::{SeedableRng, prelude::StdRng};
408
409 use super::*;
410 use crate::builder::{ConstraintSystem, WitnessIndex};
411
412 #[test]
413 fn test_sbox() {
414 let mut cs = ConstraintSystem::new();
415 let mut table = cs.add_table("sbox test");
416
417 let input = table.add_committed::<B8, 2>("input");
418 let sbox = SBox::new(&mut table, input + B8::new(0xFF));
419
420 let table_id = table.id();
421
422 let mut allocator = CpuComputeAllocator::new(1 << 12);
423 let allocator = allocator.into_bump_allocator();
424
425 let mut witness =
426 WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
427
428 let table_witness = witness.init_table(table_id, 1 << 8).unwrap();
429
430 let mut rng = StdRng::seed_from_u64(0);
431 let mut segment = table_witness.full_segment();
432 for in_i in &mut *segment.get_mut(input).unwrap() {
433 *in_i = PackedField::random(&mut rng);
434 }
435
436 sbox.populate(&mut segment).unwrap();
437
438 let ccs = cs.compile().unwrap();
439 let table_sizes = witness.table_sizes();
440 let witness = witness.into_multilinear_extension_index();
441
442 binius_core::constraint_system::validate::validate_witness(
443 &ccs,
444 &[],
445 &table_sizes,
446 &witness,
447 )
448 .unwrap();
449 }
450
451 #[test]
452 fn test_p_permutation() {
453 let mut cs = ConstraintSystem::new();
454 let mut table = cs.add_table("P-permutation test");
455
456 let input = table.add_committed_multiple::<B8, 8, 8>("state_in");
457 let perm = Permutation::new(&mut table, PermutationVariant::P, input);
458
459 let table_id = table.id();
460
461 let mut allocator = CpuComputeAllocator::new(1 << 16);
462 let allocator = allocator.into_bump_allocator();
463
464 let mut witness =
465 WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
466
467 let table_witness = witness.init_table(table_id, 1 << 8).unwrap();
468
469 let mut rng = StdRng::seed_from_u64(0);
470 let in_states = repeat_with(|| array::from_fn::<_, 64, _>(|_| B8::random(&mut rng)))
471 .take(1 << 8)
472 .collect::<Vec<_>>();
473 let out_states = in_states
474 .iter()
475 .map(|in_state| {
476 let in_state_bytes = in_state.map(|b8| AESTowerField8b::from(b8).val());
477 let mut state = GroestlShortImpl::state_from_bytes(&in_state_bytes);
478 GroestlShortImpl::p_perm(&mut state);
479 let out_state_bytes = GroestlShortImpl::state_to_bytes(&state);
480 out_state_bytes.map(|byte| B8::from(AESTowerField8b::new(byte)))
481 })
482 .collect::<Vec<_>>();
483
484 let mut segment = table_witness.full_segment();
485 perm.populate_state_in(&mut segment, in_states.iter())
486 .unwrap();
487 perm.populate(&mut segment).unwrap();
488
489 for (expected_out, generated_out) in
490 iter::zip(out_states, perm.read_state_outs(&mut segment).unwrap())
491 {
492 assert_eq!(generated_out, expected_out);
493 }
494
495 let ccs = cs.compile().unwrap();
496 let table_sizes = witness.table_sizes();
497 let witness = witness.into_multilinear_extension_index();
498
499 binius_core::constraint_system::validate::validate_witness(
500 &ccs,
501 &[],
502 &table_sizes,
503 &witness,
504 )
505 .unwrap();
506 }
507
508 #[test]
509 fn test_q_permutation() {
510 let mut cs = ConstraintSystem::new();
511 let mut table = cs.add_table("Q-permutation test");
512
513 let input = table.add_committed_multiple::<B8, 8, 8>("state_in");
514 let perm = Permutation::new(&mut table, PermutationVariant::Q, input);
515
516 let table_id = table.id();
517
518 let mut allocator = CpuComputeAllocator::new(1 << 16);
519 let allocator = allocator.into_bump_allocator();
520
521 let mut witness =
522 WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
523
524 let table_witness = witness.init_table(table_id, 1 << 8).unwrap();
525
526 let mut rng = StdRng::seed_from_u64(0);
527 let in_states = repeat_with(|| array::from_fn::<_, 64, _>(|_| B8::random(&mut rng)))
528 .take(1 << 8)
529 .collect::<Vec<_>>();
530 let out_states = in_states
531 .iter()
532 .map(|in_state| {
533 let in_state_bytes = in_state.map(|b8| AESTowerField8b::from(b8).val());
534 let mut state = GroestlShortImpl::state_from_bytes(&in_state_bytes);
535 GroestlShortImpl::q_perm(&mut state);
536 let out_state_bytes = GroestlShortImpl::state_to_bytes(&state);
537 out_state_bytes.map(|byte| B8::from(AESTowerField8b::new(byte)))
538 })
539 .collect::<Vec<_>>();
540
541 let mut segment = table_witness.full_segment();
542 perm.populate_state_in(&mut segment, in_states.iter())
543 .unwrap();
544 perm.populate(&mut segment).unwrap();
545
546 for (expected_out, generated_out) in
547 iter::zip(out_states, perm.read_state_outs(&mut segment).unwrap())
548 {
549 assert_eq!(generated_out, expected_out);
550 }
551
552 let ccs = cs.compile().unwrap();
553 let table_sizes = witness.table_sizes();
554 let witness = witness.into_multilinear_extension_index();
555
556 binius_core::constraint_system::validate::validate_witness(
557 &ccs,
558 &[],
559 &table_sizes,
560 &witness,
561 )
562 .unwrap();
563 }
564
565 #[test]
566 fn test_isomorphic_sbox() {
567 #[rustfmt::skip]
568 const S_BOX: [u8; 256] = [
569 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
570 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
571 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
572 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
573 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
574 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
575 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
576 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
577 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
578 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
579 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
580 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
581 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
582 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
583 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
584 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
585 ];
586
587 const S_BOX_MATRIX: FieldLinearTransformation<AESTowerField8b> =
588 FieldLinearTransformation::new_const(&[
589 AESTowerField8b::new(0x1F),
590 AESTowerField8b::new(0x3E),
591 AESTowerField8b::new(0x7C),
592 AESTowerField8b::new(0xF8),
593 AESTowerField8b::new(0xF1),
594 AESTowerField8b::new(0xE3),
595 AESTowerField8b::new(0xC7),
596 AESTowerField8b::new(0x8F),
597 ]);
598 const S_BOX_OFFSET: AESTowerField8b = AESTowerField8b::new(0x63);
599
600 for i in 0u8..=255u8 {
601 let sbox_in = AESTowerField8b::new(i);
602 let expected_sbox_out = AESTowerField8b::new(S_BOX[i as usize]);
603
604 let sbox_out =
605 S_BOX_MATRIX.transform(&InvertOrZero::invert_or_zero(sbox_in)) + S_BOX_OFFSET;
606 assert_eq!(sbox_out, expected_sbox_out);
607
608 let sbox_in_b8 = B8::from(sbox_in);
609 let sbox_out_b8 = S_BOX_TOWER_MATRIX
610 .transform(&InvertOrZero::invert_or_zero(sbox_in_b8))
611 + S_BOX_TOWER_OFFSET;
612 assert_eq!(AESTowerField8b::from(sbox_out_b8), expected_sbox_out);
613 }
614 }
615}