1use binius_core::{
4 constraint_system::channel::{Boundary, ChannelId, FlushDirection},
5 oracle::OracleId,
6};
7use binius_field::{
8 as_packed_field::PackScalar, BinaryField1b, BinaryField32b, ExtensionField, TowerField,
9};
10use binius_macros::arith_expr;
11use bytemuck::Pod;
12
13use crate::{
14 arithmetic,
15 builder::{
16 types::{F, U},
17 ConstraintSystemBuilder,
18 },
19 transparent,
20};
21
22pub type Advice = (usize, usize);
23
24pub struct Collatz {
25 x0: u32,
26 evens: Vec<u32>,
27 odds: Vec<u32>,
28}
29
30impl Collatz {
31 pub const fn new(x0: u32) -> Self {
32 Self {
33 x0,
34 evens: vec![],
35 odds: vec![],
36 }
37 }
38
39 pub fn init_prover(&mut self) -> Advice {
40 let (evens, odds) = collatz_orbit(self.x0).into_iter().partition(|x| x % 2 == 0);
41 self.evens = evens;
42 self.odds = odds;
43
44 (self.evens.len(), self.odds.len())
45 }
46
47 pub fn build(
48 self,
49 builder: &mut ConstraintSystemBuilder,
50 advice: Advice,
51 ) -> Result<Vec<Boundary<F>>, anyhow::Error>
52 where
53 U: PackScalar<F> + PackScalar<BinaryField1b> + PackScalar<BinaryField32b> + Pod,
54 F: TowerField + ExtensionField<BinaryField32b>,
55 {
56 let (evens_count, odds_count) = advice;
57
58 let channel = builder.add_channel();
59
60 self.even(builder, channel, evens_count)?;
61 self.odd(builder, channel, odds_count)?;
62
63 let boundaries = self.get_boundaries(channel);
64
65 Ok(boundaries)
66 }
67
68 fn even(
69 &self,
70 builder: &mut ConstraintSystemBuilder,
71 channel: ChannelId,
72 count: usize,
73 ) -> Result<(), anyhow::Error> {
74 let log_1b_rows = 5 + binius_utils::checked_arithmetics::log2_ceil_usize(count);
75 let even = builder.add_committed("even", log_1b_rows, BinaryField1b::TOWER_LEVEL);
76 if let Some(witness) = builder.witness() {
77 debug_assert_eq!(count, self.evens.len());
78 witness
79 .new_column::<BinaryField1b>(even)
80 .as_mut_slice::<u32>()[..count]
81 .copy_from_slice(&self.evens);
82 }
83
84 let half = arithmetic::u32::half(builder, "half", even, arithmetic::Flags::Checked)?;
86
87 let even_packed = arithmetic::u32::packed(builder, "even_packed", even)?;
88 builder.receive(channel, count, [even_packed])?;
89
90 let half_packed = arithmetic::u32::packed(builder, "half_packed", half)?;
91 builder.send(channel, count, [half_packed])?;
92
93 Ok(())
94 }
95
96 fn odd(
97 &self,
98 builder: &mut ConstraintSystemBuilder,
99 channel: ChannelId,
100 count: usize,
101 ) -> Result<(), anyhow::Error> {
102 let log_32b_rows = binius_utils::checked_arithmetics::log2_ceil_usize(count);
103 let log_1b_rows = 5 + log_32b_rows;
104
105 let odd = builder.add_committed("odd", log_1b_rows, BinaryField1b::TOWER_LEVEL);
106 if let Some(witness) = builder.witness() {
107 debug_assert_eq!(count, self.odds.len());
108 witness
109 .new_column::<BinaryField1b>(odd)
110 .as_mut_slice::<u32>()[..count]
111 .copy_from_slice(&self.odds);
112 }
113
114 ensure_odd(builder, odd, count)?;
116
117 let one = arithmetic::u32::constant(builder, "one", log_32b_rows, 1)?;
118 let triple =
119 arithmetic::u32::mul_const(builder, "triple", odd, 3, arithmetic::Flags::Checked)?;
120 let triple_plus_one = arithmetic::u32::add(
121 builder,
122 "triple_plus_one",
123 triple,
124 one,
125 arithmetic::Flags::Checked,
126 )?;
127
128 let odd_packed = arithmetic::u32::packed(builder, "odd_packed", odd)?;
129 builder.receive(channel, count, [odd_packed])?;
130
131 let triple_plus_one_packed =
132 arithmetic::u32::packed(builder, "triple_plus_one_packed", triple_plus_one)?;
133 builder.send(channel, count, [triple_plus_one_packed])?;
134
135 Ok(())
136 }
137
138 fn get_boundaries(&self, channel_id: usize) -> Vec<Boundary<F>> {
139 vec![
140 Boundary {
141 channel_id,
142 direction: FlushDirection::Push,
143 values: vec![BinaryField32b::new(self.x0).into()],
144 multiplicity: 1,
145 },
146 Boundary {
147 channel_id,
148 direction: FlushDirection::Pull,
149 values: vec![BinaryField32b::new(1).into()],
150 multiplicity: 1,
151 },
152 ]
153 }
154}
155
156pub fn collatz_orbit(x0: u32) -> Vec<u32> {
163 let mut res = vec![x0];
164 let mut x = x0;
165 while x != 1 {
166 if x % 2 == 0 {
167 x /= 2;
168 } else {
169 x = 3 * x + 1;
170 }
171 res.push(x);
172 }
173 res.pop();
175 res
176}
177
178pub fn ensure_odd(
179 builder: &mut ConstraintSystemBuilder,
180 input: OracleId,
181 count: usize,
182) -> Result<(), anyhow::Error> {
183 let log_32b_rows = builder.log_rows([input])? - 5;
184 let lsb = arithmetic::u32::select_bit(builder, "lsb", input, 0)?;
185 let selector = transparent::step_down(builder, "count", log_32b_rows, count)?;
186 builder.assert_zero(
187 "is_odd",
188 [lsb, selector],
189 arith_expr!([lsb, selector] = selector * (lsb + 1)).convert_field(),
190 );
191 Ok(())
192}
193
194#[cfg(test)]
195mod tests {
196 use crate::{builder::test_utils::test_circuit, collatz::Collatz};
197
198 #[test]
199 fn test_collatz() {
200 test_circuit(|builder| {
201 let x0 = 9999999;
202 let mut collatz = Collatz::new(x0);
203 let advice = collatz.init_prover();
204 let boundaries = collatz.build(builder, advice)?;
205 Ok(boundaries)
206 })
207 .unwrap();
208 }
209}