binius_math/
multilinear_query.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{cmp::max, ops::DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::bail;
7use bytemuck::zeroed_vec;
8
9use crate::{Error, eq_ind_partial_eval, tensor_prod_eq_ind};
10
11/// Tensor product expansion of sumcheck round challenges.
12///
13/// Stores the tensor product expansion $\bigotimes_{i = 0}^{n - 1} (1 - r_i, r_i)$
14/// when `round()` is `n` for the sequence of sumcheck challenges $(r_0, ..., r_{n-1})$.
15/// The tensor product can be updated with a new round challenge in linear time.
16/// This is used in the first several rounds of the sumcheck prover for small-field polynomials,
17/// before it becomes more efficient to switch over to the method that store folded multilinears.
18#[derive(Debug)]
19pub struct MultilinearQuery<P, Data = Vec<P>>
20where
21	P: PackedField,
22	Data: DerefMut<Target = [P]>,
23{
24	n_vars: usize,
25	expanded_query: Data,
26}
27
28/// Wraps `MultilinearQuery` to hide `Data` from the users.
29#[derive(Debug, Clone, Copy)]
30pub struct MultilinearQueryRef<'a, P: PackedField> {
31	n_vars: usize,
32	expanded_query: &'a [P],
33}
34
35impl<'a, P: PackedField, Data: DerefMut<Target = [P]>> From<&'a MultilinearQuery<P, Data>>
36	for MultilinearQueryRef<'a, P>
37{
38	fn from(query: &'a MultilinearQuery<P, Data>) -> Self {
39		Self::new(query)
40	}
41}
42
43impl<'a, P: PackedField> MultilinearQueryRef<'a, P> {
44	pub fn new<Data: DerefMut<Target = [P]>>(query: &'a MultilinearQuery<P, Data>) -> Self {
45		Self {
46			n_vars: query.n_vars,
47			expanded_query: &query.expanded_query,
48		}
49	}
50
51	pub const fn n_vars(&self) -> usize {
52		self.n_vars
53	}
54
55	/// Returns the tensor product expansion of the query
56	///
57	/// If the number of query variables is less than the packing width, return a single packed
58	/// element.
59	pub fn expansion(&self) -> &[P] {
60		let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
61		&self.expanded_query[0..expanded_query_len]
62	}
63}
64
65impl<P: PackedField> MultilinearQuery<P, Vec<P>> {
66	pub fn with_capacity(max_query_vars: usize) -> Self {
67		let len = 1 << max_query_vars.saturating_sub(P::LOG_WIDTH);
68		let mut expanded_query = zeroed_vec::<P>(len);
69		expanded_query[0].set(0, P::Scalar::ONE);
70		Self {
71			expanded_query,
72			n_vars: 0,
73		}
74	}
75
76	pub fn expand(query: &[P::Scalar]) -> Self {
77		let expanded_query = eq_ind_partial_eval(query);
78		Self {
79			expanded_query,
80			n_vars: query.len(),
81		}
82	}
83}
84
85impl<P: PackedField, Data: DerefMut<Target = [P]>> MultilinearQuery<P, Data> {
86	pub fn with_expansion(n_vars: usize, expanded_query: Data) -> Result<Self, Error> {
87		let expected_len = 1 << n_vars.saturating_sub(P::LOG_WIDTH);
88		if expanded_query.len() < expected_len {
89			bail!(Error::IncorrectArgumentLength {
90				arg: "expanded_query".to_string(),
91				expected: expected_len,
92			});
93		}
94		Ok(Self {
95			n_vars,
96			expanded_query,
97		})
98	}
99
100	pub const fn n_vars(&self) -> usize {
101		self.n_vars
102	}
103
104	/// Returns the tensor product expansion of the query
105	///
106	/// If the number of query variables is less than the packing width, return a single packed
107	/// element.
108	pub fn expansion(&self) -> &[P] {
109		let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
110		&self.expanded_query[0..expanded_query_len]
111	}
112
113	// REVIEW: this method is a temporary hack to allow the
114	// construction of a "multilinear query" which contains Lagrange
115	// coefficient evaluations in UnivariateZerocheck::fold_univariate_round
116	pub fn expansion_mut(&mut self) -> &mut [P] {
117		let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
118		&mut self.expanded_query[0..expanded_query_len]
119	}
120
121	pub fn into_expansion(self) -> Data {
122		self.expanded_query
123	}
124
125	pub fn update(mut self, extra_query_coordinates: &[P::Scalar]) -> Result<Self, Error> {
126		let old_n_vars = self.n_vars;
127		let new_n_vars = old_n_vars + extra_query_coordinates.len();
128		let new_length = max((1 << new_n_vars) / P::WIDTH, 1);
129		if new_length > self.expanded_query.len() {
130			bail!(Error::MultilinearQueryFull {
131				max_query_vars: old_n_vars,
132			});
133		}
134		tensor_prod_eq_ind(
135			old_n_vars,
136			&mut self.expanded_query[..new_length],
137			extra_query_coordinates,
138		)?;
139
140		Ok(Self {
141			n_vars: new_n_vars,
142			expanded_query: self.expanded_query,
143		})
144	}
145
146	pub fn to_ref(&self) -> MultilinearQueryRef<P> {
147		self.into()
148	}
149}
150
151#[cfg(test)]
152mod tests {
153	use binius_field::{Field, PackedBinaryField4x32b, PackedField};
154	use binius_utils::felts;
155	use itertools::Itertools;
156
157	use super::*;
158	use crate::tensor_prod_eq_ind;
159
160	type P = PackedBinaryField4x32b;
161	type F = <P as PackedField>::Scalar;
162
163	fn tensor_prod<P: PackedField>(p: &[P::Scalar]) -> Vec<P> {
164		let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
165		result[0] = P::set_single(P::Scalar::ONE);
166		tensor_prod_eq_ind(0, &mut result, p).unwrap();
167		result
168	}
169
170	macro_rules! expand_query {
171		($f:ident[$($elem:expr),* $(,)?], Packing=$p:ident) => {
172			binius_field::PackedField::iter_slice(
173				MultilinearQuery::<$p, _>::with_expansion(
174					{
175						let elems: &[$f] = &[$($f::new($elem)),*];
176						elems.len()
177					},
178					tensor_prod(&[$($f::new($elem)),*])
179				)
180				.unwrap()
181				.expansion(),
182			).collect::<Vec<_>>()
183		};
184	}
185
186	#[test]
187	fn test_query_no_packing_32b() {
188		use binius_field::BinaryField32b;
189
190		assert_eq!(
191			expand_query!(BinaryField32b[], Packing = BinaryField32b),
192			felts!(BinaryField32b[1])
193		);
194		assert_eq!(
195			expand_query!(BinaryField32b[2], Packing = BinaryField32b),
196			felts!(BinaryField32b[3, 2])
197		);
198		assert_eq!(
199			expand_query!(BinaryField32b[2, 2], Packing = BinaryField32b),
200			felts!(BinaryField32b[2, 1, 1, 3])
201		);
202		assert_eq!(
203			expand_query!(BinaryField32b[2, 2, 2], Packing = BinaryField32b),
204			felts!(BinaryField32b[1, 3, 3, 2, 3, 2, 2, 1])
205		);
206		assert_eq!(
207			expand_query!(BinaryField32b[2, 2, 2, 2], Packing = BinaryField32b),
208			felts!(BinaryField32b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
209		);
210	}
211
212	#[test]
213	fn test_query_packing_4x32b() {
214		use binius_field::{BinaryField32b, PackedBinaryField4x32b};
215		assert_eq!(
216			expand_query!(BinaryField32b[], Packing = PackedBinaryField4x32b),
217			felts!(BinaryField32b[1, 0, 0, 0])
218		);
219		assert_eq!(
220			expand_query!(BinaryField32b[2, 2], Packing = PackedBinaryField4x32b),
221			felts!(BinaryField32b[2, 1, 1, 3])
222		);
223		assert_eq!(
224			expand_query!(BinaryField32b[2], Packing = PackedBinaryField4x32b),
225			felts!(BinaryField32b[3, 2, 0, 0])
226		);
227		assert_eq!(
228			expand_query!(BinaryField32b[2, 2, 2], Packing = PackedBinaryField4x32b),
229			felts!(BinaryField32b[1, 3, 3, 2, 3, 2, 2, 1])
230		);
231		assert_eq!(
232			expand_query!(BinaryField32b[2, 2, 2, 2], Packing = PackedBinaryField4x32b),
233			felts!(BinaryField32b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
234		);
235	}
236
237	#[test]
238	fn test_query_packing_8x16b() {
239		use binius_field::{BinaryField16b, PackedBinaryField8x16b};
240		assert_eq!(
241			expand_query!(BinaryField16b[], Packing = PackedBinaryField8x16b),
242			felts!(BinaryField16b[1, 0, 0, 0, 0, 0, 0, 0])
243		);
244		assert_eq!(
245			expand_query!(BinaryField16b[2], Packing = PackedBinaryField8x16b),
246			felts!(BinaryField16b[3, 2, 0, 0, 0, 0, 0, 0])
247		);
248		assert_eq!(
249			expand_query!(BinaryField16b[2, 2], Packing = PackedBinaryField8x16b),
250			felts!(BinaryField16b[2, 1, 1, 3, 0, 0, 0, 0])
251		);
252		assert_eq!(
253			expand_query!(BinaryField16b[2, 2, 2], Packing = PackedBinaryField8x16b),
254			felts!(BinaryField16b[1, 3, 3, 2, 3, 2, 2, 1])
255		);
256		assert_eq!(
257			expand_query!(BinaryField16b[2, 2, 2, 2], Packing = PackedBinaryField8x16b),
258			felts!(BinaryField16b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
259		);
260	}
261
262	#[test]
263	fn test_update_single_var() {
264		let query = MultilinearQuery::<P>::with_capacity(2);
265		let r0 = F::new(2);
266		let extra_query = [r0];
267
268		let updated_query = query.update(&extra_query).unwrap();
269
270		assert_eq!(updated_query.n_vars(), 1);
271
272		let expansion = updated_query.into_expansion();
273		let expansion = PackedField::iter_slice(&expansion).collect_vec();
274
275		assert_eq!(expansion, vec![(F::ONE - r0), r0, F::ZERO, F::ZERO]);
276	}
277
278	#[test]
279	fn test_update_two_vars() {
280		let query = MultilinearQuery::<P>::with_capacity(3);
281		let r0 = F::new(2);
282		let r1 = F::new(3);
283		let extra_query = [r0, r1];
284
285		let updated_query = query.update(&extra_query).unwrap();
286		assert_eq!(updated_query.n_vars(), 2);
287
288		let expansion = updated_query.expansion();
289		let expansion = PackedField::iter_slice(expansion).collect_vec();
290
291		assert_eq!(
292			expansion,
293			vec![
294				(F::ONE - r0) * (F::ONE - r1),
295				r0 * (F::ONE - r1),
296				(F::ONE - r0) * r1,
297				r0 * r1,
298			]
299		);
300	}
301
302	#[test]
303	fn test_update_three_vars() {
304		let query = MultilinearQuery::<P>::with_capacity(4);
305		let r0 = F::new(2);
306		let r1 = F::new(3);
307		let r2 = F::new(5);
308		let extra_query = [r0, r1, r2];
309
310		let updated_query = query.update(&extra_query).unwrap();
311		assert_eq!(updated_query.n_vars(), 3);
312
313		let expansion = updated_query.expansion();
314		let expansion = PackedField::iter_slice(expansion).collect_vec();
315
316		assert_eq!(
317			expansion,
318			vec![
319				(F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
320				r0 * (F::ONE - r1) * (F::ONE - r2),
321				(F::ONE - r0) * r1 * (F::ONE - r2),
322				r0 * r1 * (F::ONE - r2),
323				(F::ONE - r0) * (F::ONE - r1) * r2,
324				r0 * (F::ONE - r1) * r2,
325				(F::ONE - r0) * r1 * r2,
326				r0 * r1 * r2,
327			]
328		);
329	}
330
331	#[test]
332	fn test_update_exceeds_capacity() {
333		let query = MultilinearQuery::<P>::with_capacity(2);
334		// More than allowed capacity
335		let extra_query = [F::new(2), F::new(3), F::new(5)];
336
337		let result = query.update(&extra_query);
338		// Expecting an error due to exceeding max_query_vars
339		assert!(result.is_err());
340	}
341
342	#[test]
343	fn test_update_empty() {
344		let query = MultilinearQuery::<P>::with_capacity(2);
345		// Updating with no new coordinates should be fine
346		let updated_query = query.update(&[]).unwrap();
347
348		assert_eq!(updated_query.n_vars(), 0);
349
350		let expansion = updated_query.expansion();
351		let expansion = PackedField::iter_slice(expansion).collect_vec();
352
353		assert_eq!(expansion, vec![F::ONE, F::ZERO, F::ZERO, F::ZERO]);
354	}
355}