Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.config.SaslConfigs;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.internals.RecordHeader;
import org.apache.kafka.common.serialization.ByteArrayDeserializer;
import org.apache.kafka.common.serialization.Deserializer;
import org.apache.kafka.common.serialization.Serializer;
Expand Down Expand Up @@ -3535,12 +3537,17 @@ public PTransform<PCollection<KV<K, V>>, PDone> buildExternal(
public static class External implements ExternalTransformRegistrar {

public static final String URN = "beam:transform:org.apache.beam:kafka_write:v1";
public static final String URN_WITH_HEADERS =
"beam:transform:org.apache.beam:kafka_write_with_headers:v1";

@Override
public Map<String, Class<? extends ExternalTransformBuilder<?, ?, ?>>> knownBuilders() {
return ImmutableMap.of(
URN,
(Class<KafkaIO.Write.Builder<?, ?>>) (Class<?>) AutoValue_KafkaIO_Write.Builder.class);
(Class<KafkaIO.Write.Builder<?, ?>>) (Class<?>) AutoValue_KafkaIO_Write.Builder.class,
URN_WITH_HEADERS,
(Class<? extends ExternalTransformBuilder<?, ?, ?>>)
(Class<?>) WriteWithHeaders.Builder.class);
}

/** Parameters class to expose the Write transform to an external SDK. */
Expand Down Expand Up @@ -3825,6 +3832,137 @@ public T decode(InputStream inStream) {
}
}

/**
* A {@link PTransform} to write to Kafka with support for record headers.
*
* <p>This transform accepts {@link Row} elements with the following schema:
*
* <ul>
* <li>key: bytes (required) - The key of the record.
* <li>value: bytes (required) - The value of the record.
* <li>headers: List&lt;Row(key=str, value=bytes)&gt; (optional) - Record headers.
* <li>topic: str (optional) - Per-record topic override.
* <li>partition: int (optional) - Per-record partition.
* <li>timestamp: long (optional) - Per-record timestamp in milliseconds.
* </ul>
*
* <p>This class is primarily used as a cross-language transform.
*/
static class WriteWithHeaders extends PTransform<PCollection<Row>, PDone> {
private static final String FIELD_KEY = "key";
private static final String FIELD_VALUE = "value";
private static final String FIELD_HEADERS = "headers";
private static final String FIELD_TOPIC = "topic";
private static final String FIELD_PARTITION = "partition";
private static final String FIELD_TIMESTAMP = "timestamp";
private static final String HEADER_FIELD_KEY = "key";
private static final String HEADER_FIELD_VALUE = "value";

private final WriteRecords<byte[], byte[]> writeRecords;

WriteWithHeaders(WriteRecords<byte[], byte[]> writeRecords) {
this.writeRecords = writeRecords;
}

static class Builder
implements ExternalTransformBuilder<Write.External.Configuration, PCollection<Row>, PDone> {

@Override
@SuppressWarnings("unchecked")
public PTransform<PCollection<Row>, PDone> buildExternal(
Write.External.Configuration configuration) {
Map<String, Object> producerConfig = new HashMap<>(configuration.producerConfig);
Class<Serializer<byte[]>> keySerializer =
(Class<Serializer<byte[]>>) resolveClass(configuration.keySerializer);
Class<Serializer<byte[]>> valueSerializer =
(Class<Serializer<byte[]>>) resolveClass(configuration.valueSerializer);

WriteRecords<byte[], byte[]> writeRecords =
KafkaIO.<byte[], byte[]>writeRecords()
.withProducerConfigUpdates(producerConfig)
.withKeySerializer(keySerializer)
.withValueSerializer(valueSerializer);

if (configuration.topic != null) {
writeRecords = writeRecords.withTopic(configuration.topic);
}

return new WriteWithHeaders(writeRecords);
}
}

@Override
public PDone expand(PCollection<Row> input) {
final @Nullable String defaultTopic = writeRecords.getTopic();
return input
.apply(
"Row to ProducerRecord",
MapElements.via(
new SimpleFunction<Row, ProducerRecord<byte[], byte[]>>() {
@Override
public ProducerRecord<byte[], byte[]> apply(Row row) {
return toProducerRecord(row, defaultTopic);
}
}))
.setCoder(
ProducerRecordCoder.of(
NullableCoder.of(ByteArrayCoder.of()), NullableCoder.of(ByteArrayCoder.of())))
.apply(writeRecords);
}

@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
writeRecords.populateDisplayData(builder);
}

