Skip to content

Commit aa584ac

Browse files
authored
feat: add "runtime" and "comptime" features to select when slang shaders need to be compiled (#2)
* feat: make stensor compatible with slang-hal’s new comptime feature * Release v0.3.0
1 parent 0239ceb commit aa584ac

File tree

13 files changed

+54
-39
lines changed

13 files changed

+54
-39
lines changed

.github/workflows/ci.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ jobs:
9696
sweep-cache: true
9797

9898
- name: Run clippy lints
99-
run: SLANG_DIR=$SLANG_DIR cargo clippy --locked --workspace --all-targets -- --deny warnings
99+
run: SLANG_DIR=$SLANG_DIR cargo clippy --locked --workspace --all-targets --features runtime -- --deny warnings
100100

101101
# Check documentation.
102102
doc:
@@ -128,7 +128,7 @@ jobs:
128128
sweep-cache: true
129129

130130
- name: Check documentation
131-
run: SLANG_DIR=$SLANG_DIR cargo doc --locked --workspace --document-private-items --no-deps
131+
run: SLANG_DIR=$SLANG_DIR cargo doc --locked --workspace --document-private-items --features runtime --no-deps
132132
# Testing.
133133
test:
134134
needs: setup-slang # Depends on setup-slang
@@ -163,4 +163,4 @@ jobs:
163163
sweep-cache: true
164164
- name: Run Cargo Tests
165165
run: |
166-
SLANG_DIR=$SLANG_DIR LIBGL_ALWAYS_SOFTWARE=1 cargo test --verbose
166+
SLANG_DIR=$SLANG_DIR LIBGL_ALWAYS_SOFTWARE=1 cargo test --verbose --features runtime
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<component name="ProjectRunConfigurationManager">
2-
<configuration default="false" name="Check" type="CargoCommandRunConfiguration" factoryName="Cargo Command" nameIsGenerated="true">
2+
<configuration default="false" name="Check (runtime)" type="CargoCommandRunConfiguration" factoryName="Cargo Command">
33
<option name="buildProfileId" value="dev" />
4-
<option name="command" value="check" />
4+
<option name="command" value="check --features runtime" />
55
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
66
<envs />
77
<option name="emulateTerminal" value="true" />

Cargo.toml

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,33 @@ name = "stensor"
33
authors = ["Sébastien Crozet <sebcrozet@dimforge.com>"]
44
description = "Cross-platform GPU tensor library with Slang and Rust."
55
repository = "https://github.com/dimforge/stensor"
6-
version = "0.2.0"
6+
version = "0.3.0"
77
edition = "2024"
88
license = "Apache-2.0"
99

1010
[features]
11+
comptime = [ "slang-hal/comptime" ]
12+
runtime = [ "slang-hal/runtime" ]
13+
14+
webgpu = [ "slang-hal/webgpu" ]
15+
vulkan = [ "slang-hal/vulkan" ]
16+
metal = [ "slang-hal/metal" ]
17+
cpu = [ "slang-hal/cpu" ]
1118
cuda = [ "cudarc", "slang-hal/cuda" ]
1219
cublas = [ "slang-hal/cublas" ]
1320

1421
[dependencies]
15-
wgpu = "27"
1622
encase = "0.12"
1723
bytemuck = "1"
1824
nalgebra = { version = "0.34", features = ["encase"] }
1925

2026
cudarc = { version = "0.16", optional = true }
2127

22-
minislang = "0.2"
23-
slang-hal = { version = "0.2", features = ["derive"] }
28+
slang-hal = { version = "0.3", features = ["derive"] }
2429
include_dir = "0.7"
2530

2631
[dev-dependencies]
32+
minislang = "0.3"
2733
nalgebra = { version = "0.34", features = ["rand"] }
2834
futures-test = "0.3"
2935
serial_test = "3"
@@ -32,12 +38,16 @@ async-std = { version = "1", features = ["attributes"] }
3238
plotly = "0.12.1"
3339
indexmap = "2"
3440
anyhow = "1"
41+
wgpu = "27"
3542

3643
[build-dependencies]
37-
minislang = "0.1"
44+
minislang = "0.3"
45+
slang-hal-build = "0.3"
46+
include_dir = "0.7"
3847

3948
[patch.crates-io]
4049
#shader-slang = { path = "../slang-rs" }
4150
#minislang = { path = "../slang-hal/crates/minislang" }
51+
#slang-hal-build = { path = "../slang-hal/crates/slang-hal-build" }
4252
#slang-hal-derive = { path = "../slang-hal/crates/slang-hal-derive" }
4353
#slang-hal = { path = "../slang-hal/crates/slang-hal" }

build.rs

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1-
use minislang::{SlangCompiler, shader_slang::CompileTarget};
2-
use std::path::PathBuf;
3-
use std::str::FromStr;
1+
#[cfg(feature = "comptime")]
2+
use std::env;
43

4+
#[cfg(not(feature = "comptime"))]
5+
pub fn main() {}
6+
7+
#[cfg(feature = "comptime")]
58
pub fn main() {
6-
let slang = SlangCompiler::new(vec![PathBuf::from_str("./shaders").unwrap()]);
9+
use slang_hal_build::ShaderCompiler;
10+
11+
const SLANG_SRC_DIR: include_dir::Dir<'_> =
12+
include_dir::include_dir!("$CARGO_MANIFEST_DIR/shaders");
713

8-
let targets = [
9-
CompileTarget::Wgsl,
10-
#[cfg(feature = "cuda")]
11-
CompileTarget::CudaSource,
12-
];
14+
let out_dir = env::var("OUT_DIR").expect("Couldn't determine output directory.");
15+
let mut compiler = ShaderCompiler::new(vec![], &out_dir);
16+
compiler.add_dir(SLANG_SRC_DIR);
1317

14-
for target in targets {
15-
slang.compile_all(target, "../shaders", "./src/autogen", &[]);
16-
}
18+
// Compile all shaders from examples/shaders directory.
19+
// Note: slang-hal-build will automatically detect which backends to compile for
20+
// based on the cargo features enabled during the build.
21+
compiler
22+
.compile_shaders_dir("shaders", &[])
23+
.expect("Failed to compile shaders");
1724
}

examples/gemm_bench.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use indexmap::IndexMap;
22
use minislang::SlangCompiler;
33
use nalgebra::DMatrix;
4-
use slang_hal::Shader;
54
use slang_hal::backend::WebGpu;
65
use slang_hal::backend::{Backend, Encoder};
6+
use slang_hal::{BufferUsages, Shader};
77
use stensor::linalg::{Gemm, GemmVariant};
88
use stensor::shapes::ViewShapeBuffers;
99
use stensor::tensor::GpuTensor;
10-
use wgpu::{BufferUsages, Features, Limits};
10+
use wgpu::{Features, Limits};
1111

1212
#[async_std::main]
1313
async fn main() -> anyhow::Result<()> {

src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
pub use geometry::*;
77
pub use linalg::*;
8-
9-
use minislang::SlangCompiler;
8+
use slang_hal::SlangCompiler;
109

1110
pub mod geometry;
1211
pub mod linalg;
@@ -15,7 +14,8 @@ pub mod tensor;
1514

1615
// pub mod utils;
1716

18-
const SLANG_SRC_DIR: include_dir::Dir<'_> =
17+
/// Directory of slang shaders from `stensor`.
18+
pub const SLANG_SRC_DIR: include_dir::Dir<'_> =
1919
include_dir::include_dir!("$CARGO_MANIFEST_DIR/shaders");
2020

2121
/// Register all the shaders from this crate (and its dependencies) as modules accessible to the

src/linalg/contiguous.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ mod test {
6262
use crate::tensor::GpuTensor;
6363
use minislang::SlangCompiler;
6464
use nalgebra::DMatrix;
65-
use slang_hal::Shader;
6665
use slang_hal::backend::WebGpu;
6766
use slang_hal::backend::{Backend, Encoder};
68-
use wgpu::{BufferUsages, Features, Limits};
67+
use slang_hal::{BufferUsages, Shader};
68+
use wgpu::{Features, Limits};
6969

7070
#[futures_test::test]
7171
#[serial_test::serial]

src/linalg/gemm.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ mod test {
202202
use approx::relative_eq;
203203
use minislang::SlangCompiler;
204204
use nalgebra::DMatrix;
205-
use slang_hal::Shader;
206205
use slang_hal::backend::{Backend, Encoder, WebGpu};
207-
use wgpu::{BufferUsages, Features, Limits};
206+
use slang_hal::{BufferUsages, Shader};
207+
use wgpu::{Features, Limits};
208208

209209
#[futures_test::test]
210210
#[serial_test::serial]

src/linalg/gemv.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,10 @@ mod test {
313313
use approx::assert_relative_eq;
314314
use minislang::SlangCompiler;
315315
use nalgebra::{DMatrix, DVector};
316-
use slang_hal::Shader;
317316
use slang_hal::backend::WebGpu;
318317
use slang_hal::backend::{Backend, Encoder};
319-
use wgpu::{BufferUsages, Features, Limits};
318+
use slang_hal::{BufferUsages, Shader};
319+
use wgpu::{Features, Limits};
320320

321321
#[futures_test::test]
322322
#[serial_test::serial]

src/linalg/op_assign.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,10 @@ mod test {
159159
use crate::tensor::GpuTensor;
160160
use minislang::SlangCompiler;
161161
use nalgebra::DVector;
162+
use slang_hal::BufferUsages;
162163
use slang_hal::backend::WebGpu;
163164
use slang_hal::backend::{Backend, Buffer, Encoder};
164165
use slang_hal::shader::Shader;
165-
use wgpu::BufferUsages;
166166

167167
#[futures_test::test]
168168
#[serial_test::serial]

0 commit comments

Comments
 (0)