binius_core/protocols/
test_utils.rs1use std::ops::Deref;
4
5use binius_field::{ExtensionField, Field, PackedField};
6use binius_math::{ArithExpr, CompositionPoly, MLEEmbeddingAdapter, MultilinearExtension};
7use rand::Rng;
8
9use crate::polynomial::Error as PolynomialError;
10
11#[derive(Clone, Debug)]
12pub struct AddOneComposition<Inner> {
13 inner: Inner,
14}
15
16impl<Inner> AddOneComposition<Inner> {
17 pub const fn new(inner: Inner) -> Self {
18 Self { inner }
19 }
20}
21
22impl<P, Inner> CompositionPoly<P> for AddOneComposition<Inner>
23where
24 P: PackedField,
25 Inner: CompositionPoly<P>,
26{
27 fn n_vars(&self) -> usize {
28 self.inner.n_vars()
29 }
30
31 fn degree(&self) -> usize {
32 self.inner.degree()
33 }
34
35 fn binary_tower_level(&self) -> usize {
36 self.inner.binary_tower_level()
37 }
38
39 fn expression(&self) -> ArithExpr<P::Scalar> {
40 self.inner.expression() + ArithExpr::one()
41 }
42
43 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
44 Ok(self.inner.evaluate(query)? + P::one())
45 }
46}
47
48#[derive(Clone, Debug)]
49pub struct TestProductComposition {
50 arity: usize,
51}
52
53impl TestProductComposition {
54 pub const fn new(arity: usize) -> Self {
55 Self { arity }
56 }
57}
58
59impl<P> CompositionPoly<P> for TestProductComposition
60where
61 P: PackedField,
62{
63 fn n_vars(&self) -> usize {
64 self.arity
65 }
66
67 fn degree(&self) -> usize {
68 self.arity
69 }
70
71 fn binary_tower_level(&self) -> usize {
72 0
73 }
74
75 fn expression(&self) -> ArithExpr<P::Scalar> {
76 (0..self.arity).map(ArithExpr::Var).product()
77 }
78
79 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
80 let n_vars = self.arity;
81 assert_eq!(query.len(), n_vars);
82 Ok(query.iter().copied().product())
84 }
85}
86
87pub fn generate_zero_product_multilinears<P, PE>(
88 mut rng: impl Rng,
89 n_vars: usize,
90 n_multilinears: usize,
91) -> Vec<MLEEmbeddingAdapter<P, PE>>
92where
93 P: PackedField,
94 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
95{
96 (0..n_multilinears)
97 .map(|j| {
98 let values = (0..(1 << n_vars.saturating_sub(P::LOG_WIDTH)))
99 .map(|i| {
103 let mut packed = P::random(&mut rng);
104 for k in 0..P::WIDTH {
105 if (k + i * P::WIDTH) % n_multilinears == j {
106 packed.set(k, P::Scalar::ZERO);
107 }
108 }
109 if n_vars < P::LOG_WIDTH {
110 for k in (1 << n_vars)..P::WIDTH {
111 packed.set(k, P::Scalar::ZERO);
112 }
113 }
114 packed
115 })
116 .collect();
117 MultilinearExtension::new(n_vars, values)
118 .unwrap()
119 .specialize::<PE>()
120 })
121 .collect()
122}
123
124pub fn transform_poly<F, OF, Data>(
125 multilin: &MultilinearExtension<F, Data>,
126) -> Result<MultilinearExtension<OF>, PolynomialError>
127where
128 F: Field,
129 OF: Field + From<F> + Into<F>,
130 Data: Deref<Target = [F]>,
131{
132 let values = multilin.evals().iter().copied().map(OF::from).collect();
133
134 Ok(MultilinearExtension::from_values(values)?)
135}