1use std::{array, iter};
8
9use anyhow::Result;
10use array_util::ArrayExt;
11use binius_core::oracle::ShiftVariant;
12use binius_field::{
13 ext_basis,
14 linear_transformation::{
15 FieldLinearTransformation, PackedTransformationFactory, Transformation,
16 },
17 packed::{get_packed_slice, len_packed_slice, set_packed_slice},
18 AESTowerField8b, ExtensionField, PackedExtension, PackedField, PackedFieldIndexable,
19 PackedSubfield, TowerField,
20};
21
22use crate::builder::{upcast_col, Col, Expr, TableBuilder, TableWitnessSegment, B1, B128, B8};
23
24const MIX_BYTES_VEC: [u8; 8] = [0x02, 0x02, 0x03, 0x04, 0x05, 0x03, 0x05, 0x07];
26
27const S_BOX_TOWER_MATRIX: FieldLinearTransformation<B8> =
30 FieldLinearTransformation::new_const(&S_BOX_TOWER_MATRIX_COLS);
31
32const S_BOX_TOWER_MATRIX_COLS: [B8; 8] = [
33 B8::new(0x62),
34 B8::new(0xd2),
35 B8::new(0x79),
36 B8::new(0x41),
37 B8::new(0xf4),
38 B8::new(0xd5),
39 B8::new(0x81),
40 B8::new(0x4e),
41];
42
43const S_BOX_TOWER_OFFSET: B8 = B8::new(0x14);
46
47#[derive(Debug, Clone)]
56pub struct Permutation {
57 rounds: [PermutationRound; 10],
58}
59
60impl Permutation {
61 pub fn new(
62 table: &mut TableBuilder,
63 pq: PermutationVariant,
64 mut state_in: [Col<B8, 8>; 8],
65 ) -> Self {
66 let rounds = array::from_fn(|i| {
67 let round = PermutationRound::new(
68 &mut table.with_namespace(format!("round[{}]", i)),
69 pq,
70 state_in,
71 i,
72 );
73 state_in = round.state_out;
74 round
75 });
76 Self { rounds }
77 }
78
79 pub fn state_in(&self) -> [Col<B8, 8>; 8] {
81 self.rounds[0].state_in
82 }
83
84 pub fn state_out(&self) -> [Col<B8, 8>; 8] {
86 self.rounds[9].state_out
87 }
88
89 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<()>
90 where
91 P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
92 PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
93 {
94 for round in &self.rounds {
95 round.populate(index)?;
96 }
97 Ok(())
98 }
99
100 pub fn populate_state_in<'a, P>(
102 &self,
103 index: &mut TableWitnessSegment<P>,
104 states: impl IntoIterator<Item = &'a [B8; 64]>,
105 ) -> Result<()>
106 where
107 P: PackedExtension<B8>,
108 P::Scalar: TowerField,
109 {
110 let mut state_in = self
111 .state_in()
112 .try_map_ext(|state_in_i| index.get_mut(state_in_i))?;
113 for (k, state_k) in states.into_iter().enumerate() {
114 for (i, state_in_i) in state_in.iter_mut().enumerate() {
115 for j in 0..8 {
116 set_packed_slice(state_in_i, k * 8 + j, state_k[j * 8 + i]);
117 }
118 }
119 }
120 Ok(())
121 }
122
123 pub fn read_state_outs<'a, P>(
127 &'a self,
128 index: &'a mut TableWitnessSegment<'a, P>,
129 ) -> Result<impl Iterator<Item = [B8; 64]> + 'a>
130 where
131 P: PackedExtension<B8>,
132 P::Scalar: TowerField,
133 {
134 let state_out = self
135 .state_out()
136 .try_map_ext(|state_out_i| index.get(state_out_i))?;
137 let iter = (0..index.log_size()).map(move |k| {
138 array::from_fn(|ij| {
139 let i = ij % 8;
140 let j = ij / 8;
141 get_packed_slice(&state_out[i], k * 8 + j)
142 })
143 });
144 Ok(iter)
145 }
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::Display)]
149pub enum PermutationVariant {
150 P,
151 Q,
152}
153
154impl PermutationVariant {
155 fn shift_bytes_offset(self, i: usize) -> usize {
160 const P_SHIFTS: [usize; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
161 const Q_SHIFTS: [usize; 8] = [1, 3, 5, 7, 0, 2, 4, 6];
162 let right_shift = match self {
163 PermutationVariant::P => P_SHIFTS[i],
164 PermutationVariant::Q => Q_SHIFTS[i],
165 };
166 (8 - right_shift) % 8
168 }
169}
170
171fn round_consts(round: usize) -> [B8; 8] {
172 array::from_fn(|i| {
173 let val = (i * 0x10) ^ round;
174 B8::from(AESTowerField8b::new(val as u8))
175 })
176}
177
178#[derive(Debug, Clone)]
180struct PermutationRound {
181 pq: PermutationVariant,
182 round: usize,
183 pub state_in: [Col<B8, 8>; 8],
185 round_const: Col<B8, 8>,
187 sbox: [SBox<8>; 8],
188 shift: [Col<B8, 8>; 8],
189 pub state_out: [Col<B8, 8>; 8],
191}
192
193impl PermutationRound {
194 pub fn new(
195 table: &mut TableBuilder,
196 pq: PermutationVariant,
197 state_in: [Col<B8, 8>; 8],
198 round: usize,
199 ) -> Self {
200 let round_const = table.add_constant("RoundConstant", round_consts(round));
201
202 let sbox = array::from_fn(|i| {
204 let sbox_in = match (i, pq) {
205 (0, PermutationVariant::P) => state_in[0] + round_const,
206 (_, PermutationVariant::P) => state_in[i].into(),
207 (7, PermutationVariant::Q) => {
208 state_in[7] + round_const + B8::from(AESTowerField8b::new(0xFF))
209 }
210 (_, PermutationVariant::Q) => state_in[i] + B8::from(AESTowerField8b::new(0xFF)),
211 };
212 SBox::new(&mut table.with_namespace(format!("SubBytes[{i}]")), sbox_in)
213 });
214
215 let shift = array::from_fn(|i| {
217 let offset = pq.shift_bytes_offset(i);
218 if offset == 0 {
219 sbox[i].output
220 } else {
221 table.add_shifted(
222 format!("ShiftBytes[{i}]"),
223 sbox[i].output,
224 3,
225 offset,
226 ShiftVariant::CircularLeft,
227 )
228 }
229 });
230
231 let mix_bytes_scalars = MIX_BYTES_VEC.map(|byte| B8::from(AESTowerField8b::new(byte)));
233 let state_out = array::from_fn(|j| {
234 let mix_bytes: [_; 8] =
235 array::from_fn(|i| shift[i] * mix_bytes_scalars[(8 + i - j) % 8]);
236 table.add_computed(
237 format!("MixBytes[{j}]"),
238 mix_bytes
239 .into_iter()
240 .reduce(|a, b| a + b)
241 .expect("mix_bytes has length 8"),
242 )
243 });
244
245 Self {
246 pq,
247 round,
248 state_in,
249 round_const,
250 sbox,
251 shift,
252 state_out,
253 }
254 }
255
256 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<()>
257 where
258 P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
259 PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
260 {
261 {
262 let mut round_const = index.get_mut(self.round_const)?;
263 let round_consts = round_consts(self.round);
264 for k in 0..len_packed_slice(&round_const) {
265 set_packed_slice(&mut round_const, k, round_consts[k % 8]);
266 }
267 }
268
269 for sbox in &self.sbox {
271 sbox.populate(index)?;
272 }
273
274 for (i, (sbox, shift)) in iter::zip(&self.sbox, self.shift).enumerate() {
276 if sbox.output == shift {
277 continue;
278 }
279
280 let sbox_out = index.get_as::<u64, _, 8>(sbox.output)?;
281 let mut shift = index.get_mut_as::<u64, _, 8>(shift)?;
282
283 let offset = self.pq.shift_bytes_offset(i);
286 for (sbox_out_j, shift_j) in iter::zip(&*sbox_out, &mut *shift) {
287 *shift_j = sbox_out_j.rotate_left((offset * 8) as u32);
288 }
289 }
290
291 let mix_bytes_scalars = MIX_BYTES_VEC.map(|byte| B8::from(AESTowerField8b::new(byte)));
295 let shift: [_; 8] = array_util::try_from_fn(|i| index.get(self.shift[i]))?;
296 for j in 0..8 {
297 let mut mix_bytes_out = index.get_mut(self.state_out[j])?;
298 for (k, mix_bytes_out_k) in mix_bytes_out.iter_mut().enumerate() {
299 *mix_bytes_out_k = (0..8)
300 .map(|i| shift[i][k] * mix_bytes_scalars[(8 + i - j) % 8])
301 .sum();
302 }
303 }
304
305 Ok(())
306 }
307}
308
309#[derive(Debug, Clone)]
319struct SBox<const V: usize> {
320 input: Expr<B8, V>,
321 inv_bits: [Col<B1, V>; 8],
323 inv: Col<B8, V>,
324 pub output: Col<B8, V>,
325}
326
327impl<const V: usize> SBox<V> {
328 pub fn new(table: &mut TableBuilder, input: Expr<B8, V>) -> Self {
329 let inv_bits = array::from_fn(|i| table.add_committed(format!("inv_bits[{}]", i)));
330 let inv = table.add_computed("inv", pack_b8(inv_bits));
331
332 table.assert_zero("inv_valid_or_inv_zero", input.clone() * Expr::from(inv).pow(2) - inv);
334 table.assert_zero("inv_valid_or_input_zero", input.clone().pow(2) * inv - input.clone());
336
337 let linear_transform_expr = iter::zip(inv_bits, S_BOX_TOWER_MATRIX_COLS)
339 .map(|(inv_bit_i, scalar)| upcast_col(inv_bit_i) * scalar)
340 .reduce(|a, b| a + b)
341 .expect("inv_bits and S_BOX_TOWER_MATRIX_COLS have length 8");
342 let output =
343 table.add_computed("output", linear_transform_expr.clone() + S_BOX_TOWER_OFFSET);
344
345 Self {
346 input,
347 inv_bits,
348 inv,
349 output,
350 }
351 }
352
353 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<()>
354 where
355 P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
356 PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
357 {
358 let mut inv = index.get_mut(self.inv)?;
359
360 for (inv_i, val_i) in iter::zip(&mut *inv, index.eval_expr(&self.input)?) {
362 *inv_i = val_i.invert_or_zero();
363 }
364
365 let mut inv_bits = self
367 .inv_bits
368 .try_map_ext(|inv_bits_i| index.get_mut(inv_bits_i))?;
369 for i in 0..index.size() * V {
370 let inv_val = get_packed_slice(&inv, i);
371 for (j, inv_bit_j) in ExtensionField::<B1>::iter_bases(&inv_val).enumerate() {
372 set_packed_slice(&mut inv_bits[j], i, inv_bit_j);
373 }
374 }
375
376 let mut output = index.get_mut(self.output)?;
378
379 let transform_matrix =
380 <PackedSubfield<P, B8>>::make_packed_transformation(S_BOX_TOWER_MATRIX);
381 let transform_offset = <PackedSubfield<P, B8>>::broadcast(S_BOX_TOWER_OFFSET);
382 for (out_i, inv_i) in iter::zip(&mut *output, &*inv) {
383 *out_i = transform_offset + transform_matrix.transform(inv_i);
384 }
385
386 Ok(())
387 }
388}
389
390fn pack_b8<const V: usize>(bits: [Col<B1, V>; 8]) -> Expr<B8, V> {
391 let b8_basis: [_; 8] = array::from_fn(ext_basis::<B8, B1>);
392 bits.into_iter()
393 .enumerate()
394 .map(|(i, bit)| upcast_col(bit) * b8_basis[i])
395 .reduce(|a, b| a + b)
396 .expect("bits has length 8")
397}
398
399#[cfg(test)]
400mod tests {
401 use std::iter::repeat_with;
402
403 use binius_field::{
404 arch::OptimalUnderlier128b, arithmetic_traits::InvertOrZero, as_packed_field::PackedType,
405 };
406 use binius_hash::groestl::{GroestlShortImpl, GroestlShortInternal};
407 use bumpalo::Bump;
408 use rand::{prelude::StdRng, SeedableRng};
409
410 use super::*;
411 use crate::builder::{ConstraintSystem, Statement};
412
413 #[test]
414 fn test_sbox() {
415 let mut cs = ConstraintSystem::new();
416 let mut table = cs.add_table("sbox test");
417
418 let input = table.add_committed::<B8, 2>("input");
419 let sbox = SBox::new(&mut table, input + B8::new(0xFF));
420
421 let table_id = table.id();
422
423 let allocator = Bump::new();
424
425 let statement = Statement {
426 boundaries: vec![],
427 table_sizes: vec![1 << 8],
428 };
429 let mut witness = cs
430 .build_witness::<PackedType<OptimalUnderlier128b, B128>>(&allocator, &statement)
431 .unwrap();
432
433 let table_witness = witness.get_table(table_id).unwrap();
434
435 let mut rng = StdRng::seed_from_u64(0);
436 let mut segment = table_witness.full_segment();
437 for in_i in &mut *segment.get_mut(input).unwrap() {
438 *in_i = PackedField::random(&mut rng);
439 }
440
441 sbox.populate(&mut segment).unwrap();
442
443 let ccs = cs.compile(&statement).unwrap();
444 let witness = witness.into_multilinear_extension_index();
445
446 binius_core::constraint_system::validate::validate_witness(&ccs, &[], &witness).unwrap();
447 }
448
449 #[test]
450 fn test_p_permutation() {
451 let mut cs = ConstraintSystem::new();
452 let mut table = cs.add_table("P-permutation test");
453
454 let input = table.add_committed_multiple::<B8, 8, 8>("state_in");
455 let perm = Permutation::new(&mut table, PermutationVariant::P, input);
456
457 let table_id = table.id();
458
459 let allocator = Bump::new();
460
461 let statement = Statement {
462 boundaries: vec![],
463 table_sizes: vec![1 << 8],
464 };
465 let mut witness = cs
466 .build_witness::<PackedType<OptimalUnderlier128b, B128>>(&allocator, &statement)
467 .unwrap();
468
469 let table_witness = witness.get_table(table_id).unwrap();
470
471 let mut rng = StdRng::seed_from_u64(0);
472 let in_states = repeat_with(|| array::from_fn::<_, 64, _>(|_| B8::random(&mut rng)))
473 .take(1 << 8)
474 .collect::<Vec<_>>();
475 let out_states = in_states
476 .iter()
477 .map(|in_state| {
478 let in_state_bytes = in_state.map(|b8| AESTowerField8b::from(b8).val());
479 let mut state = GroestlShortImpl::state_from_bytes(&in_state_bytes);
480 GroestlShortImpl::p_perm(&mut state);
481 let out_state_bytes = GroestlShortImpl::state_to_bytes(&state);
482 out_state_bytes.map(|byte| B8::from(AESTowerField8b::new(byte)))
483 })
484 .collect::<Vec<_>>();
485
486 let mut segment = table_witness.full_segment();
487 perm.populate_state_in(&mut segment, in_states.iter())
488 .unwrap();
489 perm.populate(&mut segment).unwrap();
490
491 for (expected_out, generated_out) in
492 iter::zip(out_states, perm.read_state_outs(&mut segment).unwrap())
493 {
494 assert_eq!(generated_out, expected_out);
495 }
496
497 let ccs = cs.compile(&statement).unwrap();
498 let witness = witness.into_multilinear_extension_index();
499
500 binius_core::constraint_system::validate::validate_witness(&ccs, &[], &witness).unwrap();
501 }
502
503 #[test]
504 fn test_q_permutation() {
505 let mut cs = ConstraintSystem::new();
506 let mut table = cs.add_table("Q-permutation test");
507
508 let input = table.add_committed_multiple::<B8, 8, 8>("state_in");
509 let perm = Permutation::new(&mut table, PermutationVariant::Q, input);
510
511 let table_id = table.id();
512
513 let allocator = Bump::new();
514
515 let statement = Statement {
516 boundaries: vec![],
517 table_sizes: vec![1 << 8],
518 };
519 let mut witness = cs
520 .build_witness::<PackedType<OptimalUnderlier128b, B128>>(&allocator, &statement)
521 .unwrap();
522
523 let table_witness = witness.get_table(table_id).unwrap();
524
525 let mut rng = StdRng::seed_from_u64(0);
526 let in_states = repeat_with(|| array::from_fn::<_, 64, _>(|_| B8::random(&mut rng)))
527 .take(1 << 8)
528 .collect::<Vec<_>>();
529 let out_states = in_states
530 .iter()
531 .map(|in_state| {
532 let in_state_bytes = in_state.map(|b8| AESTowerField8b::from(b8).val());
533 let mut state = GroestlShortImpl::state_from_bytes(&in_state_bytes);
534 GroestlShortImpl::q_perm(&mut state);
535 let out_state_bytes = GroestlShortImpl::state_to_bytes(&state);
536 out_state_bytes.map(|byte| B8::from(AESTowerField8b::new(byte)))
537 })
538 .collect::<Vec<_>>();
539
540 let mut segment = table_witness.full_segment();
541 perm.populate_state_in(&mut segment, in_states.iter())
542 .unwrap();
543 perm.populate(&mut segment).unwrap();
544
545 for (expected_out, generated_out) in
546 iter::zip(out_states, perm.read_state_outs(&mut segment).unwrap())
547 {
548 assert_eq!(generated_out, expected_out);
549 }
550
551 let ccs = cs.compile(&statement).unwrap();
552 let witness = witness.into_multilinear_extension_index();
553
554 binius_core::constraint_system::validate::validate_witness(&ccs, &[], &witness).unwrap();
555 }
556
557 #[test]
558 fn test_isomorphic_sbox() {
559 #[rustfmt::skip]
560 const S_BOX: [u8; 256] = [
561 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
562 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
563 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
564 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
565 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
566 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
567 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
568 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
569 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
570 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
571 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
572 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
573 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
574 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
575 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
576 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
577 ];
578
579 const S_BOX_MATRIX: FieldLinearTransformation<AESTowerField8b> =
580 FieldLinearTransformation::new_const(&[
581 AESTowerField8b::new(0x1F),
582 AESTowerField8b::new(0x3E),
583 AESTowerField8b::new(0x7C),
584 AESTowerField8b::new(0xF8),
585 AESTowerField8b::new(0xF1),
586 AESTowerField8b::new(0xE3),
587 AESTowerField8b::new(0xC7),
588 AESTowerField8b::new(0x8F),
589 ]);
590 const S_BOX_OFFSET: AESTowerField8b = AESTowerField8b::new(0x63);
591
592 for i in 0u8..=255u8 {
593 let sbox_in = AESTowerField8b::new(i);
594 let expected_sbox_out = AESTowerField8b::new(S_BOX[i as usize]);
595
596 let sbox_out =
597 S_BOX_MATRIX.transform(&InvertOrZero::invert_or_zero(sbox_in)) + S_BOX_OFFSET;
598 assert_eq!(sbox_out, expected_sbox_out);
599
600 let sbox_in_b8 = B8::from(sbox_in);
601 let sbox_out_b8 = S_BOX_TOWER_MATRIX
602 .transform(&InvertOrZero::invert_or_zero(sbox_in_b8))
603 + S_BOX_TOWER_OFFSET;
604 assert_eq!(AESTowerField8b::from(sbox_out_b8), expected_sbox_out);
605 }
606 }
607}