binius_core/transparent/
step_down.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{BinaryField128b, Field, PackedField};
4use binius_macros::{DeserializeBytes, SerializeBytes, erased_serialize_bytes};
5use binius_math::MultilinearExtension;
6use binius_utils::{DeserializeBytes, bail};
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10/// Represents a multilinear F2-polynomial whose evaluations over the hypercube are
11/// 1 until a specified index where they change to 0.
12///
13/// If the index is the length of the multilinear, then all coefficients are 1.
14///
15/// ```txt
16///     (1 << n_vars)
17/// <-------------------->
18/// 1,1 .. 1,1,0,0, .. 0,0
19///            ^
20///            index of first 0
21/// ```
22///
23/// This is useful for making constraints that are not enforced at the last rows of the trace
24#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
25pub struct StepDown {
26	n_vars: usize,
27	index: usize,
28}
29
30inventory::submit! {
31	<dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
32		"StepDown",
33		|buf, mode| Ok(Box::new(StepDown::deserialize(&mut *buf, mode)?))
34	)
35}
36
37impl StepDown {
38	pub fn new(n_vars: usize, index: usize) -> Result<Self, Error> {
39		if index > 1 << n_vars {
40			bail!(Error::ArgumentRangeError {
41				arg: "index".into(),
42				range: 0..(1 << n_vars) + 1,
43			})
44		}
45		Ok(Self { n_vars, index })
46	}
47
48	pub const fn n_vars(&self) -> usize {
49		self.n_vars
50	}
51
52	pub const fn index(&self) -> usize {
53		self.index
54	}
55
56	pub fn multilinear_extension<P: PackedField>(&self) -> Result<MultilinearExtension<P>, Error> {
57		let log_packed_length = self.n_vars.saturating_sub(P::LOG_WIDTH);
58		let mut data = vec![P::zero(); 1 << log_packed_length];
59		self.populate(&mut data);
60		Ok(MultilinearExtension::new(self.n_vars, data)?)
61	}
62
63	pub fn populate<P: PackedField>(&self, data: &mut [P]) {
64		let packed_index = self.index / P::WIDTH;
65		data[..packed_index].fill(P::one());
66		data[packed_index..].fill(P::zero());
67		for i in 0..(self.index % P::WIDTH) {
68			data[packed_index].set(i, P::Scalar::ONE);
69		}
70	}
71}
72
73#[erased_serialize_bytes]
74impl<F: Field> MultivariatePoly<F> for StepDown {
75	fn degree(&self) -> usize {
76		self.n_vars
77	}
78
79	fn n_vars(&self) -> usize {
80		self.n_vars
81	}
82
83	fn evaluate(&self, query: &[F]) -> Result<F, Error> {
84		let n_vars = MultivariatePoly::<F>::n_vars(self);
85		if query.len() != n_vars {
86			bail!(Error::IncorrectQuerySize {
87				expected: n_vars,
88				actual: query.len(),
89			});
90		}
91		let mut k = self.index;
92
93		if k == 1 << n_vars {
94			return Ok(F::ONE);
95		}
96		let mut result = F::ZERO;
97		for q in query {
98			if k & 1 == 1 {
99				// interpolate a line that is 1 at 0 and `result` at 1, at the point q
100				result = (F::ONE - q) + result * q;
101			} else {
102				// interpolate a line that is `result` at 0 and 0 at 1, and evaluate at q
103				result *= F::ONE - q;
104			}
105			k >>= 1;
106		}
107
108		Ok(result)
109	}
110
111	fn binary_tower_level(&self) -> usize {
112		0
113	}
114}
115
116#[cfg(test)]
117mod tests {
118	use binius_field::{
119		BinaryField1b, PackedBinaryField128x1b, PackedBinaryField256x1b, PackedField,
120	};
121	use binius_utils::felts;
122
123	use super::StepDown;
124	use crate::polynomial::test_utils::{hypercube_evals_from_oracle, packed_slice};
125
126	#[test]
127	fn test_step_down_trace_without_packing_simple_cases() {
128		assert_eq!(stepdown_evals::<BinaryField1b>(2, 0), felts!(BinaryField1b[0, 0, 0, 0]));
129		assert_eq!(stepdown_evals::<BinaryField1b>(2, 1), felts!(BinaryField1b[1, 0, 0, 0]));
130		assert_eq!(stepdown_evals::<BinaryField1b>(2, 2), felts!(BinaryField1b[1, 1, 0, 0]));
131		assert_eq!(stepdown_evals::<BinaryField1b>(2, 3), felts!(BinaryField1b[1, 1, 1, 0]));
132		assert_eq!(stepdown_evals::<BinaryField1b>(2, 4), felts!(BinaryField1b[1, 1, 1, 1]));
133		assert_eq!(
134			stepdown_evals::<BinaryField1b>(3, 0),
135			felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 0, 0])
136		);
137		assert_eq!(
138			stepdown_evals::<BinaryField1b>(3, 1),
139			felts!(BinaryField1b[1, 0, 0, 0, 0, 0, 0, 0])
140		);
141		assert_eq!(
142			stepdown_evals::<BinaryField1b>(3, 2),
143			felts!(BinaryField1b[1, 1, 0, 0, 0, 0, 0, 0])
144		);
145		assert_eq!(
146			stepdown_evals::<BinaryField1b>(3, 3),
147			felts!(BinaryField1b[1, 1, 1, 0, 0, 0, 0, 0])
148		);
149		assert_eq!(
150			stepdown_evals::<BinaryField1b>(3, 4),
151			felts!(BinaryField1b[1, 1, 1, 1, 0, 0, 0, 0])
152		);
153		assert_eq!(
154			stepdown_evals::<BinaryField1b>(3, 5),
155			felts!(BinaryField1b[1, 1, 1, 1, 1, 0, 0, 0])
156		);
157		assert_eq!(
158			stepdown_evals::<BinaryField1b>(3, 6),
159			felts!(BinaryField1b[1, 1, 1, 1, 1, 1, 0, 0])
160		);
161		assert_eq!(
162			stepdown_evals::<BinaryField1b>(3, 7),
163			felts!(BinaryField1b[1, 1, 1, 1, 1, 1, 1, 0])
164		);
165		assert_eq!(
166			stepdown_evals::<BinaryField1b>(3, 8),
167			felts!(BinaryField1b[1, 1, 1, 1, 1, 1, 1, 1])
168		);
169	}
170
171	#[test]
172	fn test_step_down_trace_without_packing() {
173		assert_eq!(
174			stepdown_evals::<BinaryField1b>(9, 314),
175			packed_slice::<BinaryField1b>(&[(0..314, 1), (314..512, 0)])
176		);
177		assert_eq!(
178			stepdown_evals::<BinaryField1b>(10, 555),
179			packed_slice::<BinaryField1b>(&[(0..555, 1), (555..1024, 0)])
180		);
181		assert_eq!(
182			stepdown_evals::<BinaryField1b>(11, 0),
183			packed_slice::<BinaryField1b>(&[(0..2048, 0)])
184		);
185		assert_eq!(
186			stepdown_evals::<BinaryField1b>(11, 1),
187			packed_slice::<BinaryField1b>(&[(0..1, 1), (1..2048, 0)])
188		);
189		assert_eq!(
190			stepdown_evals::<BinaryField1b>(11, 2048),
191			packed_slice::<BinaryField1b>(&[(0..2048, 1)])
192		);
193	}
194
195	#[test]
196	fn test_step_down_trace_with_packing_128() {
197		assert_eq!(
198			stepdown_evals::<PackedBinaryField128x1b>(9, 314),
199			packed_slice::<PackedBinaryField128x1b>(&[(0..314, 1), (314..512, 0)])
200		);
201		assert_eq!(
202			stepdown_evals::<PackedBinaryField128x1b>(10, 555),
203			packed_slice::<PackedBinaryField128x1b>(&[(0..555, 1), (555..1024, 0)])
204		);
205		assert_eq!(
206			stepdown_evals::<PackedBinaryField128x1b>(11, 0),
207			packed_slice::<PackedBinaryField128x1b>(&[(0..2048, 0)])
208		);
209		assert_eq!(
210			stepdown_evals::<PackedBinaryField128x1b>(11, 1),
211			packed_slice::<PackedBinaryField128x1b>(&[(0..1, 1), (1..2048, 0)])
212		);
213		assert_eq!(
214			stepdown_evals::<PackedBinaryField128x1b>(11, 2048),
215			packed_slice::<PackedBinaryField128x1b>(&[(0..2048, 1)])
216		);
217	}
218
219	#[test]
220	fn test_step_down_trace_with_packing_256() {
221		assert_eq!(
222			stepdown_evals::<PackedBinaryField256x1b>(9, 314),
223			packed_slice::<PackedBinaryField256x1b>(&[(0..314, 1), (314..512, 0)])
224		);
225		assert_eq!(
226			stepdown_evals::<PackedBinaryField256x1b>(10, 555),
227			packed_slice::<PackedBinaryField256x1b>(&[(0..555, 1), (555..1024, 0)])
228		);
229		assert_eq!(
230			stepdown_evals::<PackedBinaryField256x1b>(11, 0),
231			packed_slice::<PackedBinaryField256x1b>(&[(0..2048, 0)])
232		);
233		assert_eq!(
234			stepdown_evals::<PackedBinaryField256x1b>(11, 1),
235			packed_slice::<PackedBinaryField256x1b>(&[(0..1, 1), (1..2048, 0)])
236		);
237		assert_eq!(
238			stepdown_evals::<PackedBinaryField256x1b>(11, 2048),
239			packed_slice::<PackedBinaryField256x1b>(&[(0..2048, 1)])
240		);
241	}
242
243	#[test]
244	fn test_consistency_between_multilinear_extension_and_multilinear_poly_oracle() {
245		for n_vars in 1..6 {
246			for index in 0..=(1 << n_vars) {
247				let step_down = StepDown::new(n_vars, index).unwrap();
248				assert_eq!(
249					hypercube_evals_from_oracle::<BinaryField1b>(&step_down),
250					step_down
251						.multilinear_extension::<BinaryField1b>()
252						.unwrap()
253						.evals()
254				);
255			}
256		}
257	}
258
259	fn stepdown_evals<P>(n_vars: usize, index: usize) -> Vec<P>
260	where
261		P: PackedField<Scalar = BinaryField1b>,
262	{
263		StepDown::new(n_vars, index)
264			.unwrap()
265			.multilinear_extension::<P>()
266			.unwrap()
267			.evals()
268			.to_vec()
269	}
270}