Skip to content

Commit 4873dca

Browse files
feat[cuda]: constant array for numeric types (#6248)
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent f83fff8 commit 4873dca

File tree

12 files changed

+432
-73
lines changed

12 files changed

+432
-73
lines changed

vortex-array/src/validity.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,14 @@ impl FromIterator<bool> for Validity {
461461
impl From<Nullability> for Validity {
462462
#[inline]
463463
fn from(value: Nullability) -> Self {
464-
match value {
464+
Validity::from(&value)
465+
}
466+
}
467+
468+
impl From<&Nullability> for Validity {
469+
#[inline]
470+
fn from(value: &Nullability) -> Self {
471+
match *value {
465472
Nullability::NonNullable => Validity::NonNullable,
466473
Nullability::Nullable => Validity::AllValid,
467474
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#include "config.cuh"
5+
#include "types.cuh"
6+
#include <cuda_fp16.h>
7+
8+
// Fill an output buffer with a constant value.
9+
template<typename T>
10+
__device__ void constant_fill(
11+
T *__restrict output,
12+
T value,
13+
uint64_t array_len
14+
) {
15+
const uint64_t worker = blockIdx.x * blockDim.x + threadIdx.x;
16+
const uint64_t startElem = START_ELEM(worker, array_len);
17+
const uint64_t stopElem = STOP_ELEM(worker, array_len);
18+
19+
if (startElem >= array_len) {
20+
return;
21+
}
22+
23+
for (uint64_t idx = startElem; idx < stopElem; idx++) {
24+
output[idx] = value;
25+
}
26+
}
27+
28+
#define GENERATE_CONSTANT_NUMERIC_KERNEL(suffix, Type) \
29+
extern "C" __global__ void constant_numeric_##suffix( \
30+
Type *__restrict output, \
31+
Type value, \
32+
uint64_t array_len \
33+
) { \
34+
constant_fill(output, value, array_len); \
35+
}
36+
37+
FOR_EACH_NUMERIC(GENERATE_CONSTANT_NUMERIC_KERNEL)

vortex-cuda/kernels/src/dict.cu

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
#include <cuda.h>
5-
#include <cuda_fp16.h>
65
#include <cuda_runtime.h>
76
#include <stdint.h>
87

@@ -39,30 +38,13 @@ extern "C" __global__ void dict_##value_suffix##_##index_suffix( \
3938
dict_kernel<ValueType, IndexType>(codes, codes_len, values, output); \
4039
}
4140

42-
// Generate for all combinations of value types and index types
43-
// Value types: u8, i8, u16, i16, u32, i32, u64, i64
44-
// Index types: u8, u16, u32, u64 (codes are typically unsigned)
45-
46-
#define GENERATE_DICT_KERNELS_FOR_VALUE(value_suffix, ValueType) \
41+
// Generate dict kernel for all index types (unsigned integers) for a given value type
42+
#define GENERATE_DICT_FOR_ALL_INDICES(value_suffix, ValueType) \
4743
GENERATE_DICT_KERNEL(value_suffix, ValueType, u8, uint8_t) \
4844
GENERATE_DICT_KERNEL(value_suffix, ValueType, u16, uint16_t) \
4945
GENERATE_DICT_KERNEL(value_suffix, ValueType, u32, uint32_t) \
5046
GENERATE_DICT_KERNEL(value_suffix, ValueType, u64, uint64_t)
5147

52-
GENERATE_DICT_KERNELS_FOR_VALUE(u8, uint8_t)
53-
GENERATE_DICT_KERNELS_FOR_VALUE(i8, int8_t)
54-
GENERATE_DICT_KERNELS_FOR_VALUE(u16, uint16_t)
55-
GENERATE_DICT_KERNELS_FOR_VALUE(i16, int16_t)
56-
GENERATE_DICT_KERNELS_FOR_VALUE(u32, uint32_t)
57-
GENERATE_DICT_KERNELS_FOR_VALUE(i32, int32_t)
58-
GENERATE_DICT_KERNELS_FOR_VALUE(u64, uint64_t)
59-
GENERATE_DICT_KERNELS_FOR_VALUE(i64, int64_t)
60-
61-
// Float types
62-
GENERATE_DICT_KERNELS_FOR_VALUE(f16, __half)
63-
GENERATE_DICT_KERNELS_FOR_VALUE(f32, float)
64-
GENERATE_DICT_KERNELS_FOR_VALUE(f64, double)
48+
// Generate for all native ptypes & decimal values
49+
FOR_EACH_NUMERIC(GENERATE_DICT_FOR_ALL_INDICES)
6550

66-
// Decimal types (128-bit and 256-bit)
67-
GENERATE_DICT_KERNELS_FOR_VALUE(i128, int128_t)
68-
GENERATE_DICT_KERNELS_FOR_VALUE(i256, int256_t)

vortex-cuda/kernels/src/for.cu

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
#include "scalar_kernel.cuh"
5+
#include "types.cuh"
56

67
// Frame-of-Reference operation: adds a reference value to each element.
78
template<typename T>
@@ -34,22 +35,8 @@ extern "C" __global__ void for_in_out_##suffix( \
3435
scalar_kernel(input, output, array_len, ForOp<Type>{reference}); \
3536
}
3637

37-
// In-place variants (modifies input buffer)
38-
GENERATE_FOR_KERNEL(u8, uint8_t)
39-
GENERATE_FOR_KERNEL(i8, int8_t)
40-
GENERATE_FOR_KERNEL(u16, uint16_t)
41-
GENERATE_FOR_KERNEL(i16, int16_t)
42-
GENERATE_FOR_KERNEL(u32, uint32_t)
43-
GENERATE_FOR_KERNEL(i32, int32_t)
44-
GENERATE_FOR_KERNEL(u64, uint64_t)
45-
GENERATE_FOR_KERNEL(i64, int64_t)
38+
// In-place variants (modifies input buffer) - FoR is only used for integers
39+
FOR_EACH_INTEGER(GENERATE_FOR_KERNEL)
4640

4741
// Separate input/output variants (preserves input buffer)
48-
GENERATE_FOR_IN_OUT_KERNEL(u8, uint8_t)
49-
GENERATE_FOR_IN_OUT_KERNEL(i8, int8_t)
50-
GENERATE_FOR_IN_OUT_KERNEL(u16, uint16_t)
51-
GENERATE_FOR_IN_OUT_KERNEL(i16, int16_t)
52-
GENERATE_FOR_IN_OUT_KERNEL(u32, uint32_t)
53-
GENERATE_FOR_IN_OUT_KERNEL(i32, int32_t)
54-
GENERATE_FOR_IN_OUT_KERNEL(u64, uint64_t)
55-
GENERATE_FOR_IN_OUT_KERNEL(i64, int64_t)
42+
FOR_EACH_INTEGER(GENERATE_FOR_IN_OUT_KERNEL)

vortex-cuda/kernels/src/patches.cu

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
#include "config.cuh"
5+
#include "types.cuh"
56

67
// TODO(aduffy): this is very naive. In the future we need to
78
// transpose the patches, see G-ALP paper.
@@ -40,22 +41,12 @@ extern "C" __global__ void patches_##value_suffix##_##index_suffix( \
4041
patches(values, patchIndices, patchValues, patchesLen); \
4142
}
4243

43-
#define GENERATE_PATCHES_KERNEL_FOR_VALUE(ValueT, value_suffix) \
44-
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint8_t, u8) \
45-
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint16_t, u16) \
46-
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint32_t, u32) \
47-
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint64_t, u64)
44+
// Generate patches kernel for all index types (unsigned integers) for a given value type
45+
#define GENERATE_PATCHES_FOR_ALL_INDICES(value_suffix, ValueT) \
46+
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint8_t, u8) \
47+
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint16_t, u16) \
48+
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint32_t, u32) \
49+
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint64_t, u64)
4850

