binius_core/oracle/
multilinear.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, sync::Arc};
4
5use binius_fast_compute::arith_circuit::ArithCircuitPoly;
6use binius_field::{BinaryField128b, Field, TowerField};
7use binius_macros::{DeserializeBytes, SerializeBytes};
8use binius_math::ArithCircuit;
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, bail, ensure,
11};
12use getset::{CopyGetters, Getters};
13
14use crate::{
15	oracle::{CompositePolyOracle, Error, OracleId},
16	polynomial::{Error as PolynomialError, IdentityCompositionPoly, MultivariatePoly},
17};
18
19/// Meta struct that lets you add optional `name` for the Multilinear before adding to the
20/// [`MultilinearOracleSet`]
21pub struct MultilinearOracleSetAddition<'a, F: TowerField> {
22	name: Option<String>,
23	mut_ref: &'a mut MultilinearOracleSet<F>,
24}
25
26impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
27	pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
28		if poly.binary_tower_level() > F::TOWER_LEVEL {
29			bail!(Error::TowerLevelTooHigh {
30				tower_level: poly.binary_tower_level(),
31			});
32		}
33
34		let inner = TransparentPolyOracle::new(Arc::new(poly))?;
35
36		let oracle = |id: OracleId| MultilinearPolyOracle {
37			id,
38			n_vars: inner.poly.n_vars(),
39			tower_level: inner.poly.binary_tower_level(),
40			name: self.name,
41			variant: MultilinearPolyVariant::Transparent(inner),
42		};
43
44		Ok(self.mut_ref.add_to_set(oracle))
45	}
46
47	pub fn structured(self, n_vars: usize, expr: ArithCircuit<F>) -> Result<OracleId, Error> {
48		if expr.binary_tower_level() > F::TOWER_LEVEL {
49			bail!(Error::TowerLevelTooHigh {
50				tower_level: expr.binary_tower_level(),
51			});
52		}
53
54		let oracle = |id: OracleId| MultilinearPolyOracle {
55			id,
56			n_vars,
57			tower_level: expr.binary_tower_level(),
58			name: self.name,
59			variant: MultilinearPolyVariant::Structured(expr),
60		};
61
62		Ok(self.mut_ref.add_to_set(oracle))
63	}
64
65	pub fn committed(mut self, n_vars: usize, tower_level: usize) -> OracleId {
66		let name = self.name.take();
67		self.add_committed_with_name(n_vars, tower_level, name)
68	}
69
70	pub fn committed_multiple<const N: usize>(
71		mut self,
72		n_vars: usize,
73		tower_level: usize,
74	) -> [OracleId; N] {
75		match &self.name.take() {
76			None => [0; N].map(|_| self.add_committed_with_name(n_vars, tower_level, None)),
77			Some(s) => {
78				let x: [usize; N] = array::from_fn(|i| i);
79				x.map(|i| {
80					self.add_committed_with_name(n_vars, tower_level, Some(format!("{s}_{i}")))
81				})
82			}
83		}
84	}
85
86	pub fn repeating(self, inner_id: OracleId, log_count: usize) -> Result<OracleId, Error> {
87		ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
88
89		let inner = &self.mut_ref[inner_id];
90
91		let tower_level = inner.tower_level;
92		let n_vars = inner.n_vars + log_count;
93		let oracle = |id: OracleId| MultilinearPolyOracle {
94			id,
95			n_vars,
96			tower_level,
97			name: self.name,
98			variant: MultilinearPolyVariant::Repeating {
99				id: inner_id,
100				log_count,
101			},
102		};
103
104		Ok(self.mut_ref.add_to_set(oracle))
105	}
106
107	pub fn shifted(
108		self,
109		inner_id: OracleId,
110		offset: usize,
111		block_bits: usize,
112		variant: ShiftVariant,
113	) -> Result<OracleId, Error> {
114		ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
115
116		let inner = &self.mut_ref[inner_id];
117		if block_bits > inner.n_vars {
118			bail!(PolynomialError::InvalidBlockSize {
119				n_vars: inner.n_vars,
120			});
121		}
122
123		if offset == 0 || offset >= 1 << block_bits {
124			bail!(PolynomialError::InvalidShiftOffset {
125				max_shift_offset: (1 << block_bits) - 1,
126				shift_offset: offset,
127			});
128		}
129
130		let shifted = Shifted::new(inner, offset, block_bits, variant)?;
131
132		let tower_level = inner.tower_level;
133		let n_vars = inner.n_vars;
134		let oracle = |id: OracleId| MultilinearPolyOracle {
135			id,
136			n_vars,
137			tower_level,
138			name: self.name,
139			variant: MultilinearPolyVariant::Shifted(shifted),
140		};
141
142		Ok(self.mut_ref.add_to_set(oracle))
143	}
144
145	pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
146		ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
147
148		let inner_n_vars = self.mut_ref.n_vars(inner_id);
149		if log_degree > inner_n_vars {
150			bail!(Error::NotEnoughVarsForPacking {
151				n_vars: inner_n_vars,
152				log_degree,
153			});
154		}
155
156		let inner_tower_level = self.mut_ref.tower_level(inner_id);
157
158		let packed = Packed {
159			id: inner_id,
160			log_degree,
161		};
162
163		let oracle = |id: OracleId| MultilinearPolyOracle {
164			id,
165			n_vars: inner_n_vars - log_degree,
166			tower_level: inner_tower_level + log_degree,
167			name: self.name,
168			variant: MultilinearPolyVariant::Packed(packed),
169		};
170
171		Ok(self.mut_ref.add_to_set(oracle))
172	}
173
174	pub fn projected(
175		self,
176		inner_id: OracleId,
177		values: Vec<F>,
178		start_index: usize,
179	) -> Result<OracleId, Error> {
180		ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
181
182		let inner_n_vars = self.mut_ref.n_vars(inner_id);
183		let values_len = values.len();
184		if values_len > inner_n_vars {
185			bail!(Error::InvalidProjection {
186				n_vars: inner_n_vars,
187				values_len,
188			});
189		}
190
191		let inner = &self.mut_ref[inner_id];
192		// TODO: This is wrong, should be F::TOWER_LEVEL
193		let tower_level = inner.binary_tower_level();
194		let projected = Projected::new(inner, values, start_index)?;
195
196		let oracle = |id: OracleId| MultilinearPolyOracle {
197			id,
198			n_vars: inner_n_vars - values_len,
199			tower_level,
200			name: self.name,
201			variant: MultilinearPolyVariant::Projected(projected),
202		};
203
204		Ok(self.mut_ref.add_to_set(oracle))
205	}
206
207	pub fn projected_last_vars(
208		self,
209		inner_id: OracleId,
210		values: Vec<F>,
211	) -> Result<OracleId, Error> {
212		ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
213
214		let inner_n_vars = self.mut_ref.n_vars(inner_id);
215		let start_index = inner_n_vars - values.len();
216		let values_len = values.len();
217		if values_len > inner_n_vars {
218			bail!(Error::InvalidProjection {
219				n_vars: inner_n_vars,
220				values_len,
221			});
222		}
223
224		let inner = &self.mut_ref[inner_id];
225		// TODO: This is wrong, should be F::TOWER_LEVEL
226		let tower_level = inner.binary_tower_level();
227		let projected = Projected::new(inner, values, start_index)?;
228
229		let oracle = |id: OracleId| MultilinearPolyOracle {
230			id,
231			n_vars: inner_n_vars - values_len,
232			tower_level,
233			name: self.name,
234			variant: MultilinearPolyVariant::Projected(projected),
235		};
236
237		Ok(self.mut_ref.add_to_set(oracle))
238	}
239
240	pub fn linear_combination(
241		self,
242		n_vars: usize,
243		inner: impl IntoIterator<Item = (OracleId, F)>,
244	) -> Result<OracleId, Error> {
245		self.linear_combination_with_offset(n_vars, F::ZERO, inner)
246	}
247
248	pub fn linear_combination_with_offset(
249		self,
250		n_vars: usize,
251		offset: F,
252		inner: impl IntoIterator<Item = (OracleId, F)>,
253	) -> Result<OracleId, Error> {
254		let inner = inner
255			.into_iter()
256			.map(|(inner_id, coeff)| {
257				ensure!(
258					self.mut_ref.is_valid_oracle_id(inner_id),
259					Error::InvalidOracleId(inner_id)
260				);
261				if self.mut_ref.n_vars(inner_id) != n_vars {
262					return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
263				}
264				Ok((self.mut_ref[inner_id].clone(), coeff))
265			})
266			.collect::<Result<Vec<_>, _>>()?;
267
268		let tower_level = inner
269			.iter()
270			.map(|(oracle, coeff)| oracle.binary_tower_level().max(coeff.min_tower_level()))
271			.max()
272			.unwrap_or(0)
273			.max(offset.min_tower_level());
274
275		let linear_combination = LinearCombination::new(n_vars, offset, inner)?;
276
277		let oracle = |id: OracleId| MultilinearPolyOracle {
278			id,
279			n_vars,
280			tower_level,
281			name: self.name,
282			variant: MultilinearPolyVariant::LinearCombination(linear_combination),
283		};
284
285		Ok(self.mut_ref.add_to_set(oracle))
286	}
287
288	pub fn composite_mle(
289		self,
290		n_vars: usize,
291		inner: impl IntoIterator<Item = OracleId>,
292		comp: ArithCircuit<F>,
293	) -> Result<OracleId, Error> {
294		let inner = inner
295			.into_iter()
296			.map(|inner_id| {
297				ensure!(
298					self.mut_ref.is_valid_oracle_id(inner_id),
299					Error::InvalidOracleId(inner_id)
300				);
301				if self.mut_ref.n_vars(inner_id) != n_vars {
302					return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
303				}
304				Ok(self.mut_ref[inner_id].clone())
305			})
306			.collect::<Result<Vec<_>, _>>()?;
307
308		let tower_level = inner
309			.iter()
310			.map(|oracle| oracle.binary_tower_level())
311			.max()
312			.unwrap_or(0);
313
314		let composite_mle = CompositeMLE::new(n_vars, inner, comp)?;
315
316		let oracle = |id: OracleId| MultilinearPolyOracle {
317			id,
318			n_vars,
319			tower_level,
320			name: self.name,
321			variant: MultilinearPolyVariant::Composite(composite_mle),
322		};
323
324		Ok(self.mut_ref.add_to_set(oracle))
325	}
326
327	pub fn zero_padded(
328		self,
329		inner_id: OracleId,
330		n_pad_vars: usize,
331		nonzero_index: usize,
332		start_index: usize,
333	) -> Result<OracleId, Error> {
334		let inner_n_vars = self.mut_ref.n_vars(inner_id);
335		if start_index > inner_n_vars {
336			bail!(Error::InvalidStartIndex {
337				expected: inner_n_vars
338			});
339		}
340
341		let inner = &self.mut_ref[inner_id];
342		let tower_level = inner.binary_tower_level();
343		let padded = ZeroPadded::new(inner, n_pad_vars, nonzero_index, start_index)?;
344
345		let oracle = |id: OracleId| MultilinearPolyOracle {
346			id,
347			n_vars: inner_n_vars + n_pad_vars,
348			tower_level,
349			name: self.name,
350			variant: MultilinearPolyVariant::ZeroPadded(padded),
351		};
352
353		Ok(self.mut_ref.add_to_set(oracle))
354	}
355
356	fn add_committed_with_name(
357		&mut self,
358		n_vars: usize,
359		tower_level: usize,
360		name: Option<String>,
361	) -> OracleId {
362		let oracle = |oracle_id: OracleId| MultilinearPolyOracle {
363			id: oracle_id,
364			n_vars,
365			tower_level,
366			name: name.clone(),
367			variant: MultilinearPolyVariant::Committed,
368		};
369
370		self.mut_ref.add_to_set(oracle)
371	}
372}
373
374/// An ordered set of multilinear polynomial oracles.
375///
376/// The multilinear polynomial oracles form a directed acyclic graph, where each multilinear oracle
377/// is either transparent, committed, or derived from one or more others. Each oracle is assigned a
378/// unique `OracleId`.
379///
380/// The oracle set also tracks the committed polynomial in batches where each batch is committed
381/// together with a polynomial commitment scheme.
382#[derive(Default, Debug, Clone, SerializeBytes, DeserializeBytes)]
383#[deserialize_bytes(eval_generics(F = BinaryField128b))]
384pub struct MultilinearOracleSet<F: TowerField> {
385	oracles: Vec<MultilinearPolyOracle<F>>,
386}
387
388impl<F: TowerField> MultilinearOracleSet<F> {
389	pub const fn new() -> Self {
390		Self {
391			oracles: Vec::new(),
392		}
393	}
394
395	pub fn size(&self) -> usize {
396		self.oracles.len()
397	}
398
399	pub fn polys(&self) -> impl Iterator<Item = &MultilinearPolyOracle<F>> + '_ {
400		(0..self.oracles.len()).map(|index| &self[OracleId::from_index(index)])
401	}
402
403	pub fn ids(&self) -> impl Iterator<Item = OracleId> {
404		(0..self.oracles.len()).map(OracleId::from_index)
405	}
406
407	pub fn iter(&self) -> impl Iterator<Item = (OracleId, &MultilinearPolyOracle<F>)> + '_ {
408		(0..self.oracles.len()).map(|index| {
409			let oracle_id = OracleId::from_index(index);
410			(oracle_id, &self[oracle_id])
411		})
412	}
413
414	pub const fn add(&mut self) -> MultilinearOracleSetAddition<F> {
415		MultilinearOracleSetAddition {
416			name: None,
417			mut_ref: self,
418		}
419	}
420
421	pub fn add_named<S: ToString>(&mut self, s: S) -> MultilinearOracleSetAddition<F> {
422		MultilinearOracleSetAddition {
423			name: Some(s.to_string()),
424			mut_ref: self,
425		}
426	}
427
428	pub fn is_valid_oracle_id(&self, id: OracleId) -> bool {
429		id.index() < self.oracles.len()
430	}
431
432	fn add_to_set(
433		&mut self,
434		oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
435	) -> OracleId {
436		let id = OracleId::from_index(self.oracles.len());
437		self.oracles.push(oracle(id));
438		id
439	}
440
441	pub fn add_transparent(
442		&mut self,
443		poly: impl MultivariatePoly<F> + 'static,
444	) -> Result<OracleId, Error> {
445		self.add().transparent(poly)
446	}
447
448	pub fn add_committed(&mut self, n_vars: usize, tower_level: usize) -> OracleId {
449		self.add().committed(n_vars, tower_level)
450	}
451
452	pub fn add_committed_multiple<const N: usize>(
453		&mut self,
454		n_vars: usize,
455		tower_level: usize,
456	) -> [OracleId; N] {
457		self.add().committed_multiple(n_vars, tower_level)
458	}
459
460	pub fn add_repeating(&mut self, id: OracleId, log_count: usize) -> Result<OracleId, Error> {
461		self.add().repeating(id, log_count)
462	}
463
464	pub fn add_shifted(
465		&mut self,
466		id: OracleId,
467		offset: usize,
468		block_bits: usize,
469		variant: ShiftVariant,
470	) -> Result<OracleId, Error> {
471		self.add().shifted(id, offset, block_bits, variant)
472	}
473
474	pub fn add_packed(&mut self, id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
475		self.add().packed(id, log_degree)
476	}
477
478	/// Adds a projection to the variables starting at `start_index`.
479	pub fn add_projected(
480		&mut self,
481		id: OracleId,
482		values: Vec<F>,
483		start_index: usize,
484	) -> Result<OracleId, Error> {
485		self.add().projected(id, values, start_index)
486	}
487
488	/// Adds a projection to the last variables.
489	pub fn add_projected_last_vars(
490		&mut self,
491		id: OracleId,
492		values: Vec<F>,
493	) -> Result<OracleId, Error> {
494		self.add().projected_last_vars(id, values)
495	}
496
497	pub fn add_linear_combination(
498		&mut self,
499		n_vars: usize,
500		inner: impl IntoIterator<Item = (OracleId, F)>,
501	) -> Result<OracleId, Error> {
502		self.add().linear_combination(n_vars, inner)
503	}
504
505	pub fn add_linear_combination_with_offset(
506		&mut self,
507		n_vars: usize,
508		offset: F,
509		inner: impl IntoIterator<Item = (OracleId, F)>,
510	) -> Result<OracleId, Error> {
511		self.add()
512			.linear_combination_with_offset(n_vars, offset, inner)
513	}
514
515	pub fn add_zero_padded(
516		&mut self,
517		id: OracleId,
518		n_pad_vars: usize,
519		nonzero_index: usize,
520		start_index: usize,
521	) -> Result<OracleId, Error> {
522		self.add()
523			.zero_padded(id, n_pad_vars, nonzero_index, start_index)
524	}
525
526	pub fn add_composite_mle(
527		&mut self,
528		n_vars: usize,
529		inner: impl IntoIterator<Item = OracleId>,
530		comp: ArithCircuit<F>,
531	) -> Result<OracleId, Error> {
532		self.add().composite_mle(n_vars, inner, comp)
533	}
534
535	pub fn n_vars(&self, id: OracleId) -> usize {
536		self[id].n_vars()
537	}
538
539	pub fn label(&self, id: OracleId) -> String {
540		self[id].label()
541	}
542
543	/// Maximum tower level of the oracle's values over the boolean hypercube.
544	pub fn tower_level(&self, id: OracleId) -> usize {
545		self[id].binary_tower_level()
546	}
547}
548
549impl<F: TowerField> std::ops::Index<OracleId> for MultilinearOracleSet<F> {
550	type Output = MultilinearPolyOracle<F>;
551
552	fn index(&self, id: OracleId) -> &Self::Output {
553		&self.oracles[id.index()]
554	}
555}
556
557/// A multilinear polynomial oracle in the polynomial IOP model.
558///
559/// In the multilinear polynomial IOP model, a prover sends multilinear polynomials to an oracle,
560/// and the verifier may at the end of the protocol query their evaluations at chosen points. An
561/// oracle is a verifier and prover's shared view of a polynomial that can be queried for
562/// evaluations by the verifier.
563///
564/// There are three fundamental categories of oracles:
565///
566/// 1. *Transparent oracles*. These are multilinear polynomials with a succinct description and
567///    evaluation algorithm that are known to the verifier. When the verifier queries a transparent
568///    oracle, it evaluates the polynomial itself.
569/// 2. *Committed oracles*. These are polynomials actually sent by the prover. When the polynomial
570///    IOP is compiled to an interactive protocol, these polynomial are committed with a polynomial
571///    commitment scheme.
572/// 3. *Virtual oracles*. A virtual multilinear oracle is not actually sent by the prover, but
573///    instead admits an interactive reduction for evaluation queries to evaluation queries to other
574///    oracles. This is formalized in [DP23] Section 4.
575///
576/// [DP23]: <https://eprint.iacr.org/2023/1784>
577#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
578#[deserialize_bytes(eval_generics(F = BinaryField128b))]
579pub struct MultilinearPolyOracle<F: TowerField> {
580	pub id: OracleId,
581	pub name: Option<String>,
582	pub n_vars: usize,
583	pub tower_level: usize,
584	pub variant: MultilinearPolyVariant<F>,
585}
586
587#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
588pub enum MultilinearPolyVariant<F: TowerField> {
589	Committed,
590	Transparent(TransparentPolyOracle<F>),
591	/// A structured virtual polynomial is one that can be evaluated succinctly by a verifier.
592	///
593	/// These are referred to as "MLE-structured" tables in [Lasso]. The evaluation algorithm is
594	/// expressed as an arithmetic circuit, of polynomial size in the number of variables.
595	///
596	/// [Lasso]: <https://eprint.iacr.org/2023/1216>
597	Structured(ArithCircuit<F>),
598	Repeating {
599		id: OracleId,
600		log_count: usize,
601	},
602	Projected(Projected<F>),
603	Shifted(Shifted),
604	Packed(Packed),
605	LinearCombination(LinearCombination<F>),
606	ZeroPadded(ZeroPadded),
607	Composite(CompositeMLE<F>),
608}
609
610impl<F: TowerField> MultilinearPolyVariant<F> {
611	/// Returns true if this multilinear polynomial variant is committed, as opposed to virtual.
612	pub fn is_committed(&self) -> bool {
613		matches!(self, Self::Committed)
614	}
615}
616
617impl DeserializeBytes for MultilinearPolyVariant<BinaryField128b> {
618	fn deserialize(
619		mut buf: impl bytes::Buf,
620		mode: SerializationMode,
621	) -> Result<Self, SerializationError>
622	where
623		Self: Sized,
624	{
625		Ok(match u8::deserialize(&mut buf, mode)? {
626			0 => Self::Committed,
627			1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
628			2 => Self::Structured(DeserializeBytes::deserialize(buf, mode)?),
629			3 => Self::Repeating {
630				id: DeserializeBytes::deserialize(&mut buf, mode)?,
631				log_count: DeserializeBytes::deserialize(buf, mode)?,
632			},
633			4 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?),
634			5 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?),
635			6 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?),
636			7 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?),
637			8 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?),
638			variant_index => {
639				return Err(SerializationError::UnknownEnumVariant {
640					name: "MultilinearPolyVariant",
641					index: variant_index,
642				});
643			}
644		})
645	}
646}
647
648/// A transparent multilinear polynomial oracle.
649///
650/// See the [`MultilinearPolyOracle`] documentation for context.
651#[derive(Debug, Clone, Getters, CopyGetters)]
652pub struct TransparentPolyOracle<F: Field> {
653	#[get = "pub"]
654	poly: Arc<dyn MultivariatePoly<F>>,
655}
656
657impl<F: TowerField> SerializeBytes for TransparentPolyOracle<F> {
658	fn serialize(
659		&self,
660		mut write_buf: impl bytes::BufMut,
661		mode: SerializationMode,
662	) -> Result<(), SerializationError> {
663		self.poly.erased_serialize(&mut write_buf, mode)
664	}
665}
666
667impl DeserializeBytes for TransparentPolyOracle<BinaryField128b> {
668	fn deserialize(
669		read_buf: impl bytes::Buf,
670		mode: SerializationMode,
671	) -> Result<Self, SerializationError>
672	where
673		Self: Sized,
674	{
675		Ok(Self {
676			poly: Box::<dyn MultivariatePoly<BinaryField128b>>::deserialize(read_buf, mode)?.into(),
677		})
678	}
679}
680
681impl<F: TowerField> TransparentPolyOracle<F> {
682	fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
683		if poly.binary_tower_level() > F::TOWER_LEVEL {
684			bail!(Error::TowerLevelTooHigh {
685				tower_level: poly.binary_tower_level(),
686			});
687		}
688		Ok(Self { poly })
689	}
690}
691
692impl<F: Field> TransparentPolyOracle<F> {
693	/// Maximum tower level of the oracle's values over the boolean hypercube.
694	pub fn binary_tower_level(&self) -> usize {
695		self.poly.binary_tower_level()
696	}
697}
698
699impl<F: Field> PartialEq for TransparentPolyOracle<F> {
700	fn eq(&self, other: &Self) -> bool {
701		Arc::ptr_eq(&self.poly, &other.poly)
702	}
703}
704
705impl<F: Field> Eq for TransparentPolyOracle<F> {}
706
707#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
708pub struct Projected<F: TowerField> {
709	#[get_copy = "pub"]
710	id: OracleId,
711	#[get = "pub"]
712	values: Vec<F>,
713	#[get_copy = "pub"]
714	start_index: usize,
715}
716
717impl<F: TowerField> Projected<F> {
718	fn new(
719		oracle: &MultilinearPolyOracle<F>,
720		values: Vec<F>,
721		start_index: usize,
722	) -> Result<Self, Error> {
723		if values.len() + start_index > oracle.n_vars() {
724			bail!(Error::InvalidProjection {
725				n_vars: oracle.n_vars(),
726				values_len: values.len()
727			});
728		}
729		Ok(Self {
730			id: oracle.id(),
731			values,
732			start_index,
733		})
734	}
735}
736
737#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
738pub struct ZeroPadded {
739	#[get_copy = "pub"]
740	id: OracleId,
741	#[get_copy = "pub"]
742	n_pad_vars: usize,
743	#[get_copy = "pub"]
744	nonzero_index: usize,
745	#[get_copy = "pub"]
746	start_index: usize,
747}
748
749impl ZeroPadded {
750	fn new<F: TowerField>(
751		oracle: &MultilinearPolyOracle<F>,
752		n_pad_vars: usize,
753		nonzero_index: usize,
754		start_index: usize,
755	) -> Result<Self, Error> {
756		if start_index > oracle.n_vars() {
757			bail!(Error::InvalidStartIndex {
758				expected: oracle.n_vars(),
759			});
760		}
761
762		if nonzero_index > 1 << n_pad_vars {
763			bail!(Error::InvalidNonzeroIndex {
764				expected: 1 << n_pad_vars,
765			});
766		}
767
768		Ok(Self {
769			id: oracle.id(),
770			n_pad_vars,
771			nonzero_index,
772			start_index,
773		})
774	}
775}
776
777#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
778pub enum ShiftVariant {
779	CircularLeft,
780	LogicalLeft,
781	LogicalRight,
782}
783
784#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
785pub struct Shifted {
786	#[get_copy = "pub"]
787	id: OracleId,
788	#[get_copy = "pub"]
789	shift_offset: usize,
790	#[get_copy = "pub"]
791	block_size: usize,
792	#[get_copy = "pub"]
793	shift_variant: ShiftVariant,
794}
795
796impl Shifted {
797	fn new<F: TowerField>(
798		oracle: &MultilinearPolyOracle<F>,
799		shift_offset: usize,
800		block_size: usize,
801		shift_variant: ShiftVariant,
802	) -> Result<Self, Error> {
803		if block_size > oracle.n_vars() {
804			bail!(PolynomialError::InvalidBlockSize {
805				n_vars: oracle.n_vars(),
806			});
807		}
808
809		if shift_offset == 0 || shift_offset >= 1 << block_size {
810			bail!(PolynomialError::InvalidShiftOffset {
811				max_shift_offset: (1 << block_size) - 1,
812				shift_offset,
813			});
814		}
815
816		Ok(Self {
817			id: oracle.id(),
818			shift_offset,
819			block_size,
820			shift_variant,
821		})
822	}
823}
824
825#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
826pub struct Packed {
827	#[get_copy = "pub"]
828	id: OracleId,
829	/// The number of tower levels increased by the packing operation.
830	///
831	/// This is the base 2 logarithm of the field extension, and is called $\kappa$ in [DP23],
832	/// Section 4.3.
833	///
834	/// [DP23]: https://eprint.iacr.org/2023/1784
835	#[get_copy = "pub"]
836	log_degree: usize,
837}
838
839#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
840pub struct LinearCombination<F: TowerField> {
841	#[get_copy = "pub"]
842	n_vars: usize,
843	#[get_copy = "pub"]
844	offset: F,
845	inner: Vec<(OracleId, F)>,
846}
847
848impl<F: TowerField> LinearCombination<F> {
849	fn new(
850		n_vars: usize,
851		offset: F,
852		inner: impl IntoIterator<Item = (MultilinearPolyOracle<F>, F)>,
853	) -> Result<Self, Error> {
854		let inner = inner
855			.into_iter()
856			.map(|(oracle, value)| {
857				if oracle.n_vars() == n_vars {
858					Ok((oracle.id(), value))
859				} else {
860					Err(Error::IncorrectNumberOfVariables { expected: n_vars })
861				}
862			})
863			.collect::<Result<Vec<_>, _>>()?;
864		Ok(Self {
865			n_vars,
866			offset,
867			inner,
868		})
869	}
870
871	pub fn n_polys(&self) -> usize {
872		self.inner.len()
873	}
874
875	pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
876		self.inner.iter().map(|(id, _)| *id)
877	}
878
879	pub fn coefficients(&self) -> impl Iterator<Item = F> + '_ {
880		self.inner.iter().map(|(_, coeff)| *coeff)
881	}
882}
883
884/// MLE of a multivariate polynomial evaluated on multilinear oracles.
885///
886/// i.e. the MLE of the evaluations of $C(M_1, M_2, \ldots, M_n)$ on $\{0, 1\}^\mu$ where:
887/// - $C$ is a arbitrary polynomial in $n$ variables
888/// - $M_1, M_2, \ldots, M_n$ are multilinear oracles in μ variables
889///
890/// ($C$ should be sufficiently lightweight to be evaluated by the verifier)
891#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes)]
892pub struct CompositeMLE<F: TowerField> {
893	/// $\mu$
894	#[get_copy = "pub"]
895	n_vars: usize,
896	/// $M_1, M_2, \ldots, M_n$
897	#[getset(get = "pub")]
898	inner: Vec<OracleId>,
899	/// $C$
900	#[getset(get = "pub")]
901	c: ArithCircuitPoly<F>,
902}
903
904impl<F: TowerField> CompositeMLE<F> {
905	pub fn new(
906		n_vars: usize,
907		inner: impl IntoIterator<Item = MultilinearPolyOracle<F>>,
908		c: ArithCircuit<F>,
909	) -> Result<Self, Error> {
910		let inner = inner
911			.into_iter()
912			.map(|oracle| {
913				if oracle.n_vars() == n_vars {
914					Ok(oracle.id())
915				} else {
916					Err(Error::IncorrectNumberOfVariables { expected: n_vars })
917				}
918			})
919			.collect::<Result<Vec<_>, _>>()?;
920		let c = ArithCircuitPoly::with_n_vars(inner.len(), c)
921			.map_err(|_| Error::CompositionMismatch)?; // occurs if `c` has more variables than `inner.len()`
922		Ok(Self { n_vars, inner, c })
923	}
924
925	pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
926		self.inner.iter().copied()
927	}
928
929	pub fn n_polys(&self) -> usize {
930		self.inner.len()
931	}
932}
933
934impl<F: TowerField> MultilinearPolyOracle<F> {
935	pub const fn id(&self) -> OracleId {
936		self.id
937	}
938
939	pub fn label(&self) -> String {
940		match self.name() {
941			Some(name) => format!("{}: {}", self.type_str(), name),
942			None => format!("{}: id={}", self.type_str(), self.id()),
943		}
944	}
945
946	pub fn name(&self) -> Option<&str> {
947		self.name.as_deref()
948	}
949
950	const fn type_str(&self) -> &str {
951		match self.variant {
952			MultilinearPolyVariant::Transparent(_) => "Transparent",
953			MultilinearPolyVariant::Structured(_) => "Structured",
954			MultilinearPolyVariant::Committed => "Committed",
955			MultilinearPolyVariant::Repeating { .. } => "Repeating",
956			MultilinearPolyVariant::Projected(_) => "Projected",
957			MultilinearPolyVariant::Shifted(_) => "Shifted",
958			MultilinearPolyVariant::Packed(_) => "Packed",
959			MultilinearPolyVariant::LinearCombination(_) => "LinearCombination",
960			MultilinearPolyVariant::ZeroPadded(_) => "ZeroPadded",
961			MultilinearPolyVariant::Composite(_) => "CompositeMLE",
962		}
963	}
964
965	pub const fn n_vars(&self) -> usize {
966		self.n_vars
967	}
968
969	/// Maximum tower level of the oracle's values over the boolean hypercube.
970	pub const fn binary_tower_level(&self) -> usize {
971		self.tower_level
972	}
973
974	pub fn into_composite(self) -> CompositePolyOracle<F> {
975		let composite =
976			CompositePolyOracle::new(self.n_vars(), vec![self], IdentityCompositionPoly);
977		composite.expect("Can always apply the identity composition to one variable")
978	}
979}
980
981#[cfg(test)]
982mod tests {
983	use binius_field::{BinaryField1b, BinaryField128b, Field, TowerField};
984
985	use super::MultilinearOracleSet;
986
987	#[test]
988	fn add_projection_with_all_vars() {
989		type F = BinaryField128b;
990		let mut oracles = MultilinearOracleSet::<F>::new();
991		let data = oracles.add_committed(5, BinaryField1b::TOWER_LEVEL);
992		let start_index = 0;
993		let projected = oracles
994			.add_projected(data, vec![F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], start_index)
995			.unwrap();
996		let _ = &oracles[projected];
997	}
998}