binius_math/ntt/
domain_context.rs1use binius_field::BinaryField;
6
7use super::DomainContext;
8use crate::{BinarySubspace, binary_subspace::BinarySubspaceIterator};
9
10fn generate_evals_from_subspace<F: BinaryField>(subspace: &BinarySubspace<F>) -> Vec<Vec<F>> {
20 let l = subspace.dim();
21 let mut evals = Vec::with_capacity(l);
22
23 evals.push(subspace.basis().to_vec());
25 for i in 1..l {
26 evals.push(Vec::with_capacity(l - i));
28 for k in 1..evals[i - 1].len() {
29 let val = evals[i - 1][k] * (evals[i - 1][k] + evals[i - 1][0]);
33 evals[i].push(val);
34 }
35 }
36
37 for evals_i in evals.iter_mut() {
39 let w_i_b_i_inverse = evals_i[0].invert().unwrap();
40 for eval_i_j in evals_i.iter_mut() {
41 *eval_i_j *= w_i_b_i_inverse;
42 }
43 }
44
45 evals
46}
47
48#[derive(Debug)]
55pub struct GenericOnTheFly<F> {
56 evals: Vec<Vec<F>>,
64}
65
66impl<F: BinaryField> GenericOnTheFly<F> {
67 pub fn generate_from_subspace(subspace: &BinarySubspace<F>) -> Self {
71 Self {
72 evals: generate_evals_from_subspace(subspace),
73 }
74 }
75}
76
77impl<F: BinaryField> DomainContext for GenericOnTheFly<F> {
78 type Field = F;
79
80 fn log_domain_size(&self) -> usize {
81 self.evals.len()
82 }
83
84 fn subspace(&self, i: usize) -> BinarySubspace<F> {
85 if i == 0 {
86 return BinarySubspace::with_dim(0).unwrap();
87 }
88 BinarySubspace::new_unchecked(self.evals[self.log_domain_size() - i].clone())
89 }
90
91 fn twiddle(&self, layer: usize, block: usize) -> F {
92 let v = &self.evals[self.log_domain_size() - layer - 1];
93 BinarySubspace::new_unchecked(&v[1..]).get(block)
94 }
95}
96
97#[derive(Debug)]
99pub struct GenericPreExpanded<F> {
100 evals: Vec<Vec<F>>,
102 expanded: Vec<Vec<F>>,
108}
109
110impl<F: BinaryField> GenericPreExpanded<F> {
111 pub fn generate_from_subspace(subspace: &BinarySubspace<F>) -> Self {
115 let evals = generate_evals_from_subspace(subspace);
116
117 let mut expanded = Vec::with_capacity(evals.len());
118 for basis in evals.iter().rev() {
119 let mut expanded_i = Vec::with_capacity(1 << (basis.len() - 1));
120 expanded_i.push(F::ZERO);
121 for i in 1..basis.len() {
122 for j in 0..expanded_i.len() {
123 expanded_i.push(expanded_i[j] + basis[i]);
124 }
125 }
126 assert_eq!(expanded_i.len(), 1 << (basis.len() - 1));
127 expanded.push(expanded_i)
128 }
129 assert_eq!(expanded.len(), evals.len());
130
131 Self { evals, expanded }
132 }
133}
134
135impl<F: BinaryField> DomainContext for GenericPreExpanded<F> {
136 type Field = F;
137
138 fn log_domain_size(&self) -> usize {
139 self.evals.len()
140 }
141
142 fn subspace(&self, i: usize) -> BinarySubspace<F> {
143 if i == 0 {
144 return BinarySubspace::with_dim(0).unwrap();
145 }
146 BinarySubspace::new_unchecked(self.evals[self.log_domain_size() - i].clone())
147 }
148
149 fn twiddle(&self, layer: usize, block: usize) -> F {
150 self.expanded[layer][block]
151 }
152
153 fn iter_twiddles(&self, layer: usize, log_step_by: usize) -> impl Iterator<Item = F> {
154 self.expanded[layer]
155 .iter()
156 .step_by(1 << log_step_by)
157 .copied()
158 }
159}
160
161pub trait TraceOneElement {
163 fn trace_one_element() -> Self;
165}
166
167impl TraceOneElement for binius_field::BinaryField128bGhash {
168 fn trace_one_element() -> Self {
169 Self::new(1 << 121)
170 }
171}
172
173fn gao_mateer_basis<F: BinaryField + TraceOneElement>(num_basis_elements: usize) -> Vec<F> {
181 assert!(F::N_BITS.is_power_of_two());
182
183 let mut beta = F::trace_one_element();
186
187 for _i in 0..(F::N_BITS - num_basis_elements) {
189 beta = beta.square() + beta;
190 }
191
192 let mut basis = vec![F::ZERO; num_basis_elements];
194 basis[num_basis_elements - 1] = beta;
195 for i in (1..num_basis_elements).rev() {
196 basis[i - 1] = basis[i].square() + basis[i];
197 }
198
199 assert_eq!(basis[0], F::ONE);
202
203 basis
204}
205
206#[derive(Debug)]
234pub struct GaoMateerOnTheFly<F> {
235 basis: Vec<F>,
237}
238
239impl<F: BinaryField + TraceOneElement> GaoMateerOnTheFly<F> {
240 pub fn generate(log_domain_size: usize) -> Self {
250 Self {
251 basis: gao_mateer_basis(log_domain_size),
252 }
253 }
254}
255
256impl<F: BinaryField> DomainContext for GaoMateerOnTheFly<F> {
257 type Field = F;
258
259 fn log_domain_size(&self) -> usize {
260 self.basis.len()
261 }
262
263 fn subspace(&self, i: usize) -> BinarySubspace<F> {
264 BinarySubspace::new_unchecked(self.basis[..i].to_vec())
265 }
266
267 fn twiddle(&self, layer: usize, block: usize) -> F {
268 BinarySubspace::new_unchecked(&self.basis[1..=layer]).get(block)
269 }
270
271 fn iter_twiddles(&self, layer: usize, log_step_by: usize) -> impl Iterator<Item = F> + '_ {
272 BinarySubspaceIterator::new(&self.basis[1 + log_step_by..=layer])
273 }
274}
275
276#[derive(Debug)]
281pub struct GaoMateerPreExpanded<F> {
282 basis: Vec<F>,
284 expanded: Vec<F>,
289}
290
291impl<F: BinaryField + TraceOneElement> GaoMateerPreExpanded<F> {
292 pub fn generate(log_domain_size: usize) -> Self {
302 let basis: Vec<F> = gao_mateer_basis(log_domain_size);
303
304 let mut expanded = Vec::with_capacity(1 << log_domain_size);
305 expanded.push(F::ZERO);
306 for i in 1..log_domain_size {
307 for j in 0..expanded.len() {
308 expanded.push(expanded[j] + basis[i]);
309 }
310 }
311 assert_eq!(expanded.len(), 1usize << (log_domain_size - 1));
312
313 Self { basis, expanded }
314 }
315}
316
317impl<F: BinaryField> DomainContext for GaoMateerPreExpanded<F> {
318 type Field = F;
319
320 fn log_domain_size(&self) -> usize {
321 self.basis.len()
322 }
323
324 fn subspace(&self, i: usize) -> BinarySubspace<F> {
325 BinarySubspace::new_unchecked(self.basis[..i].to_vec())
326 }
327
328 fn twiddle(&self, _layer: usize, block: usize) -> F {
329 self.expanded[block]
330 }
331
332 fn iter_twiddles(&self, layer: usize, log_step_by: usize) -> impl Iterator<Item = F> {
333 self.expanded[..1 << layer]
334 .iter()
335 .step_by(1 << log_step_by)
336 .copied()
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::test_utils::B128;
344
345 fn test_equivalence<F: BinaryField>(
346 dc_1: &impl DomainContext<Field = F>,
347 dc_2: &impl DomainContext<Field = F>,
348 log_domain_size: usize,
349 ) {
350 assert_eq!(dc_1.log_domain_size(), log_domain_size);
351 assert_eq!(dc_2.log_domain_size(), log_domain_size);
352
353 for i in 0..log_domain_size {
354 assert_eq!(dc_1.subspace(i), dc_2.subspace(i));
355
356 for block in 0..1 << i {
357 assert_eq!(dc_1.twiddle(i, block), dc_2.twiddle(i, block));
358 }
359 }
360 assert_eq!(dc_1.subspace(log_domain_size), dc_2.subspace(log_domain_size))
361 }
362
363 #[test]
364 fn test_generic() {
365 const LOG_SIZE: usize = 5;
366
367 let subspace = BinarySubspace::with_dim(LOG_SIZE).unwrap();
368
369 let dc_otf = GenericOnTheFly::<B128>::generate_from_subspace(&subspace);
370 let dc_pre = GenericPreExpanded::<B128>::generate_from_subspace(&subspace);
371
372 test_equivalence(&dc_otf, &dc_pre, LOG_SIZE);
373 }
374
375 #[test]
376 fn test_gao_mateer() {
377 const LOG_SIZE: usize = 5;
378
379 let dc_gm_otf = GaoMateerOnTheFly::<B128>::generate(LOG_SIZE);
380 let dc_gm_pre = GaoMateerPreExpanded::<B128>::generate(LOG_SIZE);
381 let dc_generic_otf =
382 GenericOnTheFly::<B128>::generate_from_subspace(&dc_gm_otf.subspace(LOG_SIZE));
383
384 test_equivalence(&dc_gm_otf, &dc_gm_pre, LOG_SIZE);
385 test_equivalence(&dc_gm_otf, &dc_generic_otf, LOG_SIZE);
386 }
387
388 #[test]
389 fn test_iter_layer() {
390 const LOG_SIZE: usize = 7;
391
392 let dc_gm_otf = GaoMateerOnTheFly::<B128>::generate(LOG_SIZE);
393 let dc_gm_pre = GaoMateerPreExpanded::<B128>::generate(LOG_SIZE);
394 let subspace = BinarySubspace::with_dim(LOG_SIZE).unwrap();
395 let dc_generic_otf = GenericOnTheFly::<B128>::generate_from_subspace(&subspace);
396 let dc_generic_pre = GenericPreExpanded::<B128>::generate_from_subspace(&subspace);
397
398 for layer in 0..LOG_SIZE {
400 let expected: Vec<_> = (0..1 << layer)
401 .map(|block| dc_gm_pre.twiddle(layer, block))
402 .collect();
403
404 assert_eq!(
405 dc_gm_otf.iter_twiddles(layer, 0).collect::<Vec<_>>(),
406 expected,
407 "GaoMateerOnTheFly iter_layer mismatch at layer {}",
408 layer
409 );
410 assert_eq!(
411 dc_gm_pre.iter_twiddles(layer, 0).collect::<Vec<_>>(),
412 expected,
413 "GaoMateerPreExpanded iter_layer mismatch at layer {}",
414 layer
415 );
416 assert_eq!(
417 dc_generic_otf.iter_twiddles(layer, 0).collect::<Vec<_>>(),
418 dc_generic_pre.iter_twiddles(layer, 0).collect::<Vec<_>>(),
419 "Generic iter_layer mismatch at layer {}",
420 layer
421 );
422 }
423 }
424}