From 5196a3b7d2844c04b06f548b7f8eca3bdb644151 Mon Sep 17 00:00:00 2001 From: Hao Xu Date: Fri, 25 Oct 2024 17:09:14 -0700 Subject: [PATCH] Fix unit test --- .../kafka/consumer/ConsumptionTask.java | 4 +- .../kafka/consumer/KafkaConsumerService.java | 85 ++++++---- .../kafka/consumer/SharedKafkaConsumer.java | 5 + .../KafkaConsumerServiceDelegatorTest.java | 156 ++++++++++++++++++ 4 files changed, 217 insertions(+), 33 deletions(-) diff --git a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/ConsumptionTask.java b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/ConsumptionTask.java index c72d67175d..1bfac05cc3 100644 --- a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/ConsumptionTask.java +++ b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/ConsumptionTask.java @@ -250,8 +250,8 @@ void setDataReceiver( // Defensive coding. Should never happen except in case of a regression. throw new IllegalStateException( "It is not allowed to set multiple " + ConsumedDataReceiver.class.getSimpleName() + " instances for the same " - + "topic-partition of a given consumer. Previous: " + previousConsumedDataReceiver + ", New: " - + consumedDataReceiver); + + "topic-partition of a given consumer. Previous: " + previousConsumedDataReceiver.destinationIdentifier() + + ", New: " + consumedDataReceiver.destinationIdentifier()); } synchronized (this) { notifyAll(); diff --git a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerService.java b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerService.java index cf0bbd58f0..f20e89d439 100644 --- a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerService.java +++ b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerService.java @@ -222,8 +222,14 @@ public void unsubscribeAll(PubSubTopic versionTopic) { versionTopicToTopicPartitionToConsumer.compute(versionTopic, (k, topicPartitionToConsumerMap) -> { if (topicPartitionToConsumerMap != null) { topicPartitionToConsumerMap.forEach((topicPartition, sharedConsumer) -> { - sharedConsumer.unSubscribe(topicPartition); - removeTopicPartitionFromConsumptionTask(sharedConsumer, topicPartition); + /** + * Refer {@link KafkaConsumerService#startConsumptionIntoDataReceiver} for avoiding race condition caused by + * setting data receiver and unsubscribing concurrently for the same topic partition on a shared consumer. + */ + synchronized (sharedConsumer) { + sharedConsumer.unSubscribe(topicPartition); + removeTopicPartitionFromConsumptionTask(sharedConsumer, topicPartition); + } }); } return null; @@ -237,8 +243,14 @@ public void unsubscribeAll(PubSubTopic versionTopic) { public void unSubscribe(PubSubTopic versionTopic, PubSubTopicPartition pubSubTopicPartition) { PubSubConsumerAdapter consumer = getConsumerAssignedToVersionTopicPartition(versionTopic, pubSubTopicPartition); if (consumer != null) { - consumer.unSubscribe(pubSubTopicPartition); - consumerToConsumptionTask.get(consumer).removeDataReceiver(pubSubTopicPartition); + /** + * Refer {@link KafkaConsumerService#startConsumptionIntoDataReceiver} for avoiding race condition caused by + * setting data receiver and unsubscribing concurrently for the same topic partition on a shared consumer. + */ + synchronized (consumer) { + consumer.unSubscribe(pubSubTopicPartition); + removeTopicPartitionFromConsumptionTask(consumer, pubSubTopicPartition); + } versionTopicToTopicPartitionToConsumer.compute(versionTopic, (k, topicPartitionToConsumerMap) -> { if (topicPartitionToConsumerMap != null) { topicPartitionToConsumerMap.remove(pubSubTopicPartition); @@ -265,20 +277,25 @@ public void batchUnsubscribe(PubSubTopic versionTopic, Set /** * Leverage {@link PubSubConsumerAdapter#batchUnsubscribe(Set)}. */ - consumerUnSubTopicPartitionSet.forEach((c, tpSet) -> { - c.batchUnsubscribe(tpSet); - ConsumptionTask task = consumerToConsumptionTask.get(c); - tpSet.forEach(tp -> { - task.removeDataReceiver(tp); - versionTopicToTopicPartitionToConsumer.compute(versionTopic, (k, topicPartitionToConsumerMap) -> { - if (topicPartitionToConsumerMap != null) { - topicPartitionToConsumerMap.remove(tp); - return topicPartitionToConsumerMap.isEmpty() ? null : topicPartitionToConsumerMap; - } else { - return null; - } - }); - }); + consumerUnSubTopicPartitionSet.forEach((sharedConsumer, tpSet) -> { + ConsumptionTask task = consumerToConsumptionTask.get(sharedConsumer); + /** + * Refer {@link KafkaConsumerService#startConsumptionIntoDataReceiver} for avoiding race condition caused by + * setting data receiver and unsubscribing concurrently for the same topic partition on a shared consumer. + */ + synchronized (sharedConsumer) { + sharedConsumer.batchUnsubscribe(tpSet); + tpSet.forEach(task::removeDataReceiver); + } + tpSet.forEach( + tp -> versionTopicToTopicPartitionToConsumer.compute(versionTopic, (k, topicPartitionToConsumerMap) -> { + if (topicPartitionToConsumerMap != null) { + topicPartitionToConsumerMap.remove(tp); + return topicPartitionToConsumerMap.isEmpty() ? null : topicPartitionToConsumerMap; + } else { + return null; + } + })); }); } @@ -387,26 +404,32 @@ public void startConsumptionIntoDataReceiver( PubSubTopic versionTopic = consumedDataReceiver.destinationIdentifier(); PubSubTopicPartition topicPartition = partitionReplicaIngestionContext.getPubSubTopicPartition(); SharedKafkaConsumer consumer = assignConsumerFor(versionTopic, topicPartition); - if (consumer == null) { // Defensive code. Shouldn't happen except in case of a regression. throw new VeniceException( "Shared consumer must exist for version topic: " + versionTopic + " in Kafka cluster: " + kafkaUrl); } - - ConsumptionTask consumptionTask = consumerToConsumptionTask.get(consumer); - if (consumptionTask == null) { - // Defensive coding. Should never happen except in case of a regression. - throw new IllegalStateException( - "There should be a " + ConsumptionTask.class.getSimpleName() + " assigned for this " - + SharedKafkaConsumer.class.getSimpleName()); - } /** - * N.B. it's important to set the {@link ConsumedDataReceiver} prior to subscribing, otherwise the - * {@link KafkaConsumerService.ConsumptionTask} will not be able to funnel the messages. + * It is possible that when one {@link StoreIngestionTask} thread finishes unsubscribing a topic partition but not + * finish removing data receiver, but the other {@link StoreIngestionTask} thread is setting data receiver for this + * topic partition before subscription. As {@link ConsumptionTask} does not allow 2 different data receivers for + * the same topic partition, it will throw exception. */ - consumptionTask.setDataReceiver(topicPartition, consumedDataReceiver); - consumer.subscribe(consumedDataReceiver.destinationIdentifier(), topicPartition, lastReadOffset); + synchronized (consumer) { + ConsumptionTask consumptionTask = consumerToConsumptionTask.get(consumer); + if (consumptionTask == null) { + // Defensive coding. Should never happen except in case of a regression. + throw new IllegalStateException( + "There should be a " + ConsumptionTask.class.getSimpleName() + " assigned for this " + + SharedKafkaConsumer.class.getSimpleName()); + } + /** + * N.B. it's important to set the {@link ConsumedDataReceiver} prior to subscribing, otherwise the + * {@link KafkaConsumerService.ConsumptionTask} will not be able to funnel the messages. + */ + consumptionTask.setDataReceiver(topicPartition, consumedDataReceiver); + consumer.subscribe(consumedDataReceiver.destinationIdentifier(), topicPartition, lastReadOffset); + } } interface KCSConstructor { diff --git a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/SharedKafkaConsumer.java b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/SharedKafkaConsumer.java index 33a12ebdcf..9460ee2209 100644 --- a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/SharedKafkaConsumer.java +++ b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/SharedKafkaConsumer.java @@ -348,4 +348,9 @@ public Long endOffset(PubSubTopicPartition pubSubTopicPartition) { public List partitionsFor(PubSubTopic topic) { throw new UnsupportedOperationException("partitionsFor is not supported in SharedKafkaConsumer"); } + + // Test only + public void setNextPollTimeOutSeconds(long seconds) { + this.nextPollTimeOutSeconds = seconds; + } } diff --git a/clients/da-vinci-client/src/test/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerServiceDelegatorTest.java b/clients/da-vinci-client/src/test/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerServiceDelegatorTest.java index a6fe08dff1..59b979bf73 100644 --- a/clients/da-vinci-client/src/test/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerServiceDelegatorTest.java +++ b/clients/da-vinci-client/src/test/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerServiceDelegatorTest.java @@ -1,28 +1,51 @@ package com.linkedin.davinci.kafka.consumer; +import static com.linkedin.venice.ConfigKeys.KAFKA_BOOTSTRAP_SERVERS; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; import com.linkedin.davinci.config.VeniceServerConfig; import com.linkedin.davinci.ingestion.consumption.ConsumedDataReceiver; +import com.linkedin.venice.kafka.protocol.KafkaMessageEnvelope; +import com.linkedin.venice.meta.ReadOnlyStoreRepository; +import com.linkedin.venice.meta.Version; +import com.linkedin.venice.pubsub.PubSubConsumerAdapterFactory; import com.linkedin.venice.pubsub.PubSubTopicPartitionImpl; import com.linkedin.venice.pubsub.PubSubTopicRepository; +import com.linkedin.venice.pubsub.adapter.kafka.consumer.ApacheKafkaConsumerAdapter; +import com.linkedin.venice.pubsub.api.PubSubMessageDeserializer; import com.linkedin.venice.pubsub.api.PubSubTopic; import com.linkedin.venice.pubsub.api.PubSubTopicPartition; +import com.linkedin.venice.serialization.avro.OptimizedKafkaValueSerializer; +import com.linkedin.venice.utils.SystemTime; +import com.linkedin.venice.utils.Utils; +import com.linkedin.venice.utils.pools.LandFillObjectPool; +import io.tehuti.metrics.MetricsRepository; +import io.tehuti.metrics.Sensor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Properties; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -458,4 +481,137 @@ private void verifyConsumerServiceStartConsumptionIntoDataReceiver( partitionReplicaIngestionContext.getVersionTopic(), partitionReplicaIngestionContext.getPubSubTopicPartition()); } + + /** + * This test is to simulate multiple threads resubscribing to the same real-time topic partition for different store + * versions and verify if the lock will protect the handoff for {@link ConsumptionTask} and {@link ConsumedDataReceiver} + * during the re-subscription. + */ + @Test(invocationCount = 9) + public void testKafkaConsumerServiceResubscriptionConcurrency() throws Exception { + ApacheKafkaConsumerAdapter consumer1 = mock(ApacheKafkaConsumerAdapter.class); + when(consumer1.hasAnySubscription()).thenReturn(true); + + PubSubConsumerAdapterFactory factory = mock(PubSubConsumerAdapterFactory.class); + when(factory.create(any(), anyBoolean(), any(), any())).thenReturn(consumer1); + + Properties properties = new Properties(); + String testKafkaUrl = "test_kafka_url"; + properties.put(KAFKA_BOOTSTRAP_SERVERS, testKafkaUrl); + MetricsRepository mockMetricsRepository = mock(MetricsRepository.class); + final Sensor mockSensor = mock(Sensor.class); + doReturn(mockSensor).when(mockMetricsRepository).sensor(anyString(), any()); + + int versionNum = 5; + + PubSubMessageDeserializer pubSubDeserializer = new PubSubMessageDeserializer( + new OptimizedKafkaValueSerializer(), + new LandFillObjectPool<>(KafkaMessageEnvelope::new), + new LandFillObjectPool<>(KafkaMessageEnvelope::new)); + KafkaConsumerService consumerService = new PartitionWiseKafkaConsumerService( + ConsumerPoolType.REGULAR_POOL, + factory, + properties, + 1000l, + versionNum + 1, // Plus 1 to guarantee the consumer pool size is larger than the # of versions. + mock(IngestionThrottler.class), + mock(KafkaClusterBasedRecordThrottler.class), + mockMetricsRepository, + "test_kafka_cluster_alias", + TimeUnit.MINUTES.toMillis(1), + mock(TopicExistenceChecker.class), + false, + pubSubDeserializer, + SystemTime.INSTANCE, + null, + false, + mock(ReadOnlyStoreRepository.class), + false); + String storeName = Utils.getUniqueString("test_consumer_service"); + + Function isAAWCStoreFunc = vt -> true; + KafkaConsumerServiceDelegator.KafkaConsumerServiceBuilder consumerServiceBuilder = + (ignored, poolType) -> consumerService; + VeniceServerConfig mockConfig = mock(VeniceServerConfig.class); + doReturn(false).when(mockConfig).isDedicatedConsumerPoolForAAWCLeaderEnabled(); + doReturn(true).when(mockConfig).isResubscriptionTriggeredByVersionIngestionContextChangeEnabled(); + doReturn(KafkaConsumerServiceDelegator.ConsumerPoolStrategyType.CURRENT_VERSION_PRIORITIZATION).when(mockConfig) + .getConsumerPoolStrategyType(); + KafkaConsumerServiceDelegator delegator = + new KafkaConsumerServiceDelegator(mockConfig, consumerServiceBuilder, isAAWCStoreFunc); + PubSubTopicPartition realTimeTopicPartition = + new PubSubTopicPartitionImpl(TOPIC_REPOSITORY.getTopic(Version.composeRealTimeTopic(storeName)), 0); + + CountDownLatch countDownLatch = new CountDownLatch(1); + List infiniteSubUnSubThreads = new ArrayList<>(); + for (int i = 0; i < versionNum; i++) { + PubSubTopic topicV1ForStoreName3 = TOPIC_REPOSITORY.getTopic(Version.composeKafkaTopic(storeName, i)); + StoreIngestionTask task = mock(StoreIngestionTask.class); + when(task.getVersionTopic()).thenReturn(topicV1ForStoreName3); + when(task.isHybridMode()).thenReturn(true); + + PartitionReplicaIngestionContext partitionReplicaIngestionContext = new PartitionReplicaIngestionContext( + topicV1ForStoreName3, + realTimeTopicPartition, + PartitionReplicaIngestionContext.VersionRole.CURRENT, + PartitionReplicaIngestionContext.WorkloadType.AA_OR_WRITE_COMPUTE); + ConsumedDataReceiver consumedDataReceiver = mock(ConsumedDataReceiver.class); + when(consumedDataReceiver.destinationIdentifier()).thenReturn(topicV1ForStoreName3); + Runnable infiniteSubUnSub = getResubscriptionRunnableFor( + delegator, + partitionReplicaIngestionContext, + consumedDataReceiver, + countDownLatch); + Thread infiniteSubUnSubThread = new Thread(infiniteSubUnSub, "infiniteResubscribe: " + topicV1ForStoreName3); + infiniteSubUnSubThread.start(); + infiniteSubUnSubThreads.add(infiniteSubUnSubThread); + } + + long currentTime = System.currentTimeMillis(); + Boolean raceConditionFound = countDownLatch.await(30, TimeUnit.SECONDS); + long elapsedTime = System.currentTimeMillis() - currentTime; + for (Thread infiniteSubUnSubThread: infiniteSubUnSubThreads) { + infiniteSubUnSubThread.interrupt(); + } + Assert.assertFalse( + raceConditionFound, + "Found race condition in KafkaConsumerService with time passed in milliseconds: " + elapsedTime); + } + + private Runnable getResubscriptionRunnableFor( + KafkaConsumerServiceDelegator consumerServiceDelegator, + PartitionReplicaIngestionContext partitionReplicaIngestionContext, + ConsumedDataReceiver consumedDataReceiver, + CountDownLatch countDownLatch) { + PubSubTopic versionTopic = partitionReplicaIngestionContext.getVersionTopic(); + PubSubTopicPartition pubSubTopicPartition = partitionReplicaIngestionContext.getPubSubTopicPartition(); + return () -> { + try { + while (true) { + if (Thread.currentThread().isInterrupted()) { + consumerServiceDelegator.unSubscribe(versionTopic, pubSubTopicPartition); + break; + } + consumerServiceDelegator + .startConsumptionIntoDataReceiver(partitionReplicaIngestionContext, 0, consumedDataReceiver); + // Avoid wait time here to increase the chance for race condition. + consumerServiceDelegator.assignConsumerFor(versionTopic, pubSubTopicPartition).setNextPollTimeOutSeconds(0); + int versionNum = + Version.parseVersionFromKafkaTopicName(partitionReplicaIngestionContext.getVersionTopic().getName()); + if (versionNum % 3 == 0) { + consumerServiceDelegator.unSubscribe(versionTopic, pubSubTopicPartition); + } else if (versionNum % 3 == 1) { + consumerServiceDelegator.unsubscribeAll(partitionReplicaIngestionContext.getVersionTopic()); + } else { + consumerServiceDelegator.batchUnsubscribe( + partitionReplicaIngestionContext.getVersionTopic(), + Collections.singleton(partitionReplicaIngestionContext.getPubSubTopicPartition())); + } + } + } catch (Exception e) { + e.printStackTrace(); + countDownLatch.countDown(); + } + }; + } }