binius_circuits/
sha256.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_core::oracle::{OracleId, ShiftVariant};
4use binius_field::{as_packed_field::PackedType, BinaryField1b, Field, TowerField};
5use binius_macros::arith_expr;
6use itertools::izip;
7
8use crate::{
9	arithmetic,
10	arithmetic::u32::{u32const_repeating, LOG_U32_BITS},
11	builder::{types::U, ConstraintSystemBuilder},
12};
13
14type B1 = BinaryField1b;
15
16/// SHA-256 round constants, K
17pub const ROUND_CONSTS_K: [u32; 64] = [
18	0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
19	0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
20	0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
21	0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
22	0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
23	0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
24	0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
25	0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
26];
27
28pub const INIT: [u32; 8] = [
29	0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
30];
31
32pub enum RotateRightType {
33	Circular,
34	Logical,
35}
36
37pub fn rotate_and_xor(
38	log_size: usize,
39	builder: &mut ConstraintSystemBuilder,
40	r: &[(OracleId, usize, RotateRightType)],
41) -> Result<OracleId, anyhow::Error> {
42	let shifted_oracle_ids = r
43		.iter()
44		.map(|(oracle_id, shift, t)| {
45			match t {
46				RotateRightType::Circular => builder.add_shifted(
47					format!("RotateRightType::Circular shift:{} oracle_id: {}", shift, oracle_id),
48					*oracle_id,
49					32 - shift,
50					LOG_U32_BITS,
51					ShiftVariant::CircularLeft,
52				),
53				RotateRightType::Logical => builder.add_shifted(
54					format!("RotateRightType::Logical shift:{} oracle_id: {}", shift, oracle_id),
55					*oracle_id,
56					*shift,
57					LOG_U32_BITS,
58					ShiftVariant::LogicalRight,
59				),
60			}
61			.map_err(|e| e.into())
62		})
63		.collect::<Result<Vec<_>, anyhow::Error>>()?;
64
65	let result_oracle_id = builder.add_linear_combination(
66		format!("linear combination of {:?}", shifted_oracle_ids),
67		log_size,
68		shifted_oracle_ids.iter().map(|s| (*s, Field::ONE)),
69	)?;
70
71	if let Some(witness) = builder.witness() {
72		let mut result_witness = witness.new_column::<B1>(result_oracle_id);
73		let result_u32 = result_witness.as_mut_slice::<u32>();
74
75		for ((oracle_id, shift, t), shifted_oracle_id) in r.iter().zip(&shifted_oracle_ids) {
76			let values_u32 = witness.get::<B1>(*oracle_id)?.as_slice::<u32>();
77
78			let mut shifted_witness = witness.new_column::<B1>(*shifted_oracle_id);
79			let shifted_u32 = shifted_witness.as_mut_slice::<u32>();
80
81			izip!(shifted_u32.iter_mut(), values_u32, result_u32.iter_mut()).for_each(
82				|(shifted, val, res)| {
83					*shifted = match t {
84						RotateRightType::Circular => val.rotate_right(*shift as u32),
85						RotateRightType::Logical => val >> shift,
86					};
87					*res ^= *shifted;
88				},
89			);
90		}
91	}
92
93	Ok(result_oracle_id)
94}
95
96pub fn sha256(
97	builder: &mut ConstraintSystemBuilder,
98	input: [OracleId; 16],
99	log_size: usize,
100) -> Result<[OracleId; 8], anyhow::Error> {
101	if log_size < <PackedType<U, BinaryField1b>>::LOG_WIDTH {
102		Err(anyhow::Error::msg("log_size too small"))?
103	}
104
105	let mut w = [OracleId::MAX; 64];
106
107	w[0..16].copy_from_slice(&input);
108
109	for i in 16..64 {
110		let s0 = rotate_and_xor(
111			log_size,
112			builder,
113			&[
114				(w[i - 15], 7, RotateRightType::Circular),
115				(w[i - 15], 18, RotateRightType::Circular),
116				(w[i - 15], 3, RotateRightType::Logical),
117			],
118		)?;
119		let s1 = rotate_and_xor(
120			log_size,
121			builder,
122			&[
123				(w[i - 2], 17, RotateRightType::Circular),
124				(w[i - 2], 19, RotateRightType::Circular),
125				(w[i - 2], 10, RotateRightType::Logical),
126			],
127		)?;
128		let w_addition = arithmetic::u32::add(
129			builder,
130			"w_addition",
131			w[i - 16],
132			w[i - 7],
133			arithmetic::Flags::Unchecked,
134		)?;
135		let s_addition =
136			arithmetic::u32::add(builder, "s_addition", s0, s1, arithmetic::Flags::Unchecked)?;
137
138		w[i] = arithmetic::u32::add(
139			builder,
140			format!("w[{}]", i),
141			w_addition,
142			s_addition,
143			arithmetic::Flags::Unchecked,
144		)?;
145	}
146
147	let init_oracles = INIT.map(|val| u32const_repeating(log_size, builder, val, "INIT").unwrap());
148
149	let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = init_oracles;
150
151	let k = ROUND_CONSTS_K
152		.map(|val| u32const_repeating(log_size, builder, val, "ROUND_CONSTS_K").unwrap());
153
154	let ch: [OracleId; 64] = builder.add_committed_multiple("ch", log_size, B1::TOWER_LEVEL);
155
156	let maj: [OracleId; 64] = builder.add_committed_multiple("maj", log_size, B1::TOWER_LEVEL);
157
158	for i in 0..64 {
159		let sigma1 = rotate_and_xor(
160			log_size,
161			builder,
162			&[
163				(e, 6, RotateRightType::Circular),
164				(e, 11, RotateRightType::Circular),
165				(e, 25, RotateRightType::Circular),
166			],
167		)?;
168
169		if let Some(witness) = builder.witness() {
170			let mut ch_witness = witness.new_column::<B1>(ch[i]);
171			let ch_u32 = ch_witness.as_mut_slice::<u32>();
172			let e_u32 = witness.get::<B1>(e)?.as_slice::<u32>();
173			let f_u32 = witness.get::<B1>(f)?.as_slice::<u32>();
174			let g_u32 = witness.get::<B1>(g)?.as_slice::<u32>();
175			izip!(ch_u32.iter_mut(), e_u32, f_u32, g_u32).for_each(|(ch, e, f, g)| {
176				*ch = g ^ (e & (f ^ g));
177			});
178		}
179
180		let h_sigma1 =
181			arithmetic::u32::add(builder, "h_sigma1", h, sigma1, arithmetic::Flags::Unchecked)?;
182		let ch_ki =
183			arithmetic::u32::add(builder, "ch_ki", ch[i], k[i], arithmetic::Flags::Unchecked)?;
184		let ch_ki_w_i =
185			arithmetic::u32::add(builder, "ch_ki_w_i", ch_ki, w[i], arithmetic::Flags::Unchecked)?;
186		let temp1 = arithmetic::u32::add(
187			builder,
188			"temp1",
189			h_sigma1,
190			ch_ki_w_i,
191			arithmetic::Flags::Unchecked,
192		)?;
193
194		let sigma0 = rotate_and_xor(
195			log_size,
196			builder,
197			&[
198				(a, 2, RotateRightType::Circular),
199				(a, 13, RotateRightType::Circular),
200				(a, 22, RotateRightType::Circular),
201			],
202		)?;
203
204		if let Some(witness) = builder.witness() {
205			let mut maj_witness = witness.new_column::<B1>(maj[i]);
206			let maj_u32 = maj_witness.as_mut_slice::<u32>();
207			let a_u32 = witness.get::<B1>(a)?.as_slice::<u32>();
208			let b_u32 = witness.get::<B1>(b)?.as_slice::<u32>();
209			let c_u32 = witness.get::<B1>(c)?.as_slice::<u32>();
210			izip!(maj_u32.iter_mut(), a_u32, b_u32, c_u32).for_each(|(maj, a, b, c)| {
211				*maj = (a & (b ^ c)) ^ (b & c);
212			});
213		}
214
215		let temp2 =
216			arithmetic::u32::add(builder, "temp2", sigma0, maj[i], arithmetic::Flags::Unchecked)?;
217
218		// Optimization:
219		// (e * f + (1 - e) * g) can be replaced with (g + e * (f + g))
220		// (a * b + a * c + b * c) can be replaced with (a * (b + c) + b * c)
221		// Reference: https://x.com/bartolomeo_diaz/status/1866788688799080922
222		builder.assert_zero(
223			format!("ch_{i}"),
224			[e, f, g, ch[i]],
225			arith_expr!([e, f, g, ch] = (g + e * (f + g)) - ch).convert_field(),
226		);
227
228		builder.assert_zero(
229			format!("maj_{i}"),
230			[a, b, c, maj[i]],
231			arith_expr!([a, b, c, maj] = maj - (a * (b + c)) + b * c).convert_field(),
232		);
233
234		h = g;
235		g = f;
236		f = e;
237		e = arithmetic::u32::add(builder, "e", d, temp1, arithmetic::Flags::Unchecked)?;
238		d = c;
239		c = b;
240		b = a;
241		a = arithmetic::u32::add(builder, "a", temp1, temp2, arithmetic::Flags::Unchecked)?;
242	}
243
244	let abcdefgh = [a, b, c, d, e, f, g, h];
245
246	let output = std::array::from_fn(|i| {
247		arithmetic::u32::add(
248			builder,
249			"output",
250			init_oracles[i],
251			abcdefgh[i],
252			arithmetic::Flags::Unchecked,
253		)
254		.unwrap()
255	});
256
257	Ok(output)
258}
259
260#[cfg(test)]
261mod tests {
262	use binius_core::oracle::OracleId;
263	use binius_field::{as_packed_field::PackedType, BinaryField1b};
264	use sha2::{compress256, digest::generic_array::GenericArray};
265
266	use crate::{
267		builder::{test_utils::test_circuit, types::U},
268		unconstrained::unconstrained,
269	};
270
271	#[test]
272	fn test_sha256() {
273		test_circuit(|builder| {
274			let log_size = PackedType::<U, BinaryField1b>::LOG_WIDTH;
275			let input: [OracleId; 16] = std::array::from_fn(|i| {
276				unconstrained::<BinaryField1b>(builder, i, log_size).unwrap()
277			});
278			let state_output = super::sha256(builder, input, log_size).unwrap();
279
280			if let Some(witness) = builder.witness() {
281				let input_witneses: [_; 16] = std::array::from_fn(|i| {
282					witness
283						.get::<BinaryField1b>(input[i])
284						.unwrap()
285						.as_slice::<u32>()
286				});
287
288				let output_witneses: [_; 8] = std::array::from_fn(|i| {
289					witness
290						.get::<BinaryField1b>(state_output[i])
291						.unwrap()
292						.as_slice::<u32>()
293				});
294
295				let mut generic_array_input = GenericArray::<u8, _>::default();
296
297				let n_compressions = input_witneses[0].len();
298
299				for j in 0..n_compressions {
300					for i in 0..16 {
301						for z in 0..4 {
302							generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z];
303						}
304					}
305
306					let mut output = crate::sha256::INIT;
307					compress256(&mut output, &[generic_array_input]);
308
309					for i in 0..8 {
310						assert_eq!(output[i], output_witneses[i][j]);
311					}
312				}
313			}
314
315			Ok(vec![])
316		})
317		.unwrap();
318	}
319}