1use std::{iter, marker::PhantomData};
4
5use binius_field::{BinaryField, ExtensionField, Field, TowerField, util::inner_product_unchecked};
6use binius_math::{ArithCircuit, TowerTop, extrapolate_line_scalar};
7use binius_ntt::AdditiveNTT;
8use binius_utils::checked_arithmetics::{checked_log_2, strict_log_2};
9use bytemuck::zeroed_vec;
10use itertools::izip;
11
12use super::{memory::CpuMemory, tower_macro::each_tower_subfield};
13use crate::{
14 ComputeData, ComputeHolder, ComputeLayerExecutor, KernelExecutor,
15 alloc::{BumpAllocator, ComputeAllocator, HostBumpAllocator},
16 layer::{ComputeLayer, Error, FSlice, FSliceMut, KernelBuffer, KernelMemMap},
17 memory::{ComputeMemory, SizedSlice, SlicesBatch, SubfieldSlice},
18};
19
20#[derive(Debug, Default)]
21pub struct CpuLayer<F>(PhantomData<F>);
22
23impl<F: TowerTop> ComputeLayer<F> for CpuLayer<F> {
24 type Exec<'a> = CpuLayerExecutor<F>;
25 type DevMem = CpuMemory;
26
27 fn copy_h2d(&self, src: &[F], dst: &mut FSliceMut<'_, F, Self>) -> Result<(), Error> {
28 assert_eq!(
29 src.len(),
30 dst.len(),
31 "precondition: src and dst buffers must have the same length"
32 );
33 dst.copy_from_slice(src);
34 Ok(())
35 }
36
37 fn copy_d2h(&self, src: FSlice<'_, F, Self>, dst: &mut [F]) -> Result<(), Error> {
38 assert_eq!(
39 src.len(),
40 dst.len(),
41 "precondition: src and dst buffers must have the same length"
42 );
43 dst.copy_from_slice(src);
44 Ok(())
45 }
46
47 fn copy_d2d(
48 &self,
49 src: FSlice<'_, F, Self>,
50 dst: &mut FSliceMut<'_, F, Self>,
51 ) -> Result<(), Error> {
52 assert_eq!(
53 src.len(),
54 dst.len(),
55 "precondition: src and dst buffers must have the same length"
56 );
57 dst.copy_from_slice(src);
58 Ok(())
59 }
60
61 fn execute<'a, 'b>(
62 &'b self,
63 f: impl FnOnce(&mut Self::Exec<'a>) -> Result<Vec<F>, Error>,
64 ) -> Result<Vec<F>, Error>
65 where
66 'b: 'a,
67 {
68 f(&mut CpuLayerExecutor::<F>::default())
69 }
70
71 fn compile_expr(
72 &self,
73 expr: &ArithCircuit<F>,
74 ) -> Result<<Self::Exec<'_> as ComputeLayerExecutor<F>>::ExprEval, Error> {
75 Ok(expr.clone())
76 }
77
78 fn fill(
79 &self,
80 slice: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
81 value: F,
82 ) -> Result<(), Error> {
83 slice.fill(value);
84 Ok(())
85 }
86}
87
88#[derive(Debug)]
89pub struct CpuLayerExecutor<F>(PhantomData<F>);
90
91impl<F: TowerTop> CpuLayerExecutor<F> {
92 fn map_kernel_mem<'a>(
93 mappings: &'a mut [MemMap<'_, Self, F>],
94 local_buffer_alloc: &'a BumpAllocator<F, <Self as ComputeLayerExecutor<F>>::DevMem>,
95 log_chunks: usize,
96 i: usize,
97 ) -> Vec<Buffer<'a, Self, F>> {
98 mappings
99 .iter_mut()
100 .map(|mapping| match mapping {
101 KernelMemMap::Chunked { data, .. } => {
102 let log_size = checked_log_2(data.len());
103 let log_chunk_size = log_size - log_chunks;
104 KernelBuffer::Ref(<Self as ComputeLayerExecutor<F>>::DevMem::slice(
105 data,
106 (i << log_chunk_size)..((i + 1) << log_chunk_size),
107 ))
108 }
109 KernelMemMap::ChunkedMut { data, .. } => {
110 let log_size = checked_log_2(data.len());
111 let log_chunk_size = log_size - log_chunks;
112 KernelBuffer::Mut(<Self as ComputeLayerExecutor<F>>::DevMem::slice_mut(
113 data,
114 (i << log_chunk_size)..((i + 1) << log_chunk_size),
115 ))
116 }
117 KernelMemMap::Local { log_size } => {
118 let log_chunk_size = *log_size - log_chunks;
119 let buffer = local_buffer_alloc.alloc(1 << log_chunk_size).expect(
120 "precondition: allocator must have enough space for all local buffers",
121 );
122 KernelBuffer::Mut(buffer)
123 }
124 })
125 .collect()
126 }
127
128 fn process_kernels_chunks<R>(
129 &self,
130 map: impl Sync
131 + for<'a> Fn(
132 &'a mut CpuKernelBuilder,
133 usize,
134 Vec<KernelBuffer<'a, F, CpuMemory>>,
135 ) -> Result<R, Error>,
136 mut mem_maps: Vec<KernelMemMap<'_, F, CpuMemory>>,
137 ) -> Result<impl Iterator<Item = Result<R, Error>>, Error> {
138 let log_chunks_range = KernelMemMap::log_chunks_range(&mem_maps)
139 .expect("Many variant must have at least one entry");
140
141 let log_chunks = log_chunks_range.end;
143 let total_alloc = count_total_local_buffer_sizes(&mem_maps, log_chunks);
144 let mut local_buffer = zeroed_vec(total_alloc);
145 let iter = (0..1 << log_chunks).map(move |i| {
146 let local_buffer_alloc = BumpAllocator::new(local_buffer.as_mut());
147 let kernel_data =
148 Self::map_kernel_mem(&mut mem_maps, &local_buffer_alloc, log_chunks, i);
149 map(&mut CpuKernelBuilder, log_chunks, kernel_data)
150 });
151
152 Ok(iter)
153 }
154}
155
156impl<F> Default for CpuLayerExecutor<F> {
157 fn default() -> Self {
158 Self(PhantomData)
159 }
160}
161
162impl<F: TowerTop> ComputeLayerExecutor<F> for CpuLayerExecutor<F> {
163 type OpValue = F;
164 type ExprEval = ArithCircuit<F>;
165 type KernelExec = CpuKernelBuilder;
166 type DevMem = CpuMemory;
167
168 fn accumulate_kernels(
169 &mut self,
170 map: impl Sync
171 + for<'a> Fn(
172 &'a mut Self::KernelExec,
173 usize,
174 Vec<KernelBuffer<'a, F, Self::DevMem>>,
175 ) -> Result<Vec<F>, Error>,
176 inputs: Vec<KernelMemMap<'_, F, Self::DevMem>>,
177 ) -> Result<Vec<Self::OpValue>, Error> {
178 self.process_kernels_chunks(map, inputs)?
179 .reduce(|out1, out2| {
180 let mut out1 = out1?;
181 let mut out2_iter = out2?.into_iter();
182 for (out1_i, out2_i) in std::iter::zip(&mut out1, &mut out2_iter) {
183 *out1_i += out2_i;
184 }
185 out1.extend(out2_iter);
186 Ok(out1)
187 })
188 .expect("range is not empty")
189 }
190
191 fn map_kernels(
192 &mut self,
193 map: impl Sync
194 + for<'a> Fn(
195 &'a mut Self::KernelExec,
196 usize,
197 Vec<KernelBuffer<'a, F, Self::DevMem>>,
198 ) -> Result<(), Error>,
199 mem_maps: Vec<KernelMemMap<'_, F, Self::DevMem>>,
200 ) -> Result<(), Error> {
201 self.process_kernels_chunks(map, mem_maps)?.for_each(drop);
202 Ok(())
203 }
204
205 fn inner_product<'a>(
206 &'a mut self,
207 a_in: SubfieldSlice<'_, F, Self::DevMem>,
208 b_in: &'a [F],
209 ) -> Result<F, Error> {
210 if a_in.tower_level > F::TOWER_LEVEL
211 || a_in.slice.len() << (F::TOWER_LEVEL - a_in.tower_level) != b_in.len()
212 {
213 return Err(Error::InputValidation(format!(
214 "invalid input: a_edeg={} |a|={} |b|={}",
215 a_in.tower_level,
216 a_in.slice.len(),
217 b_in.len()
218 )));
219 }
220
221 fn inner_product<F, FExt>(a_in: &[FExt], b_in: &[FExt]) -> FExt
222 where
223 F: Field,
224 FExt: ExtensionField<F>,
225 {
226 inner_product_unchecked(
227 b_in.iter().copied(),
228 a_in.iter()
229 .flat_map(<FExt as ExtensionField<F>>::iter_bases),
230 )
231 }
232
233 let result =
234 each_tower_subfield!(a_in.tower_level, inner_product::<_, F>(a_in.slice, b_in));
235 Ok(result)
236 }
237
238 fn fold_left(
239 &mut self,
240 mat: SubfieldSlice<'_, F, Self::DevMem>,
241 vec: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
242 out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
243 ) -> Result<(), Error> {
244 if mat.tower_level > F::TOWER_LEVEL {
245 return Err(Error::InputValidation(format!(
246 "invalid evals: tower_level={} > {}",
247 mat.tower_level,
248 F::TOWER_LEVEL
249 )));
250 }
251 let log_evals_size = mat.slice.len().ilog2() as usize + F::TOWER_LEVEL - mat.tower_level;
252 each_tower_subfield!(
255 mat.tower_level,
256 compute_left_fold::<_, F>(mat.slice, log_evals_size, vec, out)
257 )
258 }
259
260 fn fold_right(
261 &mut self,
262 mat: SubfieldSlice<'_, F, Self::DevMem>,
263 vec: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
264 out: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
265 ) -> Result<(), Error> {
266 if mat.tower_level > F::TOWER_LEVEL {
267 return Err(Error::InputValidation(format!(
268 "invalid evals: tower_level={} > {}",
269 mat.tower_level,
270 F::TOWER_LEVEL
271 )));
272 }
273 let log_evals_size = mat.slice.len().ilog2() as usize + F::TOWER_LEVEL - mat.tower_level;
274 each_tower_subfield!(
277 mat.tower_level,
278 compute_right_fold::<_, F>(mat.slice, log_evals_size, vec, out)
279 )
280 }
281
282 fn tensor_expand(
283 &mut self,
284 log_n: usize,
285 coordinates: &[F],
286 data: &mut &mut [F],
287 ) -> Result<(), Error> {
288 if data.len() != 1 << (log_n + coordinates.len()) {
289 return Err(Error::InputValidation(format!("invalid data length: {}", data.len())));
290 }
291
292 for (i, r_i) in coordinates.iter().enumerate() {
293 let (lhs, rest) = data.split_at_mut(1 << (log_n + i));
294 let (rhs, _rest) = rest.split_at_mut(1 << (log_n + i));
295 for (x_i, y_i) in std::iter::zip(lhs, rhs) {
296 let prod = *x_i * r_i;
297 *x_i -= prod;
298 *y_i += prod;
299 }
300 }
301 Ok(())
302 }
303
304 fn fri_fold<FSub>(
305 &mut self,
306 ntt: &(impl AdditiveNTT<FSub> + Sync),
307 log_len: usize,
308 log_batch_size: usize,
309 challenges: &[F],
310 data_in: &[F],
311 data_out: &mut &mut [F],
312 ) -> Result<(), Error>
313 where
314 FSub: BinaryField,
315 F: ExtensionField<FSub>,
316 {
317 if data_in.len() != 1 << (log_len + log_batch_size) {
318 return Err(Error::InputValidation(format!(
319 "invalid data_in length: {}",
320 data_in.len()
321 )));
322 }
323
324 if challenges.len() < log_batch_size {
325 return Err(Error::InputValidation(format!(
326 "invalid challenges length: {}",
327 challenges.len()
328 )));
329 }
330
331 if challenges.len() > log_batch_size + log_len {
332 return Err(Error::InputValidation(format!(
333 "challenges length too big: {}",
334 challenges.len()
335 )));
336 }
337
338 if data_out.len() != 1 << (log_len - (challenges.len() - log_batch_size)) {
339 return Err(Error::InputValidation(format!(
340 "invalid data_out length: {}",
341 data_out.len()
342 )));
343 }
344
345 let (interleave_challenges, fold_challenges) = challenges.split_at(log_batch_size);
346 let log_size = fold_challenges.len();
347
348 let mut values = vec![F::ZERO; 1 << challenges.len()];
349 for (chunk_index, (chunk, out)) in data_in
350 .chunks_exact(1 << challenges.len())
351 .zip(data_out.iter_mut())
352 .enumerate()
353 {
354 values[..(1 << challenges.len())].copy_from_slice(chunk);
356 let mut current_values = &mut values[0..1 << challenges.len()];
357 for challenge in interleave_challenges {
358 let new_num_elements = current_values.len() / 2;
359 for out_idx in 0..new_num_elements {
360 current_values[out_idx] = extrapolate_line_scalar(
361 current_values[out_idx * 2],
362 current_values[out_idx * 2 + 1],
363 *challenge,
364 );
365 }
366 current_values = &mut current_values[0..new_num_elements];
367 }
368
369 let mut log_len = log_len;
371 let mut log_size = log_size;
372 for &challenge in fold_challenges {
373 for index_offset in 0..1 << (log_size - 1) {
374 let t = ntt
375 .get_subspace_eval(log_len, (chunk_index << (log_size - 1)) | index_offset);
376 let (mut u, mut v) =
377 (values[index_offset << 1], values[(index_offset << 1) | 1]);
378 v += u;
379 u += v * t;
380 values[index_offset] = extrapolate_line_scalar(u, v, challenge);
381 }
382
383 log_len -= 1;
384 log_size -= 1;
385 }
386
387 *out = values[0];
388 }
389
390 Ok(())
391 }
392
393 fn extrapolate_line(
394 &mut self,
395 evals_0: &mut &mut [F],
396 evals_1: &[F],
397 z: F,
398 ) -> Result<(), Error> {
399 if evals_0.len() != evals_1.len() {
400 return Err(Error::InputValidation(
401 "evals_0 and evals_1 must be the same length".into(),
402 ));
403 }
404 for (x0, x1) in iter::zip(&mut **evals_0, evals_1) {
405 *x0 += (*x1 - *x0) * z
406 }
407 Ok(())
408 }
409
410 fn compute_composite(
411 &mut self,
412 inputs: &SlicesBatch<<Self::DevMem as ComputeMemory<F>>::FSlice<'_>>,
413 output: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
414 composition: &Self::ExprEval,
415 ) -> Result<(), Error> {
416 if inputs.row_len() != output.len() {
417 return Err(Error::InputValidation("inputs and output must be the same length".into()));
418 }
419
420 if composition.n_vars() != inputs.n_rows() {
421 return Err(Error::InputValidation("composition not match with input".into()));
422 }
423
424 let mut query = zeroed_vec(inputs.n_rows());
425
426 for (i, output) in output.iter_mut().enumerate() {
427 for (j, query) in query.iter_mut().enumerate() {
428 *query = inputs.row(j)[i];
429 }
430
431 *output = composition.evaluate(&query).expect("Evaluation to succeed");
432 }
433
434 Ok(())
435 }
436
437 fn pairwise_product_reduce(
438 &mut self,
439 input: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
440 round_outputs: &mut [<Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>],
441 ) -> Result<(), Error> {
442 let log_num_inputs = match strict_log_2(input.len()) {
443 None => {
444 return Err(Error::InputValidation(format!(
445 "input length must be a power of 2: {}",
446 input.len()
447 )));
448 }
449 Some(0) => {
450 return Err(Error::InputValidation(format!(
451 "input length must be greater than or equal to 2 in order to perform at least one reduction: {}",
452 input.len()
453 )));
454 }
455 Some(log_num_inputs) => log_num_inputs,
456 };
457 let expected_round_outputs_len = log_num_inputs;
458 if round_outputs.len() != expected_round_outputs_len as usize {
459 return Err(Error::InputValidation(format!(
460 "round_outputs.len() does not match the expected length: {} != {expected_round_outputs_len}",
461 round_outputs.len()
462 )));
463 }
464 for (round_idx, round_output_data) in round_outputs.iter().enumerate() {
465 let expected_output_size = 1usize << (log_num_inputs as usize - round_idx - 1);
466 if round_output_data.len() != expected_output_size {
467 return Err(Error::InputValidation(format!(
468 "round_outputs[{}].len() = {}, expected {expected_output_size}",
469 round_idx,
470 round_output_data.len()
471 )));
472 }
473 }
474 let mut round_data_source = input;
475 for round_output_data in round_outputs.iter_mut() {
476 for idx in 0..round_output_data.len() {
477 round_output_data[idx] =
478 round_data_source[idx * 2] * round_data_source[idx * 2 + 1];
479 }
480 round_data_source = round_output_data
481 }
482
483 Ok(())
484 }
485}
486
487#[derive(Debug)]
488pub struct CpuKernelBuilder;
489
490impl<F: TowerField> KernelExecutor<F> for CpuKernelBuilder {
491 type Mem = CpuMemory;
492 type Value = F;
493 type ExprEval = ArithCircuit<F>;
494
495 fn decl_value(&mut self, init: F) -> Result<F, Error> {
496 Ok(init)
497 }
498
499 fn sum_composition_evals(
500 &mut self,
501 inputs: &SlicesBatch<<Self::Mem as ComputeMemory<F>>::FSlice<'_>>,
502 composition: &Self::ExprEval,
503 batch_coeff: F,
504 accumulator: &mut Self::Value,
505 ) -> Result<(), Error> {
506 let ret = (0..inputs.row_len())
507 .map(|i| {
508 let row = inputs.iter().map(|input| input[i]).collect::<Vec<_>>();
509 composition.evaluate(&row).expect("Evaluation to succeed")
510 })
511 .sum::<F>();
512 *accumulator += ret * batch_coeff;
513 Ok(())
514 }
515
516 fn add(
517 &mut self,
518 log_len: usize,
519 src1: &'_ [F],
520 src2: &'_ [F],
521 dst: &mut &'_ mut [F],
522 ) -> Result<(), Error> {
523 assert_eq!(src1.len(), 1 << log_len);
524 assert_eq!(src2.len(), 1 << log_len);
525 assert_eq!(dst.len(), 1 << log_len);
526
527 for (dst_i, &src1_i, &src2_i) in izip!(&mut **dst, src1, src2) {
528 *dst_i = src1_i + src2_i;
529 }
530
531 Ok(())
532 }
533
534 fn add_assign(
535 &mut self,
536 log_len: usize,
537 src: &'_ [F],
538 dst: &mut &'_ mut [F],
539 ) -> Result<(), Error> {
540 assert_eq!(src.len(), 1 << log_len);
541 assert_eq!(dst.len(), 1 << log_len);
542
543 for (dst_i, &src_i) in iter::zip(&mut **dst, src) {
544 *dst_i += src_i;
545 }
546
547 Ok(())
548 }
549}
550
551type MemMap<'a, C, F> = KernelMemMap<'a, F, <C as ComputeLayerExecutor<F>>::DevMem>;
554type Buffer<'a, C, F> = KernelBuffer<'a, F, <C as ComputeLayerExecutor<F>>::DevMem>;
555
556pub fn count_total_local_buffer_sizes<F, Mem: ComputeMemory<F>>(
557 mappings: &[KernelMemMap<F, Mem>],
558 log_chunks: usize,
559) -> usize {
560 mappings
561 .iter()
562 .map(|mapping| match mapping {
563 KernelMemMap::Chunked { .. } | KernelMemMap::ChunkedMut { .. } => 0,
564 KernelMemMap::Local { log_size } => 1 << log_size.saturating_sub(log_chunks),
565 })
566 .sum()
567}
568
569fn compute_left_fold<EvalType: TowerField, F: TowerTop + ExtensionField<EvalType>>(
575 evals_as_b128: &[F],
576 log_evals_size: usize,
577 query: &[F],
578 out: FSliceMut<'_, F, CpuLayer<F>>,
579) -> Result<(), Error> {
580 let evals = evals_as_b128
581 .iter()
582 .flat_map(ExtensionField::<EvalType>::iter_bases)
583 .collect::<Vec<_>>();
584 let log_query_size = query.len().ilog2() as usize;
585 let num_cols = 1 << log_query_size;
586 let num_rows = 1 << (log_evals_size - log_query_size);
587
588 if evals.len() != num_cols * num_rows {
589 return Err(Error::InputValidation(format!(
590 "evals has {} elements, expected {}",
591 evals.len(),
592 num_cols * num_rows
593 )));
594 }
595
596 if query.len() != num_cols {
597 return Err(Error::InputValidation(format!(
598 "query has {} elements, expected {}",
599 query.len(),
600 num_cols
601 )));
602 }
603
604 if out.len() != num_rows {
605 return Err(Error::InputValidation(format!(
606 "output has {} elements, expected {}",
607 out.len(),
608 num_rows
609 )));
610 }
611
612 for i in 0..num_rows {
613 let mut acc = F::ZERO;
614 for j in 0..num_cols {
615 acc += query[j] * evals[j * num_rows + i];
616 }
617 out[i] = acc;
618 }
619
620 Ok(())
621}
622
623fn compute_right_fold<EvalType: TowerField, F: TowerTop + ExtensionField<EvalType>>(
629 evals_as_b128: &[F],
630 log_evals_size: usize,
631 query: &[F],
632 out: FSliceMut<'_, F, CpuLayer<F>>,
633) -> Result<(), Error> {
634 let evals = evals_as_b128
635 .iter()
636 .flat_map(ExtensionField::<EvalType>::iter_bases)
637 .collect::<Vec<_>>();
638 let log_query_size = query.len().ilog2() as usize;
639 let num_rows = 1 << log_query_size;
640 let num_cols = 1 << (log_evals_size - log_query_size);
641
642 if evals.len() != num_cols * num_rows {
643 return Err(Error::InputValidation(format!(
644 "evals has {} elements, expected {}",
645 evals.len(),
646 num_cols * num_rows
647 )));
648 }
649
650 if query.len() != num_rows {
651 return Err(Error::InputValidation(format!(
652 "query has {} elements, expected {}",
653 query.len(),
654 num_rows
655 )));
656 }
657
658 if out.len() != num_cols {
659 return Err(Error::InputValidation(format!(
660 "output has {} elements, expected {}",
661 out.len(),
662 num_cols
663 )));
664 }
665
666 for i in 0..num_cols {
667 let mut acc = F::ZERO;
668 for j in 0..num_rows {
669 acc += query[j] * evals[i * num_rows + j];
670 }
671 out[i] = acc;
672 }
673
674 Ok(())
675}
676
677#[derive(Default)]
678pub struct CpuLayerHolder<F> {
679 layer: CpuLayer<F>,
680 host_mem: Vec<F>,
681 dev_mem: Vec<F>,
682}
683
684impl<F: TowerTop> CpuLayerHolder<F> {
685 pub fn new(host_mem_size: usize, dev_mem_size: usize) -> Self {
686 let cpu_mem = zeroed_vec(host_mem_size);
687 let dev_mem = zeroed_vec(dev_mem_size);
688 Self {
689 layer: CpuLayer::default(),
690 host_mem: cpu_mem,
691 dev_mem,
692 }
693 }
694}
695
696impl<F: TowerTop> ComputeHolder<F, CpuLayer<F>> for CpuLayerHolder<F> {
697 type HostComputeAllocator<'a> = HostBumpAllocator<'a, F>;
698 type DeviceComputeAllocator<'a> =
699 BumpAllocator<'a, F, <CpuLayer<F> as ComputeLayer<F>>::DevMem>;
700
701 fn to_data<'a, 'b>(
702 &'a mut self,
703 ) -> ComputeData<
704 'a,
705 F,
706 CpuLayer<F>,
707 Self::HostComputeAllocator<'b>,
708 Self::DeviceComputeAllocator<'b>,
709 >
710 where
711 'a: 'b,
712 {
713 ComputeData::new(
714 &self.layer,
715 BumpAllocator::new(self.host_mem.as_mut_slice()),
716 BumpAllocator::new(self.dev_mem.as_mut_slice()),
717 )
718 }
719}