binius_circuits/
collatz.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_core::{
4	constraint_system::channel::{Boundary, ChannelId, FlushDirection},
5	oracle::OracleId,
6};
7use binius_field::{
8	as_packed_field::PackScalar, BinaryField1b, BinaryField32b, ExtensionField, TowerField,
9};
10use binius_macros::arith_expr;
11use bytemuck::Pod;
12
13use crate::{
14	arithmetic,
15	builder::{
16		types::{F, U},
17		ConstraintSystemBuilder,
18	},
19	transparent,
20};
21
22pub type Advice = (usize, usize);
23
24pub struct Collatz {
25	x0: u32,
26	evens: Vec<u32>,
27	odds: Vec<u32>,
28}
29
30impl Collatz {
31	pub const fn new(x0: u32) -> Self {
32		Self {
33			x0,
34			evens: vec![],
35			odds: vec![],
36		}
37	}
38
39	pub fn init_prover(&mut self) -> Advice {
40		let (evens, odds) = collatz_orbit(self.x0).into_iter().partition(|x| x % 2 == 0);
41		self.evens = evens;
42		self.odds = odds;
43
44		(self.evens.len(), self.odds.len())
45	}
46
47	pub fn build(
48		self,
49		builder: &mut ConstraintSystemBuilder,
50		advice: Advice,
51	) -> Result<Vec<Boundary<F>>, anyhow::Error>
52	where
53		U: PackScalar<F> + PackScalar<BinaryField1b> + PackScalar<BinaryField32b> + Pod,
54		F: TowerField + ExtensionField<BinaryField32b>,
55	{
56		let (evens_count, odds_count) = advice;
57
58		let channel = builder.add_channel();
59
60		self.even(builder, channel, evens_count)?;
61		self.odd(builder, channel, odds_count)?;
62
63		let boundaries = self.get_boundaries(channel);
64
65		Ok(boundaries)
66	}
67
68	fn even(
69		&self,
70		builder: &mut ConstraintSystemBuilder,
71		channel: ChannelId,
72		count: usize,
73	) -> Result<(), anyhow::Error> {
74		let log_1b_rows = 5 + binius_utils::checked_arithmetics::log2_ceil_usize(count);
75		let even = builder.add_committed("even", log_1b_rows, BinaryField1b::TOWER_LEVEL);
76		if let Some(witness) = builder.witness() {
77			debug_assert_eq!(count, self.evens.len());
78			witness
79				.new_column::<BinaryField1b>(even)
80				.as_mut_slice::<u32>()[..count]
81				.copy_from_slice(&self.evens);
82		}
83
84		// Passing Checked flag here makes sure the number is actually even
85		let half = arithmetic::u32::half(builder, "half", even, arithmetic::Flags::Checked)?;
86
87		let even_packed = arithmetic::u32::packed(builder, "even_packed", even)?;
88		builder.receive(channel, count, [even_packed])?;
89
90		let half_packed = arithmetic::u32::packed(builder, "half_packed", half)?;
91		builder.send(channel, count, [half_packed])?;
92
93		Ok(())
94	}
95
96	fn odd(
97		&self,
98		builder: &mut ConstraintSystemBuilder,
99		channel: ChannelId,
100		count: usize,
101	) -> Result<(), anyhow::Error> {
102		let log_32b_rows = binius_utils::checked_arithmetics::log2_ceil_usize(count);
103		let log_1b_rows = 5 + log_32b_rows;
104
105		let odd = builder.add_committed("odd", log_1b_rows, BinaryField1b::TOWER_LEVEL);
106		if let Some(witness) = builder.witness() {
107			debug_assert_eq!(count, self.odds.len());
108			witness
109				.new_column::<BinaryField1b>(odd)
110				.as_mut_slice::<u32>()[..count]
111				.copy_from_slice(&self.odds);
112		}
113
114		// Ensure the number is odd
115		ensure_odd(builder, odd, count)?;
116
117		let one = arithmetic::u32::constant(builder, "one", log_32b_rows, 1)?;
118		let triple =
119			arithmetic::u32::mul_const(builder, "triple", odd, 3, arithmetic::Flags::Checked)?;
120		let triple_plus_one = arithmetic::u32::add(
121			builder,
122			"triple_plus_one",
123			triple,
124			one,
125			arithmetic::Flags::Checked,
126		)?;
127
128		let odd_packed = arithmetic::u32::packed(builder, "odd_packed", odd)?;
129		builder.receive(channel, count, [odd_packed])?;
130
131		let triple_plus_one_packed =
132			arithmetic::u32::packed(builder, "triple_plus_one_packed", triple_plus_one)?;
133		builder.send(channel, count, [triple_plus_one_packed])?;
134
135		Ok(())
136	}
137
138	fn get_boundaries(&self, channel_id: usize) -> Vec<Boundary<F>> {
139		vec![
140			Boundary {
141				channel_id,
142				direction: FlushDirection::Push,
143				values: vec![BinaryField32b::new(self.x0).into()],
144				multiplicity: 1,
145			},
146			Boundary {
147				channel_id,
148				direction: FlushDirection::Pull,
149				values: vec![BinaryField32b::new(1).into()],
150				multiplicity: 1,
151			},
152		]
153	}
154}
155
156/// ```
157/// assert_eq!(
158///     binius_circuits::collatz::collatz_orbit(5),
159///     vec![5, 16, 8, 4, 2]
160/// )
161/// ```
162pub fn collatz_orbit(x0: u32) -> Vec<u32> {
163	let mut res = vec![x0];
164	let mut x = x0;
165	while x != 1 {
166		if x % 2 == 0 {
167			x /= 2;
168		} else {
169			x = 3 * x + 1;
170		}
171		res.push(x);
172	}
173	// We ignore the final 1
174	res.pop();
175	res
176}
177
178pub fn ensure_odd(
179	builder: &mut ConstraintSystemBuilder,
180	input: OracleId,
181	count: usize,
182) -> Result<(), anyhow::Error> {
183	let log_32b_rows = builder.log_rows([input])? - 5;
184	let lsb = arithmetic::u32::select_bit(builder, "lsb", input, 0)?;
185	let selector = transparent::step_down(builder, "count", log_32b_rows, count)?;
186	builder.assert_zero(
187		"is_odd",
188		[lsb, selector],
189		arith_expr!([lsb, selector] = selector * (lsb + 1)).convert_field(),
190	);
191	Ok(())
192}
193
194#[cfg(test)]
195mod tests {
196	use crate::{builder::test_utils::test_circuit, collatz::Collatz};
197
198	#[test]
199	fn test_collatz() {
200		test_circuit(|builder| {
201			let x0 = 9999999;
202			let mut collatz = Collatz::new(x0);
203			let advice = collatz.init_prover();
204			let boundaries = collatz.build(builder, advice)?;
205			Ok(boundaries)
206		})
207		.unwrap();
208	}
209}