1use binius_core::{
4 oracle::{OracleId, ProjectionVariant, ShiftVariant},
5 transparent::MultilinearExtensionTransparent,
6};
7use binius_field::{
8 as_packed_field::PackedType, packed::set_packed_slice, underlier::WithUnderlier, BinaryField1b,
9 BinaryField32b, Field, PackedField, TowerField,
10};
11use binius_macros::arith_expr;
12use binius_maybe_rayon::prelude::*;
13use binius_utils::checked_arithmetics::checked_log_2;
14use bytemuck::{pod_collect_to_vec, Pod};
15
16use crate::{
17 builder::{
18 types::{F, U},
19 ConstraintSystemBuilder,
20 },
21 transparent,
22};
23
24type B1 = BinaryField1b;
25type B32 = BinaryField32b;
26
27pub fn packed(
28 builder: &mut ConstraintSystemBuilder,
29 name: impl ToString,
30 input: OracleId,
31) -> Result<OracleId, anyhow::Error> {
32 let packed = builder.add_packed(name, input, 5)?;
33 if let Some(witness) = builder.witness() {
34 witness.set(packed, witness.get::<B1>(input)?.repacked::<B32>())?;
35 }
36 Ok(packed)
37}
38
39pub fn mul_const(
40 builder: &mut ConstraintSystemBuilder,
41 name: impl ToString,
42 input: OracleId,
43 value: u32,
44 flags: super::Flags,
45) -> Result<OracleId, anyhow::Error> {
46 if value == 0 {
47 let log_rows = builder.log_rows([input])?;
48 return transparent::constant(builder, name, log_rows, B1::ZERO);
49 }
50
51 if value == 1 {
52 return Ok(input);
53 }
54
55 builder.push_namespace(name);
56 let mut tmp = value;
57 let mut offset = 0;
58 let mut result = input;
59 let mut first = true;
60 while tmp != 0 {
61 if tmp & 1 == 1 {
62 let shifted = shl(builder, format!("input_shl{offset}"), input, offset)?;
63 if first {
64 result = shifted;
65 first = false;
66 } else {
67 result = add(builder, format!("add_shl{offset}"), result, shifted, flags)?;
68 }
69 }
70 tmp >>= 1;
71 if tmp != 0 {
72 offset += 1;
73 }
74 }
75
76 if matches!(flags, super::Flags::Checked) {
77 for i in 32 - offset..32 {
79 let x = select_bit(builder, format!("bit{i}"), input, i)?;
80 builder.assert_zero("overflow", [x], arith_expr!([x] = x).convert_field());
81 }
82 }
83
84 builder.pop_namespace();
85 Ok(result)
86}
87
88pub fn add(
89 builder: &mut ConstraintSystemBuilder,
90 name: impl ToString,
91 xin: OracleId,
92 yin: OracleId,
93 flags: super::Flags,
94) -> Result<OracleId, anyhow::Error> {
95 builder.push_namespace(name);
96 let log_rows = builder.log_rows([xin, yin])?;
97 let cout = builder.add_committed("cout", log_rows, B1::TOWER_LEVEL);
98 let cin = builder.add_shifted("cin", cout, 1, 5, ShiftVariant::LogicalLeft)?;
99 let zout = builder.add_committed("zout", log_rows, B1::TOWER_LEVEL);
100
101 if let Some(witness) = builder.witness() {
102 (
103 witness.get::<B1>(xin)?.as_slice::<u32>(),
104 witness.get::<B1>(yin)?.as_slice::<u32>(),
105 witness.new_column::<B1>(zout).as_mut_slice::<u32>(),
106 witness.new_column::<B1>(cout).as_mut_slice::<u32>(),
107 witness.new_column::<B1>(cin).as_mut_slice::<u32>(),
108 )
109 .into_par_iter()
110 .for_each(|(xin, yin, zout, cout, cin)| {
111 let carry;
112 (*zout, carry) = (*xin).overflowing_add(*yin);
113 *cin = (*xin) ^ (*yin) ^ (*zout);
114 *cout = ((carry as u32) << 31) | (*cin >> 1);
115 });
116 }
117
118 builder.assert_zero(
119 "sum",
120 [xin, yin, cin, zout],
121 arith_expr!([xin, yin, cin, zout] = xin + yin + cin - zout).convert_field(),
122 );
123
124 builder.assert_zero(
125 "carry",
126 [xin, yin, cin, cout],
127 arith_expr!([xin, yin, cin, cout] = (xin + cin) * (yin + cin) + cin - cout).convert_field(),
128 );
129
130 if matches!(flags, super::Flags::Checked) {
132 let last_cout = select_bit(builder, "last_cout", cout, 31)?;
133 builder.assert_zero(
134 "overflow",
135 [last_cout],
136 arith_expr!([last_cout] = last_cout).convert_field(),
137 );
138 }
139
140 builder.pop_namespace();
141 Ok(zout)
142}
143
144pub fn sub(
145 builder: &mut ConstraintSystemBuilder,
146 name: impl ToString,
147 zin: OracleId,
148 yin: OracleId,
149 flags: super::Flags,
150) -> Result<OracleId, anyhow::Error> {
151 builder.push_namespace(name);
152 let log_rows = builder.log_rows([zin, yin])?;
153 let cout = builder.add_committed("cout", log_rows, B1::TOWER_LEVEL);
154 let cin = builder.add_shifted("cin", cout, 1, 5, ShiftVariant::LogicalLeft)?;
155 let xout = builder.add_committed("xin", log_rows, B1::TOWER_LEVEL);
156
157 if let Some(witness) = builder.witness() {
158 (
159 witness.get::<B1>(zin)?.as_slice::<u32>(),
160 witness.get::<B1>(yin)?.as_slice::<u32>(),
161 witness.new_column::<B1>(xout).as_mut_slice::<u32>(),
162 witness.new_column::<B1>(cout).as_mut_slice::<u32>(),
163 witness.new_column::<B1>(cin).as_mut_slice::<u32>(),
164 )
165 .into_par_iter()
166 .for_each(|(zout, yin, xin, cout, cin)| {
167 let carry;
168 (*xin, carry) = (*zout).overflowing_sub(*yin);
169 *cin = (*xin) ^ (*yin) ^ (*zout);
170 *cout = ((carry as u32) << 31) | (*cin >> 1);
171 });
172 }
173
174 builder.assert_zero(
175 "sum",
176 [xout, yin, cin, zin],
177 arith_expr!([xout, yin, cin, zin] = xout + yin + cin - zin).convert_field(),
178 );
179
180 builder.assert_zero(
181 "carry",
182 [xout, yin, cin, cout],
183 arith_expr!([xout, yin, cin, cout] = (xout + cin) * (yin + cin) + cin - cout)
184 .convert_field(),
185 );
186
187 if matches!(flags, super::Flags::Checked) {
189 let last_cout = select_bit(builder, "last_cout", cout, 31)?;
190 builder.assert_zero(
191 "underflow",
192 [last_cout],
193 arith_expr!([last_cout] = last_cout).convert_field(),
194 );
195 }
196
197 builder.pop_namespace();
198 Ok(xout)
199}
200
201pub fn half(
202 builder: &mut ConstraintSystemBuilder,
203 name: impl ToString,
204 input: OracleId,
205 flags: super::Flags,
206) -> Result<OracleId, anyhow::Error> {
207 if matches!(flags, super::Flags::Checked) {
208 let lsb = select_bit(builder, "lsb", input, 0)?;
210 builder.assert_zero("is_even", [lsb], arith_expr!([lsb] = lsb).convert_field());
211 }
212 shr(builder, name, input, 1)
213}
214
215pub fn shl(
216 builder: &mut ConstraintSystemBuilder,
217 name: impl ToString,
218 input: OracleId,
219 offset: usize,
220) -> Result<OracleId, anyhow::Error> {
221 if offset == 0 {
222 return Ok(input);
223 }
224
225 let shifted = builder.add_shifted(name, input, offset, 5, ShiftVariant::LogicalLeft)?;
226 if let Some(witness) = builder.witness() {
227 (
228 witness.new_column::<B1>(shifted).as_mut_slice::<u32>(),
229 witness.get::<B1>(input)?.as_slice::<u32>(),
230 )
231 .into_par_iter()
232 .for_each(|(shifted, input)| *shifted = *input << offset);
233 }
234
235 Ok(shifted)
236}
237
238pub fn shr(
239 builder: &mut ConstraintSystemBuilder,
240 name: impl ToString,
241 input: OracleId,
242 offset: usize,
243) -> Result<OracleId, anyhow::Error> {
244 if offset == 0 {
245 return Ok(input);
246 }
247
248 let shifted = builder.add_shifted(name, input, offset, 5, ShiftVariant::LogicalRight)?;
249 if let Some(witness) = builder.witness() {
250 (
251 witness.new_column::<B1>(shifted).as_mut_slice::<u32>(),
252 witness.get::<B1>(input)?.as_slice::<u32>(),
253 )
254 .into_par_iter()
255 .for_each(|(shifted, input)| *shifted = *input >> offset);
256 }
257
258 Ok(shifted)
259}
260
261pub fn select_bit(
262 builder: &mut ConstraintSystemBuilder,
263 name: impl ToString,
264 input: OracleId,
265 index: usize,
266) -> Result<OracleId, anyhow::Error> {
267 let log_rows = builder.log_rows([input])?;
268 anyhow::ensure!(log_rows >= 5, "Polynomial must have n_vars >= 5. Got {log_rows}");
269 anyhow::ensure!(index < 32, "Only index values between 0 and 32 are allowed. Got {index}");
270
271 let query = binius_core::polynomial::test_utils::decompose_index_to_hypercube_point(5, index);
272 let bits = builder.add_projected(name, input, query, ProjectionVariant::FirstVars)?;
273
274 if let Some(witness) = builder.witness() {
275 let mut bits = witness.new_column::<B1>(bits);
276 let bits = bits.packed();
277 let input = witness.get::<B1>(input)?.as_slice::<u32>();
278 input.iter().enumerate().for_each(|(i, &val)| {
279 let value = match (val >> index) & 1 {
280 0 => B1::ZERO,
281 _ => B1::ONE,
282 };
283 set_packed_slice(bits, i, value);
284 });
285 }
286
287 Ok(bits)
288}
289
290pub fn constant(
291 builder: &mut ConstraintSystemBuilder,
292 name: impl ToString,
293 log_count: usize,
294 value: u32,
295) -> Result<OracleId, anyhow::Error> {
296 builder.push_namespace(name);
297 let output = builder.add_committed("output", log_count + 5, B1::TOWER_LEVEL);
299 if let Some(witness) = builder.witness() {
300 witness.new_column::<B1>(output).as_mut_slice().fill(value);
301 }
302
303 let output_packed = builder.add_packed("output_packed", output, 5)?;
304 let transparent = builder.add_transparent(
305 "transparent",
306 binius_core::transparent::constant::Constant::new(log_count, B32::new(value)),
307 )?;
308 if let Some(witness) = builder.witness() {
309 let packed = witness.get::<B1>(output)?.repacked::<B32>();
310 witness.set(output_packed, packed)?;
311 witness.set(transparent, packed)?;
312 }
313 builder.assert_zero(
314 "unpack",
315 [output_packed, transparent],
316 arith_expr!([x, y] = x - y).convert_field(),
317 );
318 builder.pop_namespace();
319 Ok(output)
320}
321
322pub const LOG_U32_BITS: usize = checked_log_2(32);
323
324#[inline]
325fn into_packed_vec<P>(src: &[impl Pod]) -> Vec<P>
326where
327 P: PackedField + WithUnderlier,
328 P::Underlier: Pod,
329{
330 pod_collect_to_vec::<_, P::Underlier>(src)
331 .into_iter()
332 .map(P::from_underlier)
333 .collect()
334}
335
336pub fn u32const_repeating(
337 log_size: usize,
338 builder: &mut ConstraintSystemBuilder,
339 x: u32,
340 name: &str,
341) -> Result<OracleId, anyhow::Error> {
342 let brodcasted = vec![x; 1 << (PackedType::<U, B1>::LOG_WIDTH.saturating_sub(LOG_U32_BITS))];
343
344 let transparent_id = builder.add_transparent(
345 format!("transparent {}", name),
346 MultilinearExtensionTransparent::<_, PackedType<U, F>, _>::from_values(into_packed_vec::<
347 PackedType<U, B1>,
348 >(&brodcasted))?,
349 )?;
350
351 let repeating_id = builder.add_repeating(
352 format!("repeating {}", name),
353 transparent_id,
354 log_size - PackedType::<U, B1>::LOG_WIDTH,
355 )?;
356
357 if let Some(witness) = builder.witness() {
358 let mut transparent_witness = witness.new_column::<B1>(transparent_id);
359 transparent_witness.as_mut_slice::<u32>().fill(x);
360
361 let mut repeating_witness = witness.new_column::<B1>(repeating_id);
362 repeating_witness.as_mut_slice::<u32>().fill(x);
363 }
364
365 Ok(repeating_id)
366}
367
368#[cfg(test)]
369mod tests {
370 use binius_field::{BinaryField1b, TowerField};
371
372 use crate::{arithmetic, builder::test_utils::test_circuit, unconstrained::unconstrained};
373
374 #[test]
375 fn test_mul_const() {
376 test_circuit(|builder| {
377 let a = builder.add_committed("a", 5, BinaryField1b::TOWER_LEVEL);
378 if let Some(witness) = builder.witness() {
379 witness
380 .new_column::<BinaryField1b>(a)
381 .as_mut_slice::<u32>()
382 .iter_mut()
383 .for_each(|v| *v = 0b01000000_00000000_00000000_00000000u32);
384 }
385 let _c = arithmetic::u32::mul_const(builder, "mul3", a, 3, arithmetic::Flags::Checked)?;
386 Ok(vec![])
387 })
388 .unwrap();
389 }
390
391 #[test]
392 fn test_add() {
393 test_circuit(|builder| {
394 let log_size = 14;
395 let a = unconstrained::<BinaryField1b>(builder, "a", log_size)?;
396 let b = unconstrained::<BinaryField1b>(builder, "b", log_size)?;
397 let _c = arithmetic::u32::add(builder, "u32add", a, b, arithmetic::Flags::Unchecked)?;
398 Ok(vec![])
399 })
400 .unwrap();
401 }
402
403 #[test]
404 fn test_sub() {
405 test_circuit(|builder| {
406 let a = unconstrained::<BinaryField1b>(builder, "a", 7).unwrap();
407 let b = unconstrained::<BinaryField1b>(builder, "a", 7).unwrap();
408 let _c = arithmetic::u32::sub(builder, "c", a, b, arithmetic::Flags::Unchecked)?;
409 Ok(vec![])
410 })
411 .unwrap();
412 }
413}