49-
50-
GENERATE_PATCHES_KERNEL_FOR_VALUE(uint8_t, u8)
51-
GENERATE_PATCHES_KERNEL_FOR_VALUE(uint16_t, u16)
52-
GENERATE_PATCHES_KERNEL_FOR_VALUE(uint32_t, u32)
53-
GENERATE_PATCHES_KERNEL_FOR_VALUE(uint64_t, u64)
54-
55-
GENERATE_PATCHES_KERNEL_FOR_VALUE(int8_t, i8)
56-
GENERATE_PATCHES_KERNEL_FOR_VALUE(int16_t, i16)
57-
GENERATE_PATCHES_KERNEL_FOR_VALUE(int32_t, i32)
58-
GENERATE_PATCHES_KERNEL_FOR_VALUE(int64_t, i64)
59-
60-
GENERATE_PATCHES_KERNEL_FOR_VALUE(float, f32)
61-
GENERATE_PATCHES_KERNEL_FOR_VALUE(double, f64)
51+
// Generate for all native SIMD ptypes
52+
FOR_EACH_NATIVE_SIMD_PTYPE(GENERATE_PATCHES_FOR_ALL_INDICES)

vortex-cuda/kernels/src/types.cuh

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
#ifndef VORTEX_CUDA_TYPES_CUH
5-
#define VORTEX_CUDA_TYPES_CUH
4+
#pragma once
65

