Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import software.amazon.smithy.java.core.serde.event.EventEncoderFactory;
import software.amazon.smithy.java.core.serde.event.EventStreamingException;
import software.amazon.smithy.java.core.serde.event.FrameEncoder;
import software.amazon.smithy.java.core.serde.event.FrameTransformer;

/**
* A {@link EventEncoderFactory} for AWS events.
Expand All @@ -24,19 +25,22 @@ public final class AwsEventEncoderFactory implements EventEncoderFactory<AwsEven
private final Schema schema;
private final Codec codec;
private final String payloadMediaType;
private final FrameTransformer<AwsEventFrame> transformer;
private final Function<Throwable, EventStreamingException> exceptionHandler;

private AwsEventEncoderFactory(
InitialEventType initialEventType,
Schema schema,
Codec codec,
String payloadMediaType,
FrameTransformer<AwsEventFrame> transformer,
Function<Throwable, EventStreamingException> exceptionHandler
) {
this.initialEventType = Objects.requireNonNull(initialEventType, "initialEventType");
this.schema = Objects.requireNonNull(schema, "schema").isMember() ? schema.memberTarget() : schema;
this.codec = Objects.requireNonNull(codec, "codec");
this.payloadMediaType = Objects.requireNonNull(payloadMediaType, "payloadMediaType");
this.transformer = Objects.requireNonNull(transformer, "transformer");
this.exceptionHandler = Objects.requireNonNull(exceptionHandler, "exceptionHandler");
}

Expand All @@ -53,12 +57,14 @@ public static AwsEventEncoderFactory forInputStream(
InputEventStreamingApiOperation<?, ?, ?> operation,
Codec codec,
String payloadMediaType,
FrameTransformer<AwsEventFrame> transformer,
Function<Throwable, EventStreamingException> exceptionHandler
) {
return new AwsEventEncoderFactory(InitialEventType.INITIAL_REQUEST,
operation.inputStreamMember(),
codec,
payloadMediaType,
transformer,
exceptionHandler);
}

Expand All @@ -75,18 +81,25 @@ public static AwsEventEncoderFactory forOutputStream(
OutputEventStreamingApiOperation<?, ?, ?> operation,
Codec codec,
String payloadMediaType,
FrameTransformer<AwsEventFrame> transformer,
Function<Throwable, EventStreamingException> exceptionHandler
) {
return new AwsEventEncoderFactory(InitialEventType.INITIAL_RESPONSE,
operation.outputStreamMember(),
codec,
payloadMediaType,
transformer,
exceptionHandler);
}

@Override
public EventEncoder<AwsEventFrame> newEventEncoder() {
return new AwsEventShapeEncoder(initialEventType, schema, codec, payloadMediaType, exceptionHandler);
return new AwsEventShapeEncoder(initialEventType,
schema,
codec,
payloadMediaType,
transformer,
exceptionHandler);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Flow;
import java.util.function.Supplier;
import software.amazon.eventstream.HeaderValue;
import software.amazon.eventstream.Message;
Expand All @@ -21,6 +20,7 @@
import software.amazon.smithy.java.core.serde.ShapeDeserializer;
import software.amazon.smithy.java.core.serde.SpecificShapeDeserializer;
import software.amazon.smithy.java.core.serde.event.EventDecoder;
import software.amazon.smithy.java.core.serde.event.EventStream;

/**
* A decoder for AWS events
Expand All @@ -36,7 +36,6 @@ public final class AwsEventShapeDecoder<E extends SerializableStruct, IR extends
private final Supplier<ShapeBuilder<E>> eventBuilder;
private final Schema eventSchema;
private final Codec codec;
private volatile Flow.Publisher<SerializableStruct> publisher;

AwsEventShapeDecoder(
InitialEventType initialEventType,
Expand All @@ -54,19 +53,9 @@ public final class AwsEventShapeDecoder<E extends SerializableStruct, IR extends

@Override
public SerializableStruct decode(AwsEventFrame frame) {
var message = frame.unwrap();
var eventType = getEventType(message);
if (initialEventType.value().equals(eventType)) {
return decodeInitialResponse(frame);
}
return decodeEvent(frame);
}

@Override
public void onPrepare(Flow.Publisher<SerializableStruct> publisher) {
this.publisher = publisher;
}

private E decodeEvent(AwsEventFrame frame) {
var message = frame.unwrap();
var eventType = getEventType(message);
Expand All @@ -85,12 +74,13 @@ private E decodeEvent(AwsEventFrame frame) {
return builder.build();
}

private IR decodeInitialResponse(AwsEventFrame frame) {
@Override
public IR decodeInitialEvent(AwsEventFrame frame, EventStream<?> eventStream) {
var message = frame.unwrap();
var builder = initialEventBuilder.get();
var publisherMember = getPublisherMember(builder.schema());
var publisherMember = getEventStreamMember(builder.schema());
// Set the publisher member
var responseDeserializer = new InitialResponseDeserializer(publisherMember, publisher);
var responseDeserializer = new InitialResponseDeserializer(publisherMember, eventStream);
builder.deserialize(responseDeserializer);
// Deserialize the rest of the members if any
var headers = message.getHeaders();
Expand All @@ -100,7 +90,7 @@ private IR decodeInitialResponse(AwsEventFrame frame) {
return builder.build();
}

private Schema getPublisherMember(Schema schema) {
private Schema getEventStreamMember(Schema schema) {
for (var member : schema.members()) {
if (member.memberTarget().hasTrait(TraitKey.STREAMING_TRAIT)) {
return member;
Expand All @@ -115,16 +105,16 @@ private String getEventType(Message message) {

static class InitialResponseDeserializer extends SpecificShapeDeserializer {
private final Schema publisherMember;
private final Flow.Publisher<? extends SerializableStruct> publisher;
private final EventStream<?> eventStream;

InitialResponseDeserializer(Schema publisherMember, Flow.Publisher<? extends SerializableStruct> publisher) {
InitialResponseDeserializer(Schema publisherMember, EventStream<?> eventStream) {
this.publisherMember = publisherMember;
this.publisher = publisher;
this.eventStream = eventStream;
}

@Override
public Flow.Publisher<? extends SerializableStruct> readEventStream(Schema schema) {
return publisher;
public EventStream<? extends SerializableStruct> readEventStream(Schema schema) {
return eventStream;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import software.amazon.smithy.java.core.serde.SpecificShapeSerializer;
import software.amazon.smithy.java.core.serde.event.EventEncoder;
import software.amazon.smithy.java.core.serde.event.EventStreamingException;
import software.amazon.smithy.java.core.serde.event.FrameTransformer;
import software.amazon.smithy.model.shapes.ShapeId;

public final class AwsEventShapeEncoder implements EventEncoder<AwsEventFrame> {
Expand All @@ -35,13 +36,15 @@ public final class AwsEventShapeEncoder implements EventEncoder<AwsEventFrame> {
private final String payloadMediaType;
private final Map<String, BiFunction<OutputStream, Map<String, HeaderValue>, ShapeSerializer>> possibleTypes;
private final Map<ShapeId, Schema> possibleExceptions;
private final FrameTransformer<AwsEventFrame> frameTransformer;
private final Function<Throwable, EventStreamingException> exceptionHandler;

public AwsEventShapeEncoder(
InitialEventType initialEventType,
Schema eventSchema,
Codec codec,
String payloadMediaType,
FrameTransformer<AwsEventFrame> frameTransformer,
Function<Throwable, EventStreamingException> exceptionHandler
) {
this.initialEventType = Objects.requireNonNull(initialEventType, "initialEventType");
Expand All @@ -51,6 +54,7 @@ public AwsEventShapeEncoder(
codec,
initialEventType.value());
this.possibleExceptions = possibleExceptions(Objects.requireNonNull(eventSchema, "eventSchema"));
this.frameTransformer = Objects.requireNonNull(frameTransformer, "frameTransformer");
this.exceptionHandler = Objects.requireNonNull(exceptionHandler, "exceptionHandler");
}

Expand All @@ -62,7 +66,8 @@ public AwsEventFrame encode(SerializableStruct item) {
headers.put(":message-type", HeaderValue.fromString("event"));
headers.put(":event-type", HeaderValue.fromString(typeHolder.get()));
headers.put(":content-type", HeaderValue.fromString(payloadMediaType));
return new AwsEventFrame(new Message(headers, payload));
var frame = new AwsEventFrame(new Message(headers, payload));
return frameTransformer.apply(frame);
}

private byte[] encodeInput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public AwsFrameDecoder(FrameTransformer<AwsEventFrame> transformer) {
public List<AwsEventFrame> decode(ByteBuffer buffer) {
decoder.feed(buffer);
var messages = decoder.getDecodedMessages();
var result = new ArrayList<AwsEventFrame>();
var result = new ArrayList<AwsEventFrame>(messages.size());
for (var message : messages) {
var event = new AwsEventFrame(message);
var transformed = transformer.apply(event);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@

package software.amazon.smithy.java.aws.events;

import java.nio.ByteBuffer;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
import software.amazon.smithy.java.core.schema.Schema;
import software.amazon.smithy.java.core.schema.SerializableStruct;
import software.amazon.smithy.java.core.schema.TraitKey;
import software.amazon.smithy.java.core.serde.event.EventDecoderFactory;
import software.amazon.smithy.java.core.serde.event.EventEncoderFactory;
import software.amazon.smithy.java.core.serde.event.EventStreamFrameDecodingProcessor;
import software.amazon.smithy.java.core.serde.event.EventStreamFrameEncodingProcessor;
import software.amazon.smithy.java.core.serde.event.EventStream;
import software.amazon.smithy.java.core.serde.event.ProtocolEventStreamReader;
import software.amazon.smithy.java.core.serde.event.ProtocolEventStreamWriter;
import software.amazon.smithy.java.io.datastream.DataStream;

/**
Expand All @@ -24,47 +22,26 @@ public final class RpcEventStreamsUtil {

private RpcEventStreamsUtil() {}

public static Flow.Publisher<ByteBuffer> bodyForEventStreaming(
@SuppressWarnings("unchecked")
public static DataStream bodyForEventStreaming(
EventEncoderFactory<AwsEventFrame> eventStreamEncodingFactory,
SerializableStruct input
) {
Flow.Publisher<SerializableStruct> eventStream = input.getMemberValue(streamingMember(input.schema()));
return EventStreamFrameEncodingProcessor.create(eventStream, eventStreamEncodingFactory, input);
EventStream<SerializableStruct> eventStream = input.getMemberValue(streamingMember(input.schema()));
ProtocolEventStreamWriter<SerializableStruct, SerializableStruct, AwsEventFrame> writer =
ProtocolEventStreamWriter.toInternal(eventStream);
writer.bootstrap(eventStreamEncodingFactory, input);
return writer.toDataStream();
}

// TODO: Make more synchronous
public static <O extends SerializableStruct> O deserializeResponse(
EventDecoderFactory<AwsEventFrame> eventDecoderFactory,
DataStream bodyDataStream
) {
var result = new CompletableFuture<O>();
var processor = EventStreamFrameDecodingProcessor.create(bodyDataStream, eventDecoderFactory);

// A subscriber to serialize the initial event.
processor.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
subscription.request(1);
}

@Override
@SuppressWarnings("unchecked")
public void onNext(SerializableStruct item) {
result.complete((O) item);
}

@Override
public void onError(Throwable throwable) {
result.completeExceptionally(throwable);
}

@Override
public void onComplete() {
result.completeExceptionally(new RuntimeException("Unexpected event stream completion"));
}
});

return result.join();
var reader = ProtocolEventStreamReader.<O, SerializableStruct, AwsEventFrame>newReader(bodyDataStream,
eventDecoderFactory,
true);
return reader.readInitialEvent();
}

private static Schema streamingMember(Schema schema) {
Expand All @@ -75,5 +52,4 @@ private static Schema streamingMember(Schema schema) {
}
throw new IllegalArgumentException("No streaming member found");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void testDecodeInitialResponse() {
var frame = new AwsEventFrame(message);

// Act
var struct = createDecoder().decode(frame);
var struct = createDecoder().decodeInitialEvent(frame, null);

// Assert
assertInstanceOf(TestOperationOutput.class, struct);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import software.amazon.smithy.java.aws.events.model.TestOperationInput;
import software.amazon.smithy.java.core.serde.Codec;
import software.amazon.smithy.java.core.serde.event.EventStreamingException;
import software.amazon.smithy.java.core.serde.event.FrameTransformer;
import software.amazon.smithy.java.json.JsonCodec;

class AwsEventShapeEncoderTest {
Expand Down Expand Up @@ -135,6 +136,7 @@ static AwsEventShapeEncoder createEncoder() {
TestOperation.instance().inputStreamMember(), // event schema
createJsonCodec(), // codec
"text/json",
FrameTransformer.identity(),
(e) -> new EventStreamingException("InternalServerException", "Internal Server Error"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
package software.amazon.smithy.java.aws.events.model;

import java.util.Objects;
import java.util.concurrent.Flow.Publisher;
import software.amazon.smithy.java.core.schema.Schema;
import software.amazon.smithy.java.core.schema.SchemaUtils;
import software.amazon.smithy.java.core.schema.SerializableStruct;
import software.amazon.smithy.java.core.schema.ShapeBuilder;
import software.amazon.smithy.java.core.serde.ShapeDeserializer;
import software.amazon.smithy.java.core.serde.ShapeSerializer;
import software.amazon.smithy.java.core.serde.ToStringSerializer;
import software.amazon.smithy.java.core.serde.event.EventStream;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.utils.SmithyGenerated;

Expand All @@ -29,7 +29,7 @@ public final class TestOperationInput implements SerializableStruct {

private final transient String headerString;
private final transient String inputStringMember;
private final transient Publisher<TestEventStream> stream;
private final transient EventStream<TestEventStream> stream;

private TestOperationInput(Builder builder) {
this.headerString = builder.headerString;
Expand All @@ -45,7 +45,7 @@ public String getInputStringMember() {
return inputStringMember;
}

public Publisher<TestEventStream> getStream() {
public EventStream<TestEventStream> getStream() {
return stream;
}

Expand Down Expand Up @@ -130,7 +130,7 @@ public static Builder builder() {
public static final class Builder implements ShapeBuilder<TestOperationInput> {
private String headerString;
private String inputStringMember;
private Publisher<TestEventStream> stream;
private EventStream<TestEventStream> stream;

private Builder() {}

Expand Down Expand Up @@ -158,7 +158,7 @@ public Builder inputStringMember(String inputStringMember) {
/**
* @return this builder.
*/
public Builder stream(Publisher<TestEventStream> stream) {
public Builder stream(EventStream<TestEventStream> stream) {
this.stream = stream;
return this;
}
Expand All @@ -176,7 +176,8 @@ public void setMemberValue(Schema member, Object value) {
case 1 -> inputStringMember(
(String) SchemaUtils.validateSameMember($SCHEMA_INPUT_STRING_MEMBER, member, value));
case 2 ->
stream((Publisher<TestEventStream>) SchemaUtils.validateSameMember($SCHEMA_STREAM, member, value));
stream((EventStream<TestEventStream>) SchemaUtils
.validateSameMember($SCHEMA_STREAM, member, value));
default -> ShapeBuilder.super.setMemberValue(member, value);
}
}
Expand All @@ -201,7 +202,7 @@ public void accept(Builder builder, Schema member, ShapeDeserializer de) {
switch (member.memberIndex()) {
case 0 -> builder.headerString(de.readString(member));
case 1 -> builder.inputStringMember(de.readString(member));
case 2 -> builder.stream((Publisher<TestEventStream>) de.readEventStream(member));
case 2 -> builder.stream((EventStream<TestEventStream>) de.readEventStream(member));
default -> throw new IllegalArgumentException("Unexpected member: " + member.memberName());
}
}
Expand Down
Loading