Skip to content

Commit f76e402

Browse files
committed
Add assume-role functionality to DynamoDB client, refactor BaseModel methods to support it, and adjust tests for compatibility.
1 parent 182160b commit f76e402

File tree

4 files changed

+106
-19
lines changed

4 files changed

+106
-19
lines changed

src/BaseModel.integration.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,14 @@ describe('save', () => {
164164
it('applies passing conditionExpression', async () => {
165165
const obj = { id: 'test-create-5' };
166166
expect(await Model1.get(obj)).toBeUndefined();
167-
await Model1.save(obj, 101, 'attribute_not_exists(id)');
167+
await Model1.save(obj, 101, { conditionExpression: 'attribute_not_exists(id)'});
168168
expect(await Model1.get(obj)).toEqual({ ...obj, version: 1 });
169169
});
170170
it('applies failing conditionExpression', async () => {
171171
const obj = { id: 'test-put-3' };
172172
expect(await Model1.get(obj)).toBeDefined();
173173
const throws = async () => {
174-
await Model1.save(obj, 101, 'attribute_not_exists(id)');
174+
await Model1.save(obj, 101, { conditionExpression:'attribute_not_exists(id)'});
175175
};
176176
await expect(throws()).rejects.toThrow('The conditional request failed');
177177
expect((await Model1.get(obj)).version).toBeUndefined();

src/BaseModel.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ describe('save', () => {
332332
expect(ddbMock).toHaveReceivedCommandWith(PutCommand, { Item: testObj });
333333
});
334334
it('sets conditionExpression when provided', async () => {
335-
await instance.save(testObj, 101, 'x = 1');
335+
await instance.save(testObj, 101, {conditionExpression:'x = 1'});
336336

337337
expect(ddbMock).toHaveReceivedCommandWith(PutCommand, { ConditionExpression: 'x = 1 AND attribute_not_exists(version)' });
338338
});

src/BaseModel.ts

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { chunkArray } from './utils/array';
22
import { dateNow } from './utils/date';
3-
import { dbClient } from './utils/dynamoDbv3';
3+
import { AssumeRoleOpts, getDbClient } from './utils/dynamoDbv3';
44
import {
55
BatchGetCommand,
66
BatchWriteCommand,
@@ -15,6 +15,7 @@ import {
1515
} from '@aws-sdk/lib-dynamodb';
1616
import { Paginator } from '@aws-sdk/types';
1717
import { QueryCommandOutput } from '@aws-sdk/lib-dynamodb/dist-types/commands/QueryCommand';
18+
import { ReturnValue } from '@aws-sdk/client-dynamodb';
1819

1920
const getVersionCondition = (item: BaseObject): string => {
2021
if (item.version) {
@@ -65,7 +66,7 @@ type UpdateOpts = {
6566
ExpressionAttributeValues?: Record<string, any>;
6667
ConditionExpression?: string;
6768
UpdateExpression?: string;
68-
ReturnValues?: string;
69+
ReturnValues?: ReturnValue;
6970
};
7071
skipVersionCondition?: boolean;
7172
};
@@ -147,9 +148,21 @@ export default class BaseModel<T extends BaseObject> {
147148
return { items, lastEvaluatedKey };
148149
};
149150

150-
get = async (keyObj, { consistentRead = true, opts = {} }: { opts?: GetOpts; consistentRead?: boolean } = {}): Promise<T | undefined> => {
151+
get = async (
152+
keyObj,
153+
{
154+
consistentRead = true,
155+
opts = {},
156+
role
157+
}: {
158+
opts?: GetOpts;
159+
consistentRead?: boolean;
160+
role?: AssumeRoleOpts;
161+
} = {}
162+
): Promise<T | undefined> => {
151163
const key = this.createKey(keyObj);
152164
try {
165+
const dbClient = await getDbClient(role);
153166
const { Item } = await dbClient.send(
154167
new GetCommand({
155168
TableName: this.tableName,
@@ -165,7 +178,8 @@ export default class BaseModel<T extends BaseObject> {
165178
}
166179
};
167180

168-
remove = async (keyObj) => {
181+
remove = async (keyObj, { role }: { role?: AssumeRoleOpts } = {}) => {
182+
const dbClient = await getDbClient(role);
169183
await dbClient.send(
170184
new DeleteCommand({
171185
TableName: this.tableName,
@@ -174,8 +188,9 @@ export default class BaseModel<T extends BaseObject> {
174188
);
175189
};
176190

177-
removeBatch = async (keyObjs: any[]) => {
191+
removeBatch = async (keyObjs: any[], { role }: { role?: AssumeRoleOpts } = {}) => {
178192
try {
193+
const dbClient = await getDbClient(role);
179194
const results = await Promise.all(
180195
chunkArray(keyObjs, 25).map((objs) =>
181196
dbClient.send(
@@ -202,7 +217,12 @@ export default class BaseModel<T extends BaseObject> {
202217
}
203218
};
204219

205-
getBatch = async (keyObjs: any[], consistentRead = true, { opts = {} } = {}): Promise<T[]> => {
220+
getBatch = async (
221+
keyObjs: any[],
222+
consistentRead = true,
223+
{ opts = {}, role }: { opts?: any; role?: AssumeRoleOpts } = {}
224+
): Promise<T[]> => {
225+
const dbClient = await getDbClient(role);
206226
const results = await Promise.all(
207227
chunkArray(keyObjs, 100).map((pairs) =>
208228
dbClient.send(
@@ -232,9 +252,17 @@ export default class BaseModel<T extends BaseObject> {
232252
{
233253
opts = { ExpressionAttributeNames: {}, ExpressionAttributeValues: {} },
234254
rangeOp = '=',
235-
maxRequests = 7
236-
}: { opts?: QueryOpts; rangeOp?: string; pageSize?: number; maxRequests?: number } = {}
255+
maxRequests = 7,
256+
role
257+
}: {
258+
opts?: QueryOpts;
259+
rangeOp?: string;
260+
pageSize?: number;
261+
maxRequests?: number;
262+
role?: AssumeRoleOpts;
263+
} = {}
237264
): Promise<{ items: T[]; lastEvaluatedKey?: Partial<T> }> => {
265+
const dbClient = await getDbClient(role);
238266
const ind = this.keys.globalIndexes?.[index];
239267
if (!ind || !ind.hashKey) {
240268
throw new Error(`index "${index}" is not defined in the model`);
@@ -276,9 +304,11 @@ export default class BaseModel<T extends BaseObject> {
276304
opts = { ExpressionAttributeNames: {}, ExpressionAttributeValues: {} },
277305
rangeOp = '=',
278306
consistentRead = true,
279-
maxRequests = 7
280-
}: { opts?: QueryOpts; rangeOp?: string; consistentRead?: boolean; pageSize?: number; maxRequests?: number } = {}
307+
maxRequests = 7,
308+
role
309+
}: { opts?: QueryOpts; rangeOp?: string; consistentRead?: boolean; pageSize?: number; maxRequests?: number; role?: AssumeRoleOpts } = {}
281310
): Promise<{ items: T[]; lastEvaluatedKey?: Partial<T> }> => {
311+
const dbClient = await getDbClient(role);
282312
const hashKey = <string>this.keys.hashKey;
283313
if (!keyObj[hashKey]) {
284314
throw new Error(`hashKey ${hashKey} was not found on keyObj`);
@@ -314,7 +344,8 @@ export default class BaseModel<T extends BaseObject> {
314344
);
315345
};
316346

317-
all = async (): Promise<T[] | undefined> => {
347+
all = async ({ role }: { role?: AssumeRoleOpts } = {}): Promise<T[] | undefined> => {
348+
const dbClient = await getDbClient(role);
318349
const result = await this.getPaginatedResult(paginateScan({ client: dbClient }, { TableName: this.tableName }));
319350
return result.items;
320351
};
@@ -342,12 +373,18 @@ export default class BaseModel<T extends BaseObject> {
342373
};
343374
};
344375

345-
save = async (item: T, userId?, conditionExpression?: string): Promise<T> => {
376+
save = async (
377+
item: T,
378+
userId?,
379+
{ conditionExpression, role }: { conditionExpression?: string; role?: AssumeRoleOpts } = {}
380+
): Promise<T> => {
381+
const dbClient = await getDbClient(role);
346382
const params = this.prepareSave(item, userId, conditionExpression);
347383
return <T>(await dbClient.send(new PutCommand(params))).Attributes;
348384
};
349385

350-
saveBatch = async (items: any[], userId?) => {
386+
saveBatch = async (items: any[], userId?, { role }: { role?: AssumeRoleOpts } = {}) => {
387+
const dbClient = await getDbClient(role);
351388
return await Promise.all(
352389
chunkArray(items, 25).map((itemBatch) =>
353390
dbClient.send(
@@ -383,7 +420,8 @@ export default class BaseModel<T extends BaseObject> {
383420
};
384421
};
385422

386-
update = async (keyObj, opts = {}): Promise<T | undefined> => {
423+
update = async (keyObj, { opts, role }: { opts?: any; role?: AssumeRoleOpts } = {}): Promise<T | undefined> => {
424+
const dbClient = await getDbClient(role);
387425
const data = await dbClient.send(new UpdateCommand(this.prepareUpdate(keyObj, opts)));
388426
return <T>data.Attributes;
389427
};
@@ -399,7 +437,7 @@ export default class BaseModel<T extends BaseObject> {
399437
ExpressionAttributeValues: {},
400438
ConditionExpression: undefined,
401439
UpdateExpression: undefined,
402-
ReturnValues: 'ALL_NEW',
440+
ReturnValues: ReturnValue.ALL_NEW,
403441
...overrideOpts
404442
};
405443
this.validateItemKeys(item, 'updated');
@@ -443,7 +481,13 @@ export default class BaseModel<T extends BaseObject> {
443481
};
444482
};
445483

446-
updateV2 = async (item, changes = {}, updateOpts?: UpdateOpts, userId?: string): Promise<T | undefined> => {
484+
updateV2 = async (
485+
item,
486+
changes = {},
487+
{ role, ...updateOpts }: UpdateOpts & { role?: AssumeRoleOpts } = {},
488+
userId?: string
489+
): Promise<T | undefined> => {
490+
const dbClient = await getDbClient(role);
447491
const result = await dbClient.send(new UpdateCommand(this.prepareUpdateV2(item, changes, updateOpts, userId)));
448492
return <T>result?.Attributes;
449493
};

src/utils/dynamoDbv3.ts

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,52 @@
11
import { DynamoDBClient } from '@aws-sdk/client-dynamodb';
22
import { DynamoDBDocumentClient } from '@aws-sdk/lib-dynamodb';
3+
import { AssumeRoleCommand, STSClient } from '@aws-sdk/client-sts';
34

45
const { AWS_DYNAMODB_REGION: region } = process.env;
56

67
const client = new DynamoDBClient({ ...(region && { region }) });
78
export const dbClient = DynamoDBDocumentClient.from(client, {
89
marshallOptions: { removeUndefinedValues: true }
910
});
11+
12+
export type AssumeRoleOpts =
13+
| {
14+
roleArn: string;
15+
roleSessionName: string;
16+
externalId: string;
17+
region: string;
18+
}
19+
| Record<string, never>;
20+
21+
export async function createAssumedDbClient(params: AssumeRoleOpts): Promise<DynamoDBDocumentClient> {
22+
const sts = new STSClient({ region: params.region });
23+
const assumeRole = await sts.send(
24+
new AssumeRoleCommand({
25+
RoleArn: params.roleArn,
26+
RoleSessionName: params.roleSessionName || 'dynamodb-per-call',
27+
ExternalId: params.externalId
28+
})
29+
);
30+
if (!assumeRole.Credentials) throw new Error('AssumeRole returned no credentials');
31+
32+
const dynamo = new DynamoDBClient({
33+
region: params.region,
34+
credentials: {
35+
accessKeyId: assumeRole.Credentials.AccessKeyId!,
36+
secretAccessKey: assumeRole.Credentials.SecretAccessKey!,
37+
sessionToken: assumeRole.Credentials.SessionToken!
38+
}
39+
});
40+
return DynamoDBDocumentClient.from(dynamo, { marshallOptions: { removeUndefinedValues: true } });
41+
}
42+
43+
// Reusable selector for which DocumentClient to use on a call
44+
export const getDbClient = async (params?: AssumeRoleOpts): Promise<DynamoDBDocumentClient> => {
45+
if (!params || !params?.roleArn) return dbClient;
46+
return await createAssumedDbClient({
47+
roleArn: params.roleArn,
48+
roleSessionName: params.roleSessionName,
49+
externalId: params.externalId,
50+
region: params.region
51+
});
52+
};

0 commit comments

Comments
 (0)