Skip to content

Commit c6ca60b

Browse files
Merge master into feature/cloudformation
2 parents b4ea548 + 747f18c commit c6ca60b

File tree

18 files changed

+1413
-50
lines changed

18 files changed

+1413
-50
lines changed

packages/core/resources/hyperpod_connect

Lines changed: 141 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,15 @@
66
# HyperPod is AWS's managed service for distributed machine learning training at scale.
77
#
88
# OVERVIEW:
9-
# The script acts as a wrapper around the AWS SSM CLI to create a secure session tunnel to a HyperPod
10-
# compute instance. It validates required parameters, logs the connection attempt, and executes the
11-
# SSM StartSession command with HyperPod-specific connection details.
12-
#
13-
# REQUIRED ENVIRONMENT VARIABLES:
14-
# AWS_REGION - AWS region where the HyperPod cluster is located (e.g., us-west-2)
15-
# AWS_SSM_CLI - Path to the AWS SSM CLI executable
16-
# STREAM_URL - WebSocket stream URL for the HyperPod session connection
17-
# TOKEN - Authentication token for the HyperPod session
18-
# SESSION_ID - Unique identifier for the HyperPod session
9+
# The script always gets fresh credentials from the reconnection manager and uses them to establish
10+
# a secure session tunnel to a HyperPod compute instance.
1911
#
2012
# OPTIONAL ENVIRONMENT VARIABLES:
2113
# LOG_FILE_LOCATION - Path to log file (default: /tmp/hyperpod_connect.log)
2214
# DEBUG_LOG - Enable debug logging (default: 0)
2315
#
2416
# USAGE:
25-
# AWS_REGION=us-west-2 AWS_SSM_CLI=/usr/local/bin/session-manager-plugin \
26-
# STREAM_URL=wss://... TOKEN=abc123... SESSION_ID=session-xyz \
27-
# ./hyperpod_connect
17+
# ./hyperpod_connect hp_demo1
2818
#
2919
# SECURITY NOTE:
3020
# This script handles sensitive authentication tokens. Ensure proper file permissions
@@ -57,6 +47,28 @@ _require() {
5747
_log "$1=$2"
5848
}
5949

