66# HyperPod is AWS's managed service for distributed machine learning training at scale.
77#
88# OVERVIEW:
9- # The script acts as a wrapper around the AWS SSM CLI to create a secure session tunnel to a HyperPod
10- # compute instance. It validates required parameters, logs the connection attempt, and executes the
11- # SSM StartSession command with HyperPod-specific connection details.
12- #
13- # REQUIRED ENVIRONMENT VARIABLES:
14- # AWS_REGION - AWS region where the HyperPod cluster is located (e.g., us-west-2)
15- # AWS_SSM_CLI - Path to the AWS SSM CLI executable
16- # STREAM_URL - WebSocket stream URL for the HyperPod session connection
17- # TOKEN - Authentication token for the HyperPod session
18- # SESSION_ID - Unique identifier for the HyperPod session
9+ # The script always gets fresh credentials from the reconnection manager and uses them to establish
10+ # a secure session tunnel to a HyperPod compute instance.
1911#
2012# OPTIONAL ENVIRONMENT VARIABLES:
2113# LOG_FILE_LOCATION - Path to log file (default: /tmp/hyperpod_connect.log)
2214# DEBUG_LOG - Enable debug logging (default: 0)
2315#
2416# USAGE:
25- # AWS_REGION=us-west-2 AWS_SSM_CLI=/usr/local/bin/session-manager-plugin \
26- # STREAM_URL=wss://... TOKEN=abc123... SESSION_ID=session-xyz \
27- # ./hyperpod_connect
17+ # ./hyperpod_connect hp_demo1
2818#
2919# SECURITY NOTE:
3020# This script handles sensitive authentication tokens. Ensure proper file permissions
@@ -57,6 +47,28 @@ _require() {
5747 _log " $1 =$2 "
5848}
5949
50+ _url_encode () {
51+ python3 -c " import sys, urllib.parse; print(urllib.parse.quote(sys.argv[1]))" " $1 "
52+ }
53+
54+ _resolve_connection_key () {
55+ local DEVSPACE_NAME=$1
56+ local PROFILES_FILE=" $HOME /.aws/.hyperpod-space-profiles"
57+
58+ [ ! -f " $PROFILES_FILE " ] && echo " $DEVSPACE_NAME " && return
59+
60+ python3 -c "
61+ import json
62+ try:
63+ with open('$PROFILES_FILE ', 'r') as f:
64+ profiles = json.load(f)
65+ matches = sorted([k for k in profiles.keys() if k.endswith(':$DEVSPACE_NAME ')])
66+ print(matches[0] if matches else '$DEVSPACE_NAME ')
67+ except:
68+ print('$DEVSPACE_NAME ')
69+ " 2> /dev/null
70+ }
71+
6072_hyperpod () {
6173 # Function inputs
6274 local AWS_SSM_CLI=$1
@@ -68,19 +80,129 @@ _hyperpod() {
6880 exec " $AWS_SSM_CLI " " {\" streamUrl\" :\" $STREAM_URL \" ,\" tokenValue\" :\" $TOKEN \" ,\" sessionId\" :\" $SESSION_ID \" }" " $AWS_REGION " " StartSession"
6981}
7082
83+ _get_fresh_credentials () {
84+ local CONNECTION_KEY=$1
85+
86+ _log " Getting fresh credentials for connection key: $CONNECTION_KEY "
87+
88+ # Read server info to get port
89+ if [ -z " $SAGEMAKER_LOCAL_SERVER_FILE_PATH " ]; then
90+ _log " Error: SAGEMAKER_LOCAL_SERVER_FILE_PATH environment variable is not set"
91+ exit 1
92+ fi
93+
94+ if [ ! -f " $SAGEMAKER_LOCAL_SERVER_FILE_PATH " ]; then
95+ _log " Error: Server info file not found: $SAGEMAKER_LOCAL_SERVER_FILE_PATH "
96+ exit 1
97+ fi
98+
99+ local PORT=$( jq -r ' .port' " $SAGEMAKER_LOCAL_SERVER_FILE_PATH " )
100+ if [ -z " $PORT " ] || [ " $PORT " == " null" ]; then
101+ _log " Error: 'port' field is missing or invalid in $SAGEMAKER_LOCAL_SERVER_FILE_PATH "
102+ exit 1
103+ fi
104+
105+ # Call API to get fresh credentials using the determined connection key
106+ local API_URL
107+ if [[ " $CONNECTION_KEY " =~ ^[^:]+:[^:]+:[^:]+$ ]]; then
108+ # Use full connection key for precise lookup (preferred)
109+ API_URL=" http://localhost:$PORT /get_hyperpod_session?connection_key=$( _url_encode " $CONNECTION_KEY " ) "
110+ elif [ -n " $CLUSTER_NAME " ] && [ -n " $NAMESPACE " ]; then
111+ # Use individual parameters if available
112+ API_URL=" http://localhost:$PORT /get_hyperpod_session?devspace_name=$( _url_encode " $DEVSPACE_NAME " ) &namespace=$( _url_encode " $NAMESPACE " ) &cluster_name=$( _url_encode " $CLUSTER_NAME " ) "
113+ else
114+ # Fallback for legacy format
115+ API_URL=" http://localhost:$PORT /get_hyperpod_session?connection_key=$( _url_encode " $CONNECTION_KEY " ) "
116+ fi
117+ API_RESPONSE=$( curl -s " $API_URL " )
118+
119+ if [ $? -ne 0 ] || [ -z " $API_RESPONSE " ]; then
120+ _log " Error: Failed to get credentials from API"
121+ exit 1
122+ fi
123+
124+ # Parse JSON once and check for error response
125+ read -r STATUS ERROR_MSG CONNECTION_URL < <( echo " $API_RESPONSE " | python3 -c " import sys, json; data=json.load(sys.stdin); print(data.get('status', 'success'), data.get('message', ''), data.get('connection', {}).get('url', ''))" 2> /dev/null) || {
126+ _log " Error: Failed to parse API response JSON"
127+ exit 1
128+ }
129+ if [ " $STATUS " = " error" ]; then
130+ _log " Error from API: ${ERROR_MSG:- Unknown error} "
131+ exit 1
132+ fi
133+
134+ _log " Fresh credentials obtained from API"
135+ }
136+
71137_main () {
72138 # Set defaults for missing environment variables
73139 DEBUG_LOG=${DEBUG_LOG:- 0}
74140 LOG_FILE_LOCATION=${LOG_FILE_LOCATION:-/ tmp/ hyperpod_connect.log}
75141
142+ if [ $# -ne 1 ]; then
143+ _log " Usage: $0 hp_<devspace_name>"
144+ exit 1
145+ fi
146+
147+ # Extract devspace name, cluster name, and namespace from hostname
148+ # New format: hp_{cluster_name}_{namespace}_{space_name}_{region}_{account_id}
149+ # Old format: hp_{devspace} (for backward compatibility)
150+ if [[ " $1 " =~ ^hp_([^_]+)_([^_]+)_([^_]+)_([^_]+)_([^_]+)$ ]]; then
151+ # New format: cluster_namespace_space_region_account
152+ CLUSTER_NAME=" ${BASH_REMATCH[1]} "
153+ NAMESPACE=" ${BASH_REMATCH[2]} "
154+ DEVSPACE_NAME=" ${BASH_REMATCH[3]} "
155+ # Construct the expected connection key (cluster:namespace:devspace)
156+ CONNECTION_KEY=" ${CLUSTER_NAME} :${NAMESPACE} :${DEVSPACE_NAME} "
157+ else
158+ # Old format or fallback - extract devspace name only
159+ DEVSPACE_NAME=$( echo " $1 " | sed ' s/hp_//' )
160+ CLUSTER_NAME=" "
161+ NAMESPACE=" "
162+ CONNECTION_KEY=" "
163+ fi
164+
165+ # For new format, we already have the connection key constructed from hostname
166+ # For old format, look up the connection key from the hyperpod profiles file
167+ [ -z " $CONNECTION_KEY " ] && CONNECTION_KEY=$( _resolve_connection_key " $DEVSPACE_NAME " )
168+
169+ if [ -z " $CONNECTION_KEY " ]; then
170+ _log " Error: Could not determine connection key for devspace: $DEVSPACE_NAME "
171+ exit 1
172+ fi
173+
76174 _log " =============================================================================="
77- _require AWS_REGION " ${AWS_REGION:- } "
78- _require AWS_SSM_CLI " ${AWS_SSM_CLI:- } "
79- _require SESSION_ID " ${SESSION_ID:- } "
80- _require_nolog STREAM_URL " ${STREAM_URL:- } "
81- _require_nolog TOKEN " ${TOKEN:- } "
175+ _log " Connecting to HyperPod devspace: $DEVSPACE_NAME (connection key: $CONNECTION_KEY )"
176+
177+ # Always get fresh credentials (CONNECTION_URL is set by _get_fresh_credentials)
178+ _get_fresh_credentials " $CONNECTION_KEY "
179+
180+ if [[ " $CONNECTION_URL " =~ ^wss:// ]]; then
181+ # Presigned URL format - extract session ID from path and token from cell-number param
182+ SESSION_ID=$( echo " $CONNECTION_URL " | python3 -c " import sys, urllib.parse; url=sys.stdin.read().strip(); path=urllib.parse.urlparse(url).path; session_id=path.split('/')[-1] if path else ''; print(session_id)" 2> /dev/null)
183+ TOKEN=$( echo " $CONNECTION_URL " | python3 -c " import sys, urllib.parse; url=sys.stdin.read().strip(); params=urllib.parse.parse_qs(urllib.parse.urlparse(url).query); print(params.get('cell-number', [''])[0])" 2> /dev/null)
184+ STREAM_URL=" $CONNECTION_URL "
185+ else
186+ # Kubectl format - extract from query params
187+ SESSION_ID=$( echo " $CONNECTION_URL " | python3 -c " import sys, urllib.parse; url=sys.stdin.read().strip(); params=urllib.parse.parse_qs(urllib.parse.urlparse(url).query); print(params.get('sessionId', [''])[0])" 2> /dev/null)
188+ TOKEN=$( echo " $CONNECTION_URL " | python3 -c " import sys, urllib.parse; url=sys.stdin.read().strip(); params=urllib.parse.parse_qs(urllib.parse.urlparse(url).query); print(params.get('sessionToken', [''])[0])" 2> /dev/null)
189+ STREAM_URL=$( echo " $CONNECTION_URL " | python3 -c " import sys, urllib.parse; url=sys.stdin.read().strip(); params=urllib.parse.parse_qs(urllib.parse.urlparse(url).query); print(params.get('streamUrl', [''])[0])" 2> /dev/null)
190+ fi
191+
192+ # Extract region from stream URL
193+ AWS_REGION=$( echo " $STREAM_URL " | grep -o ' \.[a-z0-9-]*\.amazonaws\.com' | sed ' s/^\.//;s/\.amazonaws\.com$//' )
194+ AWS_SSM_CLI=" ${AWS_SSM_CLI:- session-manager-plugin} "
195+
196+ # Validate required parameters (SESSION_ID, STREAM_URL, TOKEN are fetched from API, not validated here)
197+ _require AWS_REGION " ${AWS_REGION} "
198+ _require AWS_SSM_CLI " ${AWS_SSM_CLI} "
199+
200+ if [ -z " ${SESSION_ID} " ] || [ -z " ${STREAM_URL} " ] || [ -z " ${TOKEN} " ]; then
201+ _log " Error: Failed to retrieve valid session credentials from API"
202+ exit 1
203+ fi
82204
83205 _hyperpod " $AWS_SSM_CLI " " $AWS_REGION " " $STREAM_URL " " $TOKEN " " $SESSION_ID "
84206}
85207
86- _main
208+ _main " $@ "
0 commit comments