binius_circuits/lasso/
sha256.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use anyhow::Result;
4use binius_core::oracle::OracleId;
5use binius_field::{
6	as_packed_field::PackedType, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField4b,
7	PackedFieldIndexable, TowerField,
8};
9use itertools::izip;
10
11use super::{lasso::lasso, u32add::SeveralU32add};
12use crate::{
13	arithmetic::u32::u32const_repeating,
14	builder::{
15		types::{F, U},
16		ConstraintSystemBuilder,
17	},
18	pack::pack,
19	sha256::{rotate_and_xor, RotateRightType, INIT, ROUND_CONSTS_K},
20};
21
22pub const CH_MAJ_T_LOG_SIZE: usize = 12;
23
24type B1 = BinaryField1b;
25type B4 = BinaryField4b;
26type B16 = BinaryField16b;
27type B32 = BinaryField32b;
28
29struct SeveralBitwise {
30	n_lookups: Vec<usize>,
31	lookup_t: OracleId,
32	lookups_u: Vec<[OracleId; 1]>,
33	u_to_t_mappings: Vec<Vec<usize>>,
34	f: fn(u32, u32, u32) -> u32,
35}
36
37impl SeveralBitwise {
38	pub fn new(builder: &mut ConstraintSystemBuilder, f: fn(u32, u32, u32) -> u32) -> Result<Self> {
39		let lookup_t =
40			builder.add_committed("bitwise lookup_t", CH_MAJ_T_LOG_SIZE, B16::TOWER_LEVEL);
41
42		if let Some(witness) = builder.witness() {
43			let mut lookup_t_witness = witness.new_column::<B16>(lookup_t);
44
45			let lookup_t_scalars =
46				PackedType::<U, B16>::unpack_scalars_mut(lookup_t_witness.packed());
47
48			for (i, lookup_t) in lookup_t_scalars.iter_mut().enumerate() {
49				let x = ((i >> 8) & 15) as u16;
50				let y = ((i >> 4) & 15) as u16;
51				let z = (i & 15) as u16;
52
53				let res = f(x as u32, y as u32, z as u32);
54
55				let lookup_index = (((x << 4) | y) << 4) | z;
56				*lookup_t = B16::new((lookup_index << 4) | res as u16);
57			}
58		}
59		Ok(Self {
60			n_lookups: Vec::new(),
61			lookup_t,
62			lookups_u: Vec::new(),
63			u_to_t_mappings: Vec::new(),
64			f,
65		})
66	}
67
68	pub fn calculate(
69		&mut self,
70		builder: &mut ConstraintSystemBuilder,
71		name: impl ToString,
72		params: [OracleId; 3],
73	) -> Result<OracleId> {
74		let [xin, yin, zin] = params;
75
76		let log_size = builder.log_rows(params)?;
77
78		let xin_packed = pack::<B1, B4>(xin, builder, "xin_packed")?;
79		let yin_packed = pack::<B1, B4>(yin, builder, "yin_packed")?;
80		let zin_packed = pack::<B1, B4>(zin, builder, "zin_packed")?;
81
82		let res = builder.add_committed(name, log_size, B1::TOWER_LEVEL);
83
84		let res_packed = builder.add_packed("res_packed", res, B4::TOWER_LEVEL)?;
85
86		let lookup_u = builder.add_linear_combination(
87			"ch or maj lookup_u",
88			log_size - B4::TOWER_LEVEL,
89			[
90				(xin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 3)?),
91				(yin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 2)?),
92				(zin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 1)?),
93				(res_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 0)?),
94			],
95		)?;
96
97		if let Some(witness) = builder.witness() {
98			let mut lookup_u_witness = witness.new_column::<B16>(lookup_u);
99			let lookup_u_u16 = PackedType::<U, B16>::unpack_scalars_mut(lookup_u_witness.packed());
100
101			let mut u_to_t_mapping_witness = Vec::with_capacity(1 << (log_size - B4::TOWER_LEVEL));
102
103			let mut res_witness = witness.new_column::<B1>(res);
104			let res_u32 = res_witness.as_mut_slice::<u32>();
105
106			let xin_u32 = witness.get::<B1>(xin)?.as_slice::<u32>();
107
108			let yin_u32 = witness.get::<B1>(yin)?.as_slice::<u32>();
109
110			let zin_u32 = witness.get::<B1>(zin)?.as_slice::<u32>();
111
112			for (res, x, y, z, lookup_u) in
113				izip!(res_u32.iter_mut(), xin_u32, yin_u32, zin_u32, lookup_u_u16.chunks_mut(8))
114			{
115				*res = (self.f)(*x, *y, *z);
116
117				#[allow(clippy::needless_range_loop)]
118				for i in 0..8 {
119					let x = ((*x >> (4 * i)) & 15) as u16;
120					let y = ((*y >> (4 * i)) & 15) as u16;
121					let z = ((*z >> (4 * i)) & 15) as u16;
122					let res = ((*res >> (4 * i)) & 15) as u16;
123					let lookup_index = (((x << 4) | y) << 4) | z;
124					lookup_u[i] = B16::new((lookup_index << 4) | res);
125					u_to_t_mapping_witness.push(lookup_index as usize)
126				}
127			}
128
129			std::mem::drop(res_witness);
130
131			let res_packed_witness = witness.get::<B1>(res)?;
132			witness.set::<B4>(res_packed, res_packed_witness.repacked::<B4>())?;
133
134			self.u_to_t_mappings.push(u_to_t_mapping_witness);
135		}
136
137		self.lookups_u.push([lookup_u]);
138		self.n_lookups.push(1 << (log_size - B4::TOWER_LEVEL));
139		Ok(res)
140	}
141
142	pub fn finalize(
143		self,
144		builder: &mut ConstraintSystemBuilder,
145		name: impl ToString,
146	) -> Result<()> {
147		let channel = builder.add_channel();
148
149		lasso::<B32>(
150			builder,
151			name,
152			&self.n_lookups,
153			&self.u_to_t_mappings,
154			&self.lookups_u,
155			[self.lookup_t],
156			channel,
157		)
158	}
159}
160
161pub fn sha256(
162	builder: &mut ConstraintSystemBuilder,
163	input: [OracleId; 16],
164	log_size: usize,
165) -> Result<[OracleId; 8], anyhow::Error> {
166	let n_vars = log_size;
167
168	let mut several_u32_add = SeveralU32add::new(builder)?;
169
170	let mut several_ch = SeveralBitwise::new(builder, |e, f, g| (e & f) ^ ((!e) & g))?;
171
172	let mut several_maj = SeveralBitwise::new(builder, |a, b, c| (a & b) ^ (a & c) ^ (b & c))?;
173
174	let mut w = [OracleId::MAX; 64];
175
176	w[0..16].copy_from_slice(&input);
177
178	for i in 16..64 {
179		let s0 = rotate_and_xor(
180			n_vars,
181			builder,
182			&[
183				(w[i - 15], 7, RotateRightType::Circular),
184				(w[i - 15], 18, RotateRightType::Circular),
185				(w[i - 15], 3, RotateRightType::Logical),
186			],
187		)?;
188		let s1 = rotate_and_xor(
189			n_vars,
190			builder,
191			&[
192				(w[i - 2], 17, RotateRightType::Circular),
193				(w[i - 2], 19, RotateRightType::Circular),
194				(w[i - 2], 10, RotateRightType::Logical),
195			],
196		)?;
197
198		let w_addition =
199			several_u32_add.u32add::<B1, B1>(builder, "w_addition", w[i - 16], w[i - 7])?;
200
201		let s_addition = several_u32_add.u32add::<B1, B1>(builder, "s_addition", s0, s1)?;
202
203		w[i] = several_u32_add.u32add::<B1, B1>(
204			builder,
205			format!("w[{}]", i),
206			w_addition,
207			s_addition,
208		)?;
209	}
210
211	let init_oracles = INIT.map(|val| u32const_repeating(n_vars, builder, val, "INIT").unwrap());
212
213	let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = init_oracles;
214
215	let k = ROUND_CONSTS_K
216		.map(|val| u32const_repeating(n_vars, builder, val, "ROUND_CONSTS_K").unwrap());
217
218	for i in 0..64 {
219		let sigma1 = rotate_and_xor(
220			n_vars,
221			builder,
222			&[
223				(e, 6, RotateRightType::Circular),
224				(e, 11, RotateRightType::Circular),
225				(e, 25, RotateRightType::Circular),
226			],
227		)?;
228
229		let ch = several_ch.calculate(builder, "ch", [e, f, g])?;
230
231		let h_sigma1 = several_u32_add.u32add::<B1, B1>(builder, "h_sigma1", h, sigma1)?;
232		let ch_ki = several_u32_add.u32add::<B1, B1>(builder, "ch_ki", ch, k[i])?;
233		let ch_ki_w_i = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", ch_ki, w[i])?;
234		let temp1 = several_u32_add.u32add::<B1, B1>(builder, "temp1", h_sigma1, ch_ki_w_i)?;
235
236		let sigma0 = rotate_and_xor(
237			n_vars,
238			builder,
239			&[
240				(a, 2, RotateRightType::Circular),
241				(a, 13, RotateRightType::Circular),
242				(a, 22, RotateRightType::Circular),
243			],
244		)?;
245
246		let maj = several_maj.calculate(builder, "maj", [a, b, c])?;
247
248		let temp2 = several_u32_add.u32add::<B1, B1>(builder, "temp2", sigma0, maj)?;
249
250		h = g;
251		g = f;
252		f = e;
253		e = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", d, temp1)?;
254		d = c;
255		c = b;
256		b = a;
257		a = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", temp1, temp2)?;
258	}
259
260	let abcdefgh = [a, b, c, d, e, f, g, h];
261
262	let output = std::array::from_fn(|i| {
263		several_u32_add
264			.u32add::<B1, B1>(builder, "output", init_oracles[i], abcdefgh[i])
265			.unwrap()
266	});
267
268	several_u32_add.finalize(builder, "lasso")?;
269
270	several_ch.finalize(builder, "ch")?;
271	several_maj.finalize(builder, "maj")?;
272
273	Ok(output)
274}
275
276#[cfg(test)]
277mod tests {
278	use binius_core::oracle::OracleId;
279	use binius_field::{as_packed_field::PackedType, BinaryField1b, BinaryField8b, TowerField};
280	use sha2::{compress256, digest::generic_array::GenericArray};
281
282	use crate::{
283		builder::{test_utils::test_circuit, types::U},
284		unconstrained::unconstrained,
285	};
286
287	#[test]
288	fn test_sha256_lasso() {
289		test_circuit(|builder| {
290			let log_size = PackedType::<U, BinaryField1b>::LOG_WIDTH + BinaryField8b::TOWER_LEVEL;
291			let input: [OracleId; 16] = std::array::from_fn(|i| {
292				unconstrained::<BinaryField1b>(builder, i, log_size).unwrap()
293			});
294			let state_output = super::sha256(builder, input, log_size).unwrap();
295
296			if let Some(witness) = builder.witness() {
297				let input_witneses: [_; 16] = std::array::from_fn(|i| {
298					witness
299						.get::<BinaryField1b>(input[i])
300						.unwrap()
301						.as_slice::<u32>()
302				});
303
304				let output_witneses: [_; 8] = std::array::from_fn(|i| {
305					witness
306						.get::<BinaryField1b>(state_output[i])
307						.unwrap()
308						.as_slice::<u32>()
309				});
310
311				let mut generic_array_input = GenericArray::<u8, _>::default();
312
313				let n_compressions = input_witneses[0].len();
314
315				for j in 0..n_compressions {
316					for i in 0..16 {
317						for z in 0..4 {
318							generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z];
319						}
320					}
321
322					let mut output = crate::sha256::INIT;
323					compress256(&mut output, &[generic_array_input]);
324
325					for i in 0..8 {
326						assert_eq!(output[i], output_witneses[i][j]);
327					}
328				}
329			}
330
331			Ok(vec![])
332		})
333		.unwrap();
334	}
335}