1use std::{array, fmt::Debug, sync::Arc};
4
5use binius_field::{BinaryField128b, Field, TowerField};
6use binius_macros::{DeserializeBytes, SerializeBytes};
7use binius_math::ArithCircuit;
8use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
9use getset::{CopyGetters, Getters};
10
11use crate::{
12 oracle::{CompositePolyOracle, Error},
13 polynomial::{
14 ArithCircuitPoly, Error as PolynomialError, IdentityCompositionPoly, MultivariatePoly,
15 },
16};
17
18pub type OracleId = usize;
20
21pub struct MultilinearOracleSetAddition<'a, F: TowerField> {
24 name: Option<String>,
25 mut_ref: &'a mut MultilinearOracleSet<F>,
26}
27
28impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
29 pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
30 if poly.binary_tower_level() > F::TOWER_LEVEL {
31 bail!(Error::TowerLevelTooHigh {
32 tower_level: poly.binary_tower_level(),
33 });
34 }
35
36 let inner = TransparentPolyOracle::new(Arc::new(poly))?;
37
38 let oracle = |id: OracleId| MultilinearPolyOracle {
39 id,
40 n_vars: inner.poly.n_vars(),
41 tower_level: inner.poly.binary_tower_level(),
42 name: self.name,
43 variant: MultilinearPolyVariant::Transparent(inner),
44 };
45
46 Ok(self.mut_ref.add_to_set(oracle))
47 }
48
49 pub fn committed(mut self, n_vars: usize, tower_level: usize) -> OracleId {
50 let name = self.name.take();
51 self.add_committed_with_name(n_vars, tower_level, name)
52 }
53
54 pub fn committed_multiple<const N: usize>(
55 mut self,
56 n_vars: usize,
57 tower_level: usize,
58 ) -> [OracleId; N] {
59 match &self.name.take() {
60 None => [0; N].map(|_| self.add_committed_with_name(n_vars, tower_level, None)),
61 Some(s) => {
62 let x: [usize; N] = array::from_fn(|i| i);
63 x.map(|i| {
64 self.add_committed_with_name(n_vars, tower_level, Some(format!("{s}_{i}")))
65 })
66 }
67 }
68 }
69
70 pub fn repeating(self, inner_id: OracleId, log_count: usize) -> Result<OracleId, Error> {
71 if inner_id >= self.mut_ref.oracles.len() {
72 bail!(Error::InvalidOracleId(inner_id));
73 }
74
75 let inner = self.mut_ref.get_from_set(inner_id);
76
77 let oracle = |id: OracleId| MultilinearPolyOracle {
78 id,
79 n_vars: inner.n_vars + log_count,
80 tower_level: inner.tower_level,
81 name: self.name,
82 variant: MultilinearPolyVariant::Repeating {
83 id: inner_id,
84 log_count,
85 },
86 };
87
88 Ok(self.mut_ref.add_to_set(oracle))
89 }
90
91 pub fn shifted(
92 self,
93 inner_id: OracleId,
94 offset: usize,
95 block_bits: usize,
96 variant: ShiftVariant,
97 ) -> Result<OracleId, Error> {
98 if inner_id >= self.mut_ref.oracles.len() {
99 bail!(Error::InvalidOracleId(inner_id));
100 }
101
102 let inner = self.mut_ref.get_from_set(inner_id);
103 if block_bits > inner.n_vars {
104 bail!(PolynomialError::InvalidBlockSize {
105 n_vars: inner.n_vars,
106 });
107 }
108
109 if offset == 0 || offset >= 1 << block_bits {
110 bail!(PolynomialError::InvalidShiftOffset {
111 max_shift_offset: (1 << block_bits) - 1,
112 shift_offset: offset,
113 });
114 }
115
116 let shifted = Shifted::new(&inner, offset, block_bits, variant)?;
117
118 let oracle = |id: OracleId| MultilinearPolyOracle {
119 id,
120 n_vars: inner.n_vars,
121 tower_level: inner.tower_level,
122 name: self.name,
123 variant: MultilinearPolyVariant::Shifted(shifted),
124 };
125
126 Ok(self.mut_ref.add_to_set(oracle))
127 }
128
129 pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
130 if inner_id >= self.mut_ref.oracles.len() {
131 bail!(Error::InvalidOracleId(inner_id));
132 }
133
134 let inner_n_vars = self.mut_ref.n_vars(inner_id);
135 if log_degree > inner_n_vars {
136 bail!(Error::NotEnoughVarsForPacking {
137 n_vars: inner_n_vars,
138 log_degree,
139 });
140 }
141
142 let inner_tower_level = self.mut_ref.tower_level(inner_id);
143
144 let packed = Packed {
145 id: inner_id,
146 log_degree,
147 };
148
149 let oracle = |id: OracleId| MultilinearPolyOracle {
150 id,
151 n_vars: inner_n_vars - log_degree,
152 tower_level: inner_tower_level + log_degree,
153 name: self.name,
154 variant: MultilinearPolyVariant::Packed(packed),
155 };
156
157 Ok(self.mut_ref.add_to_set(oracle))
158 }
159
160 pub fn projected(
161 self,
162 inner_id: OracleId,
163 values: Vec<F>,
164 start_index: usize,
165 ) -> Result<OracleId, Error> {
166 let inner_n_vars = self.mut_ref.n_vars(inner_id);
167 let values_len = values.len();
168 if values_len > inner_n_vars {
169 bail!(Error::InvalidProjection {
170 n_vars: inner_n_vars,
171 values_len,
172 });
173 }
174
175 let inner = self.mut_ref.get_from_set(inner_id);
176 let tower_level = inner.binary_tower_level();
178 let projected = Projected::new(&inner, values, start_index)?;
179
180 let oracle = |id: OracleId| MultilinearPolyOracle {
181 id,
182 n_vars: inner_n_vars - values_len,
183 tower_level,
184 name: self.name,
185 variant: MultilinearPolyVariant::Projected(projected),
186 };
187
188 Ok(self.mut_ref.add_to_set(oracle))
189 }
190
191 pub fn projected_last_vars(
192 self,
193 inner_id: OracleId,
194 values: Vec<F>,
195 ) -> Result<OracleId, Error> {
196 let inner_n_vars = self.mut_ref.n_vars(inner_id);
197 let start_index = inner_n_vars - values.len();
198 let values_len = values.len();
199 if values_len > inner_n_vars {
200 bail!(Error::InvalidProjection {
201 n_vars: inner_n_vars,
202 values_len,
203 });
204 }
205
206 let inner = self.mut_ref.get_from_set(inner_id);
207 let tower_level = inner.binary_tower_level();
209 let projected = Projected::new(&inner, values, start_index)?;
210
211 let oracle = |id: OracleId| MultilinearPolyOracle {
212 id,
213 n_vars: inner_n_vars - values_len,
214 tower_level,
215 name: self.name,
216 variant: MultilinearPolyVariant::Projected(projected),
217 };
218
219 Ok(self.mut_ref.add_to_set(oracle))
220 }
221
222 pub fn linear_combination(
223 self,
224 n_vars: usize,
225 inner: impl IntoIterator<Item = (OracleId, F)>,
226 ) -> Result<OracleId, Error> {
227 self.linear_combination_with_offset(n_vars, F::ZERO, inner)
228 }
229
230 pub fn linear_combination_with_offset(
231 self,
232 n_vars: usize,
233 offset: F,
234 inner: impl IntoIterator<Item = (OracleId, F)>,
235 ) -> Result<OracleId, Error> {
236 let inner = inner
237 .into_iter()
238 .map(|(inner_id, coeff)| {
239 if inner_id >= self.mut_ref.oracles.len() {
240 return Err(Error::InvalidOracleId(inner_id));
241 }
242 if self.mut_ref.n_vars(inner_id) != n_vars {
243 return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
244 }
245 Ok((self.mut_ref.get_from_set(inner_id), coeff))
246 })
247 .collect::<Result<Vec<_>, _>>()?;
248
249 let tower_level = inner
250 .iter()
251 .map(|(oracle, coeff)| oracle.binary_tower_level().max(coeff.min_tower_level()))
252 .max()
253 .unwrap_or(0)
254 .max(offset.min_tower_level());
255
256 let linear_combination = LinearCombination::new(n_vars, offset, inner)?;
257
258 let oracle = |id: OracleId| MultilinearPolyOracle {
259 id,
260 n_vars,
261 tower_level,
262 name: self.name,
263 variant: MultilinearPolyVariant::LinearCombination(linear_combination),
264 };
265
266 Ok(self.mut_ref.add_to_set(oracle))
267 }
268
269 pub fn composite_mle(
270 self,
271 n_vars: usize,
272 inner: impl IntoIterator<Item = OracleId>,
273 comp: ArithCircuit<F>,
274 ) -> Result<OracleId, Error> {
275 let inner = inner
276 .into_iter()
277 .map(|inner_id| {
278 if inner_id >= self.mut_ref.oracles.len() {
279 return Err(Error::InvalidOracleId(inner_id));
280 }
281 if self.mut_ref.n_vars(inner_id) != n_vars {
282 return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
283 }
284 Ok(self.mut_ref.get_from_set(inner_id))
285 })
286 .collect::<Result<Vec<_>, _>>()?;
287
288 let tower_level = inner
289 .iter()
290 .map(|oracle| oracle.binary_tower_level())
291 .max()
292 .unwrap_or(0);
293
294 let composite_mle = CompositeMLE::new(n_vars, inner, comp)?;
295
296 let oracle = |id: OracleId| MultilinearPolyOracle {
297 id,
298 n_vars,
299 tower_level,
300 name: self.name,
301 variant: MultilinearPolyVariant::Composite(composite_mle),
302 };
303
304 Ok(self.mut_ref.add_to_set(oracle))
305 }
306
307 pub fn zero_padded(
308 self,
309 inner_id: OracleId,
310 n_pad_vars: usize,
311 nonzero_index: usize,
312 start_index: usize,
313 ) -> Result<OracleId, Error> {
314 let inner_n_vars = self.mut_ref.n_vars(inner_id);
315 if start_index > inner_n_vars {
316 bail!(Error::InvalidStartIndex {
317 expected: inner_n_vars
318 });
319 }
320
321 let inner = self.mut_ref.get_from_set(inner_id);
322 let tower_level = inner.binary_tower_level();
323 let padded = ZeroPadded::new(&inner, n_pad_vars, nonzero_index, start_index)?;
324
325 let oracle = |id: OracleId| MultilinearPolyOracle {
326 id,
327 n_vars: inner_n_vars + n_pad_vars,
328 tower_level,
329 name: self.name,
330 variant: MultilinearPolyVariant::ZeroPadded(padded),
331 };
332
333 Ok(self.mut_ref.add_to_set(oracle))
334 }
335
336 fn add_committed_with_name(
337 &mut self,
338 n_vars: usize,
339 tower_level: usize,
340 name: Option<String>,
341 ) -> OracleId {
342 let oracle = |oracle_id: OracleId| MultilinearPolyOracle {
343 id: oracle_id,
344 n_vars,
345 tower_level,
346 name: name.clone(),
347 variant: MultilinearPolyVariant::Committed,
348 };
349
350 self.mut_ref.add_to_set(oracle)
351 }
352}
353
354#[derive(Default, Debug, Clone, SerializeBytes, DeserializeBytes)]
363#[deserialize_bytes(eval_generics(F = BinaryField128b))]
364pub struct MultilinearOracleSet<F: TowerField> {
365 oracles: Vec<MultilinearPolyOracle<F>>,
366}
367
368impl<F: TowerField> MultilinearOracleSet<F> {
369 pub const fn new() -> Self {
370 Self {
371 oracles: Vec::new(),
372 }
373 }
374
375 pub fn size(&self) -> usize {
376 self.oracles.len()
377 }
378
379 pub fn iter(&self) -> impl Iterator<Item = MultilinearPolyOracle<F>> + '_ {
380 (0..self.oracles.len()).map(|id| self.oracle(id))
381 }
382
383 pub const fn add(&mut self) -> MultilinearOracleSetAddition<F> {
384 MultilinearOracleSetAddition {
385 name: None,
386 mut_ref: self,
387 }
388 }
389
390 pub fn add_named<S: ToString>(&mut self, s: S) -> MultilinearOracleSetAddition<F> {
391 MultilinearOracleSetAddition {
392 name: Some(s.to_string()),
393 mut_ref: self,
394 }
395 }
396
397 pub fn is_valid_oracle_id(&self, id: OracleId) -> bool {
398 id < self.oracles.len()
399 }
400
401 fn add_to_set(
402 &mut self,
403 oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
404 ) -> OracleId {
405 let id = self.oracles.len();
406 self.oracles.push(oracle(id));
407 id
408 }
409
410 fn get_from_set(&self, id: OracleId) -> MultilinearPolyOracle<F> {
411 self.oracles[id].clone()
412 }
413
414 pub fn add_transparent(
415 &mut self,
416 poly: impl MultivariatePoly<F> + 'static,
417 ) -> Result<OracleId, Error> {
418 self.add().transparent(poly)
419 }
420
421 pub fn add_committed(&mut self, n_vars: usize, tower_level: usize) -> OracleId {
422 self.add().committed(n_vars, tower_level)
423 }
424
425 pub fn add_committed_multiple<const N: usize>(
426 &mut self,
427 n_vars: usize,
428 tower_level: usize,
429 ) -> [OracleId; N] {
430 self.add().committed_multiple(n_vars, tower_level)
431 }
432
433 pub fn add_repeating(&mut self, id: OracleId, log_count: usize) -> Result<OracleId, Error> {
434 self.add().repeating(id, log_count)
435 }
436
437 pub fn add_shifted(
438 &mut self,
439 id: OracleId,
440 offset: usize,
441 block_bits: usize,
442 variant: ShiftVariant,
443 ) -> Result<OracleId, Error> {
444 self.add().shifted(id, offset, block_bits, variant)
445 }
446
447 pub fn add_packed(&mut self, id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
448 self.add().packed(id, log_degree)
449 }
450
451 pub fn add_projected(
453 &mut self,
454 id: OracleId,
455 values: Vec<F>,
456 start_index: usize,
457 ) -> Result<OracleId, Error> {
458 self.add().projected(id, values, start_index)
459 }
460
461 pub fn add_projected_last_vars(
463 &mut self,
464 id: OracleId,
465 values: Vec<F>,
466 ) -> Result<OracleId, Error> {
467 self.add().projected_last_vars(id, values)
468 }
469
470 pub fn add_linear_combination(
471 &mut self,
472 n_vars: usize,
473 inner: impl IntoIterator<Item = (OracleId, F)>,
474 ) -> Result<OracleId, Error> {
475 self.add().linear_combination(n_vars, inner)
476 }
477
478 pub fn add_linear_combination_with_offset(
479 &mut self,
480 n_vars: usize,
481 offset: F,
482 inner: impl IntoIterator<Item = (OracleId, F)>,
483 ) -> Result<OracleId, Error> {
484 self.add()
485 .linear_combination_with_offset(n_vars, offset, inner)
486 }
487
488 pub fn add_zero_padded(
489 &mut self,
490 id: OracleId,
491 n_pad_vars: usize,
492 nonzero_index: usize,
493 start_index: usize,
494 ) -> Result<OracleId, Error> {
495 self.add()
496 .zero_padded(id, n_pad_vars, nonzero_index, start_index)
497 }
498
499 pub fn add_composite_mle(
500 &mut self,
501 n_vars: usize,
502 inner: impl IntoIterator<Item = OracleId>,
503 comp: ArithCircuit<F>,
504 ) -> Result<OracleId, Error> {
505 self.add().composite_mle(n_vars, inner, comp)
506 }
507
508 pub fn oracle(&self, id: OracleId) -> MultilinearPolyOracle<F> {
509 self.oracles[id].clone()
510 }
511
512 pub fn n_vars(&self, id: OracleId) -> usize {
513 self.oracles[id].n_vars()
514 }
515
516 pub fn label(&self, id: OracleId) -> String {
517 self.oracles[id].label()
518 }
519
520 pub fn tower_level(&self, id: OracleId) -> usize {
522 self.oracles[id].binary_tower_level()
523 }
524}
525
526#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
547#[deserialize_bytes(eval_generics(F = BinaryField128b))]
548pub struct MultilinearPolyOracle<F: TowerField> {
549 pub id: OracleId,
550 pub name: Option<String>,
551 pub n_vars: usize,
552 pub tower_level: usize,
553 pub variant: MultilinearPolyVariant<F>,
554}
555
556#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
557pub enum MultilinearPolyVariant<F: TowerField> {
558 Committed,
559 Transparent(TransparentPolyOracle<F>),
560 Repeating { id: usize, log_count: usize },
561 Projected(Projected<F>),
562 Shifted(Shifted),
563 Packed(Packed),
564 LinearCombination(LinearCombination<F>),
565 ZeroPadded(ZeroPadded),
566 Composite(CompositeMLE<F>),
567}
568
569impl DeserializeBytes for MultilinearPolyVariant<BinaryField128b> {
570 fn deserialize(
571 mut buf: impl bytes::Buf,
572 mode: SerializationMode,
573 ) -> Result<Self, SerializationError>
574 where
575 Self: Sized,
576 {
577 Ok(match u8::deserialize(&mut buf, mode)? {
578 0 => Self::Committed,
579 1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
580 2 => Self::Repeating {
581 id: DeserializeBytes::deserialize(&mut buf, mode)?,
582 log_count: DeserializeBytes::deserialize(buf, mode)?,
583 },
584 3 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?),
585 4 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?),
586 5 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?),
587 6 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?),
588 7 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?),
589 variant_index => {
590 return Err(SerializationError::UnknownEnumVariant {
591 name: "MultilinearPolyVariant",
592 index: variant_index,
593 });
594 }
595 })
596 }
597}
598
599#[derive(Debug, Clone, Getters, CopyGetters)]
603pub struct TransparentPolyOracle<F: Field> {
604 #[get = "pub"]
605 poly: Arc<dyn MultivariatePoly<F>>,
606}
607
608impl<F: TowerField> SerializeBytes for TransparentPolyOracle<F> {
609 fn serialize(
610 &self,
611 mut write_buf: impl bytes::BufMut,
612 mode: SerializationMode,
613 ) -> Result<(), SerializationError> {
614 self.poly.erased_serialize(&mut write_buf, mode)
615 }
616}
617
618impl DeserializeBytes for TransparentPolyOracle<BinaryField128b> {
619 fn deserialize(
620 read_buf: impl bytes::Buf,
621 mode: SerializationMode,
622 ) -> Result<Self, SerializationError>
623 where
624 Self: Sized,
625 {
626 Ok(Self {
627 poly: Box::<dyn MultivariatePoly<BinaryField128b>>::deserialize(read_buf, mode)?.into(),
628 })
629 }
630}
631
632impl<F: TowerField> TransparentPolyOracle<F> {
633 fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
634 if poly.binary_tower_level() > F::TOWER_LEVEL {
635 bail!(Error::TowerLevelTooHigh {
636 tower_level: poly.binary_tower_level(),
637 });
638 }
639 Ok(Self { poly })
640 }
641}
642
643impl<F: Field> TransparentPolyOracle<F> {
644 pub fn binary_tower_level(&self) -> usize {
646 self.poly.binary_tower_level()
647 }
648}
649
650impl<F: Field> PartialEq for TransparentPolyOracle<F> {
651 fn eq(&self, other: &Self) -> bool {
652 Arc::ptr_eq(&self.poly, &other.poly)
653 }
654}
655
656impl<F: Field> Eq for TransparentPolyOracle<F> {}
657
658#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
659pub struct Projected<F: TowerField> {
660 #[get_copy = "pub"]
661 id: OracleId,
662 #[get = "pub"]
663 values: Vec<F>,
664 #[get_copy = "pub"]
665 start_index: usize,
666}
667
668impl<F: TowerField> Projected<F> {
669 fn new(
670 oracle: &MultilinearPolyOracle<F>,
671 values: Vec<F>,
672 start_index: usize,
673 ) -> Result<Self, Error> {
674 if values.len() + start_index > oracle.n_vars() {
675 bail!(Error::InvalidProjection {
676 n_vars: oracle.n_vars(),
677 values_len: values.len()
678 });
679 }
680 Ok(Self {
681 id: oracle.id(),
682 values,
683 start_index,
684 })
685 }
686}
687
688#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
689pub struct ZeroPadded {
690 #[get_copy = "pub"]
691 id: OracleId,
692 #[get_copy = "pub"]
693 n_pad_vars: usize,
694 #[get_copy = "pub"]
695 nonzero_index: usize,
696 #[get_copy = "pub"]
697 start_index: usize,
698}
699
700impl ZeroPadded {
701 fn new<F: TowerField>(
702 oracle: &MultilinearPolyOracle<F>,
703 n_pad_vars: usize,
704 nonzero_index: usize,
705 start_index: usize,
706 ) -> Result<Self, Error> {
707 if start_index > oracle.n_vars() {
708 bail!(Error::InvalidStartIndex {
709 expected: oracle.n_vars(),
710 });
711 }
712
713 if nonzero_index > 1 << n_pad_vars {
714 bail!(Error::InvalidNonzeroIndex {
715 expected: 1 << n_pad_vars,
716 });
717 }
718
719 Ok(Self {
720 id: oracle.id(),
721 n_pad_vars,
722 nonzero_index,
723 start_index,
724 })
725 }
726}
727
728#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
729pub enum ShiftVariant {
730 CircularLeft,
731 LogicalLeft,
732 LogicalRight,
733}
734
735#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
736pub struct Shifted {
737 #[get_copy = "pub"]
738 id: OracleId,
739 #[get_copy = "pub"]
740 shift_offset: usize,
741 #[get_copy = "pub"]
742 block_size: usize,
743 #[get_copy = "pub"]
744 shift_variant: ShiftVariant,
745}
746
747impl Shifted {
748 fn new<F: TowerField>(
749 oracle: &MultilinearPolyOracle<F>,
750 shift_offset: usize,
751 block_size: usize,
752 shift_variant: ShiftVariant,
753 ) -> Result<Self, Error> {
754 if block_size > oracle.n_vars() {
755 bail!(PolynomialError::InvalidBlockSize {
756 n_vars: oracle.n_vars(),
757 });
758 }
759
760 if shift_offset == 0 || shift_offset >= 1 << block_size {
761 bail!(PolynomialError::InvalidShiftOffset {
762 max_shift_offset: (1 << block_size) - 1,
763 shift_offset,
764 });
765 }
766
767 Ok(Self {
768 id: oracle.id(),
769 shift_offset,
770 block_size,
771 shift_variant,
772 })
773 }
774}
775
776#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
777pub struct Packed {
778 #[get_copy = "pub"]
779 id: OracleId,
780 #[get_copy = "pub"]
787 log_degree: usize,
788}
789
790#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
791pub struct LinearCombination<F: TowerField> {
792 #[get_copy = "pub"]
793 n_vars: usize,
794 #[get_copy = "pub"]
795 offset: F,
796 inner: Vec<(OracleId, F)>,
797}
798
799impl<F: TowerField> LinearCombination<F> {
800 fn new(
801 n_vars: usize,
802 offset: F,
803 inner: impl IntoIterator<Item = (MultilinearPolyOracle<F>, F)>,
804 ) -> Result<Self, Error> {
805 let inner = inner
806 .into_iter()
807 .map(|(oracle, value)| {
808 if oracle.n_vars() == n_vars {
809 Ok((oracle.id(), value))
810 } else {
811 Err(Error::IncorrectNumberOfVariables { expected: n_vars })
812 }
813 })
814 .collect::<Result<Vec<_>, _>>()?;
815 Ok(Self {
816 n_vars,
817 offset,
818 inner,
819 })
820 }
821
822 pub fn n_polys(&self) -> usize {
823 self.inner.len()
824 }
825
826 pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
827 self.inner.iter().map(|(id, _)| *id)
828 }
829
830 pub fn coefficients(&self) -> impl Iterator<Item = F> + '_ {
831 self.inner.iter().map(|(_, coeff)| *coeff)
832 }
833}
834
835#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes)]
843pub struct CompositeMLE<F: TowerField> {
844 #[get_copy = "pub"]
846 n_vars: usize,
847 #[getset(get = "pub")]
849 inner: Vec<OracleId>,
850 #[getset(get = "pub")]
852 c: ArithCircuitPoly<F>,
853}
854
855impl<F: TowerField> CompositeMLE<F> {
856 pub fn new(
857 n_vars: usize,
858 inner: impl IntoIterator<Item = MultilinearPolyOracle<F>>,
859 c: ArithCircuit<F>,
860 ) -> Result<Self, Error> {
861 let inner = inner
862 .into_iter()
863 .map(|oracle| {
864 if oracle.n_vars() == n_vars {
865 Ok(oracle.id())
866 } else {
867 Err(Error::IncorrectNumberOfVariables { expected: n_vars })
868 }
869 })
870 .collect::<Result<Vec<_>, _>>()?;
871 let c = ArithCircuitPoly::with_n_vars(inner.len(), c)
872 .map_err(|_| Error::CompositionMismatch)?; Ok(Self { n_vars, inner, c })
874 }
875
876 pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
877 self.inner.iter().copied()
878 }
879
880 pub fn n_polys(&self) -> usize {
881 self.inner.len()
882 }
883}
884
885impl<F: TowerField> MultilinearPolyOracle<F> {
886 pub const fn id(&self) -> OracleId {
887 self.id
888 }
889
890 pub fn label(&self) -> String {
891 match self.name() {
892 Some(name) => format!("{}: {}", self.type_str(), name),
893 None => format!("{}: id={}", self.type_str(), self.id()),
894 }
895 }
896
897 pub fn name(&self) -> Option<&str> {
898 self.name.as_deref()
899 }
900
901 const fn type_str(&self) -> &str {
902 match self.variant {
903 MultilinearPolyVariant::Transparent(_) => "Transparent",
904 MultilinearPolyVariant::Committed => "Committed",
905 MultilinearPolyVariant::Repeating { .. } => "Repeating",
906 MultilinearPolyVariant::Projected(_) => "Projected",
907 MultilinearPolyVariant::Shifted(_) => "Shifted",
908 MultilinearPolyVariant::Packed(_) => "Packed",
909 MultilinearPolyVariant::LinearCombination(_) => "LinearCombination",
910 MultilinearPolyVariant::ZeroPadded(_) => "ZeroPadded",
911 MultilinearPolyVariant::Composite(_) => "CompositeMLE",
912 }
913 }
914
915 pub const fn n_vars(&self) -> usize {
916 self.n_vars
917 }
918
919 pub const fn binary_tower_level(&self) -> usize {
921 self.tower_level
922 }
923
924 pub fn into_composite(self) -> CompositePolyOracle<F> {
925 let composite =
926 CompositePolyOracle::new(self.n_vars(), vec![self], IdentityCompositionPoly);
927 composite.expect("Can always apply the identity composition to one variable")
928 }
929}
930
931#[cfg(test)]
932mod tests {
933 use binius_field::{BinaryField128b, BinaryField1b, Field, TowerField};
934
935 use super::MultilinearOracleSet;
936
937 #[test]
938 fn add_projection_with_all_vars() {
939 type F = BinaryField128b;
940 let mut oracles = MultilinearOracleSet::<F>::new();
941 let data = oracles.add_committed(5, BinaryField1b::TOWER_LEVEL);
942 let start_index = 0;
943 let projected = oracles
944 .add_projected(data, vec![F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], start_index)
945 .unwrap();
946 let _ = oracles.oracle(projected);
947 }
948}