1use std::array;
4
5use binius_core::oracle::{OracleId, ShiftVariant};
6use binius_field::{BinaryField1b, Field};
7
8use crate::{
9 arithmetic::{
10 self,
11 u32::{u32const_repeating, LOG_U32_BITS},
12 Flags,
13 },
14 bitwise,
15 builder::{types::F, ConstraintSystemBuilder},
16};
17
18type F1 = BinaryField1b;
19pub const CHAINING_VALUE_LEN: usize = 8;
20pub const BLAKE3_STATE_LEN: usize = 16;
21const MSG_PERMUTATION: [usize; BLAKE3_STATE_LEN] =
22 [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8];
23const IV_0_4: [u32; 4] = [0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A];
24
25fn xor_rotate_right(
27 builder: &mut ConstraintSystemBuilder,
28 name: impl ToString,
29 log_size: usize,
30 a: OracleId,
31 b: OracleId,
32 rotate_right_offset: u32,
33) -> Result<OracleId, anyhow::Error> {
34 assert!(rotate_right_offset <= 32);
35
36 builder.push_namespace(name);
37
38 let xor = builder
39 .add_linear_combination("xor", log_size, [(a, F::ONE), (b, F::ONE)])
40 .unwrap();
41
42 let rotate = builder.add_shifted(
43 "rotate",
44 xor,
45 32 - rotate_right_offset as usize,
46 LOG_U32_BITS,
47 ShiftVariant::CircularLeft,
48 )?;
49
50 if let Some(witness) = builder.witness() {
51 let a_value = witness.get::<F1>(a)?.as_slice::<u32>();
52 let b_value = witness.get::<F1>(b)?.as_slice::<u32>();
53
54 let mut xor_witness = witness.new_column::<F1>(xor);
55 let xor_value = xor_witness.as_mut_slice::<u32>();
56
57 for (idx, v) in xor_value.iter_mut().enumerate() {
58 *v = a_value[idx] ^ b_value[idx];
59 }
60
61 let mut rotate_witness = witness.new_column::<F1>(rotate);
62 let rotate_value = rotate_witness.as_mut_slice::<u32>();
63 for (idx, v) in rotate_value.iter_mut().enumerate() {
64 *v = xor_value[idx].rotate_right(rotate_right_offset);
65 }
66 }
67
68 builder.pop_namespace();
69
70 Ok(rotate)
71}
72
73#[allow(clippy::too_many_arguments)]
75pub fn g(
76 builder: &mut ConstraintSystemBuilder,
77 name: impl ToString,
78 a_in: OracleId,
79 b_in: OracleId,
80 c_in: OracleId,
81 d_in: OracleId,
82 mx: OracleId,
83 my: OracleId,
84 log_size: usize,
85) -> Result<[OracleId; 4], anyhow::Error> {
86 builder.push_namespace(name);
87
88 let ab = arithmetic::u32::add(builder, "a_in + b_in", a_in, b_in, Flags::Unchecked)?;
89 let a1 = arithmetic::u32::add(builder, "a_in + b_in + mx", ab, mx, Flags::Unchecked)?;
90
91 let d1 = xor_rotate_right(builder, "(d_in ^ a1).rotate_right(16)", log_size, d_in, a1, 16u32)?;
92
93 let c1 = arithmetic::u32::add(builder, "c_in + d1", c_in, d1, Flags::Unchecked)?;
94
95 let b1 = xor_rotate_right(builder, "(b_in ^ c1).rotate_right(12)", log_size, b_in, c1, 12u32)?;
96
97 let a1b1 = arithmetic::u32::add(builder, "a1 + b1", a1, b1, Flags::Unchecked)?;
98 let a2 = arithmetic::u32::add(builder, "a1 + b1 + my_in", a1b1, my, Flags::Unchecked)?;
99
100 let d2 = xor_rotate_right(builder, "(d1 ^ a2).rotate_right(8)", log_size, d1, a2, 8u32)?;
101
102 let c2 = arithmetic::u32::add(builder, "c1 + d2", c1, d2, Flags::Unchecked)?;
103
104 let b2 = xor_rotate_right(builder, "(b1 ^ c2).rotate_right(7)", log_size, b1, c2, 7u32)?;
105
106 builder.pop_namespace();
107
108 Ok([a2, b2, c2, d2])
109}
110
111pub fn round(
113 builder: &mut ConstraintSystemBuilder,
114 name: impl ToString,
115 state: &[OracleId; BLAKE3_STATE_LEN],
116 m: &[OracleId; BLAKE3_STATE_LEN],
117 log_size: usize,
118) -> Result<[OracleId; BLAKE3_STATE_LEN], anyhow::Error> {
119 builder.push_namespace(name);
120
121 let [s0, s4, s8, s12] =
123 g(builder, "mix-columns-0", state[0], state[4], state[8], state[12], m[0], m[1], log_size)?;
124
125 let [s1, s5, s9, s13] =
126 g(builder, "mix-columns-1", state[1], state[5], state[9], state[13], m[2], m[3], log_size)?;
127 #[rustfmt::skip]
128 let [s2, s6, s10, s14] =
129 g(builder, "mix-columns-2", state[2], state[6], state[10], state[14], m[4], m[5], log_size)?;
130 #[rustfmt::skip]
131 let [s3, s7, s11, s15] =
132 g(builder, "mix-columns-3", state[3], state[7], state[11], state[15], m[6], m[7], log_size)?;
133
134 let [s0, s5, s10, s15] = g(builder, "mix-diagonals-0", s0, s5, s10, s15, m[8], m[9], log_size)?;
136 #[rustfmt::skip]
137 let [s1, s6, s11, s12] = g(builder, "mix-diagonals-1", s1, s6, s11, s12, m[10], m[11], log_size)?;
138
139 let [s2, s7, s8, s13] = g(builder, "mix-diagonals-2", s2, s7, s8, s13, m[12], m[13], log_size)?;
140
141 let [s3, s4, s9, s14] = g(builder, "mix-diagonals-3", s3, s4, s9, s14, m[14], m[15], log_size)?;
142
143 builder.pop_namespace();
144
145 Ok([
146 s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15,
147 ])
148}
149
150pub fn permute(block_words: &[OracleId; BLAKE3_STATE_LEN]) -> [OracleId; BLAKE3_STATE_LEN] {
151 array::from_fn(|i| block_words[MSG_PERMUTATION[i]])
152}
153
154#[allow(clippy::too_many_arguments)]
156pub fn compress(
157 builder: &mut ConstraintSystemBuilder,
158 name: impl ToString,
159 chaining_value: &[OracleId; CHAINING_VALUE_LEN],
160 block_words: &[OracleId; BLAKE3_STATE_LEN],
161 counter: u64,
162 block_len: u32,
163 flags: u32,
164 log_size: usize,
165) -> Result<[OracleId; BLAKE3_STATE_LEN], anyhow::Error> {
166 builder.push_namespace(name);
167
168 let counter_low = counter as u32;
169 let counter_high = (counter >> 32) as u32;
170
171 let mut state = [OracleId::MAX; BLAKE3_STATE_LEN];
172 state[0..8].copy_from_slice(chaining_value);
173
174 let iv_oracles =
175 IV_0_4.map(|val| u32const_repeating(log_size, builder, val, "blake3_iv").unwrap());
176
177 state[8..12].copy_from_slice(&iv_oracles);
178
179 state[12] = u32const_repeating(log_size, builder, counter_low, "counter_low")?;
180
181 state[13] = u32const_repeating(log_size, builder, counter_high, "counter_high")?;
182
183 state[14] = u32const_repeating(log_size, builder, block_len, "block_len")?;
184
185 state[15] = u32const_repeating(log_size, builder, flags, "flags")?;
186
187 let new_state = round(builder, "round_1", &state, block_words, log_size)?;
188 let new_block_words = permute(block_words);
189
190 let new_state = round(builder, "round_2", &new_state, &new_block_words, log_size)?;
191 let new_block_words = permute(&new_block_words);
192
193 let new_state = round(builder, "round_3", &new_state, &new_block_words, log_size)?;
194 let new_block_words = permute(&new_block_words);
195
196 let new_state = round(builder, "round_4", &new_state, &new_block_words, log_size)?;
197 let new_block_words = permute(&new_block_words);
198
199 let new_state = round(builder, "round_5", &new_state, &new_block_words, log_size)?;
200 let new_block_words = permute(&new_block_words);
201
202 let new_state = round(builder, "round_6", &new_state, &new_block_words, log_size)?;
203 let new_block_words = permute(&new_block_words);
204
205 let pre_final_state = round(builder, "round_7", &new_state, &new_block_words, log_size)?;
206
207 let final_state_left = (0..8)
208 .map(|idx| {
209 bitwise::xor(builder, "final_state_0_8", pre_final_state[idx], pre_final_state[idx + 8])
210 .unwrap()
211 })
212 .collect::<Vec<OracleId>>();
213
214 let final_state_right = (0..8)
215 .map(|idx| {
216 bitwise::xor(builder, "final_state_8_16", pre_final_state[idx + 8], chaining_value[idx])
217 .unwrap()
218 })
219 .collect::<Vec<OracleId>>();
220
221 builder.pop_namespace();
222
223 Ok([final_state_left, final_state_right]
224 .concat()
225 .try_into()
226 .unwrap())
227}
228
229#[cfg(test)]
230mod tests {
231 use std::array;
232
233 use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId};
234 use binius_field::BinaryField1b;
235 use binius_maybe_rayon::prelude::*;
236
237 use crate::{
238 blake3::{
239 compress, g, round, BLAKE3_STATE_LEN, CHAINING_VALUE_LEN, IV_0_4, MSG_PERMUTATION,
240 },
241 builder::ConstraintSystemBuilder,
242 unconstrained::{fixed_u32, unconstrained},
243 };
244
245 type F1 = BinaryField1b;
246
247 const LOG_SIZE: usize = 10;
248
249 const fn g_out_of_circuit(
252 a_in: u32,
253 b_in: u32,
254 c_in: u32,
255 d_in: u32,
256 mx: u32,
257 my: u32,
258 ) -> (u32, u32, u32, u32) {
259 let a1 = a_in.wrapping_add(b_in).wrapping_add(mx);
260 let d1 = (d_in ^ a1).rotate_right(16);
261 let c1 = c_in.wrapping_add(d1);
262 let b1 = (b_in ^ c1).rotate_right(12);
263
264 let a2 = a1.wrapping_add(b1).wrapping_add(my);
265 let d2 = (d1 ^ a2).rotate_right(8);
266 let c2 = c1.wrapping_add(d2);
267 let b2 = (b1 ^ c2).rotate_right(7);
268
269 (a2, b2, c2, d2)
270 }
271
272 #[test]
273 fn test_vector_g() {
274 let a = 0xaaaaaaaau32;
277 let b = 0xbbbbbbbbu32;
278 let c = 0xccccccccu32;
279 let d = 0xddddddddu32;
280 let mx = 0xffff00ffu32;
281 let my = 0xff00ffffu32;
282
283 let (expected_0, expected_1, expected_2, expected_3) = g_out_of_circuit(a, b, c, d, mx, my);
284
285 let size = 1 << LOG_SIZE;
286
287 let allocator = bumpalo::Bump::new();
288 let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
289
290 let a_in = fixed_u32::<F1>(&mut builder, "a", LOG_SIZE, vec![a; size]).unwrap();
291 let b_in = fixed_u32::<F1>(&mut builder, "b", LOG_SIZE, vec![b; size]).unwrap();
292 let c_in = fixed_u32::<F1>(&mut builder, "c", LOG_SIZE, vec![c; size]).unwrap();
293 let d_in = fixed_u32::<F1>(&mut builder, "d", LOG_SIZE, vec![d; size]).unwrap();
294 let mx_in = fixed_u32::<F1>(&mut builder, "mx", LOG_SIZE, vec![mx; size]).unwrap();
295 let my_in = fixed_u32::<F1>(&mut builder, "my", LOG_SIZE, vec![my; size]).unwrap();
296
297 let output = g(&mut builder, "g", a_in, b_in, c_in, d_in, mx_in, my_in, LOG_SIZE).unwrap();
298
299 if let Some(witness) = builder.witness() {
300 (
301 witness.get::<F1>(output[0]).unwrap().as_slice::<u32>(),
302 witness.get::<F1>(output[1]).unwrap().as_slice::<u32>(),
303 witness.get::<F1>(output[2]).unwrap().as_slice::<u32>(),
304 witness.get::<F1>(output[3]).unwrap().as_slice::<u32>(),
305 )
306 .into_par_iter()
307 .for_each(|(actual_0, actual_1, actual_2, actual_3)| {
308 assert_eq!(*actual_0, expected_0);
309 assert_eq!(*actual_1, expected_1);
310 assert_eq!(*actual_2, expected_2);
311 assert_eq!(*actual_3, expected_3);
312 });
313 }
314
315 let witness = builder.take_witness().unwrap();
316 let constraints_system = builder.build().unwrap();
317
318 validate_witness(&constraints_system, &[], &witness).unwrap();
319 }
320
321 #[test]
322 fn test_random_input_g() {
323 let allocator = bumpalo::Bump::new();
324 let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
325
326 let a_in = unconstrained::<F1>(&mut builder, "a", LOG_SIZE).unwrap();
327 let b_in = unconstrained::<F1>(&mut builder, "b", LOG_SIZE).unwrap();
328 let c_in = unconstrained::<F1>(&mut builder, "c", LOG_SIZE).unwrap();
329 let d_in = unconstrained::<F1>(&mut builder, "d", LOG_SIZE).unwrap();
330 let mx_in = unconstrained::<F1>(&mut builder, "mx", LOG_SIZE).unwrap();
331 let my_in = unconstrained::<F1>(&mut builder, "my", LOG_SIZE).unwrap();
332
333 g(&mut builder, "g", a_in, b_in, c_in, d_in, mx_in, my_in, LOG_SIZE).unwrap();
334
335 let witness = builder.take_witness().unwrap();
336 let constraints_system = builder.build().unwrap();
337
338 validate_witness(&constraints_system, &[], &witness).unwrap();
339 }
340
341 const fn round_out_of_circuit(
344 state: &[u32; BLAKE3_STATE_LEN],
345 m: &[u32; BLAKE3_STATE_LEN],
346 ) -> [u32; BLAKE3_STATE_LEN] {
347 let (s0, s4, s8, s12) =
349 g_out_of_circuit(state[0], state[4], state[8], state[12], m[0], m[1]);
350 let (s1, s5, s9, s13) =
351 g_out_of_circuit(state[1], state[5], state[9], state[13], m[2], m[3]);
352 let (s2, s6, s10, s14) =
353 g_out_of_circuit(state[2], state[6], state[10], state[14], m[4], m[5]);
354 let (s3, s7, s11, s15) =
355 g_out_of_circuit(state[3], state[7], state[11], state[15], m[6], m[7]);
356
357 let (s0, s5, s10, s15) = g_out_of_circuit(s0, s5, s10, s15, m[8], m[9]);
359 let (s1, s6, s11, s12) = g_out_of_circuit(s1, s6, s11, s12, m[10], m[11]);
360 let (s2, s7, s8, s13) = g_out_of_circuit(s2, s7, s8, s13, m[12], m[13]);
361 let (s3, s4, s9, s14) = g_out_of_circuit(s3, s4, s9, s14, m[14], m[15]);
362
363 [
364 s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15,
365 ]
366 }
367
368 #[test]
369 fn test_vector_round() {
370 let state = [
371 0xfffffff0, 0xfffffff1, 0xfffffff2, 0xfffffff3, 0xfffffff4, 0xfffffff5, 0xfffffff6,
372 0xfffffff7, 0xfffffff8, 0xfffffff9, 0xfffffffa, 0xfffffffb, 0xfffffffc, 0xfffffffd,
373 0xfffffffe, 0xffffffff,
374 ];
375
376 let m = [
377 0x09ffffff, 0x08ffffff, 0x07ffffff, 0x06ffffff, 0x05ffffff, 0x04ffffff, 0x03ffffff,
378 0x02ffffff, 0x01ffffff, 0x00ffffff, 0x0fffffff, 0x0effffff, 0x0dffffff, 0x0cffffff,
379 0x0bffffff, 0x0affffff,
380 ];
381
382 assert_eq!(state.len(), BLAKE3_STATE_LEN);
383 assert_eq!(state.len(), m.len());
384
385 let expected = round_out_of_circuit(&state, &m);
386
387 let size = 1 << LOG_SIZE;
388
389 let allocator = bumpalo::Bump::new();
390 let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
391
392 let state: [OracleId; BLAKE3_STATE_LEN] = array::from_fn(|idx| {
394 fixed_u32::<F1>(&mut builder, format!("state{}", idx), LOG_SIZE, vec![state[idx]; size])
395 .unwrap()
396 });
397
398 let m: [OracleId; BLAKE3_STATE_LEN] = array::from_fn(|idx| {
399 fixed_u32::<F1>(&mut builder, format!("m{}", idx), LOG_SIZE, vec![m[idx]; size])
400 .unwrap()
401 });
402
403 let actual = round(&mut builder, "round", &state, &m, LOG_SIZE).unwrap();
405
406 if let Some(witness) = builder.witness() {
408 for (i, expected_i) in expected.into_iter().enumerate() {
409 let values = witness.get::<F1>(actual[i]).unwrap().as_slice::<u32>();
410 assert!(values.iter().all(|v| *v == expected_i));
411 }
412 }
413
414 let witness = builder.take_witness().unwrap();
415 let constraints_system = builder.build().unwrap();
416
417 validate_witness(&constraints_system, &[], &witness).unwrap();
418 }
419
420 #[test]
421 fn test_random_input_round() {
422 let allocator = bumpalo::Bump::new();
423 let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
424
425 let state: [OracleId; BLAKE3_STATE_LEN] = array::from_fn(|idx| {
426 unconstrained::<F1>(&mut builder, format!("state{}", idx), LOG_SIZE).unwrap()
427 });
428
429 let m: [OracleId; BLAKE3_STATE_LEN] = array::from_fn(|idx| {
430 unconstrained::<F1>(&mut builder, format!("m{}", idx), LOG_SIZE).unwrap()
431 });
432
433 round(&mut builder, "round", &state, &m, LOG_SIZE).unwrap();
434
435 let witness = builder.take_witness().unwrap();
436 let constraints_system = builder.build().unwrap();
437
438 validate_witness(&constraints_system, &[], &witness).unwrap();
439 }
440
441 fn compress_out_of_circuit(
442 chaining_value: &[u32; 8],
443 block_words: &[u32; 16],
444 counter: u64,
445 block_len: u32,
446 flags: u32,
447 ) -> [u32; 16] {
448 fn permute(m: &mut [u32; 16]) {
449 let mut permuted = [0; 16];
450 for i in 0..16 {
451 permuted[i] = m[MSG_PERMUTATION[i]];
452 }
453 *m = permuted;
454 }
455
456 let counter_low = counter as u32;
457 let counter_high = (counter >> 32) as u32;
458
459 let mut state = [
460 chaining_value[0],
461 chaining_value[1],
462 chaining_value[2],
463 chaining_value[3],
464 chaining_value[4],
465 chaining_value[5],
466 chaining_value[6],
467 chaining_value[7],
468 IV_0_4[0],
469 IV_0_4[1],
470 IV_0_4[2],
471 IV_0_4[3],
472 counter_low,
473 counter_high,
474 block_len,
475 flags,
476 ];
477 let mut block = *block_words;
478
479 state = round_out_of_circuit(&state, &block); permute(&mut block);
481 state = round_out_of_circuit(&state, &block); permute(&mut block);
483 state = round_out_of_circuit(&state, &block); permute(&mut block);
485 state = round_out_of_circuit(&state, &block); permute(&mut block);
487 state = round_out_of_circuit(&state, &block); permute(&mut block);
489 state = round_out_of_circuit(&state, &block); permute(&mut block);
491 state = round_out_of_circuit(&state, &block); for i in 0..8 {
494 state[i] ^= state[i + 8];
495 state[i + 8] ^= chaining_value[i];
496 }
497 state
498 }
499
500 #[test]
501 fn test_vector_compress() {
502 let chaining_value = [
503 0xfffffff0, 0xfffffff1, 0xfffffff2, 0xfffffff3, 0xfffffff4, 0xfffffff5, 0xfffffff6,
504 0xfffffff7,
505 ];
506
507 let m = [
508 0x09ffffff, 0x08ffffff, 0x07ffffff, 0x06ffffff, 0x05ffffff, 0x04ffffff, 0x03ffffff,
509 0x02ffffff, 0x01ffffff, 0x00ffffff, 0x0fffffff, 0x0effffff, 0x0dffffff, 0x0cffffff,
510 0x0bffffff, 0x0affffff,
511 ];
512
513 let counter = u64::MAX;
514 let block_len = u32::MAX;
515 let flags = u32::MAX;
516
517 let expected = compress_out_of_circuit(&chaining_value, &m, counter, block_len, flags);
518
519 let size = 1 << LOG_SIZE;
520 let allocator = bumpalo::Bump::new();
521 let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
522
523 let chaining_value: [OracleId; CHAINING_VALUE_LEN] = array::from_fn(|idx| {
524 fixed_u32::<F1>(
525 &mut builder,
526 format!("cv{}", idx),
527 LOG_SIZE,
528 vec![chaining_value[idx]; size],
529 )
530 .unwrap()
531 });
532
533 let block_words: [OracleId; BLAKE3_STATE_LEN] = array::from_fn(|idx| {
534 fixed_u32::<F1>(&mut builder, format!("block{}", idx), LOG_SIZE, vec![m[idx]; size])
535 .unwrap()
536 });
537
538 let actual = compress(
539 &mut builder,
540 "compress",
541 &chaining_value,
542 &block_words,
543 counter,
544 block_len,
545 flags,
546 LOG_SIZE,
547 )
548 .unwrap();
549
550 if let Some(witness) = builder.witness() {
552 for (i, expected_i) in expected.into_iter().enumerate() {
553 let values = witness.get::<F1>(actual[i]).unwrap().as_slice::<u32>();
554 assert!(values.iter().all(|v| *v == expected_i));
555 }
556 }
557
558 let witness = builder.take_witness().unwrap();
559 let constraints_system = builder.build().unwrap();
560
561 validate_witness(&constraints_system, &[], &witness).unwrap();
562 }
563
564 #[test]
565 fn test_random_input_compress() {
566 let allocator = bumpalo::Bump::new();
567 let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
568
569 let chaining_values: [OracleId; CHAINING_VALUE_LEN] = array::from_fn(|idx| {
570 unconstrained::<F1>(&mut builder, format!("cv{}", idx), LOG_SIZE).unwrap()
571 });
572
573 let block_words: [OracleId; BLAKE3_STATE_LEN] = array::from_fn(|idx| {
574 unconstrained::<F1>(&mut builder, format!("block{}", idx), LOG_SIZE).unwrap()
575 });
576
577 let counter = u64::MAX;
578 let block_len = u32::MAX;
579 let flags = u32::MAX;
580
581 compress(
582 &mut builder,
583 "compress",
584 &chaining_values,
585 &block_words,
586 counter,
587 block_len,
588 flags,
589 LOG_SIZE,
590 )
591 .unwrap();
592
593 let witness = builder.take_witness().unwrap();
594 let constraints_system = builder.build().unwrap();
595
596 validate_witness(&constraints_system, &[], &witness).unwrap();
597 }
598}