forked from aws/aws-toolkit-vscode
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.ts
More file actions
371 lines (330 loc) · 13.6 KB
/
model.ts
File metadata and controls
371 lines (330 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
/*!
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
// Disabled: detached server files cannot import vscode.
/* eslint-disable no-restricted-imports */
import * as vscode from 'vscode'
import { sshAgentSocketVariable, startSshAgent, startVscodeRemote } from '../../shared/extensions/ssh'
import { createBoundProcess, ensureDependencies } from '../../shared/remoteSession'
import { SshConfig } from '../../shared/sshConfig'
import { Result } from '../../shared/utilities/result'
import * as path from 'path'
import { persistLocalCredentials, persistSmusProjectCreds, persistSSMConnection } from './credentialMapping'
import * as os from 'os'
import _ from 'lodash'
import { fs } from '../../shared/fs/fs'
import * as nodefs from 'fs'
import { getSmSsmEnv, spawnDetachedServer } from './utils'
import { getLogger } from '../../shared/logger/logger'
import { DevSettings } from '../../shared/settings'
import { ToolkitError } from '../../shared/errors'
import { SagemakerSpaceNode } from './explorer/sagemakerSpaceNode'
import { sleep } from '../../shared/utilities/timeoutUtils'
import { SagemakerUnifiedStudioSpaceNode } from '../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpaceNode'
import globals from '../../shared/extensionGlobals'
const logger = getLogger('sagemaker')
class HyperPodSshConfig extends SshConfig {
constructor(
sshPath: string,
private readonly hyperpodConnectPath: string
) {
super(sshPath, 'hp_', 'hyperpod_connect')
}
protected override createSSHConfigSection(proxyCommand: string): string {
return `
# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode
Host hp_*
ForwardAgent yes
AddKeysToAgent yes
StrictHostKeyChecking accept-new
ProxyCommand '${this.hyperpodConnectPath}' '%h'
IdentitiesOnly yes
`
}
public override async ensureValid() {
const proxyCommand = `'${this.hyperpodConnectPath}' '%h'`
const verifyHost = await this.verifySSHHost(proxyCommand)
if (verifyHost.isErr()) {
return verifyHost
}
return Result.ok()
}
}
export async function tryRemoteConnection(
node: SagemakerSpaceNode | SagemakerUnifiedStudioSpaceNode,
ctx: vscode.ExtensionContext,
progress: vscode.Progress<{ message?: string; increment?: number }>
) {
const spaceArn = (await node.getSpaceArn()) as string
const isSMUS = node instanceof SagemakerUnifiedStudioSpaceNode
const remoteEnv = await prepareDevEnvConnection(spaceArn, ctx, 'sm_lc', isSMUS, node)
try {
progress.report({ message: 'Opening remote session' })
await startVscodeRemote(
remoteEnv.SessionProcess,
remoteEnv.hostname,
'/home/sagemaker-user',
remoteEnv.vscPath,
'sagemaker-user'
)
} catch (err) {
getLogger().info(
`sm:OpenRemoteConnect: Unable to connect to target space with arn: ${await node.getAppArn()} error: ${err}`
)
}
}
export function extractRegionFromStreamUrl(streamUrl: string): string {
const url = new URL(streamUrl)
const match = url.hostname.match(/^[^.]+\.([^.]+)\.amazonaws\.com$/)
if (!match) {
throw new Error(`Unable to get region from stream url: ${streamUrl}`)
}
return match[1]
}
export async function prepareDevEnvConnection(
spaceArn: string,
ctx: vscode.ExtensionContext,
connectionType: string,
isSMUS: boolean,
node: SagemakerSpaceNode | SagemakerUnifiedStudioSpaceNode | undefined,
session?: string,
wsUrl?: string,
token?: string,
domain?: string,
appType?: string
) {
const remoteLogger = configureRemoteConnectionLogger()
const { ssm, vsc, ssh } = (await ensureDependencies()).unwrap()
// Check timeout setting for remote SSH connections
const remoteSshConfig = vscode.workspace.getConfiguration('remote.SSH')
const current = remoteSshConfig.get<number>('connectTimeout')
if (typeof current === 'number' && current < 120) {
await remoteSshConfig.update('connectTimeout', 120, vscode.ConfigurationTarget.Global)
void vscode.window.showInformationMessage(
'Updated "remote.SSH.connectTimeout" to 120 seconds to improve stability.'
)
}
const hostnamePrefix = connectionType
let hostname: string
if (connectionType === 'sm_hp') {
hostname = `hp_${session}`
} else {
hostname = `${hostnamePrefix}_${spaceArn.replace(/\//g, '__').replace(/:/g, '_._')}`
}
// save space credential mapping
if (connectionType === 'sm_lc') {
if (!isSMUS) {
await persistLocalCredentials(spaceArn)
} else {
await persistSmusProjectCreds(spaceArn, node as SagemakerUnifiedStudioSpaceNode)
}
} else if (connectionType === 'sm_dl') {
await persistSSMConnection(spaceArn, domain ?? '', session, wsUrl, token, appType, isSMUS)
}
// HyperPod doesn't need the local server (only for SageMaker Studio)
if (connectionType !== 'sm_hp') {
await startLocalServer(ctx)
}
await removeKnownHost(hostname)
const hyperpodConnectPath = path.join(ctx.globalStorageUri.fsPath, 'hyperpod_connect')
// Copy hyperpod_connect script if needed
if (connectionType === 'sm_hp') {
const sourceScriptPath = ctx.asAbsolutePath('resources/hyperpod_connect')
if (!(await fs.existsFile(hyperpodConnectPath))) {
try {
await fs.copy(sourceScriptPath, hyperpodConnectPath)
await fs.chmod(hyperpodConnectPath, 0o755)
logger.info(`Copied hyperpod_connect script to ${hyperpodConnectPath}`)
} catch (err) {
logger.error(`Failed to copy hyperpod_connect script: ${err}`)
}
}
}
const sshConfig =
connectionType === 'sm_hp'
? new HyperPodSshConfig(ssh, hyperpodConnectPath)
: new SshConfig(ssh, 'sm_', 'sagemaker_connect')
const config = await sshConfig.ensureValid()
if (config.isErr()) {
const err = config.err()
const logPrefix = connectionType === 'sm_hp' ? 'hyperpod' : 'sagemaker'
logger.error(`${logPrefix}: failed to add ssh config section: ${err.message}`)
throw err
}
// set envirionment variables
const vars: NodeJS.ProcessEnv =
connectionType === 'sm_hp'
? await (async () => {
const logFileLocation = path.join(ctx.globalStorageUri.fsPath, 'hyperpod-connection.log')
const decodedWsUrl =
wsUrl
?.replace(/'/g, "'")
.replace(/"/g, '"')
.replace(/&/g, '&') || ''
const decodedToken =
token
?.replace(/'/g, "'")
.replace(/"/g, '"')
.replace(/&/g, '&') || ''
const region = decodedWsUrl ? extractRegionFromStreamUrl(decodedWsUrl) : ''
const hyperPodEnv: NodeJS.ProcessEnv = {
AWS_REGION: region,
SESSION_ID: session || '',
STREAM_URL: decodedWsUrl,
TOKEN: decodedToken,
AWS_SSM_CLI: ssm,
DEBUG_LOG: '1',
LOG_FILE_LOCATION: logFileLocation,
}
// Add AWS credentials
try {
const creds = await globals.awsContext.getCredentials()
if (creds) {
hyperPodEnv.AWS_ACCESS_KEY_ID = creds.accessKeyId
hyperPodEnv.AWS_SECRET_ACCESS_KEY = creds.secretAccessKey
if (creds.sessionToken) {
hyperPodEnv.AWS_SESSION_TOKEN = creds.sessionToken
}
logger.info('Added AWS credentials to environment')
} else {
logger.warn('No AWS credentials available for HyperPod connection')
}
} catch (err) {
logger.warn(`Failed to get AWS credentials: ${err}`)
}
return { ...process.env, ...hyperPodEnv }
})()
: getSmSsmEnv(ssm, path.join(ctx.globalStorageUri.fsPath, 'sagemaker-local-server-info.json'))
logger.info(`connect script logs at ${vars.LOG_FILE_LOCATION}`)
const envProvider = async () => {
return { [sshAgentSocketVariable]: await startSshAgent(), ...vars }
}
const SessionProcess = createBoundProcess(envProvider).extend({
onStdout: (data: string) => {
remoteLogger(data)
if (connectionType === 'sm_hp') {
getLogger().info(`[ProxyCommand stdout] ${data}`)
}
},
onStderr: (data: string) => {
remoteLogger(data)
if (connectionType === 'sm_hp') {
getLogger().error(`[ProxyCommand stderr] ${data}`)
}
},
rejectOnErrorCode: true,
})
return {
hostname,
envProvider,
sshPath: ssh,
vscPath: vsc,
SessionProcess,
}
}
export function configureRemoteConnectionLogger() {
const logPrefix = 'sagemaker:'
const logger = (data: string) => getLogger().info(`${logPrefix}: ${data}`)
return logger
}
export async function startLocalServer(ctx: vscode.ExtensionContext) {
const storagePath = ctx.globalStorageUri.fsPath
const serverPath = ctx.asAbsolutePath(path.join('dist/src/awsService/sagemaker/detached-server/', 'server.js'))
const outLog = path.join(storagePath, 'sagemaker-local-server.out.log')
const errLog = path.join(storagePath, 'sagemaker-local-server.err.log')
const infoFilePath = path.join(storagePath, 'sagemaker-local-server-info.json')
logger.info(`sagemaker-local-server.*.log at ${storagePath}`)
const customEndpoint = DevSettings.instance.get('endpoints', {})['sagemaker']
await stopLocalServer(ctx)
const child = spawnDetachedServer(process.execPath, [serverPath], {
cwd: path.dirname(serverPath),
detached: true,
stdio: ['ignore', nodefs.openSync(outLog, 'a'), nodefs.openSync(errLog, 'a')],
env: {
...process.env,
SAGEMAKER_ENDPOINT: customEndpoint,
SAGEMAKER_LOCAL_SERVER_FILE_PATH: infoFilePath,
},
})
child.unref()
// Wait for the info file to appear (timeout after 10 seconds)
const maxRetries = 20
const delayMs = 500
for (let i = 0; i < maxRetries; i++) {
if (await fs.existsFile(infoFilePath)) {
logger.debug('Detected server info file.')
return
}
await sleep(delayMs)
}
throw new ToolkitError(`Timed out waiting for local server info file: ${infoFilePath}`)
}
interface LocalServerInfo {
pid: number
port: string
}
export async function stopLocalServer(ctx: vscode.ExtensionContext): Promise<void> {
const infoFilePath = path.join(ctx.globalStorageUri.fsPath, 'sagemaker-local-server-info.json')
if (!(await fs.existsFile(infoFilePath))) {
logger.debug('no server info file found. nothing to stop.')
return
}
let pid: number | undefined
try {
const content = await fs.readFileText(infoFilePath)
const infoJson = JSON.parse(content) as LocalServerInfo
pid = infoJson.pid
} catch (err: any) {
throw ToolkitError.chain(err, 'failed to parse server info file')
}
if (typeof pid === 'number' && !isNaN(pid)) {
try {
process.kill(pid)
logger.debug(`stopped local server with PID ${pid}`)
} catch (err: any) {
if (err.code === 'ESRCH') {
logger.warn(`no process found with PID ${pid}. It may have already exited.`)
} else {
throw ToolkitError.chain(err, 'failed to stop local server')
}
}
} else {
logger.warn('no valid PID found in info file.')
}
try {
await fs.delete(infoFilePath)
logger.debug('removed server info file.')
} catch (err: any) {
logger.warn(`could not delete info file: ${err.message ?? err}`)
}
}
export async function removeKnownHost(hostname: string): Promise<void> {
const knownHostsPath = path.join(os.homedir(), '.ssh', 'known_hosts')
if (!(await fs.existsFile(knownHostsPath))) {
logger.warn(`known_hosts not found at ${knownHostsPath}`)
return
}
let lines: string[]
try {
const content = await fs.readFileText(knownHostsPath)
lines = content.split('\n')
} catch (err: any) {
throw ToolkitError.chain(err, 'Failed to read known_hosts file')
}
const updatedLines = lines.filter((line) => {
const entryHostname = line.split(' ')[0].split(',')
// Hostnames in the known_hosts file seem to be always lowercase, but keeping the case-sensitive check just in
// case. Originally we were only doing the case-sensitive check which caused users to get a host
// identification error when reconnecting to a Space after it was restarted.
return !entryHostname.includes(hostname) && !entryHostname.includes(hostname.toLowerCase())
})
if (updatedLines.length !== lines.length) {
try {
await fs.writeFile(knownHostsPath, updatedLines.join('\n'), { atomic: true })
logger.debug(`Removed '${hostname}' from known_hosts`)
} catch (err: any) {
throw ToolkitError.chain(err, 'Failed to write updated known_hosts file')
}
}
}