6+
#include <cuda_fp16.h>
77
#include <stdint.h>
88

99
// 128-bit signed integer type for decimal values
@@ -17,4 +17,52 @@ struct __align__(32) int256_t {
1717
int64_t parts[4];
1818
};
1919

20-
#endif // VORTEX_CUDA_TYPES_CUH
20+
// Type iteration macros - call MACRO(suffix, Type) for each type in category.
21+
// These mirror the Rust match_each_*_ptype macros.
22+
23+
// Unsigned integers
24+
#define FOR_EACH_UNSIGNED_INT(MACRO) \
25+
MACRO(u8, uint8_t) \
26+
MACRO(u16, uint16_t) \
27+
MACRO(u32, uint32_t) \
28+
MACRO(u64, uint64_t)
29+
30+
// Signed integers
31+
#define FOR_EACH_SIGNED_INT(MACRO) \
32+
MACRO(i8, int8_t) \
33+
MACRO(i16, int16_t) \
34+
MACRO(i32, int32_t) \
35+
MACRO(i64, int64_t)
36+
37+
// All integers (signed + unsigned)
38+
#define FOR_EACH_INTEGER(MACRO) \
39+
FOR_EACH_UNSIGNED_INT(MACRO) \
40+
FOR_EACH_SIGNED_INT(MACRO)
41+
42+
// All floating point types (requires #include <cuda_fp16.h>)
43+
#define FOR_EACH_FLOAT(MACRO) \
44+
MACRO(f16, __half) \
45+
MACRO(f32, float) \
46+
MACRO(f64, double)
47+
48+
// Native SIMD types (integers + f32/f64, matches match_each_native_simd_ptype)
49+
#define FOR_EACH_NATIVE_SIMD_PTYPE(MACRO) \
50+
FOR_EACH_INTEGER(MACRO) \
51+
MACRO(f32, float) \
52+
MACRO(f64, double)
53+
54+
// All native ptypes (requires #include <cuda_fp16.h>, matches match_each_native_ptype)
55+
#define FOR_EACH_NATIVE_PTYPE(MACRO) \
56+
FOR_EACH_INTEGER(MACRO) \
57+
FOR_EACH_FLOAT(MACRO)
58+
59+
// Large decimal types (128-bit and 256-bit integers for decimal representation).
60+
// Use alongside FOR_EACH_NATIVE_PTYPE for full type coverage.
61+
#define FOR_EACH_LARGE_DECIMAL(MACRO) \
62+
MACRO(i128, int128_t) \
63+
MACRO(i256, int256_t)
64+
65+
// All numeric types: native ptypes + large decimals (requires #include <cuda_fp16.h>)
66+
#define FOR_EACH_NUMERIC(MACRO) \
67+
FOR_EACH_NATIVE_PTYPE(MACRO) \
68+
FOR_EACH_LARGE_DECIMAL(MACRO)

0 commit comments

Comments
 (0)