binius_core/oracle/
multilinear.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, fmt::Debug, sync::Arc};
4
5use binius_field::{BinaryField128b, Field, TowerField};
6use binius_macros::{DeserializeBytes, SerializeBytes};
7use binius_math::ArithExpr;
8use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
9use bytes::Buf;
10use getset::{CopyGetters, Getters};
11
12use crate::{
13	oracle::{CompositePolyOracle, Error},
14	polynomial::{
15		ArithCircuitPoly, Error as PolynomialError, IdentityCompositionPoly, MultivariatePoly,
16	},
17};
18
19/// Identifier for a multilinear oracle in a [`MultilinearOracleSet`].
20pub type OracleId = usize;
21
22/// Meta struct that lets you add optional `name` for the Multilinear before adding to the
23/// [`MultilinearOracleSet`]
24pub struct MultilinearOracleSetAddition<'a, F: TowerField> {
25	name: Option<String>,
26	mut_ref: &'a mut MultilinearOracleSet<F>,
27}
28
29impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
30	pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
31		if poly.binary_tower_level() > F::TOWER_LEVEL {
32			bail!(Error::TowerLevelTooHigh {
33				tower_level: poly.binary_tower_level(),
34			});
35		}
36
37		let inner = TransparentPolyOracle::new(Arc::new(poly))?;
38
39		let oracle = |id: OracleId| MultilinearPolyOracle {
40			id,
41			n_vars: inner.poly.n_vars(),
42			tower_level: inner.poly.binary_tower_level(),
43			name: self.name,
44			variant: MultilinearPolyVariant::Transparent(inner),
45		};
46
47		Ok(self.mut_ref.add_to_set(oracle))
48	}
49
50	pub fn committed(mut self, n_vars: usize, tower_level: usize) -> OracleId {
51		let name = self.name.take();
52		self.add_committed_with_name(n_vars, tower_level, name)
53	}
54
55	pub fn committed_multiple<const N: usize>(
56		mut self,
57		n_vars: usize,
58		tower_level: usize,
59	) -> [OracleId; N] {
60		match &self.name.take() {
61			None => [0; N].map(|_| self.add_committed_with_name(n_vars, tower_level, None)),
62			Some(s) => {
63				let x: [usize; N] = array::from_fn(|i| i);
64				x.map(|i| {
65					self.add_committed_with_name(n_vars, tower_level, Some(format!("{}_{}", s, i)))
66				})
67			}
68		}
69	}
70
71	pub fn repeating(self, inner_id: OracleId, log_count: usize) -> Result<OracleId, Error> {
72		if inner_id >= self.mut_ref.oracles.len() {
73			bail!(Error::InvalidOracleId(inner_id));
74		}
75
76		let inner = self.mut_ref.get_from_set(inner_id);
77
78		let oracle = |id: OracleId| MultilinearPolyOracle {
79			id,
80			n_vars: inner.n_vars + log_count,
81			tower_level: inner.tower_level,
82			name: self.name,
83			variant: MultilinearPolyVariant::Repeating {
84				id: inner_id,
85				log_count,
86			},
87		};
88
89		Ok(self.mut_ref.add_to_set(oracle))
90	}
91
92	pub fn shifted(
93		self,
94		inner_id: OracleId,
95		offset: usize,
96		block_bits: usize,
97		variant: ShiftVariant,
98	) -> Result<OracleId, Error> {
99		if inner_id >= self.mut_ref.oracles.len() {
100			bail!(Error::InvalidOracleId(inner_id));
101		}
102
103		let inner = self.mut_ref.get_from_set(inner_id);
104		if block_bits > inner.n_vars {
105			bail!(PolynomialError::InvalidBlockSize {
106				n_vars: inner.n_vars,
107			});
108		}
109
110		if offset == 0 || offset >= 1 << block_bits {
111			bail!(PolynomialError::InvalidShiftOffset {
112				max_shift_offset: (1 << block_bits) - 1,
113				shift_offset: offset,
114			});
115		}
116
117		let shifted = Shifted::new(&inner, offset, block_bits, variant)?;
118
119		let oracle = |id: OracleId| MultilinearPolyOracle {
120			id,
121			n_vars: inner.n_vars,
122			tower_level: inner.tower_level,
123			name: self.name,
124			variant: MultilinearPolyVariant::Shifted(shifted),
125		};
126
127		Ok(self.mut_ref.add_to_set(oracle))
128	}
129
130	pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
131		if inner_id >= self.mut_ref.oracles.len() {
132			bail!(Error::InvalidOracleId(inner_id));
133		}
134
135		let inner_n_vars = self.mut_ref.n_vars(inner_id);
136		if log_degree > inner_n_vars {
137			bail!(Error::NotEnoughVarsForPacking {
138				n_vars: inner_n_vars,
139				log_degree,
140			});
141		}
142
143		let inner_tower_level = self.mut_ref.tower_level(inner_id);
144
145		let packed = Packed {
146			id: inner_id,
147			log_degree,
148		};
149
150		let oracle = |id: OracleId| MultilinearPolyOracle {
151			id,
152			n_vars: inner_n_vars - log_degree,
153			tower_level: inner_tower_level + log_degree,
154			name: self.name,
155			variant: MultilinearPolyVariant::Packed(packed),
156		};
157
158		Ok(self.mut_ref.add_to_set(oracle))
159	}
160
161	pub fn projected(
162		self,
163		inner_id: OracleId,
164		values: Vec<F>,
165		variant: ProjectionVariant,
166	) -> Result<OracleId, Error> {
167		let inner_n_vars = self.mut_ref.n_vars(inner_id);
168		let values_len = values.len();
169		if values_len > inner_n_vars {
170			bail!(Error::InvalidProjection {
171				n_vars: inner_n_vars,
172				values_len,
173			});
174		}
175
176		let inner = self.mut_ref.get_from_set(inner_id);
177		// TODO: This is wrong, should be F::TOWER_LEVEL
178		let tower_level = inner.binary_tower_level();
179		let projected = Projected::new(&inner, values, variant)?;
180
181		let oracle = |id: OracleId| MultilinearPolyOracle {
182			id,
183			n_vars: inner_n_vars - values_len,
184			tower_level,
185			name: self.name,
186			variant: MultilinearPolyVariant::Projected(projected),
187		};
188
189		Ok(self.mut_ref.add_to_set(oracle))
190	}
191
192	pub fn linear_combination(
193		self,
194		n_vars: usize,
195		inner: impl IntoIterator<Item = (OracleId, F)>,
196	) -> Result<OracleId, Error> {
197		self.linear_combination_with_offset(n_vars, F::ZERO, inner)
198	}
199
200	pub fn linear_combination_with_offset(
201		self,
202		n_vars: usize,
203		offset: F,
204		inner: impl IntoIterator<Item = (OracleId, F)>,
205	) -> Result<OracleId, Error> {
206		let inner = inner
207			.into_iter()
208			.map(|(inner_id, coeff)| {
209				if inner_id >= self.mut_ref.oracles.len() {
210					return Err(Error::InvalidOracleId(inner_id));
211				}
212				if self.mut_ref.n_vars(inner_id) != n_vars {
213					return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
214				}
215				Ok((self.mut_ref.get_from_set(inner_id), coeff))
216			})
217			.collect::<Result<Vec<_>, _>>()?;
218
219		let tower_level = inner
220			.iter()
221			.map(|(oracle, _)| oracle.binary_tower_level())
222			.max()
223			.unwrap_or(0);
224
225		let linear_combination = LinearCombination::new(n_vars, offset, inner)?;
226
227		let oracle = |id: OracleId| MultilinearPolyOracle {
228			id,
229			n_vars,
230			tower_level,
231			name: self.name,
232			variant: MultilinearPolyVariant::LinearCombination(linear_combination),
233		};
234
235		Ok(self.mut_ref.add_to_set(oracle))
236	}
237
238	pub fn composite_mle(
239		self,
240		n_vars: usize,
241		inner: impl IntoIterator<Item = OracleId>,
242		comp: ArithExpr<F>,
243	) -> Result<OracleId, Error> {
244		let inner = inner
245			.into_iter()
246			.map(|inner_id| {
247				if inner_id >= self.mut_ref.oracles.len() {
248					return Err(Error::InvalidOracleId(inner_id));
249				}
250				if self.mut_ref.n_vars(inner_id) != n_vars {
251					return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
252				}
253				Ok(self.mut_ref.get_from_set(inner_id))
254			})
255			.collect::<Result<Vec<_>, _>>()?;
256
257		let tower_level = inner
258			.iter()
259			.map(|oracle| oracle.binary_tower_level())
260			.max()
261			.unwrap_or(0);
262
263		let composite_mle = CompositeMLE::new(n_vars, inner, comp)?;
264
265		let oracle = |id: OracleId| MultilinearPolyOracle {
266			id,
267			n_vars,
268			tower_level,
269			name: self.name,
270			variant: MultilinearPolyVariant::Composite(composite_mle),
271		};
272
273		Ok(self.mut_ref.add_to_set(oracle))
274	}
275
276	pub fn zero_padded(self, inner_id: OracleId, n_vars: usize) -> Result<OracleId, Error> {
277		if inner_id >= self.mut_ref.oracles.len() {
278			bail!(Error::InvalidOracleId(inner_id));
279		}
280
281		if self.mut_ref.n_vars(inner_id) > n_vars {
282			bail!(Error::IncorrectNumberOfVariables {
283				expected: self.mut_ref.n_vars(inner_id),
284			});
285		};
286
287		let inner = self.mut_ref.get_from_set(inner_id);
288
289		let oracle = |id: OracleId| MultilinearPolyOracle {
290			id,
291			n_vars,
292			tower_level: inner.tower_level,
293			name: self.name,
294			variant: MultilinearPolyVariant::ZeroPadded(inner_id),
295		};
296
297		Ok(self.mut_ref.add_to_set(oracle))
298	}
299
300	fn add_committed_with_name(
301		&mut self,
302		n_vars: usize,
303		tower_level: usize,
304		name: Option<String>,
305	) -> OracleId {
306		let oracle = |oracle_id: OracleId| MultilinearPolyOracle {
307			id: oracle_id,
308			n_vars,
309			tower_level,
310			name: name.clone(),
311			variant: MultilinearPolyVariant::Committed,
312		};
313
314		self.mut_ref.add_to_set(oracle)
315	}
316}
317
318/// An ordered set of multilinear polynomial oracles.
319///
320/// The multilinear polynomial oracles form a directed acyclic graph, where each multilinear oracle
321/// is either transparent, committed, or derived from one or more others. Each oracle is assigned a
322/// unique `OracleId`.
323///
324/// The oracle set also tracks the committed polynomial in batches where each batch is committed
325/// together with a polynomial commitment scheme.
326#[derive(Default, Debug, Clone, SerializeBytes)]
327pub struct MultilinearOracleSet<F: TowerField> {
328	oracles: Vec<MultilinearPolyOracle<F>>,
329}
330
331impl DeserializeBytes for MultilinearOracleSet<BinaryField128b> {
332	fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result<Self, SerializationError>
333	where
334		Self: Sized,
335	{
336		Ok(Self {
337			oracles: DeserializeBytes::deserialize(read_buf, mode)?,
338		})
339	}
340}
341
342impl<F: TowerField> MultilinearOracleSet<F> {
343	pub const fn new() -> Self {
344		Self {
345			oracles: Vec::new(),
346		}
347	}
348
349	pub fn size(&self) -> usize {
350		self.oracles.len()
351	}
352
353	pub fn iter(&self) -> impl Iterator<Item = MultilinearPolyOracle<F>> + '_ {
354		(0..self.oracles.len()).map(|id| self.oracle(id))
355	}
356
357	pub const fn add(&mut self) -> MultilinearOracleSetAddition<F> {
358		MultilinearOracleSetAddition {
359			name: None,
360			mut_ref: self,
361		}
362	}
363
364	pub fn add_named<S: ToString>(&mut self, s: S) -> MultilinearOracleSetAddition<F> {
365		MultilinearOracleSetAddition {
366			name: Some(s.to_string()),
367			mut_ref: self,
368		}
369	}
370
371	pub fn is_valid_oracle_id(&self, id: OracleId) -> bool {
372		id < self.oracles.len()
373	}
374
375	fn add_to_set(
376		&mut self,
377		oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
378	) -> OracleId {
379		let id = self.oracles.len();
380		self.oracles.push(oracle(id));
381		id
382	}
383
384	fn get_from_set(&self, id: OracleId) -> MultilinearPolyOracle<F> {
385		self.oracles[id].clone()
386	}
387
388	pub fn add_transparent(
389		&mut self,
390		poly: impl MultivariatePoly<F> + 'static,
391	) -> Result<OracleId, Error> {
392		self.add().transparent(poly)
393	}
394
395	pub fn add_committed(&mut self, n_vars: usize, tower_level: usize) -> OracleId {
396		self.add().committed(n_vars, tower_level)
397	}
398
399	pub fn add_committed_multiple<const N: usize>(
400		&mut self,
401		n_vars: usize,
402		tower_level: usize,
403	) -> [OracleId; N] {
404		self.add().committed_multiple(n_vars, tower_level)
405	}
406
407	pub fn add_repeating(&mut self, id: OracleId, log_count: usize) -> Result<OracleId, Error> {
408		self.add().repeating(id, log_count)
409	}
410
411	pub fn add_shifted(
412		&mut self,
413		id: OracleId,
414		offset: usize,
415		block_bits: usize,
416		variant: ShiftVariant,
417	) -> Result<OracleId, Error> {
418		self.add().shifted(id, offset, block_bits, variant)
419	}
420
421	pub fn add_packed(&mut self, id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
422		self.add().packed(id, log_degree)
423	}
424
425	pub fn add_projected(
426		&mut self,
427		id: OracleId,
428		values: Vec<F>,
429		variant: ProjectionVariant,
430	) -> Result<OracleId, Error> {
431		self.add().projected(id, values, variant)
432	}
433
434	pub fn add_linear_combination(
435		&mut self,
436		n_vars: usize,
437		inner: impl IntoIterator<Item = (OracleId, F)>,
438	) -> Result<OracleId, Error> {
439		self.add().linear_combination(n_vars, inner)
440	}
441
442	pub fn add_linear_combination_with_offset(
443		&mut self,
444		n_vars: usize,
445		offset: F,
446		inner: impl IntoIterator<Item = (OracleId, F)>,
447	) -> Result<OracleId, Error> {
448		self.add()
449			.linear_combination_with_offset(n_vars, offset, inner)
450	}
451
452	pub fn add_zero_padded(&mut self, id: OracleId, n_vars: usize) -> Result<OracleId, Error> {
453		self.add().zero_padded(id, n_vars)
454	}
455
456	pub fn add_composite_mle(
457		&mut self,
458		n_vars: usize,
459		inner: impl IntoIterator<Item = OracleId>,
460		comp: ArithExpr<F>,
461	) -> Result<OracleId, Error> {
462		self.add().composite_mle(n_vars, inner, comp)
463	}
464
465	pub fn oracle(&self, id: OracleId) -> MultilinearPolyOracle<F> {
466		self.oracles[id].clone()
467	}
468
469	pub fn n_vars(&self, id: OracleId) -> usize {
470		self.oracles[id].n_vars()
471	}
472
473	pub fn label(&self, id: OracleId) -> String {
474		self.oracles[id].label()
475	}
476
477	/// Maximum tower level of the oracle's values over the boolean hypercube.
478	pub fn tower_level(&self, id: OracleId) -> usize {
479		self.oracles[id].binary_tower_level()
480	}
481}
482
483/// A multilinear polynomial oracle in the polynomial IOP model.
484///
485/// In the multilinear polynomial IOP model, a prover sends multilinear polynomials to an oracle,
486/// and the verifier may at the end of the protocol query their evaluations at chosen points. An
487/// oracle is a verifier and prover's shared view of a polynomial that can be queried for
488/// evaluations by the verifier.
489///
490/// There are three fundamental categories of oracles:
491///
492/// 1. *Transparent oracles*. These are multilinear polynomials with a succinct description and
493///    evaluation algorithm that are known to the verifier. When the verifier queries a transparent
494///    oracle, it evaluates the polynomial itself.
495/// 2. *Committed oracles*. These are polynomials actually sent by the prover. When the polynomial
496///    IOP is compiled to an interactive protocol, these polynomial are committed with a polynomial
497///    commitment scheme.
498/// 3. *Virtual oracles*. A virtual multilinear oracle is not actually sent by the prover, but
499///    instead admits an interactive reduction for evaluation queries to evaluation queries to
500///    other oracles. This is formalized in [DP23] Section 4.
501///
502/// [DP23]: <https://eprint.iacr.org/2023/1784>
503#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
504pub struct MultilinearPolyOracle<F: TowerField> {
505	pub id: OracleId,
506	pub name: Option<String>,
507	pub n_vars: usize,
508	pub tower_level: usize,
509	pub variant: MultilinearPolyVariant<F>,
510}
511
512impl DeserializeBytes for MultilinearPolyOracle<BinaryField128b> {
513	fn deserialize(
514		mut read_buf: impl bytes::Buf,
515		mode: SerializationMode,
516	) -> Result<Self, SerializationError>
517	where
518		Self: Sized,
519	{
520		Ok(Self {
521			id: DeserializeBytes::deserialize(&mut read_buf, mode)?,
522			name: DeserializeBytes::deserialize(&mut read_buf, mode)?,
523			n_vars: DeserializeBytes::deserialize(&mut read_buf, mode)?,
524			tower_level: DeserializeBytes::deserialize(&mut read_buf, mode)?,
525			variant: DeserializeBytes::deserialize(&mut read_buf, mode)?,
526		})
527	}
528}
529
530#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
531pub enum MultilinearPolyVariant<F: TowerField> {
532	Committed,
533	Transparent(TransparentPolyOracle<F>),
534	Repeating { id: usize, log_count: usize },
535	Projected(Projected<F>),
536	Shifted(Shifted),
537	Packed(Packed),
538	LinearCombination(LinearCombination<F>),
539	ZeroPadded(OracleId),
540	Composite(CompositeMLE<F>),
541}
542
543impl DeserializeBytes for MultilinearPolyVariant<BinaryField128b> {
544	fn deserialize(
545		mut buf: impl bytes::Buf,
546		mode: SerializationMode,
547	) -> Result<Self, SerializationError>
548	where
549		Self: Sized,
550	{
551		Ok(match u8::deserialize(&mut buf, mode)? {
552			0 => Self::Committed,
553			1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
554			2 => Self::Repeating {
555				id: DeserializeBytes::deserialize(&mut buf, mode)?,
556				log_count: DeserializeBytes::deserialize(buf, mode)?,
557			},
558			3 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?),
559			4 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?),
560			5 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?),
561			6 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?),
562			7 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?),
563			variant_index => {
564				return Err(SerializationError::UnknownEnumVariant {
565					name: "MultilinearPolyVariant",
566					index: variant_index,
567				});
568			}
569		})
570	}
571}
572
573/// A transparent multilinear polynomial oracle.
574///
575/// See the [`MultilinearPolyOracle`] documentation for context.
576#[derive(Debug, Clone, Getters, CopyGetters)]
577pub struct TransparentPolyOracle<F: Field> {
578	#[get = "pub"]
579	poly: Arc<dyn MultivariatePoly<F>>,
580}
581
582impl<F: TowerField> SerializeBytes for TransparentPolyOracle<F> {
583	fn serialize(
584		&self,
585		mut write_buf: impl bytes::BufMut,
586		mode: SerializationMode,
587	) -> Result<(), SerializationError> {
588		self.poly.erased_serialize(&mut write_buf, mode)
589	}
590}
591
592impl DeserializeBytes for TransparentPolyOracle<BinaryField128b> {
593	fn deserialize(
594		read_buf: impl bytes::Buf,
595		mode: SerializationMode,
596	) -> Result<Self, SerializationError>
597	where
598		Self: Sized,
599	{
600		Ok(Self {
601			poly: Box::<dyn MultivariatePoly<BinaryField128b>>::deserialize(read_buf, mode)?.into(),
602		})
603	}
604}
605
606impl<F: TowerField> TransparentPolyOracle<F> {
607	fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
608		if poly.binary_tower_level() > F::TOWER_LEVEL {
609			bail!(Error::TowerLevelTooHigh {
610				tower_level: poly.binary_tower_level(),
611			});
612		}
613		Ok(Self { poly })
614	}
615}
616
617impl<F: Field> TransparentPolyOracle<F> {
618	/// Maximum tower level of the oracle's values over the boolean hypercube.
619	pub fn binary_tower_level(&self) -> usize {
620		self.poly.binary_tower_level()
621	}
622}
623
624impl<F: Field> PartialEq for TransparentPolyOracle<F> {
625	fn eq(&self, other: &Self) -> bool {
626		Arc::ptr_eq(&self.poly, &other.poly)
627	}
628}
629
630impl<F: Field> Eq for TransparentPolyOracle<F> {}
631
632#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
633pub enum ProjectionVariant {
634	FirstVars,
635	LastVars,
636}
637
638#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
639pub struct Projected<F: TowerField> {
640	#[get_copy = "pub"]
641	id: OracleId,
642	#[get = "pub"]
643	values: Vec<F>,
644	#[get_copy = "pub"]
645	projection_variant: ProjectionVariant,
646}
647
648impl<F: TowerField> Projected<F> {
649	fn new(
650		oracle: &MultilinearPolyOracle<F>,
651		values: Vec<F>,
652		projection_variant: ProjectionVariant,
653	) -> Result<Self, Error> {
654		if values.len() > oracle.n_vars() {
655			bail!(Error::InvalidProjection {
656				n_vars: oracle.n_vars(),
657				values_len: values.len()
658			});
659		}
660		Ok(Self {
661			id: oracle.id(),
662			values,
663			projection_variant,
664		})
665	}
666}
667
668#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
669pub enum ShiftVariant {
670	CircularLeft,
671	LogicalLeft,
672	LogicalRight,
673}
674
675#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
676pub struct Shifted {
677	#[get_copy = "pub"]
678	id: OracleId,
679	#[get_copy = "pub"]
680	shift_offset: usize,
681	#[get_copy = "pub"]
682	block_size: usize,
683	#[get_copy = "pub"]
684	shift_variant: ShiftVariant,
685}
686
687impl Shifted {
688	fn new<F: TowerField>(
689		oracle: &MultilinearPolyOracle<F>,
690		shift_offset: usize,
691		block_size: usize,
692		shift_variant: ShiftVariant,
693	) -> Result<Self, Error> {
694		if block_size > oracle.n_vars() {
695			bail!(PolynomialError::InvalidBlockSize {
696				n_vars: oracle.n_vars(),
697			});
698		}
699
700		if shift_offset == 0 || shift_offset >= 1 << block_size {
701			bail!(PolynomialError::InvalidShiftOffset {
702				max_shift_offset: (1 << block_size) - 1,
703				shift_offset,
704			});
705		}
706
707		Ok(Self {
708			id: oracle.id(),
709			shift_offset,
710			block_size,
711			shift_variant,
712		})
713	}
714}
715
716#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
717pub struct Packed {
718	#[get_copy = "pub"]
719	id: OracleId,
720	/// The number of tower levels increased by the packing operation.
721	///
722	/// This is the base 2 logarithm of the field extension, and is called $\kappa$ in [DP23],
723	/// Section 4.3.
724	///
725	/// [DP23]: https://eprint.iacr.org/2023/1784
726	#[get_copy = "pub"]
727	log_degree: usize,
728}
729
730#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
731pub struct LinearCombination<F: TowerField> {
732	#[get_copy = "pub"]
733	n_vars: usize,
734	#[get_copy = "pub"]
735	offset: F,
736	inner: Vec<(OracleId, F)>,
737}
738
739impl<F: TowerField> LinearCombination<F> {
740	fn new(
741		n_vars: usize,
742		offset: F,
743		inner: impl IntoIterator<Item = (MultilinearPolyOracle<F>, F)>,
744	) -> Result<Self, Error> {
745		let inner = inner
746			.into_iter()
747			.map(|(oracle, value)| {
748				if oracle.n_vars() == n_vars {
749					Ok((oracle.id(), value))
750				} else {
751					Err(Error::IncorrectNumberOfVariables { expected: n_vars })
752				}
753			})
754			.collect::<Result<Vec<_>, _>>()?;
755		Ok(Self {
756			n_vars,
757			offset,
758			inner,
759		})
760	}
761
762	pub fn n_polys(&self) -> usize {
763		self.inner.len()
764	}
765
766	pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
767		self.inner.iter().map(|(id, _)| *id)
768	}
769
770	pub fn coefficients(&self) -> impl Iterator<Item = F> + '_ {
771		self.inner.iter().map(|(_, coeff)| *coeff)
772	}
773}
774
775/// MLE of a multivariate polynomial evaluated on multilinear oracles.
776///
777/// i.e. the MLE of the evaluations of $C(M_1, M_2, \ldots, M_n)$ on $\{0, 1\}^\mu$ where:
778/// - $C$ is a arbitrary polynomial in $n$ variables
779/// - $M_1, M_2, \ldots, M_n$ are multilinear oracles in μ variables
780///
781/// ($C$ should be sufficiently lightweight to be evaluated by the verifier)
782#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes)]
783pub struct CompositeMLE<F: TowerField> {
784	/// $\mu$
785	#[get_copy = "pub"]
786	n_vars: usize,
787	/// $M_1, M_2, \ldots, M_n$
788	#[getset(get = "pub")]
789	inner: Vec<OracleId>,
790	/// $C$
791	#[getset(get = "pub")]
792	c: ArithCircuitPoly<F>,
793}
794
795impl<F: TowerField> CompositeMLE<F> {
796	pub fn new(
797		n_vars: usize,
798		inner: impl IntoIterator<Item = MultilinearPolyOracle<F>>,
799		c: ArithExpr<F>,
800	) -> Result<Self, Error> {
801		let inner = inner
802			.into_iter()
803			.map(|oracle| {
804				if oracle.n_vars() == n_vars {
805					Ok(oracle.id())
806				} else {
807					Err(Error::IncorrectNumberOfVariables { expected: n_vars })
808				}
809			})
810			.collect::<Result<Vec<_>, _>>()?;
811		let c = ArithCircuitPoly::with_n_vars(inner.len(), c)
812			.map_err(|_| Error::CompositionMismatch)?; // occurs if `c` has more variables than `inner.len()`
813		Ok(Self { n_vars, inner, c })
814	}
815
816	pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
817		self.inner.iter().copied()
818	}
819
820	pub fn n_polys(&self) -> usize {
821		self.inner.len()
822	}
823}
824
825impl<F: TowerField> MultilinearPolyOracle<F> {
826	pub const fn id(&self) -> OracleId {
827		self.id
828	}
829
830	pub fn label(&self) -> String {
831		match self.name() {
832			Some(name) => format!("{}: {}", self.type_str(), name),
833			None => format!("{}: id={}", self.type_str(), self.id()),
834		}
835	}
836
837	pub fn name(&self) -> Option<&str> {
838		self.name.as_deref()
839	}
840
841	const fn type_str(&self) -> &str {
842		match self.variant {
843			MultilinearPolyVariant::Transparent(_) => "Transparent",
844			MultilinearPolyVariant::Committed => "Committed",
845			MultilinearPolyVariant::Repeating { .. } => "Repeating",
846			MultilinearPolyVariant::Projected(_) => "Projected",
847			MultilinearPolyVariant::Shifted(_) => "Shifted",
848			MultilinearPolyVariant::Packed(_) => "Packed",
849			MultilinearPolyVariant::LinearCombination(_) => "LinearCombination",
850			MultilinearPolyVariant::ZeroPadded(_) => "ZeroPadded",
851			MultilinearPolyVariant::Composite(_) => "CompositeMLE",
852		}
853	}
854
855	pub const fn n_vars(&self) -> usize {
856		self.n_vars
857	}
858
859	/// Maximum tower level of the oracle's values over the boolean hypercube.
860	pub const fn binary_tower_level(&self) -> usize {
861		self.tower_level
862	}
863
864	pub fn into_composite(self) -> CompositePolyOracle<F> {
865		let composite =
866			CompositePolyOracle::new(self.n_vars(), vec![self], IdentityCompositionPoly);
867		composite.expect("Can always apply the identity composition to one variable")
868	}
869}
870
871#[cfg(test)]
872mod tests {
873	use binius_field::{BinaryField128b, BinaryField1b, Field, TowerField};
874
875	use super::{MultilinearOracleSet, ProjectionVariant};
876
877	#[test]
878	fn add_projection_with_all_vars() {
879		type F = BinaryField128b;
880		let mut oracles = MultilinearOracleSet::<F>::new();
881		let data = oracles.add_committed(5, BinaryField1b::TOWER_LEVEL);
882		let projected = oracles
883			.add_projected(
884				data,
885				vec![F::ONE, F::ONE, F::ONE, F::ONE, F::ONE],
886				ProjectionVariant::FirstVars,
887			)
888			.unwrap();
889		let _ = oracles.oracle(projected);
890	}
891}