1use binius_core::oracle::{OracleId, ShiftVariant};
4use binius_field::{as_packed_field::PackedType, BinaryField1b, Field, TowerField};
5use binius_macros::arith_expr;
6use itertools::izip;
7
8use crate::{
9 arithmetic,
10 arithmetic::u32::{u32const_repeating, LOG_U32_BITS},
11 builder::{types::U, ConstraintSystemBuilder},
12};
13
14type B1 = BinaryField1b;
15
16pub const ROUND_CONSTS_K: [u32; 64] = [
18 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
19 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
20 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
21 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
22 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
23 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
24 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
25 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
26];
27
28pub const INIT: [u32; 8] = [
29 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
30];
31
32pub enum RotateRightType {
33 Circular,
34 Logical,
35}
36
37pub fn rotate_and_xor(
38 log_size: usize,
39 builder: &mut ConstraintSystemBuilder,
40 r: &[(OracleId, usize, RotateRightType)],
41) -> Result<OracleId, anyhow::Error> {
42 let shifted_oracle_ids = r
43 .iter()
44 .map(|(oracle_id, shift, t)| {
45 match t {
46 RotateRightType::Circular => builder.add_shifted(
47 format!("RotateRightType::Circular shift:{} oracle_id: {}", shift, oracle_id),
48 *oracle_id,
49 32 - shift,
50 LOG_U32_BITS,
51 ShiftVariant::CircularLeft,
52 ),
53 RotateRightType::Logical => builder.add_shifted(
54 format!("RotateRightType::Logical shift:{} oracle_id: {}", shift, oracle_id),
55 *oracle_id,
56 *shift,
57 LOG_U32_BITS,
58 ShiftVariant::LogicalRight,
59 ),
60 }
61 .map_err(|e| e.into())
62 })
63 .collect::<Result<Vec<_>, anyhow::Error>>()?;
64
65 let result_oracle_id = builder.add_linear_combination(
66 format!("linear combination of {:?}", shifted_oracle_ids),
67 log_size,
68 shifted_oracle_ids.iter().map(|s| (*s, Field::ONE)),
69 )?;
70
71 if let Some(witness) = builder.witness() {
72 let mut result_witness = witness.new_column::<B1>(result_oracle_id);
73 let result_u32 = result_witness.as_mut_slice::<u32>();
74
75 for ((oracle_id, shift, t), shifted_oracle_id) in r.iter().zip(&shifted_oracle_ids) {
76 let values_u32 = witness.get::<B1>(*oracle_id)?.as_slice::<u32>();
77
78 let mut shifted_witness = witness.new_column::<B1>(*shifted_oracle_id);
79 let shifted_u32 = shifted_witness.as_mut_slice::<u32>();
80
81 izip!(shifted_u32.iter_mut(), values_u32, result_u32.iter_mut()).for_each(
82 |(shifted, val, res)| {
83 *shifted = match t {
84 RotateRightType::Circular => val.rotate_right(*shift as u32),
85 RotateRightType::Logical => val >> shift,
86 };
87 *res ^= *shifted;
88 },
89 );
90 }
91 }
92
93 Ok(result_oracle_id)
94}
95
96pub fn sha256(
97 builder: &mut ConstraintSystemBuilder,
98 input: [OracleId; 16],
99 log_size: usize,
100) -> Result<[OracleId; 8], anyhow::Error> {
101 if log_size < <PackedType<U, BinaryField1b>>::LOG_WIDTH {
102 Err(anyhow::Error::msg("log_size too small"))?
103 }
104
105 let mut w = [OracleId::MAX; 64];
106
107 w[0..16].copy_from_slice(&input);
108
109 for i in 16..64 {
110 let s0 = rotate_and_xor(
111 log_size,
112 builder,
113 &[
114 (w[i - 15], 7, RotateRightType::Circular),
115 (w[i - 15], 18, RotateRightType::Circular),
116 (w[i - 15], 3, RotateRightType::Logical),
117 ],
118 )?;
119 let s1 = rotate_and_xor(
120 log_size,
121 builder,
122 &[
123 (w[i - 2], 17, RotateRightType::Circular),
124 (w[i - 2], 19, RotateRightType::Circular),
125 (w[i - 2], 10, RotateRightType::Logical),
126 ],
127 )?;
128 let w_addition = arithmetic::u32::add(
129 builder,
130 "w_addition",
131 w[i - 16],
132 w[i - 7],
133 arithmetic::Flags::Unchecked,
134 )?;
135 let s_addition =
136 arithmetic::u32::add(builder, "s_addition", s0, s1, arithmetic::Flags::Unchecked)?;
137
138 w[i] = arithmetic::u32::add(
139 builder,
140 format!("w[{}]", i),
141 w_addition,
142 s_addition,
143 arithmetic::Flags::Unchecked,
144 )?;
145 }
146
147 let init_oracles = INIT.map(|val| u32const_repeating(log_size, builder, val, "INIT").unwrap());
148
149 let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = init_oracles;
150
151 let k = ROUND_CONSTS_K
152 .map(|val| u32const_repeating(log_size, builder, val, "ROUND_CONSTS_K").unwrap());
153
154 let ch: [OracleId; 64] = builder.add_committed_multiple("ch", log_size, B1::TOWER_LEVEL);
155
156 let maj: [OracleId; 64] = builder.add_committed_multiple("maj", log_size, B1::TOWER_LEVEL);
157
158 for i in 0..64 {
159 let sigma1 = rotate_and_xor(
160 log_size,
161 builder,
162 &[
163 (e, 6, RotateRightType::Circular),
164 (e, 11, RotateRightType::Circular),
165 (e, 25, RotateRightType::Circular),
166 ],
167 )?;
168
169 if let Some(witness) = builder.witness() {
170 let mut ch_witness = witness.new_column::<B1>(ch[i]);
171 let ch_u32 = ch_witness.as_mut_slice::<u32>();
172 let e_u32 = witness.get::<B1>(e)?.as_slice::<u32>();
173 let f_u32 = witness.get::<B1>(f)?.as_slice::<u32>();
174 let g_u32 = witness.get::<B1>(g)?.as_slice::<u32>();
175 izip!(ch_u32.iter_mut(), e_u32, f_u32, g_u32).for_each(|(ch, e, f, g)| {
176 *ch = g ^ (e & (f ^ g));
177 });
178 }
179
180 let h_sigma1 =
181 arithmetic::u32::add(builder, "h_sigma1", h, sigma1, arithmetic::Flags::Unchecked)?;
182 let ch_ki =
183 arithmetic::u32::add(builder, "ch_ki", ch[i], k[i], arithmetic::Flags::Unchecked)?;
184 let ch_ki_w_i =
185 arithmetic::u32::add(builder, "ch_ki_w_i", ch_ki, w[i], arithmetic::Flags::Unchecked)?;
186 let temp1 = arithmetic::u32::add(
187 builder,
188 "temp1",
189 h_sigma1,
190 ch_ki_w_i,
191 arithmetic::Flags::Unchecked,
192 )?;
193
194 let sigma0 = rotate_and_xor(
195 log_size,
196 builder,
197 &[
198 (a, 2, RotateRightType::Circular),
199 (a, 13, RotateRightType::Circular),
200 (a, 22, RotateRightType::Circular),
201 ],
202 )?;
203
204 if let Some(witness) = builder.witness() {
205 let mut maj_witness = witness.new_column::<B1>(maj[i]);
206 let maj_u32 = maj_witness.as_mut_slice::<u32>();
207 let a_u32 = witness.get::<B1>(a)?.as_slice::<u32>();
208 let b_u32 = witness.get::<B1>(b)?.as_slice::<u32>();
209 let c_u32 = witness.get::<B1>(c)?.as_slice::<u32>();
210 izip!(maj_u32.iter_mut(), a_u32, b_u32, c_u32).for_each(|(maj, a, b, c)| {
211 *maj = (a & (b ^ c)) ^ (b & c);
212 });
213 }
214
215 let temp2 =
216 arithmetic::u32::add(builder, "temp2", sigma0, maj[i], arithmetic::Flags::Unchecked)?;
217
218 builder.assert_zero(
223 format!("ch_{i}"),
224 [e, f, g, ch[i]],
225 arith_expr!([e, f, g, ch] = (g + e * (f + g)) - ch).convert_field(),
226 );
227
228 builder.assert_zero(
229 format!("maj_{i}"),
230 [a, b, c, maj[i]],
231 arith_expr!([a, b, c, maj] = maj - (a * (b + c)) + b * c).convert_field(),
232 );
233
234 h = g;
235 g = f;
236 f = e;
237 e = arithmetic::u32::add(builder, "e", d, temp1, arithmetic::Flags::Unchecked)?;
238 d = c;
239 c = b;
240 b = a;
241 a = arithmetic::u32::add(builder, "a", temp1, temp2, arithmetic::Flags::Unchecked)?;
242 }
243
244 let abcdefgh = [a, b, c, d, e, f, g, h];
245
246 let output = std::array::from_fn(|i| {
247 arithmetic::u32::add(
248 builder,
249 "output",
250 init_oracles[i],
251 abcdefgh[i],
252 arithmetic::Flags::Unchecked,
253 )
254 .unwrap()
255 });
256
257 Ok(output)
258}
259
260#[cfg(test)]
261mod tests {
262 use binius_core::oracle::OracleId;
263 use binius_field::{as_packed_field::PackedType, BinaryField1b};
264 use sha2::{compress256, digest::generic_array::GenericArray};
265
266 use crate::{
267 builder::{test_utils::test_circuit, types::U},
268 unconstrained::unconstrained,
269 };
270
271 #[test]
272 fn test_sha256() {
273 test_circuit(|builder| {
274 let log_size = PackedType::<U, BinaryField1b>::LOG_WIDTH;
275 let input: [OracleId; 16] = std::array::from_fn(|i| {
276 unconstrained::<BinaryField1b>(builder, i, log_size).unwrap()
277 });
278 let state_output = super::sha256(builder, input, log_size).unwrap();
279
280 if let Some(witness) = builder.witness() {
281 let input_witneses: [_; 16] = std::array::from_fn(|i| {
282 witness
283 .get::<BinaryField1b>(input[i])
284 .unwrap()
285 .as_slice::<u32>()
286 });
287
288 let output_witneses: [_; 8] = std::array::from_fn(|i| {
289 witness
290 .get::<BinaryField1b>(state_output[i])
291 .unwrap()
292 .as_slice::<u32>()
293 });
294
295 let mut generic_array_input = GenericArray::<u8, _>::default();
296
297 let n_compressions = input_witneses[0].len();
298
299 for j in 0..n_compressions {
300 for i in 0..16 {
301 for z in 0..4 {
302 generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z];
303 }
304 }
305
306 let mut output = crate::sha256::INIT;
307 compress256(&mut output, &[generic_array_input]);
308
309 for i in 0..8 {
310 assert_eq!(output[i], output_witneses[i][j]);
311 }
312 }
313 }
314
315 Ok(vec![])
316 })
317 .unwrap();
318 }
319}