|
1 | 1 | import os |
| 2 | +import re |
2 | 3 | import shutil |
3 | 4 | import subprocess |
4 | 5 | import sys |
|
8 | 9 | from huggingface_hub import snapshot_download |
9 | 10 | from huggingface_hub.utils import disable_progress_bars |
10 | 11 |
|
| 12 | +from kernels.compat import tomllib |
| 13 | +from kernels.utils import KNOWN_BACKENDS |
| 14 | + |
11 | 15 |
|
12 | 16 | def run_init(args: Namespace) -> None: |
13 | 17 | kernel_name = args.kernel_name |
| 18 | + if args.backends is None: |
| 19 | + backends = ["metal"] if sys.platform == "darwin" else ["cuda"] |
| 20 | + else: |
| 21 | + backends = [ |
| 22 | + v.strip().lower() |
| 23 | + for item in args.backends |
| 24 | + for v in item.split(",") |
| 25 | + if v.strip() |
| 26 | + ] |
| 27 | + if "all" in backends: |
| 28 | + if len(backends) > 1: |
| 29 | + print( |
| 30 | + "Error: --backends must be either 'all' or a list of backends.", |
| 31 | + file=sys.stderr, |
| 32 | + ) |
| 33 | + sys.exit(1) |
| 34 | + backends = [] |
| 35 | + else: |
| 36 | + valid = set(KNOWN_BACKENDS) |
| 37 | + invalid = sorted(set(backends) - valid) |
| 38 | + if invalid: |
| 39 | + print( |
| 40 | + f"Error: invalid backend(s): {', '.join(invalid)}. Valid values are: {', '.join(sorted(valid))}.", |
| 41 | + file=sys.stderr, |
| 42 | + ) |
| 43 | + sys.exit(1) |
| 44 | + seen: set[str] = set() |
| 45 | + backends = [b for b in backends if not (b in seen or seen.add(b))] |
14 | 46 | # must be fully qualified repo name <owner>/<repo> |
15 | 47 | owner_repo = kernel_name.split("/") |
16 | 48 | if len(owner_repo) != 2: |
@@ -52,6 +84,9 @@ def run_init(args: Namespace) -> None: |
52 | 84 | _init_from_local_template( |
53 | 85 | template_dir, target_dir, kernel_name, kernel_name_normalized, repo_id |
54 | 86 | ) |
| 87 | + if backends: |
| 88 | + _update_build_backends(target_dir / "build.toml", backends) |
| 89 | + _remove_backend_dirs(target_dir, backends) |
55 | 90 |
|
56 | 91 | # Initialize git repo (required for Nix flakes) |
57 | 92 | subprocess.run(["git", "init"], cwd=target_dir, check=True, capture_output=True) |
@@ -140,3 +175,58 @@ def _print_tree(directory: Path, prefix: str = "") -> None: |
140 | 175 | if entry.is_dir(): |
141 | 176 | extension = " " if is_last else "│ " |
142 | 177 | _print_tree(entry, prefix + extension) |
| 178 | + |
| 179 | + |
| 180 | +def _update_build_backends(build_toml_path: Path, backends: list[str]) -> None: |
| 181 | + if not build_toml_path.exists(): |
| 182 | + return |
| 183 | + text = build_toml_path.read_text() |
| 184 | + with open(build_toml_path, "rb") as f: |
| 185 | + data = tomllib.load(f) |
| 186 | + if "general" not in data: |
| 187 | + return |
| 188 | + kernel_table = data.get("kernel", {}) |
| 189 | + if not isinstance(kernel_table, dict): |
| 190 | + kernel_table = {} |
| 191 | + remove_kernels = { |
| 192 | + name |
| 193 | + for name, cfg in kernel_table.items() |
| 194 | + if isinstance(cfg, dict) and cfg.get("backend") not in set(backends) |
| 195 | + } |
| 196 | + backends_list = ", ".join(f'"{b}"' for b in backends) |
| 197 | + new_line = f"backends = [{backends_list}]" |
| 198 | + pattern = r"(\[general\][\s\S]*?)^\s*backends\s*=\s*\[[^\]]*\]" |
| 199 | + new_text, count = re.subn(pattern, r"\1" + new_line, text, count=1, flags=re.M) |
| 200 | + if remove_kernels: |
| 201 | + new_text = _remove_kernel_sections(new_text, remove_kernels) |
| 202 | + if count or remove_kernels: |
| 203 | + build_toml_path.write_text(new_text) |
| 204 | + |
| 205 | + |
| 206 | +def _remove_kernel_sections(text: str, remove_kernels: set[str]) -> str: |
| 207 | + lines = text.splitlines(keepends=True) |
| 208 | + output: list[str] = [] |
| 209 | + skip = False |
| 210 | + for line in lines: |
| 211 | + match = re.match(r"^\s*\[kernel\.([^\]]+)\]\s*$", line) |
| 212 | + if match: |
| 213 | + skip = match.group(1).strip() in remove_kernels |
| 214 | + if skip: |
| 215 | + continue |
| 216 | + if skip and re.match(r"^\s*\[[^\]]+\]\s*$", line): |
| 217 | + skip = False |
| 218 | + if not skip: |
| 219 | + output.append(line) |
| 220 | + return "".join(output) |
| 221 | + |
| 222 | + |
| 223 | +def _remove_backend_dirs(target_dir: Path, backends: list[str]) -> None: |
| 224 | + keep = set(backends) |
| 225 | + known = set(KNOWN_BACKENDS) |
| 226 | + for entry in target_dir.iterdir(): |
| 227 | + if not entry.is_dir(): |
| 228 | + continue |
| 229 | + for backend in known - keep: |
| 230 | + if entry.name.endswith(f"_{backend}"): |
| 231 | + shutil.rmtree(entry) |
| 232 | + break |
0 commit comments