binius_core/oracle/
multilinear.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, fmt::Debug, sync::Arc};
4
5use binius_field::{BinaryField128b, Field, TowerField};
6use binius_macros::{DeserializeBytes, SerializeBytes};
7use binius_math::ArithCircuit;
8use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
9use getset::{CopyGetters, Getters};
10
11use crate::{
12	oracle::{CompositePolyOracle, Error},
13	polynomial::{
14		ArithCircuitPoly, Error as PolynomialError, IdentityCompositionPoly, MultivariatePoly,
15	},
16};
17
18/// Identifier for a multilinear oracle in a [`MultilinearOracleSet`].
19pub type OracleId = usize;
20
21/// Meta struct that lets you add optional `name` for the Multilinear before adding to the
22/// [`MultilinearOracleSet`]
23pub struct MultilinearOracleSetAddition<'a, F: TowerField> {
24	name: Option<String>,
25	mut_ref: &'a mut MultilinearOracleSet<F>,
26}
27
28impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
29	pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
30		if poly.binary_tower_level() > F::TOWER_LEVEL {
31			bail!(Error::TowerLevelTooHigh {
32				tower_level: poly.binary_tower_level(),
33			});
34		}
35
36		let inner = TransparentPolyOracle::new(Arc::new(poly))?;
37
38		let oracle = |id: OracleId| MultilinearPolyOracle {
39			id,
40			n_vars: inner.poly.n_vars(),
41			tower_level: inner.poly.binary_tower_level(),
42			name: self.name,
43			variant: MultilinearPolyVariant::Transparent(inner),
44		};
45
46		Ok(self.mut_ref.add_to_set(oracle))
47	}
48
49	pub fn committed(mut self, n_vars: usize, tower_level: usize) -> OracleId {
50		let name = self.name.take();
51		self.add_committed_with_name(n_vars, tower_level, name)
52	}
53
54	pub fn committed_multiple<const N: usize>(
55		mut self,
56		n_vars: usize,
57		tower_level: usize,
58	) -> [OracleId; N] {
59		match &self.name.take() {
60			None => [0; N].map(|_| self.add_committed_with_name(n_vars, tower_level, None)),
61			Some(s) => {
62				let x: [usize; N] = array::from_fn(|i| i);
63				x.map(|i| {
64					self.add_committed_with_name(n_vars, tower_level, Some(format!("{s}_{i}")))
65				})
66			}
67		}
68	}
69
70	pub fn repeating(self, inner_id: OracleId, log_count: usize) -> Result<OracleId, Error> {
71		if inner_id >= self.mut_ref.oracles.len() {
72			bail!(Error::InvalidOracleId(inner_id));
73		}
74
75		let inner = self.mut_ref.get_from_set(inner_id);
76
77		let oracle = |id: OracleId| MultilinearPolyOracle {
78			id,
79			n_vars: inner.n_vars + log_count,
80			tower_level: inner.tower_level,
81			name: self.name,
82			variant: MultilinearPolyVariant::Repeating {
83				id: inner_id,
84				log_count,
85			},
86		};
87
88		Ok(self.mut_ref.add_to_set(oracle))
89	}
90
91	pub fn shifted(
92		self,
93		inner_id: OracleId,
94		offset: usize,
95		block_bits: usize,
96		variant: ShiftVariant,
97	) -> Result<OracleId, Error> {
98		if inner_id >= self.mut_ref.oracles.len() {
99			bail!(Error::InvalidOracleId(inner_id));
100		}
101
102		let inner = self.mut_ref.get_from_set(inner_id);
103		if block_bits > inner.n_vars {
104			bail!(PolynomialError::InvalidBlockSize {
105				n_vars: inner.n_vars,
106			});
107		}
108
109		if offset == 0 || offset >= 1 << block_bits {
110			bail!(PolynomialError::InvalidShiftOffset {
111				max_shift_offset: (1 << block_bits) - 1,
112				shift_offset: offset,
113			});
114		}
115
116		let shifted = Shifted::new(&inner, offset, block_bits, variant)?;
117
118		let oracle = |id: OracleId| MultilinearPolyOracle {
119			id,
120			n_vars: inner.n_vars,
121			tower_level: inner.tower_level,
122			name: self.name,
123			variant: MultilinearPolyVariant::Shifted(shifted),
124		};
125
126		Ok(self.mut_ref.add_to_set(oracle))
127	}
128
129	pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
130		if inner_id >= self.mut_ref.oracles.len() {
131			bail!(Error::InvalidOracleId(inner_id));
132		}
133
134		let inner_n_vars = self.mut_ref.n_vars(inner_id);
135		if log_degree > inner_n_vars {
136			bail!(Error::NotEnoughVarsForPacking {
137				n_vars: inner_n_vars,
138				log_degree,
139			});
140		}
141
142		let inner_tower_level = self.mut_ref.tower_level(inner_id);
143
144		let packed = Packed {
145			id: inner_id,
146			log_degree,
147		};
148
149		let oracle = |id: OracleId| MultilinearPolyOracle {
150			id,
151			n_vars: inner_n_vars - log_degree,
152			tower_level: inner_tower_level + log_degree,
153			name: self.name,
154			variant: MultilinearPolyVariant::Packed(packed),
155		};
156
157		Ok(self.mut_ref.add_to_set(oracle))
158	}
159
160	pub fn projected(
161		self,
162		inner_id: OracleId,
163		values: Vec<F>,
164		start_index: usize,
165	) -> Result<OracleId, Error> {
166		let inner_n_vars = self.mut_ref.n_vars(inner_id);
167		let values_len = values.len();
168		if values_len > inner_n_vars {
169			bail!(Error::InvalidProjection {
170				n_vars: inner_n_vars,
171				values_len,
172			});
173		}
174
175		let inner = self.mut_ref.get_from_set(inner_id);
176		// TODO: This is wrong, should be F::TOWER_LEVEL
177		let tower_level = inner.binary_tower_level();
178		let projected = Projected::new(&inner, values, start_index)?;
179
180		let oracle = |id: OracleId| MultilinearPolyOracle {
181			id,
182			n_vars: inner_n_vars - values_len,
183			tower_level,
184			name: self.name,
185			variant: MultilinearPolyVariant::Projected(projected),
186		};
187
188		Ok(self.mut_ref.add_to_set(oracle))
189	}
190
191	pub fn projected_last_vars(
192		self,
193		inner_id: OracleId,
194		values: Vec<F>,
195	) -> Result<OracleId, Error> {
196		let inner_n_vars = self.mut_ref.n_vars(inner_id);
197		let start_index = inner_n_vars - values.len();
198		let values_len = values.len();
199		if values_len > inner_n_vars {
200			bail!(Error::InvalidProjection {
201				n_vars: inner_n_vars,
202				values_len,
203			});
204		}
205
206		let inner = self.mut_ref.get_from_set(inner_id);
207		// TODO: This is wrong, should be F::TOWER_LEVEL
208		let tower_level = inner.binary_tower_level();
209		let projected = Projected::new(&inner, values, start_index)?;
210
211		let oracle = |id: OracleId| MultilinearPolyOracle {
212			id,
213			n_vars: inner_n_vars - values_len,
214			tower_level,
215			name: self.name,
216			variant: MultilinearPolyVariant::Projected(projected),
217		};
218
219		Ok(self.mut_ref.add_to_set(oracle))
220	}
221
222	pub fn linear_combination(
223		self,
224		n_vars: usize,
225		inner: impl IntoIterator<Item = (OracleId, F)>,
226	) -> Result<OracleId, Error> {
227		self.linear_combination_with_offset(n_vars, F::ZERO, inner)
228	}
229
230	pub fn linear_combination_with_offset(
231		self,
232		n_vars: usize,
233		offset: F,
234		inner: impl IntoIterator<Item = (OracleId, F)>,
235	) -> Result<OracleId, Error> {
236		let inner = inner
237			.into_iter()
238			.map(|(inner_id, coeff)| {
239				if inner_id >= self.mut_ref.oracles.len() {
240					return Err(Error::InvalidOracleId(inner_id));
241				}
242				if self.mut_ref.n_vars(inner_id) != n_vars {
243					return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
244				}
245				Ok((self.mut_ref.get_from_set(inner_id), coeff))
246			})
247			.collect::<Result<Vec<_>, _>>()?;
248
249		let tower_level = inner
250			.iter()
251			.map(|(oracle, coeff)| oracle.binary_tower_level().max(coeff.min_tower_level()))
252			.max()
253			.unwrap_or(0)
254			.max(offset.min_tower_level());
255
256		let linear_combination = LinearCombination::new(n_vars, offset, inner)?;
257
258		let oracle = |id: OracleId| MultilinearPolyOracle {
259			id,
260			n_vars,
261			tower_level,
262			name: self.name,
263			variant: MultilinearPolyVariant::LinearCombination(linear_combination),
264		};
265
266		Ok(self.mut_ref.add_to_set(oracle))
267	}
268
269	pub fn composite_mle(
270		self,
271		n_vars: usize,
272		inner: impl IntoIterator<Item = OracleId>,
273		comp: ArithCircuit<F>,
274	) -> Result<OracleId, Error> {
275		let inner = inner
276			.into_iter()
277			.map(|inner_id| {
278				if inner_id >= self.mut_ref.oracles.len() {
279					return Err(Error::InvalidOracleId(inner_id));
280				}
281				if self.mut_ref.n_vars(inner_id) != n_vars {
282					return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
283				}
284				Ok(self.mut_ref.get_from_set(inner_id))
285			})
286			.collect::<Result<Vec<_>, _>>()?;
287
288		let tower_level = inner
289			.iter()
290			.map(|oracle| oracle.binary_tower_level())
291			.max()
292			.unwrap_or(0);
293
294		let composite_mle = CompositeMLE::new(n_vars, inner, comp)?;
295
296		let oracle = |id: OracleId| MultilinearPolyOracle {
297			id,
298			n_vars,
299			tower_level,
300			name: self.name,
301			variant: MultilinearPolyVariant::Composite(composite_mle),
302		};
303
304		Ok(self.mut_ref.add_to_set(oracle))
305	}
306
307	pub fn zero_padded(
308		self,
309		inner_id: OracleId,
310		n_pad_vars: usize,
311		nonzero_index: usize,
312		start_index: usize,
313	) -> Result<OracleId, Error> {
314		let inner_n_vars = self.mut_ref.n_vars(inner_id);
315		if start_index > inner_n_vars {
316			bail!(Error::InvalidStartIndex {
317				expected: inner_n_vars
318			});
319		}
320
321		let inner = self.mut_ref.get_from_set(inner_id);
322		let tower_level = inner.binary_tower_level();
323		let padded = ZeroPadded::new(&inner, n_pad_vars, nonzero_index, start_index)?;
324
325		let oracle = |id: OracleId| MultilinearPolyOracle {
326			id,
327			n_vars: inner_n_vars + n_pad_vars,
328			tower_level,
329			name: self.name,
330			variant: MultilinearPolyVariant::ZeroPadded(padded),
331		};
332
333		Ok(self.mut_ref.add_to_set(oracle))
334	}
335
336	fn add_committed_with_name(
337		&mut self,
338		n_vars: usize,
339		tower_level: usize,
340		name: Option<String>,
341	) -> OracleId {
342		let oracle = |oracle_id: OracleId| MultilinearPolyOracle {
343			id: oracle_id,
344			n_vars,
345			tower_level,
346			name: name.clone(),
347			variant: MultilinearPolyVariant::Committed,
348		};
349
350		self.mut_ref.add_to_set(oracle)
351	}
352}
353
354/// An ordered set of multilinear polynomial oracles.
355///
356/// The multilinear polynomial oracles form a directed acyclic graph, where each multilinear oracle
357/// is either transparent, committed, or derived from one or more others. Each oracle is assigned a
358/// unique `OracleId`.
359///
360/// The oracle set also tracks the committed polynomial in batches where each batch is committed
361/// together with a polynomial commitment scheme.
362#[derive(Default, Debug, Clone, SerializeBytes, DeserializeBytes)]
363#[deserialize_bytes(eval_generics(F = BinaryField128b))]
364pub struct MultilinearOracleSet<F: TowerField> {
365	oracles: Vec<MultilinearPolyOracle<F>>,
366}
367
368impl<F: TowerField> MultilinearOracleSet<F> {
369	pub const fn new() -> Self {
370		Self {
371			oracles: Vec::new(),
372		}
373	}
374
375	pub fn size(&self) -> usize {
376		self.oracles.len()
377	}
378
379	pub fn iter(&self) -> impl Iterator<Item = MultilinearPolyOracle<F>> + '_ {
380		(0..self.oracles.len()).map(|id| self.oracle(id))
381	}
382
383	pub const fn add(&mut self) -> MultilinearOracleSetAddition<F> {
384		MultilinearOracleSetAddition {
385			name: None,
386			mut_ref: self,
387		}
388	}
389
390	pub fn add_named<S: ToString>(&mut self, s: S) -> MultilinearOracleSetAddition<F> {
391		MultilinearOracleSetAddition {
392			name: Some(s.to_string()),
393			mut_ref: self,
394		}
395	}
396
397	pub fn is_valid_oracle_id(&self, id: OracleId) -> bool {
398		id < self.oracles.len()
399	}
400
401	fn add_to_set(
402		&mut self,
403		oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
404	) -> OracleId {
405		let id = self.oracles.len();
406		self.oracles.push(oracle(id));
407		id
408	}
409
410	fn get_from_set(&self, id: OracleId) -> MultilinearPolyOracle<F> {
411		self.oracles[id].clone()
412	}
413
414	pub fn add_transparent(
415		&mut self,
416		poly: impl MultivariatePoly<F> + 'static,
417	) -> Result<OracleId, Error> {
418		self.add().transparent(poly)
419	}
420
421	pub fn add_committed(&mut self, n_vars: usize, tower_level: usize) -> OracleId {
422		self.add().committed(n_vars, tower_level)
423	}
424
425	pub fn add_committed_multiple<const N: usize>(
426		&mut self,
427		n_vars: usize,
428		tower_level: usize,
429	) -> [OracleId; N] {
430		self.add().committed_multiple(n_vars, tower_level)
431	}
432
433	pub fn add_repeating(&mut self, id: OracleId, log_count: usize) -> Result<OracleId, Error> {
434		self.add().repeating(id, log_count)
435	}
436
437	pub fn add_shifted(
438		&mut self,
439		id: OracleId,
440		offset: usize,
441		block_bits: usize,
442		variant: ShiftVariant,
443	) -> Result<OracleId, Error> {
444		self.add().shifted(id, offset, block_bits, variant)
445	}
446
447	pub fn add_packed(&mut self, id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
448		self.add().packed(id, log_degree)
449	}
450
451	/// Adds a projection to the variables starting at `start_index`.
452	pub fn add_projected(
453		&mut self,
454		id: OracleId,
455		values: Vec<F>,
456		start_index: usize,
457	) -> Result<OracleId, Error> {
458		self.add().projected(id, values, start_index)
459	}
460
461	/// Adds a projection to the last variables.
462	pub fn add_projected_last_vars(
463		&mut self,
464		id: OracleId,
465		values: Vec<F>,
466	) -> Result<OracleId, Error> {
467		self.add().projected_last_vars(id, values)
468	}
469
470	pub fn add_linear_combination(
471		&mut self,
472		n_vars: usize,
473		inner: impl IntoIterator<Item = (OracleId, F)>,
474	) -> Result<OracleId, Error> {
475		self.add().linear_combination(n_vars, inner)
476	}
477
478	pub fn add_linear_combination_with_offset(
479		&mut self,
480		n_vars: usize,
481		offset: F,
482		inner: impl IntoIterator<Item = (OracleId, F)>,
483	) -> Result<OracleId, Error> {
484		self.add()
485			.linear_combination_with_offset(n_vars, offset, inner)
486	}
487
488	pub fn add_zero_padded(
489		&mut self,
490		id: OracleId,
491		n_pad_vars: usize,
492		nonzero_index: usize,
493		start_index: usize,
494	) -> Result<OracleId, Error> {
495		self.add()
496			.zero_padded(id, n_pad_vars, nonzero_index, start_index)
497	}
498
499	pub fn add_composite_mle(
500		&mut self,
501		n_vars: usize,
502		inner: impl IntoIterator<Item = OracleId>,
503		comp: ArithCircuit<F>,
504	) -> Result<OracleId, Error> {
505		self.add().composite_mle(n_vars, inner, comp)
506	}
507
508	pub fn oracle(&self, id: OracleId) -> MultilinearPolyOracle<F> {
509		self.oracles[id].clone()
510	}
511
512	pub fn n_vars(&self, id: OracleId) -> usize {
513		self.oracles[id].n_vars()
514	}
515
516	pub fn label(&self, id: OracleId) -> String {
517		self.oracles[id].label()
518	}
519
520	/// Maximum tower level of the oracle's values over the boolean hypercube.
521	pub fn tower_level(&self, id: OracleId) -> usize {
522		self.oracles[id].binary_tower_level()
523	}
524}
525
526/// A multilinear polynomial oracle in the polynomial IOP model.
527///
528/// In the multilinear polynomial IOP model, a prover sends multilinear polynomials to an oracle,
529/// and the verifier may at the end of the protocol query their evaluations at chosen points. An
530/// oracle is a verifier and prover's shared view of a polynomial that can be queried for
531/// evaluations by the verifier.
532///
533/// There are three fundamental categories of oracles:
534///
535/// 1. *Transparent oracles*. These are multilinear polynomials with a succinct description and
536///    evaluation algorithm that are known to the verifier. When the verifier queries a transparent
537///    oracle, it evaluates the polynomial itself.
538/// 2. *Committed oracles*. These are polynomials actually sent by the prover. When the polynomial
539///    IOP is compiled to an interactive protocol, these polynomial are committed with a polynomial
540///    commitment scheme.
541/// 3. *Virtual oracles*. A virtual multilinear oracle is not actually sent by the prover, but
542///    instead admits an interactive reduction for evaluation queries to evaluation queries to
543///    other oracles. This is formalized in [DP23] Section 4.
544///
545/// [DP23]: <https://eprint.iacr.org/2023/1784>
546#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
547#[deserialize_bytes(eval_generics(F = BinaryField128b))]
548pub struct MultilinearPolyOracle<F: TowerField> {
549	pub id: OracleId,
550	pub name: Option<String>,
551	pub n_vars: usize,
552	pub tower_level: usize,
553	pub variant: MultilinearPolyVariant<F>,
554}
555
556#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
557pub enum MultilinearPolyVariant<F: TowerField> {
558	Committed,
559	Transparent(TransparentPolyOracle<F>),
560	Repeating { id: usize, log_count: usize },
561	Projected(Projected<F>),
562	Shifted(Shifted),
563	Packed(Packed),
564	LinearCombination(LinearCombination<F>),
565	ZeroPadded(ZeroPadded),
566	Composite(CompositeMLE<F>),
567}
568
569impl DeserializeBytes for MultilinearPolyVariant<BinaryField128b> {
570	fn deserialize(
571		mut buf: impl bytes::Buf,
572		mode: SerializationMode,
573	) -> Result<Self, SerializationError>
574	where
575		Self: Sized,
576	{
577		Ok(match u8::deserialize(&mut buf, mode)? {
578			0 => Self::Committed,
579			1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
580			2 => Self::Repeating {
581				id: DeserializeBytes::deserialize(&mut buf, mode)?,
582				log_count: DeserializeBytes::deserialize(buf, mode)?,
583			},
584			3 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?),
585			4 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?),
586			5 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?),
587			6 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?),
588			7 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?),
589			variant_index => {
590				return Err(SerializationError::UnknownEnumVariant {
591					name: "MultilinearPolyVariant",
592					index: variant_index,
593				});
594			}
595		})
596	}
597}
598
599/// A transparent multilinear polynomial oracle.
600///
601/// See the [`MultilinearPolyOracle`] documentation for context.
602#[derive(Debug, Clone, Getters, CopyGetters)]
603pub struct TransparentPolyOracle<F: Field> {
604	#[get = "pub"]
605	poly: Arc<dyn MultivariatePoly<F>>,
606}
607
608impl<F: TowerField> SerializeBytes for TransparentPolyOracle<F> {
609	fn serialize(
610		&self,
611		mut write_buf: impl bytes::BufMut,
612		mode: SerializationMode,
613	) -> Result<(), SerializationError> {
614		self.poly.erased_serialize(&mut write_buf, mode)
615	}
616}
617
618impl DeserializeBytes for TransparentPolyOracle<BinaryField128b> {
619	fn deserialize(
620		read_buf: impl bytes::Buf,
621		mode: SerializationMode,
622	) -> Result<Self, SerializationError>
623	where
624		Self: Sized,
625	{
626		Ok(Self {
627			poly: Box::<dyn MultivariatePoly<BinaryField128b>>::deserialize(read_buf, mode)?.into(),
628		})
629	}
630}
631
632impl<F: TowerField> TransparentPolyOracle<F> {
633	fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
634		if poly.binary_tower_level() > F::TOWER_LEVEL {
635			bail!(Error::TowerLevelTooHigh {
636				tower_level: poly.binary_tower_level(),
637			});
638		}
639		Ok(Self { poly })
640	}
641}
642
643impl<F: Field> TransparentPolyOracle<F> {
644	/// Maximum tower level of the oracle's values over the boolean hypercube.
645	pub fn binary_tower_level(&self) -> usize {
646		self.poly.binary_tower_level()
647	}
648}
649
650impl<F: Field> PartialEq for TransparentPolyOracle<F> {
651	fn eq(&self, other: &Self) -> bool {
652		Arc::ptr_eq(&self.poly, &other.poly)
653	}
654}
655
656impl<F: Field> Eq for TransparentPolyOracle<F> {}
657
658#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
659pub struct Projected<F: TowerField> {
660	#[get_copy = "pub"]
661	id: OracleId,
662	#[get = "pub"]
663	values: Vec<F>,
664	#[get_copy = "pub"]
665	start_index: usize,
666}
667
668impl<F: TowerField> Projected<F> {
669	fn new(
670		oracle: &MultilinearPolyOracle<F>,
671		values: Vec<F>,
672		start_index: usize,
673	) -> Result<Self, Error> {
674		if values.len() + start_index > oracle.n_vars() {
675			bail!(Error::InvalidProjection {
676				n_vars: oracle.n_vars(),
677				values_len: values.len()
678			});
679		}
680		Ok(Self {
681			id: oracle.id(),
682			values,
683			start_index,
684		})
685	}
686}
687
688#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
689pub struct ZeroPadded {
690	#[get_copy = "pub"]
691	id: OracleId,
692	#[get_copy = "pub"]
693	n_pad_vars: usize,
694	#[get_copy = "pub"]
695	nonzero_index: usize,
696	#[get_copy = "pub"]
697	start_index: usize,
698}
699
700impl ZeroPadded {
701	fn new<F: TowerField>(
702		oracle: &MultilinearPolyOracle<F>,
703		n_pad_vars: usize,
704		nonzero_index: usize,
705		start_index: usize,
706	) -> Result<Self, Error> {
707		if start_index > oracle.n_vars() {
708			bail!(Error::InvalidStartIndex {
709				expected: oracle.n_vars(),
710			});
711		}
712
713		if nonzero_index > 1 << n_pad_vars {
714			bail!(Error::InvalidNonzeroIndex {
715				expected: 1 << n_pad_vars,
716			});
717		}
718
719		Ok(Self {
720			id: oracle.id(),
721			n_pad_vars,
722			nonzero_index,
723			start_index,
724		})
725	}
726}
727
728#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
729pub enum ShiftVariant {
730	CircularLeft,
731	LogicalLeft,
732	LogicalRight,
733}
734
735#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
736pub struct Shifted {
737	#[get_copy = "pub"]
738	id: OracleId,
739	#[get_copy = "pub"]
740	shift_offset: usize,
741	#[get_copy = "pub"]
742	block_size: usize,
743	#[get_copy = "pub"]
744	shift_variant: ShiftVariant,
745}
746
747impl Shifted {
748	fn new<F: TowerField>(
749		oracle: &MultilinearPolyOracle<F>,
750		shift_offset: usize,
751		block_size: usize,
752		shift_variant: ShiftVariant,
753	) -> Result<Self, Error> {
754		if block_size > oracle.n_vars() {
755			bail!(PolynomialError::InvalidBlockSize {
756				n_vars: oracle.n_vars(),
757			});
758		}
759
760		if shift_offset == 0 || shift_offset >= 1 << block_size {
761			bail!(PolynomialError::InvalidShiftOffset {
762				max_shift_offset: (1 << block_size) - 1,
763				shift_offset,
764			});
765		}
766
767		Ok(Self {
768			id: oracle.id(),
769			shift_offset,
770			block_size,
771			shift_variant,
772		})
773	}
774}
775
776#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
777pub struct Packed {
778	#[get_copy = "pub"]
779	id: OracleId,
780	/// The number of tower levels increased by the packing operation.
781	///
782	/// This is the base 2 logarithm of the field extension, and is called $\kappa$ in [DP23],
783	/// Section 4.3.
784	///
785	/// [DP23]: https://eprint.iacr.org/2023/1784
786	#[get_copy = "pub"]
787	log_degree: usize,
788}
789
790#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
791pub struct LinearCombination<F: TowerField> {
792	#[get_copy = "pub"]
793	n_vars: usize,
794	#[get_copy = "pub"]
795	offset: F,
796	inner: Vec<(OracleId, F)>,
797}
798
799impl<F: TowerField> LinearCombination<F> {
800	fn new(
801		n_vars: usize,
802		offset: F,
803		inner: impl IntoIterator<Item = (MultilinearPolyOracle<F>, F)>,
804	) -> Result<Self, Error> {
805		let inner = inner
806			.into_iter()
807			.map(|(oracle, value)| {
808				if oracle.n_vars() == n_vars {
809					Ok((oracle.id(), value))
810				} else {
811					Err(Error::IncorrectNumberOfVariables { expected: n_vars })
812				}
813			})
814			.collect::<Result<Vec<_>, _>>()?;
815		Ok(Self {
816			n_vars,
817			offset,
818			inner,
819		})
820	}
821
822	pub fn n_polys(&self) -> usize {
823		self.inner.len()
824	}
825
826	pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
827		self.inner.iter().map(|(id, _)| *id)
828	}
829
830	pub fn coefficients(&self) -> impl Iterator<Item = F> + '_ {
831		self.inner.iter().map(|(_, coeff)| *coeff)
832	}
833}
834
835/// MLE of a multivariate polynomial evaluated on multilinear oracles.
836///
837/// i.e. the MLE of the evaluations of $C(M_1, M_2, \ldots, M_n)$ on $\{0, 1\}^\mu$ where:
838/// - $C$ is a arbitrary polynomial in $n$ variables
839/// - $M_1, M_2, \ldots, M_n$ are multilinear oracles in μ variables
840///
841/// ($C$ should be sufficiently lightweight to be evaluated by the verifier)
842#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes)]
843pub struct CompositeMLE<F: TowerField> {
844	/// $\mu$
845	#[get_copy = "pub"]
846	n_vars: usize,
847	/// $M_1, M_2, \ldots, M_n$
848	#[getset(get = "pub")]
849	inner: Vec<OracleId>,
850	/// $C$
851	#[getset(get = "pub")]
852	c: ArithCircuitPoly<F>,
853}
854
855impl<F: TowerField> CompositeMLE<F> {
856	pub fn new(
857		n_vars: usize,
858		inner: impl IntoIterator<Item = MultilinearPolyOracle<F>>,
859		c: ArithCircuit<F>,
860	) -> Result<Self, Error> {
861		let inner = inner
862			.into_iter()
863			.map(|oracle| {
864				if oracle.n_vars() == n_vars {
865					Ok(oracle.id())
866				} else {
867					Err(Error::IncorrectNumberOfVariables { expected: n_vars })
868				}
869			})
870			.collect::<Result<Vec<_>, _>>()?;
871		let c = ArithCircuitPoly::with_n_vars(inner.len(), c)
872			.map_err(|_| Error::CompositionMismatch)?; // occurs if `c` has more variables than `inner.len()`
873		Ok(Self { n_vars, inner, c })
874	}
875
876	pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
877		self.inner.iter().copied()
878	}
879
880	pub fn n_polys(&self) -> usize {
881		self.inner.len()
882	}
883}
884
885impl<F: TowerField> MultilinearPolyOracle<F> {
886	pub const fn id(&self) -> OracleId {
887		self.id
888	}
889
890	pub fn label(&self) -> String {
891		match self.name() {
892			Some(name) => format!("{}: {}", self.type_str(), name),
893			None => format!("{}: id={}", self.type_str(), self.id()),
894		}
895	}
896
897	pub fn name(&self) -> Option<&str> {
898		self.name.as_deref()
899	}
900
901	const fn type_str(&self) -> &str {
902		match self.variant {
903			MultilinearPolyVariant::Transparent(_) => "Transparent",
904			MultilinearPolyVariant::Committed => "Committed",
905			MultilinearPolyVariant::Repeating { .. } => "Repeating",
906			MultilinearPolyVariant::Projected(_) => "Projected",
907			MultilinearPolyVariant::Shifted(_) => "Shifted",
908			MultilinearPolyVariant::Packed(_) => "Packed",
909			MultilinearPolyVariant::LinearCombination(_) => "LinearCombination",
910			MultilinearPolyVariant::ZeroPadded(_) => "ZeroPadded",
911			MultilinearPolyVariant::Composite(_) => "CompositeMLE",
912		}
913	}
914
915	pub const fn n_vars(&self) -> usize {
916		self.n_vars
917	}
918
919	/// Maximum tower level of the oracle's values over the boolean hypercube.
920	pub const fn binary_tower_level(&self) -> usize {
921		self.tower_level
922	}
923
924	pub fn into_composite(self) -> CompositePolyOracle<F> {
925		let composite =
926			CompositePolyOracle::new(self.n_vars(), vec![self], IdentityCompositionPoly);
927		composite.expect("Can always apply the identity composition to one variable")
928	}
929}
930
931#[cfg(test)]
932mod tests {
933	use binius_field::{BinaryField128b, BinaryField1b, Field, TowerField};
934
935	use super::MultilinearOracleSet;
936
937	#[test]
938	fn add_projection_with_all_vars() {
939		type F = BinaryField128b;
940		let mut oracles = MultilinearOracleSet::<F>::new();
941		let data = oracles.add_committed(5, BinaryField1b::TOWER_LEVEL);
942		let start_index = 0;
943		let projected = oracles
944			.add_projected(data, vec![F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], start_index)
945			.unwrap();
946		let _ = oracles.oracle(projected);
947	}
948}