1use std::{cmp::max, ops::DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::bail;
7use bytemuck::zeroed_vec;
8
9use crate::{eq_ind_partial_eval, tensor_prod_eq_ind, Error};
10
11#[derive(Debug)]
19pub struct MultilinearQuery<P, Data = Vec<P>>
20where
21 P: PackedField,
22 Data: DerefMut<Target = [P]>,
23{
24 n_vars: usize,
25 expanded_query: Data,
26}
27
28#[derive(Debug, Clone, Copy)]
30pub struct MultilinearQueryRef<'a, P: PackedField> {
31 n_vars: usize,
32 expanded_query: &'a [P],
33}
34
35impl<'a, P: PackedField, Data: DerefMut<Target = [P]>> From<&'a MultilinearQuery<P, Data>>
36 for MultilinearQueryRef<'a, P>
37{
38 fn from(query: &'a MultilinearQuery<P, Data>) -> Self {
39 Self::new(query)
40 }
41}
42
43impl<'a, P: PackedField> MultilinearQueryRef<'a, P> {
44 pub fn new<Data: DerefMut<Target = [P]>>(query: &'a MultilinearQuery<P, Data>) -> Self {
45 Self {
46 n_vars: query.n_vars,
47 expanded_query: &query.expanded_query,
48 }
49 }
50
51 pub const fn n_vars(&self) -> usize {
52 self.n_vars
53 }
54
55 pub fn expansion(&self) -> &[P] {
59 let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
60 &self.expanded_query[0..expanded_query_len]
61 }
62}
63
64impl<P: PackedField> MultilinearQuery<P, Vec<P>> {
65 pub fn with_capacity(max_query_vars: usize) -> Self {
66 let len = 1 << max_query_vars.saturating_sub(P::LOG_WIDTH);
67 let mut expanded_query = zeroed_vec::<P>(len);
68 expanded_query[0].set(0, P::Scalar::ONE);
69 Self {
70 expanded_query,
71 n_vars: 0,
72 }
73 }
74
75 pub fn expand(query: &[P::Scalar]) -> Self {
76 let expanded_query = eq_ind_partial_eval(query);
77 Self {
78 expanded_query,
79 n_vars: query.len(),
80 }
81 }
82}
83
84impl<P: PackedField, Data: DerefMut<Target = [P]>> MultilinearQuery<P, Data> {
85 pub fn with_expansion(n_vars: usize, expanded_query: Data) -> Result<Self, Error> {
86 let expected_len = 1 << n_vars.saturating_sub(P::LOG_WIDTH);
87 if expanded_query.len() < expected_len {
88 bail!(Error::IncorrectArgumentLength {
89 arg: "expanded_query".to_string(),
90 expected: expected_len,
91 });
92 }
93 Ok(Self {
94 n_vars,
95 expanded_query,
96 })
97 }
98
99 pub const fn n_vars(&self) -> usize {
100 self.n_vars
101 }
102
103 pub fn expansion(&self) -> &[P] {
107 let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
108 &self.expanded_query[0..expanded_query_len]
109 }
110
111 pub fn expansion_mut(&mut self) -> &mut [P] {
115 let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
116 &mut self.expanded_query[0..expanded_query_len]
117 }
118
119 pub fn into_expansion(self) -> Data {
120 self.expanded_query
121 }
122
123 pub fn update(mut self, extra_query_coordinates: &[P::Scalar]) -> Result<Self, Error> {
124 let old_n_vars = self.n_vars;
125 let new_n_vars = old_n_vars + extra_query_coordinates.len();
126 let new_length = max((1 << new_n_vars) / P::WIDTH, 1);
127 if new_length > self.expanded_query.len() {
128 bail!(Error::MultilinearQueryFull {
129 max_query_vars: old_n_vars,
130 });
131 }
132 tensor_prod_eq_ind(
133 old_n_vars,
134 &mut self.expanded_query[..new_length],
135 extra_query_coordinates,
136 )?;
137
138 Ok(Self {
139 n_vars: new_n_vars,
140 expanded_query: self.expanded_query,
141 })
142 }
143
144 pub fn to_ref(&self) -> MultilinearQueryRef<P> {
145 self.into()
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use binius_field::{Field, PackedBinaryField4x32b, PackedField};
152 use binius_utils::felts;
153 use itertools::Itertools;
154
155 use super::*;
156 use crate::tensor_prod_eq_ind;
157
158 type P = PackedBinaryField4x32b;
159 type F = <P as PackedField>::Scalar;
160
161 fn tensor_prod<P: PackedField>(p: &[P::Scalar]) -> Vec<P> {
162 let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
163 result[0] = P::set_single(P::Scalar::ONE);
164 tensor_prod_eq_ind(0, &mut result, p).unwrap();
165 result
166 }
167
168 macro_rules! expand_query {
169 ($f:ident[$($elem:expr),* $(,)?], Packing=$p:ident) => {
170 binius_field::PackedField::iter_slice(
171 MultilinearQuery::<$p, _>::with_expansion(
172 {
173 let elems: &[$f] = &[$($f::new($elem)),*];
174 elems.len()
175 },
176 tensor_prod(&[$($f::new($elem)),*])
177 )
178 .unwrap()
179 .expansion(),
180 ).collect::<Vec<_>>()
181 };
182 }
183
184 #[test]
185 fn test_query_no_packing_32b() {
186 use binius_field::BinaryField32b;
187
188 assert_eq!(
189 expand_query!(BinaryField32b[], Packing = BinaryField32b),
190 felts!(BinaryField32b[1])
191 );
192 assert_eq!(
193 expand_query!(BinaryField32b[2], Packing = BinaryField32b),
194 felts!(BinaryField32b[3, 2])
195 );
196 assert_eq!(
197 expand_query!(BinaryField32b[2, 2], Packing = BinaryField32b),
198 felts!(BinaryField32b[2, 1, 1, 3])
199 );
200 assert_eq!(
201 expand_query!(BinaryField32b[2, 2, 2], Packing = BinaryField32b),
202 felts!(BinaryField32b[1, 3, 3, 2, 3, 2, 2, 1])
203 );
204 assert_eq!(
205 expand_query!(BinaryField32b[2, 2, 2, 2], Packing = BinaryField32b),
206 felts!(BinaryField32b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
207 );
208 }
209
210 #[test]
211 fn test_query_packing_4x32b() {
212 use binius_field::{BinaryField32b, PackedBinaryField4x32b};
213 assert_eq!(
214 expand_query!(BinaryField32b[], Packing = PackedBinaryField4x32b),
215 felts!(BinaryField32b[1, 0, 0, 0])
216 );
217 assert_eq!(
218 expand_query!(BinaryField32b[2, 2], Packing = PackedBinaryField4x32b),
219 felts!(BinaryField32b[2, 1, 1, 3])
220 );
221 assert_eq!(
222 expand_query!(BinaryField32b[2], Packing = PackedBinaryField4x32b),
223 felts!(BinaryField32b[3, 2, 0, 0])
224 );
225 assert_eq!(
226 expand_query!(BinaryField32b[2, 2, 2], Packing = PackedBinaryField4x32b),
227 felts!(BinaryField32b[1, 3, 3, 2, 3, 2, 2, 1])
228 );
229 assert_eq!(
230 expand_query!(BinaryField32b[2, 2, 2, 2], Packing = PackedBinaryField4x32b),
231 felts!(BinaryField32b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
232 );
233 }
234
235 #[test]
236 fn test_query_packing_8x16b() {
237 use binius_field::{BinaryField16b, PackedBinaryField8x16b};
238 assert_eq!(
239 expand_query!(BinaryField16b[], Packing = PackedBinaryField8x16b),
240 felts!(BinaryField16b[1, 0, 0, 0, 0, 0, 0, 0])
241 );
242 assert_eq!(
243 expand_query!(BinaryField16b[2], Packing = PackedBinaryField8x16b),
244 felts!(BinaryField16b[3, 2, 0, 0, 0, 0, 0, 0])
245 );
246 assert_eq!(
247 expand_query!(BinaryField16b[2, 2], Packing = PackedBinaryField8x16b),
248 felts!(BinaryField16b[2, 1, 1, 3, 0, 0, 0, 0])
249 );
250 assert_eq!(
251 expand_query!(BinaryField16b[2, 2, 2], Packing = PackedBinaryField8x16b),
252 felts!(BinaryField16b[1, 3, 3, 2, 3, 2, 2, 1])
253 );
254 assert_eq!(
255 expand_query!(BinaryField16b[2, 2, 2, 2], Packing = PackedBinaryField8x16b),
256 felts!(BinaryField16b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
257 );
258 }
259
260 #[test]
261 fn test_update_single_var() {
262 let query = MultilinearQuery::<P>::with_capacity(2);
263 let r0 = F::new(2);
264 let extra_query = [r0];
265
266 let updated_query = query.update(&extra_query).unwrap();
267
268 assert_eq!(updated_query.n_vars(), 1);
269
270 let expansion = updated_query.into_expansion();
271 let expansion = PackedField::iter_slice(&expansion).collect_vec();
272
273 assert_eq!(expansion, vec![(F::ONE - r0), r0, F::ZERO, F::ZERO]);
274 }
275
276 #[test]
277 fn test_update_two_vars() {
278 let query = MultilinearQuery::<P>::with_capacity(3);
279 let r0 = F::new(2);
280 let r1 = F::new(3);
281 let extra_query = [r0, r1];
282
283 let updated_query = query.update(&extra_query).unwrap();
284 assert_eq!(updated_query.n_vars(), 2);
285
286 let expansion = updated_query.expansion();
287 let expansion = PackedField::iter_slice(expansion).collect_vec();
288
289 assert_eq!(
290 expansion,
291 vec![
292 (F::ONE - r0) * (F::ONE - r1),
293 r0 * (F::ONE - r1),
294 (F::ONE - r0) * r1,
295 r0 * r1,
296 ]
297 );
298 }
299
300 #[test]
301 fn test_update_three_vars() {
302 let query = MultilinearQuery::<P>::with_capacity(4);
303 let r0 = F::new(2);
304 let r1 = F::new(3);
305 let r2 = F::new(5);
306 let extra_query = [r0, r1, r2];
307
308 let updated_query = query.update(&extra_query).unwrap();
309 assert_eq!(updated_query.n_vars(), 3);
310
311 let expansion = updated_query.expansion();
312 let expansion = PackedField::iter_slice(expansion).collect_vec();
313
314 assert_eq!(
315 expansion,
316 vec![
317 (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
318 r0 * (F::ONE - r1) * (F::ONE - r2),
319 (F::ONE - r0) * r1 * (F::ONE - r2),
320 r0 * r1 * (F::ONE - r2),
321 (F::ONE - r0) * (F::ONE - r1) * r2,
322 r0 * (F::ONE - r1) * r2,
323 (F::ONE - r0) * r1 * r2,
324 r0 * r1 * r2,
325 ]
326 );
327 }
328
329 #[test]
330 fn test_update_exceeds_capacity() {
331 let query = MultilinearQuery::<P>::with_capacity(2);
332 let extra_query = [F::new(2), F::new(3), F::new(5)];
334
335 let result = query.update(&extra_query);
336 assert!(result.is_err());
338 }
339
340 #[test]
341 fn test_update_empty() {
342 let query = MultilinearQuery::<P>::with_capacity(2);
343 let updated_query = query.update(&[]).unwrap();
345
346 assert_eq!(updated_query.n_vars(), 0);
347
348 let expansion = updated_query.expansion();
349 let expansion = PackedField::iter_slice(expansion).collect_vec();
350
351 assert_eq!(expansion, vec![F::ONE, F::ZERO, F::ZERO, F::ZERO]);
352 }
353}