Skip to content

Commit

Permalink
Merge pull request cdapio#15391 from cdapio/CDAP-20868-task-worker-co…
Browse files Browse the repository at this point in the history
…ncurrent-requests

[CDAP-20868] Allow task workers to run concurrent requests when configured
  • Loading branch information
arjan-bal authored Nov 6, 2023
2 parents f878680 + 8b6c4d6 commit 9bd05f5
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import io.cdap.cdap.api.service.worker.RunnableTaskRequest;
import io.cdap.cdap.common.conf.CConfiguration;
import io.cdap.cdap.common.conf.Constants;
import io.cdap.cdap.common.conf.Constants.TaskWorker;
import io.cdap.cdap.common.conf.SConfiguration;
import io.cdap.cdap.common.http.CommonNettyHttpServiceFactory;
import io.cdap.cdap.common.http.DefaultHttpRequestConfig;
import io.cdap.cdap.common.metrics.NoOpMetricsCollectionService;
import io.cdap.cdap.common.utils.Tasks;
import io.cdap.cdap.proto.BasicThrowable;
import io.cdap.common.http.HttpRequest;
import io.cdap.common.http.HttpRequests;
Expand All @@ -44,15 +46,15 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.twill.discovery.InMemoryDiscoveryService;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Unit test for {@link TaskWorkerService}.
Expand All @@ -61,7 +63,6 @@ public class TaskWorkerServiceTest {
@ClassRule
public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder();

private static final Logger LOG = LoggerFactory.getLogger(TaskWorkerServiceTest.class);
private static final Gson GSON = new Gson();
private static final MetricsCollectionService metricsCollectionService = new NoOpMetricsCollectionService();

Expand Down Expand Up @@ -284,12 +285,14 @@ public void testStartAndStopWithInvalidRequest() throws Exception {
}

@Test
public void testConcurrentRequests() throws Exception {
public void testConcurrentRequestsWithIsolationEnabled() throws Exception {
InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

RunnableTaskRequest request = RunnableTaskRequest.getBuilder(TestRunnableClass.class.getName())
.withParam("1000").withNamespace("testNamespace").build();
RunnableTaskRequest request = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam("1000").withNamespace("testNamespace").build();

String reqBody = GSON.toJson(request);
List<Callable<HttpResponse>> calls = new ArrayList<>();
Expand All @@ -310,7 +313,8 @@ public void testConcurrentRequests() throws Exception {
for (int i = 0; i < concurrentRequests; i++) {
if (responses.get(i).get().getResponseCode() == HttpResponseStatus.OK.code()) {
okResponse++;
} else if (responses.get(i).get().getResponseCode() == HttpResponseStatus.TOO_MANY_REQUESTS.code()) {
} else if (responses.get(i).get().getResponseCode()
== HttpResponseStatus.TOO_MANY_REQUESTS.code()) {
conflictResponse++;
}
}
Expand All @@ -320,7 +324,67 @@ public void testConcurrentRequests() throws Exception {
Assert.assertEquals(Service.State.TERMINATED, taskWorkerService.state());
}

@Test
public void testConcurrentRequestsWithIsolationDisabled() throws Exception {
CConfiguration cConf = createCConf();
cConf.setInt(TaskWorker.REQUEST_LIMIT, 2);
cConf.setBoolean(TaskWorker.USER_CODE_ISOLATION_ENABLED, false);
InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(cConf,
createSConf(), discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
taskWorkerService.startAndWait();
InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

RunnableTaskRequest request = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam("1000").withNamespace("testNamespace").build();

String reqBody = GSON.toJson(request);
List<Callable<HttpResponse>> calls = new ArrayList<>();
int concurrentRequests = 3;

for (int i = 0; i < concurrentRequests; i++) {
calls.add(
() -> HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false))
);
}

List<Future<HttpResponse>> responses = Executors.newFixedThreadPool(
concurrentRequests).invokeAll(calls);
int okResponse = 0;
int conflictResponse = 0;
for (int i = 0; i < concurrentRequests; i++) {
if (responses.get(i).get().getResponseCode()
== HttpResponseStatus.OK.code()) {
okResponse++;
} else if (responses.get(i).get().getResponseCode()
== HttpResponseStatus.TOO_MANY_REQUESTS.code()) {
conflictResponse++;
}
}
// Verify that the task worker service doesn't stop automatically.
try {
Tasks.waitFor(false, () -> taskWorkerService.isRunning(), 1,
TimeUnit.SECONDS);
Assert.fail();
} catch (TimeoutException e) {
// ignore.
}
taskWorkerService.stopAndWait();
Assert.assertEquals(2, okResponse);
Assert.assertEquals(concurrentRequests, okResponse + conflictResponse);
Assert.assertEquals(Service.State.TERMINATED, taskWorkerService.state());
}

public static class TestRunnableClass implements RunnableTask {

@Override
public void run(RunnableTaskContext context) throws Exception {
if (!context.getParam().equals("")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ public static final class TaskWorker {
"task.worker.container.kill.after.request.count";
public static final String CONTAINER_KILL_AFTER_DURATION_SECOND =
"task.worker.container.kill.after.duration.second";
public static final String REQUEST_LIMIT = "task.worker.request.limit";
public static final String USER_CODE_ISOLATION_ENABLED = "task.worker.request.userCodeIsolation.enabled";
public static final String CONTAINER_RUN_AS_USER = "task.worker.container.run.as.user";
public static final String CONTAINER_RUN_AS_GROUP = "task.worker.container.run.as.group";
public static final String CONTAINER_DISK_READONLY = "task.worker.container.disk.readonly";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.cdap.cdap.api.service.worker.RunnableTaskRequest;
import io.cdap.cdap.common.conf.CConfiguration;
import io.cdap.cdap.common.conf.Constants;
import io.cdap.cdap.common.conf.Constants.TaskWorker;
import io.cdap.cdap.common.utils.GcpMetadataTaskContextUtil;
import io.cdap.cdap.proto.BasicThrowable;
import io.cdap.cdap.proto.codec.BasicThrowableCodec;
Expand Down Expand Up @@ -82,12 +83,11 @@ public class TaskWorkerHttpHandlerInternal extends AbstractHttpHandler {
private final RunnableTaskLauncher runnableTaskLauncher;
private final BiConsumer<Boolean, TaskDetails> taskCompletionConsumer;

private final AtomicBoolean hasInflightRequest = new AtomicBoolean(false);

/**
* Holds the total number of requests that have been executed by this handler
* that should count toward max allowed.
*/
private final AtomicInteger runningRequestCount = new AtomicInteger(0);
private final AtomicInteger requestProcessedCount = new AtomicInteger(0);

private final String metadataServiceEndpoint;
Expand All @@ -98,6 +98,7 @@ public class TaskWorkerHttpHandlerInternal extends AbstractHttpHandler {
* If true, pod will restart once an operation finish its execution.
*/
private final AtomicBoolean mustRestart = new AtomicBoolean(false);
private final int requestLimit;

/**
* Constructs the {@link TaskWorkerHttpHandlerInternal}.
Expand All @@ -110,37 +111,48 @@ public TaskWorkerHttpHandlerInternal(CConfiguration cConf,
final int killAfterRequestCount = cConf.getInt(
Constants.TaskWorker.CONTAINER_KILL_AFTER_REQUEST_COUNT, 0);
this.runnableTaskLauncher = new RunnableTaskLauncher(cConf,
discoveryService,
discoveryServiceClient, metricsCollectionService);
discoveryService, discoveryServiceClient, metricsCollectionService);
this.metricsCollectionService = metricsCollectionService;
this.metadataServiceEndpoint = cConf.get(
Constants.TaskWorker.METADATA_SERVICE_END_POINT);
this.taskCompletionConsumer = (succeeded, taskDetails) -> {
taskDetails.emitMetrics(succeeded);
boolean enableUserCodeIsolationEnabled = cConf.getBoolean(
TaskWorker.USER_CODE_ISOLATION_ENABLED);
if (enableUserCodeIsolationEnabled) {
// Run only one request at a time in user code isolation mode.
this.requestLimit = 1;
// Restart the service to clean up and re-claim resources after user code
// execution.
this.taskCompletionConsumer = (succeeded, taskDetails) -> {
taskDetails.emitMetrics(succeeded);
runningRequestCount.decrementAndGet();
requestProcessedCount.incrementAndGet();

String className = taskDetails.getClassName();
String className = taskDetails.getClassName();

if (mustRestart.get()) {
stopper.accept(className);
return;
}
if (mustRestart.get()) {
stopper.accept(className);
return;
}

if (!taskDetails.isTerminateOnComplete() || className == null
|| killAfterRequestCount <= 0) {
// No need to restart.
requestProcessedCount.decrementAndGet();
hasInflightRequest.set(false);
return;
}
if (!taskDetails.isTerminateOnComplete() || className == null
|| killAfterRequestCount <= 0) {
// No need to restart.
return;
}

if (requestProcessedCount.get() >= killAfterRequestCount) {
stopper.accept(className);
} else {
hasInflightRequest.set(false);
}
};
if (requestProcessedCount.get() >= killAfterRequestCount) {
stopper.accept(className);
}
};

enablePeriodicRestart(cConf, stopper);
enablePeriodicRestart(cConf, stopper);
} else {
this.requestLimit = cConf.getInt(TaskWorker.REQUEST_LIMIT);
this.taskCompletionConsumer = (succeeded, taskDetails) -> {
taskDetails.emitMetrics(succeeded);
runningRequestCount.decrementAndGet();
};
}
}

/**
Expand Down Expand Up @@ -170,10 +182,10 @@ private void enablePeriodicRestart(CConfiguration cConf,
stopper.accept("");
return;
}
// we restart once ongoing request (which has set hasInflightRequest to true)
// we restart once ongoing request (which has set runningRequestCount to 1)
// finishes.
mustRestart.set(true);
if (hasInflightRequest.compareAndSet(false, true)) {
if (runningRequestCount.compareAndSet(0, 1)) {
// there is no ongoing request. pod gets restarted.
stopper.accept("");
}
Expand All @@ -193,11 +205,11 @@ private void enablePeriodicRestart(CConfiguration cConf,
@POST
@Path("/run")
public void run(FullHttpRequest request, HttpResponder responder) {
if (!hasInflightRequest.compareAndSet(false, true)) {
if (runningRequestCount.incrementAndGet() > requestLimit) {
responder.sendStatus(HttpResponseStatus.TOO_MANY_REQUESTS);
runningRequestCount.decrementAndGet();
return;
}
requestProcessedCount.incrementAndGet();

long startTime = System.currentTimeMillis();
try {
Expand Down Expand Up @@ -303,15 +315,15 @@ private String exceptionToJson(Exception ex) {
private static class RunnableTaskBodyProducer extends BodyProducer {

private final RunnableTaskContext context;
private final BiConsumer<Boolean, TaskDetails> stopper;
private final BiConsumer<Boolean, TaskDetails> taskCompletionConsumer;
private final TaskDetails taskDetails;
private boolean done;

RunnableTaskBodyProducer(RunnableTaskContext context,
BiConsumer<Boolean, TaskDetails> stopper,
BiConsumer<Boolean, TaskDetails> taskCompletionConsumer,
TaskDetails taskDetails) {
this.context = context;
this.stopper = stopper;
this.taskCompletionConsumer = taskCompletionConsumer;
this.taskDetails = taskDetails;
}

Expand All @@ -328,14 +340,14 @@ public ByteBuf nextChunk() {
@Override
public void finished() {
context.executeCleanupTask();
stopper.accept(true, taskDetails);
taskCompletionConsumer.accept(true, taskDetails);
}

@Override
public void handleError(@Nullable Throwable cause) {
LOG.error("Error when sending chunks", cause);
context.executeCleanupTask();
stopper.accept(false, taskDetails);
taskCompletionConsumer.accept(false, taskDetails);
}
}
}
17 changes: 17 additions & 0 deletions cdap-common/src/main/resources/cdap-default.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5251,6 +5251,23 @@
</description>
</property>

<property>
<name>task.worker.request.limit</name>
<value>10</value>
<description>
Number of concurrent requests accepted by task worker pods.
</description>
</property>

<property>
<name>task.worker.request.userCodeIsolation.enabled</name>
<value>true</value>
<description>
Whether user code isolation is enabled in task worker. When enabled, task workers will ensure
multiple requests that run user code are not executed concurrently.
</description>
</property>

<property>
<name>task.worker.bind.address</name>
<value>0.0.0.0</value>
Expand Down

0 comments on commit 9bd05f5

Please sign in to comment.