Skip to content

Commit

Permalink
[server] Validate PubSub address against cluster map in AASIT::consum…
Browse files Browse the repository at this point in the history
…erSubscribe (#1342)

Ensure the PubSub URL is present in the PubSub cluster map before
subscribing to a topic. This prevents issues caused by attempting to
subscribe to unknown PubSub URLs during message consumption.
  • Loading branch information
sushantmane authored Nov 25, 2024
1 parent 98b980c commit 1735a63
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,28 @@ protected void startConsumingAsLeader(PartitionConsumptionState partitionConsump
leaderOffsetByKafkaURL);
}

/**
* Ensures the PubSub URL is present in the PubSub cluster URL-to-ID map before subscribing to a topic.
* Prevents subscription to unknown PubSub URLs, which can cause issues during message consumption.
*/
public void consumerSubscribe(PubSubTopicPartition pubSubTopicPartition, long startOffset, String pubSubAddress) {
VeniceServerConfig serverConfig = getServerConfig();
if (isDaVinciClient() || serverConfig.getKafkaClusterUrlToIdMap().containsKey(pubSubAddress)) {
super.consumerSubscribe(pubSubTopicPartition, startOffset, pubSubAddress);
return;
}
LOGGER.error(
"PubSub address: {} is not in the pubsub cluster map: {}. Cannot subscribe to topic-partition: {}",
pubSubAddress,
serverConfig.getKafkaClusterUrlToIdMap(),
pubSubTopicPartition);
throw new VeniceException(
String.format(
"PubSub address: %s is not in the pubsub cluster map. Cannot subscribe to topic-partition: %s",
pubSubAddress,
pubSubTopicPartition));
}

