Skip to content

Commit 1e7dbea

Browse files
authored
support ordered sparse 64 selector (#1243)
left-over of #1213 blocked by #1241
1 parent 66945f6 commit 1e7dbea

File tree

3 files changed

+11
-13
lines changed

3 files changed

+11
-13
lines changed

ceno_recursion/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ path = "src/bin/e2e_aggregate.rs"
5454
[features]
5555
bench-metrics = ["openvm-circuit/metrics"]
5656
default = ["parallel", "nightly-features"]
57-
gpu = ["openvm-circuit/cuda", "openvm-native-circuit/cuda", "dep:openvm-cuda-backend"]
57+
gpu = ["ceno_zkvm/gpu", "openvm-circuit/cuda", "openvm-native-circuit/cuda", "dep:openvm-cuda-backend"]
5858
nightly-features = [
5959
"ceno_zkvm/nightly-features",
6060
"openvm-stark-sdk/nightly-features",

gkr_iop/src/gkr/layer/gpu/utils.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,21 +129,18 @@ pub fn build_eq_x_r_with_sel_gpu<E: ExtensionField>(
129129
}
130130

131131
let eq_len = 1 << point.len();
132-
let (num_instances, is_sp32, indices) = match selector {
132+
let (num_instances, sparse_num_var, indices) = match selector {
133133
SelectorType::None => panic!("SelectorType::None"),
134-
SelectorType::Whole(_expr) => (eq_len, false, vec![]),
135-
SelectorType::Prefix(_expr) => (selector_ctx.num_instances, false, vec![]),
136-
SelectorType::OrderedSparse32 { indices, .. } => {
137-
(selector_ctx.num_instances, true, indices.clone())
138-
}
139-
SelectorType::OrderedSparse64 { .. } => {
140-
unimplemented!("OrderedSparse64 is not supported in GPU selector path")
141-
}
134+
SelectorType::Whole(_expr) => (eq_len, 0, vec![]),
135+
SelectorType::Prefix(_expr) => (selector_ctx.num_instances, 0, vec![]),
136+
SelectorType::OrderedSparse {
137+
indices, num_vars, ..
138+
} => (selector_ctx.num_instances, *num_vars, indices.clone()),
142139
SelectorType::QuarkBinaryTreeLessThan(..) => unimplemented!(),
143140
};
144141

145142
// type eq
146-
let eq_mle = if is_sp32 {
143+
let eq_mle = if sparse_num_var > 0 {
147144
assert_eq!(selector_ctx.offset, 0);
148145
let eq = build_eq_x_r_gpu(hal, point);
149146
let mut eq_buf = match eq.mle {
@@ -152,11 +149,12 @@ pub fn build_eq_x_r_with_sel_gpu<E: ExtensionField>(
152149
GpuFieldType::Unreachable => panic!("Unreachable GpuFieldType"),
153150
};
154151
let indices_u32 = indices.iter().map(|x| *x as u32).collect_vec();
155-
ordered_sparse32_selector_gpu::<CudaHalBB31, BB31Ext, BB31Base>(
152+
ordered_sparse_selector_gpu::<CudaHalBB31, BB31Ext, BB31Base>(
156153
&hal.inner,
157154
&mut eq_buf.buf,
158155
&indices_u32,
159156
num_instances,
157+
sparse_num_var,
160158
)
161159
.unwrap();
162160
eq_buf

gkr_iop/src/gpu/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub mod gpu_prover {
2929
buffer::BufferImpl,
3030
get_ceno_gpu_device_id,
3131
mle::{
32-
build_mle_as_ceno, ordered_sparse32_selector_gpu, rotation_next_base_mle_gpu,
32+
build_mle_as_ceno, ordered_sparse_selector_gpu, rotation_next_base_mle_gpu,
3333
rotation_selector_gpu,
3434
},
3535
utils::HasUtils,

0 commit comments

Comments
 (0)