binius_circuits/lasso/
sha256.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use anyhow::Result;
4use binius_core::oracle::OracleId;
5use binius_field::{
6	as_packed_field::PackedType, packed::packed_from_fn_with_offset, BinaryField16b, BinaryField1b,
7	BinaryField32b, BinaryField4b, PackedField, TowerField,
8};
9
10use super::{lasso::lasso, u32add::SeveralU32add};
11use crate::{
12	arithmetic::u32::u32const_repeating,
13	builder::{
14		types::{F, U},
15		ConstraintSystemBuilder,
16	},
17	pack::pack,
18	sha256::{rotate_and_xor, RotateRightType, INIT, ROUND_CONSTS_K},
19};
20
21pub const CH_MAJ_T_LOG_SIZE: usize = 12;
22
23type B1 = BinaryField1b;
24type B4 = BinaryField4b;
25type B16 = BinaryField16b;
26type B32 = BinaryField32b;
27
28struct SeveralBitwise {
29	n_lookups: Vec<usize>,
30	lookup_t: OracleId,
31	lookups_u: Vec<[OracleId; 1]>,
32	u_to_t_mappings: Vec<Vec<usize>>,
33	f: fn(u32, u32, u32) -> u32,
34}
35
36impl SeveralBitwise {
37	pub fn new(builder: &mut ConstraintSystemBuilder, f: fn(u32, u32, u32) -> u32) -> Result<Self> {
38		let lookup_t =
39			builder.add_committed("bitwise lookup_t", CH_MAJ_T_LOG_SIZE, B16::TOWER_LEVEL);
40
41		if let Some(witness) = builder.witness() {
42			let mut lookup_t_witness = witness.new_column::<B16>(lookup_t);
43
44			let lookup_t = lookup_t_witness.packed();
45			for (i, lookup_t) in lookup_t.iter_mut().enumerate() {
46				*lookup_t = packed_from_fn_with_offset(i, |i| {
47					let x = ((i >> 8) & 15) as u16;
48					let y = ((i >> 4) & 15) as u16;
49					let z = (i & 15) as u16;
50
51					let res = f(x as u32, y as u32, z as u32);
52
53					let lookup_index = (((x << 4) | y) << 4) | z;
54					B16::new((lookup_index << 4) | res as u16)
55				});
56			}
57		}
58		Ok(Self {
59			n_lookups: Vec::new(),
60			lookup_t,
61			lookups_u: Vec::new(),
62			u_to_t_mappings: Vec::new(),
63			f,
64		})
65	}
66
67	pub fn calculate(
68		&mut self,
69		builder: &mut ConstraintSystemBuilder,
70		name: impl ToString,
71		params: [OracleId; 3],
72	) -> Result<OracleId> {
73		let [xin, yin, zin] = params;
74
75		let log_size = builder.log_rows(params)?;
76
77		let xin_packed = pack::<B1, B4>(xin, builder, "xin_packed")?;
78		let yin_packed = pack::<B1, B4>(yin, builder, "yin_packed")?;
79		let zin_packed = pack::<B1, B4>(zin, builder, "zin_packed")?;
80
81		let res = builder.add_committed(name, log_size, B1::TOWER_LEVEL);
82
83		let res_packed = builder.add_packed("res_packed", res, B4::TOWER_LEVEL)?;
84
85		let lookup_u = builder.add_linear_combination(
86			"ch or maj lookup_u",
87			log_size - B4::TOWER_LEVEL,
88			[
89				(xin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 3)?),
90				(yin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 2)?),
91				(zin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 1)?),
92				(res_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 0)?),
93			],
94		)?;
95
96		if let Some(witness) = builder.witness() {
97			let mut lookup_u_witness = witness.new_column::<B16>(lookup_u);
98			let lookup_u_u16 = lookup_u_witness.packed();
99
100			let mut u_to_t_mapping_witness = Vec::with_capacity(1 << (log_size - B4::TOWER_LEVEL));
101
102			let mut res_witness = witness.new_column::<B1>(res);
103			let res_u32 = res_witness.as_mut_slice::<u32>();
104
105			let xin_u32 = witness.get::<B1>(xin)?.as_slice::<u32>();
106			let yin_u32 = witness.get::<B1>(yin)?.as_slice::<u32>();
107			let zin_u32 = witness.get::<B1>(zin)?.as_slice::<u32>();
108
109			// The following code works only if packed width is at least 8.
110			// Which should be true in real life cases, becasue we don't use
111			// packed types with less than 128 bits size.
112			if PackedType::<U, B16>::LOG_WIDTH < 3 {
113				return Err(anyhow::anyhow!(
114					"PackedType::<U, B16>::LOG_WIDTH < 3, this is not supported"
115				));
116			}
117
118			for (i, lookup_u) in lookup_u_u16.iter_mut().enumerate() {
119				let offset = i << (PackedType::<U, B16>::LOG_WIDTH - 3);
120				let scalars = (0..(PackedType::<U, B16>::WIDTH / 8)).flat_map(|j| {
121					let index = offset + j;
122					let x = xin_u32[index];
123					let y = yin_u32[index];
124					let z = zin_u32[index];
125					let res = (self.f)(x, y, z);
126					res_u32[index] = res;
127
128					let scalars = std::array::from_fn::<_, 8, _>(|k| {
129						let x = ((x >> (4 * k)) & 15) as u16;
130						let y = ((y >> (4 * k)) & 15) as u16;
131						let z = ((z >> (4 * k)) & 15) as u16;
132						let res = ((res >> (4 * k)) & 15) as u16;
133						let lookup_index = (((x << 4) | y) << 4) | z;
134
135						u_to_t_mapping_witness.push(lookup_index as usize);
136
137						B16::new((lookup_index << 4) | res)
138					});
139
140					scalars.into_iter()
141				});
142				*lookup_u = PackedType::<U, B16>::from_scalars(scalars);
143			}
144
145			std::mem::drop(res_witness);
146
147			let res_packed_witness = witness.get::<B1>(res)?;
148			witness.set::<B4>(res_packed, res_packed_witness.repacked::<B4>())?;
149
150			self.u_to_t_mappings.push(u_to_t_mapping_witness);
151		}
152
153		self.lookups_u.push([lookup_u]);
154		self.n_lookups.push(1 << (log_size - B4::TOWER_LEVEL));
155		Ok(res)
156	}
157
158	pub fn finalize(
159		self,
160		builder: &mut ConstraintSystemBuilder,
161		name: impl ToString,
162	) -> Result<()> {
163		let channel = builder.add_channel();
164
165		lasso::<B32>(
166			builder,
167			name,
168			&self.n_lookups,
169			&self.u_to_t_mappings,
170			&self.lookups_u,
171			[self.lookup_t],
172			channel,
173		)
174	}
175}
176
177pub fn sha256(
178	builder: &mut ConstraintSystemBuilder,
179	input: [OracleId; 16],
180	log_size: usize,
181) -> Result<[OracleId; 8], anyhow::Error> {
182	let n_vars = log_size;
183
184	let mut several_u32_add = SeveralU32add::new(builder)?;
185
186	let mut several_ch = SeveralBitwise::new(builder, |e, f, g| (e & f) ^ ((!e) & g))?;
187
188	let mut several_maj = SeveralBitwise::new(builder, |a, b, c| (a & b) ^ (a & c) ^ (b & c))?;
189
190	let mut w = [OracleId::MAX; 64];
191
192	w[0..16].copy_from_slice(&input);
193
194	for i in 16..64 {
195		let s0 = rotate_and_xor(
196			n_vars,
197			builder,
198			&[
199				(w[i - 15], 7, RotateRightType::Circular),
200				(w[i - 15], 18, RotateRightType::Circular),
201				(w[i - 15], 3, RotateRightType::Logical),
202			],
203		)?;
204		let s1 = rotate_and_xor(
205			n_vars,
206			builder,
207			&[
208				(w[i - 2], 17, RotateRightType::Circular),
209				(w[i - 2], 19, RotateRightType::Circular),
210				(w[i - 2], 10, RotateRightType::Logical),
211			],
212		)?;
213
214		let w_addition =
215			several_u32_add.u32add::<B1, B1>(builder, "w_addition", w[i - 16], w[i - 7])?;
216
217		let s_addition = several_u32_add.u32add::<B1, B1>(builder, "s_addition", s0, s1)?;
218
219		w[i] = several_u32_add.u32add::<B1, B1>(
220			builder,
221			format!("w[{}]", i),
222			w_addition,
223			s_addition,
224		)?;
225	}
226
227	let init_oracles = INIT.map(|val| u32const_repeating(n_vars, builder, val, "INIT").unwrap());
228
229	let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = init_oracles;
230
231	let k = ROUND_CONSTS_K
232		.map(|val| u32const_repeating(n_vars, builder, val, "ROUND_CONSTS_K").unwrap());
233
234	for i in 0..64 {
235		let sigma1 = rotate_and_xor(
236			n_vars,
237			builder,
238			&[
239				(e, 6, RotateRightType::Circular),
240				(e, 11, RotateRightType::Circular),
241				(e, 25, RotateRightType::Circular),
242			],
243		)?;
244
245		let ch = several_ch.calculate(builder, "ch", [e, f, g])?;
246
247		let h_sigma1 = several_u32_add.u32add::<B1, B1>(builder, "h_sigma1", h, sigma1)?;
248		let ch_ki = several_u32_add.u32add::<B1, B1>(builder, "ch_ki", ch, k[i])?;
249		let ch_ki_w_i = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", ch_ki, w[i])?;
250		let temp1 = several_u32_add.u32add::<B1, B1>(builder, "temp1", h_sigma1, ch_ki_w_i)?;
251
252		let sigma0 = rotate_and_xor(
253			n_vars,
254			builder,
255			&[
256				(a, 2, RotateRightType::Circular),
257				(a, 13, RotateRightType::Circular),
258				(a, 22, RotateRightType::Circular),
259			],
260		)?;
261
262		let maj = several_maj.calculate(builder, "maj", [a, b, c])?;
263
264		let temp2 = several_u32_add.u32add::<B1, B1>(builder, "temp2", sigma0, maj)?;
265
266		h = g;
267		g = f;
268		f = e;
269		e = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", d, temp1)?;
270		d = c;
271		c = b;
272		b = a;
273		a = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", temp1, temp2)?;
274	}
275
276	let abcdefgh = [a, b, c, d, e, f, g, h];
277
278	let output = std::array::from_fn(|i| {
279		several_u32_add
280			.u32add::<B1, B1>(builder, "output", init_oracles[i], abcdefgh[i])
281			.unwrap()
282	});
283
284	several_u32_add.finalize(builder, "lasso")?;
285
286	several_ch.finalize(builder, "ch")?;
287	several_maj.finalize(builder, "maj")?;
288
289	Ok(output)
290}
291
292#[cfg(test)]
293mod tests {
294	use binius_core::oracle::OracleId;
295	use binius_field::{as_packed_field::PackedType, BinaryField1b, BinaryField8b, TowerField};
296	use sha2::{compress256, digest::generic_array::GenericArray};
297
298	use crate::{
299		builder::{test_utils::test_circuit, types::U},
300		unconstrained::unconstrained,
301	};
302
303	#[test]
304	fn test_sha256_lasso() {
305		test_circuit(|builder| {
306			let log_size = PackedType::<U, BinaryField1b>::LOG_WIDTH + BinaryField8b::TOWER_LEVEL;
307			let input: [OracleId; 16] = std::array::from_fn(|i| {
308				unconstrained::<BinaryField1b>(builder, i, log_size).unwrap()
309			});
310			let state_output = super::sha256(builder, input, log_size).unwrap();
311
312			if let Some(witness) = builder.witness() {
313				let input_witneses: [_; 16] = std::array::from_fn(|i| {
314					witness
315						.get::<BinaryField1b>(input[i])
316						.unwrap()
317						.as_slice::<u32>()
318				});
319
320				let output_witneses: [_; 8] = std::array::from_fn(|i| {
321					witness
322						.get::<BinaryField1b>(state_output[i])
323						.unwrap()
324						.as_slice::<u32>()
325				});
326
327				let mut generic_array_input = GenericArray::<u8, _>::default();
328
329				let n_compressions = input_witneses[0].len();
330
331				for j in 0..n_compressions {
332					for i in 0..16 {
333						for z in 0..4 {
334							generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z];
335						}
336					}
337
338					let mut output = crate::sha256::INIT;
339					compress256(&mut output, &[generic_array_input]);
340
341					for i in 0..8 {
342						assert_eq!(output[i], output_witneses[i][j]);
343					}
344				}
345			}
346
347			Ok(vec![])
348		})
349		.unwrap();
350	}
351}