@SuppressWarnings("argument")
private static ProducerRecord<byte[], byte[]> toProducerRecord(
Row row, @Nullable String defaultTopic) {
String topic = defaultTopic;
if (row.getSchema().hasField(FIELD_TOPIC)) {
String rowTopic = row.getString(FIELD_TOPIC);
if (rowTopic != null) {
topic = rowTopic;
}
}
checkArgument(
topic != null, "Row is missing field '%s' and no default topic configured", FIELD_TOPIC);

byte[] key = row.getBytes(FIELD_KEY);
byte[] value = row.getBytes(FIELD_VALUE);
Integer partition =
row.getSchema().hasField(FIELD_PARTITION) ? row.getInt32(FIELD_PARTITION) : null;
Long timestamp =
row.getSchema().hasField(FIELD_TIMESTAMP) ? row.getInt64(FIELD_TIMESTAMP) : null;

boolean hasHeaders = ConsumerSpEL.hasHeaders();
Iterable<Header> headers = Collections.emptyList();
if (hasHeaders && row.getSchema().hasField(FIELD_HEADERS)) {
Iterable<Row> headerRows = row.getArray(FIELD_HEADERS);
if (headerRows != null) {
List<Header> headerList = new ArrayList<>();
for (Row headerRow : headerRows) {
String headerKey = headerRow.getString(HEADER_FIELD_KEY);
checkArgument(headerKey != null, "Header key is required");
byte[] headerValue = headerRow.getBytes(HEADER_FIELD_VALUE);
headerList.add(new RecordHeader(headerKey, headerValue));
}
headers = headerList;
}
} else if (!hasHeaders && row.getSchema().hasField(FIELD_HEADERS)) {
// Log warning when headers are present but Kafka client doesn't support them
LOG.warn(
"Dropping headers from Kafka record because the Kafka client version "
+ "does not support headers (requires Kafka 0.11+).");
}

return hasHeaders
? new ProducerRecord<>(topic, partition, timestamp, key, value, headers)
: new ProducerRecord<>(topic, partition, timestamp, key, value);
}
}

