binius_circuits/lasso/
sha256.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use anyhow::Result;
4use binius_core::oracle::OracleId;
5use binius_field::{
6	as_packed_field::PackedType, packed::packed_from_fn_with_offset, BinaryField16b, BinaryField1b,
7	BinaryField32b, BinaryField4b, PackedField, TowerField,
8};
9
10use super::{lasso::lasso, u32add::SeveralU32add};
11use crate::{
12	arithmetic::u32::u32const_repeating,
13	builder::{
14		types::{F, U},
15		ConstraintSystemBuilder,
16	},
17	pack::pack,
18	sha256::{rotate_and_xor, RotateRightType, INIT, ROUND_CONSTS_K},
19};
20
21pub const CH_MAJ_T_LOG_SIZE: usize = 12;
22
23type B1 = BinaryField1b;
24type B4 = BinaryField4b;
25type B16 = BinaryField16b;
26type B32 = BinaryField32b;
27
28struct SeveralBitwise {
29	n_lookups: Vec<usize>,
30	lookup_t: OracleId,
31	lookups_u: Vec<[OracleId; 1]>,
32	u_to_t_mappings: Vec<Vec<usize>>,
33	f: fn(u32, u32, u32) -> u32,
34}
35
36impl SeveralBitwise {
37	pub fn new(builder: &mut ConstraintSystemBuilder, f: fn(u32, u32, u32) -> u32) -> Result<Self> {
38		let lookup_t =
39			builder.add_committed("bitwise lookup_t", CH_MAJ_T_LOG_SIZE, B16::TOWER_LEVEL);
40
41		if let Some(witness) = builder.witness() {
42			let mut lookup_t_witness = witness.new_column::<B16>(lookup_t);
43
44			let lookup_t = lookup_t_witness.packed();
45			for (i, lookup_t) in lookup_t.iter_mut().enumerate() {
46				*lookup_t = packed_from_fn_with_offset(i, |i| {
47					let x = ((i >> 8) & 15) as u16;
48					let y = ((i >> 4) & 15) as u16;
49					let z = (i & 15) as u16;
50
51					let res = f(x as u32, y as u32, z as u32);
52
53					let lookup_index = (((x << 4) | y) << 4) | z;
54					B16::new((lookup_index << 4) | res as u16)
55				});
56			}
57		}
58		Ok(Self {
59			n_lookups: Vec::new(),
60			lookup_t,
61			lookups_u: Vec::new(),
62			u_to_t_mappings: Vec::new(),
63			f,
64		})
65	}
66
67	pub fn calculate(
68		&mut self,
69		builder: &mut ConstraintSystemBuilder,
70		name: impl ToString,
71		params: [OracleId; 3],
72	) -> Result<OracleId> {
73		let [xin, yin, zin] = params;
74
75		let log_size = builder.log_rows(params)?;
76
77		let xin_packed = pack::<B1, B4>(xin, builder, "xin_packed")?;
78		let yin_packed = pack::<B1, B4>(yin, builder, "yin_packed")?;
79		let zin_packed = pack::<B1, B4>(zin, builder, "zin_packed")?;
80
81		let res = builder.add_committed(name, log_size, B1::TOWER_LEVEL);
82
83		let res_packed = builder.add_packed("res_packed", res, B4::TOWER_LEVEL)?;
84
85		let lookup_u = builder.add_linear_combination(
86			"ch or maj lookup_u",
87			log_size - B4::TOWER_LEVEL,
88			[
89				(xin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 3)?),
90				(yin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 2)?),
91				(zin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 1)?),
92				(res_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 0)?),
93			],
94		)?;
95
96		if let Some(witness) = builder.witness() {
97			let mut lookup_u_witness = witness.new_column::<B16>(lookup_u);
98			let lookup_u_u16 = lookup_u_witness.packed();
99
100			let mut u_to_t_mapping_witness = Vec::with_capacity(1 << (log_size - B4::TOWER_LEVEL));
101
102			let mut res_witness = witness.new_column::<B1>(res);
103			let res_u32 = res_witness.as_mut_slice::<u32>();
104
105			let xin_u32 = witness.get::<B1>(xin)?.as_slice::<u32>();
106			let yin_u32 = witness.get::<B1>(yin)?.as_slice::<u32>();
107			let zin_u32 = witness.get::<B1>(zin)?.as_slice::<u32>();
108
109			// The following code works only if packed width is at least 8.
110			// Which should be true in real life cases, becasue we don't use
111			// packed types with less than 128 bits size.
112			if PackedType::<U, B16>::LOG_WIDTH < 3 {
113				return Err(anyhow::anyhow!(
114					"PackedType::<U, B16>::LOG_WIDTH < 3, this is not supported"
115				));
116			}
117
118			for (i, lookup_u) in lookup_u_u16.iter_mut().enumerate() {
119				let offset = i << (PackedType::<U, B16>::LOG_WIDTH - 3);
120				let scalars = (0..(PackedType::<U, B16>::WIDTH / 8)).flat_map(|j| {
121					let index = offset + j;
122					let x = xin_u32[index];
123					let y = yin_u32[index];
124					let z = zin_u32[index];
125					let res = (self.f)(x, y, z);
126					res_u32[index] = res;
127
128					let scalars = std::array::from_fn::<_, 8, _>(|k| {
129						let x = ((x >> (4 * k)) & 15) as u16;
130						let y = ((y >> (4 * k)) & 15) as u16;
131						let z = ((z >> (4 * k)) & 15) as u16;
132						let res = ((res >> (4 * k)) & 15) as u16;
133						let lookup_index = (((x << 4) | y) << 4) | z;
134
135						u_to_t_mapping_witness.push(lookup_index as usize);
136
137						B16::new((lookup_index << 4) | res)
138					});
139
140					scalars.into_iter()
141				});
142				*lookup_u = PackedType::<U, B16>::from_scalars(scalars);
143			}
144
145			std::mem::drop(res_witness);
146
147			let res_packed_witness = witness.get::<B1>(res)?;
148			witness.set::<B4>(res_packed, res_packed_witness.repacked::<B4>())?;
149
150			self.u_to_t_mappings.push(u_to_t_mapping_witness);
151		}
152
153		self.lookups_u.push([lookup_u]);
154		self.n_lookups.push(1 << (log_size - B4::TOWER_LEVEL));
155		Ok(res)
156	}
157
158	pub fn finalize(
159		self,
160		builder: &mut ConstraintSystemBuilder,
161		name: impl ToString,
162	) -> Result<()> {
163		let channel = builder.add_channel();
164
165		lasso::<B32>(
166			builder,
167			name,
168			&self.n_lookups,
169			&self.u_to_t_mappings,
170			&self.lookups_u,
171			[self.lookup_t],
172			channel,
173		)
174	}
175}
176
177pub fn sha256(
178	builder: &mut ConstraintSystemBuilder,
179	input: [OracleId; 16],
180	log_size: usize,
181) -> Result<[OracleId; 8], anyhow::Error> {
182	let n_vars = log_size;
183
184	let mut several_u32_add = SeveralU32add::new(builder)?;
185
186	let mut several_ch = SeveralBitwise::new(builder, |e, f, g| (e & f) ^ ((!e) & g))?;
187
188	let mut several_maj = SeveralBitwise::new(builder, |a, b, c| (a & b) ^ (a & c) ^ (b & c))?;
189
190	let mut w = [OracleId::invalid(); 64];
191
192	w[0..16].copy_from_slice(&input);
193
194	for i in 16..64 {
195		let s0 = rotate_and_xor(
196			n_vars,
197			builder,
198			&[
199				(w[i - 15], 7, RotateRightType::Circular),
200				(w[i - 15], 18, RotateRightType::Circular),
201				(w[i - 15], 3, RotateRightType::Logical),
202			],
203		)?;
204		let s1 = rotate_and_xor(
205			n_vars,
206			builder,
207			&[
208				(w[i - 2], 17, RotateRightType::Circular),
209				(w[i - 2], 19, RotateRightType::Circular),
210				(w[i - 2], 10, RotateRightType::Logical),
211			],
212		)?;
213
214		let w_addition =
215			several_u32_add.u32add::<B1, B1>(builder, "w_addition", w[i - 16], w[i - 7])?;
216
217		let s_addition = several_u32_add.u32add::<B1, B1>(builder, "s_addition", s0, s1)?;
218
219		w[i] =
220			several_u32_add.u32add::<B1, B1>(builder, format!("w[{i}]"), w_addition, s_addition)?;
221	}
222
223	let init_oracles = INIT.map(|val| u32const_repeating(n_vars, builder, val, "INIT").unwrap());
224
225	let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = init_oracles;
226
227	let k = ROUND_CONSTS_K
228		.map(|val| u32const_repeating(n_vars, builder, val, "ROUND_CONSTS_K").unwrap());
229
230	for i in 0..64 {
231		let sigma1 = rotate_and_xor(
232			n_vars,
233			builder,
234			&[
235				(e, 6, RotateRightType::Circular),
236				(e, 11, RotateRightType::Circular),
237				(e, 25, RotateRightType::Circular),
238			],
239		)?;
240
241		let ch = several_ch.calculate(builder, "ch", [e, f, g])?;
242
243		let h_sigma1 = several_u32_add.u32add::<B1, B1>(builder, "h_sigma1", h, sigma1)?;
244		let ch_ki = several_u32_add.u32add::<B1, B1>(builder, "ch_ki", ch, k[i])?;
245		let ch_ki_w_i = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", ch_ki, w[i])?;
246		let temp1 = several_u32_add.u32add::<B1, B1>(builder, "temp1", h_sigma1, ch_ki_w_i)?;
247
248		let sigma0 = rotate_and_xor(
249			n_vars,
250			builder,
251			&[
252				(a, 2, RotateRightType::Circular),
253				(a, 13, RotateRightType::Circular),
254				(a, 22, RotateRightType::Circular),
255			],
256		)?;
257
258		let maj = several_maj.calculate(builder, "maj", [a, b, c])?;
259
260		let temp2 = several_u32_add.u32add::<B1, B1>(builder, "temp2", sigma0, maj)?;
261
262		h = g;
263		g = f;
264		f = e;
265		e = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", d, temp1)?;
266		d = c;
267		c = b;
268		b = a;
269		a = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", temp1, temp2)?;
270	}
271
272	let abcdefgh = [a, b, c, d, e, f, g, h];
273
274	let output = std::array::from_fn(|i| {
275		several_u32_add
276			.u32add::<B1, B1>(builder, "output", init_oracles[i], abcdefgh[i])
277			.unwrap()
278	});
279
280	several_u32_add.finalize(builder, "lasso")?;
281
282	several_ch.finalize(builder, "ch")?;
283	several_maj.finalize(builder, "maj")?;
284
285	Ok(output)
286}
287
288#[cfg(test)]
289mod tests {
290	use binius_core::oracle::OracleId;
291	use binius_field::{as_packed_field::PackedType, BinaryField1b, BinaryField8b, TowerField};
292	use sha2::{compress256, digest::generic_array::GenericArray};
293
294	use crate::{
295		builder::{test_utils::test_circuit, types::U},
296		unconstrained::unconstrained,
297	};
298
299	#[test]
300	fn test_sha256_lasso() {
301		test_circuit(|builder| {
302			let log_size = PackedType::<U, BinaryField1b>::LOG_WIDTH + BinaryField8b::TOWER_LEVEL;
303			let input: [OracleId; 16] = std::array::from_fn(|i| {
304				unconstrained::<BinaryField1b>(builder, i, log_size).unwrap()
305			});
306			let state_output = super::sha256(builder, input, log_size).unwrap();
307
308			if let Some(witness) = builder.witness() {
309				let input_witneses: [_; 16] = std::array::from_fn(|i| {
310					witness
311						.get::<BinaryField1b>(input[i])
312						.unwrap()
313						.as_slice::<u32>()
314				});
315
316				let output_witneses: [_; 8] = std::array::from_fn(|i| {
317					witness
318						.get::<BinaryField1b>(state_output[i])
319						.unwrap()
320						.as_slice::<u32>()
321				});
322
323				let mut generic_array_input = GenericArray::<u8, _>::default();
324
325				let n_compressions = input_witneses[0].len();
326
327				for j in 0..n_compressions {
328					for i in 0..16 {
329						for z in 0..4 {
330							generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z];
331						}
332					}
333
334					let mut output = crate::sha256::INIT;
335					compress256(&mut output, &[generic_array_input]);
336
337					for i in 0..8 {
338						assert_eq!(output[i], output_witneses[i][j]);
339					}
340				}
341			}
342
343			Ok(vec![])
344		})
345		.unwrap();
346	}
347}