Skip to content

Commit cca4e38

Browse files
committed
fix(amazonq): fix iam credential update logic to use custom comparator and added buffer time in cred validation
1 parent c368527 commit cca4e38

File tree

5 files changed

+86
-5
lines changed

5 files changed

+86
-5
lines changed

packages/amazonq/src/lsp/auth.ts

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import * as crypto from 'crypto'
1717
import { LanguageClient } from 'vscode-languageclient'
1818
import { AuthUtil } from 'aws-core-vscode/codewhisperer'
1919
import { Writable } from 'stream'
20-
import { onceChanged } from 'aws-core-vscode/utils'
20+
import { onceChanged, onceChangedWithComparator } from 'aws-core-vscode/utils'
2121
import { getLogger, oneMinute, isSageMaker } from 'aws-core-vscode/shared'
2222
import { isSsoConnection, isIamConnection } from 'aws-core-vscode/auth'
2323

@@ -108,7 +108,21 @@ export class AmazonQLspAuth {
108108
this.client.info(`UpdateBearerToken: ${JSON.stringify(request)}`)
109109
}
110110

111-
public updateIamCredentials = onceChanged(this._updateIamCredentials.bind(this))
111+
private areCredentialsEqual(creds1: any, creds2: any): boolean {
112+
if (!creds1 && !creds2) return true
113+
if (!creds1 || !creds2) return false
114+
115+
return (
116+
creds1.accessKeyId === creds2.accessKeyId &&
117+
creds1.secretAccessKey === creds2.secretAccessKey &&
118+
creds1.sessionToken === creds2.sessionToken
119+
)
120+
}
121+
122+
public updateIamCredentials = onceChangedWithComparator(
123+
this._updateIamCredentials.bind(this),
124+
([prevCreds], [currentCreds]) => this.areCredentialsEqual(prevCreds, currentCreds)
125+
)
112126
private async _updateIamCredentials(credentials: any) {
113127
getLogger().info(
114128
`[SageMaker Debug] Updating IAM credentials - credentials received: ${credentials ? 'YES' : 'NO'}`

packages/core/src/auth/auth.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,7 @@ export class Auth implements AuthService, ConnectionManager {
862862

863863
private async createCachedCredentials(provider: CredentialsProvider) {
864864
const providerId = provider.getCredentialsId()
865+
getLogger().debug(`credentials: create cache credentials for ${provider.getProviderType()}`)
865866
globals.loginManager.store.invalidateCredentials(providerId)
866867
const { credentials, endpointUrl } = await globals.loginManager.store.upsertCredentials(providerId, provider)
867868
await globals.loginManager.validateCredentials(credentials, endpointUrl, provider.getDefaultRegion())
@@ -874,6 +875,7 @@ export class Auth implements AuthService, ConnectionManager {
874875
if (creds !== undefined && creds.credentialsHashCode === provider.getHashCode()) {
875876
return creds.credentials
876877
}
878+
return undefined
877879
}
878880

879881
private readonly getToken = keyedDebounce(this._getToken.bind(this))

packages/core/src/auth/credentials/store.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import globals from '../../shared/extensionGlobals'
88
import { getLogger } from '../../shared/logger/logger'
99
import { asString, CredentialsProvider, CredentialsId } from '../providers/credentials'
1010
import { CredentialsProviderManager } from '../providers/credentialsProviderManager'
11+
// import { get } from 'lodash'
1112

1213
export interface CachedCredentials {
1314
credentials: AWS.Credentials
@@ -31,11 +32,17 @@ export class CredentialsStore {
3132
* If the expiration property does not exist, it is assumed to never expire.
3233
*/
3334
public isValid(key: string): boolean {
35+
// Apply 60-second buffer similar to SSO token expiry logic
36+
const expirationBufferMs = 60000
37+
3438
if (this.credentialsCache[key]) {
3539
const expiration = this.credentialsCache[key].credentials.expiration
36-
return expiration !== undefined ? expiration >= new globals.clock.Date() : true
40+
const now = new globals.clock.Date()
41+
const bufferedNow = new globals.clock.Date(now.getTime() + expirationBufferMs)
42+
const isValid = expiration !== undefined ? expiration >= bufferedNow : true
43+
return isValid
3744
}
38-
45+
getLogger().debug(`credentials: no credentials found for ${key}`)
3946
return false
4047
}
4148

packages/core/src/shared/utilities/functionUtils.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,32 @@ export function onceChanged<T, U extends any[]>(fn: (...args: U) => T): (...args
6363
: ((val = fn(...args)), (ran = true), (prevArgs = args.map(String).join(':')), val)
6464
}
6565

66+
/**
67+
* Creates a function that runs only if the args changed versus the previous invocation,
68+
* using a custom comparator function for argument comparison.
69+
*
70+
* @param fn The function to wrap
71+
* @param comparator Function that returns true if arguments are equal
72+
*/
73+
export function onceChangedWithComparator<T, U extends any[]>(
74+
fn: (...args: U) => T,
75+
comparator: (prev: U, current: U) => boolean
76+
): (...args: U) => T {
77+
let val: T
78+
let ran = false
79+
let prevArgs: U
80+
81+
return (...args) => {
82+
if (ran && comparator(prevArgs, args)) {
83+
return val
84+
}
85+
val = fn(...args)
86+
ran = true
87+
prevArgs = args
88+
return val
89+
}
90+
}
91+
6692
/**
6793
* Creates a new function that stores the result of a call.
6894
*

packages/core/src/test/shared/utilities/functionUtils.test.ts

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
*/
55

66
import assert from 'assert'
7-
import { once, onceChanged, debounce, oncePerUniqueArg } from '../../../shared/utilities/functionUtils'
7+
import {
8+
once,
9+
onceChanged,
10+
debounce,
11+
oncePerUniqueArg,
12+
onceChangedWithComparator,
13+
} from '../../../shared/utilities/functionUtils'
814
import { installFakeClock } from '../../testUtil'
915

1016
describe('functionUtils', function () {
@@ -49,6 +55,32 @@ describe('functionUtils', function () {
4955
assert.strictEqual(counter, 3)
5056
})
5157

58+
it('onceChangedWithComparator()', function () {
59+
let counter = 0
60+
const credentialsEqual = ([prev]: [any], [current]: [any]) => {
61+
if (!prev && !current) return true
62+
if (!prev || !current) return false
63+
return prev.accessKeyId === current.accessKeyId && prev.secretAccessKey === current.secretAccessKey
64+
}
65+
const fn = onceChangedWithComparator((creds: any) => void counter++, credentialsEqual)
66+
67+
const creds1 = { accessKeyId: 'key1', secretAccessKey: 'secret1' }
68+
const creds2 = { accessKeyId: 'key1', secretAccessKey: 'secret1' }
69+
const creds3 = { accessKeyId: 'key2', secretAccessKey: 'secret2' }
70+
71+
fn(creds1)
72+
assert.strictEqual(counter, 1)
73+
74+
fn(creds2) // Same values, should not execute
75+
assert.strictEqual(counter, 1)
76+
77+
fn(creds3) // Different values, should execute
78+
assert.strictEqual(counter, 2)
79+
80+
fn(creds3) // Same as previous, should not execute
81+
assert.strictEqual(counter, 2)
82+
})
83+
5284
it('oncePerUniqueArg()', function () {
5385
let counter = 0
5486
const fn = oncePerUniqueArg((s: string) => {

0 commit comments

Comments
 (0)