private static Class<?> resolveClass(String className) {
try {
return Class.forName(className);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,92 @@ public void testConstructKafkaWrite() throws Exception {
assertThat(spec.getValueSerializer().getName(), Matchers.is(valueSerializer));
}

@Test
public void testConstructKafkaWriteWithHeaders() throws Exception {
String topic = "topic";
String keySerializer = "org.apache.kafka.common.serialization.ByteArraySerializer";
String valueSerializer = "org.apache.kafka.common.serialization.ByteArraySerializer";
ImmutableMap<String, String> producerConfig =
ImmutableMap.<String, String>builder()
.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "server1:port,server2:port")
.put("retries", "3")
.build();

ExternalTransforms.ExternalConfigurationPayload payload =
encodeRow(
Row.withSchema(
Schema.of(
Field.of("topic", FieldType.STRING),
Field.of(
"producer_config", FieldType.map(FieldType.STRING, FieldType.STRING)),
Field.of("key_serializer", FieldType.STRING),
Field.of("value_serializer", FieldType.STRING)))
.withFieldValue("topic", topic)
.withFieldValue("producer_config", producerConfig)
.withFieldValue("key_serializer", keySerializer)
.withFieldValue("value_serializer", valueSerializer)
.build());

Schema rowSchema =
Schema.of(Field.of("key", FieldType.BYTES), Field.of("value", FieldType.BYTES));
Row inputRow = Row.withSchema(rowSchema).addValues(new byte[0], new byte[0]).build();

Pipeline p = Pipeline.create();
p.apply(Create.of(inputRow).withRowSchema(rowSchema));
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
String inputPCollection =
Iterables.getOnlyElement(
Iterables.getLast(pipelineProto.getComponents().getTransformsMap().values())
.getOutputsMap()
.values());

ExpansionApi.ExpansionRequest request =
ExpansionApi.ExpansionRequest.newBuilder()
.setComponents(pipelineProto.getComponents())
.setTransform(
RunnerApi.PTransform.newBuilder()
.setUniqueName("test")
.putInputs("input", inputPCollection)
.setSpec(
RunnerApi.FunctionSpec.newBuilder()
.setUrn(
org.apache.beam.sdk.io.kafka.KafkaIO.Write.External
.URN_WITH_HEADERS)
.setPayload(payload.toByteString())))
.setNamespace("test_namespace")
.build();

ExpansionService expansionService = new ExpansionService();
TestStreamObserver<ExpansionApi.ExpansionResponse> observer = new TestStreamObserver<>();
expansionService.expand(request, observer);

ExpansionApi.ExpansionResponse result = observer.result;
RunnerApi.PTransform transform = result.getTransform();
assertThat(
transform.getSubtransformsList(),
Matchers.hasItem(MatchesPattern.matchesPattern(".*Row-to-ProducerRecord.*")));
assertThat(
transform.getSubtransformsList(),
Matchers.hasItem(MatchesPattern.matchesPattern(".*KafkaIO-WriteRecords.*")));
assertThat(transform.getInputsCount(), Matchers.is(1));
assertThat(transform.getOutputsCount(), Matchers.is(0));

RunnerApi.PTransform writeComposite =
result.getComponents().getTransformsOrThrow(transform.getSubtransforms(1));
RunnerApi.PTransform writeParDo =
result.getComponents().getTransformsOrThrow(writeComposite.getSubtransforms(0));

RunnerApi.ParDoPayload parDoPayload =
RunnerApi.ParDoPayload.parseFrom(writeParDo.getSpec().getPayload());
KafkaWriter<?, ?> kafkaWriter = (KafkaWriter<?, ?>) ParDoTranslation.getDoFn(parDoPayload);
KafkaIO.WriteRecords<?, ?> spec = kafkaWriter.getSpec();

assertThat(spec.getProducerConfig(), Matchers.is(producerConfig));
assertThat(spec.getTopic(), Matchers.is(topic));
assertThat(spec.getKeySerializer().getName(), Matchers.is(keySerializer));
assertThat(spec.getValueSerializer().getName(), Matchers.is(valueSerializer));
}

private static ExternalConfigurationPayload encodeRow(Row row) {
ByteStringOutputStream outputStream = new ByteStringOutputStream();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ RequestResponseIO<RequestT, ResponseT> withSleeperSupplier(SerializableSupplier<
* need for a {@link SerializableSupplier} instead of setting this directly is that some {@link
* BackOff} implementations, such as {@link FluentBackoff} are not {@link Serializable}.
*/
RequestResponseIO<RequestT, ResponseT> withBackOffSupplier(SerializableSupplier<BackOff> value) {
public RequestResponseIO<RequestT, ResponseT> withBackOffSupplier(
SerializableSupplier<BackOff> value) {
return new RequestResponseIO<>(
rrioConfiguration, callConfiguration.toBuilder().setBackOffSupplier(value).build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.hamcrest.Matchers.greaterThan;

import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.util.List;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.Coder;
Expand All @@ -40,6 +41,7 @@
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.SerializableSupplier;
Expand Down Expand Up @@ -333,6 +335,36 @@ public void givenCustomBackoff_thenBackoffBehaviorCustom() {
greaterThan(0L));
}

@Test
public void givenBoundedBackoff_thenRetriesStopAfterLimit() {
int maxRetries = 3;
Caller<Request, Response> caller = new CallerImpl(5);
SerializableSupplier<BackOff> boundedBackoffSupplier = () -> new BoundedBackOff(maxRetries);

Result<Response> result =
requests()
.apply(
"rrio",
RequestResponseIO.of(caller, RESPONSE_CODER)
.withBackOffSupplier(boundedBackoffSupplier)
.withMonitoringConfiguration(
Monitoring.builder().setCountCalls(true).setCountFailures(true).build()));

PAssert.that(result.getResponses()).empty();
PAssert.thatSingleton(result.getFailures().apply("CountFailures", Count.globally()))
.isEqualTo(1L);

PipelineResult pipelineResult = pipeline.run();
MetricResults metrics = pipelineResult.metrics();
pipelineResult.waitUntilFinish();

assertThat(
getCounterResult(metrics, Call.class, Monitoring.callCounterNameOf(caller)),
equalTo((long) maxRetries + 1));
assertThat(
getCounterResult(metrics, Call.class, Monitoring.FAILURES_COUNTER_NAME), equalTo(1L));
}

// TODO(damondouglas): Count metrics of caching after https://github.com/apache/beam/issues/29888
// resolves.
@Ignore
Expand Down Expand Up @@ -463,6 +495,29 @@ MetricName getCounterName() {
}
}

private static class BoundedBackOff implements BackOff, Serializable {
private final int maxRetries;
private int retries = 0;

private BoundedBackOff(int maxRetries) {
this.maxRetries = maxRetries;
}

@Override
public void reset() {
retries = 0;
}

@Override
public long nextBackOffMillis() {
if (retries >= maxRetries) {
return BackOff.STOP;
}
retries++;
return 0L;
}
}

private static class CustomBackOffSupplier implements SerializableSupplier<BackOff> {

private final Counter counter = Metrics.counter(CustomBackOffSupplier.class, "custom_counter");
Expand Down
Loading
Loading