diff --git a/cdap-app-fabric/src/main/java/io/cdap/cdap/internal/app/worker/system/SystemWorkerHttpHandlerInternal.java b/cdap-app-fabric/src/main/java/io/cdap/cdap/internal/app/worker/system/SystemWorkerHttpHandlerInternal.java index e00480769dae..21cb75aae0b8 100644 --- a/cdap-app-fabric/src/main/java/io/cdap/cdap/internal/app/worker/system/SystemWorkerHttpHandlerInternal.java +++ b/cdap-app-fabric/src/main/java/io/cdap/cdap/internal/app/worker/system/SystemWorkerHttpHandlerInternal.java @@ -100,6 +100,7 @@ public void run(FullHttpRequest request, HttpResponder responder) { if (requestProcessedCount.incrementAndGet() > requestLimit) { responder.sendStatus(HttpResponseStatus.TOO_MANY_REQUESTS); + requestProcessedCount.decrementAndGet(); return; } diff --git a/cdap-app-fabric/src/test/java/io/cdap/cdap/internal/app/worker/system/SystemWorkerServiceTest.java b/cdap-app-fabric/src/test/java/io/cdap/cdap/internal/app/worker/system/SystemWorkerServiceTest.java index 73d3612372b0..7238db9e104a 100644 --- a/cdap-app-fabric/src/test/java/io/cdap/cdap/internal/app/worker/system/SystemWorkerServiceTest.java +++ b/cdap-app-fabric/src/test/java/io/cdap/cdap/internal/app/worker/system/SystemWorkerServiceTest.java @@ -198,7 +198,7 @@ public void testValidConcurrentRequests() throws Exception { == HttpResponseStatus.OK.code()) { okResponse++; } else if (responses.get(i).get().getResponseCode() - == HttpResponseStatus.TOO_MANY_REQUESTS.code()) { + == HttpResponseStatus.TOO_MANY_REQUESTS.code()) { conflictResponse++; } } @@ -206,6 +206,50 @@ public void testValidConcurrentRequests() throws Exception { Assert.assertEquals(concurrentRequests, okResponse + conflictResponse); } + @Test + public void testRepeatedConcurrentRequests() throws Exception { + InetSocketAddress addr = systemWorkerService.getBindAddress(); + URI uri = URI.create( + String.format("http://%s:%s", addr.getHostName(), addr.getPort())); + + RunnableTaskRequest request = RunnableTaskRequest.getBuilder( + SystemWorkerServiceTest.TestRunnableClass.class.getName()) + .withParam("500").build(); + + String reqBody = GSON.toJson(request); + List> calls = new ArrayList<>(); + int concurrentRequests = 6; + + for (int i = 0; i < concurrentRequests; i++) { + calls.add(() -> HttpRequests.execute( + HttpRequest.post(uri.resolve("/v3Internal/system/run").toURL()) + .withBody(reqBody).build(), new DefaultHttpRequestConfig(false))); + } + + Executors.newFixedThreadPool(concurrentRequests).invokeAll(calls); + + // Wait for requests to complete and retry them. + Thread.sleep(1000); + + int okResponse = 0; + int conflictResponse = 0; + List> responses = Executors.newFixedThreadPool( + concurrentRequests).invokeAll(calls); + for (int i = 0; i < concurrentRequests; i++) { + if (responses.get(i).get().getResponseCode() + == HttpResponseStatus.OK.code()) { + okResponse++; + } else if (responses.get(i).get().getResponseCode() + == HttpResponseStatus.TOO_MANY_REQUESTS.code()) { + conflictResponse++; + } + } + + Assert.assertEquals(5, okResponse); + Assert.assertEquals(concurrentRequests, okResponse + conflictResponse); + } + + @Test public void testInvalidConcurrentRequests() throws Exception { InetSocketAddress addr = systemWorkerService.getBindAddress();