1use std::{array, sync::Arc};
4
5use binius_fast_compute::arith_circuit::ArithCircuitPoly;
6use binius_field::{BinaryField128b, Field, TowerField};
7use binius_macros::{DeserializeBytes, SerializeBytes};
8use binius_math::ArithCircuit;
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, bail, ensure,
11};
12use getset::{CopyGetters, Getters};
13
14use crate::{
15 oracle::{CompositePolyOracle, Error, OracleId},
16 polynomial::{Error as PolynomialError, IdentityCompositionPoly, MultivariatePoly},
17};
18
19pub struct MultilinearOracleSetAddition<'a, F: TowerField> {
22 name: Option<String>,
23 mut_ref: &'a mut MultilinearOracleSet<F>,
24}
25
26impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
27 pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
28 if poly.binary_tower_level() > F::TOWER_LEVEL {
29 bail!(Error::TowerLevelTooHigh {
30 tower_level: poly.binary_tower_level(),
31 });
32 }
33
34 let inner = TransparentPolyOracle::new(Arc::new(poly))?;
35
36 let oracle = |id: OracleId| MultilinearPolyOracle {
37 id,
38 n_vars: inner.poly.n_vars(),
39 tower_level: inner.poly.binary_tower_level(),
40 name: self.name,
41 variant: MultilinearPolyVariant::Transparent(inner),
42 };
43
44 Ok(self.mut_ref.add_to_set(oracle))
45 }
46
47 pub fn structured(self, n_vars: usize, expr: ArithCircuit<F>) -> Result<OracleId, Error> {
48 if expr.binary_tower_level() > F::TOWER_LEVEL {
49 bail!(Error::TowerLevelTooHigh {
50 tower_level: expr.binary_tower_level(),
51 });
52 }
53
54 let oracle = |id: OracleId| MultilinearPolyOracle {
55 id,
56 n_vars,
57 tower_level: expr.binary_tower_level(),
58 name: self.name,
59 variant: MultilinearPolyVariant::Structured(expr),
60 };
61
62 Ok(self.mut_ref.add_to_set(oracle))
63 }
64
65 pub fn committed(mut self, n_vars: usize, tower_level: usize) -> OracleId {
66 let name = self.name.take();
67 self.add_committed_with_name(n_vars, tower_level, name)
68 }
69
70 pub fn committed_multiple<const N: usize>(
71 mut self,
72 n_vars: usize,
73 tower_level: usize,
74 ) -> [OracleId; N] {
75 match &self.name.take() {
76 None => [0; N].map(|_| self.add_committed_with_name(n_vars, tower_level, None)),
77 Some(s) => {
78 let x: [usize; N] = array::from_fn(|i| i);
79 x.map(|i| {
80 self.add_committed_with_name(n_vars, tower_level, Some(format!("{s}_{i}")))
81 })
82 }
83 }
84 }
85
86 pub fn repeating(self, inner_id: OracleId, log_count: usize) -> Result<OracleId, Error> {
87 ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
88
89 let inner = &self.mut_ref[inner_id];
90
91 let tower_level = inner.tower_level;
92 let n_vars = inner.n_vars + log_count;
93 let oracle = |id: OracleId| MultilinearPolyOracle {
94 id,
95 n_vars,
96 tower_level,
97 name: self.name,
98 variant: MultilinearPolyVariant::Repeating {
99 id: inner_id,
100 log_count,
101 },
102 };
103
104 Ok(self.mut_ref.add_to_set(oracle))
105 }
106
107 pub fn shifted(
108 self,
109 inner_id: OracleId,
110 offset: usize,
111 block_bits: usize,
112 variant: ShiftVariant,
113 ) -> Result<OracleId, Error> {
114 ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
115
116 let inner = &self.mut_ref[inner_id];
117 if block_bits > inner.n_vars {
118 bail!(PolynomialError::InvalidBlockSize {
119 n_vars: inner.n_vars,
120 });
121 }
122
123 if offset == 0 || offset >= 1 << block_bits {
124 bail!(PolynomialError::InvalidShiftOffset {
125 max_shift_offset: (1 << block_bits) - 1,
126 shift_offset: offset,
127 });
128 }
129
130 let shifted = Shifted::new(inner, offset, block_bits, variant)?;
131
132 let tower_level = inner.tower_level;
133 let n_vars = inner.n_vars;
134 let oracle = |id: OracleId| MultilinearPolyOracle {
135 id,
136 n_vars,
137 tower_level,
138 name: self.name,
139 variant: MultilinearPolyVariant::Shifted(shifted),
140 };
141
142 Ok(self.mut_ref.add_to_set(oracle))
143 }
144
145 pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
146 ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
147
148 let inner_n_vars = self.mut_ref.n_vars(inner_id);
149 if log_degree > inner_n_vars {
150 bail!(Error::NotEnoughVarsForPacking {
151 n_vars: inner_n_vars,
152 log_degree,
153 });
154 }
155
156 let inner_tower_level = self.mut_ref.tower_level(inner_id);
157
158 let packed = Packed {
159 id: inner_id,
160 log_degree,
161 };
162
163 let oracle = |id: OracleId| MultilinearPolyOracle {
164 id,
165 n_vars: inner_n_vars - log_degree,
166 tower_level: inner_tower_level + log_degree,
167 name: self.name,
168 variant: MultilinearPolyVariant::Packed(packed),
169 };
170
171 Ok(self.mut_ref.add_to_set(oracle))
172 }
173
174 pub fn projected(
175 self,
176 inner_id: OracleId,
177 values: Vec<F>,
178 start_index: usize,
179 ) -> Result<OracleId, Error> {
180 ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
181
182 let inner_n_vars = self.mut_ref.n_vars(inner_id);
183 let values_len = values.len();
184 if values_len > inner_n_vars {
185 bail!(Error::InvalidProjection {
186 n_vars: inner_n_vars,
187 values_len,
188 });
189 }
190
191 let inner = &self.mut_ref[inner_id];
192 let tower_level = inner.binary_tower_level();
194 let projected = Projected::new(inner, values, start_index)?;
195
196 let oracle = |id: OracleId| MultilinearPolyOracle {
197 id,
198 n_vars: inner_n_vars - values_len,
199 tower_level,
200 name: self.name,
201 variant: MultilinearPolyVariant::Projected(projected),
202 };
203
204 Ok(self.mut_ref.add_to_set(oracle))
205 }
206
207 pub fn projected_last_vars(
208 self,
209 inner_id: OracleId,
210 values: Vec<F>,
211 ) -> Result<OracleId, Error> {
212 ensure!(self.mut_ref.is_valid_oracle_id(inner_id), Error::InvalidOracleId(inner_id));
213
214 let inner_n_vars = self.mut_ref.n_vars(inner_id);
215 let start_index = inner_n_vars - values.len();
216 let values_len = values.len();
217 if values_len > inner_n_vars {
218 bail!(Error::InvalidProjection {
219 n_vars: inner_n_vars,
220 values_len,
221 });
222 }
223
224 let inner = &self.mut_ref[inner_id];
225 let tower_level = inner.binary_tower_level();
227 let projected = Projected::new(inner, values, start_index)?;
228
229 let oracle = |id: OracleId| MultilinearPolyOracle {
230 id,
231 n_vars: inner_n_vars - values_len,
232 tower_level,
233 name: self.name,
234 variant: MultilinearPolyVariant::Projected(projected),
235 };
236
237 Ok(self.mut_ref.add_to_set(oracle))
238 }
239
240 pub fn linear_combination(
241 self,
242 n_vars: usize,
243 inner: impl IntoIterator<Item = (OracleId, F)>,
244 ) -> Result<OracleId, Error> {
245 self.linear_combination_with_offset(n_vars, F::ZERO, inner)
246 }
247
248 pub fn linear_combination_with_offset(
249 self,
250 n_vars: usize,
251 offset: F,
252 inner: impl IntoIterator<Item = (OracleId, F)>,
253 ) -> Result<OracleId, Error> {
254 let inner = inner
255 .into_iter()
256 .map(|(inner_id, coeff)| {
257 ensure!(
258 self.mut_ref.is_valid_oracle_id(inner_id),
259 Error::InvalidOracleId(inner_id)
260 );
261 if self.mut_ref.n_vars(inner_id) != n_vars {
262 return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
263 }
264 Ok((self.mut_ref[inner_id].clone(), coeff))
265 })
266 .collect::<Result<Vec<_>, _>>()?;
267
268 let tower_level = inner
269 .iter()
270 .map(|(oracle, coeff)| oracle.binary_tower_level().max(coeff.min_tower_level()))
271 .max()
272 .unwrap_or(0)
273 .max(offset.min_tower_level());
274
275 let linear_combination = LinearCombination::new(n_vars, offset, inner)?;
276
277 let oracle = |id: OracleId| MultilinearPolyOracle {
278 id,
279 n_vars,
280 tower_level,
281 name: self.name,
282 variant: MultilinearPolyVariant::LinearCombination(linear_combination),
283 };
284
285 Ok(self.mut_ref.add_to_set(oracle))
286 }
287
288 pub fn composite_mle(
289 self,
290 n_vars: usize,
291 inner: impl IntoIterator<Item = OracleId>,
292 comp: ArithCircuit<F>,
293 ) -> Result<OracleId, Error> {
294 let inner = inner
295 .into_iter()
296 .map(|inner_id| {
297 ensure!(
298 self.mut_ref.is_valid_oracle_id(inner_id),
299 Error::InvalidOracleId(inner_id)
300 );
301 if self.mut_ref.n_vars(inner_id) != n_vars {
302 return Err(Error::IncorrectNumberOfVariables { expected: n_vars });
303 }
304 Ok(self.mut_ref[inner_id].clone())
305 })
306 .collect::<Result<Vec<_>, _>>()?;
307
308 let tower_level = inner
309 .iter()
310 .map(|oracle| oracle.binary_tower_level())
311 .max()
312 .unwrap_or(0);
313
314 let composite_mle = CompositeMLE::new(n_vars, inner, comp)?;
315
316 let oracle = |id: OracleId| MultilinearPolyOracle {
317 id,
318 n_vars,
319 tower_level,
320 name: self.name,
321 variant: MultilinearPolyVariant::Composite(composite_mle),
322 };
323
324 Ok(self.mut_ref.add_to_set(oracle))
325 }
326
327 pub fn zero_padded(
328 self,
329 inner_id: OracleId,
330 n_pad_vars: usize,
331 nonzero_index: usize,
332 start_index: usize,
333 ) -> Result<OracleId, Error> {
334 let inner_n_vars = self.mut_ref.n_vars(inner_id);
335 if start_index > inner_n_vars {
336 bail!(Error::InvalidStartIndex {
337 expected: inner_n_vars
338 });
339 }
340
341 let inner = &self.mut_ref[inner_id];
342 let tower_level = inner.binary_tower_level();
343 let padded = ZeroPadded::new(inner, n_pad_vars, nonzero_index, start_index)?;
344
345 let oracle = |id: OracleId| MultilinearPolyOracle {
346 id,
347 n_vars: inner_n_vars + n_pad_vars,
348 tower_level,
349 name: self.name,
350 variant: MultilinearPolyVariant::ZeroPadded(padded),
351 };
352
353 Ok(self.mut_ref.add_to_set(oracle))
354 }
355
356 fn add_committed_with_name(
357 &mut self,
358 n_vars: usize,
359 tower_level: usize,
360 name: Option<String>,
361 ) -> OracleId {
362 let oracle = |oracle_id: OracleId| MultilinearPolyOracle {
363 id: oracle_id,
364 n_vars,
365 tower_level,
366 name: name.clone(),
367 variant: MultilinearPolyVariant::Committed,
368 };
369
370 self.mut_ref.add_to_set(oracle)
371 }
372}
373
374#[derive(Default, Debug, Clone, SerializeBytes, DeserializeBytes)]
383#[deserialize_bytes(eval_generics(F = BinaryField128b))]
384pub struct MultilinearOracleSet<F: TowerField> {
385 oracles: Vec<MultilinearPolyOracle<F>>,
386}
387
388impl<F: TowerField> MultilinearOracleSet<F> {
389 pub const fn new() -> Self {
390 Self {
391 oracles: Vec::new(),
392 }
393 }
394
395 pub fn size(&self) -> usize {
396 self.oracles.len()
397 }
398
399 pub fn polys(&self) -> impl Iterator<Item = &MultilinearPolyOracle<F>> + '_ {
400 (0..self.oracles.len()).map(|index| &self[OracleId::from_index(index)])
401 }
402
403 pub fn ids(&self) -> impl Iterator<Item = OracleId> {
404 (0..self.oracles.len()).map(OracleId::from_index)
405 }
406
407 pub fn iter(&self) -> impl Iterator<Item = (OracleId, &MultilinearPolyOracle<F>)> + '_ {
408 (0..self.oracles.len()).map(|index| {
409 let oracle_id = OracleId::from_index(index);
410 (oracle_id, &self[oracle_id])
411 })
412 }
413
414 pub const fn add(&mut self) -> MultilinearOracleSetAddition<F> {
415 MultilinearOracleSetAddition {
416 name: None,
417 mut_ref: self,
418 }
419 }
420
421 pub fn add_named<S: ToString>(&mut self, s: S) -> MultilinearOracleSetAddition<F> {
422 MultilinearOracleSetAddition {
423 name: Some(s.to_string()),
424 mut_ref: self,
425 }
426 }
427
428 pub fn is_valid_oracle_id(&self, id: OracleId) -> bool {
429 id.index() < self.oracles.len()
430 }
431
432 fn add_to_set(
433 &mut self,
434 oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
435 ) -> OracleId {
436 let id = OracleId::from_index(self.oracles.len());
437 self.oracles.push(oracle(id));
438 id
439 }
440
441 pub fn add_transparent(
442 &mut self,
443 poly: impl MultivariatePoly<F> + 'static,
444 ) -> Result<OracleId, Error> {
445 self.add().transparent(poly)
446 }
447
448 pub fn add_committed(&mut self, n_vars: usize, tower_level: usize) -> OracleId {
449 self.add().committed(n_vars, tower_level)
450 }
451
452 pub fn add_committed_multiple<const N: usize>(
453 &mut self,
454 n_vars: usize,
455 tower_level: usize,
456 ) -> [OracleId; N] {
457 self.add().committed_multiple(n_vars, tower_level)
458 }
459
460 pub fn add_repeating(&mut self, id: OracleId, log_count: usize) -> Result<OracleId, Error> {
461 self.add().repeating(id, log_count)
462 }
463
464 pub fn add_shifted(
465 &mut self,
466 id: OracleId,
467 offset: usize,
468 block_bits: usize,
469 variant: ShiftVariant,
470 ) -> Result<OracleId, Error> {
471 self.add().shifted(id, offset, block_bits, variant)
472 }
473
474 pub fn add_packed(&mut self, id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
475 self.add().packed(id, log_degree)
476 }
477
478 pub fn add_projected(
480 &mut self,
481 id: OracleId,
482 values: Vec<F>,
483 start_index: usize,
484 ) -> Result<OracleId, Error> {
485 self.add().projected(id, values, start_index)
486 }
487
488 pub fn add_projected_last_vars(
490 &mut self,
491 id: OracleId,
492 values: Vec<F>,
493 ) -> Result<OracleId, Error> {
494 self.add().projected_last_vars(id, values)
495 }
496
497 pub fn add_linear_combination(
498 &mut self,
499 n_vars: usize,
500 inner: impl IntoIterator<Item = (OracleId, F)>,
501 ) -> Result<OracleId, Error> {
502 self.add().linear_combination(n_vars, inner)
503 }
504
505 pub fn add_linear_combination_with_offset(
506 &mut self,
507 n_vars: usize,
508 offset: F,
509 inner: impl IntoIterator<Item = (OracleId, F)>,
510 ) -> Result<OracleId, Error> {
511 self.add()
512 .linear_combination_with_offset(n_vars, offset, inner)
513 }
514
515 pub fn add_zero_padded(
516 &mut self,
517 id: OracleId,
518 n_pad_vars: usize,
519 nonzero_index: usize,
520 start_index: usize,
521 ) -> Result<OracleId, Error> {
522 self.add()
523 .zero_padded(id, n_pad_vars, nonzero_index, start_index)
524 }
525
526 pub fn add_composite_mle(
527 &mut self,
528 n_vars: usize,
529 inner: impl IntoIterator<Item = OracleId>,
530 comp: ArithCircuit<F>,
531 ) -> Result<OracleId, Error> {
532 self.add().composite_mle(n_vars, inner, comp)
533 }
534
535 pub fn n_vars(&self, id: OracleId) -> usize {
536 self[id].n_vars()
537 }
538
539 pub fn label(&self, id: OracleId) -> String {
540 self[id].label()
541 }
542
543 pub fn tower_level(&self, id: OracleId) -> usize {
545 self[id].binary_tower_level()
546 }
547}
548
549impl<F: TowerField> std::ops::Index<OracleId> for MultilinearOracleSet<F> {
550 type Output = MultilinearPolyOracle<F>;
551
552 fn index(&self, id: OracleId) -> &Self::Output {
553 &self.oracles[id.index()]
554 }
555}
556
557#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
578#[deserialize_bytes(eval_generics(F = BinaryField128b))]
579pub struct MultilinearPolyOracle<F: TowerField> {
580 pub id: OracleId,
581 pub name: Option<String>,
582 pub n_vars: usize,
583 pub tower_level: usize,
584 pub variant: MultilinearPolyVariant<F>,
585}
586
587#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
588pub enum MultilinearPolyVariant<F: TowerField> {
589 Committed,
590 Transparent(TransparentPolyOracle<F>),
591 Structured(ArithCircuit<F>),
598 Repeating {
599 id: OracleId,
600 log_count: usize,
601 },
602 Projected(Projected<F>),
603 Shifted(Shifted),
604 Packed(Packed),
605 LinearCombination(LinearCombination<F>),
606 ZeroPadded(ZeroPadded),
607 Composite(CompositeMLE<F>),
608}
609
610impl<F: TowerField> MultilinearPolyVariant<F> {
611 pub fn is_committed(&self) -> bool {
613 matches!(self, Self::Committed)
614 }
615}
616
617impl DeserializeBytes for MultilinearPolyVariant<BinaryField128b> {
618 fn deserialize(
619 mut buf: impl bytes::Buf,
620 mode: SerializationMode,
621 ) -> Result<Self, SerializationError>
622 where
623 Self: Sized,
624 {
625 Ok(match u8::deserialize(&mut buf, mode)? {
626 0 => Self::Committed,
627 1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
628 2 => Self::Structured(DeserializeBytes::deserialize(buf, mode)?),
629 3 => Self::Repeating {
630 id: DeserializeBytes::deserialize(&mut buf, mode)?,
631 log_count: DeserializeBytes::deserialize(buf, mode)?,
632 },
633 4 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?),
634 5 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?),
635 6 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?),
636 7 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?),
637 8 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?),
638 variant_index => {
639 return Err(SerializationError::UnknownEnumVariant {
640 name: "MultilinearPolyVariant",
641 index: variant_index,
642 });
643 }
644 })
645 }
646}
647
648#[derive(Debug, Clone, Getters, CopyGetters)]
652pub struct TransparentPolyOracle<F: Field> {
653 #[get = "pub"]
654 poly: Arc<dyn MultivariatePoly<F>>,
655}
656
657impl<F: TowerField> SerializeBytes for TransparentPolyOracle<F> {
658 fn serialize(
659 &self,
660 mut write_buf: impl bytes::BufMut,
661 mode: SerializationMode,
662 ) -> Result<(), SerializationError> {
663 self.poly.erased_serialize(&mut write_buf, mode)
664 }
665}
666
667impl DeserializeBytes for TransparentPolyOracle<BinaryField128b> {
668 fn deserialize(
669 read_buf: impl bytes::Buf,
670 mode: SerializationMode,
671 ) -> Result<Self, SerializationError>
672 where
673 Self: Sized,
674 {
675 Ok(Self {
676 poly: Box::<dyn MultivariatePoly<BinaryField128b>>::deserialize(read_buf, mode)?.into(),
677 })
678 }
679}
680
681impl<F: TowerField> TransparentPolyOracle<F> {
682 fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
683 if poly.binary_tower_level() > F::TOWER_LEVEL {
684 bail!(Error::TowerLevelTooHigh {
685 tower_level: poly.binary_tower_level(),
686 });
687 }
688 Ok(Self { poly })
689 }
690}
691
692impl<F: Field> TransparentPolyOracle<F> {
693 pub fn binary_tower_level(&self) -> usize {
695 self.poly.binary_tower_level()
696 }
697}
698
699impl<F: Field> PartialEq for TransparentPolyOracle<F> {
700 fn eq(&self, other: &Self) -> bool {
701 Arc::ptr_eq(&self.poly, &other.poly)
702 }
703}
704
705impl<F: Field> Eq for TransparentPolyOracle<F> {}
706
707#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
708pub struct Projected<F: TowerField> {
709 #[get_copy = "pub"]
710 id: OracleId,
711 #[get = "pub"]
712 values: Vec<F>,
713 #[get_copy = "pub"]
714 start_index: usize,
715}
716
717impl<F: TowerField> Projected<F> {
718 fn new(
719 oracle: &MultilinearPolyOracle<F>,
720 values: Vec<F>,
721 start_index: usize,
722 ) -> Result<Self, Error> {
723 if values.len() + start_index > oracle.n_vars() {
724 bail!(Error::InvalidProjection {
725 n_vars: oracle.n_vars(),
726 values_len: values.len()
727 });
728 }
729 Ok(Self {
730 id: oracle.id(),
731 values,
732 start_index,
733 })
734 }
735}
736
737#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
738pub struct ZeroPadded {
739 #[get_copy = "pub"]
740 id: OracleId,
741 #[get_copy = "pub"]
742 n_pad_vars: usize,
743 #[get_copy = "pub"]
744 nonzero_index: usize,
745 #[get_copy = "pub"]
746 start_index: usize,
747}
748
749impl ZeroPadded {
750 fn new<F: TowerField>(
751 oracle: &MultilinearPolyOracle<F>,
752 n_pad_vars: usize,
753 nonzero_index: usize,
754 start_index: usize,
755 ) -> Result<Self, Error> {
756 if start_index > oracle.n_vars() {
757 bail!(Error::InvalidStartIndex {
758 expected: oracle.n_vars(),
759 });
760 }
761
762 if nonzero_index > 1 << n_pad_vars {
763 bail!(Error::InvalidNonzeroIndex {
764 expected: 1 << n_pad_vars,
765 });
766 }
767
768 Ok(Self {
769 id: oracle.id(),
770 n_pad_vars,
771 nonzero_index,
772 start_index,
773 })
774 }
775}
776
777#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
778pub enum ShiftVariant {
779 CircularLeft,
780 LogicalLeft,
781 LogicalRight,
782}
783
784#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
785pub struct Shifted {
786 #[get_copy = "pub"]
787 id: OracleId,
788 #[get_copy = "pub"]
789 shift_offset: usize,
790 #[get_copy = "pub"]
791 block_size: usize,
792 #[get_copy = "pub"]
793 shift_variant: ShiftVariant,
794}
795
796impl Shifted {
797 fn new<F: TowerField>(
798 oracle: &MultilinearPolyOracle<F>,
799 shift_offset: usize,
800 block_size: usize,
801 shift_variant: ShiftVariant,
802 ) -> Result<Self, Error> {
803 if block_size > oracle.n_vars() {
804 bail!(PolynomialError::InvalidBlockSize {
805 n_vars: oracle.n_vars(),
806 });
807 }
808
809 if shift_offset == 0 || shift_offset >= 1 << block_size {
810 bail!(PolynomialError::InvalidShiftOffset {
811 max_shift_offset: (1 << block_size) - 1,
812 shift_offset,
813 });
814 }
815
816 Ok(Self {
817 id: oracle.id(),
818 shift_offset,
819 block_size,
820 shift_variant,
821 })
822 }
823}
824
825#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
826pub struct Packed {
827 #[get_copy = "pub"]
828 id: OracleId,
829 #[get_copy = "pub"]
836 log_degree: usize,
837}
838
839#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)]
840pub struct LinearCombination<F: TowerField> {
841 #[get_copy = "pub"]
842 n_vars: usize,
843 #[get_copy = "pub"]
844 offset: F,
845 inner: Vec<(OracleId, F)>,
846}
847
848impl<F: TowerField> LinearCombination<F> {
849 fn new(
850 n_vars: usize,
851 offset: F,
852 inner: impl IntoIterator<Item = (MultilinearPolyOracle<F>, F)>,
853 ) -> Result<Self, Error> {
854 let inner = inner
855 .into_iter()
856 .map(|(oracle, value)| {
857 if oracle.n_vars() == n_vars {
858 Ok((oracle.id(), value))
859 } else {
860 Err(Error::IncorrectNumberOfVariables { expected: n_vars })
861 }
862 })
863 .collect::<Result<Vec<_>, _>>()?;
864 Ok(Self {
865 n_vars,
866 offset,
867 inner,
868 })
869 }
870
871 pub fn n_polys(&self) -> usize {
872 self.inner.len()
873 }
874
875 pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
876 self.inner.iter().map(|(id, _)| *id)
877 }
878
879 pub fn coefficients(&self) -> impl Iterator<Item = F> + '_ {
880 self.inner.iter().map(|(_, coeff)| *coeff)
881 }
882}
883
884#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes)]
892pub struct CompositeMLE<F: TowerField> {
893 #[get_copy = "pub"]
895 n_vars: usize,
896 #[getset(get = "pub")]
898 inner: Vec<OracleId>,
899 #[getset(get = "pub")]
901 c: ArithCircuitPoly<F>,
902}
903
904impl<F: TowerField> CompositeMLE<F> {
905 pub fn new(
906 n_vars: usize,
907 inner: impl IntoIterator<Item = MultilinearPolyOracle<F>>,
908 c: ArithCircuit<F>,
909 ) -> Result<Self, Error> {
910 let inner = inner
911 .into_iter()
912 .map(|oracle| {
913 if oracle.n_vars() == n_vars {
914 Ok(oracle.id())
915 } else {
916 Err(Error::IncorrectNumberOfVariables { expected: n_vars })
917 }
918 })
919 .collect::<Result<Vec<_>, _>>()?;
920 let c = ArithCircuitPoly::with_n_vars(inner.len(), c)
921 .map_err(|_| Error::CompositionMismatch)?; Ok(Self { n_vars, inner, c })
923 }
924
925 pub fn polys(&self) -> impl Iterator<Item = OracleId> + '_ {
926 self.inner.iter().copied()
927 }
928
929 pub fn n_polys(&self) -> usize {
930 self.inner.len()
931 }
932}
933
934impl<F: TowerField> MultilinearPolyOracle<F> {
935 pub const fn id(&self) -> OracleId {
936 self.id
937 }
938
939 pub fn label(&self) -> String {
940 match self.name() {
941 Some(name) => format!("{}: {}", self.type_str(), name),
942 None => format!("{}: id={}", self.type_str(), self.id()),
943 }
944 }
945
946 pub fn name(&self) -> Option<&str> {
947 self.name.as_deref()
948 }
949
950 const fn type_str(&self) -> &str {
951 match self.variant {
952 MultilinearPolyVariant::Transparent(_) => "Transparent",
953 MultilinearPolyVariant::Structured(_) => "Structured",
954 MultilinearPolyVariant::Committed => "Committed",
955 MultilinearPolyVariant::Repeating { .. } => "Repeating",
956 MultilinearPolyVariant::Projected(_) => "Projected",
957 MultilinearPolyVariant::Shifted(_) => "Shifted",
958 MultilinearPolyVariant::Packed(_) => "Packed",
959 MultilinearPolyVariant::LinearCombination(_) => "LinearCombination",
960 MultilinearPolyVariant::ZeroPadded(_) => "ZeroPadded",
961 MultilinearPolyVariant::Composite(_) => "CompositeMLE",
962 }
963 }
964
965 pub const fn n_vars(&self) -> usize {
966 self.n_vars
967 }
968
969 pub const fn binary_tower_level(&self) -> usize {
971 self.tower_level
972 }
973
974 pub fn into_composite(self) -> CompositePolyOracle<F> {
975 let composite =
976 CompositePolyOracle::new(self.n_vars(), vec![self], IdentityCompositionPoly);
977 composite.expect("Can always apply the identity composition to one variable")
978 }
979}
980
981#[cfg(test)]
982mod tests {
983 use binius_field::{BinaryField1b, BinaryField128b, Field, TowerField};
984
985 use super::MultilinearOracleSet;
986
987 #[test]
988 fn add_projection_with_all_vars() {
989 type F = BinaryField128b;
990 let mut oracles = MultilinearOracleSet::<F>::new();
991 let data = oracles.add_committed(5, BinaryField1b::TOWER_LEVEL);
992 let start_index = 0;
993 let projected = oracles
994 .add_projected(data, vec![F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], start_index)
995 .unwrap();
996 let _ = &oracles[projected];
997 }
998}