Skip to content

Commit e7d3188

Browse files
committed
fix: detect and default template for platform
1 parent 0645ecb commit e7d3188

File tree

3 files changed

+97
-0
lines changed

3 files changed

+97
-0
lines changed

kernels/src/kernels/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,12 @@ def main():
162162
default="drbh/template",
163163
help="HuggingFace repo ID for the template",
164164
)
165+
init_parser.add_argument(
166+
"--backends",
167+
nargs="+",
168+
default=None,
169+
help="Backends to include ('all' or list like: cpu cuda metal rocm xpu). Defaults: cuda on Linux/Windows, metal on macOS.",
170+
)
165171
init_parser.set_defaults(func=run_init)
166172

167173
args = parser.parse_args()

kernels/src/kernels/init.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import re
23
import shutil
34
import subprocess
45
import sys
@@ -8,9 +9,40 @@
89
from huggingface_hub import snapshot_download
910
from huggingface_hub.utils import disable_progress_bars
1011

12+
from kernels.compat import tomllib
13+
from kernels.utils import KNOWN_BACKENDS
14+
1115

1216
def run_init(args: Namespace) -> None:
1317
kernel_name = args.kernel_name
18+
if args.backends is None:
19+
backends = ["metal"] if sys.platform == "darwin" else ["cuda"]
20+
else:
21+
backends = [
22+
v.strip().lower()
23+
for item in args.backends
24+
for v in item.split(",")
25+
if v.strip()
26+
]
27+
if "all" in backends:
28+
if len(backends) > 1:
29+
print(
30+
"Error: --backends must be either 'all' or a list of backends.",
31+
file=sys.stderr,
32+
)
33+
sys.exit(1)
34+
backends = []
35+
else:
36+
valid = set(KNOWN_BACKENDS)
37+
invalid = sorted(set(backends) - valid)
38+
if invalid:
39+
print(
40+
f"Error: invalid backend(s): {', '.join(invalid)}. Valid values are: {', '.join(sorted(valid))}.",
41+
file=sys.stderr,
42+
)
43+
sys.exit(1)
44+
seen: set[str] = set()
45+
backends = [b for b in backends if not (b in seen or seen.add(b))]
1446
# must be fully qualified repo name <owner>/<repo>
1547
owner_repo = kernel_name.split("/")
1648
if len(owner_repo) != 2:
@@ -52,6 +84,9 @@ def run_init(args: Namespace) -> None:
5284
_init_from_local_template(
5385
template_dir, target_dir, kernel_name, kernel_name_normalized, repo_id
5486
)
87+
if backends:
88+
_update_build_backends(target_dir / "build.toml", backends)
89+
_remove_backend_dirs(target_dir, backends)
5590

5691
# Initialize git repo (required for Nix flakes)
5792
subprocess.run(["git", "init"], cwd=target_dir, check=True, capture_output=True)
@@ -140,3 +175,58 @@ def _print_tree(directory: Path, prefix: str = "") -> None:
140175
if entry.is_dir():
141176
extension = " " if is_last else "│ "
142177
_print_tree(entry, prefix + extension)
178+
179+
180+
def _update_build_backends(build_toml_path: Path, backends: list[str]) -> None:
181+
if not build_toml_path.exists():
182+
return
183+
text = build_toml_path.read_text()
184+
with open(build_toml_path, "rb") as f:
185+
data = tomllib.load(f)
186+
if "general" not in data:
187+
return
188+
kernel_table = data.get("kernel", {})
189+
if not isinstance(kernel_table, dict):
190+
kernel_table = {}
191+
remove_kernels = {
192+
name
193+
for name, cfg in kernel_table.items()
194+
if isinstance(cfg, dict) and cfg.get("backend") not in set(backends)
195+
}
196+
backends_list = ", ".join(f'"{b}"' for b in backends)
197+
new_line = f"backends = [{backends_list}]"
198+
pattern = r"(\[general\][\s\S]*?)^\s*backends\s*=\s*\[[^\]]*\]"
199+
new_text, count = re.subn(pattern, r"\1" + new_line, text, count=1, flags=re.M)
200+
if remove_kernels:
201+
new_text = _remove_kernel_sections(new_text, remove_kernels)
202+
if count or remove_kernels:
203+
build_toml_path.write_text(new_text)
204+
205+
206+
def _remove_kernel_sections(text: str, remove_kernels: set[str]) -> str:
207+
lines = text.splitlines(keepends=True)
208+
output: list[str] = []
209+
skip = False
210+
for line in lines:
211+
match = re.match(r"^\s*\[kernel\.([^\]]+)\]\s*$", line)
212+
if match:
213+
skip = match.group(1).strip() in remove_kernels
214+
if skip:
215+
continue
216+
if skip and re.match(r"^\s*\[[^\]]+\]\s*$", line):
217+
skip = False
218+
if not skip:
219+
output.append(line)
220+
return "".join(output)
221+
222+
223+
def _remove_backend_dirs(target_dir: Path, backends: list[str]) -> None:
224+
keep = set(backends)
225+
known = set(KNOWN_BACKENDS)
226+
for entry in target_dir.iterdir():
227+
if not entry.is_dir():
228+
continue
229+
for backend in known - keep:
230+
if entry.name.endswith(f"_{backend}"):
231+
shutil.rmtree(entry)
232+
break

kernels/src/kernels/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from kernels.metadata import Metadata
2323

2424
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
25+
KNOWN_BACKENDS = ("cpu", "cuda", "metal", "rocm", "xpu", "npu")
2526

2627

2728
def _get_cache_dir() -> str | None:

0 commit comments

Comments
 (0)