1use std::{cmp::max, ops::DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::bail;
7use bytemuck::zeroed_vec;
8
9use crate::{Error, eq_ind_partial_eval, tensor_prod_eq_ind};
10
11#[derive(Debug)]
19pub struct MultilinearQuery<P, Data = Vec<P>>
20where
21 P: PackedField,
22 Data: DerefMut<Target = [P]>,
23{
24 n_vars: usize,
25 expanded_query: Data,
26}
27
28#[derive(Debug, Clone, Copy)]
30pub struct MultilinearQueryRef<'a, P: PackedField> {
31 n_vars: usize,
32 expanded_query: &'a [P],
33}
34
35impl<'a, P: PackedField, Data: DerefMut<Target = [P]>> From<&'a MultilinearQuery<P, Data>>
36 for MultilinearQueryRef<'a, P>
37{
38 fn from(query: &'a MultilinearQuery<P, Data>) -> Self {
39 Self::new(query)
40 }
41}
42
43impl<'a, P: PackedField> MultilinearQueryRef<'a, P> {
44 pub fn new<Data: DerefMut<Target = [P]>>(query: &'a MultilinearQuery<P, Data>) -> Self {
45 Self {
46 n_vars: query.n_vars,
47 expanded_query: &query.expanded_query,
48 }
49 }
50
51 pub const fn n_vars(&self) -> usize {
52 self.n_vars
53 }
54
55 pub fn expansion(&self) -> &[P] {
60 let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
61 &self.expanded_query[0..expanded_query_len]
62 }
63}
64
65impl<P: PackedField> MultilinearQuery<P, Vec<P>> {
66 pub fn with_capacity(max_query_vars: usize) -> Self {
67 let len = 1 << max_query_vars.saturating_sub(P::LOG_WIDTH);
68 let mut expanded_query = zeroed_vec::<P>(len);
69 expanded_query[0].set(0, P::Scalar::ONE);
70 Self {
71 expanded_query,
72 n_vars: 0,
73 }
74 }
75
76 pub fn expand(query: &[P::Scalar]) -> Self {
77 let expanded_query = eq_ind_partial_eval(query);
78 Self {
79 expanded_query,
80 n_vars: query.len(),
81 }
82 }
83}
84
85impl<P: PackedField, Data: DerefMut<Target = [P]>> MultilinearQuery<P, Data> {
86 pub fn with_expansion(n_vars: usize, expanded_query: Data) -> Result<Self, Error> {
87 let expected_len = 1 << n_vars.saturating_sub(P::LOG_WIDTH);
88 if expanded_query.len() < expected_len {
89 bail!(Error::IncorrectArgumentLength {
90 arg: "expanded_query".to_string(),
91 expected: expected_len,
92 });
93 }
94 Ok(Self {
95 n_vars,
96 expanded_query,
97 })
98 }
99
100 pub const fn n_vars(&self) -> usize {
101 self.n_vars
102 }
103
104 pub fn expansion(&self) -> &[P] {
109 let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
110 &self.expanded_query[0..expanded_query_len]
111 }
112
113 pub fn expansion_mut(&mut self) -> &mut [P] {
117 let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
118 &mut self.expanded_query[0..expanded_query_len]
119 }
120
121 pub fn into_expansion(self) -> Data {
122 self.expanded_query
123 }
124
125 pub fn update(mut self, extra_query_coordinates: &[P::Scalar]) -> Result<Self, Error> {
126 let old_n_vars = self.n_vars;
127 let new_n_vars = old_n_vars + extra_query_coordinates.len();
128 let new_length = max((1 << new_n_vars) / P::WIDTH, 1);
129 if new_length > self.expanded_query.len() {
130 bail!(Error::MultilinearQueryFull {
131 max_query_vars: old_n_vars,
132 });
133 }
134 tensor_prod_eq_ind(
135 old_n_vars,
136 &mut self.expanded_query[..new_length],
137 extra_query_coordinates,
138 )?;
139
140 Ok(Self {
141 n_vars: new_n_vars,
142 expanded_query: self.expanded_query,
143 })
144 }
145
146 pub fn to_ref(&self) -> MultilinearQueryRef<P> {
147 self.into()
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use binius_field::{Field, PackedBinaryField4x32b, PackedField};
154 use binius_utils::felts;
155 use itertools::Itertools;
156
157 use super::*;
158 use crate::tensor_prod_eq_ind;
159
160 type P = PackedBinaryField4x32b;
161 type F = <P as PackedField>::Scalar;
162
163 fn tensor_prod<P: PackedField>(p: &[P::Scalar]) -> Vec<P> {
164 let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
165 result[0] = P::set_single(P::Scalar::ONE);
166 tensor_prod_eq_ind(0, &mut result, p).unwrap();
167 result
168 }
169
170 macro_rules! expand_query {
171 ($f:ident[$($elem:expr),* $(,)?], Packing=$p:ident) => {
172 binius_field::PackedField::iter_slice(
173 MultilinearQuery::<$p, _>::with_expansion(
174 {
175 let elems: &[$f] = &[$($f::new($elem)),*];
176 elems.len()
177 },
178 tensor_prod(&[$($f::new($elem)),*])
179 )
180 .unwrap()
181 .expansion(),
182 ).collect::<Vec<_>>()
183 };
184 }
185
186 #[test]
187 fn test_query_no_packing_32b() {
188 use binius_field::BinaryField32b;
189
190 assert_eq!(
191 expand_query!(BinaryField32b[], Packing = BinaryField32b),
192 felts!(BinaryField32b[1])
193 );
194 assert_eq!(
195 expand_query!(BinaryField32b[2], Packing = BinaryField32b),
196 felts!(BinaryField32b[3, 2])
197 );
198 assert_eq!(
199 expand_query!(BinaryField32b[2, 2], Packing = BinaryField32b),
200 felts!(BinaryField32b[2, 1, 1, 3])
201 );
202 assert_eq!(
203 expand_query!(BinaryField32b[2, 2, 2], Packing = BinaryField32b),
204 felts!(BinaryField32b[1, 3, 3, 2, 3, 2, 2, 1])
205 );
206 assert_eq!(
207 expand_query!(BinaryField32b[2, 2, 2, 2], Packing = BinaryField32b),
208 felts!(BinaryField32b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
209 );
210 }
211
212 #[test]
213 fn test_query_packing_4x32b() {
214 use binius_field::{BinaryField32b, PackedBinaryField4x32b};
215 assert_eq!(
216 expand_query!(BinaryField32b[], Packing = PackedBinaryField4x32b),
217 felts!(BinaryField32b[1, 0, 0, 0])
218 );
219 assert_eq!(
220 expand_query!(BinaryField32b[2, 2], Packing = PackedBinaryField4x32b),
221 felts!(BinaryField32b[2, 1, 1, 3])
222 );
223 assert_eq!(
224 expand_query!(BinaryField32b[2], Packing = PackedBinaryField4x32b),
225 felts!(BinaryField32b[3, 2, 0, 0])
226 );
227 assert_eq!(
228 expand_query!(BinaryField32b[2, 2, 2], Packing = PackedBinaryField4x32b),
229 felts!(BinaryField32b[1, 3, 3, 2, 3, 2, 2, 1])
230 );
231 assert_eq!(
232 expand_query!(BinaryField32b[2, 2, 2, 2], Packing = PackedBinaryField4x32b),
233 felts!(BinaryField32b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
234 );
235 }
236
237 #[test]
238 fn test_query_packing_8x16b() {
239 use binius_field::{BinaryField16b, PackedBinaryField8x16b};
240 assert_eq!(
241 expand_query!(BinaryField16b[], Packing = PackedBinaryField8x16b),
242 felts!(BinaryField16b[1, 0, 0, 0, 0, 0, 0, 0])
243 );
244 assert_eq!(
245 expand_query!(BinaryField16b[2], Packing = PackedBinaryField8x16b),
246 felts!(BinaryField16b[3, 2, 0, 0, 0, 0, 0, 0])
247 );
248 assert_eq!(
249 expand_query!(BinaryField16b[2, 2], Packing = PackedBinaryField8x16b),
250 felts!(BinaryField16b[2, 1, 1, 3, 0, 0, 0, 0])
251 );
252 assert_eq!(
253 expand_query!(BinaryField16b[2, 2, 2], Packing = PackedBinaryField8x16b),
254 felts!(BinaryField16b[1, 3, 3, 2, 3, 2, 2, 1])
255 );
256 assert_eq!(
257 expand_query!(BinaryField16b[2, 2, 2, 2], Packing = PackedBinaryField8x16b),
258 felts!(BinaryField16b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
259 );
260 }
261
262 #[test]
263 fn test_update_single_var() {
264 let query = MultilinearQuery::<P>::with_capacity(2);
265 let r0 = F::new(2);
266 let extra_query = [r0];
267
268 let updated_query = query.update(&extra_query).unwrap();
269
270 assert_eq!(updated_query.n_vars(), 1);
271
272 let expansion = updated_query.into_expansion();
273 let expansion = PackedField::iter_slice(&expansion).collect_vec();
274
275 assert_eq!(expansion, vec![(F::ONE - r0), r0, F::ZERO, F::ZERO]);
276 }
277
278 #[test]
279 fn test_update_two_vars() {
280 let query = MultilinearQuery::<P>::with_capacity(3);
281 let r0 = F::new(2);
282 let r1 = F::new(3);
283 let extra_query = [r0, r1];
284
285 let updated_query = query.update(&extra_query).unwrap();
286 assert_eq!(updated_query.n_vars(), 2);
287
288 let expansion = updated_query.expansion();
289 let expansion = PackedField::iter_slice(expansion).collect_vec();
290
291 assert_eq!(
292 expansion,
293 vec![
294 (F::ONE - r0) * (F::ONE - r1),
295 r0 * (F::ONE - r1),
296 (F::ONE - r0) * r1,
297 r0 * r1,
298 ]
299 );
300 }
301
302 #[test]
303 fn test_update_three_vars() {
304 let query = MultilinearQuery::<P>::with_capacity(4);
305 let r0 = F::new(2);
306 let r1 = F::new(3);
307 let r2 = F::new(5);
308 let extra_query = [r0, r1, r2];
309
310 let updated_query = query.update(&extra_query).unwrap();
311 assert_eq!(updated_query.n_vars(), 3);
312
313 let expansion = updated_query.expansion();
314 let expansion = PackedField::iter_slice(expansion).collect_vec();
315
316 assert_eq!(
317 expansion,
318 vec![
319 (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
320 r0 * (F::ONE - r1) * (F::ONE - r2),
321 (F::ONE - r0) * r1 * (F::ONE - r2),
322 r0 * r1 * (F::ONE - r2),
323 (F::ONE - r0) * (F::ONE - r1) * r2,
324 r0 * (F::ONE - r1) * r2,
325 (F::ONE - r0) * r1 * r2,
326 r0 * r1 * r2,
327 ]
328 );
329 }
330
331 #[test]
332 fn test_update_exceeds_capacity() {
333 let query = MultilinearQuery::<P>::with_capacity(2);
334 let extra_query = [F::new(2), F::new(3), F::new(5)];
336
337 let result = query.update(&extra_query);
338 assert!(result.is_err());
340 }
341
342 #[test]
343 fn test_update_empty() {
344 let query = MultilinearQuery::<P>::with_capacity(2);
345 let updated_query = query.update(&[]).unwrap();
347
348 assert_eq!(updated_query.n_vars(), 0);
349
350 let expansion = updated_query.expansion();
351 let expansion = PackedField::iter_slice(expansion).collect_vec();
352
353 assert_eq!(expansion, vec![F::ONE, F::ZERO, F::ZERO, F::ZERO]);
354 }
355}