Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ Run `uv run script.py --help` for the full list. Common options include:
| `--partitions` | Partition training by `deck` or `preset`. | Default: `none` |
| `--data` | Path to `revlogs/*.parquet`. | Default: `../anki-revlogs-10k` |
| `--processes` | Number of worker processes. | Default: `8` |
| `--gpus` | CUDA device IDs to assign to workers (e.g., `0,1` or `all`). | Default: unset |
| `--max-user-id` | Maximum user ID to process (inclusive). | No limit |
| `--n_splits` | Number of TimeSeriesSplit folds. | Default: `5` |
| `--train_equals_test` | Train and test on the same data. | Off |
Expand All @@ -330,4 +331,4 @@ To pretrain LSTM on multiple users, run:

```bash
uv run pretrain.py --algo LSTM
```
```
34 changes: 34 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import re
import torch
from pathlib import Path
from typing import List, Optional, Literal, get_args
Expand Down Expand Up @@ -46,6 +47,11 @@ def create_parser():
parser.add_argument(
"--processes", default=8, type=int, help="set the number of processes"
)
parser.add_argument(
"--gpus",
default=None,
help="comma/space-separated CUDA device IDs to use (e.g., '0,1' or 'all')",
)
parser.add_argument("--dev", action="store_true", help="for local development")

# Add this line:
Expand Down Expand Up @@ -154,6 +160,33 @@ def create_parser():
return parser


def _parse_cuda_devices(raw: Optional[str]) -> Optional[List[int]]:
if raw is None:
return None
value = raw.strip()
if value == "":
return None
value_lower = value.lower()
if value_lower in {"all", "*"}:
if not torch.cuda.is_available():
return []
return list(range(torch.cuda.device_count()))

parts = [p for p in re.split(r"[,\s]+", value) if p]
device_ids: List[int] = []
for part in parts:
try:
device_id = int(part)
except ValueError as exc:
raise ValueError(
f"Invalid CUDA device id '{part}'. Use comma/space-separated integers."
) from exc
if device_id < 0:
raise ValueError("CUDA device IDs must be >= 0.")
device_ids.append(device_id)
return device_ids


class Config:
"""Holds all application configurations derived from command-line arguments and defaults."""

Expand Down Expand Up @@ -182,6 +215,7 @@ def __init__(self, args: argparse.Namespace):
self.data_path: Path = Path(args.data)
self.use_recency_weighting: bool = args.recency
self.train_equals_test: bool = args.train_equals_test
self.cuda_device_ids: Optional[List[int]] = _parse_cuda_devices(args.gpus)

# Training/data parameters from parser (with defaults)
self.n_splits: int = args.n_splits
Expand Down
47 changes: 45 additions & 2 deletions script.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,34 @@ def plot(self):
return fig


def _configure_process_device(device_id: Optional[int]) -> None:
if device_id is None:
return
if not torch.cuda.is_available():
return
if config.device.type != "cuda":
return
device_count = torch.cuda.device_count()
if device_id < 0 or device_id >= device_count:
raise ValueError(
f"Invalid CUDA device id {device_id}. Available range: 0..{device_count - 1}"
)
torch.cuda.set_device(device_id)
config.device = torch.device(f"cuda:{device_id}")
if config.model_name == "LSTM":
try:
import reptile_trainer

reptile_trainer.DEVICE = config.device
except Exception:
pass
Comment on lines +220 to +221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a broad except Exception: pass is risky as it can hide important errors, such as ImportError if reptile_trainer is not found or AttributeError if DEVICE is not a member. This could lead to the model silently running on the wrong device. It's better to catch specific exceptions and log a warning to make debugging easier.

Suggested change
except Exception:
pass
except (ImportError, AttributeError) as e:
print(f"Warning: Could not configure device for reptile_trainer: {e}")



@catch_exceptions
def process(user_id: int) -> tuple[dict, Optional[dict]]:
def process(user_id: int, device_id: Optional[int] = None) -> tuple[dict, Optional[dict]]:
"""Main processing function for all models."""
plt.close("all")
_configure_process_device(device_id)

# Load data once for all models
data_loader = UserDataLoader(config)
Expand Down Expand Up @@ -372,13 +396,32 @@ def process(user_id: int) -> tuple[dict, Optional[dict]]:

unprocessed_users.sort()

cuda_device_ids = None
if config.cuda_device_ids:
if config.device.type != "cuda":
print("Warning: --gpus ignored because CUDA is not enabled for this model.")
else:
device_count = torch.cuda.device_count()
invalid = [i for i in config.cuda_device_ids if i >= device_count]
if invalid:
raise ValueError(
"Invalid CUDA device IDs "
f"{invalid}; available range is 0..{device_count - 1}"
)
cuda_device_ids = config.cuda_device_ids
if config.num_processes > len(cuda_device_ids):
print(
"Warning: --processes exceeds --gpus; multiple workers will share GPUs."
)

with ProcessPoolExecutor(max_workers=config.num_processes) as executor:
futures = [
executor.submit(
process,
user_id,
cuda_device_ids[idx % len(cuda_device_ids)] if cuda_device_ids else None,
)
for user_id in unprocessed_users
for idx, user_id in enumerate(unprocessed_users)
]
for future in (
pbar := tqdm(as_completed(futures), total=len(futures), smoothing=0.03)
Expand Down
Loading