Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import software.amazon.smithy.java.mcp.model.JsonRpcErrorResponse;
import software.amazon.smithy.java.mcp.model.JsonRpcRequest;
import software.amazon.smithy.java.mcp.model.JsonRpcResponse;
import software.amazon.smithy.model.shapes.ShapeType;
import software.amazon.smithy.utils.SmithyUnstableApi;

@SmithyUnstableApi
Expand Down Expand Up @@ -208,7 +207,7 @@ private JsonRpcResponse parseSseResponse(HttpResponse response, JsonRpcRequest r
// This is a notification - convert Document to JsonRpcRequest and forward
JsonRpcRequest notification = jsonDocument.asShape(JsonRpcRequest.builder());
LOG.debug("Received notification from SSE stream: method={}", notification.getMethod());
notifyRequest(notification);
notify(notification);
} else {
// This is a response - convert Document to JsonRpcResponse
finalResponse = jsonDocument.asShape(JsonRpcResponse.builder());
Expand Down Expand Up @@ -236,14 +235,14 @@ private JsonRpcResponse parseSseResponse(HttpResponse response, JsonRpcRequest r
.build();
LOG.debug("Received notification from remaining SSE buffer: method={}",
notification.getMethod());
notifyRequest(notification);
notify(notification);
} else {
JsonRpcResponse message = JsonRpcResponse.builder()
.deserialize(jsonDocument.createDeserializer())
.build();

if (message.getId() == null) {
notifyRequest(JsonRpcRequest.builder()
notify(JsonRpcRequest.builder()
.jsonrpc("2.0")
.method("notifications/unknown")
.build());
Expand Down Expand Up @@ -282,27 +281,6 @@ private JsonRpcResponse parseSseResponse(HttpResponse response, JsonRpcRequest r
}
}

/**
* Determines if a Document represents a notification (has "method" but no "id")
* rather than a response (has "id").
*
* - Responses have an "id" field at the top level
* - Notifications have a "method" field but no "id" field at the top level
*/
private boolean isNotification(Document doc) {
try {
if (!doc.isType(ShapeType.STRUCTURE) && !doc.isType(ShapeType.MAP)) {
return false;
}

// If it has a "method" field but no "id", it's a notification
return doc.getMember("id") == null && doc.getMember("method") != null;
} catch (Exception e) {
LOG.warn("Failed to determine if notification from Document", e);
return false;
}
}

private JsonRpcResponse handleErrorResponse(HttpResponse response) {
long contentLength = response.body().contentLength();
String errorMessage = "HTTP " + response.statusCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import software.amazon.smithy.java.mcp.model.ListToolsResult;
import software.amazon.smithy.java.mcp.model.PromptInfo;
import software.amazon.smithy.java.mcp.model.ToolInfo;
import software.amazon.smithy.model.shapes.ShapeType;

public abstract class McpServerProxy {

Expand Down Expand Up @@ -81,6 +82,14 @@ public void initialize(
if (result.getError() != null) {
throw new RuntimeException("Error during initialization: " + result.getError().getMessage());
}

// Send the initialized notification per MCP protocol spec
JsonRpcRequest initializedNotification = JsonRpcRequest.builder()
.method("notifications/initialized")
.jsonrpc("2.0")
.build();
rpc(initializedNotification);

this.notificationConsumer.set(notificationConsumer);
this.requestNotificationConsumer.set(requestNotificationConsumer);
this.protocolVersion.set(protocolVersion);
Expand Down Expand Up @@ -127,7 +136,7 @@ protected void notify(JsonRpcResponse response) {
* Forwards a notification request by converting it to a response format.
* Notifications have a method field but no id.
*/
protected void notifyRequest(JsonRpcRequest notification) {
protected void notify(JsonRpcRequest notification) {
var rnc = requestNotificationConsumer.get();
if (rnc != null) {
LOG.debug("Forwarding notification to consumer: method={}", notification.getMethod());
Expand All @@ -138,5 +147,26 @@ protected void notifyRequest(JsonRpcRequest notification) {
}
}

/**
* Determines if a Document represents a notification (has "method" but no "id")
* rather than a response (has "id").
*
* - Responses have an "id" field at the top level
* - Notifications have a "method" field but no "id" field at the top level
*/
protected static boolean isNotification(Document doc) {
try {
if (!doc.isType(ShapeType.STRUCTURE) && !doc.isType(ShapeType.MAP)) {
return false;
}

// If it has a "method" field but no "id", it's a notification
return doc.getMember("id") == null && doc.getMember("method") != null;
} catch (Exception e) {
LOG.warn("Failed to determine if notification from Document", e);
return false;
}
}

public abstract String name();
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,25 @@ public CompletableFuture<JsonRpcResponse> rpc(JsonRpcRequest request) {
return future;
}

// Notifications don't have an ID and don't expect a response
if (request.getId() == null) {
try {
writeLock.lock();
String serializedRequest = JSON_CODEC.serializeToString(request);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do serialization outside of the lock

LOG.debug("Sending notification: {}", serializedRequest);
writer.write(serializedRequest);
writer.newLine();
writer.flush();
} catch (IOException e) {
LOG.error("Error sending notification to MCP server", e);
return CompletableFuture.failedFuture(
new RuntimeException("Failed to send notification to MCP server: " + e.getMessage(), e));
} finally {
writeLock.unlock();
}
return CompletableFuture.completedFuture(null);
}

String requestId = getStringRequestId(request.getId());
CompletableFuture<JsonRpcResponse> responseFuture = new CompletableFuture<>();
pendingRequests.put(requestId, responseFuture);
Expand Down Expand Up @@ -185,20 +204,23 @@ public synchronized void start() {
}

LOG.debug("Received response: {}", responseLine);
JsonRpcResponse response = JsonRpcResponse.builder()
.deserialize(
JSON_CODEC.createDeserializer(
responseLine.getBytes(StandardCharsets.UTF_8)))
.build();

String responseId = getStringRequestId(response.getId());
LOG.debug("Processing response ID: {}", responseId);

CompletableFuture<JsonRpcResponse> future = pendingRequests.remove(responseId);
if (future != null) {
future.complete(response);
var output =
JSON_CODEC.createDeserializer(responseLine.getBytes(StandardCharsets.UTF_8))
.readDocument();
if (isNotification(output)) {
notify(output.asShape(JsonRpcRequest.builder()));
} else {
notify(response);
JsonRpcResponse response = output.asShape(JsonRpcResponse.builder());

String responseId = getStringRequestId(response.getId());
LOG.debug("Processing response ID: {}", responseId);

CompletableFuture<JsonRpcResponse> future = pendingRequests.remove(responseId);
if (future != null) {
future.complete(response);
} else {
notify(response);
}
}
} catch (IOException e) {
if (running) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,10 @@ void testToolsListChangedNotificationInvalidatesCache() {
.build();
service.handleRequest(initRequest, r -> {}, ProtocolVersion.defaultVersion());

// Verify notifications/initialized was sent during initialization
assertTrue(mockProxy.getSentNotifications().contains("notifications/initialized"),
"notifications/initialized should be sent during initialization");

// First tools/list - fetches from proxy
var toolsRequest = JsonRpcRequest.builder()
.method("tools/list")
Expand Down Expand Up @@ -1488,6 +1492,10 @@ void testOtherNotificationsDoNotInvalidateCache() {
.build();
service.handleRequest(initRequest, r -> {}, ProtocolVersion.defaultVersion());

// Verify notifications/initialized was sent during initialization
assertTrue(mockProxy.getSentNotifications().contains("notifications/initialized"),
"notifications/initialized should be sent during initialization");

// First tools/list
var toolsRequest = JsonRpcRequest.builder()
.method("tools/list")
Expand Down Expand Up @@ -1516,6 +1524,7 @@ void testOtherNotificationsDoNotInvalidateCache() {

private static class CacheTestProxy extends McpServerProxy {
private final AtomicInteger callCounter;
private final List<String> sentNotifications = new ArrayList<>();

CacheTestProxy(AtomicInteger callCounter) {
this.callCounter = callCounter;
Expand All @@ -1539,6 +1548,11 @@ public List<software.amazon.smithy.java.mcp.model.PromptInfo> listPrompts() {

@Override
CompletableFuture<JsonRpcResponse> rpc(JsonRpcRequest request) {
// Notifications have no ID
if (request.getId() == null) {
sentNotifications.add(request.getMethod());
return CompletableFuture.completedFuture(null);
}
return CompletableFuture.completedFuture(
JsonRpcResponse.builder()
.id(request.getId())
Expand All @@ -1547,6 +1561,10 @@ CompletableFuture<JsonRpcResponse> rpc(JsonRpcRequest request) {
.build());
}

List<String> getSentNotifications() {
return sentNotifications;
}

@Override
void start() {}

Expand All @@ -1561,7 +1579,7 @@ public String name() {
}

void sendNotification(JsonRpcRequest notification) {
notifyRequest(notification);
notify(notification);
}
}
}