1use anyhow::Result;
4use binius_core::oracle::OracleId;
5use binius_field::{
6 as_packed_field::PackedType, packed::packed_from_fn_with_offset, BinaryField16b, BinaryField1b,
7 BinaryField32b, BinaryField4b, PackedField, TowerField,
8};
9
10use super::{lasso::lasso, u32add::SeveralU32add};
11use crate::{
12 arithmetic::u32::u32const_repeating,
13 builder::{
14 types::{F, U},
15 ConstraintSystemBuilder,
16 },
17 pack::pack,
18 sha256::{rotate_and_xor, RotateRightType, INIT, ROUND_CONSTS_K},
19};
20
21pub const CH_MAJ_T_LOG_SIZE: usize = 12;
22
23type B1 = BinaryField1b;
24type B4 = BinaryField4b;
25type B16 = BinaryField16b;
26type B32 = BinaryField32b;
27
28struct SeveralBitwise {
29 n_lookups: Vec<usize>,
30 lookup_t: OracleId,
31 lookups_u: Vec<[OracleId; 1]>,
32 u_to_t_mappings: Vec<Vec<usize>>,
33 f: fn(u32, u32, u32) -> u32,
34}
35
36impl SeveralBitwise {
37 pub fn new(builder: &mut ConstraintSystemBuilder, f: fn(u32, u32, u32) -> u32) -> Result<Self> {
38 let lookup_t =
39 builder.add_committed("bitwise lookup_t", CH_MAJ_T_LOG_SIZE, B16::TOWER_LEVEL);
40
41 if let Some(witness) = builder.witness() {
42 let mut lookup_t_witness = witness.new_column::<B16>(lookup_t);
43
44 let lookup_t = lookup_t_witness.packed();
45 for (i, lookup_t) in lookup_t.iter_mut().enumerate() {
46 *lookup_t = packed_from_fn_with_offset(i, |i| {
47 let x = ((i >> 8) & 15) as u16;
48 let y = ((i >> 4) & 15) as u16;
49 let z = (i & 15) as u16;
50
51 let res = f(x as u32, y as u32, z as u32);
52
53 let lookup_index = (((x << 4) | y) << 4) | z;
54 B16::new((lookup_index << 4) | res as u16)
55 });
56 }
57 }
58 Ok(Self {
59 n_lookups: Vec::new(),
60 lookup_t,
61 lookups_u: Vec::new(),
62 u_to_t_mappings: Vec::new(),
63 f,
64 })
65 }
66
67 pub fn calculate(
68 &mut self,
69 builder: &mut ConstraintSystemBuilder,
70 name: impl ToString,
71 params: [OracleId; 3],
72 ) -> Result<OracleId> {
73 let [xin, yin, zin] = params;
74
75 let log_size = builder.log_rows(params)?;
76
77 let xin_packed = pack::<B1, B4>(xin, builder, "xin_packed")?;
78 let yin_packed = pack::<B1, B4>(yin, builder, "yin_packed")?;
79 let zin_packed = pack::<B1, B4>(zin, builder, "zin_packed")?;
80
81 let res = builder.add_committed(name, log_size, B1::TOWER_LEVEL);
82
83 let res_packed = builder.add_packed("res_packed", res, B4::TOWER_LEVEL)?;
84
85 let lookup_u = builder.add_linear_combination(
86 "ch or maj lookup_u",
87 log_size - B4::TOWER_LEVEL,
88 [
89 (xin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 3)?),
90 (yin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 2)?),
91 (zin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 1)?),
92 (res_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 0)?),
93 ],
94 )?;
95
96 if let Some(witness) = builder.witness() {
97 let mut lookup_u_witness = witness.new_column::<B16>(lookup_u);
98 let lookup_u_u16 = lookup_u_witness.packed();
99
100 let mut u_to_t_mapping_witness = Vec::with_capacity(1 << (log_size - B4::TOWER_LEVEL));
101
102 let mut res_witness = witness.new_column::<B1>(res);
103 let res_u32 = res_witness.as_mut_slice::<u32>();
104
105 let xin_u32 = witness.get::<B1>(xin)?.as_slice::<u32>();
106 let yin_u32 = witness.get::<B1>(yin)?.as_slice::<u32>();
107 let zin_u32 = witness.get::<B1>(zin)?.as_slice::<u32>();
108
109 if PackedType::<U, B16>::LOG_WIDTH < 3 {
113 return Err(anyhow::anyhow!(
114 "PackedType::<U, B16>::LOG_WIDTH < 3, this is not supported"
115 ));
116 }
117
118 for (i, lookup_u) in lookup_u_u16.iter_mut().enumerate() {
119 let offset = i << (PackedType::<U, B16>::LOG_WIDTH - 3);
120 let scalars = (0..(PackedType::<U, B16>::WIDTH / 8)).flat_map(|j| {
121 let index = offset + j;
122 let x = xin_u32[index];
123 let y = yin_u32[index];
124 let z = zin_u32[index];
125 let res = (self.f)(x, y, z);
126 res_u32[index] = res;
127
128 let scalars = std::array::from_fn::<_, 8, _>(|k| {
129 let x = ((x >> (4 * k)) & 15) as u16;
130 let y = ((y >> (4 * k)) & 15) as u16;
131 let z = ((z >> (4 * k)) & 15) as u16;
132 let res = ((res >> (4 * k)) & 15) as u16;
133 let lookup_index = (((x << 4) | y) << 4) | z;
134
135 u_to_t_mapping_witness.push(lookup_index as usize);
136
137 B16::new((lookup_index << 4) | res)
138 });
139
140 scalars.into_iter()
141 });
142 *lookup_u = PackedType::<U, B16>::from_scalars(scalars);
143 }
144
145 std::mem::drop(res_witness);
146
147 let res_packed_witness = witness.get::<B1>(res)?;
148 witness.set::<B4>(res_packed, res_packed_witness.repacked::<B4>())?;
149
150 self.u_to_t_mappings.push(u_to_t_mapping_witness);
151 }
152
153 self.lookups_u.push([lookup_u]);
154 self.n_lookups.push(1 << (log_size - B4::TOWER_LEVEL));
155 Ok(res)
156 }
157
158 pub fn finalize(
159 self,
160 builder: &mut ConstraintSystemBuilder,
161 name: impl ToString,
162 ) -> Result<()> {
163 let channel = builder.add_channel();
164
165 lasso::<B32>(
166 builder,
167 name,
168 &self.n_lookups,
169 &self.u_to_t_mappings,
170 &self.lookups_u,
171 [self.lookup_t],
172 channel,
173 )
174 }
175}
176
177pub fn sha256(
178 builder: &mut ConstraintSystemBuilder,
179 input: [OracleId; 16],
180 log_size: usize,
181) -> Result<[OracleId; 8], anyhow::Error> {
182 let n_vars = log_size;
183
184 let mut several_u32_add = SeveralU32add::new(builder)?;
185
186 let mut several_ch = SeveralBitwise::new(builder, |e, f, g| (e & f) ^ ((!e) & g))?;
187
188 let mut several_maj = SeveralBitwise::new(builder, |a, b, c| (a & b) ^ (a & c) ^ (b & c))?;
189
190 let mut w = [OracleId::MAX; 64];
191
192 w[0..16].copy_from_slice(&input);
193
194 for i in 16..64 {
195 let s0 = rotate_and_xor(
196 n_vars,
197 builder,
198 &[
199 (w[i - 15], 7, RotateRightType::Circular),
200 (w[i - 15], 18, RotateRightType::Circular),
201 (w[i - 15], 3, RotateRightType::Logical),
202 ],
203 )?;
204 let s1 = rotate_and_xor(
205 n_vars,
206 builder,
207 &[
208 (w[i - 2], 17, RotateRightType::Circular),
209 (w[i - 2], 19, RotateRightType::Circular),
210 (w[i - 2], 10, RotateRightType::Logical),
211 ],
212 )?;
213
214 let w_addition =
215 several_u32_add.u32add::<B1, B1>(builder, "w_addition", w[i - 16], w[i - 7])?;
216
217 let s_addition = several_u32_add.u32add::<B1, B1>(builder, "s_addition", s0, s1)?;
218
219 w[i] = several_u32_add.u32add::<B1, B1>(
220 builder,
221 format!("w[{}]", i),
222 w_addition,
223 s_addition,
224 )?;
225 }
226
227 let init_oracles = INIT.map(|val| u32const_repeating(n_vars, builder, val, "INIT").unwrap());
228
229 let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = init_oracles;
230
231 let k = ROUND_CONSTS_K
232 .map(|val| u32const_repeating(n_vars, builder, val, "ROUND_CONSTS_K").unwrap());
233
234 for i in 0..64 {
235 let sigma1 = rotate_and_xor(
236 n_vars,
237 builder,
238 &[
239 (e, 6, RotateRightType::Circular),
240 (e, 11, RotateRightType::Circular),
241 (e, 25, RotateRightType::Circular),
242 ],
243 )?;
244
245 let ch = several_ch.calculate(builder, "ch", [e, f, g])?;
246
247 let h_sigma1 = several_u32_add.u32add::<B1, B1>(builder, "h_sigma1", h, sigma1)?;
248 let ch_ki = several_u32_add.u32add::<B1, B1>(builder, "ch_ki", ch, k[i])?;
249 let ch_ki_w_i = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", ch_ki, w[i])?;
250 let temp1 = several_u32_add.u32add::<B1, B1>(builder, "temp1", h_sigma1, ch_ki_w_i)?;
251
252 let sigma0 = rotate_and_xor(
253 n_vars,
254 builder,
255 &[
256 (a, 2, RotateRightType::Circular),
257 (a, 13, RotateRightType::Circular),
258 (a, 22, RotateRightType::Circular),
259 ],
260 )?;
261
262 let maj = several_maj.calculate(builder, "maj", [a, b, c])?;
263
264 let temp2 = several_u32_add.u32add::<B1, B1>(builder, "temp2", sigma0, maj)?;
265
266 h = g;
267 g = f;
268 f = e;
269 e = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", d, temp1)?;
270 d = c;
271 c = b;
272 b = a;
273 a = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", temp1, temp2)?;
274 }
275
276 let abcdefgh = [a, b, c, d, e, f, g, h];
277
278 let output = std::array::from_fn(|i| {
279 several_u32_add
280 .u32add::<B1, B1>(builder, "output", init_oracles[i], abcdefgh[i])
281 .unwrap()
282 });
283
284 several_u32_add.finalize(builder, "lasso")?;
285
286 several_ch.finalize(builder, "ch")?;
287 several_maj.finalize(builder, "maj")?;
288
289 Ok(output)
290}
291
292#[cfg(test)]
293mod tests {
294 use binius_core::oracle::OracleId;
295 use binius_field::{as_packed_field::PackedType, BinaryField1b, BinaryField8b, TowerField};
296 use sha2::{compress256, digest::generic_array::GenericArray};
297
298 use crate::{
299 builder::{test_utils::test_circuit, types::U},
300 unconstrained::unconstrained,
301 };
302
303 #[test]
304 fn test_sha256_lasso() {
305 test_circuit(|builder| {
306 let log_size = PackedType::<U, BinaryField1b>::LOG_WIDTH + BinaryField8b::TOWER_LEVEL;
307 let input: [OracleId; 16] = std::array::from_fn(|i| {
308 unconstrained::<BinaryField1b>(builder, i, log_size).unwrap()
309 });
310 let state_output = super::sha256(builder, input, log_size).unwrap();
311
312 if let Some(witness) = builder.witness() {
313 let input_witneses: [_; 16] = std::array::from_fn(|i| {
314 witness
315 .get::<BinaryField1b>(input[i])
316 .unwrap()
317 .as_slice::<u32>()
318 });
319
320 let output_witneses: [_; 8] = std::array::from_fn(|i| {
321 witness
322 .get::<BinaryField1b>(state_output[i])
323 .unwrap()
324 .as_slice::<u32>()
325 });
326
327 let mut generic_array_input = GenericArray::<u8, _>::default();
328
329 let n_compressions = input_witneses[0].len();
330
331 for j in 0..n_compressions {
332 for i in 0..16 {
333 for z in 0..4 {
334 generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z];
335 }
336 }
337
338 let mut output = crate::sha256::INIT;
339 compress256(&mut output, &[generic_array_input]);
340
341 for i in 0..8 {
342 assert_eq!(output[i], output_witneses[i][j]);
343 }
344 }
345 }
346
347 Ok(vec![])
348 })
349 .unwrap();
350 }
351}