1use binius_field::{BinaryField128b, Field, PackedField};
4use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes};
5use binius_math::MultilinearExtension;
6use binius_utils::{bail, DeserializeBytes};
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
25pub struct StepUp {
26 n_vars: usize,
27 index: usize,
28}
29
30inventory::submit! {
31 <dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
32 "StepUp",
33 |buf, mode| Ok(Box::new(StepUp::deserialize(&mut *buf, mode)?))
34 )
35}
36
37impl StepUp {
38 pub fn new(n_vars: usize, index: usize) -> Result<Self, Error> {
39 if index > 1 << n_vars {
40 bail!(Error::ArgumentRangeError {
41 arg: "index".into(),
42 range: 0..(1 << n_vars) + 1,
43 })
44 }
45 Ok(Self { n_vars, index })
46 }
47
48 pub const fn n_vars(&self) -> usize {
49 self.n_vars
50 }
51
52 pub fn multilinear_extension<P: PackedField>(&self) -> Result<MultilinearExtension<P>, Error> {
53 let log_packed_length = self.n_vars.saturating_sub(P::LOG_WIDTH);
54 let mut data = vec![P::one(); 1 << log_packed_length];
55 self.populate(&mut data);
56 Ok(MultilinearExtension::new(self.n_vars, data)?)
57 }
58
59 pub fn populate<P: PackedField>(&self, data: &mut [P]) {
60 let packed_index = self.index / P::WIDTH;
61 data[..packed_index].fill(P::zero());
62 data[packed_index..].fill(P::one());
63 for i in 0..(self.index % P::WIDTH) {
64 data[packed_index].set(i, P::Scalar::ZERO);
65 }
66 }
67}
68
69#[erased_serialize_bytes]
70impl<F: Field> MultivariatePoly<F> for StepUp {
71 fn degree(&self) -> usize {
72 self.n_vars
73 }
74
75 fn n_vars(&self) -> usize {
76 self.n_vars
77 }
78
79 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
80 let n_vars = MultivariatePoly::<F>::n_vars(self);
81 if query.len() != n_vars {
82 bail!(Error::IncorrectQuerySize { expected: n_vars });
83 }
84 let mut k = self.index;
85
86 if k == 1 << n_vars {
87 return Ok(F::ZERO);
88 }
89
90 let mut result = F::ONE;
91 for q in query {
92 if k & 1 == 1 {
93 result *= q;
95 } else {
96 result = result * (F::ONE - q) + q;
98 }
99 k >>= 1;
100 }
101
102 Ok(result)
103 }
104
105 fn binary_tower_level(&self) -> usize {
106 0
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use binius_field::{
113 BinaryField1b, PackedBinaryField128x1b, PackedBinaryField256x1b, PackedField,
114 };
115 use binius_utils::felts;
116
117 use super::StepUp;
118 use crate::polynomial::test_utils::{hypercube_evals_from_oracle, packed_slice};
119
120 #[test]
121 fn test_step_up_trace_without_packing_simple_cases() {
122 assert_eq!(stepup_evals::<BinaryField1b>(2, 0), felts!(BinaryField1b[1, 1, 1, 1]));
123 assert_eq!(stepup_evals::<BinaryField1b>(2, 1), felts!(BinaryField1b[0, 1, 1, 1]));
124 assert_eq!(stepup_evals::<BinaryField1b>(2, 2), felts!(BinaryField1b[0, 0, 1, 1]));
125 assert_eq!(stepup_evals::<BinaryField1b>(2, 3), felts!(BinaryField1b[0, 0, 0, 1]));
126 assert_eq!(stepup_evals::<BinaryField1b>(2, 4), felts!(BinaryField1b[0, 0, 0, 0]));
127 assert_eq!(
128 stepup_evals::<BinaryField1b>(3, 0),
129 felts!(BinaryField1b[1, 1, 1, 1, 1, 1, 1, 1])
130 );
131 assert_eq!(
132 stepup_evals::<BinaryField1b>(3, 1),
133 felts!(BinaryField1b[0, 1, 1, 1, 1, 1, 1, 1])
134 );
135 assert_eq!(
136 stepup_evals::<BinaryField1b>(3, 2),
137 felts!(BinaryField1b[0, 0, 1, 1, 1, 1, 1, 1])
138 );
139 assert_eq!(
140 stepup_evals::<BinaryField1b>(3, 3),
141 felts!(BinaryField1b[0, 0, 0, 1, 1, 1, 1, 1])
142 );
143 assert_eq!(
144 stepup_evals::<BinaryField1b>(3, 4),
145 felts!(BinaryField1b[0, 0, 0, 0, 1, 1, 1, 1])
146 );
147 assert_eq!(
148 stepup_evals::<BinaryField1b>(3, 5),
149 felts!(BinaryField1b[0, 0, 0, 0, 0, 1, 1, 1])
150 );
151 assert_eq!(
152 stepup_evals::<BinaryField1b>(3, 6),
153 felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 1, 1])
154 );
155 assert_eq!(
156 stepup_evals::<BinaryField1b>(3, 7),
157 felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 0, 1])
158 );
159 assert_eq!(
160 stepup_evals::<BinaryField1b>(3, 8),
161 felts!(BinaryField1b[0, 0, 0, 0, 0, 0, 0, 0])
162 );
163 }
164
165 #[test]
166 fn test_step_up_trace_without_packing() {
167 assert_eq!(
168 stepup_evals::<BinaryField1b>(9, 314),
169 packed_slice::<BinaryField1b>(&[(0..314, 0), (314..512, 1)])
170 );
171 assert_eq!(
172 stepup_evals::<BinaryField1b>(10, 555),
173 packed_slice::<BinaryField1b>(&[(0..555, 0), (555..1024, 1)])
174 );
175 assert_eq!(
176 stepup_evals::<BinaryField1b>(11, 0),
177 packed_slice::<BinaryField1b>(&[(0..2048, 1)])
178 );
179 assert_eq!(
180 stepup_evals::<BinaryField1b>(11, 1),
181 packed_slice::<BinaryField1b>(&[(0..1, 0), (1..2048, 1)])
182 );
183 assert_eq!(
184 stepup_evals::<BinaryField1b>(11, 2048),
185 packed_slice::<BinaryField1b>(&[(0..2048, 0)])
186 );
187 }
188
189 #[test]
190 fn test_step_up_trace_with_packing_128() {
191 assert_eq!(
192 stepup_evals::<PackedBinaryField128x1b>(9, 314),
193 packed_slice::<PackedBinaryField128x1b>(&[(0..314, 0), (314..512, 1)])
194 );
195 assert_eq!(
196 stepup_evals::<PackedBinaryField128x1b>(10, 555),
197 packed_slice::<PackedBinaryField128x1b>(&[(0..555, 0), (555..1024, 1)])
198 );
199 assert_eq!(
200 stepup_evals::<PackedBinaryField128x1b>(11, 0),
201 packed_slice::<PackedBinaryField128x1b>(&[(0..2048, 1)])
202 );
203 assert_eq!(
204 stepup_evals::<PackedBinaryField128x1b>(11, 1),
205 packed_slice::<PackedBinaryField128x1b>(&[(0..1, 0), (1..2048, 1)])
206 );
207 assert_eq!(
208 stepup_evals::<PackedBinaryField128x1b>(11, 2048),
209 packed_slice::<PackedBinaryField128x1b>(&[(0..2048, 0)])
210 );
211 }
212
213 #[test]
214 fn test_step_up_trace_with_packing_256() {
215 assert_eq!(
216 stepup_evals::<PackedBinaryField256x1b>(9, 314),
217 packed_slice::<PackedBinaryField256x1b>(&[(0..314, 0), (314..512, 1)])
218 );
219 assert_eq!(
220 stepup_evals::<PackedBinaryField256x1b>(10, 555),
221 packed_slice::<PackedBinaryField256x1b>(&[(0..555, 0), (555..1024, 1)])
222 );
223 assert_eq!(
224 stepup_evals::<PackedBinaryField256x1b>(11, 0),
225 packed_slice::<PackedBinaryField256x1b>(&[(0..2048, 1)])
226 );
227 assert_eq!(
228 stepup_evals::<PackedBinaryField256x1b>(11, 1),
229 packed_slice::<PackedBinaryField256x1b>(&[(0..1, 0), (1..2048, 1)])
230 );
231 assert_eq!(
232 stepup_evals::<PackedBinaryField256x1b>(11, 2048),
233 packed_slice::<PackedBinaryField256x1b>(&[(0..2048, 0)])
234 );
235 }
236
237 #[test]
238 fn test_consistency_between_multilinear_extension_and_multilinear_poly_oracle() {
239 for n_vars in 1..6 {
240 for index in 0..=(1 << n_vars) {
241 let step_up = StepUp::new(n_vars, index).unwrap();
242 assert_eq!(
243 hypercube_evals_from_oracle::<BinaryField1b>(&step_up),
244 step_up
245 .multilinear_extension::<BinaryField1b>()
246 .unwrap()
247 .evals()
248 );
249 }
250 }
251 }
252
253 fn stepup_evals<P>(n_vars: usize, index: usize) -> Vec<P>
254 where
255 P: PackedField<Scalar = BinaryField1b>,
256 {
257 StepUp::new(n_vars, index)
258 .unwrap()
259 .multilinear_extension::<P>()
260 .unwrap()
261 .evals()
262 .to_vec()
263 }
264}