1use anyhow::anyhow;
4use binius_core::oracle::{OracleId, ShiftVariant};
5use binius_field::{BinaryField1b, BinaryField32b, Field, TowerField};
6use binius_macros::arith_expr;
7use binius_utils::checked_arithmetics::log2_ceil_usize;
8
9use crate::{
10 arithmetic::u32::LOG_U32_BITS,
11 builder::{types::F, ConstraintSystemBuilder},
12};
13
14const STATE_SIZE: usize = 32;
15
16const SINGLE_COMPRESSION_N_VARS: usize = 6;
18
19const TEMP_STATE_OUT_INDEX: usize = 56;
21
22const TEMP_STATE_OUT_INDEX_BINARY: [F; SINGLE_COMPRESSION_N_VARS] = [
24 Field::ZERO,
25 Field::ZERO,
26 Field::ZERO,
27 Field::ONE,
28 Field::ONE,
29 Field::ONE,
30];
31
32const SINGLE_COMPRESSION_HEIGHT: usize = 2usize.pow(SINGLE_COMPRESSION_N_VARS as u32);
34
35const OUT_HEIGHT: usize = 8;
37
38const ADDITION_OPERATIONS_NUMBER: usize = 6;
40
41const IV: [u32; 8] = [
43 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
44];
45
46const MSG_PERMUTATION: [usize; 16] = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8];
48
49#[derive(Debug, Default, Copy, Clone)]
50pub struct Blake3CompressState {
51 pub cv: [u32; 8],
52 pub block: [u32; 16],
53 pub counter_low: u32,
54 pub counter_high: u32,
55 pub block_len: u32,
56 pub flags: u32,
57}
58
59pub struct Blake3CompressOracles {
60 pub input: [OracleId; STATE_SIZE],
61 pub output: [OracleId; STATE_SIZE],
62}
63
64type F32 = BinaryField32b;
65type F1 = BinaryField1b;
66
67pub fn blake3_compress(
68 builder: &mut ConstraintSystemBuilder,
69 input_witness: &Option<impl AsRef<[Blake3CompressState]>>,
70 states_amount: usize,
71) -> Result<Blake3CompressOracles, anyhow::Error> {
72 let state_n_vars = log2_ceil_usize(states_amount * SINGLE_COMPRESSION_HEIGHT);
74 let state_transitions: [OracleId; STATE_SIZE] =
75 builder.add_committed_multiple("state_transitions", state_n_vars, F32::TOWER_LEVEL);
76
77 let input: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
79 builder.add_projected(
80 "input",
81 state_transitions[xy],
82 vec![F::ZERO; SINGLE_COMPRESSION_N_VARS],
83 0,
84 )
85 })?;
86
87 let output: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
89 builder.add_projected(
90 "output",
91 state_transitions[xy],
92 TEMP_STATE_OUT_INDEX_BINARY.to_vec(),
93 0,
94 )
95 })?;
96
97 let out_n_vars = log2_ceil_usize(states_amount * OUT_HEIGHT);
99 let cv: OracleId = builder.add_committed("cv", out_n_vars, F32::TOWER_LEVEL);
100 let state_i = builder.add_committed("state_i", out_n_vars, F32::TOWER_LEVEL);
101 let state_i_8 = builder.add_committed("state_i_8", out_n_vars, F32::TOWER_LEVEL);
102
103 let state_i_xor_state_i_8 = builder.add_linear_combination(
104 "state_i_xor_state_i_8",
105 out_n_vars,
106 [(state_i, F::ONE), (state_i_8, F::ONE)],
107 )?;
108
109 let cv_oracle_xor_state_i_8 = builder.add_linear_combination(
110 "cv_oracle_xor_state_i_8",
111 out_n_vars,
112 [(cv, F::ONE), (state_i_8, F::ONE)],
113 )?;
114
115 let a_in: OracleId = builder.add_committed("a_in", state_n_vars + 5, F1::TOWER_LEVEL);
117 let b_in: OracleId = builder.add_committed("b_in", state_n_vars + 5, F1::TOWER_LEVEL);
118 let c_in: OracleId = builder.add_committed("c_in", state_n_vars + 5, F1::TOWER_LEVEL);
119 let d_in: OracleId = builder.add_committed("d_in", state_n_vars + 5, F1::TOWER_LEVEL);
120 let mx_in: OracleId = builder.add_committed("mx_in", state_n_vars + 5, F1::TOWER_LEVEL);
121 let my_in: OracleId = builder.add_committed("my_in", state_n_vars + 5, F1::TOWER_LEVEL);
122 let a_0_tmp: OracleId = builder.add_committed("a_0_tmp", state_n_vars + 5, F1::TOWER_LEVEL);
123 let a_0: OracleId = builder.add_committed("a_0", state_n_vars + 5, F1::TOWER_LEVEL);
124 let c_0: OracleId = builder.add_committed("c_0", state_n_vars + 5, F1::TOWER_LEVEL);
125 let b_in_xor_c_0: OracleId = builder.add_linear_combination(
126 "b_in_xor_c_0",
127 state_n_vars + 5,
128 [(b_in, F::ONE), (c_0, F::ONE)],
129 )?;
130 let b_0: OracleId = builder.add_shifted(
131 "d_1",
132 b_in_xor_c_0,
133 (32 - 12) as usize,
134 LOG_U32_BITS,
135 ShiftVariant::CircularLeft,
136 )?;
137 let d_in_xor_a_0: OracleId = builder.add_linear_combination(
138 "d_in_xor_a_0",
139 state_n_vars + 5,
140 [(d_in, F::ONE), (a_0, F::ONE)],
141 )?;
142 let d_0: OracleId = builder.add_shifted(
143 "d_0",
144 d_in_xor_a_0,
145 (32 - 16) as usize,
146 LOG_U32_BITS,
147 ShiftVariant::CircularLeft,
148 )?;
149 let a_1_tmp: OracleId = builder.add_committed("a_1_tmp", state_n_vars + 5, F1::TOWER_LEVEL);
150 let a_1: OracleId = builder.add_committed("a_1", state_n_vars + 5, F1::TOWER_LEVEL);
151 let d_0_xor_a_1: OracleId = builder.add_linear_combination(
152 "d_0_xor_a_1",
153 state_n_vars + 5,
154 [(d_0, F::ONE), (a_1, F::ONE)],
155 )?;
156 let d_1: OracleId = builder.add_shifted(
157 "d_1",
158 d_0_xor_a_1,
159 (32 - 8) as usize,
160 LOG_U32_BITS,
161 ShiftVariant::CircularLeft,
162 )?;
163 let c_1: OracleId = builder.add_committed("c_1", state_n_vars + 5, F1::TOWER_LEVEL);
164 let b_0_xor_c_1: OracleId = builder.add_linear_combination(
165 "b_0_xor_c_1",
166 state_n_vars + 5,
167 [(b_0, F::ONE), (c_1, F::ONE)],
168 )?;
169 let b_1: OracleId = builder.add_shifted(
170 "b_1",
171 b_0_xor_c_1,
172 (32 - 7) as usize,
173 LOG_U32_BITS,
174 ShiftVariant::CircularLeft,
175 )?;
176
177 let cout: [OracleId; ADDITION_OPERATIONS_NUMBER] =
178 builder.add_committed_multiple("cout", state_n_vars + 5, F1::TOWER_LEVEL);
179 let cin: [OracleId; ADDITION_OPERATIONS_NUMBER] = array_util::try_from_fn(|xy| {
180 builder.add_shifted("cin", cout[xy], 1, 5, ShiftVariant::LogicalLeft)
181 })?;
182
183 if let Some(witness) = builder.witness() {
185 let input_witness = input_witness
186 .as_ref()
187 .ok_or_else(|| anyhow!("builder witness available and input witness is not"))?
188 .as_ref();
189
190 let mut state_cols = state_transitions.map(|id| witness.new_column::<F32>(id));
193 let mut input_cols = input.map(|id| witness.new_column::<F32>(id));
194 let mut output_cols = output.map(|id| witness.new_column::<F32>(id));
195
196 let mut cv_col = witness.new_column::<F32>(cv);
197 let mut state_i_col = witness.new_column::<F32>(state_i);
198 let mut state_i_8_col = witness.new_column::<F32>(state_i_8);
199 let mut state_i_xor_state_i_8_col = witness.new_column::<F32>(state_i_xor_state_i_8);
200 let mut cv_oracle_xor_state_i_8_col = witness.new_column::<F32>(cv_oracle_xor_state_i_8);
201
202 let mut a_in_col = witness.new_column::<F1>(a_in);
203 let mut b_in_col = witness.new_column::<F1>(b_in);
204 let mut c_in_col = witness.new_column::<F1>(c_in);
205 let mut d_in_col = witness.new_column::<F1>(d_in);
206 let mut mx_in_col = witness.new_column::<F1>(mx_in);
207 let mut my_in_col = witness.new_column::<F1>(my_in);
208 let mut a_0_tmp_col = witness.new_column::<F1>(a_0_tmp);
209 let mut a_0_col = witness.new_column::<F1>(a_0);
210 let mut b_in_xor_c_0_col = witness.new_column::<F1>(b_in_xor_c_0);
211 let mut b_0_col = witness.new_column::<F1>(b_0);
212 let mut c_0_col = witness.new_column::<F1>(c_0);
213 let mut d_in_xor_a_0_col = witness.new_column::<F1>(d_in_xor_a_0);
214 let mut d_0_col = witness.new_column::<F1>(d_0);
215 let mut a_1_tmp_col = witness.new_column::<F1>(a_1_tmp);
216 let mut a_1_col = witness.new_column::<F1>(a_1);
217 let mut d_0_xor_a_1_col = witness.new_column::<F1>(d_0_xor_a_1);
218 let mut d_1_col = witness.new_column::<F1>(d_1);
219 let mut c_1_col = witness.new_column::<F1>(c_1);
220 let mut b_0_xor_c_1_col = witness.new_column::<F1>(b_0_xor_c_1);
221 let mut b_1_col = witness.new_column::<F1>(b_1);
222 let mut cout_cols = cout.map(|id| witness.new_column::<F1>(id));
223 let mut cin_cols = cin.map(|id| witness.new_column::<F1>(id));
224
225 let state_vals = state_cols.each_mut().map(|col| col.as_mut_slice::<u32>());
228 let input_vals = input_cols.each_mut().map(|col| col.as_mut_slice::<u32>());
229 let output_vals = output_cols.each_mut().map(|col| col.as_mut_slice::<u32>());
230
231 let cv_vals = cv_col.as_mut_slice::<u32>();
232 let state_i_vals = state_i_col.as_mut_slice::<u32>();
233 let state_i_8_vals = state_i_8_col.as_mut_slice::<u32>();
234 let state_i_xor_state_i_8_vals = state_i_xor_state_i_8_col.as_mut_slice::<u32>();
235 let cv_oracle_xor_state_i_8_vals = cv_oracle_xor_state_i_8_col.as_mut_slice::<u32>();
236
237 let a_in_vals = a_in_col.as_mut_slice::<u32>();
238 let b_in_vals = b_in_col.as_mut_slice::<u32>();
239 let c_in_vals = c_in_col.as_mut_slice::<u32>();
240 let d_in_vals = d_in_col.as_mut_slice::<u32>();
241 let mx_in_vals = mx_in_col.as_mut_slice::<u32>();
242 let my_in_vals = my_in_col.as_mut_slice::<u32>();
243 let a_0_tmp_vals = a_0_tmp_col.as_mut_slice::<u32>();
244 let a_0_vals = a_0_col.as_mut_slice::<u32>();
245 let b_in_xor_c_0_vals = b_in_xor_c_0_col.as_mut_slice::<u32>();
246 let b_0_vals = b_0_col.as_mut_slice::<u32>();
247 let c_0_vals = c_0_col.as_mut_slice::<u32>();
248 let d_in_xor_a_0_vals = d_in_xor_a_0_col.as_mut_slice::<u32>();
249 let d_0_vals = d_0_col.as_mut_slice::<u32>();
250 let a_1_tmp_vals = a_1_tmp_col.as_mut_slice::<u32>();
251 let a_1_vals = a_1_col.as_mut_slice::<u32>();
252 let d_0_xor_a_1_vals = d_0_xor_a_1_col.as_mut_slice::<u32>();
253 let d_1_vals = d_1_col.as_mut_slice::<u32>();
254 let c_1_vals = c_1_col.as_mut_slice::<u32>();
255 let b_0_xor_c_1_vals = b_0_xor_c_1_col.as_mut_slice::<u32>();
256 let b_1_vals = b_1_col.as_mut_slice::<u32>();
257
258 let cout_vals = cout_cols.each_mut().map(|col| col.as_mut_slice::<u32>());
259 let cin_vals = cin_cols.each_mut().map(|col| col.as_mut_slice::<u32>());
260
261 let a = [0, 1, 2, 3, 0, 1, 2, 3];
266 let b = [4, 5, 6, 7, 5, 6, 7, 4];
267 let c = [8, 9, 10, 11, 10, 11, 8, 9];
268 let d = [12, 13, 14, 15, 15, 12, 13, 14];
269
270 let mx = [16, 18, 20, 22, 24, 26, 28, 30];
272 let my = [17, 19, 21, 23, 25, 27, 29, 31];
273
274 let mut compression_offset = 0usize;
275 for compression_idx in 0..states_amount {
276 let state = input_witness
277 .get(compression_idx)
278 .copied()
279 .unwrap_or_default();
280
281 let mut state_idx = 0;
282
283 for i in 0..state.cv.len() {
285 state_vals[state_idx][compression_offset] = state.cv[i];
286 state_idx += 1;
287 }
288
289 state_vals[state_idx][compression_offset] = IV[0];
290 state_vals[state_idx + 1][compression_offset] = IV[1];
291 state_vals[state_idx + 2][compression_offset] = IV[2];
292 state_vals[state_idx + 3][compression_offset] = IV[3];
293 state_vals[state_idx + 4][compression_offset] = state.counter_low;
294 state_vals[state_idx + 5][compression_offset] = state.counter_high;
295 state_vals[state_idx + 6][compression_offset] = state.block_len;
296 state_vals[state_idx + 7][compression_offset] = state.flags;
297
298 state_idx += 8;
299
300 for i in 0..state.block.len() {
301 state_vals[state_idx][compression_offset] = state.block[i];
302 state_idx += 1;
303 }
304
305 for xy in 0..STATE_SIZE {
307 input_vals[xy][compression_idx] = state_vals[xy][compression_offset];
308 }
309
310 assert_eq!(state_idx, STATE_SIZE);
311
312 let mut state_offset = 1usize;
314 let mut temp_vars_offset = 0usize;
315
316 fn add(a: u32, b: u32) -> (u32, u32, u32) {
317 let zout;
318 let carry;
319
320 (zout, carry) = a.overflowing_add(b);
321 let cin = a ^ b ^ zout;
322 let cout = ((carry as u32) << 31) | (cin >> 1);
323
324 (cin, cout, zout)
325 }
326
327 for round_idx in 0..7 {
329 for j in 0..8 {
330 let state_transition_idx = state_offset + compression_offset;
331 let var_offset = temp_vars_offset + compression_offset;
332 let mut add_offset = 0usize;
333
334 #[allow(clippy::needless_range_loop)]
336 for i in 0..STATE_SIZE {
337 state_vals[i][state_transition_idx] =
338 state_vals[i][state_transition_idx - 1];
339 }
340
341 a_in_vals[var_offset] = state_vals[a[j]][state_transition_idx - 1];
343 b_in_vals[var_offset] = state_vals[b[j]][state_transition_idx - 1];
344 c_in_vals[var_offset] = state_vals[c[j]][state_transition_idx - 1];
345 d_in_vals[var_offset] = state_vals[d[j]][state_transition_idx - 1];
346 mx_in_vals[var_offset] = state_vals[mx[j]][state_transition_idx - 1];
347 my_in_vals[var_offset] = state_vals[my[j]][state_transition_idx - 1];
348
349 (
352 cin_vals[add_offset][var_offset],
353 cout_vals[add_offset][var_offset],
354 a_0_tmp_vals[var_offset],
355 ) = add(a_in_vals[var_offset], b_in_vals[var_offset]);
356 add_offset += 1;
357
358 (
359 cin_vals[add_offset][var_offset],
360 cout_vals[add_offset][var_offset],
361 a_0_vals[var_offset],
362 ) = add(a_0_tmp_vals[var_offset], mx_in_vals[var_offset]);
363 add_offset += 1;
364
365 d_in_xor_a_0_vals[var_offset] = d_in_vals[var_offset] ^ a_0_vals[var_offset];
366
367 d_0_vals[var_offset] = d_in_xor_a_0_vals[var_offset].rotate_right(16);
368
369 (
370 cin_vals[add_offset][var_offset],
371 cout_vals[add_offset][var_offset],
372 c_0_vals[var_offset],
373 ) = add(c_in_vals[var_offset], d_0_vals[var_offset]);
374 add_offset += 1;
375
376 b_in_xor_c_0_vals[var_offset] = b_in_vals[var_offset] ^ c_0_vals[var_offset];
377
378 b_0_vals[var_offset] = b_in_xor_c_0_vals[var_offset].rotate_right(12);
379
380 (
381 cin_vals[add_offset][var_offset],
382 cout_vals[add_offset][var_offset],
383 a_1_tmp_vals[var_offset],
384 ) = add(a_0_vals[var_offset], b_0_vals[var_offset]);
385 add_offset += 1;
386
387 (
388 cin_vals[add_offset][var_offset],
389 cout_vals[add_offset][var_offset],
390 a_1_vals[var_offset],
391 ) = add(a_1_tmp_vals[var_offset], my_in_vals[var_offset]);
392 add_offset += 1;
393
394 d_0_xor_a_1_vals[var_offset] = d_0_vals[var_offset] ^ a_1_vals[var_offset];
395
396 d_1_vals[var_offset] = d_0_xor_a_1_vals[var_offset].rotate_right(8);
397
398 (
399 cin_vals[add_offset][var_offset],
400 cout_vals[add_offset][var_offset],
401 c_1_vals[var_offset],
402 ) = add(c_0_vals[var_offset], d_1_vals[var_offset]);
403 add_offset += 1;
404
405 b_0_xor_c_1_vals[var_offset] = b_0_vals[var_offset] ^ c_1_vals[var_offset];
406
407 b_1_vals[var_offset] = b_0_xor_c_1_vals[var_offset].rotate_right(7);
408
409 state_vals[a[j]][state_transition_idx] = a_1_vals[var_offset];
411 state_vals[b[j]][state_transition_idx] = b_1_vals[var_offset];
412 state_vals[c[j]][state_transition_idx] = c_1_vals[var_offset];
413 state_vals[d[j]][state_transition_idx] = d_1_vals[var_offset];
414
415 state_offset += 1;
416 temp_vars_offset += 1;
417 assert_eq!(add_offset, ADDITION_OPERATIONS_NUMBER);
418 }
419
420 if round_idx < 6 {
422 let mut permuted = [0u32; 16];
423 for i in 0..16 {
424 permuted[i] = state_vals[16 + MSG_PERMUTATION[i]]
425 [state_offset + compression_offset - 1];
426 }
427
428 for i in 0..16 {
429 state_vals[16 + i][state_offset + compression_offset - 1] = permuted[i];
430 }
431 }
432 }
433
434 assert_eq!(state_offset, TEMP_STATE_OUT_INDEX + 1);
435
436 for i in 0..8 {
437 cv_vals[i * compression_idx + i] = state_vals[i][compression_offset];
439 state_i_vals[i * compression_idx + i] =
440 state_vals[i][state_offset + compression_offset - 1];
441 state_i_8_vals[i * compression_idx + i] =
442 state_vals[i + 8][state_offset + compression_offset - 1];
443
444 state_vals[i][state_offset + compression_offset - 1] ^=
446 state_vals[i + 8][state_offset + compression_offset - 1];
447
448 state_i_xor_state_i_8_vals[i * compression_idx + i] =
450 state_vals[i][state_offset + compression_offset - 1];
451
452 state_vals[i + 8][state_offset + compression_offset - 1] ^=
454 state_vals[i][compression_offset];
455
456 cv_oracle_xor_state_i_8_vals[i * compression_idx + i] =
458 state_vals[i + 8][state_offset + compression_offset - 1];
459 }
460
461 for i in 0..STATE_SIZE {
463 output_vals[i][compression_idx] =
464 state_vals[i][state_offset + compression_offset - 1];
465 }
466
467 compression_offset += SINGLE_COMPRESSION_HEIGHT;
468 }
469 }
470
471 let xins = [a_in, a_0_tmp, c_in, a_0, a_1_tmp, c_0];
477 let yins = [b_in, mx_in, d_0, b_0, my_in, d_1];
478 let zouts = [a_0_tmp, a_0, c_0, a_1_tmp, a_1, c_1];
479
480 for (idx, (xin, (yin, zout))) in xins
481 .into_iter()
482 .zip(yins.into_iter().zip(zouts.into_iter()))
483 .enumerate()
484 {
485 builder.assert_zero(
486 format!("sum{idx}"),
487 [xin, yin, cin[idx], zout],
488 arith_expr!([xin, yin, cin, zout] = xin + yin + cin - zout).convert_field(),
489 );
490
491 builder.assert_zero(
492 format!("carry{idx}"),
493 [xin, yin, cin[idx], cout[idx]],
494 arith_expr!([xin, yin, cin, cout] = (xin + cin) * (yin + cin) + cin - cout)
495 .convert_field(),
496 );
497 }
498
499 Ok(Blake3CompressOracles { input, output })
500}
501
502#[cfg(test)]
503mod tests {
504 use std::array;
505
506 use rand::{rngs::StdRng, Rng, SeedableRng};
507
508 use crate::{
509 blake3::{blake3_compress, Blake3CompressState, F32, IV, MSG_PERMUTATION},
510 builder::test_utils::test_circuit,
511 };
512
513 fn compress(
516 chaining_value: &[u32; 8],
517 block_words: &[u32; 16],
518 counter: u64,
519 block_len: u32,
520 flags: u32,
521 ) -> [u32; 16] {
522 let counter_low = counter as u32;
523 let counter_high = (counter >> 32) as u32;
524
525 #[rustfmt::skip]
526 let mut state = [
527 chaining_value[0], chaining_value[1], chaining_value[2], chaining_value[3],
528 chaining_value[4], chaining_value[5], chaining_value[6], chaining_value[7],
529 IV[0], IV[1], IV[2], IV[3],
530 counter_low, counter_high, block_len, flags,
531 block_words[0], block_words[1], block_words[2], block_words[3],
532 block_words[4], block_words[5], block_words[6], block_words[7],
533 block_words[8], block_words[9], block_words[10], block_words[11],
534 block_words[12], block_words[13], block_words[14], block_words[15],
535 ];
536
537 let a = [0, 1, 2, 3, 0, 1, 2, 3];
538 let b = [4, 5, 6, 7, 5, 6, 7, 4];
539 let c = [8, 9, 10, 11, 10, 11, 8, 9];
540 let d = [12, 13, 14, 15, 15, 12, 13, 14];
541 let mx = [16, 18, 20, 22, 24, 26, 28, 30];
542 let my = [17, 19, 21, 23, 25, 27, 29, 31];
543
544 for round_idx in 0..7 {
546 for j in 0..8 {
547 let a_in = state[a[j]];
548 let b_in = state[b[j]];
549 let c_in = state[c[j]];
550 let d_in = state[d[j]];
551 let mx_in = state[mx[j]];
552 let my_in = state[my[j]];
553
554 let a_0 = a_in.wrapping_add(b_in).wrapping_add(mx_in);
555 let d_0 = (d_in ^ a_0).rotate_right(16);
556 let c_0 = c_in.wrapping_add(d_0);
557 let b_0 = (b_in ^ c_0).rotate_right(12);
558
559 let a_1 = a_0.wrapping_add(b_0).wrapping_add(my_in);
560 let d_1 = (d_0 ^ a_1).rotate_right(8);
561 let c_1 = c_0.wrapping_add(d_1);
562 let b_1 = (b_0 ^ c_1).rotate_right(7);
563
564 state[a[j]] = a_1;
565 state[b[j]] = b_1;
566 state[c[j]] = c_1;
567 state[d[j]] = d_1;
568 }
569
570 if round_idx < 6 {
572 let mut permuted = [0; 16];
573 for i in 0..16 {
574 permuted[i] = state[16 + MSG_PERMUTATION[i]];
575 }
576 state[16..32].copy_from_slice(&permuted);
577 }
578 }
579
580 for i in 0..8 {
581 state[i] ^= state[i + 8];
582 state[i + 8] ^= chaining_value[i];
583 }
584
585 let state_out: [u32; 16] = std::array::from_fn(|i| state[i]);
586 state_out
587 }
588
589 #[test]
590 fn test_blake3_compression() {
591 test_circuit(|builder| {
592 let compressions = 8;
593 let mut rng = StdRng::seed_from_u64(0);
594 let mut expected = vec![];
595 let states = (0..compressions)
596 .map(|_| {
597 let cv: [u32; 8] = array::from_fn(|_| rng.gen::<u32>());
598 let block: [u32; 16] = array::from_fn(|_| rng.gen::<u32>());
599 let counter = rng.gen::<u64>();
600 let counter_low = counter as u32;
601 let counter_high = (counter >> 32) as u32;
602 let block_len = rng.gen::<u32>();
603 let flags = rng.gen::<u32>();
604
605 expected.push(compress(&cv, &block, counter, block_len, flags).to_vec());
607
608 Blake3CompressState {
609 cv,
610 block,
611 counter_low,
612 counter_high,
613 block_len,
614 flags,
615 }
616 })
617 .collect::<Vec<Blake3CompressState>>();
618
619 let expected = transpose(expected);
621
622 let states_len = states.len();
623 let state_out = blake3_compress(builder, &Some(states), states_len)?;
624 if let Some(witness) = builder.witness() {
625 for (i, expected_i) in expected.into_iter().enumerate() {
626 let actual = witness
627 .get::<F32>(state_out.output[i])
628 .unwrap()
629 .as_slice::<u32>();
630 let len = expected_i.len();
631 assert_eq!(actual[..len], expected_i);
632 }
633 }
634 Ok(vec![])
635 })
636 .unwrap();
637 }
638
639 fn transpose<T>(v: Vec<Vec<T>>) -> Vec<Vec<T>>
640 where
641 T: Clone,
642 {
643 assert!(!v.is_empty());
644 (0..v[0].len())
645 .map(|i| v.iter().map(|inner| inner[i].clone()).collect::<Vec<T>>())
646 .collect()
647 }
648}