Skip to content

Commit

Permalink
fix: Improve MCP server stability and test reliability
Browse files Browse the repository at this point in the history
- Replace bounded elastic schedulers with single thread executors
- Add timeouts to MCP server tests
- Enhance error handling in StdioServerTransport
- Comment out unused PaginatedRequest unmarshalling
- Simplify shutdown logic in StdioServerTransport
  • Loading branch information
tzolov committed Dec 30, 2024
1 parent fcf0ff0 commit ad3cdd2
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ private DefaultMcpSession.RequestHandler toolsCallRequestHandler() {

private DefaultMcpSession.RequestHandler resourcesListRequestHandler() {
return params -> {
McpSchema.PaginatedRequest request = transport.unmarshalFrom(params,
new TypeReference<McpSchema.PaginatedRequest>() {
});
// McpSchema.PaginatedRequest request = transport.unmarshalFrom(params,
// new TypeReference<McpSchema.PaginatedRequest>() {
// });

var resourceList = this.resources.values().stream().map(ResourceRegistration::resource).toList();

Expand All @@ -361,9 +361,9 @@ private DefaultMcpSession.RequestHandler resourcesListRequestHandler() {

private DefaultMcpSession.RequestHandler resourceTemplateListRequestHandler() {
return params -> {
McpSchema.PaginatedRequest request = transport.unmarshalFrom(params,
new TypeReference<McpSchema.PaginatedRequest>() {
});
// McpSchema.PaginatedRequest request = transport.unmarshalFrom(params,
// new TypeReference<McpSchema.PaginatedRequest>() {
// });

return Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null));
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.concurrent.Executors;
import java.util.function.Function;

import com.fasterxml.jackson.core.type.TypeReference;
Expand Down Expand Up @@ -97,8 +98,8 @@ public StdioServerTransport(ObjectMapper objectMapper) {
this.outputStream = System.out;

// Use bounded schedulers for better resource management
this.inboundScheduler = Schedulers.newBoundedElastic(1, 1, "inbound");
this.outboundScheduler = Schedulers.newBoundedElastic(1, 1, "outbound");
this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "inbound");
this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound");
}

@Override
Expand Down Expand Up @@ -139,38 +140,43 @@ public Mono<Void> sendMessage(JSONRPCMessage message) {
private void startInboundProcessing() {
this.inboundScheduler.schedule(() -> {
inboundReady.tryEmitValue(null);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
String line;
while (!isClosing && (line = reader.readLine()) != null) {
BufferedReader reader = null;
try {
reader = new BufferedReader(new InputStreamReader(inputStream));
while (!isClosing) {
try {
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line);
if (!this.inboundSink.tryEmitNext(message).isSuccess()) {
String line = reader.readLine();
if (line == null || isClosing) {
break;
}

try {
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line);
if (!this.inboundSink.tryEmitNext(message).isSuccess()) {
if (!isClosing) {
logger.error("Failed to enqueue message");
}
break;
}
}
catch (Exception e) {
if (!isClosing) {
logger.error("Failed to enqueue message");
logger.error("Error processing inbound message", e);
}
break;
}
}
catch (Exception e) {
catch (IOException e) {
if (!isClosing) {
logger.error("Error processing inbound message", e);
logger.error("Error reading from stdin", e);
}
break;
}
}
}
catch (IOException e) {
// Check isClosing before the error occurs to properly categorize it
boolean wasClosing = isClosing;
isClosing = true;
if (!wasClosing && e.getMessage().equals("Pipe closed")) {
logger.debug("Stream closed during shutdown", e);
}
else if (!wasClosing) {
logger.error("Error reading from stdin", e);
}
else {
logger.debug("Stream error during shutdown", e);
catch (Exception e) {
if (!isClosing) {
logger.error("Error in inbound processing", e);
}
}
finally {
Expand Down Expand Up @@ -234,6 +240,7 @@ else if (isClosing) {

@Override
public Mono<Void> closeGracefully() {

return Mono.fromRunnable(() -> {
isClosing = true;
logger.debug("Initiating graceful shutdown");
Expand All @@ -244,18 +251,10 @@ public Mono<Void> closeGracefully() {
return Mono.delay(Duration.ofMillis(100));
})).then(Mono.fromRunnable(() -> {
try {
// Dispose schedulers first
// Dispose schedulers with longer timeout
inboundScheduler.dispose();
outboundScheduler.dispose();

// Wait for schedulers to terminate
if (!inboundScheduler.isDisposed()) {
inboundScheduler.disposeGracefully().block(Duration.ofSeconds(5));
}
if (!outboundScheduler.isDisposed()) {
outboundScheduler.disposeGracefully().block(Duration.ofSeconds(5));
}

logger.info("Graceful shutdown completed");
}
catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ protected void onClose() {

@BeforeEach
void setUp() {
// onStart();
}

@AfterEach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.ai.mcp.server;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Timeout;

import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.server.RouterFunctions;
Expand All @@ -31,6 +33,7 @@
*
* @author Christian Tzolov
*/
@Timeout(15) // Giving extra time beyond the client timeout
class SseMcpAsyncServerTests extends AbstractMcpAsyncServerTests {

private static final int PORT = 8181;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.ai.mcp.server;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Timeout;

import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.server.RouterFunctions;
Expand All @@ -31,6 +33,7 @@
*
* @author Christian Tzolov
*/
@Timeout(15) // Giving extra time beyond the client timeout
class SseMcpSyncServerTests extends AbstractMcpSyncServerTests {

private static final int PORT = 8182;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.ai.mcp.server;

import org.junit.jupiter.api.Timeout;

import org.springframework.ai.mcp.server.transport.StdioServerTransport;
import org.springframework.ai.mcp.spec.McpTransport;

Expand All @@ -24,6 +26,7 @@
*
* @author Christian Tzolov
*/
@Timeout(15) // Giving extra time beyond the client timeout
class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.ai.mcp.server;

import org.junit.jupiter.api.Timeout;

import org.springframework.ai.mcp.server.transport.StdioServerTransport;
import org.springframework.ai.mcp.spec.McpTransport;

Expand All @@ -24,6 +26,7 @@
*
* @author Christian Tzolov
*/
@Timeout(15) // Giving extra time beyond the client timeout
class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.ai.mcp.server.transport;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.PrintStream;
Expand Down Expand Up @@ -51,8 +50,6 @@ class StdioServerTransportTests {

private ByteArrayOutputStream testOut;

private ByteArrayInputStream testIn;

private ByteArrayOutputStream testErr;

private PrintStream testOutPrintStream;
Expand Down Expand Up @@ -89,7 +86,6 @@ void tearDown() {
void shouldHandleIncomingMessages() throws Exception {
// Prepare test input
String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}";
testIn = new ByteArrayInputStream((jsonMessage + "\n").getBytes(StandardCharsets.UTF_8));

// Create transport with test streams
transport = new StdioServerTransport(objectMapper);
Expand Down
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@

<!-- Output test execution times in the logs -->
<redirectTestOutputToFile>false</redirectTestOutputToFile>

</configuration>
</plugin>
<plugin>
Expand Down

0 comments on commit ad3cdd2

Please sign in to comment.