1use std::{array, fmt::Debug, sync::Arc};
4
5use binius_field::{BinaryField128b, Field, TowerField};
6use binius_macros::{DeserializeBytes, SerializeBytes};
7use binius_math::ArithExpr;
8use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
9use bytes::Buf;
10use getset::{CopyGetters, Getters};
11
12use crate::{
13 oracle::{CompositePolyOracle, Error},
14 polynomial::{
15 ArithCircuitPoly, Error as PolynomialError, IdentityCompositionPoly, MultivariatePoly,
16 },
17};
18
19pub type OracleId = usize;
21
22pub struct MultilinearOracleSetAddition<'a, F: TowerField> {
25 name: Option<String>,
26 mut_ref: &'a mut MultilinearOracleSet<F>,
27}
28
29impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
30 pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
31 if poly.binary_tower_level() > F::TOWER_LEVEL {
32 bail!(Error::TowerLevelTooHigh {
33 tower_level: poly.binary_tower_level(),
34 });
35 }
36
37 let inner = TransparentPolyOracle::new(Arc::new(poly))?;
38
39 let oracle = |id: OracleId| MultilinearPolyOracle {
40 id,
41 n_vars: inner.poly.n_vars(),
42 tower_level: inner.poly.binary_tower_level(),
43 name: self.name,
44 variant: MultilinearPolyVariant::Transparent(inner),
45 };
46
47 Ok(self.mut_ref.add_to_set(oracle))
48 }
49
50 pub fn committed(mut self, n_vars: usize, tower_level: usize) -> OracleId {
51 let name = self.name.take();
52 self.add_committed_with_name(n_vars, tower_level, name)
53 }
54
55 pub fn committed_multiple<const N: usize>(
56 mut self,
57 n_vars: usize,
58 tower_level: usize,
59 ) -> [OracleId; N] {
60 match &self.name.take() {
61 None => [0; N].map(|_| self.add_committed_with_name(n_vars, tower_level, None)),
62 Some(s) => {
63 let x: [usize; N] = array::from_fn(|i| i);
64 x.map(|i| {
65 self.add_committed_with_name(n_vars, tower_level, Some(format!("{}_{}", s, i)))
66 })
67 }
68 }
69 }
70
71 pub fn repeating(self, inner_id: OracleId, log_count: usize) -> Result<OracleId, Error> {
72 if inner_id >= self.mut_ref.oracles.len() {
73 bail!(Error::InvalidOracleId(inner_id));
74 }
75
76 let inner = self.mut_ref.get_from_set(inner_id);
77
78 let oracle = |id: OracleId| MultilinearPolyOracle {
79 id,
80 n_vars: inner.n_vars + log_count,
81 tower_level: inner.tower_level,
82 name: self.name,
83 variant: MultilinearPolyVariant::Repeating {
84 id: inner_id,
85 log_count,
86 },
87 };
88
89 Ok(self.mut_ref.add_to_set(oracle))
90 }
91
92 pub fn shifted(
93 self,
94 inner_id: OracleId,
95 offset: usize,
96 block_bits: usize,
97 variant: ShiftVariant,
98 ) -> Result<OracleId, Error> {
99 if inner_id >= self.mut_ref.oracles.len() {
100 bail!(Error::InvalidOracleId(inner_id));
101 }
102
103 let inner = self.mut_ref.get_from_set(inner_id);
104 if block_bits > inner.n_vars {
105 bail!(PolynomialError::InvalidBlockSize {
106 n_vars: inner.n_vars,
107 });
108 }
109
110 if offset == 0 || offset >= 1 << block_bits {
111 bail!(PolynomialError::InvalidShiftOffset {
112 max_shift_offset: (1 << block_bits) - 1,
113 shift_offset: offset,
114 });
115 }
116
117 let shifted = Shifted::new(&inner, offset, block_bits, variant)?;
118
119 let oracle = |id: OracleId| MultilinearPolyOracle {
120 id,
121 n_vars: inner.n_vars,
122 tower_level: inner.tower_level,
123 name: self.name,
124 variant: MultilinearPolyVariant::Shifted(shifted),
125 };
126
127 Ok(self.mut_ref.add_to_set(oracle))
128 }
129
130 pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
131 if inner_id >= self.mut_ref.oracles.len() {
132 bail!(Error::InvalidOracleId(inner_id));
133 }
134
135 let inner_n_vars = self.mut_ref.n_vars(inner_id);
136 if log_degree > inner_n_vars {
137 bail!(Error::NotEnoughVarsForPacking {
138 n_vars: inner_n_vars,
139 log_degree,
140 });
141 }
142
143 let inner_tower_level = self.mut_ref.tower_level(inner_id);
144
145 let packed = Packed {
146 id: inner_id,
147 log_degree,
148 };
149
150 let oracle = |id: OracleId| MultilinearPolyOracle {
151 id,
152 n_vars: inner_n_vars - log_degree,
153 tower_level: inner_tower_level + log_degree,
154 name: self.name,
155 variant: MultilinearPolyVariant::Packed(packed),
156 };
157
158 Ok(self.mut_ref.add_to_set(oracle))
159 }
160
161 pub fn projected(
162 self,
163 inner_id: OracleId,
164 values: Vec<F>,
165 start_index: usize,
166 ) -> Result<OracleId, Error> {
167 let inner_n_vars = self.mut_ref.n_vars(inner_id);
168 let values_len = values.len();
169 if values_len > inner_n_vars {
170 bail!(Error::InvalidProjection {
171 n_vars: inner_n_vars,
172 values_len,
173 });
174 }
175
176 let inner = self.mut_ref.get_from_set(inner_id);
177 let tower_level = inner.binary_tower_level();
179 let projected = Projected::new(&inner, values, start_index)?;
180
181 let oracle = |id: OracleId| MultilinearPolyOracle {
182 id,
183 n_vars: inner_n_vars - values_len,
184 tower_level,
185 name: self.name,
186 variant: MultilinearPolyVariant::Projected(projected),
187 };
188
189 Ok(self.mut_ref.add_to_set(oracle))
190 }
191
192 pub fn projected_last_vars(
193 self,
194 inner_id: OracleId,
195 values: Vec<F>,
196 ) -> Result<OracleId, Error> {
197 let inner_n_vars = self.mut_ref.n_vars(inner_id);
198 let start_index = inner_n_vars - values.len();
199 let values_len = values.len();
200 if values_len > inner_n_vars {
201 bail!(Error::InvalidProjection {
202 n_vars: inner_n_vars,
203 values_len,
204 });
205 }
206
207 let inner = self.mut_ref.get_from_set(inner_id);
208 let tower_level = inner.binary_tower_level();
210 let projected = Projected::new(&inner, values, start_index)?;
211
212 let oracle = |id: OracleId| MultilinearPolyOracle {
213 id,
214 n_vars: inner_n_vars - values_len,
215 tower_level,
216 name: self.name,
217 variant: MultilinearPolyVariant::Projected(projected),
218 };
219
220 Ok(self.mut_ref.add_to_set(oracle))
221 }
222
223 pub fn linear_combination(
224 self,
225 n_vars: usize,
226 inner: impl IntoIterator<Item = (OracleId, F)>,
227 ) -> Result<OracleId, Error> {
228 self.linear_combination_with_offset(n_vars, F::ZERO, inner)
229 }
230
231 pub fn linear_combination_with_offset(
232 self,
233 n_vars: usize,
234 offset: F,
235 inner: impl IntoIterator<Item = (OracleId, F)>,
236 ) -> Result<OracleId, Error> {
237 let inner = inner
238 .into_iter()
239 .map(|(inner_id, coeff)| {
240 if inner_id >= self.mut_ref.oracles.len() {
241 return Err(Error::InvalidOracleId(inner_id));
242 }
243 if self.mut_ref.n_vars(inner_id) != n_vars {
244 return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
245 }
246 Ok((self.mut_ref.get_from_set(inner_id), coeff))
247 })
248 .collect::<Result<Vec<_>, _>>()?;
249
250 let tower_level = inner
251 .iter()
252 .map(|(oracle, coeff)| oracle.binary_tower_level().max(coeff.min_tower_level()))
253 .max()
254 .unwrap_or(0)
255 .max(offset.min_tower_level());
256
257 let linear_combination = LinearCombination::new(n_vars, offset, inner)?;
258
259 let oracle = |id: OracleId| MultilinearPolyOracle {
260 id,
261 n_vars,
262 tower_level,
263 name: self.name,
264 variant: MultilinearPolyVariant::LinearCombination(linear_combination),
265 };
266
267 Ok(self.mut_ref.add_to_set(oracle))
268 }
269
270 pub fn composite_mle(
271 self,
272 n_vars: usize,
273 inner: impl IntoIterator<Item = OracleId>,
274 comp: ArithExpr<F>,
275 ) -> Result<OracleId, Error> {
276 let inner = inner
277 .into_iter()
278 .map(|inner_id| {
279 if inner_id >= self.mut_ref.oracles.len() {
280 return Err(Error::InvalidOracleId(inner_id));
281 }
282 if self.mut_ref.n_vars(inner_id) != n_vars {
283 return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
284 }
285 Ok(self.mut_ref.get_from_set(inner_id))
286 })
287 .collect::<Result<Vec<_>, _>>()?;
288
289 let tower_level = inner
290 .iter()
291 .map(|oracle| oracle.binary_tower_level())
292 .max()
293 .unwrap_or(0);
294
295 let composite_mle = CompositeMLE::new(n_vars, inner, comp)?;
296
297 let oracle = |id: OracleId| MultilinearPolyOracle {
298 id,
299 n_vars,
300 tower_level,
301 name: self.name,
302 variant: MultilinearPolyVariant::Composite(composite_mle),
303 };
304
305 Ok(self.mut_ref.add_to_set(oracle))
306 }
307
308 pub fn zero_padded(self, inner_id: OracleId, n_vars: usize) -> Result<OracleId, Error> {
309 if inner_id >= self.mut_ref.oracles.len() {
310 bail!(Error::InvalidOracleId(inner_id));
311 }
312
313 if self.mut_ref.n_vars(inner_id) > n_vars {
314 bail!(Error::IncorrectNumberOfVariables {
315 expected: self.mut_ref.n_vars(inner_id),
316 });
317 };
318
319 let inner = self.mut_ref.get_from_set(inner_id);
320
321 let oracle = |id: OracleId| MultilinearPolyOracle {
322 id,
323 n_vars,
324 tower_level: inner.tower_level,
325 name: self.name,
326 variant: MultilinearPolyVariant::ZeroPadded(inner_id),
327 };
328
329 Ok(self.mut_ref.add_to_set(oracle))
330 }
331
332 fn add_committed_with_name(
333 &mut self,
334 n_vars: usize,
335 tower_level: usize,
336 name: Option<String>,
337 ) -> OracleId {
338 let oracle = |oracle_id: OracleId| MultilinearPolyOracle {
339 id: oracle_id,
340 n_vars,
341 tower_level,
342 name: name.clone(),
343 variant: MultilinearPolyVariant::Committed,
344 };
345
346 self.mut_ref.add_to_set(oracle)
347 }
348}
349
350#[derive(Default, Debug, Clone, SerializeBytes)]
359pub struct MultilinearOracleSet<F: TowerField> {
360 oracles: Vec<MultilinearPolyOracle<F>>,
361}
362
363impl DeserializeBytes for MultilinearOracleSet<BinaryField128b> {
364 fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result<Self, SerializationError>
365 where
366 Self: Sized,
367 {
368 Ok(Self {
369 oracles: DeserializeBytes::deserialize(read_buf, mode)?,
370 })
371 }
372}
373
374impl<F: TowerField> MultilinearOracleSet<F> {
375 pub const fn new() -> Self {
376 Self {
377 oracles: Vec::new(),
378 }
379 }
380
381 pub fn size(&self) -> usize {
382 self.oracles.len()
383 }
384
385 pub fn iter(&self) -> impl Iterator<Item = MultilinearPolyOracle<F>> + '_ {
386 (0..self.oracles.len()).map(|id| self.oracle(id))
387 }
388
389 pub const fn add(&mut self) -> MultilinearOracleSetAddition<F> {
390 MultilinearOracleSetAddition {
391 name: None,
392 mut_ref: self,
393 }
394 }
395
396 pub fn add_named<S: ToString>(&mut self, s: S) -> MultilinearOracleSetAddition<F> {
397 MultilinearOracleSetAddition {
398 name: Some(s.to_string()),
399 mut_ref: self,
400 }
401 }
402
403 pub fn is_valid_oracle_id(&self, id: OracleId) -> bool {
404 id < self.oracles.len()
405 }
406
407 fn add_to_set(
408 &mut self,
409 oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
410 ) -> OracleId {
411 let id = self.oracles.len();
412 self.oracles.push(oracle(id));
413 id
414 }
415
416 fn get_from_set(&self, id: OracleId) -> MultilinearPolyOracle<F> {
417 self.oracles[id].clone()
418 }
419
420 pub fn add_transparent(
421 &mut self,
422 poly: impl MultivariatePoly<F> + 'static,
423 ) -> Result<OracleId, Error> {
424 self.add().transparent(poly)
425 }
426
427 pub fn add_committed(&mut self, n_vars: usize, tower_level: usize) -> OracleId {
428 self.add().committed(n_vars, tower_level)
429 }
430
431 pub fn add_committed_multiple<const N: usize>(
432 &mut self,
433 n_vars: usize,
434 tower_level: usize,
435 ) -> [OracleId; N] {
436 self.add().committed_multiple(n_vars, tower_level)
437 }
438
439 pub fn add_repeating(&mut self, id: OracleId, log_count: usize) -> Result<OracleId, Error> {
440 self.add().repeating(id, log_count)
441 }
442
443 pub fn add_shifted(
444 &mut self,
445 id: OracleId,
446 offset: usize,
447 block_bits: usize,
448 variant: ShiftVariant,
449 ) -> Result<OracleId, Error> {
450 self.add().shifted(id, offset, block_bits, variant)
451 }
452
453 pub fn add_packed(&mut self, id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
454 self.add().packed(id, log_degree)
455 }
456
457 pub fn add_projected(
459 &mut self,
460 id: OracleId,
461 values: Vec<F>,
462 start_index: usize,
463 ) -> Result<OracleId, Error> {
464 self.add().projected(id, values, start_index)
465 }
466
467 pub fn add_projected_last_vars(
469 &mut self,
470 id: OracleId,
471 values: Vec<F>,
472 ) -> Result<OracleId, Error> {
473 self.add().projected_last_vars(id, values)
474 }
475
476 pub fn add_linear_combination(
477 &mut self,
478 n_vars: usize,
479 inner: impl IntoIterator<Item = (OracleId, F)>,
480 ) -> Result<OracleId, Error> {
481 self.add().linear_combination(n_vars, inner)
482 }
483
484 pub fn add_linear_combination_with_offset(
485 &mut self,
486 n_vars: usize,
487 offset: F,
488 inner: impl IntoIterator<Item = (OracleId, F)>,
489 ) -> Result<OracleId, Error> {
490 self.add()
491 .linear_combination_with_offset(n_vars, offset, inner)
492 }
493
494 pub fn add_zero_padded(&mut self, id: OracleId, n_vars: usize) -> Result<OracleId, Error> {
495 self.add().zero_padded(id, n_vars)
496 }
497
498 pub fn add_composite_mle(
499 &mut self,
500 n_vars: usize,
501 inner: impl IntoIterator<Item = OracleId>,
502 comp: ArithExpr<F>,
503 ) -> Result<OracleId, Error> {
504 self.add().composite_mle(n_vars, inner, comp)
505 }
506
507 pub fn oracle(&self, id: OracleId) -> MultilinearPolyOracle<F> {
508 self.oracles[id].clone()
509 }
510
511 pub fn n_vars(&self, id: OracleId) -> usize {
512 self.oracles[id].n_vars()
513 }
514
515 pub fn label(&self, id: OracleId) -> String {
516 self.oracles[id].label()
517 }
518
519 pub fn tower_level(&self, id: OracleId) -> usize {
521 self.oracles[id].binary_tower_level()
522 }
523}
524
525#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
546pub struct MultilinearPolyOracle<F: TowerField> {
547 pub id: OracleId,
548 pub name: Option<String>,
549 pub n_vars: usize,
550 pub tower_level: usize,
551 pub variant: MultilinearPolyVariant<F>,
552}
553
554impl DeserializeBytes for MultilinearPolyOracle<BinaryField128b> {
555 fn deserialize(
556 mut read_buf: impl bytes::Buf,
557 mode: SerializationMode,
558 ) -> Result<Self, SerializationError>
559 where
560 Self: Sized,
561 {
562 Ok(Self {
563 id: DeserializeBytes::deserialize(&mut read_buf, mode)?,
564 name: DeserializeBytes::deserialize(&mut read_buf, mode)?,
565 n_vars: DeserializeBytes::deserialize(&mut read_buf, mode)?,
566 tower_level: DeserializeBytes::deserialize(&mut read_buf, mode)?,
567 variant: DeserializeBytes::deserialize(&mut read_buf, mode)?,
568 })
569 }
570}
571
572#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
573pub enum MultilinearPolyVariant<F: TowerField> {
574 Committed,
575 Transparent(TransparentPolyOracle<F>),
576 Repeating { id: usize, log_count: usize },
577 Projected(Projected<F>),
578 Shifted(Shifted),
579 Packed(Packed),
580 LinearCombination(LinearCombination<F>),
581 ZeroPadded(OracleId),
582 Composite(CompositeMLE<F>),
583}
584
585impl DeserializeBytes for MultilinearPolyVariant<BinaryField128b> {
586 fn deserialize(
587 mut buf: impl bytes::Buf,
588 mode: SerializationMode,
589 ) -> Result<Self, SerializationError>
590 where
591 Self: Sized,
592 {
593 Ok(match u8::deserialize(&mut buf, mode)? {
594 0 => Self::Committed,
595 1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
596 2 => Self::Repeating {
597 id: DeserializeBytes::deserialize(&mut buf, mode)?,
598 log_count: DeserializeBytes::deserialize(buf, mode)?,
599 },
600 3 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?),
601 4 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?),
602 5 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?),
603 6 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?),
604 7 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?),
605 variant_index => {
606 return Err(SerializationError::UnknownEnumVariant {
607 name: "MultilinearPolyVariant",
608 index: variant_index,
609 });
610 }
611 })
612 }
613}
614
615#[derive(Debug, Clone, Getters, CopyGetters)]
619pub struct TransparentPolyOracle<F: Field> {
620 #[get = "pub"]
621 poly: Arc<dyn MultivariatePoly<F>>,
622}
623
624impl<F: TowerField> SerializeBytes for TransparentPolyOracle<F> {
625 fn serialize(
626 &self,
627 mut write_buf: impl bytes::BufMut,
628 mode: SerializationMode,
629 ) -> Result<(), SerializationError> {
630 self.poly.erased_serialize(&mut write_buf, mode)
631 }
632}
633
634impl DeserializeBytes for TransparentPolyOracle<BinaryField128b> {
635 fn deserialize(
636 read_buf: impl bytes::Buf,
637 mode: SerializationMode,
638 ) -> Result<Self, SerializationError>
639 where
640 Self: Sized,
641 {
642 Ok(Self {
643 poly: Box::<dyn MultivariatePoly<BinaryField128b>>::deserialize(read_buf, mode)?.into(),
644 })
645 }
646}
647
648impl<F: TowerField> TransparentPolyOracle<F> {
649 fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
650 if poly.binary_tower_level() > F::TOWER_LEVEL {
651 bail!(Error::TowerLevelTooHigh {
652 tower_level: poly.binary_tower_level(),
653 });
654 }
655 Ok(Self { poly })
656 }
657}
658
659impl<F: Field> TransparentPolyOracle<F> {
660 pub fn binary_tower_level(&self) -> usize {
662 self.poly.binary_tower_level()
663 }
664}
665
666impl<F: Field> PartialEq for TransparentPolyOracle<F> {
667 fn eq(&self, other: &Self) -> bool {
668 Arc::ptr_eq(&self.poly, &other.poly)
669 }
670}
671
672impl<F: Field> Eq for TransparentPolyOracle<F> {}
673
674#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
675pub struct Projected<F: TowerField> {
676 #[get_copy = "pub"]
677 id: OracleId,
678 #[get = "pub"]
679 values: Vec<F>,
680 #[get_copy = "pub"]
681 start_index: usize,
682}
683
684impl<F: TowerField> Projected<F> {
685 fn new(
686 oracle: &MultilinearPolyOracle<F>,
687 values: Vec<F>,
688 start_index: usize,
689 ) -> Result<Self, Error> {
690 if values.len() + start_index > oracle.n_vars() {
691 bail!(Error::InvalidProjection {
692 n_vars: oracle.n_vars(),
693 values_len: values.len()
694 });
695 }
696 Ok(Self {
697 id: oracle.id(),
698 values,
699 start_index,
700 })
701 }
702}
703
704#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
705pub enum ShiftVariant {
706 CircularLeft,
707 LogicalLeft,
708 LogicalRight,
709}
710
711#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
712pub struct Shifted {
713 #[get_copy = "pub"]
714 id: OracleId,
715 #[get_copy = "pub"]
716 shift_offset: usize,
717 #[get_copy = "pub"]
718 block_size: usize,
719 #[get_copy = "pub"]
720 shift_variant: ShiftVariant,
721}
722
723impl Shifted {
724 fn new<F: TowerField>(
725 oracle: &MultilinearPolyOracle<F>,
726 shift_offset: usize,
727 block_size: usize,
728 shift_variant: ShiftVariant,
729 ) -> Result<Self, Error> {
730 if block_size > oracle.n_vars() {
731 bail!(PolynomialError::InvalidBlockSize {
732 n_vars: oracle.n_vars(),
733 });
734 }
735
736 if shift_offset == 0 || shift_offset >= 1 << block_size {
737 bail!(PolynomialError::InvalidShiftOffset {
738 max_shift_offset: (1 << block_size) - 1,
739 shift_offset,
740 });
741 }
742
743 Ok(Self {
744 id: oracle.id(),
745 shift_offset,
746 block_size,
747 shift_variant,
748 })
749 }
750}
751
752#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
753pub struct Packed {
754 #[get_copy = "pub"]
755 id: OracleId,
756 #[get_copy = "pub"]
763 log_degree: usize,
764}
765
766#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
767pub struct LinearCombination<F: TowerField> {
768 #[get_copy = "pub"]
769 n_vars: usize,
770 #[get_copy = "pub"]
771 offset: F,
772 inner: Vec<(OracleId, F)>,
773}
774
775impl<F: TowerField> LinearCombination<F> {
776 fn new(
777 n_vars: usize,
778 offset: F,
779 inner: impl IntoIterator<Item = (MultilinearPolyOracle<F>, F)>,
780 ) -> Result<Self, Error> {
781 let inner = inner
782 .into_iter()
783 .map(|(oracle, value)| {
784 if oracle.n_vars() == n_vars {
785 Ok((oracle.id(), value))
786 } else {
787 Err(Error::IncorrectNumberOfVariables { expected: n_vars })
788 }
789 })
790 .collect::<Result<Vec<_>, _>>()?;
791 Ok(Self {
792 n_vars,
793 offset,
794 inner,
795 })
796 }
797
798 pub fn n_polys(&self) -> usize {
799 self.inner.len()
800 }
801
802 pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
803 self.inner.iter().map(|(id, _)| *id)
804 }
805
806 pub fn coefficients(&self) -> impl Iterator<Item = F> + '_ {
807 self.inner.iter().map(|(_, coeff)| *coeff)
808 }
809}
810
811#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes)]
819pub struct CompositeMLE<F: TowerField> {
820 #[get_copy = "pub"]
822 n_vars: usize,
823 #[getset(get = "pub")]
825 inner: Vec<OracleId>,
826 #[getset(get = "pub")]
828 c: ArithCircuitPoly<F>,
829}
830
831impl<F: TowerField> CompositeMLE<F> {
832 pub fn new(
833 n_vars: usize,
834 inner: impl IntoIterator<Item = MultilinearPolyOracle<F>>,
835 c: ArithExpr<F>,
836 ) -> Result<Self, Error> {
837 let inner = inner
838 .into_iter()
839 .map(|oracle| {
840 if oracle.n_vars() == n_vars {
841 Ok(oracle.id())
842 } else {
843 Err(Error::IncorrectNumberOfVariables { expected: n_vars })
844 }
845 })
846 .collect::<Result<Vec<_>, _>>()?;
847 let c = ArithCircuitPoly::with_n_vars(inner.len(), c)
848 .map_err(|_| Error::CompositionMismatch)?; Ok(Self { n_vars, inner, c })
850 }
851
852 pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
853 self.inner.iter().copied()
854 }
855
856 pub fn n_polys(&self) -> usize {
857 self.inner.len()
858 }
859}
860
861impl<F: TowerField> MultilinearPolyOracle<F> {
862 pub const fn id(&self) -> OracleId {
863 self.id
864 }
865
866 pub fn label(&self) -> String {
867 match self.name() {
868 Some(name) => format!("{}: {}", self.type_str(), name),
869 None => format!("{}: id={}", self.type_str(), self.id()),
870 }
871 }
872
873 pub fn name(&self) -> Option<&str> {
874 self.name.as_deref()
875 }
876
877 const fn type_str(&self) -> &str {
878 match self.variant {
879 MultilinearPolyVariant::Transparent(_) => "Transparent",
880 MultilinearPolyVariant::Committed => "Committed",
881 MultilinearPolyVariant::Repeating { .. } => "Repeating",
882 MultilinearPolyVariant::Projected(_) => "Projected",
883 MultilinearPolyVariant::Shifted(_) => "Shifted",
884 MultilinearPolyVariant::Packed(_) => "Packed",
885 MultilinearPolyVariant::LinearCombination(_) => "LinearCombination",
886 MultilinearPolyVariant::ZeroPadded(_) => "ZeroPadded",
887 MultilinearPolyVariant::Composite(_) => "CompositeMLE",
888 }
889 }
890
891 pub const fn n_vars(&self) -> usize {
892 self.n_vars
893 }
894
895 pub const fn binary_tower_level(&self) -> usize {
897 self.tower_level
898 }
899
900 pub fn into_composite(self) -> CompositePolyOracle<F> {
901 let composite =
902 CompositePolyOracle::new(self.n_vars(), vec![self], IdentityCompositionPoly);
903 composite.expect("Can always apply the identity composition to one variable")
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use binius_field::{BinaryField128b, BinaryField1b, Field, TowerField};
910
911 use super::MultilinearOracleSet;
912
913 #[test]
914 fn add_projection_with_all_vars() {
915 type F = BinaryField128b;
916 let mut oracles = MultilinearOracleSet::<F>::new();
917 let data = oracles.add_committed(5, BinaryField1b::TOWER_LEVEL);
918 let start_index = 0;
919 let projected = oracles
920 .add_projected(data, vec![F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], start_index)
921 .unwrap();
922 let _ = oracles.oracle(projected);
923 }
924}