1use binius_field::{BinaryField128b, Field, PackedField};
4use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes};
5use binius_math::MultilinearExtension;
6use binius_utils::{bail, DeserializeBytes};
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
25pub struct StepDown {
26 n_vars: usize,
27 index: usize,
28}
29
30inventory::submit! {
31 <dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
32 "StepDown",
33 |buf, mode| Ok(Box::new(StepDown::deserialize(&mut *buf, mode)?))
34 )
35}
36
37impl StepDown {
38 pub fn new(n_vars: usize, index: usize) -> Result<Self, Error> {
39 if index > 1 << n_vars {
40 bail!(Error::ArgumentRangeError {
41 arg: "index".into(),
42 range: 0..(1 << n_vars) + 1,
43 })
44 }
45 Ok(Self { n_vars, index })
46 }
47
48 pub const fn n_vars(&self) -> usize {
49 self.n_vars
50 }
51
52 pub const fn index(&self) -> usize {
53 self.index
54 }
55
56 pub fn multilinear_extension<P: PackedField>(&self) -> Result<MultilinearExtension<P>, Error> {
57 let log_packed_length = self.n_vars.saturating_sub(P::LOG_WIDTH);
58 let mut data = vec![P::zero(); 1 << log_packed_length];
59 self.populate(&mut data);
60 Ok(MultilinearExtension::new(self.n_vars, data)?)
61 }
62
63 pub fn populate<P: PackedField>(&self, data: &mut [P]) {
64 let packed_index = self.index / P::WIDTH;
65 data[..packed_index].fill(P::one());
66 data[packed_index..].fill(P::zero());
67 for i in 0..(self.index % P::WIDTH) {
68 data[packed_index].set(i, P::Scalar::ONE);
69 }
70 }
71}
72
73#[erased_serialize_bytes]
74impl<F: Field> MultivariatePoly<F> for StepDown {
75 fn degree(&self) -> usize {
76 self.n_vars
77 }
78
79 fn n_vars(&self) -> usize {
80 self.n_vars
81 }
82
83 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
84 let n_vars = MultivariatePoly::<F>::n_vars(self);
85 if query.len() != n_vars {
86 bail!(Error::IncorrectQuerySize { expected: n_vars });
87 }
88 let mut k = self.index;
89
90 if k == 1 << n_vars {
91 return Ok(F::ONE);
92 }
93 let mut result = F::ZERO;
94 for q in query {
95 if k & 1 == 1 {
96 result = (F::ONE - q) + result * q;
98 } else {
99 result *= F::ONE - q;
101 }
102 k >>= 1;
103 }
104
105 Ok(result)
106 }
107
108 fn binary_tower_level(&self) -> usize {
109 0
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use binius_field::{
116 BinaryField1b, PackedBinaryField128x1b, PackedBinaryField256x1b, PackedField,
117 };
118 use binius_utils::felts;
119
120 use super::StepDown;
121 use crate::polynomial::test_utils::{hypercube_evals_from_oracle, packed_slice};
122
123 #[test]
124 fn test_step_down_trace_without_packing_simple_cases() {
125 assert_eq!(stepdown_evals::<BinaryField1b>(2, 0), felts!(BinaryField1b[0, 0, 0, 0]));
126 assert_eq!(stepdown_evals::<BinaryField1b>(2, 1), felts!(BinaryField1b[1, 0, 0, 0]));
127 assert_eq!(stepdown_evals::<BinaryField1b>(2, 2), felts!(BinaryField1b[1, 1, 0, 0]));
128 assert_eq!(stepdown_evals::<BinaryField1b>(2, 3), felts!(BinaryField1b[1, 1, 1, 0]));
129 assert_eq!(stepdown_evals::<BinaryField1b>(2, 4), felts!(BinaryField1b[1, 1, 1, 1]));
130 assert_eq!(
131 stepdown_evals::<BinaryField1b>(3, 0),
132 felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 0, 0])
133 );
134 assert_eq!(
135 stepdown_evals::<BinaryField1b>(3, 1),
136 felts!(BinaryField1b[1, 0, 0, 0, 0, 0, 0, 0])
137 );
138 assert_eq!(
139 stepdown_evals::<BinaryField1b>(3, 2),
140 felts!(BinaryField1b[1, 1, 0, 0, 0, 0, 0, 0])
141 );
142 assert_eq!(
143 stepdown_evals::<BinaryField1b>(3, 3),
144 felts!(BinaryField1b[1, 1, 1, 0, 0, 0, 0, 0])
145 );
146 assert_eq!(
147 stepdown_evals::<BinaryField1b>(3, 4),
148 felts!(BinaryField1b[1, 1, 1, 1, 0, 0, 0, 0])
149 );
150 assert_eq!(
151 stepdown_evals::<BinaryField1b>(3, 5),
152 felts!(BinaryField1b[1, 1, 1, 1, 1, 0, 0, 0])
153 );
154 assert_eq!(
155 stepdown_evals::<BinaryField1b>(3, 6),
156 felts!(BinaryField1b[1, 1, 1, 1, 1, 1, 0, 0])
157 );
158 assert_eq!(
159 stepdown_evals::<BinaryField1b>(3, 7),
160 felts!(BinaryField1b[1, 1, 1, 1, 1, 1, 1, 0])
161 );
162 assert_eq!(
163 stepdown_evals::<BinaryField1b>(3, 8),
164 felts!(BinaryField1b[1, 1, 1, 1, 1, 1, 1, 1])
165 );
166 }
167
168 #[test]
169 fn test_step_down_trace_without_packing() {
170 assert_eq!(
171 stepdown_evals::<BinaryField1b>(9, 314),
172 packed_slice::<BinaryField1b>(&[(0..314, 1), (314..512, 0)])
173 );
174 assert_eq!(
175 stepdown_evals::<BinaryField1b>(10, 555),
176 packed_slice::<BinaryField1b>(&[(0..555, 1), (555..1024, 0)])
177 );
178 assert_eq!(
179 stepdown_evals::<BinaryField1b>(11, 0),
180 packed_slice::<BinaryField1b>(&[(0..2048, 0)])
181 );
182 assert_eq!(
183 stepdown_evals::<BinaryField1b>(11, 1),
184 packed_slice::<BinaryField1b>(&[(0..1, 1), (1..2048, 0)])
185 );
186 assert_eq!(
187 stepdown_evals::<BinaryField1b>(11, 2048),
188 packed_slice::<BinaryField1b>(&[(0..2048, 1)])
189 );
190 }
191
192 #[test]
193 fn test_step_down_trace_with_packing_128() {
194 assert_eq!(
195 stepdown_evals::<PackedBinaryField128x1b>(9, 314),
196 packed_slice::<PackedBinaryField128x1b>(&[(0..314, 1), (314..512, 0)])
197 );
198 assert_eq!(
199 stepdown_evals::<PackedBinaryField128x1b>(10, 555),
200 packed_slice::<PackedBinaryField128x1b>(&[(0..555, 1), (555..1024, 0)])
201 );
202 assert_eq!(
203 stepdown_evals::<PackedBinaryField128x1b>(11, 0),
204 packed_slice::<PackedBinaryField128x1b>(&[(0..2048, 0)])
205 );
206 assert_eq!(
207 stepdown_evals::<PackedBinaryField128x1b>(11, 1),
208 packed_slice::<PackedBinaryField128x1b>(&[(0..1, 1), (1..2048, 0)])
209 );
210 assert_eq!(
211 stepdown_evals::<PackedBinaryField128x1b>(11, 2048),
212 packed_slice::<PackedBinaryField128x1b>(&[(0..2048, 1)])
213 );
214 }
215
216 #[test]
217 fn test_step_down_trace_with_packing_256() {
218 assert_eq!(
219 stepdown_evals::<PackedBinaryField256x1b>(9, 314),
220 packed_slice::<PackedBinaryField256x1b>(&[(0..314, 1), (314..512, 0)])
221 );
222 assert_eq!(
223 stepdown_evals::<PackedBinaryField256x1b>(10, 555),
224 packed_slice::<PackedBinaryField256x1b>(&[(0..555, 1), (555..1024, 0)])
225 );
226 assert_eq!(
227 stepdown_evals::<PackedBinaryField256x1b>(11, 0),
228 packed_slice::<PackedBinaryField256x1b>(&[(0..2048, 0)])
229 );
230 assert_eq!(
231 stepdown_evals::<PackedBinaryField256x1b>(11, 1),
232 packed_slice::<PackedBinaryField256x1b>(&[(0..1, 1), (1..2048, 0)])
233 );
234 assert_eq!(
235 stepdown_evals::<PackedBinaryField256x1b>(11, 2048),
236 packed_slice::<PackedBinaryField256x1b>(&[(0..2048, 1)])
237 );
238 }
239
240 #[test]
241 fn test_consistency_between_multilinear_extension_and_multilinear_poly_oracle() {
242 for n_vars in 1..6 {
243 for index in 0..=(1 << n_vars) {
244 let step_down = StepDown::new(n_vars, index).unwrap();
245 assert_eq!(
246 hypercube_evals_from_oracle::<BinaryField1b>(&step_down),
247 step_down
248 .multilinear_extension::<BinaryField1b>()
249 .unwrap()
250 .evals()
251 );
252 }
253 }
254 }
255
256 fn stepdown_evals<P>(n_vars: usize, index: usize) -> Vec<P>
257 where
258 P: PackedField<Scalar = BinaryField1b>,
259 {
260 StepDown::new(n_vars, index)
261 .unwrap()
262 .multilinear_extension::<P>()
263 .unwrap()
264 .evals()
265 .to_vec()
266 }
267}