1use std::{array, sync::Arc};
7
8use binius_field::{BinaryField128b, TowerField};
9use binius_macros::{DeserializeBytes, SerializeBytes};
10use binius_math::ArithCircuit;
11use binius_utils::{
12 DeserializeBytes, SerializationError, SerializationMode, bail,
13 checked_arithmetics::log2_ceil_usize,
14};
15
16use super::{
17 CompositeMLE, LinearCombination, MultilinearOracleSet, MultilinearPolyOracle,
18 MultilinearPolyVariant, Packed, Projected, ShiftVariant, Shifted, TransparentPolyOracle,
19 ZeroPadded,
20};
21use crate::{
22 constraint_system::TableId,
23 oracle::{Error, OracleId},
24 polynomial::{Error as PolynomialError, MultivariatePoly},
25};
26
27#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
28#[deserialize_bytes(eval_generics(F = BinaryField128b))]
29pub struct SymbolicMultilinearOracle<F: TowerField> {
30 pub id: OracleId,
31 pub name: Option<String>,
32 pub table_id: TableId,
33 pub log_values_per_row: usize,
34 pub tower_level: usize,
35 pub variant: SymbolicMultilinearPolyVariant<F>,
36}
37
38#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
39pub enum ProjectionVariant {
40 Offset(usize),
42 Last,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
47pub enum SymbolicMultilinearPolyVariant<F: TowerField> {
48 Committed,
49 Transparent(TransparentPolyOracle<F>),
50 Structured(ArithCircuit<F>),
57 Repeating {
58 id: OracleId,
59 },
60 Projected {
61 id: OracleId,
62 values: Vec<F>,
63 variant: ProjectionVariant,
64 },
65 Shifted {
66 id: OracleId,
67 shift_offset: usize,
68 block_size: usize,
69 shift_variant: ShiftVariant,
70 },
71 Packed {
72 id: OracleId,
73 log_degree: usize,
80 },
81 LinearCombination {
82 offset: F,
83 inner: Vec<(OracleId, F)>,
84 },
85 ZeroPadded {
86 id: OracleId,
87 n_pad_vars: usize,
88 nonzero_index: usize,
89 start_index: usize,
90 },
91 Composite {
92 inner: Vec<OracleId>,
93 circuit: ArithCircuit<F>,
94 },
95}
96
97impl DeserializeBytes for SymbolicMultilinearPolyVariant<BinaryField128b> {
98 fn deserialize(
99 mut buf: impl bytes::Buf,
100 mode: SerializationMode,
101 ) -> Result<Self, SerializationError>
102 where
103 Self: Sized,
104 {
105 Ok(match u8::deserialize(&mut buf, mode)? {
106 0 => Self::Committed,
107 1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
108 2 => Self::Structured(DeserializeBytes::deserialize(buf, mode)?),
109 3 => Self::Repeating {
110 id: DeserializeBytes::deserialize(&mut buf, mode)?,
111 },
112 4 => Self::Projected {
113 id: DeserializeBytes::deserialize(&mut buf, mode)?,
114 values: DeserializeBytes::deserialize(&mut buf, mode)?,
115 variant: DeserializeBytes::deserialize(buf, mode)?,
116 },
117 5 => Self::Shifted {
118 id: DeserializeBytes::deserialize(&mut buf, mode)?,
119 shift_offset: DeserializeBytes::deserialize(&mut buf, mode)?,
120 block_size: DeserializeBytes::deserialize(&mut buf, mode)?,
121 shift_variant: DeserializeBytes::deserialize(buf, mode)?,
122 },
123 6 => Self::Packed {
124 id: DeserializeBytes::deserialize(&mut buf, mode)?,
125 log_degree: DeserializeBytes::deserialize(buf, mode)?,
126 },
127 7 => Self::LinearCombination {
128 offset: DeserializeBytes::deserialize(&mut buf, mode)?,
129 inner: DeserializeBytes::deserialize(buf, mode)?,
130 },
131 8 => Self::ZeroPadded {
132 id: DeserializeBytes::deserialize(&mut buf, mode)?,
133 n_pad_vars: DeserializeBytes::deserialize(&mut buf, mode)?,
134 nonzero_index: DeserializeBytes::deserialize(&mut buf, mode)?,
135 start_index: DeserializeBytes::deserialize(buf, mode)?,
136 },
137 9 => Self::Composite {
138 inner: DeserializeBytes::deserialize(&mut buf, mode)?,
139 circuit: DeserializeBytes::deserialize(buf, mode)?,
140 },
141 variant_index => {
142 return Err(SerializationError::UnknownEnumVariant {
143 name: "SymbolicMultilinearPolyVariant",
144 index: variant_index,
145 });
146 }
147 })
148 }
149}
150
151#[derive(Default, Debug, Clone, SerializeBytes, DeserializeBytes)]
152#[deserialize_bytes(eval_generics(F = BinaryField128b))]
153pub struct SymbolicMultilinearOracleSet<F: TowerField> {
154 oracles: Vec<SymbolicMultilinearOracle<F>>,
155}
156
157impl<F: TowerField> SymbolicMultilinearOracleSet<F> {
158 pub fn new() -> Self {
159 Self {
160 oracles: Vec::new(),
161 }
162 }
163
164 pub fn instantiate(&self, table_sizes: &[usize]) -> Result<MultilinearOracleSet<F>, Error> {
166 let mut mos = MultilinearOracleSet::new();
167 for oracle in &self.oracles {
168 let table_size = table_sizes
169 .get(oracle.table_id)
170 .ok_or(Error::TableSizeMissing {
171 table_id: oracle.table_id,
172 })?;
173 if *table_size == 0 {
174 mos.add_skip();
175 continue;
176 }
177 let log_capacity = log2_ceil_usize(*table_size);
178 let n_vars = match oracle.variant {
182 SymbolicMultilinearPolyVariant::Transparent(ref transparent_poly_oracle) => {
183 transparent_poly_oracle.poly().n_vars()
184 }
185 _ => log_capacity + oracle.log_values_per_row,
186 };
187 let tower_level = oracle.tower_level;
188 let variant = instantiate_oracle_variant(&mos, oracle, n_vars)?;
189 mos.add_to_set(|id: OracleId| MultilinearPolyOracle {
190 id,
191 name: oracle.name.clone(),
192 n_vars,
193 tower_level,
194 variant,
195 });
196 }
197 Ok(mos)
198 }
199
200 pub fn add_oracle<S: ToString>(
202 &mut self,
203 table_id: usize,
204 log_values_per_row: usize,
205 s: S,
206 ) -> Builder<'_, F> {
207 Builder {
208 mut_ref: self,
209 name: Some(s.to_string()),
210 table_id,
211 log_values_per_row,
212 }
213 }
214
215 fn add_to_set(
216 &mut self,
217 oracle: impl FnOnce(OracleId) -> SymbolicMultilinearOracle<F>,
218 ) -> OracleId {
219 let id = OracleId::from_index(self.oracles.len());
220 self.oracles.push(oracle(id));
221 id
222 }
223
224 pub fn size(&self) -> usize {
225 self.oracles.len()
226 }
227
228 pub fn polys(&self) -> impl Iterator<Item = &SymbolicMultilinearOracle<F>> + '_ {
229 (0..self.oracles.len()).map(|index| &self[OracleId::from_index(index)])
230 }
231
232 pub fn ids(&self) -> impl Iterator<Item = OracleId> {
233 (0..self.oracles.len()).map(OracleId::from_index)
234 }
235
236 pub fn iter(&self) -> impl Iterator<Item = (OracleId, &SymbolicMultilinearOracle<F>)> + '_ {
237 (0..self.oracles.len()).map(|index| {
238 let oracle_id = OracleId::from_index(index);
239 (oracle_id, &self[oracle_id])
240 })
241 }
242
243 pub fn label(&self, oracle_id: OracleId) -> Option<String> {
244 self[oracle_id].name.clone()
245 }
246}
247
248fn instantiate_oracle_variant<F: TowerField>(
249 mos: &MultilinearOracleSet<F>,
250 oracle: &SymbolicMultilinearOracle<F>,
251 n_vars: usize,
252) -> Result<MultilinearPolyVariant<F>, Error> {
253 use self::{MultilinearPolyVariant as Sized, SymbolicMultilinearPolyVariant as Symbolic};
254
255 let variant = match &oracle.variant {
256 Symbolic::Committed => MultilinearPolyVariant::Committed,
257 Symbolic::Transparent(transparent_poly_oracle) => {
258 Sized::Transparent(transparent_poly_oracle.clone())
259 }
260 Symbolic::Structured(arith_circuit) => Sized::Structured(arith_circuit.clone()),
261 Symbolic::Repeating { id } => {
262 let log_count = n_vars - mos.n_vars(*id);
263 Sized::Repeating { id: *id, log_count }
264 }
265 Symbolic::Projected {
266 id,
267 values,
268 variant,
269 } => {
270 let start_index = match variant {
271 ProjectionVariant::Offset(offset) => *offset,
272 ProjectionVariant::Last => n_vars - values.len(),
273 };
274 let projected = Projected::new(mos, *id, values.clone(), start_index)?;
275 Sized::Projected(projected)
276 }
277 Symbolic::Shifted {
278 id,
279 shift_offset,
280 block_size,
281 shift_variant,
282 } => {
283 let shifted = Shifted::new(mos, *id, *shift_offset, *block_size, *shift_variant)?;
284 MultilinearPolyVariant::Shifted(shifted)
285 }
286 Symbolic::Packed { id, log_degree } => {
287 let packed = Packed::new(*id, *log_degree);
288 MultilinearPolyVariant::Packed(packed)
289 }
290 Symbolic::LinearCombination { offset, inner } => {
291 let linear_combination = LinearCombination::new(mos, n_vars, *offset, inner.clone())?;
292 MultilinearPolyVariant::LinearCombination(linear_combination)
293 }
294 Symbolic::ZeroPadded {
295 id,
296 n_pad_vars,
297 nonzero_index,
298 start_index,
299 } => {
300 let zero_padded = ZeroPadded::new(mos, *id, *n_pad_vars, *nonzero_index, *start_index)?;
301 MultilinearPolyVariant::ZeroPadded(zero_padded)
302 }
303 Symbolic::Composite { inner, circuit } => {
304 let composite_mle = CompositeMLE::new(mos, n_vars, inner.clone(), circuit.clone())?;
305 MultilinearPolyVariant::Composite(composite_mle)
306 }
307 };
308 Ok(variant)
309}
310
311impl<F: TowerField> std::ops::Index<OracleId> for SymbolicMultilinearOracleSet<F> {
312 type Output = SymbolicMultilinearOracle<F>;
313
314 fn index(&self, id: OracleId) -> &Self::Output {
315 &self.oracles[id.index()]
316 }
317}
318
319pub struct Builder<'a, F: TowerField> {
320 mut_ref: &'a mut SymbolicMultilinearOracleSet<F>,
321 name: Option<String>,
322 table_id: usize,
323 log_values_per_row: usize,
324}
325
326impl<'a, F: TowerField> Builder<'a, F> {
327 pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
328 if poly.binary_tower_level() > F::TOWER_LEVEL {
329 bail!(Error::TowerLevelTooHigh {
330 tower_level: poly.binary_tower_level(),
331 });
332 }
333
334 let inner = TransparentPolyOracle::new(Arc::new(poly))?;
335
336 let oracle = |id: OracleId| SymbolicMultilinearOracle {
337 id,
338 table_id: self.table_id,
339 log_values_per_row: self.log_values_per_row,
340 tower_level: inner.poly().binary_tower_level(),
341 name: self.name,
342 variant: SymbolicMultilinearPolyVariant::Transparent(inner),
343 };
344
345 Ok(self.mut_ref.add_to_set(oracle))
346 }
347
348 pub fn structured(self, expr: ArithCircuit<F>) -> Result<OracleId, Error> {
349 if expr.binary_tower_level() > F::TOWER_LEVEL {
350 bail!(Error::TowerLevelTooHigh {
351 tower_level: expr.binary_tower_level(),
352 });
353 }
354
355 let oracle = |id: OracleId| SymbolicMultilinearOracle {
356 id,
357 table_id: self.table_id,
358 log_values_per_row: self.log_values_per_row,
359 tower_level: expr.binary_tower_level(),
360 name: self.name,
361 variant: SymbolicMultilinearPolyVariant::Structured(expr),
362 };
363
364 Ok(self.mut_ref.add_to_set(oracle))
365 }
366
367 pub fn committed(mut self, tower_level: usize) -> OracleId {
368 let name = self.name.take();
369 self.add_committed_with_name(tower_level, name)
370 }
371
372 pub fn committed_multiple<const N: usize>(mut self, tower_level: usize) -> [OracleId; N] {
373 match &self.name.take() {
374 None => [0; N].map(|_| self.add_committed_with_name(tower_level, None)),
375 Some(s) => {
376 let x: [usize; N] = array::from_fn(|i| i);
377 x.map(|i| self.add_committed_with_name(tower_level, Some(format!("{s}_{i}"))))
378 }
379 }
380 }
381
382 pub fn repeating(self, inner_id: OracleId) -> Result<OracleId, Error> {
383 let inner = &self.mut_ref[inner_id];
384
385 let tower_level = inner.tower_level;
386 let oracle = |id: OracleId| SymbolicMultilinearOracle {
387 id,
388 table_id: self.table_id,
389 log_values_per_row: self.log_values_per_row,
390 tower_level,
391 name: self.name,
392 variant: SymbolicMultilinearPolyVariant::Repeating { id: inner_id },
393 };
394
395 Ok(self.mut_ref.add_to_set(oracle))
396 }
397
398 pub fn shifted(
399 self,
400 inner_id: OracleId,
401 offset: usize,
402 block_bits: usize,
403 variant: ShiftVariant,
404 ) -> Result<OracleId, Error> {
405 if offset == 0 || offset >= 1 << block_bits {
406 bail!(PolynomialError::InvalidShiftOffset {
407 max_shift_offset: (1 << block_bits) - 1,
408 shift_offset: offset,
409 });
410 }
411
412 let tower_level = self.mut_ref[inner_id].tower_level;
413 let oracle = |id: OracleId| SymbolicMultilinearOracle {
414 id,
415 table_id: self.table_id,
416 log_values_per_row: self.log_values_per_row,
417 tower_level,
418 name: self.name,
419 variant: SymbolicMultilinearPolyVariant::Shifted {
420 id: inner_id,
421 shift_offset: offset,
422 block_size: block_bits,
423 shift_variant: variant,
424 },
425 };
426
427 Ok(self.mut_ref.add_to_set(oracle))
428 }
429
430 pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
431 let inner_tower_level = self.mut_ref[inner_id].tower_level;
432
433 let oracle = |id: OracleId| SymbolicMultilinearOracle {
434 id,
435 table_id: self.table_id,
436 log_values_per_row: self.log_values_per_row,
437 tower_level: inner_tower_level + log_degree,
438 name: self.name,
439 variant: SymbolicMultilinearPolyVariant::Packed {
440 id: inner_id,
441 log_degree,
442 },
443 };
444
445 Ok(self.mut_ref.add_to_set(oracle))
446 }
447
448 pub fn projected(
449 self,
450 inner_id: OracleId,
451 values: Vec<F>,
452 start_index: usize,
453 ) -> Result<OracleId, Error> {
454 let tower_level = self.mut_ref[inner_id].tower_level;
456 let oracle = |id: OracleId| SymbolicMultilinearOracle {
457 id,
458 table_id: self.table_id,
459 log_values_per_row: self.log_values_per_row,
460 tower_level,
461 name: self.name,
462 variant: SymbolicMultilinearPolyVariant::Projected {
463 id: inner_id,
464 values,
465 variant: ProjectionVariant::Offset(start_index),
466 },
467 };
468
469 Ok(self.mut_ref.add_to_set(oracle))
470 }
471
472 pub fn projected_last_vars(
473 self,
474 inner_id: OracleId,
475 values: Vec<F>,
476 ) -> Result<OracleId, Error> {
477 let tower_level = self.mut_ref[inner_id].tower_level;
479 let oracle = |id: OracleId| SymbolicMultilinearOracle {
480 id,
481 table_id: self.table_id,
482 log_values_per_row: self.log_values_per_row,
483 tower_level,
484 name: self.name,
485 variant: SymbolicMultilinearPolyVariant::Projected {
486 id: inner_id,
487 values,
488 variant: ProjectionVariant::Last,
489 },
490 };
491
492 Ok(self.mut_ref.add_to_set(oracle))
493 }
494
495 pub fn linear_combination(
496 self,
497 inner: impl IntoIterator<Item = (OracleId, F)>,
498 ) -> Result<OracleId, Error> {
499 self.linear_combination_with_offset(F::ZERO, inner)
500 }
501
502 pub fn linear_combination_with_offset(
503 self,
504 offset: F,
505 inner: impl IntoIterator<Item = (OracleId, F)>,
506 ) -> Result<OracleId, Error> {
507 let inner = inner.into_iter().collect::<Vec<_>>();
508 let tower_level = inner
509 .iter()
510 .map(|(oracle_id, coeff)| {
511 self.mut_ref[*oracle_id]
512 .tower_level
513 .max(coeff.min_tower_level())
514 })
515 .max()
516 .unwrap_or(0)
517 .max(offset.min_tower_level());
518
519 let oracle = |id: OracleId| SymbolicMultilinearOracle {
520 id,
521 table_id: self.table_id,
522 log_values_per_row: self.log_values_per_row,
523 tower_level,
524 name: self.name,
525 variant: SymbolicMultilinearPolyVariant::LinearCombination { offset, inner },
526 };
527
528 Ok(self.mut_ref.add_to_set(oracle))
529 }
530
531 pub fn composite_mle(
532 self,
533 inner: impl IntoIterator<Item = OracleId>,
534 comp: ArithCircuit<F>,
535 ) -> Result<OracleId, Error> {
536 let inner = inner.into_iter().collect::<Vec<_>>();
537 let tower_level = inner
538 .iter()
539 .map(|oracle_id| self.mut_ref[*oracle_id].tower_level)
540 .max()
541 .unwrap_or(0);
542
543 let oracle = |id: OracleId| SymbolicMultilinearOracle {
544 id,
545 table_id: self.table_id,
546 log_values_per_row: self.log_values_per_row,
547 tower_level,
548 name: self.name,
549 variant: SymbolicMultilinearPolyVariant::Composite {
550 inner,
551 circuit: comp,
552 },
553 };
554
555 Ok(self.mut_ref.add_to_set(oracle))
556 }
557
558 pub fn zero_padded(
559 self,
560 inner_id: OracleId,
561 n_pad_vars: usize,
562 nonzero_index: usize,
563 start_index: usize,
564 ) -> Result<OracleId, Error> {
565 let inner = &self.mut_ref[inner_id];
566 let tower_level = inner.tower_level;
567 let oracle = |id: OracleId| SymbolicMultilinearOracle {
568 id,
569 table_id: self.table_id,
570 log_values_per_row: self.log_values_per_row,
571 tower_level,
572 name: self.name,
573 variant: SymbolicMultilinearPolyVariant::ZeroPadded {
574 id: inner_id,
575 n_pad_vars,
576 nonzero_index,
577 start_index,
578 },
579 };
580
581 Ok(self.mut_ref.add_to_set(oracle))
582 }
583
584 fn add_committed_with_name(&mut self, tower_level: usize, name: Option<String>) -> OracleId {
585 let oracle = |oracle_id: OracleId| SymbolicMultilinearOracle {
586 id: oracle_id,
587 table_id: self.table_id,
588 log_values_per_row: self.log_values_per_row,
589 tower_level,
590 name: name.clone(),
591 variant: SymbolicMultilinearPolyVariant::Committed,
592 };
593
594 self.mut_ref.add_to_set(oracle)
595 }
596}