50+
_url_encode() {
51+
python3 -c "import sys, urllib.parse; print(urllib.parse.quote(sys.argv[1]))" "$1"
52+
}
53+
54+
_resolve_connection_key() {
55+
local DEVSPACE_NAME=$1
56+
local PROFILES_FILE="$HOME/.aws/.hyperpod-space-profiles"
57+
58+
[ ! -f "$PROFILES_FILE" ] && echo "$DEVSPACE_NAME" && return
59+
60+
python3 -c "
61+
import json
62+
try:
63+
with open('$PROFILES_FILE', 'r') as f:
64+
profiles = json.load(f)
65+
matches = sorted([k for k in profiles.keys() if k.endswith(':$DEVSPACE_NAME')])
66+
print(matches[0] if matches else '$DEVSPACE_NAME')
67+
except:
68+
print('$DEVSPACE_NAME')
69+
" 2>/dev/null
70+
}
71+
6072
_hyperpod() {
6173
# Function inputs
6274
local AWS_SSM_CLI=$1
@@ -68,19 +80,129 @@ _hyperpod() {
6880
exec "$AWS_SSM_CLI" "{\"streamUrl\":\"$STREAM_URL\",\"tokenValue\":\"$TOKEN\",\"sessionId\":\"$SESSION_ID\"}" "$AWS_REGION" "StartSession"
6981
}
7082

83+
_get_fresh_credentials() {
84+
local CONNECTION_KEY=$1
85+
86+
_log "Getting fresh credentials for connection key: $CONNECTION_KEY"
87+
88+
# Read server info to get port
89+
if [ -z "$SAGEMAKER_LOCAL_SERVER_FILE_PATH" ]; then
90+
_log "Error: SAGEMAKER_LOCAL_SERVER_FILE_PATH environment variable is not set"
91+
exit 1
92+
fi
93+
94+
if [ ! -f "$SAGEMAKER_LOCAL_SERVER_FILE_PATH" ]; then
95+
_log "Error: Server info file not found: $SAGEMAKER_LOCAL_SERVER_FILE_PATH"
96+
exit 1
97+
fi
98+
99+
local PORT=$(jq -r '.port' "$SAGEMAKER_LOCAL_SERVER_FILE_PATH")
100+
if [ -z "$PORT" ] || [ "$PORT" == "null" ]; then
101+
_log "Error: 'port' field is missing or invalid in $SAGEMAKER_LOCAL_SERVER_FILE_PATH"
102+
exit 1
103+
fi
104+
105+
# Call API to get fresh credentials using the determined connection key
106+
local API_URL
107+
if [[ "$CONNECTION_KEY" =~ ^[^:]+:[^:]+:[^:]+$ ]]; then
108+
# Use full connection key for precise lookup (preferred)
109+
API_URL="http://localhost:$PORT/get_hyperpod_session?connection_key=$(_url_encode "$CONNECTION_KEY")"
110+
elif [ -n "$CLUSTER_NAME" ] && [ -n "$NAMESPACE" ]; then
111+
# Use individual parameters if available
112+
API_URL="http://localhost:$PORT/get_hyperpod_session?devspace_name=$(_url_encode "$DEVSPACE_NAME")&namespace=$(_url_encode "$NAMESPACE")&cluster_name=$(_url_encode "$CLUSTER_NAME")"
113+
else
114+
# Fallback for legacy format
115+
API_URL="http://localhost:$PORT/get_hyperpod_session?connection_key=$(_url_encode "$CONNECTION_KEY")"
116+
fi
117+
API_RESPONSE=$(curl -s "$API_URL")
118+
119+
if [ $? -ne 0 ] || [ -z "$API_RESPONSE" ]; then
120+
_log "Error: Failed to get credentials from API"
121+
exit 1
122+
fi
123+
124+
# Parse JSON once and check for error response
125+
read -r STATUS ERROR_MSG CONNECTION_URL < <(echo "$API_RESPONSE" | python3 -c "import sys, json; data=json.load(sys.stdin); print(data.get('status', 'success'), data.get('message', ''), data.get('connection', {}).get('url', ''))" 2>/dev/null) || {
126+
_log "Error: Failed to parse API response JSON"
127+
exit 1
128+
}
129+
if [ "$STATUS" = "error" ]; then
130+
_log "Error from API: ${ERROR_MSG:-Unknown error}"
131+
exit 1
132+
fi
133+
134+
_log "Fresh credentials obtained from API"
135+
}
136+
71137
_main() {
72138
# Set defaults for missing environment variables
73139
DEBUG_LOG=${DEBUG_LOG:-0}
74140
LOG_FILE_LOCATION=${LOG_FILE_LOCATION:-/tmp/hyperpod_connect.log}
75141

142+
if [ $# -ne 1 ]; then
143+
_log "Usage: $0 hp_<devspace_name>"
144+
exit 1
145+
fi
146+
147+
# Extract devspace name, cluster name, and namespace from hostname
148+
# New format: hp_{cluster_name}_{namespace}_{space_name}_{region}_{account_id}
149+
# Old format: hp_{devspace} (for backward compatibility)
150+
if [[ "$1" =~ ^hp_([^_]+)_([^_]+)_([^_]+)_([^_]+)_([^_]+)$ ]]; then
151+
# New format: cluster_namespace_space_region_account
152+
CLUSTER_NAME="${BASH_REMATCH[1]}"
153+
NAMESPACE="${BASH_REMATCH[2]}"
154+
DEVSPACE_NAME="${BASH_REMATCH[3]}"
155+
# Construct the expected connection key (cluster:namespace:devspace)
156+
CONNECTION_KEY="${CLUSTER_NAME}:${NAMESPACE}:${DEVSPACE_NAME}"
157+
else
158+
# Old format or fallback - extract devspace name only
159+
DEVSPACE_NAME=$(echo "$1" | sed 's/hp_//')
160+
CLUSTER_NAME=""
161+
NAMESPACE=""
162+
CONNECTION_KEY=""
163+
fi
164+
165+
# For new format, we already have the connection key constructed from hostname
166+
# For old format, look up the connection key from the hyperpod profiles file
167+
[ -z "$CONNECTION_KEY" ] && CONNECTION_KEY=$(_resolve_connection_key "$DEVSPACE_NAME")
168+
169+
if [ -z "$CONNECTION_KEY" ]; then
170+
_log "Error: Could not determine connection key for devspace: $DEVSPACE_NAME"
171+
exit 1
172+
fi
173+
76174
_log "=============================================================================="
77-
_require AWS_REGION "${AWS_REGION:-}"
78-
_require AWS_SSM_CLI "${AWS_SSM_CLI:-}"
79-
_require SESSION_ID "${SESSION_ID:-}"
80-
_require_nolog STREAM_URL "${STREAM_URL:-}"
81-
_require_nolog TOKEN "${TOKEN:-}"
175+
_log "Connecting to HyperPod devspace: $DEVSPACE_NAME (connection key: $CONNECTION_KEY)"
176+
177+
# Always get fresh credentials (CONNECTION_URL is set by _get_fresh_credentials)
178+
_get_fresh_credentials "$CONNECTION_KEY"
179+
180+
if [[ "$CONNECTION_URL" =~ ^wss:// ]]; then
181+
# Presigned URL format - extract session ID from path and token from cell-number param
182+
SESSION_ID=$(echo "$CONNECTION_URL" | python3 -c "import sys, urllib.parse; url=sys.stdin.read().strip(); path=urllib.parse.urlparse(url).path; session_id=path.split('/')[-1] if path else ''; print(session_id)" 2>/dev/null)
183+
TOKEN=$(echo "$CONNECTION_URL" | python3 -c "import sys, urllib.parse; url=sys.stdin.read().strip(); params=urllib.parse.parse_qs(urllib.parse.urlparse(url).query); print(params.get('cell-number', [''])[0])" 2>/dev/null)
184+
STREAM_URL="$CONNECTION_URL"
185+
else
186+
# Kubectl format - extract from query params
187+
SESSION_ID=$(echo "$CONNECTION_URL" | python3 -c "import sys, urllib.parse; url=sys.stdin.read().strip(); params=urllib.parse.parse_qs(urllib.parse.urlparse(url).query); print(params.get('sessionId', [''])[0])" 2>/dev/null)
188+
TOKEN=$(echo "$CONNECTION_URL" | python3 -c "import sys, urllib.parse; url=sys.stdin.read().strip(); params=urllib.parse.parse_qs(urllib.parse.urlparse(url).query); print(params.get('sessionToken', [''])[0])" 2>/dev/null)
189+
STREAM_URL=$(echo "$CONNECTION_URL" | python3 -c "import sys, urllib.parse; url=sys.stdin.read().strip(); params=urllib.parse.parse_qs(urllib.parse.urlparse(url).query); print(params.get('streamUrl', [''])[0])" 2>/dev/null)
190+
fi
191+
192+
# Extract region from stream URL
193+
AWS_REGION=$(echo "$STREAM_URL" | grep -o '\.[a-z0-9-]*\.amazonaws\.com' | sed 's/^\.//;s/\.amazonaws\.com$//')
194+
AWS_SSM_CLI="${AWS_SSM_CLI:-session-manager-plugin}"
195+
196+
# Validate required parameters (SESSION_ID, STREAM_URL, TOKEN are fetched from API, not validated here)
197+
_require AWS_REGION "${AWS_REGION}"
198+
_require AWS_SSM_CLI "${AWS_SSM_CLI}"
199+
200+
if [ -z "${SESSION_ID}" ] || [ -z "${STREAM_URL}" ] || [ -z "${TOKEN}" ]; then
201+
_log "Error: Failed to retrieve valid session credentials from API"
202+
exit 1
203+
fi
82204

83205
_hyperpod "$AWS_SSM_CLI" "$AWS_REGION" "$STREAM_URL" "$TOKEN" "$SESSION_ID"
84206
}
85207

86-
_main
208+
_main "$@"
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# HyperPod Connection Script (PowerShell)
2+
#
3+
# This script establishes a connection to an AWS SageMaker HyperPod instance using AWS Systems Manager (SSM).
4+
5+
param(
6+
[Parameter(Mandatory=$true)]
7+
[string]$HostName
8+
)
9+
10+
$ErrorActionPreference = "Stop"
11+
12+
$DebugLog = $env:DEBUG_LOG -eq "1"
13+
$LogFileLocation = if ($env:LOG_FILE_LOCATION) { $env:LOG_FILE_LOCATION } else { "$env:TEMP\hyperpod_connect.log" }
14+
15+
function Write-Log {
16+
param([string]$Message)
17+
$timestamp = Get-Date -Format "yyyy/MM/dd HH:mm:ss"
18+
"$timestamp $Message" | Out-File -FilePath $LogFileLocation -Append -Encoding utf8
19+
}
20+
21+
function Get-FreshCredentials {
22+
param([string]$ConnectionKey)
23+
24+
Write-Log "Getting fresh credentials for connection key: $ConnectionKey"
25+
26+
# Read server info to get port
27+
$serverInfoFile = "$env:APPDATA\Code\User\globalStorage\amazonwebservices.aws-toolkit-vscode\sagemaker-local-server-info.json"
28+
if (-not (Test-Path $serverInfoFile)) {
29+
Write-Log "Error: Server info file not found: $serverInfoFile"
30+
exit 1
31+
}
32+
33+
$serverInfo = Get-Content $serverInfoFile | ConvertFrom-Json
34+
$port = $serverInfo.port
35+
36+
if (-not $port) {
37+
Write-Log "Error: Could not extract port from server info file"
38+
exit 1
39+
}
40+
41+
# Call API to get fresh credentials
42+
$apiUrl = "http://localhost:$port/get_hyperpod_session?connection_key=$ConnectionKey"
43+
44+
try {
45+
$response = Invoke-RestMethod -Uri $apiUrl -Method Get
46+
Write-Log "Fresh credentials obtained from API"
47+
return $response
48+
} catch {
49+
Write-Log "Error: Failed to get credentials from API: $_"
50+
exit 1
51+
}
52+
}
53+
54+
function Main {
55+
Write-Log "=============================================================================="
56+
57+
# Parse hostname format: hp_{cluster_name}_{namespace}_{space_name}_{region}_{account_id}
58+
if ($HostName -match '^hp_([^_]+)_([^_]+)_([^_]+)_([^_]+)_([^_]+)$') {
59+
$clusterName = $Matches[1]
60+
$namespace = $Matches[2]
61+
$devspaceName = $Matches[3]
62+
$connectionKey = "${clusterName}:${namespace}:${devspaceName}"
63+
} else {
64+
# Old format fallback
65+
$devspaceName = $HostName -replace '^hp_', ''
66+
$profilesFile = "$env:USERPROFILE\.aws\.hyperpod-space-profiles"
67+
68+
if (Test-Path $profilesFile) {
69+
$profiles = Get-Content $profilesFile | ConvertFrom-Json
70+
$matches = $profiles.PSObject.Properties.Name | Where-Object { $_ -match ":$devspaceName$" } | Sort-Object
71+
$connectionKey = if ($matches) { $matches[0] } else { $devspaceName }
72+
} else {
73+
$connectionKey = $devspaceName
74+
}
75+
}
76+
77+
if (-not $connectionKey) {
78+
Write-Log "Error: Could not determine connection key for devspace: $devspaceName"
79+
exit 1
80+
}
81+
82+
Write-Log "Connecting to HyperPod devspace: $devspaceName (connection key: $connectionKey)"
83+
84+
# Get fresh credentials
85+
$apiResponse = Get-FreshCredentials -ConnectionKey $connectionKey
86+
87+
# Parse connection URL
88+
$connectionUrl = [System.Web.HttpUtility]::HtmlDecode($apiResponse.connection.url)
89+
$uri = [System.Uri]$connectionUrl
90+
$queryParams = [System.Web.HttpUtility]::ParseQueryString($uri.Query)
91+
92+
$sessionId = $queryParams['sessionId']
93+
$token = $queryParams['sessionToken'] -replace ' ', '+'
94+
$streamUrl = [System.Web.HttpUtility]::UrlDecode($queryParams['streamUrl']) -replace ' ', '+'
95+
96+
# Add cell-number if present (and fix spaces)
97+
$cellNumber = $queryParams['cell-number']
98+
if ($cellNumber) {
99+
$cellNumberDecoded = [System.Web.HttpUtility]::UrlDecode($cellNumber) -replace ' ', '+'
100+
$streamUrl += "&cell-number=$cellNumberDecoded"
101+
}
102+
103+
# Extract region from stream URL
104+
if ($streamUrl -match '\.([a-z0-9-]+)\.amazonaws\.com') {
105+
$awsRegion = $Matches[1]
106+
} else {
107+
Write-Log "Error: Could not extract region from stream URL"
108+
exit 1
109+
}
110+
111+
# Find session-manager-plugin
112+
$awsSsmCli = $env:AWS_SSM_CLI
113+
if (-not $awsSsmCli) {
114+
# Try bundled version first
115+
$bundledPath = "$env:APPDATA\Code\User\globalStorage\amazonwebservices.aws-toolkit-vscode\tools\Amazon\sessionmanagerplugin\bin\session-manager-plugin.exe"
116+
if (Test-Path $bundledPath) {
117+
$awsSsmCli = $bundledPath
118+
} else {
119+
# Fallback to PATH
120+
$awsSsmCli = "session-manager-plugin"
121+
}
122+
}
123+
124+
Write-Log "AWS_REGION=$awsRegion"
125+
Write-Log "AWS_SSM_CLI=$awsSsmCli"
126+
Write-Log "SESSION_ID=$sessionId"
127+
128+
# Execute session-manager-plugin with proper JSON escaping (same as Studio script)
129+
& $awsSsmCli "{\`"streamUrl\`":\`"${streamUrl}\`",\`"tokenValue\`":\`"${token}\`",\`"sessionId\`":\`"${sessionId}\`"}" "$awsRegion" "StartSession"
130+
}
131+
132+
# Load required assembly for URL decoding
133+
Add-Type -AssemblyName System.Web
134+
135+
Main

packages/core/src/awsService/sagemaker/commands.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ export async function deeplinkConnect(
146146
wsUrl,
147147
token,
148148
domain,
149-
appType
149+
appType,
150+
workspaceName,
151+
undefined,
152+
namespace
150153
)
151154

152155
try {

0 commit comments

Comments
 (0)