1use std::{array, sync::Arc};
4
5use binius_fast_compute::arith_circuit::ArithCircuitPoly;
6use binius_field::{BinaryField128b, Field, TowerField};
7use binius_macros::{DeserializeBytes, SerializeBytes};
8use binius_math::ArithCircuit;
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, bail, ensure,
11};
12use getset::{CopyGetters, Getters};
13
14use crate::{
15 oracle::{CompositePolyOracle, Error, OracleId},
16 polynomial::{Error as PolynomialError, IdentityCompositionPoly, MultivariatePoly},
17};
18
19pub struct MultilinearOracleSetAddition<'a, F: TowerField> {
22 name: Option<String>,
23 mut_ref: &'a mut MultilinearOracleSet<F>,
24}
25
26impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
27 pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
28 if poly.binary_tower_level() > F::TOWER_LEVEL {
29 bail!(Error::TowerLevelTooHigh {
30 tower_level: poly.binary_tower_level(),
31 });
32 }
33
34 let inner = TransparentPolyOracle::new(Arc::new(poly))?;
35
36 let oracle = |id: OracleId| MultilinearPolyOracle {
37 id,
38 n_vars: inner.poly.n_vars(),
39 tower_level: inner.poly.binary_tower_level(),
40 name: self.name,
41 variant: MultilinearPolyVariant::Transparent(inner),
42 };
43
44 Ok(self.mut_ref.add_to_set(oracle))
45 }
46
47 pub fn structured(self, n_vars: usize, expr: ArithCircuit<F>) -> Result<OracleId, Error> {
48 if expr.binary_tower_level() > F::TOWER_LEVEL {
49 bail!(Error::TowerLevelTooHigh {
50 tower_level: expr.binary_tower_level(),
51 });
52 }
53
54 let oracle = |id: OracleId| MultilinearPolyOracle {
55 id,
56 n_vars,
57 tower_level: expr.binary_tower_level(),
58 name: self.name,
59 variant: MultilinearPolyVariant::Structured(expr),
60 };
61
62 Ok(self.mut_ref.add_to_set(oracle))
63 }
64
65 pub fn committed(mut self, n_vars: usize, tower_level: usize) -> OracleId {
66 let name = self.name.take();
67 self.add_committed_with_name(n_vars, tower_level, name)
68 }
69
70 pub fn committed_multiple<const N: usize>(
71 mut self,
72 n_vars: usize,
73 tower_level: usize,
74 ) -> [OracleId; N] {
75 match &self.name.take() {
76 None => [0; N].map(|_| self.add_committed_with_name(n_vars, tower_level, None)),
77 Some(s) => {
78 let x: [usize; N] = array::from_fn(|i| i);
79 x.map(|i| {
80 self.add_committed_with_name(n_vars, tower_level, Some(format!("{s}_{i}")))
81 })
82 }
83 }
84 }
85
86 pub fn repeating(self, inner_id: OracleId, log_count: usize) -> Result<OracleId, Error> {
87 ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
88
89 let inner = &self.mut_ref[inner_id];
90
91 let tower_level = inner.tower_level;
92 let n_vars = inner.n_vars + log_count;
93 let oracle = |id: OracleId| MultilinearPolyOracle {
94 id,
95 n_vars,
96 tower_level,
97 name: self.name,
98 variant: MultilinearPolyVariant::Repeating {
99 id: inner_id,
100 log_count,
101 },
102 };
103
104 Ok(self.mut_ref.add_to_set(oracle))
105 }
106
107 pub fn shifted(
108 self,
109 inner_id: OracleId,
110 offset: usize,
111 block_bits: usize,
112 variant: ShiftVariant,
113 ) -> Result<OracleId, Error> {
114 ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
115
116 let n_vars = self.mut_ref.n_vars(inner_id);
117 let tower_level = self.mut_ref.tower_level(inner_id);
118
119 if block_bits > n_vars {
120 bail!(PolynomialError::InvalidBlockSize { n_vars });
121 }
122
123 if offset == 0 || offset >= 1 << block_bits {
124 bail!(PolynomialError::InvalidShiftOffset {
125 max_shift_offset: (1 << block_bits) - 1,
126 shift_offset: offset,
127 });
128 }
129
130 let shifted = Shifted::new(self.mut_ref, inner_id, offset, block_bits, variant)?;
131
132 let oracle = |id: OracleId| MultilinearPolyOracle {
133 id,
134 n_vars,
135 tower_level,
136 name: self.name,
137 variant: MultilinearPolyVariant::Shifted(shifted),
138 };
139
140 Ok(self.mut_ref.add_to_set(oracle))
141 }
142
143 pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
144 ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
145
146 let inner_n_vars = self.mut_ref.n_vars(inner_id);
147 if log_degree > inner_n_vars {
148 bail!(Error::NotEnoughVarsForPacking {
149 n_vars: inner_n_vars,
150 log_degree,
151 });
152 }
153
154 let inner_tower_level = self.mut_ref.tower_level(inner_id);
155
156 let packed = Packed {
157 id: inner_id,
158 log_degree,
159 };
160
161 let oracle = |id: OracleId| MultilinearPolyOracle {
162 id,
163 n_vars: inner_n_vars - log_degree,
164 tower_level: inner_tower_level + log_degree,
165 name: self.name,
166 variant: MultilinearPolyVariant::Packed(packed),
167 };
168
169 Ok(self.mut_ref.add_to_set(oracle))
170 }
171
172 pub fn projected(
173 self,
174 inner_id: OracleId,
175 values: Vec<F>,
176 start_index: usize,
177 ) -> Result<OracleId, Error> {
178 ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
179
180 let inner_n_vars = self.mut_ref.n_vars(inner_id);
181 let values_len = values.len();
182 if values_len > inner_n_vars {
183 bail!(Error::InvalidProjection {
184 n_vars: inner_n_vars,
185 values_len,
186 });
187 }
188
189 let tower_level = self.mut_ref.tower_level(inner_id);
191 let projected = Projected::new(self.mut_ref, inner_id, values, start_index)?;
192
193 let oracle = |id: OracleId| MultilinearPolyOracle {
194 id,
195 n_vars: inner_n_vars - values_len,
196 tower_level,
197 name: self.name,
198 variant: MultilinearPolyVariant::Projected(projected),
199 };
200
201 Ok(self.mut_ref.add_to_set(oracle))
202 }
203
204 pub fn projected_last_vars(
205 self,
206 inner_id: OracleId,
207 values: Vec<F>,
208 ) -> Result<OracleId, Error> {
209 ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
210
211 let inner_n_vars = self.mut_ref.n_vars(inner_id);
212 let start_index = inner_n_vars - values.len();
213 let values_len = values.len();
214 if values_len > inner_n_vars {
215 bail!(Error::InvalidProjection {
216 n_vars: inner_n_vars,
217 values_len,
218 });
219 }
220
221 let tower_level = self.mut_ref.tower_level(inner_id);
223 let projected = Projected::new(self.mut_ref, inner_id, values, start_index)?;
224
225 let oracle = |id: OracleId| MultilinearPolyOracle {
226 id,
227 n_vars: inner_n_vars - values_len,
228 tower_level,
229 name: self.name,
230 variant: MultilinearPolyVariant::Projected(projected),
231 };
232
233 Ok(self.mut_ref.add_to_set(oracle))
234 }
235
236 pub fn linear_combination(
237 self,
238 n_vars: usize,
239 inner: impl IntoIterator<Item = (OracleId, F)>,
240 ) -> Result<OracleId, Error> {
241 self.linear_combination_with_offset(n_vars, F::ZERO, inner)
242 }
243
244 pub fn linear_combination_with_offset(
245 self,
246 n_vars: usize,
247 offset: F,
248 inner: impl IntoIterator<Item = (OracleId, F)>,
249 ) -> Result<OracleId, Error> {
250 let inner = inner.into_iter().collect::<Vec<_>>();
251 let tower_level = inner
252 .iter()
253 .map(|(oracle, coeff)| {
254 self.mut_ref
255 .tower_level(*oracle)
256 .max(coeff.min_tower_level())
257 })
258 .max()
259 .unwrap_or(0)
260 .max(offset.min_tower_level());
261 let linear_combination = LinearCombination::new(self.mut_ref, n_vars, offset, inner)?;
262
263 let oracle = |id: OracleId| MultilinearPolyOracle {
264 id,
265 n_vars,
266 tower_level,
267 name: self.name,
268 variant: MultilinearPolyVariant::LinearCombination(linear_combination),
269 };
270
271 Ok(self.mut_ref.add_to_set(oracle))
272 }
273
274 pub fn composite_mle(
275 self,
276 n_vars: usize,
277 inner: impl IntoIterator<Item = OracleId>,
278 comp: ArithCircuit<F>,
279 ) -> Result<OracleId, Error> {
280 let inner: Vec<OracleId> = inner.into_iter().collect();
281 let tower_level = inner
282 .iter()
283 .map(|&oracle| self.mut_ref[oracle].binary_tower_level())
284 .max()
285 .unwrap_or(0);
286 let composite_mle = CompositeMLE::new(self.mut_ref, n_vars, inner, comp)?;
287
288 let oracle = |id: OracleId| MultilinearPolyOracle {
289 id,
290 n_vars,
291 tower_level,
292 name: self.name,
293 variant: MultilinearPolyVariant::Composite(composite_mle),
294 };
295
296 Ok(self.mut_ref.add_to_set(oracle))
297 }
298
299 pub fn zero_padded(
300 self,
301 inner_id: OracleId,
302 n_pad_vars: usize,
303 nonzero_index: usize,
304 start_index: usize,
305 ) -> Result<OracleId, Error> {
306 let inner_n_vars = self.mut_ref.n_vars(inner_id);
307 let tower_level = self.mut_ref.tower_level(inner_id);
308 if start_index > inner_n_vars {
309 bail!(Error::InvalidStartIndex {
310 expected: inner_n_vars
311 });
312 }
313 let padded =
314 ZeroPadded::new(self.mut_ref, inner_id, n_pad_vars, nonzero_index, start_index)?;
315
316 let oracle = |id: OracleId| MultilinearPolyOracle {
317 id,
318 n_vars: inner_n_vars + n_pad_vars,
319 tower_level,
320 name: self.name,
321 variant: MultilinearPolyVariant::ZeroPadded(padded),
322 };
323
324 Ok(self.mut_ref.add_to_set(oracle))
325 }
326
327 fn add_committed_with_name(
328 &mut self,
329 n_vars: usize,
330 tower_level: usize,
331 name: Option<String>,
332 ) -> OracleId {
333 let oracle = |oracle_id: OracleId| MultilinearPolyOracle {
334 id: oracle_id,
335 n_vars,
336 tower_level,
337 name: name.clone(),
338 variant: MultilinearPolyVariant::Committed,
339 };
340
341 self.mut_ref.add_to_set(oracle)
342 }
343}
344
345#[derive(Default, Debug, Clone, SerializeBytes, DeserializeBytes)]
354#[deserialize_bytes(eval_generics(F = BinaryField128b))]
355pub struct MultilinearOracleSet<F: TowerField> {
356 oracles: Vec<Option<MultilinearPolyOracle<F>>>,
361 size: usize,
363}
364
365impl<F: TowerField> MultilinearOracleSet<F> {
366 pub const fn new() -> Self {
367 Self {
368 oracles: Vec::new(),
369 size: 0,
370 }
371 }
372
373 pub fn size(&self) -> usize {
374 self.size
375 }
376
377 pub fn polys(&self) -> impl Iterator<Item = &MultilinearPolyOracle<F>> + '_ {
378 self.iter().map(|(_, poly)| poly)
379 }
380
381 pub fn ids(&self) -> impl Iterator<Item = OracleId> {
382 self.iter().map(|(id, _)| id)
383 }
384
385 pub fn iter(&self) -> impl Iterator<Item = (OracleId, &MultilinearPolyOracle<F>)> + '_ {
386 (0..self.oracles.len()).filter_map(|index| match self.oracles[index] {
387 Some(ref oracle) => {
388 let oracle_id = OracleId::from_index(index);
389 Some((oracle_id, oracle))
390 }
391 None => None,
392 })
393 }
394
395 pub const fn add(&mut self) -> MultilinearOracleSetAddition<F> {
396 MultilinearOracleSetAddition {
397 name: None,
398 mut_ref: self,
399 }
400 }
401
402 pub fn add_named<S: ToString>(&mut self, s: S) -> MultilinearOracleSetAddition<F> {
403 MultilinearOracleSetAddition {
404 name: Some(s.to_string()),
405 mut_ref: self,
406 }
407 }
408
409 pub fn is_valid_oracle_id(&self, id: OracleId) -> bool {
410 id.index() < self.oracles.len()
411 }
412
413 pub(crate) fn add_to_set(
414 &mut self,
415 oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
416 ) -> OracleId {
417 let id = OracleId::from_index(self.oracles.len());
418 self.oracles.push(Some(oracle(id)));
419 self.size += 1;
420 id
421 }
422
423 pub(crate) fn add_skip(&mut self) {
426 self.oracles.push(None);
427 }
429
430 pub fn add_transparent(
431 &mut self,
432 poly: impl MultivariatePoly<F> + 'static,
433 ) -> Result<OracleId, Error> {
434 self.add().transparent(poly)
435 }
436
437 pub fn add_committed(&mut self, n_vars: usize, tower_level: usize) -> OracleId {
438 self.add().committed(n_vars, tower_level)
439 }
440
441 pub fn add_committed_multiple<const N: usize>(
442 &mut self,
443 n_vars: usize,
444 tower_level: usize,
445 ) -> [OracleId; N] {
446 self.add().committed_multiple(n_vars, tower_level)
447 }
448
449 pub fn add_repeating(&mut self, id: OracleId, log_count: usize) -> Result<OracleId, Error> {
450 self.add().repeating(id, log_count)
451 }
452
453 pub fn add_shifted(
454 &mut self,
455 id: OracleId,
456 offset: usize,
457 block_bits: usize,
458 variant: ShiftVariant,
459 ) -> Result<OracleId, Error> {
460 self.add().shifted(id, offset, block_bits, variant)
461 }
462
463 pub fn add_packed(&mut self, id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
464 self.add().packed(id, log_degree)
465 }
466
467 pub fn add_projected(
469 &mut self,
470 id: OracleId,
471 values: Vec<F>,
472 start_index: usize,
473 ) -> Result<OracleId, Error> {
474 self.add().projected(id, values, start_index)
475 }
476
477 pub fn add_projected_last_vars(
479 &mut self,
480 id: OracleId,
481 values: Vec<F>,
482 ) -> Result<OracleId, Error> {
483 self.add().projected_last_vars(id, values)
484 }
485
486 pub fn add_linear_combination(
487 &mut self,
488 n_vars: usize,
489 inner: impl IntoIterator<Item = (OracleId, F)>,
490 ) -> Result<OracleId, Error> {
491 self.add().linear_combination(n_vars, inner)
492 }
493
494 pub fn add_linear_combination_with_offset(
495 &mut self,
496 n_vars: usize,
497 offset: F,
498 inner: impl IntoIterator<Item = (OracleId, F)>,
499 ) -> Result<OracleId, Error> {
500 self.add()
501 .linear_combination_with_offset(n_vars, offset, inner)
502 }
503
504 pub fn add_zero_padded(
505 &mut self,
506 id: OracleId,
507 n_pad_vars: usize,
508 nonzero_index: usize,
509 start_index: usize,
510 ) -> Result<OracleId, Error> {
511 self.add()
512 .zero_padded(id, n_pad_vars, nonzero_index, start_index)
513 }
514
515 pub fn add_composite_mle(
516 &mut self,
517 n_vars: usize,
518 inner: impl IntoIterator<Item = OracleId>,
519 comp: ArithCircuit<F>,
520 ) -> Result<OracleId, Error> {
521 self.add().composite_mle(n_vars, inner, comp)
522 }
523
524 pub fn n_vars(&self, id: OracleId) -> usize {
525 self[id].n_vars()
526 }
527
528 pub fn label(&self, id: OracleId) -> String {
529 self[id].label()
530 }
531
532 pub fn tower_level(&self, id: OracleId) -> usize {
534 self[id].binary_tower_level()
535 }
536
537 pub fn is_zero_sized(&self, id: OracleId) -> bool {
540 self.oracles[id.index()].is_none()
541 }
542}
543
544impl<F: TowerField> std::ops::Index<OracleId> for MultilinearOracleSet<F> {
545 type Output = MultilinearPolyOracle<F>;
546
547 fn index(&self, id: OracleId) -> &Self::Output {
548 self.oracles[id.index()]
549 .as_ref()
550 .expect("tried to access skipped oracle")
551 }
552}
553
554#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
575#[deserialize_bytes(eval_generics(F = BinaryField128b))]
576pub struct MultilinearPolyOracle<F: TowerField> {
577 pub id: OracleId,
578 pub name: Option<String>,
579 pub n_vars: usize,
580 pub tower_level: usize,
581 pub variant: MultilinearPolyVariant<F>,
582}
583
584#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
585pub enum MultilinearPolyVariant<F: TowerField> {
586 Committed,
587 Transparent(TransparentPolyOracle<F>),
588 Structured(ArithCircuit<F>),
595 Repeating {
596 id: OracleId,
597 log_count: usize,
598 },
599 Projected(Projected<F>),
600 Shifted(Shifted),
601 Packed(Packed),
602 LinearCombination(LinearCombination<F>),
603 ZeroPadded(ZeroPadded),
604 Composite(CompositeMLE<F>),
605}
606
607impl<F: TowerField> MultilinearPolyVariant<F> {
608 pub fn is_committed(&self) -> bool {
610 matches!(self, Self::Committed)
611 }
612}
613
614impl DeserializeBytes for MultilinearPolyVariant<BinaryField128b> {
615 fn deserialize(
616 mut buf: impl bytes::Buf,
617 mode: SerializationMode,
618 ) -> Result<Self, SerializationError>
619 where
620 Self: Sized,
621 {
622 Ok(match u8::deserialize(&mut buf, mode)? {
623 0 => Self::Committed,
624 1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
625 2 => Self::Structured(DeserializeBytes::deserialize(buf, mode)?),
626 3 => Self::Repeating {
627 id: DeserializeBytes::deserialize(&mut buf, mode)?,
628 log_count: DeserializeBytes::deserialize(buf, mode)?,
629 },
630 4 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?),
631 5 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?),
632 6 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?),
633 7 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?),
634 8 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?),
635 9 => Self::Composite(DeserializeBytes::deserialize(buf, mode)?),
636 variant_index => {
637 return Err(SerializationError::UnknownEnumVariant {
638 name: "MultilinearPolyVariant",
639 index: variant_index,
640 });
641 }
642 })
643 }
644}
645
646#[derive(Debug, Clone, Getters, CopyGetters)]
650pub struct TransparentPolyOracle<F: Field> {
651 #[get = "pub"]
652 poly: Arc<dyn MultivariatePoly<F>>,
653}
654
655impl<F: TowerField> SerializeBytes for TransparentPolyOracle<F> {
656 fn serialize(
657 &self,
658 mut write_buf: impl bytes::BufMut,
659 mode: SerializationMode,
660 ) -> Result<(), SerializationError> {
661 self.poly.erased_serialize(&mut write_buf, mode)
662 }
663}
664
665impl DeserializeBytes for TransparentPolyOracle<BinaryField128b> {
666 fn deserialize(
667 read_buf: impl bytes::Buf,
668 mode: SerializationMode,
669 ) -> Result<Self, SerializationError>
670 where
671 Self: Sized,
672 {
673 Ok(Self {
674 poly: Box::<dyn MultivariatePoly<BinaryField128b>>::deserialize(read_buf, mode)?.into(),
675 })
676 }
677}
678
679impl<F: TowerField> TransparentPolyOracle<F> {
680 pub fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
681 if poly.binary_tower_level() > F::TOWER_LEVEL {
682 bail!(Error::TowerLevelTooHigh {
683 tower_level: poly.binary_tower_level(),
684 });
685 }
686 Ok(Self { poly })
687 }
688}
689
690impl<F: Field> TransparentPolyOracle<F> {
691 pub fn binary_tower_level(&self) -> usize {
693 self.poly.binary_tower_level()
694 }
695}
696
697impl<F: Field> PartialEq for TransparentPolyOracle<F> {
698 fn eq(&self, other: &Self) -> bool {
699 Arc::ptr_eq(&self.poly, &other.poly)
700 }
701}
702
703impl<F: Field> Eq for TransparentPolyOracle<F> {}
704
705#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
706pub struct Projected<F: TowerField> {
707 #[get_copy = "pub"]
708 id: OracleId,
709 #[get = "pub"]
710 values: Vec<F>,
711 #[get_copy = "pub"]
712 start_index: usize,
713}
714
715impl<F: TowerField> Projected<F> {
716 pub fn new(
717 oracles: &MultilinearOracleSet<F>,
718 id: OracleId,
719 values: Vec<F>,
720 start_index: usize,
721 ) -> Result<Self, Error> {
722 if values.len() + start_index > oracles.n_vars(id) {
723 bail!(Error::InvalidProjection {
724 n_vars: oracles.n_vars(id),
725 values_len: values.len()
726 });
727 }
728 Ok(Self {
729 id,
730 values,
731 start_index,
732 })
733 }
734}
735
736#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
737pub struct ZeroPadded {
738 #[get_copy = "pub"]
739 id: OracleId,
740 #[get_copy = "pub"]
741 n_pad_vars: usize,
742 #[get_copy = "pub"]
743 nonzero_index: usize,
744 #[get_copy = "pub"]
745 start_index: usize,
746}
747
748impl ZeroPadded {
749 pub fn new<F: TowerField>(
750 oracles: &MultilinearOracleSet<F>,
751 id: OracleId,
752 n_pad_vars: usize,
753 nonzero_index: usize,
754 start_index: usize,
755 ) -> Result<Self, Error> {
756 let oracle = &oracles[id];
757 if start_index > oracle.n_vars() {
758 bail!(Error::InvalidStartIndex {
759 expected: oracle.n_vars(),
760 });
761 }
762
763 if nonzero_index > 1 << n_pad_vars {
764 bail!(Error::InvalidNonzeroIndex {
765 expected: 1 << n_pad_vars,
766 });
767 }
768
769 Ok(Self {
770 id,
771 n_pad_vars,
772 nonzero_index,
773 start_index,
774 })
775 }
776}
777
778#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
779pub enum ShiftVariant {
780 CircularLeft,
781 LogicalLeft,
782 LogicalRight,
783}
784
785#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
786pub struct Shifted {
787 #[get_copy = "pub"]
788 id: OracleId,
789 #[get_copy = "pub"]
790 shift_offset: usize,
791 #[get_copy = "pub"]
792 block_size: usize,
793 #[get_copy = "pub"]
794 shift_variant: ShiftVariant,
795}
796
797impl Shifted {
798 pub fn new<F: TowerField>(
799 oracles: &MultilinearOracleSet<F>,
800 id: OracleId,
801 shift_offset: usize,
802 block_size: usize,
803 shift_variant: ShiftVariant,
804 ) -> Result<Self, Error> {
805 if block_size > oracles.n_vars(id) {
806 bail!(PolynomialError::InvalidBlockSize {
807 n_vars: oracles.n_vars(id),
808 });
809 }
810
811 if shift_offset == 0 || shift_offset >= 1 << block_size {
812 bail!(PolynomialError::InvalidShiftOffset {
813 max_shift_offset: (1 << block_size) - 1,
814 shift_offset,
815 });
816 }
817
818 Ok(Self {
819 id,
820 shift_offset,
821 block_size,
822 shift_variant,
823 })
824 }
825}
826
827#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
828pub struct Packed {
829 #[get_copy = "pub"]
830 id: OracleId,
831 #[get_copy = "pub"]
838 log_degree: usize,
839}
840
841impl Packed {
842 pub fn new(id: OracleId, log_degree: usize) -> Self {
843 Self { id, log_degree }
844 }
845}
846
847#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
848pub struct LinearCombination<F: TowerField> {
849 #[get_copy = "pub"]
850 n_vars: usize,
851 #[get_copy = "pub"]
852 offset: F,
853 inner: Vec<(OracleId, F)>,
854}
855
856impl<F: TowerField> LinearCombination<F> {
857 pub fn new(
858 oracles: &MultilinearOracleSet<F>,
859 n_vars: usize,
860 offset: F,
861 inner: Vec<(OracleId, F)>,
862 ) -> Result<Self, Error> {
863 for (oracle, _) in &inner {
864 if oracles[*oracle].n_vars() != n_vars {
865 return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
866 }
867 }
868 Ok(Self {
869 n_vars,
870 offset,
871 inner,
872 })
873 }
874
875 pub fn n_polys(&self) -> usize {
876 self.inner.len()
877 }
878
879 pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
880 self.inner.iter().map(|(id, _)| *id)
881 }
882
883 pub fn coefficients(&self) -> impl Iterator<Item = F> + '_ {
884 self.inner.iter().map(|(_, coeff)| *coeff)
885 }
886}
887
888#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, DeserializeBytes, SerializeBytes)]
896pub struct CompositeMLE<F: TowerField> {
897 #[get_copy = "pub"]
899 n_vars: usize,
900 #[getset(get = "pub")]
902 inner: Vec<OracleId>,
903 #[getset(get = "pub")]
905 c: ArithCircuitPoly<F>,
906}
907
908impl<F: TowerField> CompositeMLE<F> {
909 pub fn new(
910 oracles: &MultilinearOracleSet<F>,
911 n_vars: usize,
912 inner: Vec<OracleId>,
913 c: ArithCircuit<F>,
914 ) -> Result<Self, Error> {
915 for &oracle in &inner {
916 if oracles.n_vars(oracle) != n_vars {
917 return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
918 }
919 }
920 let c = ArithCircuitPoly::with_n_vars(inner.len(), c)
921 .map_err(|_| Error::CompositionMismatch)?; Ok(Self { n_vars, inner, c })
923 }
924
925 pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
926 self.inner.iter().copied()
927 }
928
929 pub fn n_polys(&self) -> usize {
930 self.inner.len()
931 }
932}
933
934impl<F: TowerField> MultilinearPolyOracle<F> {
935 pub const fn id(&self) -> OracleId {
936 self.id
937 }
938
939 pub fn label(&self) -> String {
940 match self.name() {
941 Some(name) => format!("{}: {}", self.type_str(), name),
942 None => format!("{}: id={}", self.type_str(), self.id()),
943 }
944 }
945
946 pub fn name(&self) -> Option<&str> {
947 self.name.as_deref()
948 }
949
950 const fn type_str(&self) -> &str {
951 match self.variant {
952 MultilinearPolyVariant::Transparent(_) => "Transparent",
953 MultilinearPolyVariant::Structured(_) => "Structured",
954 MultilinearPolyVariant::Committed => "Committed",
955 MultilinearPolyVariant::Repeating { .. } => "Repeating",
956 MultilinearPolyVariant::Projected(_) => "Projected",
957 MultilinearPolyVariant::Shifted(_) => "Shifted",
958 MultilinearPolyVariant::Packed(_) => "Packed",
959 MultilinearPolyVariant::LinearCombination(_) => "LinearCombination",
960 MultilinearPolyVariant::ZeroPadded(_) => "ZeroPadded",
961 MultilinearPolyVariant::Composite(_) => "CompositeMLE",
962 }
963 }
964
965 pub const fn n_vars(&self) -> usize {
966 self.n_vars
967 }
968
969 pub const fn binary_tower_level(&self) -> usize {
971 self.tower_level
972 }
973
974 pub fn into_composite(self) -> CompositePolyOracle<F> {
975 let composite =
976 CompositePolyOracle::new(self.n_vars(), vec![self], IdentityCompositionPoly);
977 composite.expect("Can always apply the identity composition to one variable")
978 }
979}
980
981#[cfg(test)]
982mod tests {
983 use binius_field::{BinaryField1b, BinaryField128b, Field, TowerField};
984
985 use super::MultilinearOracleSet;
986
987 #[test]
988 fn add_projection_with_all_vars() {
989 type F = BinaryField128b;
990 let mut oracles = MultilinearOracleSet::<F>::new();
991 let data = oracles.add_committed(5, BinaryField1b::TOWER_LEVEL);
992 let start_index = 0;
993 let projected = oracles
994 .add_projected(data, vec![F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], start_index)
995 .unwrap();
996 let _ = &oracles[projected];
997 }
998}