Skip to content

Commit 7b54e54

Browse files
greenwoodcmaajtodd
andauthored
add RuleBuilder::then_compute_output (#4299)
this allows the developer to compute a mocked output using content from the input. ## Motivation and Context Some richer use cases for stubbing/mocking require computing a stubbed output using data from the input, for instance stubbing a method that returns the item that was put. ## Description Add a new method to `RuleBuilder` that allows for providing a function that computes the output. That function takes in a reference to the input as an argument. I'm not sure how this works with the mutable inputs like streamed bytes, but seems to work for the simple stuff. ## Testing added a simple unit test ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --------- Co-authored-by: Aaron Todd <aajtodd@users.noreply.github.com>
1 parent 9f495a4 commit 7b54e54

File tree

7 files changed

+164
-21
lines changed

7 files changed

+164
-21
lines changed

.changelog/1757712519.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
---
2+
applies_to: ["client"]
3+
authors: ["greenwoodcm"]
4+
references: ["smithy-rs#4299"]
5+
breaking: false
6+
new_feature: true
7+
bug_fix: false
8+
---
9+
Added a new `then_compute_output` to `aws-smithy-mocks` rule builder that allows using the input type when computing a mocked response, e.g.
10+
```rs
11+
// Return a computed output based on the input
12+
let compute_rule = mock!(Client::get_object)
13+
.then_compute_output(|req| {
14+
let key = req.key().unwrap_or("unknown");
15+
GetObjectOutput::builder()
16+
.body(ByteStream::from_static(format!("content for {}", key).as_bytes()))
17+
.build()
18+
});
19+
```

aws/sdk/integration-tests/s3/tests/mocks.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,39 @@ async fn test_mock_client() {
8989
assert_eq!(err.code(), Some("InvalidAccessKey"));
9090
}
9191

92+
#[tokio::test]
93+
async fn test_mock_client_compute() {
94+
let s3_computed = mock!(aws_sdk_s3::Client::get_object)
95+
.match_requests(|inp| {
96+
inp.bucket() == Some("test-bucket") && inp.key() == Some("correct-key")
97+
})
98+
.then_compute_output(|input| {
99+
let content =
100+
format!("{}.{}", input.bucket().unwrap(), input.key().unwrap()).into_bytes();
101+
GetObjectOutput::builder()
102+
.body(ByteStream::from(content))
103+
.build()
104+
});
105+
106+
let s3 = mock_client!(aws_sdk_s3, &[&s3_computed]);
107+
108+
let data = s3
109+
.get_object()
110+
.bucket("test-bucket")
111+
.key("correct-key")
112+
.send()
113+
.await
114+
.expect("success response")
115+
.body
116+
.collect()
117+
.await
118+
.expect("successful read")
119+
.to_vec();
120+
121+
assert_eq!(data, b"test-bucket.correct-key");
122+
assert_eq!(s3_computed.num_calls(), 1);
123+
}
124+
92125
#[tokio::test]
93126
async fn test_mock_client_sequence() {
94127
let rule = mock!(aws_sdk_s3::Client::get_object)

rust-runtime/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust-runtime/aws-smithy-mocks/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "aws-smithy-mocks"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>"]
55
description = "Testing utilities for smithy-rs generated clients"
66
edition = "2021"
@@ -9,7 +9,7 @@ repository = "https://github.com/smithy-lang/smithy-rs"
99

1010
[dependencies]
1111
aws-smithy-types = { path = "../aws-smithy-types" }
12-
aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["client", "http-1x"] }
12+
aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["client", "http-1x", "test-util"] }
1313
aws-smithy-http-client = { path = "../aws-smithy-http-client", features = ["test-util"] }
1414
http = "1"
1515

