1use anyhow::Result;
4use binius_core::oracle::OracleId;
5use binius_field::{
6 as_packed_field::PackedType, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField4b,
7 PackedFieldIndexable, TowerField,
8};
9use itertools::izip;
10
11use super::{lasso::lasso, u32add::SeveralU32add};
12use crate::{
13 arithmetic::u32::u32const_repeating,
14 builder::{
15 types::{F, U},
16 ConstraintSystemBuilder,
17 },
18 pack::pack,
19 sha256::{rotate_and_xor, RotateRightType, INIT, ROUND_CONSTS_K},
20};
21
22pub const CH_MAJ_T_LOG_SIZE: usize = 12;
23
24type B1 = BinaryField1b;
25type B4 = BinaryField4b;
26type B16 = BinaryField16b;
27type B32 = BinaryField32b;
28
29struct SeveralBitwise {
30 n_lookups: Vec<usize>,
31 lookup_t: OracleId,
32 lookups_u: Vec<[OracleId; 1]>,
33 u_to_t_mappings: Vec<Vec<usize>>,
34 f: fn(u32, u32, u32) -> u32,
35}
36
37impl SeveralBitwise {
38 pub fn new(builder: &mut ConstraintSystemBuilder, f: fn(u32, u32, u32) -> u32) -> Result<Self> {
39 let lookup_t =
40 builder.add_committed("bitwise lookup_t", CH_MAJ_T_LOG_SIZE, B16::TOWER_LEVEL);
41
42 if let Some(witness) = builder.witness() {
43 let mut lookup_t_witness = witness.new_column::<B16>(lookup_t);
44
45 let lookup_t_scalars =
46 PackedType::<U, B16>::unpack_scalars_mut(lookup_t_witness.packed());
47
48 for (i, lookup_t) in lookup_t_scalars.iter_mut().enumerate() {
49 let x = ((i >> 8) & 15) as u16;
50 let y = ((i >> 4) & 15) as u16;
51 let z = (i & 15) as u16;
52
53 let res = f(x as u32, y as u32, z as u32);
54
55 let lookup_index = (((x << 4) | y) << 4) | z;
56 *lookup_t = B16::new((lookup_index << 4) | res as u16);
57 }
58 }
59 Ok(Self {
60 n_lookups: Vec::new(),
61 lookup_t,
62 lookups_u: Vec::new(),
63 u_to_t_mappings: Vec::new(),
64 f,
65 })
66 }
67
68 pub fn calculate(
69 &mut self,
70 builder: &mut ConstraintSystemBuilder,
71 name: impl ToString,
72 params: [OracleId; 3],
73 ) -> Result<OracleId> {
74 let [xin, yin, zin] = params;
75
76 let log_size = builder.log_rows(params)?;
77
78 let xin_packed = pack::<B1, B4>(xin, builder, "xin_packed")?;
79 let yin_packed = pack::<B1, B4>(yin, builder, "yin_packed")?;
80 let zin_packed = pack::<B1, B4>(zin, builder, "zin_packed")?;
81
82 let res = builder.add_committed(name, log_size, B1::TOWER_LEVEL);
83
84 let res_packed = builder.add_packed("res_packed", res, B4::TOWER_LEVEL)?;
85
86 let lookup_u = builder.add_linear_combination(
87 "ch or maj lookup_u",
88 log_size - B4::TOWER_LEVEL,
89 [
90 (xin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 3)?),
91 (yin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 2)?),
92 (zin_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 1)?),
93 (res_packed, <F as TowerField>::basis(B4::TOWER_LEVEL, 0)?),
94 ],
95 )?;
96
97 if let Some(witness) = builder.witness() {
98 let mut lookup_u_witness = witness.new_column::<B16>(lookup_u);
99 let lookup_u_u16 = PackedType::<U, B16>::unpack_scalars_mut(lookup_u_witness.packed());
100
101 let mut u_to_t_mapping_witness = Vec::with_capacity(1 << (log_size - B4::TOWER_LEVEL));
102
103 let mut res_witness = witness.new_column::<B1>(res);
104 let res_u32 = res_witness.as_mut_slice::<u32>();
105
106 let xin_u32 = witness.get::<B1>(xin)?.as_slice::<u32>();
107
108 let yin_u32 = witness.get::<B1>(yin)?.as_slice::<u32>();
109
110 let zin_u32 = witness.get::<B1>(zin)?.as_slice::<u32>();
111
112 for (res, x, y, z, lookup_u) in
113 izip!(res_u32.iter_mut(), xin_u32, yin_u32, zin_u32, lookup_u_u16.chunks_mut(8))
114 {
115 *res = (self.f)(*x, *y, *z);
116
117 #[allow(clippy::needless_range_loop)]
118 for i in 0..8 {
119 let x = ((*x >> (4 * i)) & 15) as u16;
120 let y = ((*y >> (4 * i)) & 15) as u16;
121 let z = ((*z >> (4 * i)) & 15) as u16;
122 let res = ((*res >> (4 * i)) & 15) as u16;
123 let lookup_index = (((x << 4) | y) << 4) | z;
124 lookup_u[i] = B16::new((lookup_index << 4) | res);
125 u_to_t_mapping_witness.push(lookup_index as usize)
126 }
127 }
128
129 std::mem::drop(res_witness);
130
131 let res_packed_witness = witness.get::<B1>(res)?;
132 witness.set::<B4>(res_packed, res_packed_witness.repacked::<B4>())?;
133
134 self.u_to_t_mappings.push(u_to_t_mapping_witness);
135 }
136
137 self.lookups_u.push([lookup_u]);
138 self.n_lookups.push(1 << (log_size - B4::TOWER_LEVEL));
139 Ok(res)
140 }
141
142 pub fn finalize(
143 self,
144 builder: &mut ConstraintSystemBuilder,
145 name: impl ToString,
146 ) -> Result<()> {
147 let channel = builder.add_channel();
148
149 lasso::<B32>(
150 builder,
151 name,
152 &self.n_lookups,
153 &self.u_to_t_mappings,
154 &self.lookups_u,
155 [self.lookup_t],
156 channel,
157 )
158 }
159}
160
161pub fn sha256(
162 builder: &mut ConstraintSystemBuilder,
163 input: [OracleId; 16],
164 log_size: usize,
165) -> Result<[OracleId; 8], anyhow::Error> {
166 let n_vars = log_size;
167
168 let mut several_u32_add = SeveralU32add::new(builder)?;
169
170 let mut several_ch = SeveralBitwise::new(builder, |e, f, g| (e & f) ^ ((!e) & g))?;
171
172 let mut several_maj = SeveralBitwise::new(builder, |a, b, c| (a & b) ^ (a & c) ^ (b & c))?;
173
174 let mut w = [OracleId::MAX; 64];
175
176 w[0..16].copy_from_slice(&input);
177
178 for i in 16..64 {
179 let s0 = rotate_and_xor(
180 n_vars,
181 builder,
182 &[
183 (w[i - 15], 7, RotateRightType::Circular),
184 (w[i - 15], 18, RotateRightType::Circular),
185 (w[i - 15], 3, RotateRightType::Logical),
186 ],
187 )?;
188 let s1 = rotate_and_xor(
189 n_vars,
190 builder,
191 &[
192 (w[i - 2], 17, RotateRightType::Circular),
193 (w[i - 2], 19, RotateRightType::Circular),
194 (w[i - 2], 10, RotateRightType::Logical),
195 ],
196 )?;
197
198 let w_addition =
199 several_u32_add.u32add::<B1, B1>(builder, "w_addition", w[i - 16], w[i - 7])?;
200
201 let s_addition = several_u32_add.u32add::<B1, B1>(builder, "s_addition", s0, s1)?;
202
203 w[i] = several_u32_add.u32add::<B1, B1>(
204 builder,
205 format!("w[{}]", i),
206 w_addition,
207 s_addition,
208 )?;
209 }
210
211 let init_oracles = INIT.map(|val| u32const_repeating(n_vars, builder, val, "INIT").unwrap());
212
213 let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = init_oracles;
214
215 let k = ROUND_CONSTS_K
216 .map(|val| u32const_repeating(n_vars, builder, val, "ROUND_CONSTS_K").unwrap());
217
218 for i in 0..64 {
219 let sigma1 = rotate_and_xor(
220 n_vars,
221 builder,
222 &[
223 (e, 6, RotateRightType::Circular),
224 (e, 11, RotateRightType::Circular),
225 (e, 25, RotateRightType::Circular),
226 ],
227 )?;
228
229 let ch = several_ch.calculate(builder, "ch", [e, f, g])?;
230
231 let h_sigma1 = several_u32_add.u32add::<B1, B1>(builder, "h_sigma1", h, sigma1)?;
232 let ch_ki = several_u32_add.u32add::<B1, B1>(builder, "ch_ki", ch, k[i])?;
233 let ch_ki_w_i = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", ch_ki, w[i])?;
234 let temp1 = several_u32_add.u32add::<B1, B1>(builder, "temp1", h_sigma1, ch_ki_w_i)?;
235
236 let sigma0 = rotate_and_xor(
237 n_vars,
238 builder,
239 &[
240 (a, 2, RotateRightType::Circular),
241 (a, 13, RotateRightType::Circular),
242 (a, 22, RotateRightType::Circular),
243 ],
244 )?;
245
246 let maj = several_maj.calculate(builder, "maj", [a, b, c])?;
247
248 let temp2 = several_u32_add.u32add::<B1, B1>(builder, "temp2", sigma0, maj)?;
249
250 h = g;
251 g = f;
252 f = e;
253 e = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", d, temp1)?;
254 d = c;
255 c = b;
256 b = a;
257 a = several_u32_add.u32add::<B1, B1>(builder, "ch_ki_w_i", temp1, temp2)?;
258 }
259
260 let abcdefgh = [a, b, c, d, e, f, g, h];
261
262 let output = std::array::from_fn(|i| {
263 several_u32_add
264 .u32add::<B1, B1>(builder, "output", init_oracles[i], abcdefgh[i])
265 .unwrap()
266 });
267
268 several_u32_add.finalize(builder, "lasso")?;
269
270 several_ch.finalize(builder, "ch")?;
271 several_maj.finalize(builder, "maj")?;
272
273 Ok(output)
274}
275
276#[cfg(test)]
277mod tests {
278 use binius_core::oracle::OracleId;
279 use binius_field::{as_packed_field::PackedType, BinaryField1b, BinaryField8b, TowerField};
280 use sha2::{compress256, digest::generic_array::GenericArray};
281
282 use crate::{
283 builder::{test_utils::test_circuit, types::U},
284 unconstrained::unconstrained,
285 };
286
287 #[test]
288 fn test_sha256_lasso() {
289 test_circuit(|builder| {
290 let log_size = PackedType::<U, BinaryField1b>::LOG_WIDTH + BinaryField8b::TOWER_LEVEL;
291 let input: [OracleId; 16] = std::array::from_fn(|i| {
292 unconstrained::<BinaryField1b>(builder, i, log_size).unwrap()
293 });
294 let state_output = super::sha256(builder, input, log_size).unwrap();
295
296 if let Some(witness) = builder.witness() {
297 let input_witneses: [_; 16] = std::array::from_fn(|i| {
298 witness
299 .get::<BinaryField1b>(input[i])
300 .unwrap()
301 .as_slice::<u32>()
302 });
303
304 let output_witneses: [_; 8] = std::array::from_fn(|i| {
305 witness
306 .get::<BinaryField1b>(state_output[i])
307 .unwrap()
308 .as_slice::<u32>()
309 });
310
311 let mut generic_array_input = GenericArray::<u8, _>::default();
312
313 let n_compressions = input_witneses[0].len();
314
315 for j in 0..n_compressions {
316 for i in 0..16 {
317 for z in 0..4 {
318 generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z];
319 }
320 }
321
322 let mut output = crate::sha256::INIT;
323 compress256(&mut output, &[generic_array_input]);
324
325 for i in 0..8 {
326 assert_eq!(output[i], output_witneses[i][j]);
327 }
328 }
329 }
330
331 Ok(vec![])
332 })
333 .unwrap();
334 }
335}