Skip to content

Commit

Permalink
refactor: Optimize timeout handling in TimeoutSubscriber
Browse files Browse the repository at this point in the history
  • Loading branch information
jeong-yong-shin authored and jeong-yong-shin committed Jun 24, 2024
1 parent 2d6d00f commit 7c52cef
Show file tree
Hide file tree
Showing 2 changed files with 263 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,70 +29,88 @@
import io.netty.util.concurrent.ScheduledFuture;

public class TimeoutSubscriber<T> implements Subscriber<T> {

private static final String TIMEOUT_MESSAGE = "Stream timed out after %d ms (timeout mode: %s)";
private final Subscriber<? super T> delegate;
private final EventExecutor executor;

private final StreamTimeoutMode timeoutMode;
private final Duration timeoutDuration;
private ScheduledFuture<?> timeoutFuture;

private Runnable timeoutTask;
private Subscription subscription;
private final Duration timeoutDuration;
private long timeoutNanos;
private long lastOnNextTimeNanos;

public TimeoutSubscriber(Subscriber<? super T> delegate, EventExecutor executor, Duration timeoutDuration, StreamTimeoutMode timeoutMode) {
this.delegate = requireNonNull(delegate, "delegate");
this.executor = requireNonNull(executor, "executor");
this.timeoutDuration = requireNonNull(timeoutDuration, "timeoutDuration");
timeoutNanos = timeoutDuration.toNanos();
this.timeoutMode = requireNonNull(timeoutMode, "timeoutMode");
timeoutTask = createTimeoutTask();
}

private ScheduledFuture<?> scheduleTimeout() {
return executor.schedule(() -> {
private Runnable createTimeoutTask() {
return () -> {
if(timeoutMode == StreamTimeoutMode.UNTIL_NEXT) {
long currentTimeNanos = System.nanoTime();
long elapsedNanos = currentTimeNanos - lastOnNextTimeNanos;

if(elapsedNanos <= timeoutNanos) {
long delayNanos = timeoutNanos - (currentTimeNanos - lastOnNextTimeNanos);
timeoutFuture = createTimeoutSchedule(delayNanos, TimeUnit.NANOSECONDS);
return;
}
}
subscription.cancel();
delegate.onError(new TimeoutException(
String.format("Stream timed out after %d ms (timeout mode: %s)", timeoutDuration.toMillis(), timeoutMode)));
}, timeoutDuration.toMillis(), TimeUnit.MILLISECONDS);
delegate.onError(new TimeoutException(String.format(TIMEOUT_MESSAGE, timeoutDuration.toMillis(), timeoutMode)));
};
}

private ScheduledFuture<?> createTimeoutSchedule(long delay, TimeUnit unit) {
return executor.schedule(timeoutTask, delay, unit);
}

private void cancelSchedule() {
if(!timeoutFuture.isCancelled()) {
timeoutFuture.cancel(false);
}
}

@Override
public void onSubscribe(Subscription s) {
delegate.onSubscribe(s);
subscription = s;
timeoutFuture = scheduleTimeout();
lastOnNextTimeNanos = System.nanoTime();
timeoutFuture = createTimeoutSchedule(timeoutNanos, TimeUnit.NANOSECONDS);
delegate.onSubscribe(s);
}

@Override
public void onNext(T t) {
delegate.onNext(t);
if (timeoutFuture.isCancelled()) {
return;
}
switch (timeoutMode) {
case UNTIL_NEXT:
timeoutFuture.cancel(false);
timeoutFuture = scheduleTimeout();
lastOnNextTimeNanos = System.nanoTime();
break;
case UNTIL_FIRST:
timeoutFuture.cancel(false);
break;
case UNTIL_EOS:
break;
}
delegate.onNext(t);
}

@Override
public void onError(Throwable throwable) {
if(!timeoutFuture.isCancelled()) {
timeoutFuture.cancel(false);
}
cancelSchedule();
delegate.onError(throwable);
}

@Override
public void onComplete() {
if(!timeoutFuture.isCancelled()) {
timeoutFuture.cancel(false);
}
cancelSchedule();
delegate.onComplete();
}
}
Original file line number Diff line number Diff line change
@@ -1,2 +1,225 @@
package com.linecorp.armeria.common.stream;public class TimeoutStreamMessageTest {
}
package com.linecorp.armeria.common.stream;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import io.netty.util.concurrent.DefaultEventExecutor;
import io.netty.util.concurrent.EventExecutor;

public class TimeoutStreamMessageTest {
private EventExecutor executor;

@BeforeEach
public void setUp() {
executor = new DefaultEventExecutor();
}

@AfterEach
public void tearDown() {
executor.shutdownGracefully();
}


@Test
public void timeoutNextMode() {
StreamMessage<String> timeoutStreamMessage = StreamMessage.of("message1", "message2").timeout(
Duration.ofSeconds(1), StreamTimeoutMode.UNTIL_NEXT);
CompletableFuture<Void> future = new CompletableFuture<>();

timeoutStreamMessage.subscribe(new Subscriber<String>() {
private Subscription subscription;
@Override
public void onSubscribe(Subscription s) {
subscription = s;
subscription.request(1);
}

@Override
public void onNext(String s) {
executor.schedule(() -> subscription.request(1), 2, TimeUnit.SECONDS);
}

@Override
public void onError(Throwable throwable) {
future.completeExceptionally(throwable);
}

@Override
public void onComplete() {
future.complete(null);
}
}, executor);

assertThatThrownBy(future::get)
.isInstanceOf(ExecutionException.class)
.hasCauseInstanceOf(TimeoutException.class);
}

@Test
public void noTimeoutNextMode() throws Exception {
StreamMessage<String> timeoutStreamMessage = StreamMessage.of("message1", "message2").timeout(Duration.ofSeconds(1), StreamTimeoutMode.UNTIL_NEXT);

CompletableFuture<Void> future = new CompletableFuture<>();

timeoutStreamMessage.subscribe(new Subscriber<String>() {
@Override
public void onSubscribe(Subscription subscription) {
subscription.request(2);
}

@Override
public void onNext(String s) {
}

@Override
public void onError(Throwable throwable) {
future.completeExceptionally(throwable);
}

@Override
public void onComplete() {
future.complete(null);
}
}, executor);

assertThat(future.get()).isNull();
}

@Test
public void timeoutFirstMode() {
StreamMessage<String> timeoutStreamMessage = StreamMessage.of("message1", "message2").timeout(Duration.ofSeconds(1), StreamTimeoutMode.UNTIL_FIRST);
CompletableFuture<Void> future = new CompletableFuture<>();

timeoutStreamMessage.subscribe(new Subscriber<String>() {
private Subscription subscription;
@Override
public void onSubscribe(Subscription s) {
subscription = s;
executor.schedule(() -> subscription.request(1), 2, TimeUnit.SECONDS);
}

@Override
public void onNext(String s) {
subscription.request(1);
}

@Override
public void onError(Throwable throwable) {
future.completeExceptionally(throwable);
}

@Override
public void onComplete() {
future.complete(null);
}
}, executor);

assertThatThrownBy(future::get)
.isInstanceOf(ExecutionException.class)
.hasCauseInstanceOf(TimeoutException.class);
}

@Test
public void noTimeoutModeFirst() throws Exception {
StreamMessage<String> timeoutStreamMessage = StreamMessage.of("message1", "message2").timeout(Duration.ofSeconds(1), StreamTimeoutMode.UNTIL_FIRST);
CompletableFuture<Void> future = new CompletableFuture<>();

timeoutStreamMessage.subscribe(new Subscriber<String>() {
@Override
public void onSubscribe(Subscription subscription) {
subscription.request(2);
}

@Override
public void onNext(String s) {
}

@Override
public void onError(Throwable throwable) {
future.completeExceptionally(throwable);
}

@Override
public void onComplete() {
future.complete(null);
}
}, executor);

assertThat(future.get()).isNull();
}

@Test
public void timeoutEOSMode() {
StreamMessage<String> timeoutStreamMessage = StreamMessage.of("message1", "message2").timeout(Duration.ofSeconds(2), StreamTimeoutMode.UNTIL_EOS);
CompletableFuture<Void> future = new CompletableFuture<>();

timeoutStreamMessage.subscribe(new Subscriber<String>() {
private Subscription subscription;
@Override
public void onSubscribe(Subscription s) {
subscription = s;
executor.schedule(() -> subscription.request(1), 1, TimeUnit.SECONDS);
}

@Override
public void onNext(String s) {
executor.schedule(() -> subscription.request(1), 2, TimeUnit.SECONDS);
}

@Override
public void onError(Throwable throwable) {
future.completeExceptionally(throwable);
}

@Override
public void onComplete() {
future.complete(null);
}
}, executor);

assertThatThrownBy(future::get)
.isInstanceOf(ExecutionException.class)
.hasCauseInstanceOf(TimeoutException.class);
}

@Test
public void noTimeoutEOSMode() throws Exception {
StreamMessage<String> timeoutStreamMessage = StreamMessage.of("message1", "message2").timeout(Duration.ofSeconds(2), StreamTimeoutMode.UNTIL_EOS);
CompletableFuture<Void> future = new CompletableFuture<>();

timeoutStreamMessage.subscribe(new Subscriber<String>() {
@Override
public void onSubscribe(Subscription subscription) {
subscription.request(2);
}

@Override
public void onNext(String s) {
}

@Override
public void onError(Throwable throwable) {
future.completeExceptionally(throwable);
}

@Override
public void onComplete() {
future.complete(null);
}
}, executor);

assertThat(future.get()).isNull();
}
}

0 comments on commit 7c52cef

Please sign in to comment.