rust-runtime/aws-smithy-mocks/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ let http_rule = mock!(Client::get_object)
8383
StatusCode::try_from(503).unwrap(),
8484
SdkBody::from("service unavailable")
8585
));
86+
87+
// Return a computed output based on the input
88+
let compute_rule = mock!(Client::get_object)
89+
.then_compute_output(|req| {
90+
let key = req.key().unwrap_or("unknown");
91+
GetObjectOutput::builder()
92+
.body(ByteStream::from_static(format!("content for {}", key).as_bytes()))
93+
.build()
94+
});
8695
```
8796

8897
### Response Sequences

rust-runtime/aws-smithy-mocks/src/interceptor.rs

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use aws_smithy_runtime_api::box_error::BoxError;
99
use aws_smithy_runtime_api::client::http::SharedHttpClient;
1010
use aws_smithy_runtime_api::client::interceptors::context::{
1111
BeforeSerializationInterceptorContextMut, BeforeTransmitInterceptorContextMut, Error,
12-
FinalizerInterceptorContextMut, Output,
12+
FinalizerInterceptorContextMut, Input, Output,
1313
};
1414
use aws_smithy_runtime_api::client::interceptors::Intercept;
1515
use aws_smithy_runtime_api::client::orchestrator::{HttpResponse, OrchestratorError};
@@ -132,7 +132,7 @@ impl Intercept for MockResponseInterceptor {
132132
}
133133

134134
// Rule matches and is not exhausted, get the response
135-
if let Some(response) = rule.next_response() {
135+
if let Some(response) = rule.next_response(input) {
136136
matching_rule = Some(rule.clone());
137137
matching_response = Some(response);
138138
} else {
@@ -154,7 +154,7 @@ impl Intercept for MockResponseInterceptor {
154154
}
155155

156156
if (rule.matcher)(input) {
157-
if let Some(response) = rule.next_response() {
157+
if let Some(response) = rule.next_response(input) {
158158
matching_rule = Some(rule.clone());
159159
matching_response = Some(response);
160160
break;
@@ -198,7 +198,10 @@ impl Intercept for MockResponseInterceptor {
198198
if active_response.is_none() {
199199
// in the case of retries we try to get the next response if it has been consumed
200200
if let Some(active_rule) = cfg.load::<ActiveRule>() {
201-
let next_resp = active_rule.0.next_response();
201+
// During retries, input is not available in modify_before_transmit.
202+
// For HTTP status responses that don't use the input, we can use a dummy input.
203+
let dummy_input = Input::doesnt_matter();
204+
let next_resp = active_rule.0.next_response(&dummy_input);
202205
active_response = next_resp;
203206
}
204207
}
@@ -444,6 +447,43 @@ mod tests {
444447
assert_eq!(rule.num_calls(), 3);
445448
}
446449

450+
#[tokio::test]
451+
async fn test_compute_output() {
452+
// Create a rule that computes its responses based off of input data
453+
let rule = create_rule_builder()
454+
.match_requests(|input| input.bucket == "test-bucket" && input.key == "test-key")
455+
.then_compute_output(|input| TestOutput {
456+
content: format!("{}.{}", input.bucket, input.key),
457+
});
458+
459+
// Create an interceptor with the rule
460+
let interceptor = MockResponseInterceptor::new()
461+
.rule_mode(RuleMode::Sequential)
462+
.with_rule(&rule);
463+
464+
let operation = create_test_operation(interceptor, true);
465+
466+
let result = operation
467+
.invoke(TestInput::new("test-bucket", "test-key"))
468+
.await;
469+
470+
// Should succeed with the output derived from input
471+
assert!(
472+
result.is_ok(),
473+
"Expected success but got error: {:?}",
474+
result.err()
475+
);
476+
assert_eq!(
477+
result.unwrap(),
478+
TestOutput {
479+
content: "test-bucket.test-key".to_string()
480+
}
481+
);
482+
483+
// Verify the rule was used once, no retries
484+
assert_eq!(rule.num_calls(), 1);
485+
}
486+
447487
#[should_panic(
448488
expected = "must_match was enabled but no rules matched or all rules were exhausted for"
449489
)]

rust-runtime/aws-smithy-mocks/src/rule.rs

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ pub(crate) enum MockResponse<O, E> {
3232

3333
/// A function that matches requests.
3434
type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
35-
type ServeFn = Arc<dyn Fn(usize) -> Option<MockResponse<Output, Error>> + Send + Sync>;
35+
type ServeFn = Arc<dyn Fn(usize, &Input) -> Option<MockResponse<Output, Error>> + Send + Sync>;
3636

3737
/// A rule for matching requests and providing mock responses.
3838
///
@@ -67,9 +67,10 @@ impl fmt::Debug for Rule {
6767

6868
impl Rule {
6969
/// Creates a new rule with the given matcher, response handler, and max responses.
70+
#[allow(clippy::type_complexity)]
7071
pub(crate) fn new<O, E>(
7172
matcher: MatchFn,
72-
response_handler: Arc<dyn Fn(usize) -> Option<MockResponse<O, E>> + Send + Sync>,
73+
response_handler: Arc<dyn Fn(usize, &Input) -> Option<MockResponse<O, E>> + Send + Sync>,
7374
max_responses: usize,
7475
is_simple: bool,
7576
) -> Self
@@ -79,9 +80,9 @@ impl Rule {
7980
{
8081
Rule {
8182
matcher,
82-
response_handler: Arc::new(move |idx: usize| {
83+
response_handler: Arc::new(move |idx: usize, input: &Input| {
8384
if idx < max_responses {
84-
response_handler(idx).map(|resp| match resp {
85+
response_handler(idx, input).map(|resp| match resp {
8586
MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
8687
MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
8788
MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
@@ -102,9 +103,9 @@ impl Rule {
102103
}
103104

104105
/// Gets the next response.
105-
pub(crate) fn next_response(&self) -> Option<MockResponse<Output, Error>> {
106+
pub(crate) fn next_response(&self, input: &Input) -> Option<MockResponse<Output, Error>> {
106107
let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
107-
(self.response_handler)(idx)
108+
(self.response_handler)(idx, input)
108109
}
109110

110111
/// Returns the number of times this rule has been called.
@@ -227,9 +228,31 @@ where
227228
{
228229
self.sequence().http_response(response_fn).build_simple()
229230
}
231+
232+
/// Creates a rule that computes an output based on the input.
233+
///
234+
/// This allows generating responses based on the input request.
235+
///
236+
/// # Examples
237+
///
238+
/// ```rust,ignore
239+
/// let rule = mock!(Client::get_object)
240+
/// .compute_output(|req| {
241+
/// GetObjectOutput::builder()
242+
/// .body(ByteStream::from_static(format!("content for {}", req.key().unwrap_or("unknown")).as_bytes()))
243+
/// .build()
244+
/// })
245+
/// .build();
246+
/// ```
247+
pub fn then_compute_output<F>(self, compute_fn: F) -> Rule
248+
where
249+
F: Fn(&I) -> O + Send + Sync + 'static,
250+
{
251+
self.sequence().compute_output(compute_fn).build_simple()
252+
}
230253
}
231254