private long calculateRewindStartTime(PartitionConsumptionState partitionConsumptionState) {
long rewindStartTime = 0;
long rewindTimeInMs = hybridStoreConfig.get().getRewindTimeInSeconds() * Time.MS_PER_SECOND;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4442,4 +4442,8 @@ public boolean hasAllPartitionReportedCompleted() {
void setVersionRole(PartitionReplicaIngestionContext.VersionRole versionRole) {
this.versionRole = versionRole;
}

protected boolean isDaVinciClient() {
return isDaVinciClient;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.expectThrows;

import com.github.luben.zstd.Zstd;
import com.linkedin.davinci.config.VeniceServerConfig;
Expand Down Expand Up @@ -72,9 +74,7 @@
import com.linkedin.venice.pubsub.api.PubSubProduceResult;
import com.linkedin.venice.pubsub.api.PubSubProducerAdapter;
import com.linkedin.venice.pubsub.api.PubSubProducerCallback;
import com.linkedin.venice.pubsub.api.PubSubTopic;
import com.linkedin.venice.pubsub.api.PubSubTopicPartition;
import com.linkedin.venice.pubsub.api.PubSubTopicType;
import com.linkedin.venice.schema.SchemaEntry;
import com.linkedin.venice.serialization.KeyWithChunkingSuffixSerializer;
import com.linkedin.venice.serialization.avro.AvroProtocolDefinition;
Expand All @@ -92,6 +92,8 @@
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntMaps;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
Expand All @@ -114,6 +116,7 @@

public class ActiveActiveStoreIngestionTaskTest {
private static final Logger LOGGER = LogManager.getLogger(ActiveActiveStoreIngestionTaskTest.class);
private static final PubSubTopicRepository TOPIC_REPOSITORY = new PubSubTopicRepository();
String STORE_NAME = "Thvorusleikir_store";
String PUSH_JOB_ID = "yule";
String BOOTSTRAP_SERVER = "Stekkjastaur";
Expand Down Expand Up @@ -193,11 +196,6 @@ public void testGetValueBytesFromTransientRecords(CompressionStrategy strategy)

@Test
public void testisReadyToServeAnnouncedWithRTLag() {
// Set up PubSubTopicRepository
PubSubTopicRepository pubSubTopicRepository = mock(PubSubTopicRepository.class);
PubSubTopic pubSubTopic = new TestPubSubTopic(STORE_NAME + "_v1", STORE_NAME, PubSubTopicType.VERSION_TOPIC);
when(pubSubTopicRepository.getTopic("Thvorusleikir_store_v1")).thenReturn(pubSubTopic);

// Setup store/schema/storage repository
ReadOnlyStoreRepository readOnlyStoreRepository = mock(ReadOnlyStoreRepository.class);
ReadOnlySchemaRepository readOnlySchemaRepository = mock(ReadOnlySchemaRepository.class);
Expand All @@ -214,7 +212,7 @@ public void testisReadyToServeAnnouncedWithRTLag() {

// Set up IngestionTask Builder
StoreIngestionTaskFactory.Builder builder = new StoreIngestionTaskFactory.Builder();
builder.setPubSubTopicRepository(pubSubTopicRepository);
builder.setPubSubTopicRepository(TOPIC_REPOSITORY);
builder.setHostLevelIngestionStats(mock(AggHostLevelIngestionStats.class));
builder.setAggKafkaConsumerService(mock(AggKafkaConsumerService.class));
builder.setMetadataRepository(readOnlyStoreRepository);
Expand Down Expand Up @@ -614,8 +612,7 @@ public void testUnwrapByteBufferFromOldValueProvider() {

@Test
public void testGetUpstreamKafkaUrlFromKafkaValue() {
PubSubTopicRepository pubSubTopicRepository = new PubSubTopicRepository();
PubSubTopicPartition partition = new PubSubTopicPartitionImpl(pubSubTopicRepository.getTopic("topic"), 0);
PubSubTopicPartition partition = new PubSubTopicPartitionImpl(TOPIC_REPOSITORY.getTopic("topic"), 0);
long offset = 100;
long timestamp = System.currentTimeMillis();
int payloadSize = 200;
Expand Down Expand Up @@ -740,4 +737,47 @@ public void getKeyLevelLockMaxPoolSizeBasedOnServerConfigTest() {
when(serverConfig.isAAWCWorkloadParallelProcessingEnabled()).thenReturn(true);
assertEquals(ActiveActiveStoreIngestionTask.getKeyLevelLockMaxPoolSizeBasedOnServerConfig(serverConfig, 1000), 721);
}

@Test
public void testConsumerSubscribeValidatesPubSubAddress() {
// Case 1: AA store ingestion task with invalid pubsub address
ActiveActiveStoreIngestionTask ingestionTask = mock(ActiveActiveStoreIngestionTask.class);
VeniceServerConfig mockServerConfig = mock(VeniceServerConfig.class);

when(ingestionTask.getServerConfig()).thenReturn(mockServerConfig);
when(ingestionTask.isDaVinciClient()).thenReturn(false);
Object2IntMap<String> kafkaClusterUrlToIdMap = Object2IntMaps.singleton("validPubSubAddress", 1);
when(mockServerConfig.getKafkaClusterUrlToIdMap()).thenReturn(kafkaClusterUrlToIdMap);

// Set up real method call
doCallRealMethod().when(ingestionTask).consumerSubscribe(any(), anyLong(), anyString());

PubSubTopicPartition pubSubTopicPartition = new PubSubTopicPartitionImpl(TOPIC_REPOSITORY.getTopic("test"), 0);
VeniceException exception = expectThrows(
VeniceException.class,
() -> ingestionTask.consumerSubscribe(pubSubTopicPartition, 100L, "invalidPubSubAddress"));
assertNotNull(exception.getMessage(), "Exception message should not be null");
assertTrue(
exception.getMessage().contains("is not in the pubsub cluster map"),
"Exception message should contain the expected message but found: " + exception.getMessage());

verify(ingestionTask, times(1)).consumerSubscribe(pubSubTopicPartition, 100L, "invalidPubSubAddress");

// Case 2: DaVinci client
ActiveActiveStoreIngestionTask dvcIngestionTask = mock(ActiveActiveStoreIngestionTask.class);
doCallRealMethod().when(dvcIngestionTask).consumerSubscribe(any(), anyLong(), anyString());
when(dvcIngestionTask.getServerConfig()).thenReturn(mockServerConfig);
when(dvcIngestionTask.isDaVinciClient()).thenReturn(true);
when(mockServerConfig.getKafkaClusterUrlToIdMap()).thenReturn(Object2IntMaps.emptyMap());
try {
dvcIngestionTask.consumerSubscribe(pubSubTopicPartition, 100L, "validPubSubAddress");
} catch (Exception e) {
if (e.getMessage() != null) {
assertFalse(
e.getMessage().contains("is not in the pubsub cluster map"),
"Exception message should not contain the expected message but found: " + e.getMessage());
}
}
verify(dvcIngestionTask, times(1)).consumerSubscribe(pubSubTopicPartition, 100L, "validPubSubAddress");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3734,8 +3734,8 @@ public void testUpdateConsumedUpstreamRTOffsetMapDuringRTSubscription(AAConfig a
VeniceProperties mockVeniceProperties = mock(VeniceProperties.class);
doReturn(true).when(mockVeniceProperties).isEmpty();
doReturn(mockVeniceProperties).when(mockVeniceServerConfig).getKafkaConsumerConfigsForLocalConsumption();
doReturn(Object2IntMaps.emptyMap()).when(mockVeniceServerConfig).getKafkaClusterUrlToIdMap();
doReturn(Int2ObjectMaps.emptyMap()).when(mockVeniceServerConfig).getKafkaClusterIdToUrlMap();
doReturn(Object2IntMaps.singleton("localhost", 0)).when(mockVeniceServerConfig).getKafkaClusterUrlToIdMap();
doReturn(Int2ObjectMaps.singleton(0, "localhost")).when(mockVeniceServerConfig).getKafkaClusterIdToUrlMap();

StoreIngestionTaskFactory ingestionTaskFactory = TestUtils.getStoreIngestionTaskBuilder(storeName)
.setTopicManagerRepository(mockTopicManagerRepository)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ static ServiceProvider<VeniceClusterWrapper> generateService(VeniceClusterCreate
if (!options.getRegionName().isEmpty() && !options.getClusterName().isEmpty()) {
serverName = options.getRegionName() + ":" + options.getClusterName() + ":sn-" + i;
}

VeniceServerWrapper veniceServerWrapper = ServiceFactory.getVeniceServer(
options.getRegionName(),
options.getClusterName(),
Expand Down Expand Up @@ -655,6 +656,7 @@ public VeniceServerWrapper addVeniceServer(Properties properties) {
public VeniceServerWrapper addVeniceServer(Properties featureProperties, Properties configProperties) {
Properties mergedProperties = options.getExtraProperties();
mergedProperties.putAll(configProperties);

VeniceServerWrapper veniceServerWrapper = ServiceFactory.getVeniceServer(
options.getRegionName(),
getClusterName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static com.linkedin.venice.ConfigKeys.SERVER_SSL_HANDSHAKE_THREAD_POOL_SIZE;
import static com.linkedin.venice.ConfigKeys.SYSTEM_SCHEMA_CLUSTER_NAME;
import static com.linkedin.venice.ConfigKeys.SYSTEM_SCHEMA_INITIALIZATION_AT_START_TIME_ENABLED;
import static com.linkedin.venice.integration.utils.VeniceTwoLayerMultiRegionMultiClusterWrapper.addKafkaClusterIDMappingToServerConfigs;
import static com.linkedin.venice.meta.PersistenceType.ROCKS_DB;

import com.linkedin.davinci.config.VeniceConfigLoader;
Expand Down Expand Up @@ -298,8 +299,17 @@ static StatefulServiceProvider<VeniceServerWrapper> generateService(
List<ServiceDiscoveryAnnouncer> d2Servers =
new ArrayList<>(D2TestUtils.getD2Servers(zkAddress, d2ClusterName, httpURI, httpsURI));

Map<String, Map<String, String>> finalKafkaClusterMap = kafkaClusterMap;
if (finalKafkaClusterMap == null || finalKafkaClusterMap.isEmpty()) {
finalKafkaClusterMap = addKafkaClusterIDMappingToServerConfigs(
Optional.ofNullable(serverProps.toProperties()),
Collections.singletonList(regionName),
Arrays.asList(pubSubBrokerWrapper, pubSubBrokerWrapper));
LOGGER.info("PubSub cluster map was not provided. Constructed the following map: {}", finalKafkaClusterMap);
}

// generate the kafka cluster map in config directory
VeniceConfigLoader.storeKafkaClusterMap(configDirectory, kafkaClusterMap);
VeniceConfigLoader.storeKafkaClusterMap(configDirectory, finalKafkaClusterMap);

if (!forkServer) {
VeniceConfigLoader veniceConfigLoader =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ static ServiceProvider<VeniceTwoLayerMultiRegionMultiClusterWrapper> generateSer
}
}

private static Map<String, Map<String, String>> addKafkaClusterIDMappingToServerConfigs(
public static Map<String, Map<String, String>> addKafkaClusterIDMappingToServerConfigs(
Optional<Properties> serverProperties,
List<String> regionNames,
List<PubSubBrokerWrapper> kafkaBrokers) {
Expand Down

0 comments on commit 1735a63

Please sign in to comment.