diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index ad5535517646..a14eb5836006 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -124,6 +124,8 @@ import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; import org.apache.kafka.common.serialization.ByteArrayDeserializer; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.Serializer; @@ -3535,12 +3537,17 @@ public PTransform>, PDone> buildExternal( public static class External implements ExternalTransformRegistrar { public static final String URN = "beam:transform:org.apache.beam:kafka_write:v1"; + public static final String URN_WITH_HEADERS = + "beam:transform:org.apache.beam:kafka_write_with_headers:v1"; @Override public Map>> knownBuilders() { return ImmutableMap.of( URN, - (Class>) (Class) AutoValue_KafkaIO_Write.Builder.class); + (Class>) (Class) AutoValue_KafkaIO_Write.Builder.class, + URN_WITH_HEADERS, + (Class>) + (Class) WriteWithHeaders.Builder.class); } /** Parameters class to expose the Write transform to an external SDK. */ @@ -3825,6 +3832,137 @@ public T decode(InputStream inStream) { } } + /** + * A {@link PTransform} to write to Kafka with support for record headers. + * + *

This transform accepts {@link Row} elements with the following schema: + * + *

    + *
  • key: bytes (required) - The key of the record. + *
  • value: bytes (required) - The value of the record. + *
  • headers: List<Row(key=str, value=bytes)> (optional) - Record headers. + *
  • topic: str (optional) - Per-record topic override. + *
  • partition: int (optional) - Per-record partition. + *
  • timestamp: long (optional) - Per-record timestamp in milliseconds. + *
+ * + *

This class is primarily used as a cross-language transform. + */ + static class WriteWithHeaders extends PTransform, PDone> { + private static final String FIELD_KEY = "key"; + private static final String FIELD_VALUE = "value"; + private static final String FIELD_HEADERS = "headers"; + private static final String FIELD_TOPIC = "topic"; + private static final String FIELD_PARTITION = "partition"; + private static final String FIELD_TIMESTAMP = "timestamp"; + private static final String HEADER_FIELD_KEY = "key"; + private static final String HEADER_FIELD_VALUE = "value"; + + private final WriteRecords writeRecords; + + WriteWithHeaders(WriteRecords writeRecords) { + this.writeRecords = writeRecords; + } + + static class Builder + implements ExternalTransformBuilder, PDone> { + + @Override + @SuppressWarnings("unchecked") + public PTransform, PDone> buildExternal( + Write.External.Configuration configuration) { + Map producerConfig = new HashMap<>(configuration.producerConfig); + Class> keySerializer = + (Class>) resolveClass(configuration.keySerializer); + Class> valueSerializer = + (Class>) resolveClass(configuration.valueSerializer); + + WriteRecords writeRecords = + KafkaIO.writeRecords() + .withProducerConfigUpdates(producerConfig) + .withKeySerializer(keySerializer) + .withValueSerializer(valueSerializer); + + if (configuration.topic != null) { + writeRecords = writeRecords.withTopic(configuration.topic); + } + + return new WriteWithHeaders(writeRecords); + } + } + + @Override + public PDone expand(PCollection input) { + final @Nullable String defaultTopic = writeRecords.getTopic(); + return input + .apply( + "Row to ProducerRecord", + MapElements.via( + new SimpleFunction>() { + @Override + public ProducerRecord apply(Row row) { + return toProducerRecord(row, defaultTopic); + } + })) + .setCoder( + ProducerRecordCoder.of( + NullableCoder.of(ByteArrayCoder.of()), NullableCoder.of(ByteArrayCoder.of()))) + .apply(writeRecords); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + writeRecords.populateDisplayData(builder); + } + + @SuppressWarnings("argument") + private static ProducerRecord toProducerRecord( + Row row, @Nullable String defaultTopic) { + String topic = defaultTopic; + if (row.getSchema().hasField(FIELD_TOPIC)) { + String rowTopic = row.getString(FIELD_TOPIC); + if (rowTopic != null) { + topic = rowTopic; + } + } + checkArgument( + topic != null, "Row is missing field '%s' and no default topic configured", FIELD_TOPIC); + + byte[] key = row.getBytes(FIELD_KEY); + byte[] value = row.getBytes(FIELD_VALUE); + Integer partition = + row.getSchema().hasField(FIELD_PARTITION) ? row.getInt32(FIELD_PARTITION) : null; + Long timestamp = + row.getSchema().hasField(FIELD_TIMESTAMP) ? row.getInt64(FIELD_TIMESTAMP) : null; + + boolean hasHeaders = ConsumerSpEL.hasHeaders(); + Iterable

headers = Collections.emptyList(); + if (hasHeaders && row.getSchema().hasField(FIELD_HEADERS)) { + Iterable headerRows = row.getArray(FIELD_HEADERS); + if (headerRows != null) { + List
headerList = new ArrayList<>(); + for (Row headerRow : headerRows) { + String headerKey = headerRow.getString(HEADER_FIELD_KEY); + checkArgument(headerKey != null, "Header key is required"); + byte[] headerValue = headerRow.getBytes(HEADER_FIELD_VALUE); + headerList.add(new RecordHeader(headerKey, headerValue)); + } + headers = headerList; + } + } else if (!hasHeaders && row.getSchema().hasField(FIELD_HEADERS)) { + // Log warning when headers are present but Kafka client doesn't support them + LOG.warn( + "Dropping headers from Kafka record because the Kafka client version " + + "does not support headers (requires Kafka 0.11+)."); + } + + return hasHeaders + ? new ProducerRecord<>(topic, partition, timestamp, key, value, headers) + : new ProducerRecord<>(topic, partition, timestamp, key, value); + } + } + private static Class resolveClass(String className) { try { return Class.forName(className); diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java index 1973f95ddc25..dbdcb468a217 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java @@ -375,6 +375,92 @@ public void testConstructKafkaWrite() throws Exception { assertThat(spec.getValueSerializer().getName(), Matchers.is(valueSerializer)); } + @Test + public void testConstructKafkaWriteWithHeaders() throws Exception { + String topic = "topic"; + String keySerializer = "org.apache.kafka.common.serialization.ByteArraySerializer"; + String valueSerializer = "org.apache.kafka.common.serialization.ByteArraySerializer"; + ImmutableMap producerConfig = + ImmutableMap.builder() + .put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "server1:port,server2:port") + .put("retries", "3") + .build(); + + ExternalTransforms.ExternalConfigurationPayload payload = + encodeRow( + Row.withSchema( + Schema.of( + Field.of("topic", FieldType.STRING), + Field.of( + "producer_config", FieldType.map(FieldType.STRING, FieldType.STRING)), + Field.of("key_serializer", FieldType.STRING), + Field.of("value_serializer", FieldType.STRING))) + .withFieldValue("topic", topic) + .withFieldValue("producer_config", producerConfig) + .withFieldValue("key_serializer", keySerializer) + .withFieldValue("value_serializer", valueSerializer) + .build()); + + Schema rowSchema = + Schema.of(Field.of("key", FieldType.BYTES), Field.of("value", FieldType.BYTES)); + Row inputRow = Row.withSchema(rowSchema).addValues(new byte[0], new byte[0]).build(); + + Pipeline p = Pipeline.create(); + p.apply(Create.of(inputRow).withRowSchema(rowSchema)); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + String inputPCollection = + Iterables.getOnlyElement( + Iterables.getLast(pipelineProto.getComponents().getTransformsMap().values()) + .getOutputsMap() + .values()); + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName("test") + .putInputs("input", inputPCollection) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn( + org.apache.beam.sdk.io.kafka.KafkaIO.Write.External + .URN_WITH_HEADERS) + .setPayload(payload.toByteString()))) + .setNamespace("test_namespace") + .build(); + + ExpansionService expansionService = new ExpansionService(); + TestStreamObserver observer = new TestStreamObserver<>(); + expansionService.expand(request, observer); + + ExpansionApi.ExpansionResponse result = observer.result; + RunnerApi.PTransform transform = result.getTransform(); + assertThat( + transform.getSubtransformsList(), + Matchers.hasItem(MatchesPattern.matchesPattern(".*Row-to-ProducerRecord.*"))); + assertThat( + transform.getSubtransformsList(), + Matchers.hasItem(MatchesPattern.matchesPattern(".*KafkaIO-WriteRecords.*"))); + assertThat(transform.getInputsCount(), Matchers.is(1)); + assertThat(transform.getOutputsCount(), Matchers.is(0)); + + RunnerApi.PTransform writeComposite = + result.getComponents().getTransformsOrThrow(transform.getSubtransforms(1)); + RunnerApi.PTransform writeParDo = + result.getComponents().getTransformsOrThrow(writeComposite.getSubtransforms(0)); + + RunnerApi.ParDoPayload parDoPayload = + RunnerApi.ParDoPayload.parseFrom(writeParDo.getSpec().getPayload()); + KafkaWriter kafkaWriter = (KafkaWriter) ParDoTranslation.getDoFn(parDoPayload); + KafkaIO.WriteRecords spec = kafkaWriter.getSpec(); + + assertThat(spec.getProducerConfig(), Matchers.is(producerConfig)); + assertThat(spec.getTopic(), Matchers.is(topic)); + assertThat(spec.getKeySerializer().getName(), Matchers.is(keySerializer)); + assertThat(spec.getValueSerializer().getName(), Matchers.is(valueSerializer)); + } + private static ExternalConfigurationPayload encodeRow(Row row) { ByteStringOutputStream outputStream = new ByteStringOutputStream(); try { diff --git a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/RequestResponseIO.java b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/RequestResponseIO.java index 8a04f76b6829..7033bdf40111 100644 --- a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/RequestResponseIO.java +++ b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/RequestResponseIO.java @@ -229,7 +229,8 @@ RequestResponseIO withSleeperSupplier(SerializableSupplier< * need for a {@link SerializableSupplier} instead of setting this directly is that some {@link * BackOff} implementations, such as {@link FluentBackoff} are not {@link Serializable}. */ - RequestResponseIO withBackOffSupplier(SerializableSupplier value) { + public RequestResponseIO withBackOffSupplier( + SerializableSupplier value) { return new RequestResponseIO<>( rrioConfiguration, callConfiguration.toBuilder().setBackOffSupplier(value).build()); } diff --git a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/RequestResponseIOTest.java b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/RequestResponseIOTest.java index cd0b29bab661..5a199225f396 100644 --- a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/RequestResponseIOTest.java +++ b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/RequestResponseIOTest.java @@ -23,6 +23,7 @@ import static org.hamcrest.Matchers.greaterThan; import com.google.auto.value.AutoValue; +import java.io.Serializable; import java.util.List; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.coders.Coder; @@ -40,6 +41,7 @@ import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.SerializableSupplier; @@ -333,6 +335,36 @@ public void givenCustomBackoff_thenBackoffBehaviorCustom() { greaterThan(0L)); } + @Test + public void givenBoundedBackoff_thenRetriesStopAfterLimit() { + int maxRetries = 3; + Caller caller = new CallerImpl(5); + SerializableSupplier boundedBackoffSupplier = () -> new BoundedBackOff(maxRetries); + + Result result = + requests() + .apply( + "rrio", + RequestResponseIO.of(caller, RESPONSE_CODER) + .withBackOffSupplier(boundedBackoffSupplier) + .withMonitoringConfiguration( + Monitoring.builder().setCountCalls(true).setCountFailures(true).build())); + + PAssert.that(result.getResponses()).empty(); + PAssert.thatSingleton(result.getFailures().apply("CountFailures", Count.globally())) + .isEqualTo(1L); + + PipelineResult pipelineResult = pipeline.run(); + MetricResults metrics = pipelineResult.metrics(); + pipelineResult.waitUntilFinish(); + + assertThat( + getCounterResult(metrics, Call.class, Monitoring.callCounterNameOf(caller)), + equalTo((long) maxRetries + 1)); + assertThat( + getCounterResult(metrics, Call.class, Monitoring.FAILURES_COUNTER_NAME), equalTo(1L)); + } + // TODO(damondouglas): Count metrics of caching after https://github.com/apache/beam/issues/29888 // resolves. @Ignore @@ -463,6 +495,29 @@ MetricName getCounterName() { } } + private static class BoundedBackOff implements BackOff, Serializable { + private final int maxRetries; + private int retries = 0; + + private BoundedBackOff(int maxRetries) { + this.maxRetries = maxRetries; + } + + @Override + public void reset() { + retries = 0; + } + + @Override + public long nextBackOffMillis() { + if (retries >= maxRetries) { + return BackOff.STOP; + } + retries++; + return 0L; + } + } + private static class CustomBackOffSupplier implements SerializableSupplier { private final Counter counter = Metrics.counter(CustomBackOffSupplier.class, "custom_counter"); diff --git a/sdks/python/apache_beam/io/kafka.py b/sdks/python/apache_beam/io/kafka.py index b63366393252..b1847544d395 100644 --- a/sdks/python/apache_beam/io/kafka.py +++ b/sdks/python/apache_beam/io/kafka.py @@ -274,6 +274,16 @@ class WriteToKafka(ExternalTransform): assumed to be byte arrays. Experimental; no backwards compatibility guarantees. + + When with_headers=True, the input PCollection elements must be beam.Row + objects with the following schema: + + - key: bytes (required) - The key of the record. + - value: bytes (required) - The value of the record. + - headers: List[Row(key=str, value=bytes)] (optional) - Record headers. + - topic: str (optional) - Per-record topic override. + - partition: int (optional) - Per-record partition. + - timestamp: int (optional) - Per-record timestamp in milliseconds. """ # Default serializer which passes raw bytes to Kafka @@ -281,6 +291,8 @@ class WriteToKafka(ExternalTransform): 'org.apache.kafka.common.serialization.ByteArraySerializer') URN = 'beam:transform:org.apache.beam:kafka_write:v1' + URN_WITH_HEADERS = ( + 'beam:transform:org.apache.beam:kafka_write_with_headers:v1') def __init__( self, @@ -288,6 +300,7 @@ def __init__( topic, key_serializer=byte_array_serializer, value_serializer=byte_array_serializer, + with_headers=False, expansion_service=None): """ Initializes a write operation to Kafka. @@ -302,10 +315,20 @@ def __init__( Serializer for the topic's value, e.g. 'org.apache.kafka.common.serialization.LongSerializer'. Default: 'org.apache.kafka.common.serialization.ByteArraySerializer'. + :param with_headers: If True, input elements must be beam.Row objects + containing 'key', 'value', and optional 'headers' fields. + Only ByteArraySerializer is supported when with_headers=True. :param expansion_service: The address (host:port) of the ExpansionService. """ + if with_headers and (key_serializer != self.byte_array_serializer or + value_serializer != self.byte_array_serializer): + raise ValueError( + 'WriteToKafka(with_headers=True) only supports ' + 'ByteArraySerializer for key and value.') + + urn = self.URN_WITH_HEADERS if with_headers else self.URN super().__init__( - self.URN, + urn, NamedTupleBasedPayloadBuilder( WriteToKafkaSchema( producer_config=producer_config,