binius_core/oracle/
multilinear.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, sync::Arc};
4
5use binius_fast_compute::arith_circuit::ArithCircuitPoly;
6use binius_field::{BinaryField128b, Field, TowerField};
7use binius_macros::{DeserializeBytes, SerializeBytes};
8use binius_math::ArithCircuit;
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, bail, ensure,
11};
12use getset::{CopyGetters, Getters};
13
14use crate::{
15	oracle::{CompositePolyOracle, Error, OracleId},
16	polynomial::{Error as PolynomialError, IdentityCompositionPoly, MultivariatePoly},
17};
18
19/// Meta struct that lets you add optional `name` for the Multilinear before adding to the
20/// [`MultilinearOracleSet`]
21pub struct MultilinearOracleSetAddition<'a, F: TowerField> {
22	name: Option<String>,
23	mut_ref: &'a mut MultilinearOracleSet<F>,
24}
25
26impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
27	pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
28		if poly.binary_tower_level() > F::TOWER_LEVEL {
29			bail!(Error::TowerLevelTooHigh {
30				tower_level: poly.binary_tower_level(),
31			});
32		}
33
34		let inner = TransparentPolyOracle::new(Arc::new(poly))?;
35
36		let oracle = |id: OracleId| MultilinearPolyOracle {
37			id,
38			n_vars: inner.poly.n_vars(),
39			tower_level: inner.poly.binary_tower_level(),
40			name: self.name,
41			variant: MultilinearPolyVariant::Transparent(inner),
42		};
43
44		Ok(self.mut_ref.add_to_set(oracle))
45	}
46
47	pub fn structured(self, n_vars: usize, expr: ArithCircuit<F>) -> Result<OracleId, Error> {
48		if expr.binary_tower_level() > F::TOWER_LEVEL {
49			bail!(Error::TowerLevelTooHigh {
50				tower_level: expr.binary_tower_level(),
51			});
52		}
53
54		let oracle = |id: OracleId| MultilinearPolyOracle {
55			id,
56			n_vars,
57			tower_level: expr.binary_tower_level(),
58			name: self.name,
59			variant: MultilinearPolyVariant::Structured(expr),
60		};
61
62		Ok(self.mut_ref.add_to_set(oracle))
63	}
64
65	pub fn committed(mut self, n_vars: usize, tower_level: usize) -> OracleId {
66		let name = self.name.take();
67		self.add_committed_with_name(n_vars, tower_level, name)
68	}
69
70	pub fn committed_multiple<const N: usize>(
71		mut self,
72		n_vars: usize,
73		tower_level: usize,
74	) -> [OracleId; N] {
75		match &self.name.take() {
76			None => [0; N].map(|_| self.add_committed_with_name(n_vars, tower_level, None)),
77			Some(s) => {
78				let x: [usize; N] = array::from_fn(|i| i);
79				x.map(|i| {
80					self.add_committed_with_name(n_vars, tower_level, Some(format!("{s}_{i}")))
81				})
82			}
83		}
84	}
85
86	pub fn repeating(self, inner_id: OracleId, log_count: usize) -> Result<OracleId, Error> {
87		ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
88
89		let inner = &self.mut_ref[inner_id];
90
91		let tower_level = inner.tower_level;
92		let n_vars = inner.n_vars + log_count;
93		let oracle = |id: OracleId| MultilinearPolyOracle {
94			id,
95			n_vars,
96			tower_level,
97			name: self.name,
98			variant: MultilinearPolyVariant::Repeating {
99				id: inner_id,
100				log_count,
101			},
102		};
103
104		Ok(self.mut_ref.add_to_set(oracle))
105	}
106
107	pub fn shifted(
108		self,
109		inner_id: OracleId,
110		offset: usize,
111		block_bits: usize,
112		variant: ShiftVariant,
113	) -> Result<OracleId, Error> {
114		ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
115
116		let n_vars = self.mut_ref.n_vars(inner_id);
117		let tower_level = self.mut_ref.tower_level(inner_id);
118
119		if block_bits > n_vars {
120			bail!(PolynomialError::InvalidBlockSize { n_vars });
121		}
122
123		if offset == 0 || offset >= 1 << block_bits {
124			bail!(PolynomialError::InvalidShiftOffset {
125				max_shift_offset: (1 << block_bits) - 1,
126				shift_offset: offset,
127			});
128		}
129
130		let shifted = Shifted::new(self.mut_ref, inner_id, offset, block_bits, variant)?;
131
132		let oracle = |id: OracleId| MultilinearPolyOracle {
133			id,
134			n_vars,
135			tower_level,
136			name: self.name,
137			variant: MultilinearPolyVariant::Shifted(shifted),
138		};
139
140		Ok(self.mut_ref.add_to_set(oracle))
141	}
142
143	pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
144		ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
145
146		let inner_n_vars = self.mut_ref.n_vars(inner_id);
147		if log_degree > inner_n_vars {
148			bail!(Error::NotEnoughVarsForPacking {
149				n_vars: inner_n_vars,
150				log_degree,
151			});
152		}
153
154		let inner_tower_level = self.mut_ref.tower_level(inner_id);
155
156		let packed = Packed {
157			id: inner_id,
158			log_degree,
159		};
160
161		let oracle = |id: OracleId| MultilinearPolyOracle {
162			id,
163			n_vars: inner_n_vars - log_degree,
164			tower_level: inner_tower_level + log_degree,
165			name: self.name,
166			variant: MultilinearPolyVariant::Packed(packed),
167		};
168
169		Ok(self.mut_ref.add_to_set(oracle))
170	}
171
172	pub fn projected(
173		self,
174		inner_id: OracleId,
175		values: Vec<F>,
176		start_index: usize,
177	) -> Result<OracleId, Error> {
178		ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
179
180		let inner_n_vars = self.mut_ref.n_vars(inner_id);
181		let values_len = values.len();
182		if values_len > inner_n_vars {
183			bail!(Error::InvalidProjection {
184				n_vars: inner_n_vars,
185				values_len,
186			});
187		}
188
189		// TODO: This is wrong, should be F::TOWER_LEVEL
190		let tower_level = self.mut_ref.tower_level(inner_id);
191		let projected = Projected::new(self.mut_ref, inner_id, values, start_index)?;
192
193		let oracle = |id: OracleId| MultilinearPolyOracle {
194			id,
195			n_vars: inner_n_vars - values_len,
196			tower_level,
197			name: self.name,
198			variant: MultilinearPolyVariant::Projected(projected),
199		};
200
201		Ok(self.mut_ref.add_to_set(oracle))
202	}
203
204	pub fn projected_last_vars(
205		self,
206		inner_id: OracleId,
207		values: Vec<F>,
208	) -> Result<OracleId, Error> {
209		ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
210
211		let inner_n_vars = self.mut_ref.n_vars(inner_id);
212		let start_index = inner_n_vars - values.len();
213		let values_len = values.len();
214		if values_len > inner_n_vars {
215			bail!(Error::InvalidProjection {
216				n_vars: inner_n_vars,
217				values_len,
218			});
219		}
220
221		// TODO: This is wrong, should be F::TOWER_LEVEL
222		let tower_level = self.mut_ref.tower_level(inner_id);
223		let projected = Projected::new(self.mut_ref, inner_id, values, start_index)?;
224
225		let oracle = |id: OracleId| MultilinearPolyOracle {
226			id,
227			n_vars: inner_n_vars - values_len,
228			tower_level,
229			name: self.name,
230			variant: MultilinearPolyVariant::Projected(projected),
231		};
232
233		Ok(self.mut_ref.add_to_set(oracle))
234	}
235
236	pub fn linear_combination(
237		self,
238		n_vars: usize,
239		inner: impl IntoIterator<Item = (OracleId, F)>,
240	) -> Result<OracleId, Error> {
241		self.linear_combination_with_offset(n_vars, F::ZERO, inner)
242	}
243
244	pub fn linear_combination_with_offset(
245		self,
246		n_vars: usize,
247		offset: F,
248		inner: impl IntoIterator<Item = (OracleId, F)>,
249	) -> Result<OracleId, Error> {
250		let inner = inner.into_iter().collect::<Vec<_>>();
251		let tower_level = inner
252			.iter()
253			.map(|(oracle, coeff)| {
254				self.mut_ref
255					.tower_level(*oracle)
256					.max(coeff.min_tower_level())
257			})
258			.max()
259			.unwrap_or(0)
260			.max(offset.min_tower_level());
261		let linear_combination = LinearCombination::new(self.mut_ref, n_vars, offset, inner)?;
262
263		let oracle = |id: OracleId| MultilinearPolyOracle {
264			id,
265			n_vars,
266			tower_level,
267			name: self.name,
268			variant: MultilinearPolyVariant::LinearCombination(linear_combination),
269		};
270
271		Ok(self.mut_ref.add_to_set(oracle))
272	}
273
274	pub fn composite_mle(
275		self,
276		n_vars: usize,
277		inner: impl IntoIterator<Item = OracleId>,
278		comp: ArithCircuit<F>,
279	) -> Result<OracleId, Error> {
280		let inner: Vec<OracleId> = inner.into_iter().collect();
281		let tower_level = inner
282			.iter()
283			.map(|&oracle| self.mut_ref[oracle].binary_tower_level())
284			.max()
285			.unwrap_or(0);
286		let composite_mle = CompositeMLE::new(self.mut_ref, n_vars, inner, comp)?;
287
288		let oracle = |id: OracleId| MultilinearPolyOracle {
289			id,
290			n_vars,
291			tower_level,
292			name: self.name,
293			variant: MultilinearPolyVariant::Composite(composite_mle),
294		};
295
296		Ok(self.mut_ref.add_to_set(oracle))
297	}
298
299	pub fn zero_padded(
300		self,
301		inner_id: OracleId,
302		n_pad_vars: usize,
303		nonzero_index: usize,
304		start_index: usize,
305	) -> Result<OracleId, Error> {
306		let inner_n_vars = self.mut_ref.n_vars(inner_id);
307		let tower_level = self.mut_ref.tower_level(inner_id);
308		if start_index > inner_n_vars {
309			bail!(Error::InvalidStartIndex {
310				expected: inner_n_vars
311			});
312		}
313		let padded =
314			ZeroPadded::new(self.mut_ref, inner_id, n_pad_vars, nonzero_index, start_index)?;
315
316		let oracle = |id: OracleId| MultilinearPolyOracle {
317			id,
318			n_vars: inner_n_vars + n_pad_vars,
319			tower_level,
320			name: self.name,
321			variant: MultilinearPolyVariant::ZeroPadded(padded),
322		};
323
324		Ok(self.mut_ref.add_to_set(oracle))
325	}
326
327	fn add_committed_with_name(
328		&mut self,
329		n_vars: usize,
330		tower_level: usize,
331		name: Option<String>,
332	) -> OracleId {
333		let oracle = |oracle_id: OracleId| MultilinearPolyOracle {
334			id: oracle_id,
335			n_vars,
336			tower_level,
337			name: name.clone(),
338			variant: MultilinearPolyVariant::Committed,
339		};
340
341		self.mut_ref.add_to_set(oracle)
342	}
343}
344
345/// An ordered set of multilinear polynomial oracles.
346///
347/// The multilinear polynomial oracles form a directed acyclic graph, where each multilinear oracle
348/// is either transparent, committed, or derived from one or more others. Each oracle is assigned a
349/// unique `OracleId`.
350///
351/// The oracle set also tracks the committed polynomial in batches where each batch is committed
352/// together with a polynomial commitment scheme.
353#[derive(Default, Debug, Clone, SerializeBytes, DeserializeBytes)]
354#[deserialize_bytes(eval_generics(F = BinaryField128b))]
355pub struct MultilinearOracleSet<F: TowerField> {
356	/// The vector of oracles.
357	///
358	/// During sizing, an oracle could be skipped. In which case, the entry corresponding to the
359	/// oracle will be `None`.
360	oracles: Vec<Option<MultilinearPolyOracle<F>>>,
361	/// The number of non-`None` entries in `oracles`.
362	size: usize,
363}
364
365impl<F: TowerField> MultilinearOracleSet<F> {
366	pub const fn new() -> Self {
367		Self {
368			oracles: Vec::new(),
369			size: 0,
370		}
371	}
372
373	pub fn size(&self) -> usize {
374		self.size
375	}
376
377	pub fn polys(&self) -> impl Iterator<Item = &MultilinearPolyOracle<F>> + '_ {
378		self.iter().map(|(_, poly)| poly)
379	}
380
381	pub fn ids(&self) -> impl Iterator<Item = OracleId> {
382		self.iter().map(|(id, _)| id)
383	}
384
385	pub fn iter(&self) -> impl Iterator<Item = (OracleId, &MultilinearPolyOracle<F>)> + '_ {
386		(0..self.oracles.len()).filter_map(|index| match self.oracles[index] {
387			Some(ref oracle) => {
388				let oracle_id = OracleId::from_index(index);
389				Some((oracle_id, oracle))
390			}
391			None => None,
392		})
393	}
394
395	pub const fn add(&mut self) -> MultilinearOracleSetAddition<F> {
396		MultilinearOracleSetAddition {
397			name: None,
398			mut_ref: self,
399		}
400	}
401
402	pub fn add_named<S: ToString>(&mut self, s: S) -> MultilinearOracleSetAddition<F> {
403		MultilinearOracleSetAddition {
404			name: Some(s.to_string()),
405			mut_ref: self,
406		}
407	}
408
409	pub fn is_valid_oracle_id(&self, id: OracleId) -> bool {
410		id.index() < self.oracles.len()
411	}
412
413	pub(crate) fn add_to_set(
414		&mut self,
415		oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
416	) -> OracleId {
417		let id = OracleId::from_index(self.oracles.len());
418		self.oracles.push(Some(oracle(id)));
419		self.size += 1;
420		id
421	}
422
423	/// Instead of adding a concrete oracle adds a skip mark, essentially just reserving an
424	/// [`OracleId`]. Accessing such oracle, with some exceptions, is an error.
425	pub(crate) fn add_skip(&mut self) {
426		self.oracles.push(None);
427		// don't increment size
428	}
429
430	pub fn add_transparent(
431		&mut self,
432		poly: impl MultivariatePoly<F> + 'static,
433	) -> Result<OracleId, Error> {
434		self.add().transparent(poly)
435	}
436
437	pub fn add_committed(&mut self, n_vars: usize, tower_level: usize) -> OracleId {
438		self.add().committed(n_vars, tower_level)
439	}
440
441	pub fn add_committed_multiple<const N: usize>(
442		&mut self,
443		n_vars: usize,
444		tower_level: usize,
445	) -> [OracleId; N] {
446		self.add().committed_multiple(n_vars, tower_level)
447	}
448
449	pub fn add_repeating(&mut self, id: OracleId, log_count: usize) -> Result<OracleId, Error> {
450		self.add().repeating(id, log_count)
451	}
452
453	pub fn add_shifted(
454		&mut self,
455		id: OracleId,
456		offset: usize,
457		block_bits: usize,
458		variant: ShiftVariant,
459	) -> Result<OracleId, Error> {
460		self.add().shifted(id, offset, block_bits, variant)
461	}
462
463	pub fn add_packed(&mut self, id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
464		self.add().packed(id, log_degree)
465	}
466
467	/// Adds a projection to the variables starting at `start_index`.
468	pub fn add_projected(
469		&mut self,
470		id: OracleId,
471		values: Vec<F>,
472		start_index: usize,
473	) -> Result<OracleId, Error> {
474		self.add().projected(id, values, start_index)
475	}
476
477	/// Adds a projection to the last variables.
478	pub fn add_projected_last_vars(
479		&mut self,
480		id: OracleId,
481		values: Vec<F>,
482	) -> Result<OracleId, Error> {
483		self.add().projected_last_vars(id, values)
484	}
485
486	pub fn add_linear_combination(
487		&mut self,
488		n_vars: usize,
489		inner: impl IntoIterator<Item = (OracleId, F)>,
490	) -> Result<OracleId, Error> {
491		self.add().linear_combination(n_vars, inner)
492	}
493
494	pub fn add_linear_combination_with_offset(
495		&mut self,
496		n_vars: usize,
497		offset: F,
498		inner: impl IntoIterator<Item = (OracleId, F)>,
499	) -> Result<OracleId, Error> {
500		self.add()
501			.linear_combination_with_offset(n_vars, offset, inner)
502	}
503
504	pub fn add_zero_padded(
505		&mut self,
506		id: OracleId,
507		n_pad_vars: usize,
508		nonzero_index: usize,
509		start_index: usize,
510	) -> Result<OracleId, Error> {
511		self.add()
512			.zero_padded(id, n_pad_vars, nonzero_index, start_index)
513	}
514
515	pub fn add_composite_mle(
516		&mut self,
517		n_vars: usize,
518		inner: impl IntoIterator<Item = OracleId>,
519		comp: ArithCircuit<F>,
520	) -> Result<OracleId, Error> {
521		self.add().composite_mle(n_vars, inner, comp)
522	}
523
524	pub fn n_vars(&self, id: OracleId) -> usize {
525		self[id].n_vars()
526	}
527
528	pub fn label(&self, id: OracleId) -> String {
529		self[id].label()
530	}
531
532	/// Maximum tower level of the oracle's values over the boolean hypercube.
533	pub fn tower_level(&self, id: OracleId) -> usize {
534		self[id].binary_tower_level()
535	}
536
537	/// Returns `true` if the given [`OracleId`] refers to an oracle that was skipped during the
538	/// instantiation of the symbolic multilinear oracle set.
539	pub fn is_zero_sized(&self, id: OracleId) -> bool {
540		self.oracles[id.index()].is_none()
541	}
542}
543
544impl<F: TowerField> std::ops::Index<OracleId> for MultilinearOracleSet<F> {
545	type Output = MultilinearPolyOracle<F>;
546
547	fn index(&self, id: OracleId) -> &Self::Output {
548		self.oracles[id.index()]
549			.as_ref()
550			.expect("tried to access skipped oracle")
551	}
552}
553
554/// A multilinear polynomial oracle in the polynomial IOP model.
555///
556/// In the multilinear polynomial IOP model, a prover sends multilinear polynomials to an oracle,
557/// and the verifier may at the end of the protocol query their evaluations at chosen points. An
558/// oracle is a verifier and prover's shared view of a polynomial that can be queried for
559/// evaluations by the verifier.
560///
561/// There are three fundamental categories of oracles:
562///
563/// 1. *Transparent oracles*. These are multilinear polynomials with a succinct description and
564///    evaluation algorithm that are known to the verifier. When the verifier queries a transparent
565///    oracle, it evaluates the polynomial itself.
566/// 2. *Committed oracles*. These are polynomials actually sent by the prover. When the polynomial
567///    IOP is compiled to an interactive protocol, these polynomial are committed with a polynomial
568///    commitment scheme.
569/// 3. *Virtual oracles*. A virtual multilinear oracle is not actually sent by the prover, but
570///    instead admits an interactive reduction for evaluation queries to evaluation queries to other
571///    oracles. This is formalized in [DP23] Section 4.
572///
573/// [DP23]: <https://eprint.iacr.org/2023/1784>
574#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
575#[deserialize_bytes(eval_generics(F = BinaryField128b))]
576pub struct MultilinearPolyOracle<F: TowerField> {
577	pub id: OracleId,
578	pub name: Option<String>,
579	pub n_vars: usize,
580	pub tower_level: usize,
581	pub variant: MultilinearPolyVariant<F>,
582}
583
584#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
585pub enum MultilinearPolyVariant<F: TowerField> {
586	Committed,
587	Transparent(TransparentPolyOracle<F>),
588	/// A structured virtual polynomial is one that can be evaluated succinctly by a verifier.
589	///
590	/// These are referred to as "MLE-structured" tables in [Lasso]. The evaluation algorithm is
591	/// expressed as an arithmetic circuit, of polynomial size in the number of variables.
592	///
593	/// [Lasso]: <https://eprint.iacr.org/2023/1216>
594	Structured(ArithCircuit<F>),
595	Repeating {
596		id: OracleId,
597		log_count: usize,
598	},
599	Projected(Projected<F>),
600	Shifted(Shifted),
601	Packed(Packed),
602	LinearCombination(LinearCombination<F>),
603	ZeroPadded(ZeroPadded),
604	Composite(CompositeMLE<F>),
605}
606
607impl<F: TowerField> MultilinearPolyVariant<F> {
608	/// Returns true if this multilinear polynomial variant is committed, as opposed to virtual.
609	pub fn is_committed(&self) -> bool {
610		matches!(self, Self::Committed)
611	}
612}
613
614impl DeserializeBytes for MultilinearPolyVariant<BinaryField128b> {
615	fn deserialize(
616		mut buf: impl bytes::Buf,
617		mode: SerializationMode,
618	) -> Result<Self, SerializationError>
619	where
620		Self: Sized,
621	{
622		Ok(match u8::deserialize(&mut buf, mode)? {
623			0 => Self::Committed,
624			1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
625			2 => Self::Structured(DeserializeBytes::deserialize(buf, mode)?),
626			3 => Self::Repeating {
627				id: DeserializeBytes::deserialize(&mut buf, mode)?,
628				log_count: DeserializeBytes::deserialize(buf, mode)?,
629			},
630			4 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?),
631			5 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?),
632			6 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?),
633			7 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?),
634			8 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?),
635			9 => Self::Composite(DeserializeBytes::deserialize(buf, mode)?),
636			variant_index => {
637				return Err(SerializationError::UnknownEnumVariant {
638					name: "MultilinearPolyVariant",
639					index: variant_index,
640				});
641			}
642		})
643	}
644}
645
646/// A transparent multilinear polynomial oracle.
647///
648/// See the [`MultilinearPolyOracle`] documentation for context.
649#[derive(Debug, Clone, Getters, CopyGetters)]
650pub struct TransparentPolyOracle<F: Field> {
651	#[get = "pub"]
652	poly: Arc<dyn MultivariatePoly<F>>,
653}
654
655impl<F: TowerField> SerializeBytes for TransparentPolyOracle<F> {
656	fn serialize(
657		&self,
658		mut write_buf: impl bytes::BufMut,
659		mode: SerializationMode,
660	) -> Result<(), SerializationError> {
661		self.poly.erased_serialize(&mut write_buf, mode)
662	}
663}
664
665impl DeserializeBytes for TransparentPolyOracle<BinaryField128b> {
666	fn deserialize(
667		read_buf: impl bytes::Buf,
668		mode: SerializationMode,
669	) -> Result<Self, SerializationError>
670	where
671		Self: Sized,
672	{
673		Ok(Self {
674			poly: Box::<dyn MultivariatePoly<BinaryField128b>>::deserialize(read_buf, mode)?.into(),
675		})
676	}
677}
678
679impl<F: TowerField> TransparentPolyOracle<F> {
680	pub fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
681		if poly.binary_tower_level() > F::TOWER_LEVEL {
682			bail!(Error::TowerLevelTooHigh {
683				tower_level: poly.binary_tower_level(),
684			});
685		}
686		Ok(Self { poly })
687	}
688}
689
690impl<F: Field> TransparentPolyOracle<F> {
691	/// Maximum tower level of the oracle's values over the boolean hypercube.
692	pub fn binary_tower_level(&self) -> usize {
693		self.poly.binary_tower_level()
694	}
695}
696
697impl<F: Field> PartialEq for TransparentPolyOracle<F> {
698	fn eq(&self, other: &Self) -> bool {
699		Arc::ptr_eq(&self.poly, &other.poly)
700	}
701}
702
703impl<F: Field> Eq for TransparentPolyOracle<F> {}
704
705#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
706pub struct Projected<F: TowerField> {
707	#[get_copy = "pub"]
708	id: OracleId,
709	#[get = "pub"]
710	values: Vec<F>,
711	#[get_copy = "pub"]
712	start_index: usize,
713}
714
715impl<F: TowerField> Projected<F> {
716	pub fn new(
717		oracles: &MultilinearOracleSet<F>,
718		id: OracleId,
719		values: Vec<F>,
720		start_index: usize,
721	) -> Result<Self, Error> {
722		if values.len() + start_index > oracles.n_vars(id) {
723			bail!(Error::InvalidProjection {
724				n_vars: oracles.n_vars(id),
725				values_len: values.len()
726			});
727		}
728		Ok(Self {
729			id,
730			values,
731			start_index,
732		})
733	}
734}
735
736#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
737pub struct ZeroPadded {
738	#[get_copy = "pub"]
739	id: OracleId,
740	#[get_copy = "pub"]
741	n_pad_vars: usize,
742	#[get_copy = "pub"]
743	nonzero_index: usize,
744	#[get_copy = "pub"]
745	start_index: usize,
746}
747
748impl ZeroPadded {
749	pub fn new<F: TowerField>(
750		oracles: &MultilinearOracleSet<F>,
751		id: OracleId,
752		n_pad_vars: usize,
753		nonzero_index: usize,
754		start_index: usize,
755	) -> Result<Self, Error> {
756		let oracle = &oracles[id];
757		if start_index > oracle.n_vars() {
758			bail!(Error::InvalidStartIndex {
759				expected: oracle.n_vars(),
760			});
761		}
762
763		if nonzero_index > 1 << n_pad_vars {
764			bail!(Error::InvalidNonzeroIndex {
765				expected: 1 << n_pad_vars,
766			});
767		}
768
769		Ok(Self {
770			id,
771			n_pad_vars,
772			nonzero_index,
773			start_index,
774		})
775	}
776}
777
778#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
779pub enum ShiftVariant {
780	CircularLeft,
781	LogicalLeft,
782	LogicalRight,
783}
784
785#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
786pub struct Shifted {
787	#[get_copy = "pub"]
788	id: OracleId,
789	#[get_copy = "pub"]
790	shift_offset: usize,
791	#[get_copy = "pub"]
792	block_size: usize,
793	#[get_copy = "pub"]
794	shift_variant: ShiftVariant,
795}
796
797impl Shifted {
798	pub fn new<F: TowerField>(
799		oracles: &MultilinearOracleSet<F>,
800		id: OracleId,
801		shift_offset: usize,
802		block_size: usize,
803		shift_variant: ShiftVariant,
804	) -> Result<Self, Error> {
805		if block_size > oracles.n_vars(id) {
806			bail!(PolynomialError::InvalidBlockSize {
807				n_vars: oracles.n_vars(id),
808			});
809		}
810
811		if shift_offset == 0 || shift_offset >= 1 << block_size {
812			bail!(PolynomialError::InvalidShiftOffset {
813				max_shift_offset: (1 << block_size) - 1,
814				shift_offset,
815			});
816		}
817
818		Ok(Self {
819			id,
820			shift_offset,
821			block_size,
822			shift_variant,
823		})
824	}
825}
826
827#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
828pub struct Packed {
829	#[get_copy = "pub"]
830	id: OracleId,
831	/// The number of tower levels increased by the packing operation.
832	///
833	/// This is the base 2 logarithm of the field extension, and is called $\kappa$ in [DP23],
834	/// Section 4.3.
835	///
836	/// [DP23]: https://eprint.iacr.org/2023/1784
837	#[get_copy = "pub"]
838	log_degree: usize,
839}
840
841impl Packed {
842	pub fn new(id: OracleId, log_degree: usize) -> Self {
843		Self { id, log_degree }
844	}
845}
846
847#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
848pub struct LinearCombination<F: TowerField> {
849	#[get_copy = "pub"]
850	n_vars: usize,
851	#[get_copy = "pub"]
852	offset: F,
853	inner: Vec<(OracleId, F)>,
854}
855
856impl<F: TowerField> LinearCombination<F> {
857	pub fn new(
858		oracles: &MultilinearOracleSet<F>,
859		n_vars: usize,
860		offset: F,
861		inner: Vec<(OracleId, F)>,
862	) -> Result<Self, Error> {
863		for (oracle, _) in &inner {
864			if oracles[*oracle].n_vars() != n_vars {
865				return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
866			}
867		}
868		Ok(Self {
869			n_vars,
870			offset,
871			inner,
872		})
873	}
874
875	pub fn n_polys(&self) -> usize {
876		self.inner.len()
877	}
878
879	pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
880		self.inner.iter().map(|(id, _)| *id)
881	}
882
883	pub fn coefficients(&self) -> impl Iterator<Item = F> + '_ {
884		self.inner.iter().map(|(_, coeff)| *coeff)
885	}
886}
887
888/// MLE of a multivariate polynomial evaluated on multilinear oracles.
889///
890/// i.e. the MLE of the evaluations of $C(M_1, M_2, \ldots, M_n)$ on $\{0, 1\}^\mu$ where:
891/// - $C$ is a arbitrary polynomial in $n$ variables
892/// - $M_1, M_2, \ldots, M_n$ are multilinear oracles in μ variables
893///
894/// ($C$ should be sufficiently lightweight to be evaluated by the verifier)
895#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, DeserializeBytes, SerializeBytes)]
896pub struct CompositeMLE<F: TowerField> {
897	/// $\mu$
898	#[get_copy = "pub"]
899	n_vars: usize,
900	/// $M_1, M_2, \ldots, M_n$
901	#[getset(get = "pub")]
902	inner: Vec<OracleId>,
903	/// $C$
904	#[getset(get = "pub")]
905	c: ArithCircuitPoly<F>,
906}
907
908impl<F: TowerField> CompositeMLE<F> {
909	pub fn new(
910		oracles: &MultilinearOracleSet<F>,
911		n_vars: usize,
912		inner: Vec<OracleId>,
913		c: ArithCircuit<F>,
914	) -> Result<Self, Error> {
915		for &oracle in &inner {
916			if oracles.n_vars(oracle) != n_vars {
917				return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
918			}
919		}
920		let c = ArithCircuitPoly::with_n_vars(inner.len(), c)
921			.map_err(|_| Error::CompositionMismatch)?; // occurs if `c` has more variables than `inner.len()`
922		Ok(Self { n_vars, inner, c })
923	}
924
925	pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
926		self.inner.iter().copied()
927	}
928
929	pub fn n_polys(&self) -> usize {
930		self.inner.len()
931	}
932}
933
934impl<F: TowerField> MultilinearPolyOracle<F> {
935	pub const fn id(&self) -> OracleId {
936		self.id
937	}
938
939	pub fn label(&self) -> String {
940		match self.name() {
941			Some(name) => format!("{}: {}", self.type_str(), name),
942			None => format!("{}: id={}", self.type_str(), self.id()),
943		}
944	}
945
946	pub fn name(&self) -> Option<&str> {
947		self.name.as_deref()
948	}
949
950	const fn type_str(&self) -> &str {
951		match self.variant {
952			MultilinearPolyVariant::Transparent(_) => "Transparent",
953			MultilinearPolyVariant::Structured(_) => "Structured",
954			MultilinearPolyVariant::Committed => "Committed",
955			MultilinearPolyVariant::Repeating { .. } => "Repeating",
956			MultilinearPolyVariant::Projected(_) => "Projected",
957			MultilinearPolyVariant::Shifted(_) => "Shifted",
958			MultilinearPolyVariant::Packed(_) => "Packed",
959			MultilinearPolyVariant::LinearCombination(_) => "LinearCombination",
960			MultilinearPolyVariant::ZeroPadded(_) => "ZeroPadded",
961			MultilinearPolyVariant::Composite(_) => "CompositeMLE",
962		}
963	}
964
965	pub const fn n_vars(&self) -> usize {
966		self.n_vars
967	}
968
969	/// Maximum tower level of the oracle's values over the boolean hypercube.
970	pub const fn binary_tower_level(&self) -> usize {
971		self.tower_level
972	}
973
974	pub fn into_composite(self) -> CompositePolyOracle<F> {
975		let composite =
976			CompositePolyOracle::new(self.n_vars(), vec![self], IdentityCompositionPoly);
977		composite.expect("Can always apply the identity composition to one variable")
978	}
979}
980
981#[cfg(test)]
982mod tests {
983	use binius_field::{BinaryField1b, BinaryField128b, Field, TowerField};
984
985	use super::MultilinearOracleSet;
986
987	#[test]
988	fn add_projection_with_all_vars() {
989		type F = BinaryField128b;
990		let mut oracles = MultilinearOracleSet::<F>::new();
991		let data = oracles.add_committed(5, BinaryField1b::TOWER_LEVEL);
992		let start_index = 0;
993		let projected = oracles
994			.add_projected(data, vec![F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], start_index)
995			.unwrap();
996		let _ = &oracles[projected];
997	}
998}