diff --git a/integration-tests/src/test/java/com/linkedin/kafka/clients/producer/LiKafkaInstrumentedProducerIntegrationTest.java b/integration-tests/src/test/java/com/linkedin/kafka/clients/producer/LiKafkaInstrumentedProducerIntegrationTest.java index 9d57fb1..2c17160 100644 --- a/integration-tests/src/test/java/com/linkedin/kafka/clients/producer/LiKafkaInstrumentedProducerIntegrationTest.java +++ b/integration-tests/src/test/java/com/linkedin/kafka/clients/producer/LiKafkaInstrumentedProducerIntegrationTest.java @@ -17,9 +17,12 @@ import java.util.Collections; import java.util.Properties; import java.util.Random; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.apache.kafka.clients.admin.AdminClient; import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.producer.Callback; import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.clients.producer.ProducerRecord; @@ -102,6 +105,78 @@ record = new ProducerRecord<>(topic, 0, key, value); mario.close(); } + @Test + public void testCloseFromProduceCallbackOnSenderThread() throws Exception { + String topic = "testCloseFromProduceCallbackOnSenderThread"; + createTopic(topic, 1); + + Random random = new Random(666); + Properties extra = new Properties(); + extra.setProperty(ProducerConfig.MAX_REQUEST_SIZE_CONFIG, "" + 50000000); //~50MB (larger than broker-size setting) + extra.setProperty(ProducerConfig.ACKS_CONFIG, "-1"); + extra.setProperty(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class.getCanonicalName()); + extra.setProperty(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class.getCanonicalName()); + Properties baseProducerConfig = getProducerProperties(extra); + LiKafkaInstrumentedProducerImpl producer = new LiKafkaInstrumentedProducerImpl( + baseProducerConfig, + Collections.emptyMap(), + (baseConfig, overrideConfig) -> new LiKafkaProducerImpl(LiKafkaClientsUtils.getConsolidatedProperties(baseConfig, overrideConfig)), + () -> "bogus", + 10 //dont wait for a mario connection + ); + + byte[] key = new byte[3000]; + byte[] value = new byte[49000000]; + random.nextBytes(key); + random.nextBytes(value); //random data is incompressible, making sure our request is large + ProducerRecord record = new ProducerRecord<>(topic, key, value); + + AtomicReference issueRef = new AtomicReference<>(); + Thread testThread = new Thread(new Runnable() { + @Override + public void run() { + try { + final Thread ourThread = Thread.currentThread(); + Future future = producer.send(record, new Callback() { + @Override + public void onCompletion(RecordMetadata metadata, Exception exception) { + //we expect a RecordTooLargeException. we also expect this to happen + //on the same thread. + if (Thread.currentThread() != ourThread) { + issueRef.compareAndSet(null, + new IllegalStateException("completion did not happen on caller thread by " + Thread.currentThread().getName()) + ); + } + producer.close(1, TimeUnit.SECONDS); + } + }); + RecordMetadata recordMetadata = future.get(1, TimeUnit.MINUTES); + } catch (Throwable anything) { + issueRef.compareAndSet(null, anything); + } + } + }, "testCloseFromProduceCallbackOnSenderThread-thread"); + testThread.setDaemon(true); + testThread.setUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + issueRef.compareAndSet(null, e); + } + }); + testThread.start(); + + testThread.join(TimeUnit.MINUTES.toMillis(1)); + Thread.State state = testThread.getState(); + Assert.assertEquals( + state, + Thread.State.TERMINATED, + "thread was expected to finish, instead its " + state + ); + Throwable issue = issueRef.get(); + Throwable root = Throwables.getRootCause(issue); + Assert.assertTrue(root instanceof RecordTooLargeException, root.getMessage()); + } + private void createTopic(String topicName, int numPartitions) throws Exception { try (AdminClient adminClient = createRawAdminClient(null)) { adminClient.createTopics(Collections.singletonList(new NewTopic(topicName, numPartitions, (short) 1))).all().get(1, TimeUnit.MINUTES); diff --git a/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/producer/LiKafkaInstrumentedProducerImpl.java b/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/producer/LiKafkaInstrumentedProducerImpl.java index 1277c0a..38c4087 100644 --- a/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/producer/LiKafkaInstrumentedProducerImpl.java +++ b/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/producer/LiKafkaInstrumentedProducerImpl.java @@ -24,7 +24,6 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.function.Supplier; import org.apache.kafka.clients.consumer.OffsetAndMetadata; @@ -54,7 +53,8 @@ public class LiKafkaInstrumentedProducerImpl implements DelegatingProducer private static final String BOUNDED_FLUSH_THREAD_PREFIX = "Bounded-Flush-Thread-"; private final long initialConnectionTimeoutMs; - private final ReadWriteLock delegateLock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock delegateLock = new ReentrantReadWriteLock(); + private final Object closeLock = new Object(); private final Properties baseConfig; private final Map libraryVersions; private final ProducerFactory producerFactory; @@ -198,6 +198,10 @@ public Producer getDelegate() { private boolean recreateDelegate(boolean abortIfExists) { delegateLock.writeLock().lock(); try { + if (isClosed()) { + LOG.debug("this producer has been closed, not creating a new delegate"); + return false; + } Producer prevProducer = delegate; if (prevProducer != null) { if (abortIfExists) { @@ -312,6 +316,8 @@ public Future send(ProducerRecord record) { public Future send(ProducerRecord record, Callback callback) { verifyOpen(); + //the callback may try and obtain a write lock (say call producer.close()) + //so we grab an update lock, which is upgradable to a write lock delegateLock.readLock().lock(); try { return delegate.send(record, callback); @@ -444,15 +450,38 @@ private boolean proceedClosing() { if (isClosed()) { return false; } - delegateLock.writeLock().lock(); - try { + synchronized (closeLock) { if (isClosed()) { - return false; + return false; //someone beat us to it + } + int holds = delegateLock.getReadHoldCount(); //this is for our thread + ReentrantReadWriteLock.ReadLock readLock = delegateLock.readLock(); + ReentrantReadWriteLock.WriteLock writeLock = delegateLock.writeLock(); + if (holds > 0) { //do we own a read lock ? + for (int i = 0; i < holds; i++) { + readLock.unlock(); + } + //at this point we no longer hold a read lock, but any number of other + //readers/writers may slip past us + } + try { + writeLock.lock(); //wait for a write lock + try { + if (isClosed()) { + return false; //some other writer may have beaten us again + } + closedAt = System.currentTimeMillis(); + return true; + } finally { + writeLock.unlock(); + } + } finally { + if (holds > 0) { //restore our read lock holds (if we had any) + for (int i = 0; i < holds; i++) { + readLock.lock(); + } + } } - closedAt = System.currentTimeMillis(); - return true; - } finally { - delegateLock.writeLock().unlock(); } } diff --git a/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/producer/LiKafkaProducerImpl.java b/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/producer/LiKafkaProducerImpl.java index 1c5a281..c90802d 100644 --- a/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/producer/LiKafkaProducerImpl.java +++ b/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/producer/LiKafkaProducerImpl.java @@ -6,6 +6,7 @@ import com.linkedin.kafka.clients.auditing.AuditType; import com.linkedin.kafka.clients.auditing.Auditor; +import com.linkedin.kafka.clients.auditing.NoOpAuditor; import com.linkedin.kafka.clients.largemessage.LargeMessageCallback; import com.linkedin.kafka.clients.largemessage.LargeMessageSegment; import com.linkedin.kafka.clients.largemessage.MessageSplitter; @@ -469,19 +470,23 @@ public void close(long timeout, TimeUnit timeUnit) { long deadlineTimeMs = startTimeMs + budgetMs; _closed = true; - synchronized (_numThreadsInSend) { - long remainingMs = deadlineTimeMs - System.currentTimeMillis(); - while (_numThreadsInSend.get() > 0 && remainingMs > 0) { - try { - _numThreadsInSend.wait(remainingMs); - } catch (InterruptedException e) { - LOG.error("Interrupted when there are still {} sender threads.", _numThreadsInSend.get()); - break; + + //wait for all producing threads to clear the auditor + //if there's a meaningful auditor. + if (!(_auditor instanceof NoOpAuditor)) { + synchronized (_numThreadsInSend) { + long remainingMs = deadlineTimeMs - System.currentTimeMillis(); + while (_numThreadsInSend.get() > 0 && remainingMs > 0) { + try { + _numThreadsInSend.wait(remainingMs); + } catch (InterruptedException e) { + LOG.error("Interrupted when there are still {} sender threads.", _numThreadsInSend.get()); + break; + } + remainingMs = deadlineTimeMs - System.currentTimeMillis(); } - remainingMs = deadlineTimeMs - System.currentTimeMillis(); } } - _auditor.close(Math.max(0, deadlineTimeMs - System.currentTimeMillis()), TimeUnit.MILLISECONDS); _producer.close(Math.max(0, deadlineTimeMs - System.currentTimeMillis()), TimeUnit.MILLISECONDS); LOG.info("LiKafkaProducer shutdown complete in {} millis", (System.currentTimeMillis() - startTimeMs)); diff --git a/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/utils/LiKafkaClientsUtils.java b/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/utils/LiKafkaClientsUtils.java index 4bbbae4..ca11fd9 100644 --- a/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/utils/LiKafkaClientsUtils.java +++ b/li-apache-kafka-clients/src/main/java/com/linkedin/kafka/clients/utils/LiKafkaClientsUtils.java @@ -211,8 +211,8 @@ private static String fishForClientId(Map metrics) /** * Special header keys have a "_" prefix and are managed internally by the clients. - * @param headers - * @return + * @param headers kafka headers object + * @return any "special" headers container in the argument map */ public static Map fetchSpecialHeaders(Headers headers) { Map map = new HashMap<>();