binius_circuits/
blake3.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use anyhow::anyhow;
4use binius_core::oracle::{OracleId, ShiftVariant};
5use binius_field::{BinaryField1b, BinaryField32b, Field, TowerField};
6use binius_macros::arith_expr;
7use binius_utils::checked_arithmetics::log2_ceil_usize;
8
9use crate::{
10	arithmetic::u32::LOG_U32_BITS,
11	builder::{types::F, ConstraintSystemBuilder},
12};
13
14const STATE_SIZE: usize = 32;
15
16// This defines how long state columns should be
17const SINGLE_COMPRESSION_N_VARS: usize = 6;
18
19// Number of initial state mutations until getting output value
20const TEMP_STATE_OUT_INDEX: usize = 56;
21
22// Number of initial state mutations (TEMP_STATE_OUT_INDEX) in "binary" form
23const TEMP_STATE_OUT_INDEX_BINARY: [F; SINGLE_COMPRESSION_N_VARS] = [
24	Field::ZERO,
25	Field::ZERO,
26	Field::ZERO,
27	Field::ONE,
28	Field::ONE,
29	Field::ONE,
30];
31
32// Defines overall N_VARS for state transition columns
33const SINGLE_COMPRESSION_HEIGHT: usize = 2usize.pow(SINGLE_COMPRESSION_N_VARS as u32);
34
35// Deifines N_VARS for so-called 'out' columns used for finalising every compression
36const OUT_HEIGHT: usize = 8;
37
38// Defines how many temp U32 additions are involved
39const ADDITION_OPERATIONS_NUMBER: usize = 6;
40
41// Blake3 specific constant
42const IV: [u32; 8] = [
43	0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
44];
45
46// Blake3 specific constant
47const MSG_PERMUTATION: [usize; 16] = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8];
48
49#[derive(Debug, Default, Copy, Clone)]
50pub struct Blake3CompressState {
51	pub cv: [u32; 8],
52	pub block: [u32; 16],
53	pub counter_low: u32,
54	pub counter_high: u32,
55	pub block_len: u32,
56	pub flags: u32,
57}
58
59pub struct Blake3CompressOracles {
60	pub input: [OracleId; STATE_SIZE],
61	pub output: [OracleId; STATE_SIZE],
62}
63
64type F32 = BinaryField32b;
65type F1 = BinaryField1b;
66
67pub fn blake3_compress(
68	builder: &mut ConstraintSystemBuilder,
69	input_witness: &Option<impl AsRef<[Blake3CompressState]>>,
70	states_amount: usize,
71) -> Result<Blake3CompressOracles, anyhow::Error> {
72	// state
73	let state_n_vars = log2_ceil_usize(states_amount * SINGLE_COMPRESSION_HEIGHT);
74	let state_transitions: [OracleId; STATE_SIZE] =
75		builder.add_committed_multiple("state_transitions", state_n_vars, F32::TOWER_LEVEL);
76
77	// input
78	let input: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
79		builder.add_projected(
80			"input",
81			state_transitions[xy],
82			vec![F::ZERO; SINGLE_COMPRESSION_N_VARS],
83			0,
84		)
85	})?;
86
87	// output
88	let output: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
89		builder.add_projected(
90			"output",
91			state_transitions[xy],
92			TEMP_STATE_OUT_INDEX_BINARY.to_vec(),
93			0,
94		)
95	})?;
96
97	// columns for enforcing cv computation
98	let out_n_vars = log2_ceil_usize(states_amount * OUT_HEIGHT);
99	let cv: OracleId = builder.add_committed("cv", out_n_vars, F32::TOWER_LEVEL);
100	let state_i = builder.add_committed("state_i", out_n_vars, F32::TOWER_LEVEL);
101	let state_i_8 = builder.add_committed("state_i_8", out_n_vars, F32::TOWER_LEVEL);
102
103	let state_i_xor_state_i_8 = builder.add_linear_combination(
104		"state_i_xor_state_i_8",
105		out_n_vars,
106		[(state_i, F::ONE), (state_i_8, F::ONE)],
107	)?;
108
109	let cv_oracle_xor_state_i_8 = builder.add_linear_combination(
110		"cv_oracle_xor_state_i_8",
111		out_n_vars,
112		[(cv, F::ONE), (state_i_8, F::ONE)],
113	)?;
114
115	// columns for enforcing correct computation of temp variables
116	let a_in: OracleId = builder.add_committed("a_in", state_n_vars + 5, F1::TOWER_LEVEL);
117	let b_in: OracleId = builder.add_committed("b_in", state_n_vars + 5, F1::TOWER_LEVEL);
118	let c_in: OracleId = builder.add_committed("c_in", state_n_vars + 5, F1::TOWER_LEVEL);
119	let d_in: OracleId = builder.add_committed("d_in", state_n_vars + 5, F1::TOWER_LEVEL);
120	let mx_in: OracleId = builder.add_committed("mx_in", state_n_vars + 5, F1::TOWER_LEVEL);
121	let my_in: OracleId = builder.add_committed("my_in", state_n_vars + 5, F1::TOWER_LEVEL);
122	let a_0_tmp: OracleId = builder.add_committed("a_0_tmp", state_n_vars + 5, F1::TOWER_LEVEL);
123	let a_0: OracleId = builder.add_committed("a_0", state_n_vars + 5, F1::TOWER_LEVEL);
124	let c_0: OracleId = builder.add_committed("c_0", state_n_vars + 5, F1::TOWER_LEVEL);
125	let b_in_xor_c_0: OracleId = builder.add_linear_combination(
126		"b_in_xor_c_0",
127		state_n_vars + 5,
128		[(b_in, F::ONE), (c_0, F::ONE)],
129	)?;
130	let b_0: OracleId = builder.add_shifted(
131		"d_1",
132		b_in_xor_c_0,
133		(32 - 12) as usize,
134		LOG_U32_BITS,
135		ShiftVariant::CircularLeft,
136	)?;
137	let d_in_xor_a_0: OracleId = builder.add_linear_combination(
138		"d_in_xor_a_0",
139		state_n_vars + 5,
140		[(d_in, F::ONE), (a_0, F::ONE)],
141	)?;
142	let d_0: OracleId = builder.add_shifted(
143		"d_0",
144		d_in_xor_a_0,
145		(32 - 16) as usize,
146		LOG_U32_BITS,
147		ShiftVariant::CircularLeft,
148	)?;
149	let a_1_tmp: OracleId = builder.add_committed("a_1_tmp", state_n_vars + 5, F1::TOWER_LEVEL);
150	let a_1: OracleId = builder.add_committed("a_1", state_n_vars + 5, F1::TOWER_LEVEL);
151	let d_0_xor_a_1: OracleId = builder.add_linear_combination(
152		"d_0_xor_a_1",
153		state_n_vars + 5,
154		[(d_0, F::ONE), (a_1, F::ONE)],
155	)?;
156	let d_1: OracleId = builder.add_shifted(
157		"d_1",
158		d_0_xor_a_1,
159		(32 - 8) as usize,
160		LOG_U32_BITS,
161		ShiftVariant::CircularLeft,
162	)?;
163	let c_1: OracleId = builder.add_committed("c_1", state_n_vars + 5, F1::TOWER_LEVEL);
164	let b_0_xor_c_1: OracleId = builder.add_linear_combination(
165		"b_0_xor_c_1",
166		state_n_vars + 5,
167		[(b_0, F::ONE), (c_1, F::ONE)],
168	)?;
169	let b_1: OracleId = builder.add_shifted(
170		"b_1",
171		b_0_xor_c_1,
172		(32 - 7) as usize,
173		LOG_U32_BITS,
174		ShiftVariant::CircularLeft,
175	)?;
176
177	let cout: [OracleId; ADDITION_OPERATIONS_NUMBER] =
178		builder.add_committed_multiple("cout", state_n_vars + 5, F1::TOWER_LEVEL);
179	let cin: [OracleId; ADDITION_OPERATIONS_NUMBER] = array_util::try_from_fn(|xy| {
180		builder.add_shifted("cin", cout[xy], 1, 5, ShiftVariant::LogicalLeft)
181	})?;
182
183	// witness population (columns creation and data writing)
184	if let Some(witness) = builder.witness() {
185		let input_witness = input_witness
186			.as_ref()
187			.ok_or_else(|| anyhow!("builder witness available and input witness is not"))?
188			.as_ref();
189
190		// columns creation
191
192		let mut state_cols = state_transitions.map(|id| witness.new_column::<F32>(id));
193		let mut input_cols = input.map(|id| witness.new_column::<F32>(id));
194		let mut output_cols = output.map(|id| witness.new_column::<F32>(id));
195
196		let mut cv_col = witness.new_column::<F32>(cv);
197		let mut state_i_col = witness.new_column::<F32>(state_i);
198		let mut state_i_8_col = witness.new_column::<F32>(state_i_8);
199		let mut state_i_xor_state_i_8_col = witness.new_column::<F32>(state_i_xor_state_i_8);
200		let mut cv_oracle_xor_state_i_8_col = witness.new_column::<F32>(cv_oracle_xor_state_i_8);
201
202		let mut a_in_col = witness.new_column::<F1>(a_in);
203		let mut b_in_col = witness.new_column::<F1>(b_in);
204		let mut c_in_col = witness.new_column::<F1>(c_in);
205		let mut d_in_col = witness.new_column::<F1>(d_in);
206		let mut mx_in_col = witness.new_column::<F1>(mx_in);
207		let mut my_in_col = witness.new_column::<F1>(my_in);
208		let mut a_0_tmp_col = witness.new_column::<F1>(a_0_tmp);
209		let mut a_0_col = witness.new_column::<F1>(a_0);
210		let mut b_in_xor_c_0_col = witness.new_column::<F1>(b_in_xor_c_0);
211		let mut b_0_col = witness.new_column::<F1>(b_0);
212		let mut c_0_col = witness.new_column::<F1>(c_0);
213		let mut d_in_xor_a_0_col = witness.new_column::<F1>(d_in_xor_a_0);
214		let mut d_0_col = witness.new_column::<F1>(d_0);
215		let mut a_1_tmp_col = witness.new_column::<F1>(a_1_tmp);
216		let mut a_1_col = witness.new_column::<F1>(a_1);
217		let mut d_0_xor_a_1_col = witness.new_column::<F1>(d_0_xor_a_1);
218		let mut d_1_col = witness.new_column::<F1>(d_1);
219		let mut c_1_col = witness.new_column::<F1>(c_1);
220		let mut b_0_xor_c_1_col = witness.new_column::<F1>(b_0_xor_c_1);
221		let mut b_1_col = witness.new_column::<F1>(b_1);
222		let mut cout_cols = cout.map(|id| witness.new_column::<F1>(id));
223		let mut cin_cols = cin.map(|id| witness.new_column::<F1>(id));
224
225		// values
226
227		let state_vals = state_cols.each_mut().map(|col| col.as_mut_slice::<u32>());
228		let input_vals = input_cols.each_mut().map(|col| col.as_mut_slice::<u32>());
229		let output_vals = output_cols.each_mut().map(|col| col.as_mut_slice::<u32>());
230
231		let cv_vals = cv_col.as_mut_slice::<u32>();
232		let state_i_vals = state_i_col.as_mut_slice::<u32>();
233		let state_i_8_vals = state_i_8_col.as_mut_slice::<u32>();
234		let state_i_xor_state_i_8_vals = state_i_xor_state_i_8_col.as_mut_slice::<u32>();
235		let cv_oracle_xor_state_i_8_vals = cv_oracle_xor_state_i_8_col.as_mut_slice::<u32>();
236
237		let a_in_vals = a_in_col.as_mut_slice::<u32>();
238		let b_in_vals = b_in_col.as_mut_slice::<u32>();
239		let c_in_vals = c_in_col.as_mut_slice::<u32>();
240		let d_in_vals = d_in_col.as_mut_slice::<u32>();
241		let mx_in_vals = mx_in_col.as_mut_slice::<u32>();
242		let my_in_vals = my_in_col.as_mut_slice::<u32>();
243		let a_0_tmp_vals = a_0_tmp_col.as_mut_slice::<u32>();
244		let a_0_vals = a_0_col.as_mut_slice::<u32>();
245		let b_in_xor_c_0_vals = b_in_xor_c_0_col.as_mut_slice::<u32>();
246		let b_0_vals = b_0_col.as_mut_slice::<u32>();
247		let c_0_vals = c_0_col.as_mut_slice::<u32>();
248		let d_in_xor_a_0_vals = d_in_xor_a_0_col.as_mut_slice::<u32>();
249		let d_0_vals = d_0_col.as_mut_slice::<u32>();
250		let a_1_tmp_vals = a_1_tmp_col.as_mut_slice::<u32>();
251		let a_1_vals = a_1_col.as_mut_slice::<u32>();
252		let d_0_xor_a_1_vals = d_0_xor_a_1_col.as_mut_slice::<u32>();
253		let d_1_vals = d_1_col.as_mut_slice::<u32>();
254		let c_1_vals = c_1_col.as_mut_slice::<u32>();
255		let b_0_xor_c_1_vals = b_0_xor_c_1_col.as_mut_slice::<u32>();
256		let b_1_vals = b_1_col.as_mut_slice::<u32>();
257
258		let cout_vals = cout_cols.each_mut().map(|col| col.as_mut_slice::<u32>());
259		let cin_vals = cin_cols.each_mut().map(|col| col.as_mut_slice::<u32>());
260
261		/* Populating */
262
263		// indices from Blake3 reference:
264		// https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs#L53
265		let a = [0, 1, 2, 3, 0, 1, 2, 3];
266		let b = [4, 5, 6, 7, 5, 6, 7, 4];
267		let c = [8, 9, 10, 11, 10, 11, 8, 9];
268		let d = [12, 13, 14, 15, 15, 12, 13, 14];
269
270		// we consider message 'm' as part of the state
271		let mx = [16, 18, 20, 22, 24, 26, 28, 30];
272		let my = [17, 19, 21, 23, 25, 27, 29, 31];
273
274		let mut compression_offset = 0usize;
275		for compression_idx in 0..states_amount {
276			let state = input_witness
277				.get(compression_idx)
278				.copied()
279				.unwrap_or_default();
280
281			let mut state_idx = 0;
282
283			// populate current state
284			for i in 0..state.cv.len() {
285				state_vals[state_idx][compression_offset] = state.cv[i];
286				state_idx += 1;
287			}
288
289			state_vals[state_idx][compression_offset] = IV[0];
290			state_vals[state_idx + 1][compression_offset] = IV[1];
291			state_vals[state_idx + 2][compression_offset] = IV[2];
292			state_vals[state_idx + 3][compression_offset] = IV[3];
293			state_vals[state_idx + 4][compression_offset] = state.counter_low;
294			state_vals[state_idx + 5][compression_offset] = state.counter_high;
295			state_vals[state_idx + 6][compression_offset] = state.block_len;
296			state_vals[state_idx + 7][compression_offset] = state.flags;
297
298			state_idx += 8;
299
300			for i in 0..state.block.len() {
301				state_vals[state_idx][compression_offset] = state.block[i];
302				state_idx += 1;
303			}
304
305			// populate input, which consists from initial values of each state_transition
306			for xy in 0..STATE_SIZE {
307				input_vals[xy][compression_idx] = state_vals[xy][compression_offset];
308			}
309
310			assert_eq!(state_idx, STATE_SIZE);
311
312			// we start from 1, since initial state is at 0
313			let mut state_offset = 1usize;
314			let mut temp_vars_offset = 0usize;
315
316			fn add(a: u32, b: u32) -> (u32, u32, u32) {
317				let zout;
318				let carry;
319
320				(zout, carry) = a.overflowing_add(b);
321				let cin = a ^ b ^ zout;
322				let cout = ((carry as u32) << 31) | (cin >> 1);
323
324				(cin, cout, zout)
325			}
326
327			// state transition
328			for round_idx in 0..7 {
329				for j in 0..8 {
330					let state_transition_idx = state_offset + compression_offset;
331					let var_offset = temp_vars_offset + compression_offset;
332					let mut add_offset = 0usize;
333
334					// column-wise copy of the previous state to the next one
335					#[allow(clippy::needless_range_loop)]
336					for i in 0..STATE_SIZE {
337						state_vals[i][state_transition_idx] =
338							state_vals[i][state_transition_idx - 1];
339					}
340
341					// take input from previous state
342					a_in_vals[var_offset] = state_vals[a[j]][state_transition_idx - 1];
343					b_in_vals[var_offset] = state_vals[b[j]][state_transition_idx - 1];
344					c_in_vals[var_offset] = state_vals[c[j]][state_transition_idx - 1];
345					d_in_vals[var_offset] = state_vals[d[j]][state_transition_idx - 1];
346					mx_in_vals[var_offset] = state_vals[mx[j]][state_transition_idx - 1];
347					my_in_vals[var_offset] = state_vals[my[j]][state_transition_idx - 1];
348
349					// compute values of temp vars
350
351					(
352						cin_vals[add_offset][var_offset],
353						cout_vals[add_offset][var_offset],
354						a_0_tmp_vals[var_offset],
355					) = add(a_in_vals[var_offset], b_in_vals[var_offset]);
356					add_offset += 1;
357
358					(
359						cin_vals[add_offset][var_offset],
360						cout_vals[add_offset][var_offset],
361						a_0_vals[var_offset],
362					) = add(a_0_tmp_vals[var_offset], mx_in_vals[var_offset]);
363					add_offset += 1;
364
365					d_in_xor_a_0_vals[var_offset] = d_in_vals[var_offset] ^ a_0_vals[var_offset];
366
367					d_0_vals[var_offset] = d_in_xor_a_0_vals[var_offset].rotate_right(16);
368
369					(
370						cin_vals[add_offset][var_offset],
371						cout_vals[add_offset][var_offset],
372						c_0_vals[var_offset],
373					) = add(c_in_vals[var_offset], d_0_vals[var_offset]);
374					add_offset += 1;
375
376					b_in_xor_c_0_vals[var_offset] = b_in_vals[var_offset] ^ c_0_vals[var_offset];
377
378					b_0_vals[var_offset] = b_in_xor_c_0_vals[var_offset].rotate_right(12);
379
380					(
381						cin_vals[add_offset][var_offset],
382						cout_vals[add_offset][var_offset],
383						a_1_tmp_vals[var_offset],
384					) = add(a_0_vals[var_offset], b_0_vals[var_offset]);
385					add_offset += 1;
386
387					(
388						cin_vals[add_offset][var_offset],
389						cout_vals[add_offset][var_offset],
390						a_1_vals[var_offset],
391					) = add(a_1_tmp_vals[var_offset], my_in_vals[var_offset]);
392					add_offset += 1;
393
394					d_0_xor_a_1_vals[var_offset] = d_0_vals[var_offset] ^ a_1_vals[var_offset];
395
396					d_1_vals[var_offset] = d_0_xor_a_1_vals[var_offset].rotate_right(8);
397
398					(
399						cin_vals[add_offset][var_offset],
400						cout_vals[add_offset][var_offset],
401						c_1_vals[var_offset],
402					) = add(c_0_vals[var_offset], d_1_vals[var_offset]);
403					add_offset += 1;
404
405					b_0_xor_c_1_vals[var_offset] = b_0_vals[var_offset] ^ c_1_vals[var_offset];
406
407					b_1_vals[var_offset] = b_0_xor_c_1_vals[var_offset].rotate_right(7);
408
409					// mutate state
410					state_vals[a[j]][state_transition_idx] = a_1_vals[var_offset];
411					state_vals[b[j]][state_transition_idx] = b_1_vals[var_offset];
412					state_vals[c[j]][state_transition_idx] = c_1_vals[var_offset];
413					state_vals[d[j]][state_transition_idx] = d_1_vals[var_offset];
414
415					state_offset += 1;
416					temp_vars_offset += 1;
417					assert_eq!(add_offset, ADDITION_OPERATIONS_NUMBER);
418				}
419
420				// permutation (just shuffling the indices - no constraining is required)
421				if round_idx < 6 {
422					let mut permuted = [0u32; 16];
423					for i in 0..16 {
424						permuted[i] = state_vals[16 + MSG_PERMUTATION[i]]
425							[state_offset + compression_offset - 1];
426					}
427
428					for i in 0..16 {
429						state_vals[16 + i][state_offset + compression_offset - 1] = permuted[i];
430					}
431				}
432			}
433
434			assert_eq!(state_offset, TEMP_STATE_OUT_INDEX + 1);
435
436			for i in 0..8 {
437				// populate 'cv', 'state[i]' and 'state[i + 8]' columns
438				cv_vals[i * compression_idx + i] = state_vals[i][compression_offset];
439				state_i_vals[i * compression_idx + i] =
440					state_vals[i][state_offset + compression_offset - 1];
441				state_i_8_vals[i * compression_idx + i] =
442					state_vals[i + 8][state_offset + compression_offset - 1];
443
444				// compute 'state[i]' values
445				state_vals[i][state_offset + compression_offset - 1] ^=
446					state_vals[i + 8][state_offset + compression_offset - 1];
447
448				// populate 'state[i] ^ state[i + 8]' linear combination
449				state_i_xor_state_i_8_vals[i * compression_idx + i] =
450					state_vals[i][state_offset + compression_offset - 1];
451
452				// compute 'state[i + 8]' values
453				state_vals[i + 8][state_offset + compression_offset - 1] ^=
454					state_vals[i][compression_offset];
455
456				// populate 'cv ^ state[i + 8]' linear combination
457				cv_oracle_xor_state_i_8_vals[i * compression_idx + i] =
458					state_vals[i + 8][state_offset + compression_offset - 1];
459			}
460
461			// copy final state transition (of the given compression) to the output
462			for i in 0..STATE_SIZE {
463				output_vals[i][compression_idx] =
464					state_vals[i][state_offset + compression_offset - 1];
465			}
466
467			compression_offset += SINGLE_COMPRESSION_HEIGHT;
468		}
469	}
470
471	/* Constraints */
472
473	// TODO: remove this technical constraint (figure out how to properly constrain the 'state_i_8')
474	//builder.assert_zero("state_i_8", [state_i_8], arith_expr!([x] = x - x).convert_field());
475
476	let xins = [a_in, a_0_tmp, c_in, a_0, a_1_tmp, c_0];
477	let yins = [b_in, mx_in, d_0, b_0, my_in, d_1];
478	let zouts = [a_0_tmp, a_0, c_0, a_1_tmp, a_1, c_1];
479
480	for (idx, (xin, (yin, zout))) in xins
481		.into_iter()
482		.zip(yins.into_iter().zip(zouts.into_iter()))
483		.enumerate()
484	{
485		builder.assert_zero(
486			format!("sum{idx}"),
487			[xin, yin, cin[idx], zout],
488			arith_expr!([xin, yin, cin, zout] = xin + yin + cin - zout).convert_field(),
489		);
490
491		builder.assert_zero(
492			format!("carry{idx}"),
493			[xin, yin, cin[idx], cout[idx]],
494			arith_expr!([xin, yin, cin, cout] = (xin + cin) * (yin + cin) + cin - cout)
495				.convert_field(),
496		);
497	}
498
499	Ok(Blake3CompressOracles { input, output })
500}
501
502#[cfg(test)]
503mod tests {
504	use std::array;
505
506	use rand::{rngs::StdRng, Rng, SeedableRng};
507
508	use crate::{
509		blake3::{blake3_compress, Blake3CompressState, F32, IV, MSG_PERMUTATION},
510		builder::test_utils::test_circuit,
511	};
512
513	// taken (and slightly refactored) from reference Blake3 implementation:
514	// https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs
515	fn compress(
516		chaining_value: &[u32; 8],
517		block_words: &[u32; 16],
518		counter: u64,
519		block_len: u32,
520		flags: u32,
521	) -> [u32; 16] {
522		let counter_low = counter as u32;
523		let counter_high = (counter >> 32) as u32;
524
525		#[rustfmt::skip]
526    let mut state = [
527        chaining_value[0], chaining_value[1], chaining_value[2], chaining_value[3],
528        chaining_value[4], chaining_value[5], chaining_value[6], chaining_value[7],
529        IV[0],             IV[1],             IV[2],             IV[3],
530        counter_low,       counter_high,      block_len,         flags,
531		block_words[0], block_words[1], block_words[2], block_words[3],
532		block_words[4], block_words[5], block_words[6], block_words[7],
533		block_words[8], block_words[9], block_words[10], block_words[11],
534		block_words[12], block_words[13], block_words[14], block_words[15],
535    ];
536
537		let a = [0, 1, 2, 3, 0, 1, 2, 3];
538		let b = [4, 5, 6, 7, 5, 6, 7, 4];
539		let c = [8, 9, 10, 11, 10, 11, 8, 9];
540		let d = [12, 13, 14, 15, 15, 12, 13, 14];
541		let mx = [16, 18, 20, 22, 24, 26, 28, 30];
542		let my = [17, 19, 21, 23, 25, 27, 29, 31];
543
544		// we have 7 rounds in total
545		for round_idx in 0..7 {
546			for j in 0..8 {
547				let a_in = state[a[j]];
548				let b_in = state[b[j]];
549				let c_in = state[c[j]];
550				let d_in = state[d[j]];
551				let mx_in = state[mx[j]];
552				let my_in = state[my[j]];
553
554				let a_0 = a_in.wrapping_add(b_in).wrapping_add(mx_in);
555				let d_0 = (d_in ^ a_0).rotate_right(16);
556				let c_0 = c_in.wrapping_add(d_0);
557				let b_0 = (b_in ^ c_0).rotate_right(12);
558
559				let a_1 = a_0.wrapping_add(b_0).wrapping_add(my_in);
560				let d_1 = (d_0 ^ a_1).rotate_right(8);
561				let c_1 = c_0.wrapping_add(d_1);
562				let b_1 = (b_0 ^ c_1).rotate_right(7);
563
564				state[a[j]] = a_1;
565				state[b[j]] = b_1;
566				state[c[j]] = c_1;
567				state[d[j]] = d_1;
568			}
569
570			// execute permutation for the 6 first rounds
571			if round_idx < 6 {
572				let mut permuted = [0; 16];
573				for i in 0..16 {
574					permuted[i] = state[16 + MSG_PERMUTATION[i]];
575				}
576				state[16..32].copy_from_slice(&permuted);
577			}
578		}
579
580		for i in 0..8 {
581			state[i] ^= state[i + 8];
582			state[i + 8] ^= chaining_value[i];
583		}
584
585		let state_out: [u32; 16] = std::array::from_fn(|i| state[i]);
586		state_out
587	}
588
589	#[test]
590	fn test_blake3_compression() {
591		test_circuit(|builder| {
592			let compressions = 8;
593			let mut rng = StdRng::seed_from_u64(0);
594			let mut expected = vec![];
595			let states = (0..compressions)
596				.map(|_| {
597					let cv: [u32; 8] = array::from_fn(|_| rng.gen::<u32>());
598					let block: [u32; 16] = array::from_fn(|_| rng.gen::<u32>());
599					let counter = rng.gen::<u64>();
600					let counter_low = counter as u32;
601					let counter_high = (counter >> 32) as u32;
602					let block_len = rng.gen::<u32>();
603					let flags = rng.gen::<u32>();
604
605					// save expected value to use later in test
606					expected.push(compress(&cv, &block, counter, block_len, flags).to_vec());
607
608					Blake3CompressState {
609						cv,
610						block,
611						counter_low,
612						counter_high,
613						block_len,
614						flags,
615					}
616				})
617				.collect::<Vec<Blake3CompressState>>();
618
619			// transpose
620			let expected = transpose(expected);
621
622			let states_len = states.len();
623			let state_out = blake3_compress(builder, &Some(states), states_len)?;
624			if let Some(witness) = builder.witness() {
625				for (i, expected_i) in expected.into_iter().enumerate() {
626					let actual = witness
627						.get::<F32>(state_out.output[i])
628						.unwrap()
629						.as_slice::<u32>();
630					let len = expected_i.len();
631					assert_eq!(actual[..len], expected_i);
632				}
633			}
634			Ok(vec![])
635		})
636		.unwrap();
637	}
638
639	fn transpose<T>(v: Vec<Vec<T>>) -> Vec<Vec<T>>
640	where
641		T: Clone,
642	{
643		assert!(!v.is_empty());
644		(0..v[0].len())
645			.map(|i| v.iter().map(|inner| inner[i].clone()).collect::<Vec<T>>())
646			.collect()
647	}
648}