generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 144
Expand file tree
/
Copy pathnightly.py
More file actions
119 lines (90 loc) · 3.56 KB
/
nightly.py
File metadata and controls
119 lines (90 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import annotations
import dataclasses
import datetime
import os
import subprocess
import sys
from datetime import timezone
from pathlib import Path
import tyro
from holosoma.config_values.experiment import AnnotatedExperimentConfig
from holosoma.train_agent import training_context
from holosoma.utils.tyro_utils import TYRO_CONIFG
REPO_ROOT = Path(__file__).parent.parent.parent.absolute()
def now_timestamp() -> str:
return datetime.datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
def validate_wandb_metrics(config: AnnotatedExperimentConfig):
# lazy import to avoid conflicts with Isaac
import wandb
assert wandb.run is not None, "wandb run failed! wandb.run is `None`"
api = wandb.Api()
run = api.run(f"{wandb.run.entity}/{wandb.run.project}/{wandb.run.id}")
df_hist = run.history()
failures: list[str] = []
assert config.nightly is not None # for type checking
assert config.nightly.metrics is not None
for k, v in config.nightly.metrics.items():
v_min = float(v[0])
v_max = float(v[1])
v_last_100 = df_hist[k][-100:].mean()
is_in_range = v_min <= v_last_100 <= v_max
if not is_in_range:
msg = f"Metric {k}={v_last_100:0.2f} is not in range ({v_min}, {v_max})"
print(msg)
failures.append(msg)
# 3. Any other post-training work can go here
if len(failures) > 0:
print(f"Some tests failed! Metrics outside of expected ranges: {failures}")
run.tags += ("nightly_test_failed",)
run.update()
else:
run.tags += ("nightly_test_passed",)
run.update()
def main():
config = tyro.cli(AnnotatedExperimentConfig, config=TYRO_CONIFG)
# Check if multigpu is requested and we're not already in a torchrun process
if config.training.multigpu and "RANK" not in os.environ:
# Re-launch with torchrun
env = os.environ.copy()
result = subprocess.run(
[
"torchrun",
"--nproc_per_node=4",
__file__,
*sys.argv[1:], # Pass all original arguments
],
env=env,
check=False,
)
sys.exit(result.returncode)
config = dataclasses.replace(config, training=dataclasses.replace(config.training, seed=42))
# Get experiment name from config instead of hydra runtime choices
exp = config.training.name or config.logger.name
# Sanitize experiment name for wandb project name (cannot contain /,\,#,?,%,:)
sanitized_exp = (
exp.replace("/", "-").replace("\\", "-").replace("#", "-").replace("?", "-").replace("%", "-").replace(":", "-")
)
# Add multigpu suffix if enabled
multigpu_suffix = "-multigpu" if config.training.multigpu else ""
config = config.get_nightly_config()
run_tags = []
if os.getenv("GITHUB_RUN_ID"):
run_tags.append(f"gha-run-id-{os.getenv('GITHUB_RUN_ID')}")
config = dataclasses.replace(
config,
logger=dataclasses.replace(
config.logger,
project=f"nightly-{sanitized_exp}{multigpu_suffix}",
name=f"nightly-{sanitized_exp}{multigpu_suffix}-{now_timestamp()}",
tags=tuple(run_tags),
),
)
with training_context(config) as ctx:
# 1. Train
ctx.train()
# 2. Validate metrics (explicit, linear flow) - only on rank 0
if os.environ.get("RANK", "0") == "0":
validate_wandb_metrics(config)
# 4. simulation_app automatically closed when exiting context
if __name__ == "__main__":
main()