binius_math/
multilinear_query.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{cmp::max, ops::DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::bail;
7use bytemuck::zeroed_vec;
8
9use crate::{eq_ind_partial_eval, tensor_prod_eq_ind, Error};
10
11/// Tensor product expansion of sumcheck round challenges.
12///
13/// Stores the tensor product expansion $\bigotimes_{i = 0}^{n - 1} (1 - r_i, r_i)$
14/// when `round()` is `n` for the sequence of sumcheck challenges $(r_0, ..., r_{n-1})$.
15/// The tensor product can be updated with a new round challenge in linear time.
16/// This is used in the first several rounds of the sumcheck prover for small-field polynomials,
17/// before it becomes more efficient to switch over to the method that store folded multilinears.
18#[derive(Debug)]
19pub struct MultilinearQuery<P, Data = Vec<P>>
20where
21	P: PackedField,
22	Data: DerefMut<Target = [P]>,
23{
24	n_vars: usize,
25	expanded_query: Data,
26}
27
28/// Wraps `MultilinearQuery` to hide `Data` from the users.
29#[derive(Debug, Clone, Copy)]
30pub struct MultilinearQueryRef<'a, P: PackedField> {
31	n_vars: usize,
32	expanded_query: &'a [P],
33}
34
35impl<'a, P: PackedField, Data: DerefMut<Target = [P]>> From<&'a MultilinearQuery<P, Data>>
36	for MultilinearQueryRef<'a, P>
37{
38	fn from(query: &'a MultilinearQuery<P, Data>) -> Self {
39		Self::new(query)
40	}
41}
42
43impl<'a, P: PackedField> MultilinearQueryRef<'a, P> {
44	pub fn new<Data: DerefMut<Target = [P]>>(query: &'a MultilinearQuery<P, Data>) -> Self {
45		Self {
46			n_vars: query.n_vars,
47			expanded_query: &query.expanded_query,
48		}
49	}
50
51	pub const fn n_vars(&self) -> usize {
52		self.n_vars
53	}
54
55	/// Returns the tensor product expansion of the query
56	///
57	/// If the number of query variables is less than the packing width, return a single packed element.
58	pub fn expansion(&self) -> &[P] {
59		let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
60		&self.expanded_query[0..expanded_query_len]
61	}
62}
63
64impl<P: PackedField> MultilinearQuery<P, Vec<P>> {
65	pub fn with_capacity(max_query_vars: usize) -> Self {
66		let len = 1 << max_query_vars.saturating_sub(P::LOG_WIDTH);
67		let mut expanded_query = zeroed_vec::<P>(len);
68		expanded_query[0].set(0, P::Scalar::ONE);
69		Self {
70			expanded_query,
71			n_vars: 0,
72		}
73	}
74
75	pub fn expand(query: &[P::Scalar]) -> Self {
76		let expanded_query = eq_ind_partial_eval(query);
77		Self {
78			expanded_query,
79			n_vars: query.len(),
80		}
81	}
82}
83
84impl<P: PackedField, Data: DerefMut<Target = [P]>> MultilinearQuery<P, Data> {
85	pub fn with_expansion(n_vars: usize, expanded_query: Data) -> Result<Self, Error> {
86		let expected_len = 1 << n_vars.saturating_sub(P::LOG_WIDTH);
87		if expanded_query.len() < expected_len {
88			bail!(Error::IncorrectArgumentLength {
89				arg: "expanded_query".to_string(),
90				expected: expected_len,
91			});
92		}
93		Ok(Self {
94			n_vars,
95			expanded_query,
96		})
97	}
98
99	pub const fn n_vars(&self) -> usize {
100		self.n_vars
101	}
102
103	/// Returns the tensor product expansion of the query
104	///
105	/// If the number of query variables is less than the packing width, return a single packed element.
106	pub fn expansion(&self) -> &[P] {
107		let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
108		&self.expanded_query[0..expanded_query_len]
109	}
110
111	// REVIEW: this method is a temporary hack to allow the
112	// construction of a "multilinear query" which contains Lagrange
113	// coefficient evaluations in UnivariateZerocheck::fold_univariate_round
114	pub fn expansion_mut(&mut self) -> &mut [P] {
115		let expanded_query_len = 1 << self.n_vars.saturating_sub(P::LOG_WIDTH);
116		&mut self.expanded_query[0..expanded_query_len]
117	}
118
119	pub fn into_expansion(self) -> Data {
120		self.expanded_query
121	}
122
123	pub fn update(mut self, extra_query_coordinates: &[P::Scalar]) -> Result<Self, Error> {
124		let old_n_vars = self.n_vars;
125		let new_n_vars = old_n_vars + extra_query_coordinates.len();
126		let new_length = max((1 << new_n_vars) / P::WIDTH, 1);
127		if new_length > self.expanded_query.len() {
128			bail!(Error::MultilinearQueryFull {
129				max_query_vars: old_n_vars,
130			});
131		}
132		tensor_prod_eq_ind(
133			old_n_vars,
134			&mut self.expanded_query[..new_length],
135			extra_query_coordinates,
136		)?;
137
138		Ok(Self {
139			n_vars: new_n_vars,
140			expanded_query: self.expanded_query,
141		})
142	}
143
144	pub fn to_ref(&self) -> MultilinearQueryRef<P> {
145		self.into()
146	}
147}
148
149#[cfg(test)]
150mod tests {
151	use binius_field::{Field, PackedBinaryField4x32b, PackedField};
152	use binius_utils::felts;
153	use itertools::Itertools;
154
155	use super::*;
156	use crate::tensor_prod_eq_ind;
157
158	type P = PackedBinaryField4x32b;
159	type F = <P as PackedField>::Scalar;
160
161	fn tensor_prod<P: PackedField>(p: &[P::Scalar]) -> Vec<P> {
162		let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
163		result[0] = P::set_single(P::Scalar::ONE);
164		tensor_prod_eq_ind(0, &mut result, p).unwrap();
165		result
166	}
167
168	macro_rules! expand_query {
169		($f:ident[$($elem:expr),* $(,)?], Packing=$p:ident) => {
170			binius_field::PackedField::iter_slice(
171				MultilinearQuery::<$p, _>::with_expansion(
172					{
173						let elems: &[$f] = &[$($f::new($elem)),*];
174						elems.len()
175					},
176					tensor_prod(&[$($f::new($elem)),*])
177				)
178				.unwrap()
179				.expansion(),
180			).collect::<Vec<_>>()
181		};
182	}
183
184	#[test]
185	fn test_query_no_packing_32b() {
186		use binius_field::BinaryField32b;
187
188		assert_eq!(
189			expand_query!(BinaryField32b[], Packing = BinaryField32b),
190			felts!(BinaryField32b[1])
191		);
192		assert_eq!(
193			expand_query!(BinaryField32b[2], Packing = BinaryField32b),
194			felts!(BinaryField32b[3, 2])
195		);
196		assert_eq!(
197			expand_query!(BinaryField32b[2, 2], Packing = BinaryField32b),
198			felts!(BinaryField32b[2, 1, 1, 3])
199		);
200		assert_eq!(
201			expand_query!(BinaryField32b[2, 2, 2], Packing = BinaryField32b),
202			felts!(BinaryField32b[1, 3, 3, 2, 3, 2, 2, 1])
203		);
204		assert_eq!(
205			expand_query!(BinaryField32b[2, 2, 2, 2], Packing = BinaryField32b),
206			felts!(BinaryField32b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
207		);
208	}
209
210	#[test]
211	fn test_query_packing_4x32b() {
212		use binius_field::{BinaryField32b, PackedBinaryField4x32b};
213		assert_eq!(
214			expand_query!(BinaryField32b[], Packing = PackedBinaryField4x32b),
215			felts!(BinaryField32b[1, 0, 0, 0])
216		);
217		assert_eq!(
218			expand_query!(BinaryField32b[2, 2], Packing = PackedBinaryField4x32b),
219			felts!(BinaryField32b[2, 1, 1, 3])
220		);
221		assert_eq!(
222			expand_query!(BinaryField32b[2], Packing = PackedBinaryField4x32b),
223			felts!(BinaryField32b[3, 2, 0, 0])
224		);
225		assert_eq!(
226			expand_query!(BinaryField32b[2, 2, 2], Packing = PackedBinaryField4x32b),
227			felts!(BinaryField32b[1, 3, 3, 2, 3, 2, 2, 1])
228		);
229		assert_eq!(
230			expand_query!(BinaryField32b[2, 2, 2, 2], Packing = PackedBinaryField4x32b),
231			felts!(BinaryField32b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
232		);
233	}
234
235	#[test]
236	fn test_query_packing_8x16b() {
237		use binius_field::{BinaryField16b, PackedBinaryField8x16b};
238		assert_eq!(
239			expand_query!(BinaryField16b[], Packing = PackedBinaryField8x16b),
240			felts!(BinaryField16b[1, 0, 0, 0, 0, 0, 0, 0])
241		);
242		assert_eq!(
243			expand_query!(BinaryField16b[2], Packing = PackedBinaryField8x16b),
244			felts!(BinaryField16b[3, 2, 0, 0, 0, 0, 0, 0])
245		);
246		assert_eq!(
247			expand_query!(BinaryField16b[2, 2], Packing = PackedBinaryField8x16b),
248			felts!(BinaryField16b[2, 1, 1, 3, 0, 0, 0, 0])
249		);
250		assert_eq!(
251			expand_query!(BinaryField16b[2, 2, 2], Packing = PackedBinaryField8x16b),
252			felts!(BinaryField16b[1, 3, 3, 2, 3, 2, 2, 1])
253		);
254		assert_eq!(
255			expand_query!(BinaryField16b[2, 2, 2, 2], Packing = PackedBinaryField8x16b),
256			felts!(BinaryField16b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2])
257		);
258	}
259
260	#[test]
261	fn test_update_single_var() {
262		let query = MultilinearQuery::<P>::with_capacity(2);
263		let r0 = F::new(2);
264		let extra_query = [r0];
265
266		let updated_query = query.update(&extra_query).unwrap();
267
268		assert_eq!(updated_query.n_vars(), 1);
269
270		let expansion = updated_query.into_expansion();
271		let expansion = PackedField::iter_slice(&expansion).collect_vec();
272
273		assert_eq!(expansion, vec![(F::ONE - r0), r0, F::ZERO, F::ZERO]);
274	}
275
276	#[test]
277	fn test_update_two_vars() {
278		let query = MultilinearQuery::<P>::with_capacity(3);
279		let r0 = F::new(2);
280		let r1 = F::new(3);
281		let extra_query = [r0, r1];
282
283		let updated_query = query.update(&extra_query).unwrap();
284		assert_eq!(updated_query.n_vars(), 2);
285
286		let expansion = updated_query.expansion();
287		let expansion = PackedField::iter_slice(expansion).collect_vec();
288
289		assert_eq!(
290			expansion,
291			vec![
292				(F::ONE - r0) * (F::ONE - r1),
293				r0 * (F::ONE - r1),
294				(F::ONE - r0) * r1,
295				r0 * r1,
296			]
297		);
298	}
299
300	#[test]
301	fn test_update_three_vars() {
302		let query = MultilinearQuery::<P>::with_capacity(4);
303		let r0 = F::new(2);
304		let r1 = F::new(3);
305		let r2 = F::new(5);
306		let extra_query = [r0, r1, r2];
307
308		let updated_query = query.update(&extra_query).unwrap();
309		assert_eq!(updated_query.n_vars(), 3);
310
311		let expansion = updated_query.expansion();
312		let expansion = PackedField::iter_slice(expansion).collect_vec();
313
314		assert_eq!(
315			expansion,
316			vec![
317				(F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
318				r0 * (F::ONE - r1) * (F::ONE - r2),
319				(F::ONE - r0) * r1 * (F::ONE - r2),
320				r0 * r1 * (F::ONE - r2),
321				(F::ONE - r0) * (F::ONE - r1) * r2,
322				r0 * (F::ONE - r1) * r2,
323				(F::ONE - r0) * r1 * r2,
324				r0 * r1 * r2,
325			]
326		);
327	}
328
329	#[test]
330	fn test_update_exceeds_capacity() {
331		let query = MultilinearQuery::<P>::with_capacity(2);
332		// More than allowed capacity
333		let extra_query = [F::new(2), F::new(3), F::new(5)];
334
335		let result = query.update(&extra_query);
336		// Expecting an error due to exceeding max_query_vars
337		assert!(result.is_err());
338	}
339
340	#[test]
341	fn test_update_empty() {
342		let query = MultilinearQuery::<P>::with_capacity(2);
343		// Updating with no new coordinates should be fine
344		let updated_query = query.update(&[]).unwrap();
345
346		assert_eq!(updated_query.n_vars(), 0);
347
348		let expansion = updated_query.expansion();
349		let expansion = PackedField::iter_slice(expansion).collect_vec();
350
351		assert_eq!(expansion, vec![F::ONE, F::ZERO, F::ZERO, F::ZERO]);
352	}
353}