1use binius_field::{BinaryField128b, Field, PackedField};
4use binius_macros::{DeserializeBytes, SerializeBytes, erased_serialize_bytes};
5use binius_math::MultilinearExtension;
6use binius_utils::{DeserializeBytes, bail};
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
25pub struct StepDown {
26 n_vars: usize,
27 index: usize,
28}
29
30inventory::submit! {
31 <dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
32 "StepDown",
33 |buf, mode| Ok(Box::new(StepDown::deserialize(&mut *buf, mode)?))
34 )
35}
36
37impl StepDown {
38 pub fn new(n_vars: usize, index: usize) -> Result<Self, Error> {
39 if index > 1 << n_vars {
40 bail!(Error::ArgumentRangeError {
41 arg: "index".into(),
42 range: 0..(1 << n_vars) + 1,
43 })
44 }
45 Ok(Self { n_vars, index })
46 }
47
48 pub const fn n_vars(&self) -> usize {
49 self.n_vars
50 }
51
52 pub const fn index(&self) -> usize {
53 self.index
54 }
55
56 pub fn multilinear_extension<P: PackedField>(&self) -> Result<MultilinearExtension<P>, Error> {
57 let log_packed_length = self.n_vars.saturating_sub(P::LOG_WIDTH);
58 let mut data = vec![P::zero(); 1 << log_packed_length];
59 self.populate(&mut data);
60 Ok(MultilinearExtension::new(self.n_vars, data)?)
61 }
62
63 pub fn populate<P: PackedField>(&self, data: &mut [P]) {
64 let packed_index = self.index / P::WIDTH;
65 data[..packed_index].fill(P::one());
66 data[packed_index..].fill(P::zero());
67 for i in 0..(self.index % P::WIDTH) {
68 data[packed_index].set(i, P::Scalar::ONE);
69 }
70 }
71}
72
73#[erased_serialize_bytes]
74impl<F: Field> MultivariatePoly<F> for StepDown {
75 fn degree(&self) -> usize {
76 self.n_vars
77 }
78
79 fn n_vars(&self) -> usize {
80 self.n_vars
81 }
82
83 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
84 let n_vars = MultivariatePoly::<F>::n_vars(self);
85 if query.len() != n_vars {
86 bail!(Error::IncorrectQuerySize {
87 expected: n_vars,
88 actual: query.len(),
89 });
90 }
91 let mut k = self.index;
92
93 if k == 1 << n_vars {
94 return Ok(F::ONE);
95 }
96 let mut result = F::ZERO;
97 for q in query {
98 if k & 1 == 1 {
99 result = (F::ONE - q) + result * q;
101 } else {
102 result *= F::ONE - q;
104 }
105 k >>= 1;
106 }
107
108 Ok(result)
109 }
110
111 fn binary_tower_level(&self) -> usize {
112 0
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use binius_field::{
119 BinaryField1b, PackedBinaryField128x1b, PackedBinaryField256x1b, PackedField,
120 };
121 use binius_utils::felts;
122
123 use super::StepDown;
124 use crate::polynomial::test_utils::{hypercube_evals_from_oracle, packed_slice};
125
126 #[test]
127 fn test_step_down_trace_without_packing_simple_cases() {
128 assert_eq!(stepdown_evals::<BinaryField1b>(2, 0), felts!(BinaryField1b[0, 0, 0, 0]));
129 assert_eq!(stepdown_evals::<BinaryField1b>(2, 1), felts!(BinaryField1b[1, 0, 0, 0]));
130 assert_eq!(stepdown_evals::<BinaryField1b>(2, 2), felts!(BinaryField1b[1, 1, 0, 0]));
131 assert_eq!(stepdown_evals::<BinaryField1b>(2, 3), felts!(BinaryField1b[1, 1, 1, 0]));
132 assert_eq!(stepdown_evals::<BinaryField1b>(2, 4), felts!(BinaryField1b[1, 1, 1, 1]));
133 assert_eq!(
134 stepdown_evals::<BinaryField1b>(3, 0),
135 felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 0, 0])
136 );
137 assert_eq!(
138 stepdown_evals::<BinaryField1b>(3, 1),
139 felts!(BinaryField1b[1, 0, 0, 0, 0, 0, 0, 0])
140 );
141 assert_eq!(
142 stepdown_evals::<BinaryField1b>(3, 2),
143 felts!(BinaryField1b[1, 1, 0, 0, 0, 0, 0, 0])
144 );
145 assert_eq!(
146 stepdown_evals::<BinaryField1b>(3, 3),
147 felts!(BinaryField1b[1, 1, 1, 0, 0, 0, 0, 0])
148 );
149 assert_eq!(
150 stepdown_evals::<BinaryField1b>(3, 4),
151 felts!(BinaryField1b[1, 1, 1, 1, 0, 0, 0, 0])
152 );
153 assert_eq!(
154 stepdown_evals::<BinaryField1b>(3, 5),
155 felts!(BinaryField1b[1, 1, 1, 1, 1, 0, 0, 0])
156 );
157 assert_eq!(
158 stepdown_evals::<BinaryField1b>(3, 6),
159 felts!(BinaryField1b[1, 1, 1, 1, 1, 1, 0, 0])
160 );
161 assert_eq!(
162 stepdown_evals::<BinaryField1b>(3, 7),
163 felts!(BinaryField1b[1, 1, 1, 1, 1, 1, 1, 0])
164 );
165 assert_eq!(
166 stepdown_evals::<BinaryField1b>(3, 8),
167 felts!(BinaryField1b[1, 1, 1, 1, 1, 1, 1, 1])
168 );
169 }
170
171 #[test]
172 fn test_step_down_trace_without_packing() {
173 assert_eq!(
174 stepdown_evals::<BinaryField1b>(9, 314),
175 packed_slice::<BinaryField1b>(&[(0..314, 1), (314..512, 0)])
176 );
177 assert_eq!(
178 stepdown_evals::<BinaryField1b>(10, 555),
179 packed_slice::<BinaryField1b>(&[(0..555, 1), (555..1024, 0)])
180 );
181 assert_eq!(
182 stepdown_evals::<BinaryField1b>(11, 0),
183 packed_slice::<BinaryField1b>(&[(0..2048, 0)])
184 );
185 assert_eq!(
186 stepdown_evals::<BinaryField1b>(11, 1),
187 packed_slice::<BinaryField1b>(&[(0..1, 1), (1..2048, 0)])
188 );
189 assert_eq!(
190 stepdown_evals::<BinaryField1b>(11, 2048),
191 packed_slice::<BinaryField1b>(&[(0..2048, 1)])
192 );
193 }
194
195 #[test]
196 fn test_step_down_trace_with_packing_128() {
197 assert_eq!(
198 stepdown_evals::<PackedBinaryField128x1b>(9, 314),
199 packed_slice::<PackedBinaryField128x1b>(&[(0..314, 1), (314..512, 0)])
200 );
201 assert_eq!(
202 stepdown_evals::<PackedBinaryField128x1b>(10, 555),
203 packed_slice::<PackedBinaryField128x1b>(&[(0..555, 1), (555..1024, 0)])
204 );
205 assert_eq!(
206 stepdown_evals::<PackedBinaryField128x1b>(11, 0),
207 packed_slice::<PackedBinaryField128x1b>(&[(0..2048, 0)])
208 );
209 assert_eq!(
210 stepdown_evals::<PackedBinaryField128x1b>(11, 1),
211 packed_slice::<PackedBinaryField128x1b>(&[(0..1, 1), (1..2048, 0)])
212 );
213 assert_eq!(
214 stepdown_evals::<PackedBinaryField128x1b>(11, 2048),
215 packed_slice::<PackedBinaryField128x1b>(&[(0..2048, 1)])
216 );
217 }
218
219 #[test]
220 fn test_step_down_trace_with_packing_256() {
221 assert_eq!(
222 stepdown_evals::<PackedBinaryField256x1b>(9, 314),
223 packed_slice::<PackedBinaryField256x1b>(&[(0..314, 1), (314..512, 0)])
224 );
225 assert_eq!(
226 stepdown_evals::<PackedBinaryField256x1b>(10, 555),
227 packed_slice::<PackedBinaryField256x1b>(&[(0..555, 1), (555..1024, 0)])
228 );
229 assert_eq!(
230 stepdown_evals::<PackedBinaryField256x1b>(11, 0),
231 packed_slice::<PackedBinaryField256x1b>(&[(0..2048, 0)])
232 );
233 assert_eq!(
234 stepdown_evals::<PackedBinaryField256x1b>(11, 1),
235 packed_slice::<PackedBinaryField256x1b>(&[(0..1, 1), (1..2048, 0)])
236 );
237 assert_eq!(
238 stepdown_evals::<PackedBinaryField256x1b>(11, 2048),
239 packed_slice::<PackedBinaryField256x1b>(&[(0..2048, 1)])
240 );
241 }
242
243 #[test]
244 fn test_consistency_between_multilinear_extension_and_multilinear_poly_oracle() {
245 for n_vars in 1..6 {
246 for index in 0..=(1 << n_vars) {
247 let step_down = StepDown::new(n_vars, index).unwrap();
248 assert_eq!(
249 hypercube_evals_from_oracle::<BinaryField1b>(&step_down),
250 step_down
251 .multilinear_extension::<BinaryField1b>()
252 .unwrap()
253 .evals()
254 );
255 }
256 }
257 }
258
259 fn stepdown_evals<P>(n_vars: usize, index: usize) -> Vec<P>
260 where
261 P: PackedField<Scalar = BinaryField1b>,
262 {
263 StepDown::new(n_vars, index)
264 .unwrap()
265 .multilinear_extension::<P>()
266 .unwrap()
267 .evals()
268 .to_vec()
269 }
270}