Skip to content

Commit 3f2412b

Browse files
authored
Create install_pytorch.ps1
1 parent daae258 commit 3f2412b

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

.github/tools/install_pytorch.ps1

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Get versions from environment variables, with defaults
2+
$cuda_version = $env:CUDA_VERSION
3+
if (-not $cuda_version) { $cuda_version = "12.1" }
4+
5+
$pytorch_version = $env:PYTORCH_VERSION
6+
if (-not $pytorch_version) { $pytorch_version = "latest" }
7+
8+
# Determine CUDA short version for wheel index
9+
$cuda_short = switch ($cuda_version) {
10+
"11.8" { "cu118" }
11+
"12.1" { "cu121" }
12+
"12.4" { "cu124" }
13+
"12.6" { "cu126" }
14+
"12.8" { "cu128" }
15+
default {
16+
Write-Error "Unsupported CUDA version: $cuda_version"
17+
exit 1
18+
}
19+
}
20+
21+
$index_url = "https://download.pytorch.org/whl/$cuda_short"
22+
Write-Host "PyTorch wheel index: $index_url"
23+
24+
if ($pytorch_version -eq "latest") {
25+
Write-Host "Installing latest PyTorch for CUDA $cuda_version"
26+
pip install torch torchvision --index-url $index_url
27+
} else {
28+
Write-Host "Installing PyTorch $pytorch_version for CUDA $cuda_version"
29+
pip install "torch==$pytorch_version" torchvision --index-url $index_url
30+
}

0 commit comments

Comments
 (0)