1use std::{array, fmt::Debug, sync::Arc};
4
5use binius_field::{BinaryField128b, Field, TowerField};
6use binius_macros::{DeserializeBytes, SerializeBytes};
7use binius_math::ArithExpr;
8use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
9use bytes::Buf;
10use getset::{CopyGetters, Getters};
11
12use crate::{
13 oracle::{CompositePolyOracle, Error},
14 polynomial::{
15 ArithCircuitPoly, Error as PolynomialError, IdentityCompositionPoly, MultivariatePoly,
16 },
17};
18
19pub type OracleId = usize;
21
22pub struct MultilinearOracleSetAddition<'a, F: TowerField> {
25 name: Option<String>,
26 mut_ref: &'a mut MultilinearOracleSet<F>,
27}
28
29impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
30 pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
31 if poly.binary_tower_level() > F::TOWER_LEVEL {
32 bail!(Error::TowerLevelTooHigh {
33 tower_level: poly.binary_tower_level(),
34 });
35 }
36
37 let inner = TransparentPolyOracle::new(Arc::new(poly))?;
38
39 let oracle = |id: OracleId| MultilinearPolyOracle {
40 id,
41 n_vars: inner.poly.n_vars(),
42 tower_level: inner.poly.binary_tower_level(),
43 name: self.name,
44 variant: MultilinearPolyVariant::Transparent(inner),
45 };
46
47 Ok(self.mut_ref.add_to_set(oracle))
48 }
49
50 pub fn committed(mut self, n_vars: usize, tower_level: usize) -> OracleId {
51 let name = self.name.take();
52 self.add_committed_with_name(n_vars, tower_level, name)
53 }
54
55 pub fn committed_multiple<const N: usize>(
56 mut self,
57 n_vars: usize,
58 tower_level: usize,
59 ) -> [OracleId; N] {
60 match &self.name.take() {
61 None => [0; N].map(|_| self.add_committed_with_name(n_vars, tower_level, None)),
62 Some(s) => {
63 let x: [usize; N] = array::from_fn(|i| i);
64 x.map(|i| {
65 self.add_committed_with_name(n_vars, tower_level, Some(format!("{}_{}", s, i)))
66 })
67 }
68 }
69 }
70
71 pub fn repeating(self, inner_id: OracleId, log_count: usize) -> Result<OracleId, Error> {
72 if inner_id >= self.mut_ref.oracles.len() {
73 bail!(Error::InvalidOracleId(inner_id));
74 }
75
76 let inner = self.mut_ref.get_from_set(inner_id);
77
78 let oracle = |id: OracleId| MultilinearPolyOracle {
79 id,
80 n_vars: inner.n_vars + log_count,
81 tower_level: inner.tower_level,
82 name: self.name,
83 variant: MultilinearPolyVariant::Repeating {
84 id: inner_id,
85 log_count,
86 },
87 };
88
89 Ok(self.mut_ref.add_to_set(oracle))
90 }
91
92 pub fn shifted(
93 self,
94 inner_id: OracleId,
95 offset: usize,
96 block_bits: usize,
97 variant: ShiftVariant,
98 ) -> Result<OracleId, Error> {
99 if inner_id >= self.mut_ref.oracles.len() {
100 bail!(Error::InvalidOracleId(inner_id));
101 }
102
103 let inner = self.mut_ref.get_from_set(inner_id);
104 if block_bits > inner.n_vars {
105 bail!(PolynomialError::InvalidBlockSize {
106 n_vars: inner.n_vars,
107 });
108 }
109
110 if offset == 0 || offset >= 1 << block_bits {
111 bail!(PolynomialError::InvalidShiftOffset {
112 max_shift_offset: (1 << block_bits) - 1,
113 shift_offset: offset,
114 });
115 }
116
117 let shifted = Shifted::new(&inner, offset, block_bits, variant)?;
118
119 let oracle = |id: OracleId| MultilinearPolyOracle {
120 id,
121 n_vars: inner.n_vars,
122 tower_level: inner.tower_level,
123 name: self.name,
124 variant: MultilinearPolyVariant::Shifted(shifted),
125 };
126
127 Ok(self.mut_ref.add_to_set(oracle))
128 }
129
130 pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
131 if inner_id >= self.mut_ref.oracles.len() {
132 bail!(Error::InvalidOracleId(inner_id));
133 }
134
135 let inner_n_vars = self.mut_ref.n_vars(inner_id);
136 if log_degree > inner_n_vars {
137 bail!(Error::NotEnoughVarsForPacking {
138 n_vars: inner_n_vars,
139 log_degree,
140 });
141 }
142
143 let inner_tower_level = self.mut_ref.tower_level(inner_id);
144
145 let packed = Packed {
146 id: inner_id,
147 log_degree,
148 };
149
150 let oracle = |id: OracleId| MultilinearPolyOracle {
151 id,
152 n_vars: inner_n_vars - log_degree,
153 tower_level: inner_tower_level + log_degree,
154 name: self.name,
155 variant: MultilinearPolyVariant::Packed(packed),
156 };
157
158 Ok(self.mut_ref.add_to_set(oracle))
159 }
160
161 pub fn projected(
162 self,
163 inner_id: OracleId,
164 values: Vec<F>,
165 variant: ProjectionVariant,
166 ) -> Result<OracleId, Error> {
167 let inner_n_vars = self.mut_ref.n_vars(inner_id);
168 let values_len = values.len();
169 if values_len > inner_n_vars {
170 bail!(Error::InvalidProjection {
171 n_vars: inner_n_vars,
172 values_len,
173 });
174 }
175
176 let inner = self.mut_ref.get_from_set(inner_id);
177 let tower_level = inner.binary_tower_level();
179 let projected = Projected::new(&inner, values, variant)?;
180
181 let oracle = |id: OracleId| MultilinearPolyOracle {
182 id,
183 n_vars: inner_n_vars - values_len,
184 tower_level,
185 name: self.name,
186 variant: MultilinearPolyVariant::Projected(projected),
187 };
188
189 Ok(self.mut_ref.add_to_set(oracle))
190 }
191
192 pub fn linear_combination(
193 self,
194 n_vars: usize,
195 inner: impl IntoIterator<Item = (OracleId, F)>,
196 ) -> Result<OracleId, Error> {
197 self.linear_combination_with_offset(n_vars, F::ZERO, inner)
198 }
199
200 pub fn linear_combination_with_offset(
201 self,
202 n_vars: usize,
203 offset: F,
204 inner: impl IntoIterator<Item = (OracleId, F)>,
205 ) -> Result<OracleId, Error> {
206 let inner = inner
207 .into_iter()
208 .map(|(inner_id, coeff)| {
209 if inner_id >= self.mut_ref.oracles.len() {
210 return Err(Error::InvalidOracleId(inner_id));
211 }
212 if self.mut_ref.n_vars(inner_id) != n_vars {
213 return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
214 }
215 Ok((self.mut_ref.get_from_set(inner_id), coeff))
216 })
217 .collect::<Result<Vec<_>, _>>()?;
218
219 let tower_level = inner
220 .iter()
221 .map(|(oracle, _)| oracle.binary_tower_level())
222 .max()
223 .unwrap_or(0);
224
225 let linear_combination = LinearCombination::new(n_vars, offset, inner)?;
226
227 let oracle = |id: OracleId| MultilinearPolyOracle {
228 id,
229 n_vars,
230 tower_level,
231 name: self.name,
232 variant: MultilinearPolyVariant::LinearCombination(linear_combination),
233 };
234
235 Ok(self.mut_ref.add_to_set(oracle))
236 }
237
238 pub fn composite_mle(
239 self,
240 n_vars: usize,
241 inner: impl IntoIterator<Item = OracleId>,
242 comp: ArithExpr<F>,
243 ) -> Result<OracleId, Error> {
244 let inner = inner
245 .into_iter()
246 .map(|inner_id| {
247 if inner_id >= self.mut_ref.oracles.len() {
248 return Err(Error::InvalidOracleId(inner_id));
249 }
250 if self.mut_ref.n_vars(inner_id) != n_vars {
251 return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
252 }
253 Ok(self.mut_ref.get_from_set(inner_id))
254 })
255 .collect::<Result<Vec<_>, _>>()?;
256
257 let tower_level = inner
258 .iter()
259 .map(|oracle| oracle.binary_tower_level())
260 .max()
261 .unwrap_or(0);
262
263 let composite_mle = CompositeMLE::new(n_vars, inner, comp)?;
264
265 let oracle = |id: OracleId| MultilinearPolyOracle {
266 id,
267 n_vars,
268 tower_level,
269 name: self.name,
270 variant: MultilinearPolyVariant::Composite(composite_mle),
271 };
272
273 Ok(self.mut_ref.add_to_set(oracle))
274 }
275
276 pub fn zero_padded(self, inner_id: OracleId, n_vars: usize) -> Result<OracleId, Error> {
277 if inner_id >= self.mut_ref.oracles.len() {
278 bail!(Error::InvalidOracleId(inner_id));
279 }
280
281 if self.mut_ref.n_vars(inner_id) > n_vars {
282 bail!(Error::IncorrectNumberOfVariables {
283 expected: self.mut_ref.n_vars(inner_id),
284 });
285 };
286
287 let inner = self.mut_ref.get_from_set(inner_id);
288
289 let oracle = |id: OracleId| MultilinearPolyOracle {
290 id,
291 n_vars,
292 tower_level: inner.tower_level,
293 name: self.name,
294 variant: MultilinearPolyVariant::ZeroPadded(inner_id),
295 };
296
297 Ok(self.mut_ref.add_to_set(oracle))
298 }
299
300 fn add_committed_with_name(
301 &mut self,
302 n_vars: usize,
303 tower_level: usize,
304 name: Option<String>,
305 ) -> OracleId {
306 let oracle = |oracle_id: OracleId| MultilinearPolyOracle {
307 id: oracle_id,
308 n_vars,
309 tower_level,
310 name: name.clone(),
311 variant: MultilinearPolyVariant::Committed,
312 };
313
314 self.mut_ref.add_to_set(oracle)
315 }
316}
317
318#[derive(Default, Debug, Clone, SerializeBytes)]
327pub struct MultilinearOracleSet<F: TowerField> {
328 oracles: Vec<MultilinearPolyOracle<F>>,
329}
330
331impl DeserializeBytes for MultilinearOracleSet<BinaryField128b> {
332 fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result<Self, SerializationError>
333 where
334 Self: Sized,
335 {
336 Ok(Self {
337 oracles: DeserializeBytes::deserialize(read_buf, mode)?,
338 })
339 }
340}
341
342impl<F: TowerField> MultilinearOracleSet<F> {
343 pub const fn new() -> Self {
344 Self {
345 oracles: Vec::new(),
346 }
347 }
348
349 pub fn size(&self) -> usize {
350 self.oracles.len()
351 }
352
353 pub fn iter(&self) -> impl Iterator<Item = MultilinearPolyOracle<F>> + '_ {
354 (0..self.oracles.len()).map(|id| self.oracle(id))
355 }
356
357 pub const fn add(&mut self) -> MultilinearOracleSetAddition<F> {
358 MultilinearOracleSetAddition {
359 name: None,
360 mut_ref: self,
361 }
362 }
363
364 pub fn add_named<S: ToString>(&mut self, s: S) -> MultilinearOracleSetAddition<F> {
365 MultilinearOracleSetAddition {
366 name: Some(s.to_string()),
367 mut_ref: self,
368 }
369 }
370
371 pub fn is_valid_oracle_id(&self, id: OracleId) -> bool {
372 id < self.oracles.len()
373 }
374
375 fn add_to_set(
376 &mut self,
377 oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
378 ) -> OracleId {
379 let id = self.oracles.len();
380 self.oracles.push(oracle(id));
381 id
382 }
383
384 fn get_from_set(&self, id: OracleId) -> MultilinearPolyOracle<F> {
385 self.oracles[id].clone()
386 }
387
388 pub fn add_transparent(
389 &mut self,
390 poly: impl MultivariatePoly<F> + 'static,
391 ) -> Result<OracleId, Error> {
392 self.add().transparent(poly)
393 }
394
395 pub fn add_committed(&mut self, n_vars: usize, tower_level: usize) -> OracleId {
396 self.add().committed(n_vars, tower_level)
397 }
398
399 pub fn add_committed_multiple<const N: usize>(
400 &mut self,
401 n_vars: usize,
402 tower_level: usize,
403 ) -> [OracleId; N] {
404 self.add().committed_multiple(n_vars, tower_level)
405 }
406
407 pub fn add_repeating(&mut self, id: OracleId, log_count: usize) -> Result<OracleId, Error> {
408 self.add().repeating(id, log_count)
409 }
410
411 pub fn add_shifted(
412 &mut self,
413 id: OracleId,
414 offset: usize,
415 block_bits: usize,
416 variant: ShiftVariant,
417 ) -> Result<OracleId, Error> {
418 self.add().shifted(id, offset, block_bits, variant)
419 }
420
421 pub fn add_packed(&mut self, id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
422 self.add().packed(id, log_degree)
423 }
424
425 pub fn add_projected(
426 &mut self,
427 id: OracleId,
428 values: Vec<F>,
429 variant: ProjectionVariant,
430 ) -> Result<OracleId, Error> {
431 self.add().projected(id, values, variant)
432 }
433
434 pub fn add_linear_combination(
435 &mut self,
436 n_vars: usize,
437 inner: impl IntoIterator<Item = (OracleId, F)>,
438 ) -> Result<OracleId, Error> {
439 self.add().linear_combination(n_vars, inner)
440 }
441
442 pub fn add_linear_combination_with_offset(
443 &mut self,
444 n_vars: usize,
445 offset: F,
446 inner: impl IntoIterator<Item = (OracleId, F)>,
447 ) -> Result<OracleId, Error> {
448 self.add()
449 .linear_combination_with_offset(n_vars, offset, inner)
450 }
451
452 pub fn add_zero_padded(&mut self, id: OracleId, n_vars: usize) -> Result<OracleId, Error> {
453 self.add().zero_padded(id, n_vars)
454 }
455
456 pub fn add_composite_mle(
457 &mut self,
458 n_vars: usize,
459 inner: impl IntoIterator<Item = OracleId>,
460 comp: ArithExpr<F>,
461 ) -> Result<OracleId, Error> {
462 self.add().composite_mle(n_vars, inner, comp)
463 }
464
465 pub fn oracle(&self, id: OracleId) -> MultilinearPolyOracle<F> {
466 self.oracles[id].clone()
467 }
468
469 pub fn n_vars(&self, id: OracleId) -> usize {
470 self.oracles[id].n_vars()
471 }
472
473 pub fn label(&self, id: OracleId) -> String {
474 self.oracles[id].label()
475 }
476
477 pub fn tower_level(&self, id: OracleId) -> usize {
479 self.oracles[id].binary_tower_level()
480 }
481}
482
483#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
504pub struct MultilinearPolyOracle<F: TowerField> {
505 pub id: OracleId,
506 pub name: Option<String>,
507 pub n_vars: usize,
508 pub tower_level: usize,
509 pub variant: MultilinearPolyVariant<F>,
510}
511
512impl DeserializeBytes for MultilinearPolyOracle<BinaryField128b> {
513 fn deserialize(
514 mut read_buf: impl bytes::Buf,
515 mode: SerializationMode,
516 ) -> Result<Self, SerializationError>
517 where
518 Self: Sized,
519 {
520 Ok(Self {
521 id: DeserializeBytes::deserialize(&mut read_buf, mode)?,
522 name: DeserializeBytes::deserialize(&mut read_buf, mode)?,
523 n_vars: DeserializeBytes::deserialize(&mut read_buf, mode)?,
524 tower_level: DeserializeBytes::deserialize(&mut read_buf, mode)?,
525 variant: DeserializeBytes::deserialize(&mut read_buf, mode)?,
526 })
527 }
528}
529
530#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
531pub enum MultilinearPolyVariant<F: TowerField> {
532 Committed,
533 Transparent(TransparentPolyOracle<F>),
534 Repeating { id: usize, log_count: usize },
535 Projected(Projected<F>),
536 Shifted(Shifted),
537 Packed(Packed),
538 LinearCombination(LinearCombination<F>),
539 ZeroPadded(OracleId),
540 Composite(CompositeMLE<F>),
541}
542
543impl DeserializeBytes for MultilinearPolyVariant<BinaryField128b> {
544 fn deserialize(
545 mut buf: impl bytes::Buf,
546 mode: SerializationMode,
547 ) -> Result<Self, SerializationError>
548 where
549 Self: Sized,
550 {
551 Ok(match u8::deserialize(&mut buf, mode)? {
552 0 => Self::Committed,
553 1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
554 2 => Self::Repeating {
555 id: DeserializeBytes::deserialize(&mut buf, mode)?,
556 log_count: DeserializeBytes::deserialize(buf, mode)?,
557 },
558 3 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?),
559 4 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?),
560 5 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?),
561 6 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?),
562 7 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?),
563 variant_index => {
564 return Err(SerializationError::UnknownEnumVariant {
565 name: "MultilinearPolyVariant",
566 index: variant_index,
567 });
568 }
569 })
570 }
571}
572
573#[derive(Debug, Clone, Getters, CopyGetters)]
577pub struct TransparentPolyOracle<F: Field> {
578 #[get = "pub"]
579 poly: Arc<dyn MultivariatePoly<F>>,
580}
581
582impl<F: TowerField> SerializeBytes for TransparentPolyOracle<F> {
583 fn serialize(
584 &self,
585 mut write_buf: impl bytes::BufMut,
586 mode: SerializationMode,
587 ) -> Result<(), SerializationError> {
588 self.poly.erased_serialize(&mut write_buf, mode)
589 }
590}
591
592impl DeserializeBytes for TransparentPolyOracle<BinaryField128b> {
593 fn deserialize(
594 read_buf: impl bytes::Buf,
595 mode: SerializationMode,
596 ) -> Result<Self, SerializationError>
597 where
598 Self: Sized,
599 {
600 Ok(Self {
601 poly: Box::<dyn MultivariatePoly<BinaryField128b>>::deserialize(read_buf, mode)?.into(),
602 })
603 }
604}
605
606impl<F: TowerField> TransparentPolyOracle<F> {
607 fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
608 if poly.binary_tower_level() > F::TOWER_LEVEL {
609 bail!(Error::TowerLevelTooHigh {
610 tower_level: poly.binary_tower_level(),
611 });
612 }
613 Ok(Self { poly })
614 }
615}
616
617impl<F: Field> TransparentPolyOracle<F> {
618 pub fn binary_tower_level(&self) -> usize {
620 self.poly.binary_tower_level()
621 }
622}
623
624impl<F: Field> PartialEq for TransparentPolyOracle<F> {
625 fn eq(&self, other: &Self) -> bool {
626 Arc::ptr_eq(&self.poly, &other.poly)
627 }
628}
629
630impl<F: Field> Eq for TransparentPolyOracle<F> {}
631
632#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
633pub enum ProjectionVariant {
634 FirstVars,
635 LastVars,
636}
637
638#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
639pub struct Projected<F: TowerField> {
640 #[get_copy = "pub"]
641 id: OracleId,
642 #[get = "pub"]
643 values: Vec<F>,
644 #[get_copy = "pub"]
645 projection_variant: ProjectionVariant,
646}
647
648impl<F: TowerField> Projected<F> {
649 fn new(
650 oracle: &MultilinearPolyOracle<F>,
651 values: Vec<F>,
652 projection_variant: ProjectionVariant,
653 ) -> Result<Self, Error> {
654 if values.len() > oracle.n_vars() {
655 bail!(Error::InvalidProjection {
656 n_vars: oracle.n_vars(),
657 values_len: values.len()
658 });
659 }
660 Ok(Self {
661 id: oracle.id(),
662 values,
663 projection_variant,
664 })
665 }
666}
667
668#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
669pub enum ShiftVariant {
670 CircularLeft,
671 LogicalLeft,
672 LogicalRight,
673}
674
675#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
676pub struct Shifted {
677 #[get_copy = "pub"]
678 id: OracleId,
679 #[get_copy = "pub"]
680 shift_offset: usize,
681 #[get_copy = "pub"]
682 block_size: usize,
683 #[get_copy = "pub"]
684 shift_variant: ShiftVariant,
685}
686
687impl Shifted {
688 fn new<F: TowerField>(
689 oracle: &MultilinearPolyOracle<F>,
690 shift_offset: usize,
691 block_size: usize,
692 shift_variant: ShiftVariant,
693 ) -> Result<Self, Error> {
694 if block_size > oracle.n_vars() {
695 bail!(PolynomialError::InvalidBlockSize {
696 n_vars: oracle.n_vars(),
697 });
698 }
699
700 if shift_offset == 0 || shift_offset >= 1 << block_size {
701 bail!(PolynomialError::InvalidShiftOffset {
702 max_shift_offset: (1 << block_size) - 1,
703 shift_offset,
704 });
705 }
706
707 Ok(Self {
708 id: oracle.id(),
709 shift_offset,
710 block_size,
711 shift_variant,
712 })
713 }
714}
715
716#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
717pub struct Packed {
718 #[get_copy = "pub"]
719 id: OracleId,
720 #[get_copy = "pub"]
727 log_degree: usize,
728}
729
730#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
731pub struct LinearCombination<F: TowerField> {
732 #[get_copy = "pub"]
733 n_vars: usize,
734 #[get_copy = "pub"]
735 offset: F,
736 inner: Vec<(OracleId, F)>,
737}
738
739impl<F: TowerField> LinearCombination<F> {
740 fn new(
741 n_vars: usize,
742 offset: F,
743 inner: impl IntoIterator<Item = (MultilinearPolyOracle<F>, F)>,
744 ) -> Result<Self, Error> {
745 let inner = inner
746 .into_iter()
747 .map(|(oracle, value)| {
748 if oracle.n_vars() == n_vars {
749 Ok((oracle.id(), value))
750 } else {
751 Err(Error::IncorrectNumberOfVariables { expected: n_vars })
752 }
753 })
754 .collect::<Result<Vec<_>, _>>()?;
755 Ok(Self {
756 n_vars,
757 offset,
758 inner,
759 })
760 }
761
762 pub fn n_polys(&self) -> usize {
763 self.inner.len()
764 }
765
766 pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
767 self.inner.iter().map(|(id, _)| *id)
768 }
769
770 pub fn coefficients(&self) -> impl Iterator<Item = F> + '_ {
771 self.inner.iter().map(|(_, coeff)| *coeff)
772 }
773}
774
775#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes)]
783pub struct CompositeMLE<F: TowerField> {
784 #[get_copy = "pub"]
786 n_vars: usize,
787 #[getset(get = "pub")]
789 inner: Vec<OracleId>,
790 #[getset(get = "pub")]
792 c: ArithCircuitPoly<F>,
793}
794
795impl<F: TowerField> CompositeMLE<F> {
796 pub fn new(
797 n_vars: usize,
798 inner: impl IntoIterator<Item = MultilinearPolyOracle<F>>,
799 c: ArithExpr<F>,
800 ) -> Result<Self, Error> {
801 let inner = inner
802 .into_iter()
803 .map(|oracle| {
804 if oracle.n_vars() == n_vars {
805 Ok(oracle.id())
806 } else {
807 Err(Error::IncorrectNumberOfVariables { expected: n_vars })
808 }
809 })
810 .collect::<Result<Vec<_>, _>>()?;
811 let c = ArithCircuitPoly::with_n_vars(inner.len(), c)
812 .map_err(|_| Error::CompositionMismatch)?; Ok(Self { n_vars, inner, c })
814 }
815
816 pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
817 self.inner.iter().copied()
818 }
819
820 pub fn n_polys(&self) -> usize {
821 self.inner.len()
822 }
823}
824
825impl<F: TowerField> MultilinearPolyOracle<F> {
826 pub const fn id(&self) -> OracleId {
827 self.id
828 }
829
830 pub fn label(&self) -> String {
831 match self.name() {
832 Some(name) => format!("{}: {}", self.type_str(), name),
833 None => format!("{}: id={}", self.type_str(), self.id()),
834 }
835 }
836
837 pub fn name(&self) -> Option<&str> {
838 self.name.as_deref()
839 }
840
841 const fn type_str(&self) -> &str {
842 match self.variant {
843 MultilinearPolyVariant::Transparent(_) => "Transparent",
844 MultilinearPolyVariant::Committed => "Committed",
845 MultilinearPolyVariant::Repeating { .. } => "Repeating",
846 MultilinearPolyVariant::Projected(_) => "Projected",
847 MultilinearPolyVariant::Shifted(_) => "Shifted",
848 MultilinearPolyVariant::Packed(_) => "Packed",
849 MultilinearPolyVariant::LinearCombination(_) => "LinearCombination",
850 MultilinearPolyVariant::ZeroPadded(_) => "ZeroPadded",
851 MultilinearPolyVariant::Composite(_) => "CompositeMLE",
852 }
853 }
854
855 pub const fn n_vars(&self) -> usize {
856 self.n_vars
857 }
858
859 pub const fn binary_tower_level(&self) -> usize {
861 self.tower_level
862 }
863
864 pub fn into_composite(self) -> CompositePolyOracle<F> {
865 let composite =
866 CompositePolyOracle::new(self.n_vars(), vec![self], IdentityCompositionPoly);
867 composite.expect("Can always apply the identity composition to one variable")
868 }
869}
870
871#[cfg(test)]
872mod tests {
873 use binius_field::{BinaryField128b, BinaryField1b, Field, TowerField};
874
875 use super::{MultilinearOracleSet, ProjectionVariant};
876
877 #[test]
878 fn add_projection_with_all_vars() {
879 type F = BinaryField128b;
880 let mut oracles = MultilinearOracleSet::<F>::new();
881 let data = oracles.add_committed(5, BinaryField1b::TOWER_LEVEL);
882 let projected = oracles
883 .add_projected(
884 data,
885 vec![F::ONE, F::ONE, F::ONE, F::ONE, F::ONE],
886 ProjectionVariant::FirstVars,
887 )
888 .unwrap();
889 let _ = oracles.oracle(projected);
890 }
891}