binius_core/oracle/
symbolic.rs

1// Copyright 2025 Irreducible Inc.
2
3//! A symbolic representation of the multilinear oracle set. The main difference from the concrete
4//! one is that it is sizeless, i.e. does not have any information about n_vars.
5
6use std::{array, sync::Arc};
7
8use binius_field::{BinaryField128b, TowerField};
9use binius_macros::{DeserializeBytes, SerializeBytes};
10use binius_math::ArithCircuit;
11use binius_utils::{
12	DeserializeBytes, SerializationError, SerializationMode, bail,
13	checked_arithmetics::log2_ceil_usize,
14};
15
16use super::{
17	CompositeMLE, LinearCombination, MultilinearOracleSet, MultilinearPolyOracle,
18	MultilinearPolyVariant, Packed, Projected, ShiftVariant, Shifted, TransparentPolyOracle,
19	ZeroPadded,
20};
21use crate::{
22	constraint_system::TableId,
23	oracle::{Error, OracleId},
24	polynomial::{Error as PolynomialError, MultivariatePoly},
25};
26
27#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
28#[deserialize_bytes(eval_generics(F = BinaryField128b))]
29pub struct SymbolicMultilinearOracle<F: TowerField> {
30	pub id: OracleId,
31	pub name: Option<String>,
32	pub table_id: TableId,
33	pub log_values_per_row: usize,
34	pub tower_level: usize,
35	pub variant: SymbolicMultilinearPolyVariant<F>,
36}
37
38#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
39pub enum ProjectionVariant {
40	/// Project values starting at the given index.
41	Offset(usize),
42	/// Add projection to the last variables.
43	Last,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)]
47pub enum SymbolicMultilinearPolyVariant<F: TowerField> {
48	Committed,
49	Transparent(TransparentPolyOracle<F>),
50	/// A structured virtual polynomial is one that can be evaluated succinctly by a verifier.
51	///
52	/// These are referred to as "MLE-structured" tables in [Lasso]. The evaluation algorithm is
53	/// expressed as an arithmetic circuit, of polynomial size in the number of variables.
54	///
55	/// [Lasso]: <https://eprint.iacr.org/2023/1216>
56	Structured(ArithCircuit<F>),
57	Repeating {
58		id: OracleId,
59	},
60	Projected {
61		id: OracleId,
62		values: Vec<F>,
63		variant: ProjectionVariant,
64	},
65	Shifted {
66		id: OracleId,
67		shift_offset: usize,
68		block_size: usize,
69		shift_variant: ShiftVariant,
70	},
71	Packed {
72		id: OracleId,
73		/// The number of tower levels increased by the packing operation.
74		///
75		/// This is the base 2 logarithm of the field extension, and is called $\kappa$ in [DP23],
76		/// Section 4.3.
77		///
78		/// [DP23]: https://eprint.iacr.org/2023/1784
79		log_degree: usize,
80	},
81	LinearCombination {
82		offset: F,
83		inner: Vec<(OracleId, F)>,
84	},
85	ZeroPadded {
86		id: OracleId,
87		n_pad_vars: usize,
88		nonzero_index: usize,
89		start_index: usize,
90	},
91	Composite {
92		inner: Vec<OracleId>,
93		circuit: ArithCircuit<F>,
94	},
95}
96
97impl DeserializeBytes for SymbolicMultilinearPolyVariant<BinaryField128b> {
98	fn deserialize(
99		mut buf: impl bytes::Buf,
100		mode: SerializationMode,
101	) -> Result<Self, SerializationError>
102	where
103		Self: Sized,
104	{
105		Ok(match u8::deserialize(&mut buf, mode)? {
106			0 => Self::Committed,
107			1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?),
108			2 => Self::Structured(DeserializeBytes::deserialize(buf, mode)?),
109			3 => Self::Repeating {
110				id: DeserializeBytes::deserialize(&mut buf, mode)?,
111			},
112			4 => Self::Projected {
113				id: DeserializeBytes::deserialize(&mut buf, mode)?,
114				values: DeserializeBytes::deserialize(&mut buf, mode)?,
115				variant: DeserializeBytes::deserialize(buf, mode)?,
116			},
117			5 => Self::Shifted {
118				id: DeserializeBytes::deserialize(&mut buf, mode)?,
119				shift_offset: DeserializeBytes::deserialize(&mut buf, mode)?,
120				block_size: DeserializeBytes::deserialize(&mut buf, mode)?,
121				shift_variant: DeserializeBytes::deserialize(buf, mode)?,
122			},
123			6 => Self::Packed {
124				id: DeserializeBytes::deserialize(&mut buf, mode)?,
125				log_degree: DeserializeBytes::deserialize(buf, mode)?,
126			},
127			7 => Self::LinearCombination {
128				offset: DeserializeBytes::deserialize(&mut buf, mode)?,
129				inner: DeserializeBytes::deserialize(buf, mode)?,
130			},
131			8 => Self::ZeroPadded {
132				id: DeserializeBytes::deserialize(&mut buf, mode)?,
133				n_pad_vars: DeserializeBytes::deserialize(&mut buf, mode)?,
134				nonzero_index: DeserializeBytes::deserialize(&mut buf, mode)?,
135				start_index: DeserializeBytes::deserialize(buf, mode)?,
136			},
137			9 => Self::Composite {
138				inner: DeserializeBytes::deserialize(&mut buf, mode)?,
139				circuit: DeserializeBytes::deserialize(buf, mode)?,
140			},
141			variant_index => {
142				return Err(SerializationError::UnknownEnumVariant {
143					name: "SymbolicMultilinearPolyVariant",
144					index: variant_index,
145				});
146			}
147		})
148	}
149}
150
151#[derive(Default, Debug, Clone, SerializeBytes, DeserializeBytes)]
152#[deserialize_bytes(eval_generics(F = BinaryField128b))]
153pub struct SymbolicMultilinearOracleSet<F: TowerField> {
154	oracles: Vec<SymbolicMultilinearOracle<F>>,
155}
156
157impl<F: TowerField> SymbolicMultilinearOracleSet<F> {
158	pub fn new() -> Self {
159		Self {
160			oracles: Vec::new(),
161		}
162	}
163
164	/// Instantiate the given symbolic multilinear oracle set into the concrete one.
165	pub fn instantiate(&self, table_sizes: &[usize]) -> Result<MultilinearOracleSet<F>, Error> {
166		let mut mos = MultilinearOracleSet::new();
167		for oracle in &self.oracles {
168			let table_size = table_sizes
169				.get(oracle.table_id)
170				.ok_or(Error::TableSizeMissing {
171					table_id: oracle.table_id,
172				})?;
173			if *table_size == 0 {
174				mos.add_skip();
175				continue;
176			}
177			let log_capacity = log2_ceil_usize(*table_size);
178			// Come up with the n_vars for the multilinear oracle being created. This is typically
179			// `log_capacity + log_values_per_row`. However, there is an exception for `Transparent`
180			// which has a n_vars for just a single row.
181			let n_vars = match oracle.variant {
182				SymbolicMultilinearPolyVariant::Transparent(ref transparent_poly_oracle) => {
183					transparent_poly_oracle.poly().n_vars()
184				}
185				_ => log_capacity + oracle.log_values_per_row,
186			};
187			let tower_level = oracle.tower_level;
188			let variant = instantiate_oracle_variant(&mos, oracle, n_vars)?;
189			mos.add_to_set(|id: OracleId| MultilinearPolyOracle {
190				id,
191				name: oracle.name.clone(),
192				n_vars,
193				tower_level,
194				variant,
195			});
196		}
197		Ok(mos)
198	}
199
200	/// Adds a new oracle to the set.
201	pub fn add_oracle<S: ToString>(
202		&mut self,
203		table_id: usize,
204		log_values_per_row: usize,
205		s: S,
206	) -> Builder<'_, F> {
207		Builder {
208			mut_ref: self,
209			name: Some(s.to_string()),
210			table_id,
211			log_values_per_row,
212		}
213	}
214
215	fn add_to_set(
216		&mut self,
217		oracle: impl FnOnce(OracleId) -> SymbolicMultilinearOracle<F>,
218	) -> OracleId {
219		let id = OracleId::from_index(self.oracles.len());
220		self.oracles.push(oracle(id));
221		id
222	}
223
224	pub fn size(&self) -> usize {
225		self.oracles.len()
226	}
227
228	pub fn polys(&self) -> impl Iterator<Item = &SymbolicMultilinearOracle<F>> + '_ {
229		(0..self.oracles.len()).map(|index| &self[OracleId::from_index(index)])
230	}
231
232	pub fn ids(&self) -> impl Iterator<Item = OracleId> {
233		(0..self.oracles.len()).map(OracleId::from_index)
234	}
235
236	pub fn iter(&self) -> impl Iterator<Item = (OracleId, &SymbolicMultilinearOracle<F>)> + '_ {
237		(0..self.oracles.len()).map(|index| {
238			let oracle_id = OracleId::from_index(index);
239			(oracle_id, &self[oracle_id])
240		})
241	}
242
243	pub fn label(&self, oracle_id: OracleId) -> Option<String> {
244		self[oracle_id].name.clone()
245	}
246}
247
248fn instantiate_oracle_variant<F: TowerField>(
249	mos: &MultilinearOracleSet<F>,
250	oracle: &SymbolicMultilinearOracle<F>,
251	n_vars: usize,
252) -> Result<MultilinearPolyVariant<F>, Error> {
253	use self::{MultilinearPolyVariant as Sized, SymbolicMultilinearPolyVariant as Symbolic};
254
255	let variant = match &oracle.variant {
256		Symbolic::Committed => MultilinearPolyVariant::Committed,
257		Symbolic::Transparent(transparent_poly_oracle) => {
258			Sized::Transparent(transparent_poly_oracle.clone())
259		}
260		Symbolic::Structured(arith_circuit) => Sized::Structured(arith_circuit.clone()),
261		Symbolic::Repeating { id } => {
262			let log_count = n_vars - mos.n_vars(*id);
263			Sized::Repeating { id: *id, log_count }
264		}
265		Symbolic::Projected {
266			id,
267			values,
268			variant,
269		} => {
270			let start_index = match variant {
271				ProjectionVariant::Offset(offset) => *offset,
272				ProjectionVariant::Last => n_vars - values.len(),
273			};
274			let projected = Projected::new(mos, *id, values.clone(), start_index)?;
275			Sized::Projected(projected)
276		}
277		Symbolic::Shifted {
278			id,
279			shift_offset,
280			block_size,
281			shift_variant,
282		} => {
283			let shifted = Shifted::new(mos, *id, *shift_offset, *block_size, *shift_variant)?;
284			MultilinearPolyVariant::Shifted(shifted)
285		}
286		Symbolic::Packed { id, log_degree } => {
287			let packed = Packed::new(*id, *log_degree);
288			MultilinearPolyVariant::Packed(packed)
289		}
290		Symbolic::LinearCombination { offset, inner } => {
291			let linear_combination = LinearCombination::new(mos, n_vars, *offset, inner.clone())?;
292			MultilinearPolyVariant::LinearCombination(linear_combination)
293		}
294		Symbolic::ZeroPadded {
295			id,
296			n_pad_vars,
297			nonzero_index,
298			start_index,
299		} => {
300			let zero_padded = ZeroPadded::new(mos, *id, *n_pad_vars, *nonzero_index, *start_index)?;
301			MultilinearPolyVariant::ZeroPadded(zero_padded)
302		}
303		Symbolic::Composite { inner, circuit } => {
304			let composite_mle = CompositeMLE::new(mos, n_vars, inner.clone(), circuit.clone())?;
305			MultilinearPolyVariant::Composite(composite_mle)
306		}
307	};
308	Ok(variant)
309}
310
311impl<F: TowerField> std::ops::Index<OracleId> for SymbolicMultilinearOracleSet<F> {
312	type Output = SymbolicMultilinearOracle<F>;
313
314	fn index(&self, id: OracleId) -> &Self::Output {
315		&self.oracles[id.index()]
316	}
317}
318
319pub struct Builder<'a, F: TowerField> {
320	mut_ref: &'a mut SymbolicMultilinearOracleSet<F>,
321	name: Option<String>,
322	table_id: usize,
323	log_values_per_row: usize,
324}
325
326impl<'a, F: TowerField> Builder<'a, F> {
327	pub fn transparent(self, poly: impl MultivariatePoly<F> + 'static) -> Result<OracleId, Error> {
328		if poly.binary_tower_level() > F::TOWER_LEVEL {
329			bail!(Error::TowerLevelTooHigh {
330				tower_level: poly.binary_tower_level(),
331			});
332		}
333
334		let inner = TransparentPolyOracle::new(Arc::new(poly))?;
335
336		let oracle = |id: OracleId| SymbolicMultilinearOracle {
337			id,
338			table_id: self.table_id,
339			log_values_per_row: self.log_values_per_row,
340			tower_level: inner.poly().binary_tower_level(),
341			name: self.name,
342			variant: SymbolicMultilinearPolyVariant::Transparent(inner),
343		};
344
345		Ok(self.mut_ref.add_to_set(oracle))
346	}
347
348	pub fn structured(self, expr: ArithCircuit<F>) -> Result<OracleId, Error> {
349		if expr.binary_tower_level() > F::TOWER_LEVEL {
350			bail!(Error::TowerLevelTooHigh {
351				tower_level: expr.binary_tower_level(),
352			});
353		}
354
355		let oracle = |id: OracleId| SymbolicMultilinearOracle {
356			id,
357			table_id: self.table_id,
358			log_values_per_row: self.log_values_per_row,
359			tower_level: expr.binary_tower_level(),
360			name: self.name,
361			variant: SymbolicMultilinearPolyVariant::Structured(expr),
362		};
363
364		Ok(self.mut_ref.add_to_set(oracle))
365	}
366
367	pub fn committed(mut self, tower_level: usize) -> OracleId {
368		let name = self.name.take();
369		self.add_committed_with_name(tower_level, name)
370	}
371
372	pub fn committed_multiple<const N: usize>(mut self, tower_level: usize) -> [OracleId; N] {
373		match &self.name.take() {
374			None => [0; N].map(|_| self.add_committed_with_name(tower_level, None)),
375			Some(s) => {
376				let x: [usize; N] = array::from_fn(|i| i);
377				x.map(|i| self.add_committed_with_name(tower_level, Some(format!("{s}_{i}"))))
378			}
379		}
380	}
381
382	pub fn repeating(self, inner_id: OracleId) -> Result<OracleId, Error> {
383		let inner = &self.mut_ref[inner_id];
384
385		let tower_level = inner.tower_level;
386		let oracle = |id: OracleId| SymbolicMultilinearOracle {
387			id,
388			table_id: self.table_id,
389			log_values_per_row: self.log_values_per_row,
390			tower_level,
391			name: self.name,
392			variant: SymbolicMultilinearPolyVariant::Repeating { id: inner_id },
393		};
394
395		Ok(self.mut_ref.add_to_set(oracle))
396	}
397
398	pub fn shifted(
399		self,
400		inner_id: OracleId,
401		offset: usize,
402		block_bits: usize,
403		variant: ShiftVariant,
404	) -> Result<OracleId, Error> {
405		if offset == 0 || offset >= 1 << block_bits {
406			bail!(PolynomialError::InvalidShiftOffset {
407				max_shift_offset: (1 << block_bits) - 1,
408				shift_offset: offset,
409			});
410		}
411
412		let tower_level = self.mut_ref[inner_id].tower_level;
413		let oracle = |id: OracleId| SymbolicMultilinearOracle {
414			id,
415			table_id: self.table_id,
416			log_values_per_row: self.log_values_per_row,
417			tower_level,
418			name: self.name,
419			variant: SymbolicMultilinearPolyVariant::Shifted {
420				id: inner_id,
421				shift_offset: offset,
422				block_size: block_bits,
423				shift_variant: variant,
424			},
425		};
426
427		Ok(self.mut_ref.add_to_set(oracle))
428	}
429
430	pub fn packed(self, inner_id: OracleId, log_degree: usize) -> Result<OracleId, Error> {
431		let inner_tower_level = self.mut_ref[inner_id].tower_level;
432
433		let oracle = |id: OracleId| SymbolicMultilinearOracle {
434			id,
435			table_id: self.table_id,
436			log_values_per_row: self.log_values_per_row,
437			tower_level: inner_tower_level + log_degree,
438			name: self.name,
439			variant: SymbolicMultilinearPolyVariant::Packed {
440				id: inner_id,
441				log_degree,
442			},
443		};
444
445		Ok(self.mut_ref.add_to_set(oracle))
446	}
447
448	pub fn projected(
449		self,
450		inner_id: OracleId,
451		values: Vec<F>,
452		start_index: usize,
453	) -> Result<OracleId, Error> {
454		// TODO: This is wrong, should be F::TOWER_LEVEL
455		let tower_level = self.mut_ref[inner_id].tower_level;
456		let oracle = |id: OracleId| SymbolicMultilinearOracle {
457			id,
458			table_id: self.table_id,
459			log_values_per_row: self.log_values_per_row,
460			tower_level,
461			name: self.name,
462			variant: SymbolicMultilinearPolyVariant::Projected {
463				id: inner_id,
464				values,
465				variant: ProjectionVariant::Offset(start_index),
466			},
467		};
468
469		Ok(self.mut_ref.add_to_set(oracle))
470	}
471
472	pub fn projected_last_vars(
473		self,
474		inner_id: OracleId,
475		values: Vec<F>,
476	) -> Result<OracleId, Error> {
477		// TODO: This is wrong, should be F::TOWER_LEVEL
478		let tower_level = self.mut_ref[inner_id].tower_level;
479		let oracle = |id: OracleId| SymbolicMultilinearOracle {
480			id,
481			table_id: self.table_id,
482			log_values_per_row: self.log_values_per_row,
483			tower_level,
484			name: self.name,
485			variant: SymbolicMultilinearPolyVariant::Projected {
486				id: inner_id,
487				values,
488				variant: ProjectionVariant::Last,
489			},
490		};
491
492		Ok(self.mut_ref.add_to_set(oracle))
493	}
494
495	pub fn linear_combination(
496		self,
497		inner: impl IntoIterator<Item = (OracleId, F)>,
498	) -> Result<OracleId, Error> {
499		self.linear_combination_with_offset(F::ZERO, inner)
500	}
501
502	pub fn linear_combination_with_offset(
503		self,
504		offset: F,
505		inner: impl IntoIterator<Item = (OracleId, F)>,
506	) -> Result<OracleId, Error> {
507		let inner = inner.into_iter().collect::<Vec<_>>();
508		let tower_level = inner
509			.iter()
510			.map(|(oracle_id, coeff)| {
511				self.mut_ref[*oracle_id]
512					.tower_level
513					.max(coeff.min_tower_level())
514			})
515			.max()
516			.unwrap_or(0)
517			.max(offset.min_tower_level());
518
519		let oracle = |id: OracleId| SymbolicMultilinearOracle {
520			id,
521			table_id: self.table_id,
522			log_values_per_row: self.log_values_per_row,
523			tower_level,
524			name: self.name,
525			variant: SymbolicMultilinearPolyVariant::LinearCombination { offset, inner },
526		};
527
528		Ok(self.mut_ref.add_to_set(oracle))
529	}
530
531	pub fn composite_mle(
532		self,
533		inner: impl IntoIterator<Item = OracleId>,
534		comp: ArithCircuit<F>,
535	) -> Result<OracleId, Error> {
536		let inner = inner.into_iter().collect::<Vec<_>>();
537		let tower_level = inner
538			.iter()
539			.map(|oracle_id| self.mut_ref[*oracle_id].tower_level)
540			.max()
541			.unwrap_or(0);
542
543		let oracle = |id: OracleId| SymbolicMultilinearOracle {
544			id,
545			table_id: self.table_id,
546			log_values_per_row: self.log_values_per_row,
547			tower_level,
548			name: self.name,
549			variant: SymbolicMultilinearPolyVariant::Composite {
550				inner,
551				circuit: comp,
552			},
553		};
554
555		Ok(self.mut_ref.add_to_set(oracle))
556	}
557
558	pub fn zero_padded(
559		self,
560		inner_id: OracleId,
561		n_pad_vars: usize,
562		nonzero_index: usize,
563		start_index: usize,
564	) -> Result<OracleId, Error> {
565		let inner = &self.mut_ref[inner_id];
566		let tower_level = inner.tower_level;
567		let oracle = |id: OracleId| SymbolicMultilinearOracle {
568			id,
569			table_id: self.table_id,
570			log_values_per_row: self.log_values_per_row,
571			tower_level,
572			name: self.name,
573			variant: SymbolicMultilinearPolyVariant::ZeroPadded {
574				id: inner_id,
575				n_pad_vars,
576				nonzero_index,
577				start_index,
578			},
579		};
580
581		Ok(self.mut_ref.add_to_set(oracle))
582	}
583
584	fn add_committed_with_name(&mut self, tower_level: usize, name: Option<String>) -> OracleId {
585		let oracle = |oracle_id: OracleId| SymbolicMultilinearOracle {
586			id: oracle_id,
587			table_id: self.table_id,
588			log_values_per_row: self.log_values_per_row,
589			tower_level,
590			name: name.clone(),
591			variant: SymbolicMultilinearPolyVariant::Committed,
592		};
593
594		self.mut_ref.add_to_set(oracle)
595	}
596}