binius_core/oracle/
multilinear.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, fmt::Debug, sync::Arc};
4
5use binius_field::{BinaryField128b, Field, TowerField};
6use binius_macros::{DeserializeBytes, SerializeBytes};
7use binius_math::ArithExpr;
8use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
9use bytes::Buf;
10use getset::{CopyGetters, Getters};
11
12use crate::{
13	oracle::{CompositePolyOracle, Error},
14	polynomial::{
15		ArithCircuitPoly, Error as PolynomialError, IdentityCompositionPoly, MultivariatePoly,
16	},
17};
18
19/// Identifier for a multilinear oracle in a [`MultilinearOracleSet`].
20pub type OracleId = usize;
21
22/// Meta struct that lets you add optional `name` for the Multilinear before adding to the
23/// [`MultilinearOracleSet`]
24pub struct MultilinearOracleSetAddition<'a, F: TowerField> {
25	name: Option<String>,
26	mut_ref: &'a mut MultilinearOracleSet<F>,
27}
28
29impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
30	pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
31		if poly.binary_tower_level() > F::TOWER_LEVEL {
32			bail!(Error::TowerLevelTooHigh {
33				tower_level: poly.binary_tower_level(),
34			});
35		}
36
37		let inner = TransparentPolyOracle::new(Arc::new(poly))?;
38
39		let oracle = |id: OracleId| MultilinearPolyOracle {
40			id,
41			n_vars: inner.poly.n_vars(),
42			tower_level: inner.poly.binary_tower_level(),
43			name: self.name,
44			variant: MultilinearPolyVariant::Transparent(inner),
45		};
46
47		Ok(self.mut_ref.add_to_set(oracle))
48	}
49
50	pub fn committed(mut self, n_vars: usize, tower_level: usize) -> OracleId {
51		let name = self.name.take();
52		self.add_committed_with_name(n_vars, tower_level, name)
53	}
54
55	pub fn committed_multiple<const N: usize>(
56		mut self,
57		n_vars: usize,
58		tower_level: usize,
59	) -> [OracleId; N] {
60		match &self.name.take() {
61			None => [0; N].map(|_| self.add_committed_with_name(n_vars, tower_level, None)),
62			Some(s) => {
63				let x: [usize; N] = array::from_fn(|i| i);
64				x.map(|i| {
65					self.add_committed_with_name(n_vars, tower_level, Some(format!("{}_{}", s, i)))
66				})
67			}
68		}
69	}
70
71	pub fn repeating(self, inner_id: OracleId, log_count: usize) -> Result<OracleId, Error> {
72		if inner_id >= self.mut_ref.oracles.len() {
73			bail!(Error::InvalidOracleId(inner_id));
74		}
75
76		let inner = self.mut_ref.get_from_set(inner_id);
77
78		let oracle = |id: OracleId| MultilinearPolyOracle {
79			id,
80			n_vars: inner.n_vars + log_count,
81			tower_level: inner.tower_level,
82			name: self.name,
83			variant: MultilinearPolyVariant::Repeating {
84				id: inner_id,
85				log_count,
86			},
87		};
88
89		Ok(self.mut_ref.add_to_set(oracle))
90	}
91
92	pub fn shifted(
93		self,
94		inner_id: OracleId,
95		offset: usize,
96		block_bits: usize,
97		variant: ShiftVariant,
98	) -> Result<OracleId, Error> {
99		if inner_id >= self.mut_ref.oracles.len() {
100			bail!(Error::InvalidOracleId(inner_id));
101		}
102
103		let inner = self.mut_ref.get_from_set(inner_id);
104		if block_bits > inner.n_vars {
105			bail!(PolynomialError::InvalidBlockSize {
106				n_vars: inner.n_vars,
107			});
108		}
109
110		if offset == 0 || offset >= 1 << block_bits {
111			bail!(PolynomialError::InvalidShiftOffset {
112				max_shift_offset: (1 << block_bits) - 1,
113				shift_offset: offset,
114			});
115		}
116
117		let shifted = Shifted::new(&inner, offset, block_bits, variant)?;
118
119		let oracle = |id: OracleId| MultilinearPolyOracle {
120			id,
121			n_vars: inner.n_vars,
122			tower_level: inner.tower_level,
123			name: self.name,
124			variant: MultilinearPolyVariant::Shifted(shifted),
125		};
126
127		Ok(self.mut_ref.add_to_set(oracle))
128	}
129
130	pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
131		if inner_id >= self.mut_ref.oracles.len() {
132			bail!(Error::InvalidOracleId(inner_id));
133		}
134
135		let inner_n_vars = self.mut_ref.n_vars(inner_id);
136		if log_degree > inner_n_vars {
137			bail!(Error::NotEnoughVarsForPacking {
138				n_vars: inner_n_vars,
139				log_degree,
140			});
141		}
142
143		let inner_tower_level = self.mut_ref.tower_level(inner_id);
144
145		let packed = Packed {
146			id: inner_id,
147			log_degree,
148		};
149
150		let oracle = |id: OracleId| MultilinearPolyOracle {
151			id,
152			n_vars: inner_n_vars - log_degree,
153			tower_level: inner_tower_level + log_degree,
154			name: self.name,
155			variant: MultilinearPolyVariant::Packed(packed),
156		};
157
158		Ok(self.mut_ref.add_to_set(oracle))
159	}
160
161	pub fn projected(
162		self,
163		inner_id: OracleId,
164		values: Vec<F>,
165		start_index: usize,
166	) -> Result<OracleId, Error> {
167		let inner_n_vars = self.mut_ref.n_vars(inner_id);
168		let values_len = values.len();
169		if values_len > inner_n_vars {
170			bail!(Error::InvalidProjection {
171				n_vars: inner_n_vars,
172				values_len,
173			});
174		}
175
176		let inner = self.mut_ref.get_from_set(inner_id);
177		// TODO: This is wrong, should be F::TOWER_LEVEL
178		let tower_level = inner.binary_tower_level();
179		let projected = Projected::new(&inner, values, start_index)?;
180
181		let oracle = |id: OracleId| MultilinearPolyOracle {
182			id,
183			n_vars: inner_n_vars - values_len,
184			tower_level,
185			name: self.name,
186			variant: MultilinearPolyVariant::Projected(projected),
187		};
188
189		Ok(self.mut_ref.add_to_set(oracle))
190	}
191
192	pub fn projected_last_vars(
193		self,
194		inner_id: OracleId,
195		values: Vec<F>,
196	) -> Result<OracleId, Error> {
197		let inner_n_vars = self.mut_ref.n_vars(inner_id);
198		let start_index = inner_n_vars - values.len();
199		let values_len = values.len();
200		if values_len > inner_n_vars {
201			bail!(Error::InvalidProjection {
202				n_vars: inner_n_vars,
203				values_len,
204			});
205		}
206
207		let inner = self.mut_ref.get_from_set(inner_id);
208		// TODO: This is wrong, should be F::TOWER_LEVEL
209		let tower_level = inner.binary_tower_level();
210		let projected = Projected::new(&inner, values, start_index)?;
211
212		let oracle = |id: OracleId| MultilinearPolyOracle {
213			id,
214			n_vars: inner_n_vars - values_len,
215			tower_level,
216			name: self.name,
217			variant: MultilinearPolyVariant::Projected(projected),
218		};
219
220		Ok(self.mut_ref.add_to_set(oracle))
221	}
222
223	pub fn linear_combination(
224		self,
225		n_vars: usize,
226		inner: impl IntoIterator<Item = (OracleId, F)>,
227	) -> Result<OracleId, Error> {
228		self.linear_combination_with_offset(n_vars, F::ZERO, inner)
229	}
230
231	pub fn linear_combination_with_offset(
232		self,
233		n_vars: usize,
234		offset: F,
235		inner: impl IntoIterator<Item = (OracleId, F)>,
236	) -> Result<OracleId, Error> {
237		let inner = inner
238			.into_iter()
239			.map(|(inner_id, coeff)| {
240				if inner_id >= self.mut_ref.oracles.len() {
241					return Err(Error::InvalidOracleId(inner_id));
242				}
243				if self.mut_ref.n_vars(inner_id) != n_vars {
244					return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
245				}
246				Ok((self.mut_ref.get_from_set(inner_id), coeff))
247			})
248			.collect::<Result<Vec<_>, _>>()?;
249
250		let tower_level = inner
251			.iter()
252			.map(|(oracle, coeff)| oracle.binary_tower_level().max(coeff.min_tower_level()))
253			.max()
254			.unwrap_or(0)
255			.max(offset.min_tower_level());
256
257		let linear_combination = LinearCombination::new(n_vars, offset, inner)?;
258
259		let oracle = |id: OracleId| MultilinearPolyOracle {
260			id,
261			n_vars,
262			tower_level,
263			name: self.name,
264			variant: MultilinearPolyVariant::LinearCombination(linear_combination),
265		};
266
267		Ok(self.mut_ref.add_to_set(oracle))
268	}
269
270	pub fn composite_mle(
271		self,
272		n_vars: usize,
273		inner: impl IntoIterator<Item = OracleId>,
274		comp: ArithExpr<F>,
275	) -> Result<OracleId, Error> {
276		let inner = inner
277			.into_iter()
278			.map(|inner_id| {
279				if inner_id >= self.mut_ref.oracles.len() {
280					return Err(Error::InvalidOracleId(inner_id));
281				}
282				if self.mut_ref.n_vars(inner_id) != n_vars {
283					return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
284				}
285				Ok(self.mut_ref.get_from_set(inner_id))
286			})
287			.collect::<Result<Vec<_>, _>>()?;
288
289		let tower_level = inner
290			.iter()
291			.map(|oracle| oracle.binary_tower_level())
292			.max()
293			.unwrap_or(0);
294
295		let composite_mle = CompositeMLE::new(n_vars, inner, comp)?;
296
297		let oracle = |id: OracleId| MultilinearPolyOracle {
298			id,
299			n_vars,
300			tower_level,
301			name: self.name,
302			variant: MultilinearPolyVariant::Composite(composite_mle),
303		};
304
305		Ok(self.mut_ref.add_to_set(oracle))
306	}
307
308	pub fn zero_padded(self, inner_id: OracleId, n_vars: usize) -> Result<OracleId, Error> {
309		if inner_id >= self.mut_ref.oracles.len() {
310			bail!(Error::InvalidOracleId(inner_id));
311		}
312
313		if self.mut_ref.n_vars(inner_id) > n_vars {
314			bail!(Error::IncorrectNumberOfVariables {
315				expected: self.mut_ref.n_vars(inner_id),
316			});
317		};
318
319		let inner = self.mut_ref.get_from_set(inner_id);
320
321		let oracle = |id: OracleId| MultilinearPolyOracle {
322			id,
323			n_vars,
324			tower_level: inner.tower_level,
325			name: self.name,
326			variant: MultilinearPolyVariant::ZeroPadded(inner_id),
327		};
328
329		Ok(self.mut_ref.add_to_set(oracle))
330	}
331
332	fn add_committed_with_name(
333		&mut self,
334		n_vars: usize,
335		tower_level: usize,
336		name: Option<String>,
337	) -> OracleId {
338		let oracle = |oracle_id: OracleId| MultilinearPolyOracle {
339			id: oracle_id,
340			n_vars,
341			tower_level,
342			name: name.clone(),
343			variant: MultilinearPolyVariant::Committed,
344		};
345
346		self.mut_ref.add_to_set(oracle)
347	}
348}
349
350/// An ordered set of multilinear polynomial oracles.
351///
352/// The multilinear polynomial oracles form a directed acyclic graph, where each multilinear oracle
353/// is either transparent, committed, or derived from one or more others. Each oracle is assigned a
354/// unique `OracleId`.
355///
356/// The oracle set also tracks the committed polynomial in batches where each batch is committed
357/// together with a polynomial commitment scheme.
358#[derive(Default, Debug, Clone, SerializeBytes)]
359pub struct MultilinearOracleSet<F: TowerField> {
360	oracles: Vec<MultilinearPolyOracle<F>>,
361}
362
363impl DeserializeBytes for MultilinearOracleSet<BinaryField128b> {
364	fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result<Self, SerializationError>
365	where
366		Self: Sized,
367	{
368		Ok(Self {
369			oracles: DeserializeBytes::deserialize(read_buf, mode)?,
370		})
371	}
372}
373
374impl<F: TowerField> MultilinearOracleSet<F> {
375	pub const fn new() -> Self {
376		Self {
377			oracles: Vec::new(),
378		}
379	}
380
381	pub fn size(&self) -> usize {
382		self.oracles.len()
383	}
384
385	pub fn iter(&self) -> impl Iterator<Item = MultilinearPolyOracle<F>> + '_ {
386		(0..self.oracles.len()).map(|id| self.oracle(id))
387	}
388
389	pub const fn add(&mut self) -> MultilinearOracleSetAddition<F> {
390		MultilinearOracleSetAddition {
391			name: None,
392			mut_ref: self,
393		}
394	}
395
396	pub fn add_named<S: ToString>(&mut self, s: S) -> MultilinearOracleSetAddition<F> {
397		MultilinearOracleSetAddition {
398			name: Some(s.to_string()),
399			mut_ref: self,
400		}
401	}
402
403	pub fn is_valid_oracle_id(&self, id: OracleId) -> bool {
404		id < self.oracles.len()
405	}
406
407	fn add_to_set(
408		&mut self,
409		oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
410	) -> OracleId {
411		let id = self.oracles.len();
412		self.oracles.push(oracle(id));
413		id
414	}
415
416	fn get_from_set(&self, id: OracleId) -> MultilinearPolyOracle<F> {
417		self.oracles[id].clone()
418	}
419
420	pub fn add_transparent(
421		&mut self,
422		poly: impl MultivariatePoly<F> + 'static,
423	) -> Result<OracleId, Error> {
424		self.add().transparent(poly)
425	}
426
427	pub fn add_committed(&mut self, n_vars: usize, tower_level: usize) -> OracleId {
428		self.add().committed(n_vars, tower_level)
429	}
430
431	pub fn add_committed_multiple<const N: usize>(
432		&mut self,
433		n_vars: usize,
434		tower_level: usize,
435	) -> [OracleId; N] {
436		self.add().committed_multiple(n_vars, tower_level)
437	}
438
439	pub fn add_repeating(&mut self, id: OracleId, log_count: usize) -> Result<OracleId, Error> {
440		self.add().repeating(id, log_count)
441	}
442
443	pub fn add_shifted(
444		&mut self,
445		id: OracleId,
446		offset: usize,
447		block_bits: usize,
448		variant: ShiftVariant,
449	) -> Result<OracleId, Error> {
450		self.add().shifted(id, offset, block_bits, variant)
451	}
452
453	pub fn add_packed(&mut self, id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
454		self.add().packed(id, log_degree)
455	}
456
457	/// Adds a projection to the variables starting at `start_index`.
458	pub fn add_projected(
459		&mut self,
460		id: OracleId,
461		values: Vec<F>,
462		start_index: usize,
463	) -> Result<OracleId, Error> {
464		self.add().projected(id, values, start_index)
465	}
466
467	/// Adds a projection to the last variables.
468	pub fn add_projected_last_vars(
469		&mut self,
470		id: OracleId,
471		values: Vec<F>,
472	) -> Result<OracleId, Error> {
473		self.add().projected_last_vars(id, values)
474	}
475
476	pub fn add_linear_combination(
477		&mut self,
478		n_vars: usize,
479		inner: impl IntoIterator<Item = (OracleId, F)>,
480	) -> Result<OracleId, Error> {
481		self.add().linear_combination(n_vars, inner)
482	}
483
484	pub fn add_linear_combination_with_offset(
485		&mut self,
486		n_vars: usize,
487		offset: F,
488		inner: impl IntoIterator<Item = (OracleId, F)>,
489	) -> Result<OracleId, Error> {
490		self.add()
491			.linear_combination_with_offset(n_vars, offset, inner)
492	}
493
494	pub fn add_zero_padded(&mut self, id: OracleId, n_vars: usize) -> Result<OracleId, Error> {
495		self.add().zero_padded(id, n_vars)
496	}
497
498	pub fn add_composite_mle(
499		&mut self,
500		n_vars: usize,
501		inner: impl IntoIterator<Item = OracleId>,
502		comp: ArithExpr<F>,
503	) -> Result<OracleId, Error> {
504		self.add().composite_mle(n_vars, inner, comp)
505	}
506
507	pub fn oracle(&self, id: OracleId) -> MultilinearPolyOracle<F> {
508		self.oracles[id].clone()
509	}
510
511	pub fn n_vars(&self, id: OracleId) -> usize {
512		self.oracles[id].n_vars()
513	}
514
515	pub fn label(&self, id: OracleId) -> String {
516		self.oracles[id].label()
517	}
518
519	/// Maximum tower level of the oracle's values over the boolean hypercube.
520	pub fn tower_level(&self, id: OracleId) -> usize {
521		self.oracles[id].binary_tower_level()
522	}
523}
524
525/// A multilinear polynomial oracle in the polynomial IOP model.
526///
527/// In the multilinear polynomial IOP model, a prover sends multilinear polynomials to an oracle,
528/// and the verifier may at the end of the protocol query their evaluations at chosen points. An
529/// oracle is a verifier and prover's shared view of a polynomial that can be queried for
530/// evaluations by the verifier.
531///
532/// There are three fundamental categories of oracles:
533///
534/// 1. *Transparent oracles*. These are multilinear polynomials with a succinct description and
535///    evaluation algorithm that are known to the verifier. When the verifier queries a transparent
536///    oracle, it evaluates the polynomial itself.
537/// 2. *Committed oracles*. These are polynomials actually sent by the prover. When the polynomial
538///    IOP is compiled to an interactive protocol, these polynomial are committed with a polynomial
539///    commitment scheme.
540/// 3. *Virtual oracles*. A virtual multilinear oracle is not actually sent by the prover, but
541///    instead admits an interactive reduction for evaluation queries to evaluation queries to
542///    other oracles. This is formalized in [DP23] Section 4.
543///
544/// [DP23]: <https://eprint.iacr.org/2023/1784>
545#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
546pub struct MultilinearPolyOracle<F: TowerField> {
547	pub id: OracleId,
548	pub name: Option<String>,
549	pub n_vars: usize,
550	pub tower_level: usize,
551	pub variant: MultilinearPolyVariant<F>,
552}
553
554impl DeserializeBytes for MultilinearPolyOracle<BinaryField128b> {
555	fn deserialize(
556		mut read_buf: impl bytes::Buf,
557		mode: SerializationMode,
558	) -> Result<Self, SerializationError>
559	where
560		Self: Sized,
561	{
562		Ok(Self {
563			id: DeserializeBytes::deserialize(&mut read_buf, mode)?,
564			name: DeserializeBytes::deserialize(&mut read_buf, mode)?,
565			n_vars: DeserializeBytes::deserialize(&mut read_buf, mode)?,
566			tower_level: DeserializeBytes::deserialize(&mut read_buf, mode)?,
567			variant: DeserializeBytes::deserialize(&mut read_buf, mode)?,
568		})
569	}
570}
571
572#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
573pub enum MultilinearPolyVariant<F: TowerField> {
574	Committed,
575	Transparent(TransparentPolyOracle<F>),
576	Repeating { id: usize, log_count: usize },
577	Projected(Projected<F>),
578	Shifted(Shifted),
579	Packed(Packed),
580	LinearCombination(LinearCombination<F>),
581	ZeroPadded(OracleId),
582	Composite(CompositeMLE<F>),
583}
584
585impl DeserializeBytes for MultilinearPolyVariant<BinaryField128b> {
586	fn deserialize(
587		mut buf: impl bytes::Buf,
588		mode: SerializationMode,
589	) -> Result<Self, SerializationError>
590	where
591		Self: Sized,
592	{
593		Ok(match u8::deserialize(&mut buf, mode)? {
594			0 => Self::Committed,
595			1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
596			2 => Self::Repeating {
597				id: DeserializeBytes::deserialize(&mut buf, mode)?,
598				log_count: DeserializeBytes::deserialize(buf, mode)?,
599			},
600			3 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?),
601			4 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?),
602			5 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?),
603			6 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?),
604			7 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?),
605			variant_index => {
606				return Err(SerializationError::UnknownEnumVariant {
607					name: "MultilinearPolyVariant",
608					index: variant_index,
609				});
610			}
611		})
612	}
613}
614
615/// A transparent multilinear polynomial oracle.
616///
617/// See the [`MultilinearPolyOracle`] documentation for context.
618#[derive(Debug, Clone, Getters, CopyGetters)]
619pub struct TransparentPolyOracle<F: Field> {
620	#[get = "pub"]
621	poly: Arc<dyn MultivariatePoly<F>>,
622}
623
624impl<F: TowerField> SerializeBytes for TransparentPolyOracle<F> {
625	fn serialize(
626		&self,
627		mut write_buf: impl bytes::BufMut,
628		mode: SerializationMode,
629	) -> Result<(), SerializationError> {
630		self.poly.erased_serialize(&mut write_buf, mode)
631	}
632}
633
634impl DeserializeBytes for TransparentPolyOracle<BinaryField128b> {
635	fn deserialize(
636		read_buf: impl bytes::Buf,
637		mode: SerializationMode,
638	) -> Result<Self, SerializationError>
639	where
640		Self: Sized,
641	{
642		Ok(Self {
643			poly: Box::<dyn MultivariatePoly<BinaryField128b>>::deserialize(read_buf, mode)?.into(),
644		})
645	}
646}
647
648impl<F: TowerField> TransparentPolyOracle<F> {
649	fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
650		if poly.binary_tower_level() > F::TOWER_LEVEL {
651			bail!(Error::TowerLevelTooHigh {
652				tower_level: poly.binary_tower_level(),
653			});
654		}
655		Ok(Self { poly })
656	}
657}
658
659impl<F: Field> TransparentPolyOracle<F> {
660	/// Maximum tower level of the oracle's values over the boolean hypercube.
661	pub fn binary_tower_level(&self) -> usize {
662		self.poly.binary_tower_level()
663	}
664}
665
666impl<F: Field> PartialEq for TransparentPolyOracle<F> {
667	fn eq(&self, other: &Self) -> bool {
668		Arc::ptr_eq(&self.poly, &other.poly)
669	}
670}
671
672impl<F: Field> Eq for TransparentPolyOracle<F> {}
673
674#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
675pub struct Projected<F: TowerField> {
676	#[get_copy = "pub"]
677	id: OracleId,
678	#[get = "pub"]
679	values: Vec<F>,
680	#[get_copy = "pub"]
681	start_index: usize,
682}
683
684impl<F: TowerField> Projected<F> {
685	fn new(
686		oracle: &MultilinearPolyOracle<F>,
687		values: Vec<F>,
688		start_index: usize,
689	) -> Result<Self, Error> {
690		if values.len() + start_index > oracle.n_vars() {
691			bail!(Error::InvalidProjection {
692				n_vars: oracle.n_vars(),
693				values_len: values.len()
694			});
695		}
696		Ok(Self {
697			id: oracle.id(),
698			values,
699			start_index,
700		})
701	}
702}
703
704#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
705pub enum ShiftVariant {
706	CircularLeft,
707	LogicalLeft,
708	LogicalRight,
709}
710
711#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
712pub struct Shifted {
713	#[get_copy = "pub"]
714	id: OracleId,
715	#[get_copy = "pub"]
716	shift_offset: usize,
717	#[get_copy = "pub"]
718	block_size: usize,
719	#[get_copy = "pub"]
720	shift_variant: ShiftVariant,
721}
722
723impl Shifted {
724	fn new<F: TowerField>(
725		oracle: &MultilinearPolyOracle<F>,
726		shift_offset: usize,
727		block_size: usize,
728		shift_variant: ShiftVariant,
729	) -> Result<Self, Error> {
730		if block_size > oracle.n_vars() {
731			bail!(PolynomialError::InvalidBlockSize {
732				n_vars: oracle.n_vars(),
733			});
734		}
735
736		if shift_offset == 0 || shift_offset >= 1 << block_size {
737			bail!(PolynomialError::InvalidShiftOffset {
738				max_shift_offset: (1 << block_size) - 1,
739				shift_offset,
740			});
741		}
742
743		Ok(Self {
744			id: oracle.id(),
745			shift_offset,
746			block_size,
747			shift_variant,
748		})
749	}
750}
751
752#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
753pub struct Packed {
754	#[get_copy = "pub"]
755	id: OracleId,
756	/// The number of tower levels increased by the packing operation.
757	///
758	/// This is the base 2 logarithm of the field extension, and is called $\kappa$ in [DP23],
759	/// Section 4.3.
760	///
761	/// [DP23]: https://eprint.iacr.org/2023/1784
762	#[get_copy = "pub"]
763	log_degree: usize,
764}
765
766#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
767pub struct LinearCombination<F: TowerField> {
768	#[get_copy = "pub"]
769	n_vars: usize,
770	#[get_copy = "pub"]
771	offset: F,
772	inner: Vec<(OracleId, F)>,
773}
774
775impl<F: TowerField> LinearCombination<F> {
776	fn new(
777		n_vars: usize,
778		offset: F,
779		inner: impl IntoIterator<Item = (MultilinearPolyOracle<F>, F)>,
780	) -> Result<Self, Error> {
781		let inner = inner
782			.into_iter()
783			.map(|(oracle, value)| {
784				if oracle.n_vars() == n_vars {
785					Ok((oracle.id(), value))
786				} else {
787					Err(Error::IncorrectNumberOfVariables { expected: n_vars })
788				}
789			})
790			.collect::<Result<Vec<_>, _>>()?;
791		Ok(Self {
792			n_vars,
793			offset,
794			inner,
795		})
796	}
797
798	pub fn n_polys(&self) -> usize {
799		self.inner.len()
800	}
801
802	pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
803		self.inner.iter().map(|(id, _)| *id)
804	}
805
806	pub fn coefficients(&self) -> impl Iterator<Item = F> + '_ {
807		self.inner.iter().map(|(_, coeff)| *coeff)
808	}
809}
810
811/// MLE of a multivariate polynomial evaluated on multilinear oracles.
812///
813/// i.e. the MLE of the evaluations of $C(M_1, M_2, \ldots, M_n)$ on $\{0, 1\}^\mu$ where:
814/// - $C$ is a arbitrary polynomial in $n$ variables
815/// - $M_1, M_2, \ldots, M_n$ are multilinear oracles in μ variables
816///
817/// ($C$ should be sufficiently lightweight to be evaluated by the verifier)
818#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes)]
819pub struct CompositeMLE<F: TowerField> {
820	/// $\mu$
821	#[get_copy = "pub"]
822	n_vars: usize,
823	/// $M_1, M_2, \ldots, M_n$
824	#[getset(get = "pub")]
825	inner: Vec<OracleId>,
826	/// $C$
827	#[getset(get = "pub")]
828	c: ArithCircuitPoly<F>,
829}
830
831impl<F: TowerField> CompositeMLE<F> {
832	pub fn new(
833		n_vars: usize,
834		inner: impl IntoIterator<Item = MultilinearPolyOracle<F>>,
835		c: ArithExpr<F>,
836	) -> Result<Self, Error> {
837		let inner = inner
838			.into_iter()
839			.map(|oracle| {
840				if oracle.n_vars() == n_vars {
841					Ok(oracle.id())
842				} else {
843					Err(Error::IncorrectNumberOfVariables { expected: n_vars })
844				}
845			})
846			.collect::<Result<Vec<_>, _>>()?;
847		let c = ArithCircuitPoly::with_n_vars(inner.len(), c)
848			.map_err(|_| Error::CompositionMismatch)?; // occurs if `c` has more variables than `inner.len()`
849		Ok(Self { n_vars, inner, c })
850	}
851
852	pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
853		self.inner.iter().copied()
854	}
855
856	pub fn n_polys(&self) -> usize {
857		self.inner.len()
858	}
859}
860
861impl<F: TowerField> MultilinearPolyOracle<F> {
862	pub const fn id(&self) -> OracleId {
863		self.id
864	}
865
866	pub fn label(&self) -> String {
867		match self.name() {
868			Some(name) => format!("{}: {}", self.type_str(), name),
869			None => format!("{}: id={}", self.type_str(), self.id()),
870		}
871	}
872
873	pub fn name(&self) -> Option<&str> {
874		self.name.as_deref()
875	}
876
877	const fn type_str(&self) -> &str {
878		match self.variant {
879			MultilinearPolyVariant::Transparent(_) => "Transparent",
880			MultilinearPolyVariant::Committed => "Committed",
881			MultilinearPolyVariant::Repeating { .. } => "Repeating",
882			MultilinearPolyVariant::Projected(_) => "Projected",
883			MultilinearPolyVariant::Shifted(_) => "Shifted",
884			MultilinearPolyVariant::Packed(_) => "Packed",
885			MultilinearPolyVariant::LinearCombination(_) => "LinearCombination",
886			MultilinearPolyVariant::ZeroPadded(_) => "ZeroPadded",
887			MultilinearPolyVariant::Composite(_) => "CompositeMLE",
888		}
889	}
890
891	pub const fn n_vars(&self) -> usize {
892		self.n_vars
893	}
894
895	/// Maximum tower level of the oracle's values over the boolean hypercube.
896	pub const fn binary_tower_level(&self) -> usize {
897		self.tower_level
898	}
899
900	pub fn into_composite(self) -> CompositePolyOracle<F> {
901		let composite =
902			CompositePolyOracle::new(self.n_vars(), vec![self], IdentityCompositionPoly);
903		composite.expect("Can always apply the identity composition to one variable")
904	}
905}
906
907#[cfg(test)]
908mod tests {
909	use binius_field::{BinaryField128b, BinaryField1b, Field, TowerField};
910
911	use super::MultilinearOracleSet;
912
913	#[test]
914	fn add_projection_with_all_vars() {
915		type F = BinaryField128b;
916		let mut oracles = MultilinearOracleSet::<F>::new();
917		let data = oracles.add_committed(5, BinaryField1b::TOWER_LEVEL);
918		let start_index = 0;
919		let projected = oracles
920			.add_projected(data, vec![F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], start_index)
921			.unwrap();
922		let _ = oracles.oracle(projected);
923	}
924}