232-
type SequenceGeneratorFn<O, E> = Arc<dyn Fn() -> MockResponse<O, E> + Send + Sync>;
255+
type SequenceGeneratorFn<O, E> = Arc<dyn Fn(&Input) -> MockResponse<O, E> + Send + Sync>;
233256

234257
/// A builder for creating response sequences
235258
pub struct ResponseSequenceBuilder<I, O, E> {
@@ -281,7 +304,7 @@ where
281304
where
282305
F: Fn() -> O + Send + Sync + 'static,
283306
{
284-
let generator = Arc::new(move || MockResponse::Output(output_fn()));
307+
let generator = Arc::new(move |_input: &Input| MockResponse::Output(output_fn()));
285308
self.generators.push((generator, 1));
286309
self
287310
}
@@ -291,7 +314,7 @@ where
291314
where
292315
F: Fn() -> E + Send + Sync + 'static,
293316
{
294-
let generator = Arc::new(move || MockResponse::Error(error_fn()));
317+
let generator = Arc::new(move |_input: &Input| MockResponse::Error(error_fn()));
295318
self.generators.push((generator, 1));
296319
self
297320
}
@@ -301,10 +324,10 @@ where
301324
let status_code = StatusCode::try_from(status).unwrap();
302325

303326
let generator: SequenceGeneratorFn<O, E> = match body {
304-
Some(body) => Arc::new(move || {
327+
Some(body) => Arc::new(move |_input: &Input| {
305328
MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
306329
}),
307-
None => Arc::new(move || {
330+
None => Arc::new(move |_input: &Input| {
308331
MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
309332
}),
310333
};
@@ -318,7 +341,26 @@ where
318341
where
319342
F: Fn() -> HttpResponse + Send + Sync + 'static,
320343
{
321-
let generator = Arc::new(move || MockResponse::Http(response_fn()));
344+
let generator = Arc::new(move |_input: &Input| MockResponse::Http(response_fn()));
345+
self.generators.push((generator, 1));
346+
self
347+
}
348+
349+
/// Add a computed output response to the sequence. Note that this is not `pub`
350+
/// because creating computed output rules off of sequenced rules doesn't work,
351+
/// as we can't preserve the input across retries. So we only expose `compute_output`
352+
/// on unsequenced rules above.
353+
fn compute_output<F>(mut self, compute_fn: F) -> Self
354+
where
355+
F: Fn(&I) -> O + Send + Sync + 'static,
356+
{
357+
let generator = Arc::new(move |input: &Input| {
358+
if let Some(typed_input) = input.downcast_ref::<I>() {
359+
MockResponse::Output(compute_fn(typed_input))
360+
} else {
361+
panic!("Input type mismatch in compute_output")
362+
}
363+
});
322364
self.generators.push((generator, 1));
323365
self
324366
}
@@ -415,12 +457,12 @@ where
415457

416458
Rule::new(
417459
self.input_filter,
418-
Arc::new(move |idx| {
460+
Arc::new(move |idx, input| {
419461
// find which generator to use
420462
let mut current_idx = idx;
421463
for (generator, repeat_count) in &generators {
422464
if current_idx < *repeat_count {
423-
return Some(generator());
465+
return Some(generator(input));
424466
}
425467
current_idx -= repeat_count;
426468
}

0 commit comments

Comments
 (0)