From 39e8445d00dd2f0b650dc214ea763f0238700fe4 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 15 Jul 2025 11:51:15 +0530 Subject: [PATCH 01/31] [Automated] Update the native jar versions --- ballerina/Dependencies.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index a8c757c..bccd9b2 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -5,7 +5,7 @@ [ballerina] dependencies-toml-version = "2" -distribution-version = "2201.12.7" +distribution-version = "2201.12.0" [[package]] org = "ballerina" From 2abc5763d162c29fc7bf489dfeb9ce083f9b5a3a Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Mon, 9 Jun 2025 09:48:41 +0530 Subject: [PATCH 02/31] Implement initial version of MCP server --- ballerina/CompilerPlugin.toml | 9 + ballerina/Dependencies.toml | 3 + ballerina/basic_dispatcher_service.bal | 252 +++++++++++++++++ ballerina/basic_listener.bal | 85 ++++++ ballerina/build.gradle | 9 + ballerina/dispatcher_service.bal | 253 ++++++++++++++++++ ballerina/listener.bal | 84 ++++++ ballerina/main.bal | 89 ++++++ ballerina/native_listener.bal | 17 ++ ballerina/types.bal | 65 ++++- build-config/resources/CompilerPlugin.toml | 9 + compiler-plugin/build.gradle | 93 +++++++ .../stdlib/mcp/plugin/McpCodeModifier.java | 39 +++ .../stdlib/mcp/plugin/McpCompilerPlugin.java | 30 +++ .../stdlib/mcp/plugin/McpSourceModifier.java | 192 +++++++++++++ .../stdlib/mcp/plugin/ModifierContext.java | 52 ++++ .../plugin/RemoteFunctionAnalysisTask.java | 145 ++++++++++ .../stdlib/mcp/plugin/SchemaUtils.java | 192 +++++++++++++ .../io/ballerina/stdlib/mcp/plugin/Utils.java | 69 +++++ .../diagnostics/CompilationDiagnostic.java | 66 +++++ .../plugin/diagnostics/DiagnosticCode.java | 26 ++ .../plugin/diagnostics/DiagnosticMessage.java | 37 +++ .../src/main/java/module-info.java | 26 ++ gradle.properties | 3 +- .../stdlib/mcp/McpServiceMethodHelper.java | 195 ++++++++++++++ settings.gradle | 2 + spotbugs-exclude.xml | 4 + 27 files changed, 2044 insertions(+), 2 deletions(-) create mode 100644 ballerina/CompilerPlugin.toml create mode 100644 ballerina/basic_dispatcher_service.bal create mode 100644 ballerina/basic_listener.bal create mode 100644 ballerina/dispatcher_service.bal create mode 100644 ballerina/listener.bal create mode 100644 ballerina/main.bal create mode 100644 ballerina/native_listener.bal create mode 100644 build-config/resources/CompilerPlugin.toml create mode 100644 compiler-plugin/build.gradle create mode 100644 compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java create mode 100644 compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java create mode 100644 compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java create mode 100644 compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java create mode 100644 compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java create mode 100644 compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/SchemaUtils.java create mode 100644 compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java create mode 100644 compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java create mode 100644 compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java create mode 100644 compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java create mode 100644 compiler-plugin/src/main/java/module-info.java create mode 100644 native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java diff --git a/ballerina/CompilerPlugin.toml b/ballerina/CompilerPlugin.toml new file mode 100644 index 0000000..f770d36 --- /dev/null +++ b/ballerina/CompilerPlugin.toml @@ -0,0 +1,9 @@ +[plugin] +id = "mcp-compiler-plugin" +class = "io.ballerina.stdlib.mcp.plugin.McpCompilerPlugin" + +[[dependency]] +path = "../compiler-plugin/build/libs/mcp-compiler-plugin-0.4.1-SNAPSHOT.jar" + +[[dependency]] +path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar" diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index bccd9b2..18521e4 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -314,4 +314,7 @@ dependencies = [ {org = "ballerina", name = "lang.int"}, {org = "ballerina", name = "time"} ] +modules = [ + {org = "ballerina", packageName = "uuid", moduleName = "uuid"} +] diff --git a/ballerina/basic_dispatcher_service.bal b/ballerina/basic_dispatcher_service.bal new file mode 100644 index 0000000..ca5712b --- /dev/null +++ b/ballerina/basic_dispatcher_service.bal @@ -0,0 +1,252 @@ +// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/http; +import ballerina/io; +// import ballerina/io; +import ballerina/uuid; + +type BasicDispatcherService distinct service object { + *http:Service; + + isolated function addServiceRef(BasicMcpService basicMcpService); + isolated function removeServiceRef(); + isolated function setServerConfigs(ServerConfiguration serverConfigs); +}; + +BasicDispatcherService basicDispatcherService = isolated service object { + private ServerConfiguration? serverConfigs = (); + private BasicMcpService? basicMcpService = (); + private boolean isInitialized = false; + private string? sessionId = (); + + isolated function addServiceRef(BasicMcpService basicMcpService) { + lock { + self.basicMcpService = basicMcpService; + } + } + + isolated function removeServiceRef() { + lock { + self.basicMcpService = (); + } + } + + isolated function setServerConfigs(ServerConfiguration serverConfigs) { + lock { + self.serverConfigs = serverConfigs.cloneReadOnly(); + } + } + + isolated resource function get .() returns error? { + + } + + isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) returns http:BadRequest|http:Accepted|http:Ok|error { + lock { + io:println("Received request: ", request.cloneReadOnly()); + string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); + if acceptHeader is http:HeaderNotFoundError { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: 1, + 'error: { + code: -32000, + message: "Not Acceptable: Client must accept both application/json and text/event-stream" + } + } + }; + } + if !acceptHeader.includes(CONTENT_TYPE_JSON) || !acceptHeader.includes(CONTENT_TYPE_SSE) { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: 1, + 'error: { + code: -32000, + message: "Not Acceptable: Client must accept both application/json and text/event-stream" + } + } + }; + } + + string|http:HeaderNotFoundError contentTypeHeader = headers.getHeader(CONTENT_TYPE_HEADER); + if contentTypeHeader is http:HeaderNotFoundError { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: null, + 'error: { + code: -32000, + message: "Unsupported Media Type: Content-Type must be application/json" + } + } + }; + } + if !contentTypeHeader.includes(CONTENT_TYPE_JSON) { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: null, + 'error: { + code: -32000, + message: "Unsupported Media Type: Content-Type must be application/json" + } + } + }; + } + + if request is JsonRpcRequest { + if request.method == "initialize" { + if self.isInitialized && self.sessionId != () { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: null, + 'error: { + code: -32600, + message: "Invalid Request: Only one initialization request is allowed" + } + } + }; + } + self.isInitialized = true; + self.sessionId = uuid:createRandomUuid(); + + final string requestedVersion = check (request.params["protocolVersion"]).cloneWithType(); + final readonly & ServerCapabilities? capabilities = (self.serverConfigs?.options?.capabilities).cloneReadOnly(); + final readonly & Implementation? serverInfo = (self.serverConfigs?.serverInfo).cloneReadOnly(); + + if serverInfo is () { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: null, + 'error: { + code: -32000, + message: "Server Info not provided in configuration" + } + } + }; + } + + string protocolVersion = SUPPORTED_PROTOCOL_VERSIONS.some(v => v == requestedVersion) ? requestedVersion + : LATEST_PROTOCOL_VERSION; + + return { + headers: { + [SESSION_ID_HEADER]: self.sessionId ?: "" + }, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: { + protocolVersion: protocolVersion, + capabilities: capabilities ?: {}, + serverInfo: serverInfo + } + } + }; + } else if request.method == "tools/list" { + ListToolsResult listToolsResult = check self.executeListTools(); + io:println("ListToolsResult: ", listToolsResult.cloneReadOnly()); + return { + headers: { + [SESSION_ID_HEADER]: self.sessionId ?: "" + }, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: listToolsResult.cloneReadOnly() + } + }; + } else if request.method == "tools/call" { + CallToolParams params = check request.cloneReadOnly().params.ensureType(CallToolParams); + CallToolResult callToolResult = check self.executeCallTool(params); + return { + headers: { + [SESSION_ID_HEADER]: self.sessionId ?: "" + }, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: callToolResult.cloneReadOnly() + } + }; + } else { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + 'error: { + code: -32601, + message: "Method not found" + } + } + }; + } + } + else if request is JsonRpcNotification { + if request.method == "notifications/initialized" { + return http:ACCEPTED; + } + } + } + // if request is InitializeRequest { + // lock { + // if self.isInitialized && self.sessionId != () { + // return { + // jsonrpc: JSONRPC_VERSION, + // id: null, + // 'error: { + // code: -32600, + // message: "Invalid Request: Only one initialization request is allowed" + // } + // }; + // } + // self.isInitialized = true; + // self.sessionId = uuid:createRandomUuid(); + // io:println("Session initialized with ID: ", self.sessionId); + // } + // } else if request is ListToolsRequest { + // io:println("Received ListToolsRequest"); + // } else if request is CallToolRequest { + // io:println("Received CallToolRequest"); + // } + return error("Unsupported request type"); + } + + private isolated function executeListTools() returns ListToolsResult|error { + lock { + BasicMcpService? chatService = self.basicMcpService; + if chatService is BasicMcpService { + return check listToolsForRemoteFunctions(chatService); + } + return error("MCP Service is not attached"); + } + } + + private isolated function executeCallTool(CallToolParams params) returns CallToolResult|error { + lock { + BasicMcpService? chatService = self.basicMcpService; + if chatService is BasicMcpService { + return check callToolForRemoteFunctions(chatService, params.cloneReadOnly()); + } + return error("MCP Service is not attached"); + } + } +}; diff --git a/ballerina/basic_listener.bal b/ballerina/basic_listener.bal new file mode 100644 index 0000000..42c1028 --- /dev/null +++ b/ballerina/basic_listener.bal @@ -0,0 +1,85 @@ +// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/http; + +public type McpToolConfig record {| + string description?; + json schema?; +|}; + +public annotation McpToolConfig McpTool on object function; + +# A server listener for handling MCP service requests. +public class BasicListener { + private http:Listener httpListener; + private BasicDispatcherService dispatcherService; + + # Initializes the Listener. + # + # + listenTo - Either a port number (int) or an existing http:Listener. + # + config - Optional listener configuration. + # + return - error? if listener initialization fails. + public function init(int|http:Listener listenTo, ServerConfiguration serverConfigs, *ListenerConfiguration config) returns error? { + if listenTo is http:Listener { + self.httpListener = listenTo; + } else { + self.httpListener = check new (listenTo, config); + } + self.dispatcherService = basicDispatcherService; + self.dispatcherService.setServerConfigs(serverConfigs); + } + + # Attaches an MCP service to the listener under the specified path(s). + # + # + mcpService - Service to attach. + # + name - Path(s) to mount the service on (string or string array). + # + return - error? if attachment fails. + public isolated function attach(BasicMcpService basicMcpService, string[]|string? name = ()) returns error? { + check self.httpListener.attach(self.dispatcherService, name); + self.dispatcherService.addServiceRef(basicMcpService); + } + + # Detaches the MCP service from the listener. + # + # + mcpService - Service to detach. + # + return - error? if detachment fails. + public isolated function detach(BasicMcpService basicMcpService) returns error? { + check self.httpListener.detach(self.dispatcherService); + self.dispatcherService.removeServiceRef(); + } + + # Starts the listener (begin accepting connections). + # + # + return - error? if starting fails. + public isolated function 'start() returns error? { + check self.httpListener.start(); + } + + # Gracefully stops the listener (completes active requests before shutting down). + # + # + return - error? if graceful stop fails. + public isolated function gracefulStop() returns error? { + check self.httpListener.gracefulStop(); + } + + # Immediately stops the listener (terminates all connections). + # + # + return - error? if immediate stop fails. + public isolated function immediateStop() returns error? { + check self.httpListener.immediateStop(); + } +} diff --git a/ballerina/build.gradle b/ballerina/build.gradle index 37951cf..4e1ef4e 100644 --- a/ballerina/build.gradle +++ b/ballerina/build.gradle @@ -27,7 +27,9 @@ def packageName = "mcp" def packageOrg = "ballerina" def tomlVersion = stripBallerinaExtensionVersion("${project.version}") def ballerinaTomlFilePlaceHolder = new File("${project.rootDir}/build-config/resources/Ballerina.toml") +def compilerPluginTomlFilePlaceHolder = new File("${project.rootDir}/build-config/resources/CompilerPlugin.toml") def ballerinaTomlFile = new File("$project.projectDir/Ballerina.toml") +def compilerPluginTomlFile = new File("$project.projectDir/CompilerPlugin.toml") def stripBallerinaExtensionVersion(String extVersion) { if (extVersion.matches(project.ext.timestampedVersionRegex)) { @@ -55,6 +57,11 @@ task updateTomlFiles { def newConfig = ballerinaTomlFilePlaceHolder.text.replace("@project.version@", project.version) newConfig = newConfig.replace("@toml.version@", tomlVersion) ballerinaTomlFile.text = newConfig + + def ballerinaToOpenApiVersion = project.ballerinaToOpenApiVersion + def newCompilerPluginToml = compilerPluginTomlFilePlaceHolder.text.replace("@project.version@", project.version) + newCompilerPluginToml = newCompilerPluginToml.replace("@ballerinaToOpenApiVersion.version@", ballerinaToOpenApiVersion) + compilerPluginTomlFile.text = newCompilerPluginToml } } @@ -94,7 +101,9 @@ updateTomlFiles.dependsOn copyStdlibs build.dependsOn "generatePomFileForMavenPublication" build.dependsOn ":${packageName}-native:build" +build.dependsOn ":${packageName}-compiler-plugin:build" test.dependsOn ":${packageName}-native:build" +test.dependsOn ":${packageName}-compiler-plugin:build" publishToMavenLocal.dependsOn build publish.dependsOn build diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal new file mode 100644 index 0000000..67dca26 --- /dev/null +++ b/ballerina/dispatcher_service.bal @@ -0,0 +1,253 @@ +// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/http; +import ballerina/io; +import ballerina/uuid; + +// import ballerina/io; + +type DispatcherService distinct service object { + *http:Service; + + isolated function addServiceRef(McpService mcpService); + isolated function removeServiceRef(); + isolated function setServerConfigs(ServerConfiguration serverConfigs); +}; + +DispatcherService dispatcherService = isolated service object { + private ServerConfiguration? serverConfigs = (); + private McpService? mcpService = (); + private boolean isInitialized = false; + private string? sessionId = (); + + isolated function addServiceRef(McpService mcpService) { + lock { + self.mcpService = mcpService; + } + } + + isolated function removeServiceRef() { + lock { + self.mcpService = (); + } + } + + isolated function setServerConfigs(ServerConfiguration serverConfigs) { + lock { + self.serverConfigs = serverConfigs.cloneReadOnly(); + } + } + + isolated resource function get .() returns error? { + + } + + isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) returns http:BadRequest|http:Accepted|http:Ok|error { + lock { + io:println("Received request: ", request.cloneReadOnly()); + string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); + if acceptHeader is http:HeaderNotFoundError { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: 1, + 'error: { + code: -32000, + message: "Not Acceptable: Client must accept both application/json and text/event-stream" + } + } + }; + } + if !acceptHeader.includes(CONTENT_TYPE_JSON) || !acceptHeader.includes(CONTENT_TYPE_SSE) { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: 1, + 'error: { + code: -32000, + message: "Not Acceptable: Client must accept both application/json and text/event-stream" + } + } + }; + } + + string|http:HeaderNotFoundError contentTypeHeader = headers.getHeader(CONTENT_TYPE_HEADER); + if contentTypeHeader is http:HeaderNotFoundError { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: null, + 'error: { + code: -32000, + message: "Unsupported Media Type: Content-Type must be application/json" + } + } + }; + } + if !contentTypeHeader.includes(CONTENT_TYPE_JSON) { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: null, + 'error: { + code: -32000, + message: "Unsupported Media Type: Content-Type must be application/json" + } + } + }; + } + + if request is JsonRpcRequest { + if request.method == "initialize" { + if self.isInitialized && self.sessionId != () { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: null, + 'error: { + code: -32600, + message: "Invalid Request: Only one initialization request is allowed" + } + } + }; + } + self.isInitialized = true; + self.sessionId = uuid:createRandomUuid(); + + final string requestedVersion = check (request.params["protocolVersion"]).cloneWithType(); + final readonly & ServerCapabilities? capabilities = (self.serverConfigs?.options?.capabilities).cloneReadOnly(); + final readonly & Implementation? serverInfo = (self.serverConfigs?.serverInfo).cloneReadOnly(); + + if serverInfo is () { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: null, + 'error: { + code: -32000, + message: "Server Info not provided in configuration" + } + } + }; + } + + string protocolVersion = SUPPORTED_PROTOCOL_VERSIONS.some(v => v == requestedVersion) ? requestedVersion + : LATEST_PROTOCOL_VERSION; + + return { + headers: { + [SESSION_ID_HEADER]: self.sessionId ?: "" + }, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: { + protocolVersion: protocolVersion, + capabilities: capabilities ?: {}, + serverInfo: serverInfo + } + } + }; + } else if request.method == "tools/list" { + ListToolsResult listToolsResult = check self.executeOnListTools(); + io:println("ListToolsResult: ", listToolsResult.cloneReadOnly()); + return { + headers: { + [SESSION_ID_HEADER]: self.sessionId ?: "" + }, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: listToolsResult.cloneReadOnly() + } + }; + } else if request.method == "tools/call" { + CallToolParams params = check request.cloneReadOnly().params.ensureType(CallToolParams); + CallToolResult callToolResult = check self.executeOnCallTool(params); + return { + headers: { + [SESSION_ID_HEADER]: self.sessionId ?: "" + }, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: callToolResult.cloneReadOnly() + } + }; + } else { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + 'error: { + code: -32601, + message: "Method not found" + } + } + }; + } + } + else if request is JsonRpcNotification { + if request.method == "notifications/initialized" { + return http:ACCEPTED; + } + } + } + // if request is InitializeRequest { + // lock { + // if self.isInitialized && self.sessionId != () { + // return { + // jsonrpc: JSONRPC_VERSION, + // id: null, + // 'error: { + // code: -32600, + // message: "Invalid Request: Only one initialization request is allowed" + // } + // }; + // } + // self.isInitialized = true; + // self.sessionId = uuid:createRandomUuid(); + // io:println("Session initialized with ID: ", self.sessionId); + // } + // } else if request is ListToolsRequest { + // io:println("Received ListToolsRequest"); + // } else if request is CallToolRequest { + // io:println("Received CallToolRequest"); + // } + return error("Unsupported request type"); + } + + private isolated function executeOnListTools() returns ListToolsResult|error { + lock { + McpService? chatService = self.mcpService; + if chatService is McpService { + return check invokeOnListTools(chatService); + } + return error("MCP Service is not attached"); + } + } + + private isolated function executeOnCallTool(CallToolParams params) returns CallToolResult|error { + lock { + McpService? chatService = self.mcpService; + if chatService is McpService { + return check invokeOnCallTool(chatService, params.cloneReadOnly()); + } + return error("MCP Service is not attached"); + } + } +}; diff --git a/ballerina/listener.bal b/ballerina/listener.bal new file mode 100644 index 0000000..9a065af --- /dev/null +++ b/ballerina/listener.bal @@ -0,0 +1,84 @@ +// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/http; + +public type ServerOptions record {| + *ProtocolOptions; + ServerCapabilities capabilities?; + string instructions?; +|}; + +# A server listener for handling MCP service requests. +public class Listener { + private http:Listener httpListener; + private DispatcherService dispatcherService; + + # Initializes the Listener. + # + # + listenTo - Either a port number (int) or an existing http:Listener. + # + config - Optional listener configuration. + # + return - error? if listener initialization fails. + public function init(int|http:Listener listenTo, ServerConfiguration serverConfigs, *ListenerConfiguration config) returns error? { + if listenTo is http:Listener { + self.httpListener = listenTo; + } else { + self.httpListener = check new (listenTo, config); + } + self.dispatcherService = dispatcherService; + self.dispatcherService.setServerConfigs(serverConfigs); + } + + # Attaches an MCP service to the listener under the specified path(s). + # + # + mcpService - Service to attach. + # + name - Path(s) to mount the service on (string or string array). + # + return - error? if attachment fails. + public isolated function attach(McpService mcpService, string[]|string? name = ()) returns error? { + check self.httpListener.attach(self.dispatcherService, name); + self.dispatcherService.addServiceRef(mcpService); + } + + # Detaches the MCP service from the listener. + # + # + mcpService - Service to detach. + # + return - error? if detachment fails. + public isolated function detach(McpService mcpService) returns error? { + check self.httpListener.detach(self.dispatcherService); + self.dispatcherService.removeServiceRef(); + } + + # Starts the listener (begin accepting connections). + # + # + return - error? if starting fails. + public isolated function 'start() returns error? { + check self.httpListener.start(); + } + + # Gracefully stops the listener (completes active requests before shutting down). + # + # + return - error? if graceful stop fails. + public isolated function gracefulStop() returns error? { + check self.httpListener.gracefulStop(); + } + + # Immediately stops the listener (terminates all connections). + # + # + return - error? if immediate stop fails. + public isolated function immediateStop() returns error? { + check self.httpListener.immediateStop(); + } +} diff --git a/ballerina/main.bal b/ballerina/main.bal new file mode 100644 index 0000000..863715f --- /dev/null +++ b/ballerina/main.bal @@ -0,0 +1,89 @@ + +// listener Listener mcpListener = check new (9090, serverConfigs = { +// serverInfo: { +// name: "MCP Server", +// version: "1.0.0" +// }, +// options: {capabilities: {}} +// }); + +// service /mcp on mcpListener { +// remote isolated function onListTools() returns ListToolsResult|error { +// return { +// tools: [ +// { +// name: "single-greet", +// description: "Greet the user once", +// inputSchema: { +// 'type: "object", +// properties: { +// "name": {"type": "string", "description": "Name to greet"} +// }, +// required: ["name"] +// } +// }, +// { +// name: "multi-greet", +// description: "Greet the user multiple times with delay in between.", +// inputSchema: { +// 'type: "object", +// properties: { +// "name": {"type": "string", "description": "Name to greet"} +// }, +// required: ["name"] +// } +// } +// ] +// }; +// } + +// remote isolated function onCallTool(CallToolParams params) returns CallToolResult|error { +// string name = check (params.arguments["name"]).cloneWithType(); +// if params.name == "single-greet" { +// // Note: Can do any external function calls here, +// TextContent textContent = { +// 'type: "text", +// text: string `Hey ${name}! Welcome to itsuki's world!` +// }; +// return { +// content: [textContent] +// }; +// } else if params.name == "multi-greet" { +// // Note: Can do any external function calls here, +// TextContent textContent = { +// 'type: "text", +// text: string `Hey ${name}! Hope you enjoy your day!` +// }; +// return { +// content: [textContent] +// }; +// } else { +// return error("Unknown tool: " + params.name); +// } +// } +// } + +listener BasicListener basicListener = check new (9091, serverConfigs = { + serverInfo: { + name: "Basic MCP Server", + version: "1.0.0" + }, + options: {capabilities: {}} +}); + +service /mcp on basicListener { + @McpTool { + description: "Add two numbers", + schema: { + 'type: "object", + properties: { + "a": {"type": "integer", "description": "First number"}, + "b": {"type": "integer", "description": "Second number"} + }, + required: ["a", "b"] + } + } + remote function add(int a, int b) returns int { + return a + b; + } +} diff --git a/ballerina/native_listener.bal b/ballerina/native_listener.bal new file mode 100644 index 0000000..06cc428 --- /dev/null +++ b/ballerina/native_listener.bal @@ -0,0 +1,17 @@ +import ballerina/jballerina.java; + +isolated function invokeOnListTools(McpService 'service) returns ListToolsResult|error = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; + +isolated function invokeOnCallTool(McpService 'service, CallToolParams params) returns CallToolResult|error = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; + +isolated function listToolsForRemoteFunctions(BasicMcpService 'service, typedesc t = <>) returns t|error = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; + +isolated function callToolForRemoteFunctions(BasicMcpService 'service, CallToolParams params, typedesc t = <>) returns t|error = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; diff --git a/ballerina/types.bal b/ballerina/types.bal index ca83cda..5332bfc 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -14,6 +14,8 @@ // specific language governing permissions and limitations // under the License. +import ballerina/http; + # Refers to any valid JSON-RPC object that can be decoded off the wire, or encoded to be sent. public type JsonRpcMessage JsonRpcRequest|JsonRpcNotification|JsonRpcResponse; @@ -99,6 +101,48 @@ public type JsonRpcResponse record {| ServerResult result; |}; +// Standard JSON-RPC error codes +public const PARSE_ERROR = -32700; +public const INVALID_REQUEST = -32600; +public const METHOD_NOT_FOUND = -32601; +public const INVALID_PARAMS = -32602; +public const INTERNAL_ERROR = -32603; + +# A response to a request that indicates an error occurred. +public type JsonRpcError record { + # The JSON-RPC protocol version + JSONRPC_VERSION jsonrpc; + # Identifier of the request + RequestId? id; + # The error information + record { + # The error type that occurred + int code; + # A short description of the error. The message SHOULD be limited to a concise single sentence. + string message; + # Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.). + anydata data?; + } 'error; +}; + +# A response that indicates success but carries no data. +public type EmptyResult Result; + +# This notification can be sent by either side to indicate that it is cancelling a previously-issued request. +public type CancelledNotification record {| + *Notification; + # The method name for this notification + NOTIFICATION_CANCELLED method; + # The parameters for the cancellation notification + record {| + # The ID of the request to cancel. + # This MUST correspond to the ID of a request previously issued in the same direction. + RequestId requestId; + # An optional string describing the reason for the cancellation. This MAY be logged or presented to the user. + string? reason = (); + |} params; +|}; + # This request is sent from the client to the server when it first connects, asking it to begin initialization. type InitializeRequest record {| *Request; @@ -368,4 +412,23 @@ public type AudioContent record { }; # Represents a result sent from the server to the client. -public type ServerResult InitializeResult|CallToolResult|ListToolsResult; +public type ServerResult InitializeResult|CallToolResult|ListToolsResult|EmptyResult; + +# Defines a mcp service interface that handles incoming mcp requests. +public type McpService distinct isolated service object { + remote isolated function onListTools() returns ListToolsResult|error; + remote isolated function onCallTool(CallToolParams params) returns CallToolResult|error; +}; + +public type BasicMcpService distinct isolated service object { + +}; + +public type ListenerConfiguration record {| + *http:ListenerConfiguration; +|}; + +public type ServerConfiguration record {| + Implementation serverInfo; + ServerOptions options?; +|}; diff --git a/build-config/resources/CompilerPlugin.toml b/build-config/resources/CompilerPlugin.toml new file mode 100644 index 0000000..beecd8a --- /dev/null +++ b/build-config/resources/CompilerPlugin.toml @@ -0,0 +1,9 @@ +[plugin] +id = "mcp-compiler-plugin" +class = "io.ballerina.stdlib.mcp.plugin.McpCompilerPlugin" + +[[dependency]] +path = "../compiler-plugin/build/libs/mcp-compiler-plugin-@project.version@.jar" + +[[dependency]] +path = "../compiler-plugin/build/libs/ballerina-to-openapi-@ballerinaToOpenApiVersion.version@.jar" diff --git a/compiler-plugin/build.gradle b/compiler-plugin/build.gradle new file mode 100644 index 0000000..e67dd9f --- /dev/null +++ b/compiler-plugin/build.gradle @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +plugins { + id 'java' + id 'checkstyle' + id 'com.github.spotbugs' +} + +description = 'Ballerina - MCP Package Compiler Plugin' + +configurations { + externalJars +} + +dependencies { + checkstyle project(':checkstyle') + checkstyle "com.puppycrawl.tools:checkstyle:${checkstylePluginVersion}" + + implementation group: 'org.ballerinalang', name: 'ballerina-lang', version: "${ballerinaLangVersion}" + implementation group: 'org.ballerinalang', name: 'ballerina-tools-api', version: "${ballerinaLangVersion}" + implementation group: 'org.ballerinalang', name: 'ballerina-parser', version: "${ballerinaLangVersion}" + implementation group: 'io.ballerina.openapi', name: 'ballerina-to-openapi', version: "${ballerinaToOpenApiVersion}" + implementation group: 'io.swagger.core.v3', name: 'swagger-core', version: "${swaggerVersion}" + implementation group: 'io.swagger.core.v3', name: 'swagger-models', version: "${swaggerVersion}" + + implementation project(":mcp-native") + externalJars group: 'io.ballerina.openapi', name: 'ballerina-to-openapi', version: "${ballerinaToOpenApiVersion}" +} + +def excludePattern = '**/module-info.java' +tasks.withType(Checkstyle) { + exclude excludePattern +} + +checkstyle { + toolVersion "${project.checkstylePluginVersion}" + configFile rootProject.file("build-config/checkstyle/build/checkstyle.xml") + configProperties = ["suppressionFile": file("${rootDir}/build-config/checkstyle/build/suppressions.xml")] +} + +checkstyleMain.dependsOn(":checkstyle:downloadCheckstyleRuleFiles") + +spotbugsMain { + def classLoader = plugins["com.github.spotbugs"].class.classLoader + def SpotBugsConfidence = classLoader.findLoadedClass("com.github.spotbugs.snom.Confidence") + def SpotBugsEffort = classLoader.findLoadedClass("com.github.spotbugs.snom.Effort") + effort = SpotBugsEffort.MAX + reportLevel = SpotBugsConfidence.LOW + reportsDir = file("$project.buildDir/reports/spotbugs") + reports { + html.enabled true + text.enabled = true + } + def excludeFile = file("${rootDir}/spotbugs-exclude.xml") + if (excludeFile.exists()) { + excludeFilter = excludeFile + } +} + +compileJava { + doFirst { + options.compilerArgs = [ + '--module-path', classpath.asPath, + ] + classpath = files() + } +} + +build.dependsOn ":mcp-native:build" + +task copyOpenApiJar(type: Copy) { + from { + configurations.externalJars.collect { it } + } + into "${buildDir}/libs" +} + +build.dependsOn copyOpenApiJar diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java new file mode 100644 index 0000000..bdce349 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.projects.DocumentId; +import io.ballerina.projects.plugins.CodeModifier; +import io.ballerina.projects.plugins.CodeModifierContext; + +import java.util.HashMap; +import java.util.Map; + +import static io.ballerina.compiler.syntax.tree.SyntaxKind.OBJECT_METHOD_DEFINITION; + +public class McpCodeModifier extends CodeModifier { + private final Map modifierContextMap = new HashMap<>(); + + @Override + public void init(CodeModifierContext codeModifierContext) { + codeModifierContext.addSyntaxNodeAnalysisTask(new RemoteFunctionAnalysisTask(modifierContextMap), + OBJECT_METHOD_DEFINITION); + codeModifierContext.addSourceModifierTask(new McpSourceModifier(modifierContextMap)); + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java new file mode 100644 index 0000000..538c212 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.projects.plugins.CompilerPlugin; +import io.ballerina.projects.plugins.CompilerPluginContext; + +public class McpCompilerPlugin extends CompilerPlugin { + + @Override + public void init(CompilerPluginContext compilerPluginContext) { + compilerPluginContext.addCodeModifier(new McpCodeModifier()); + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java new file mode 100644 index 0000000..3baef43 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.compiler.syntax.tree.AnnotationNode; +import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; +import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode; +import io.ballerina.compiler.syntax.tree.MetadataNode; +import io.ballerina.compiler.syntax.tree.ModuleMemberDeclarationNode; +import io.ballerina.compiler.syntax.tree.ModulePartNode; +import io.ballerina.compiler.syntax.tree.Node; +import io.ballerina.compiler.syntax.tree.NodeFactory; +import io.ballerina.compiler.syntax.tree.NodeList; +import io.ballerina.compiler.syntax.tree.NodeParser; +import io.ballerina.compiler.syntax.tree.QualifiedNameReferenceNode; +import io.ballerina.compiler.syntax.tree.ServiceDeclarationNode; +import io.ballerina.compiler.syntax.tree.SyntaxTree; +import io.ballerina.compiler.syntax.tree.Token; +import io.ballerina.projects.DocumentId; +import io.ballerina.projects.Module; +import io.ballerina.projects.plugins.ModifierTask; +import io.ballerina.projects.plugins.SourceModifierContext; +import io.ballerina.tools.text.TextDocument; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static io.ballerina.compiler.syntax.tree.SyntaxKind.CLOSE_BRACE_TOKEN; +import static io.ballerina.compiler.syntax.tree.SyntaxKind.COLON_TOKEN; +import static io.ballerina.compiler.syntax.tree.SyntaxKind.OBJECT_METHOD_DEFINITION; +import static io.ballerina.compiler.syntax.tree.SyntaxKind.OPEN_BRACE_TOKEN; +import static io.ballerina.compiler.syntax.tree.SyntaxKind.QUALIFIED_NAME_REFERENCE; +import static io.ballerina.compiler.syntax.tree.SyntaxKind.SERVICE_DECLARATION; +import static io.ballerina.stdlib.mcp.plugin.RemoteFunctionAnalysisTask.EMPTY_STRING; + +public class McpSourceModifier implements ModifierTask { + private final Map modifierContextMap; + + McpSourceModifier(Map modifierContextMap) { + this.modifierContextMap = modifierContextMap; + } + + @Override + public void modify(SourceModifierContext context) { + for (Map.Entry entry : modifierContextMap.entrySet()) { + modifyDocumentWithTools(context, entry.getKey(), entry.getValue()); + } + } + + private void modifyDocumentWithTools(SourceModifierContext context, DocumentId documentId, + ModifierContext modifierContext) { + Module module = context.currentPackage().module(documentId.moduleId()); + ModulePartNode rootNode = module.document(documentId).syntaxTree().rootNode(); + ModulePartNode updatedRoot = modifyModulePartRoot(rootNode, modifierContext, documentId); + updateDocument(context, module, documentId, updatedRoot); + } + + private ModulePartNode modifyModulePartRoot(ModulePartNode modulePartNode, + ModifierContext modifierContext, DocumentId documentId) { + List modifiedMembers = getModifiedModuleMembers(modulePartNode.members(), + modifierContext, documentId); + return modulePartNode.modify().withMembers(NodeFactory.createNodeList(modifiedMembers)).apply(); + } + + private List getModifiedModuleMembers(NodeList members, + ModifierContext modifierContext, + DocumentId documentId) { + Map modifiedAnnotations = getModifiedAnnotations(modifierContext); + List modifiedMembers = new ArrayList<>(); + + for (ModuleMemberDeclarationNode member : members) { + modifiedMembers.add(getModifiedModuleMember(member, modifiedAnnotations)); + } + + return modifiedMembers; + } + + private Map getModifiedAnnotations(ModifierContext modifierContext) { + Map updatedAnnotationMap = new HashMap<>(); + for (Map.Entry entry : modifierContext + .getAnnotationConfigMap().entrySet()) { + updatedAnnotationMap.put(entry.getKey(), getModifiedAnnotation(entry.getKey(), entry.getValue())); + } + return updatedAnnotationMap; + } + + private AnnotationNode getModifiedAnnotation(AnnotationNode targetNode, ToolAnnotationConfig config) { + String mappingConstructorExpression = generateConfigMappingConstructor(config); + MappingConstructorExpressionNode mappingConstructorNode = (MappingConstructorExpressionNode) NodeParser + .parseExpression(mappingConstructorExpression); + + Node annotationReference = targetNode.annotReference(); + if (annotationReference.kind() == QUALIFIED_NAME_REFERENCE) { + QualifiedNameReferenceNode qualifiedNameReferenceNode = (QualifiedNameReferenceNode) annotationReference; + String identifier = qualifiedNameReferenceNode.identifier().text().replaceAll("\\R", EMPTY_STRING); + String modulePrefix = qualifiedNameReferenceNode.modulePrefix().text(); + annotationReference = NodeFactory.createQualifiedNameReferenceNode( + NodeFactory.createIdentifierToken(modulePrefix), + NodeFactory.createToken(COLON_TOKEN), + NodeFactory.createIdentifierToken(identifier) + ); + Token closeBraceTokenWithNewLine = NodeFactory.createToken( + CLOSE_BRACE_TOKEN, + NodeFactory.createEmptyMinutiaeList(), + NodeFactory.createMinutiaeList( + NodeFactory.createEndOfLineMinutiae(System.lineSeparator()))); + mappingConstructorNode = mappingConstructorNode.modify().withCloseBrace(closeBraceTokenWithNewLine).apply(); + } + return NodeFactory.createAnnotationNode(targetNode.atToken(), annotationReference, mappingConstructorNode); + } + + private String generateConfigMappingConstructor(ToolAnnotationConfig config) { + return generateConfigMappingConstructor(config, OPEN_BRACE_TOKEN.stringValue(), + CLOSE_BRACE_TOKEN.stringValue()); + } + + private String generateConfigMappingConstructor(ToolAnnotationConfig config, String openBraceSource, + String closeBraceSource) { + return openBraceSource + String.format("description:%s,schema:%s", + config.description() != null ? config.description().replaceAll("\\R", " ") : "", + config.schema()) + closeBraceSource; + } + + private ModuleMemberDeclarationNode getModifiedModuleMember(ModuleMemberDeclarationNode member, + Map modifiedAnnotations + ) { + + if (member.kind() == SERVICE_DECLARATION) { + return modifyServiceDeclaration((ServiceDeclarationNode) member, modifiedAnnotations); + } + return member; + } + + private ModuleMemberDeclarationNode modifyServiceDeclaration(ServiceDeclarationNode classDefinitionNode, + Map modifiedAnnotations) { + NodeList members = classDefinitionNode.members(); + ArrayList modifiedMembers = new ArrayList<>(); + + for (Node member : members) { + if (member.kind() == OBJECT_METHOD_DEFINITION) { + FunctionDefinitionNode methodDeclarationNode = (FunctionDefinitionNode) member; + if (methodDeclarationNode.metadata().isPresent()) { + MetadataNode modifiedMetadata = modifyMetadata(methodDeclarationNode.metadata().get(), + modifiedAnnotations); + methodDeclarationNode = methodDeclarationNode.modify().withMetadata(modifiedMetadata).apply(); + } + modifiedMembers.add(methodDeclarationNode); + } else { + modifiedMembers.add(member); + } + } + return classDefinitionNode.modify().withMembers(NodeFactory.createNodeList(modifiedMembers)).apply(); + } + + private MetadataNode modifyMetadata(MetadataNode metadata, + Map modifiedAnnotations) { + List updatedAnnotations = new ArrayList<>(); + for (AnnotationNode annotation : metadata.annotations()) { + updatedAnnotations.add(modifiedAnnotations.getOrDefault(annotation, annotation)); + } + return metadata.modify().withAnnotations(NodeFactory.createNodeList(updatedAnnotations)).apply(); + } + + private void updateDocument(SourceModifierContext context, Module module, DocumentId documentId, + ModulePartNode updatedRoot) { + SyntaxTree syntaxTree = module.document(documentId).syntaxTree().modifyWith(updatedRoot); + TextDocument textDocument = syntaxTree.textDocument(); + if (module.documentIds().contains(documentId)) { + context.modifySourceFile(textDocument, documentId); + } else { + context.modifyTestSourceFile(textDocument, documentId); + } + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java new file mode 100644 index 0000000..4b2b222 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.compiler.syntax.tree.AnnotationNode; + +import java.util.HashMap; +import java.util.Map; + +public class ModifierContext { + private final Map annotationConfigMap = new HashMap<>(); + + void add(AnnotationNode node, ToolAnnotationConfig config) { + annotationConfigMap.put(node, config); + } + + Map getAnnotationConfigMap() { + return annotationConfigMap; + } +} + +record ToolAnnotationConfig( + String description, + String schema) { + + public static final String DESCRIPTION_FIELD_NAME = "description"; + public static final String SCHEMA_FIELD_NAME = "schema"; + + public String get(String field) { + return switch (field) { + case DESCRIPTION_FIELD_NAME -> description(); + case SCHEMA_FIELD_NAME -> schema(); + default -> null; + }; + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java new file mode 100644 index 0000000..355b3a8 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.compiler.api.symbols.AnnotationSymbol; +import io.ballerina.compiler.api.symbols.FunctionSymbol; +import io.ballerina.compiler.api.symbols.Symbol; +import io.ballerina.compiler.api.symbols.SymbolKind; +import io.ballerina.compiler.syntax.tree.AnnotationNode; +import io.ballerina.compiler.syntax.tree.ExpressionNode; +import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; +import io.ballerina.compiler.syntax.tree.MappingFieldNode; +import io.ballerina.compiler.syntax.tree.MetadataNode; +import io.ballerina.compiler.syntax.tree.NodeFactory; +import io.ballerina.compiler.syntax.tree.NodeList; +import io.ballerina.compiler.syntax.tree.NodeParser; +import io.ballerina.compiler.syntax.tree.SeparatedNodeList; +import io.ballerina.compiler.syntax.tree.SpecificFieldNode; +import io.ballerina.compiler.syntax.tree.SyntaxKind; +import io.ballerina.projects.DocumentId; +import io.ballerina.projects.plugins.AnalysisTask; +import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext; +import io.ballerina.stdlib.mcp.plugin.diagnostics.CompilationDiagnostic; +import io.ballerina.tools.diagnostics.Diagnostic; +import io.ballerina.tools.diagnostics.Location; + +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static io.ballerina.stdlib.mcp.plugin.ToolAnnotationConfig.DESCRIPTION_FIELD_NAME; +import static io.ballerina.stdlib.mcp.plugin.ToolAnnotationConfig.SCHEMA_FIELD_NAME; +import static io.ballerina.stdlib.mcp.plugin.diagnostics.CompilationDiagnostic.UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION; + +public class RemoteFunctionAnalysisTask implements AnalysisTask { + public static final String EMPTY_STRING = ""; + public static final String NIL_EXPRESSION = "()"; + + private final Map modifierContextMap; + private SyntaxNodeAnalysisContext context; + + RemoteFunctionAnalysisTask(Map modifierContextMap) { + this.modifierContextMap = modifierContextMap; + } + + @Override + public void perform(SyntaxNodeAnalysisContext context) { + this.context = context; + + FunctionDefinitionNode functionDefinitionNode = (FunctionDefinitionNode) context.node(); + Optional metadataNode = functionDefinitionNode.metadata(); + if (metadataNode.isEmpty()) { + return; + } + + NodeList annotationNodeList = metadataNode.get().annotations(); + Optional toolAnnotationNode = annotationNodeList.stream() + .filter(annotationNode -> + context.semanticModel().symbol(annotationNode) + .filter(symbol -> symbol.kind() == SymbolKind.ANNOTATION) + .filter(symbol -> Utils.isMcpToolAnnotation((AnnotationSymbol) symbol)) + .isPresent() + ) + .findFirst(); + if (toolAnnotationNode.isEmpty()) { + return; + } + + ToolAnnotationConfig config = createAnnotationConfig(toolAnnotationNode.get(), functionDefinitionNode); + addToModifierContext(context, toolAnnotationNode.get(), config); + } + + private ToolAnnotationConfig createAnnotationConfig(AnnotationNode annotationNode, + FunctionDefinitionNode functionDefinitionNode) { + @SuppressWarnings("OptionalGetWithoutIsPresent") // is present already check in perform method + FunctionSymbol functionSymbol = getFunctionSymbol(functionDefinitionNode).get(); + String functionName = functionSymbol.getName().orElse("unknownFunction"); + SeparatedNodeList fields = annotationNode.annotValue().isEmpty() ? + NodeFactory.createSeparatedNodeList() : annotationNode.annotValue().get().fields(); + Map fieldValues = extractFieldValues(fields); + String description = fieldValues.containsKey(DESCRIPTION_FIELD_NAME) + ? fieldValues.get(DESCRIPTION_FIELD_NAME).toSourceCode() + : Utils.addDoubleQuotes(Objects.requireNonNullElse(Utils.getDescription(functionSymbol), functionName)); + String parameters = fieldValues.containsKey(SCHEMA_FIELD_NAME) + ? fieldValues.get(SCHEMA_FIELD_NAME).toSourceCode() + : getParameterSchema(functionSymbol, functionDefinitionNode.location()); + return new ToolAnnotationConfig(description, parameters); + } + + private Optional getFunctionSymbol(FunctionDefinitionNode functionDefinitionNode) { + Optional functionSymbol = context.semanticModel().symbol(functionDefinitionNode); + return functionSymbol.filter(symbol -> symbol.kind() == SymbolKind.FUNCTION + || symbol.kind() == SymbolKind.METHOD).map(FunctionSymbol.class::cast); + } + + private Map extractFieldValues(SeparatedNodeList fields) { + return fields.stream() + .filter(field -> field.kind() == SyntaxKind.SPECIFIC_FIELD) + .map(field -> (SpecificFieldNode) field) + .filter(field -> field.valueExpr().isPresent()) + .collect(Collectors.toMap( + field -> field.fieldName().toSourceCode().trim(), + field -> field.valueExpr().orElse(NodeParser.parseExpression(NIL_EXPRESSION)) + )); + } + + private String getParameterSchema(FunctionSymbol functionSymbol, Location alternativeFunctionLocation) { + try { + return SchemaUtils.getParameterSchema(functionSymbol, this.context); + } catch (Exception e) { + Diagnostic diagnostic = CompilationDiagnostic.getDiagnostic(UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION, + functionSymbol.getLocation().orElse(alternativeFunctionLocation), + functionSymbol.getName().orElse("unknownFunction")); + reportDiagnostic(diagnostic); + return NIL_EXPRESSION; + } + } + + private void reportDiagnostic(Diagnostic diagnostic) { + this.context.reportDiagnostic(diagnostic); + } + + private void addToModifierContext(SyntaxNodeAnalysisContext context, AnnotationNode annotationNode, + ToolAnnotationConfig functionDefinitionNode) { + this.modifierContextMap.computeIfAbsent(context.documentId(), document -> new ModifierContext()) + .add(annotationNode, functionDefinitionNode); + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/SchemaUtils.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/SchemaUtils.java new file mode 100644 index 0000000..9e40cae --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/SchemaUtils.java @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.compiler.api.symbols.FunctionSymbol; +import io.ballerina.compiler.api.symbols.FunctionTypeSymbol; +import io.ballerina.compiler.api.symbols.ParameterKind; +import io.ballerina.compiler.api.symbols.ParameterSymbol; +import io.ballerina.openapi.service.mapper.type.TypeMapper; +import io.ballerina.openapi.service.mapper.type.TypeMapperImpl; +import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext; +import io.swagger.v3.core.util.Json; +import io.swagger.v3.core.util.OpenAPISchema2JsonSchema; +import io.swagger.v3.oas.models.media.Schema; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static io.ballerina.stdlib.mcp.plugin.RemoteFunctionAnalysisTask.EMPTY_STRING; +import static io.ballerina.stdlib.mcp.plugin.RemoteFunctionAnalysisTask.NIL_EXPRESSION; + +/** + * Utility class for generating and manipulating function tool parameter schemas. + */ +public class SchemaUtils { + private static final String STRING = "string"; + private static final String BYTE = "byte"; + private static final String NUMBER = "number"; + + private SchemaUtils() { + } + + public static String getParameterSchema(FunctionSymbol functionSymbol, SyntaxNodeAnalysisContext context) + throws Exception { + FunctionTypeSymbol functionTypeSymbol = functionSymbol.typeDescriptor(); + List parameterSymbolList = functionTypeSymbol.params().get(); + if (functionTypeSymbol.params().isEmpty() || parameterSymbolList.isEmpty()) { + return NIL_EXPRESSION; + } + + Map individualParamSchema = new HashMap<>(); + List requiredParams = new ArrayList<>(); + TypeMapper typeMapper = new TypeMapperImpl(context); + for (ParameterSymbol parameterSymbol : parameterSymbolList) { + try { + String parameterName = parameterSymbol.getName().orElseThrow(); + if (parameterSymbol.paramKind() != ParameterKind.DEFAULTABLE) { + requiredParams.add(parameterName); + } + @SuppressWarnings("rawtypes") + Schema schema = typeMapper.getSchema(parameterSymbol.typeDescriptor()); + String parameterDescription = Utils.getParameterDescription(functionSymbol, parameterName); + schema.setDescription(parameterDescription); + String jsonSchema = SchemaUtils.getJsonSchema(schema); + individualParamSchema.put(parameterName, jsonSchema); + } catch (RuntimeException e) { + throw new Exception(e); + } + } + String properties = individualParamSchema.entrySet().stream() + .map(entry -> String.format("\"%s\": %s", entry.getKey(), entry.getValue())) + .collect(Collectors.joining(", ", "{", "}")); + + String required = requiredParams.stream() + .map(paramName -> String.format("\"%s\"", paramName)) + .collect(Collectors.joining(", ", "[", "]")); + return String.format("{\"type\":\"object\",\"required\":%s,\"properties\":%s}", + required, properties); + } + + @SuppressWarnings("rawtypes") + private static String getJsonSchema(Schema schema) { + modifySchema(schema); + OpenAPISchema2JsonSchema openAPISchema2JsonSchema = new OpenAPISchema2JsonSchema(); + openAPISchema2JsonSchema.process(schema); + String newLineRegex = "\\R"; + String jsonCompressionRegex = "\\s*([{}\\[\\]:,])\\s*"; + return Json.pretty(schema.getJsonSchema()) + .replaceAll(newLineRegex, EMPTY_STRING) + .replaceAll(jsonCompressionRegex, "$1"); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static void modifySchema(Schema schema) { + if (schema == null) { + return; + } + modifySchema(schema.getItems()); + modifySchema(schema.getNot()); + + Map properties = schema.getProperties(); + if (properties != null) { + properties.values().forEach(SchemaUtils::modifySchema); + } + + List allOf = schema.getAllOf(); + if (allOf != null) { + schema.setType(null); + allOf.forEach(SchemaUtils::modifySchema); + } + + List anyOf = schema.getAnyOf(); + if (anyOf != null) { + schema.setType(null); + anyOf.forEach(SchemaUtils::modifySchema); + } + + List oneOf = schema.getOneOf(); + if (oneOf != null) { + schema.setType(null); + oneOf.forEach(SchemaUtils::modifySchema); + } + + // Override default ballerina byte to json schema mapping + if (BYTE.equals(schema.getFormat()) && STRING.equals(schema.getType())) { + schema.setFormat(null); + schema.setType(NUMBER); + } + removeUnwantedFields(schema); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private static void removeUnwantedFields(Schema schema) { + schema.setSpecVersion(null); + schema.setSpecVersion(null); + schema.setContains(null); + schema.set$id(null); + schema.set$schema(null); + schema.set$anchor(null); + schema.setExclusiveMaximumValue(null); + schema.setExclusiveMinimumValue(null); + schema.setDiscriminator(null); + schema.setTitle(null); + schema.setMaximum(null); + schema.setExclusiveMaximum(null); + schema.setMinimum(null); + schema.setExclusiveMinimum(null); + schema.setMaxLength(null); + schema.setMinLength(null); + schema.setMaxItems(null); + schema.setMinItems(null); + schema.setMaxProperties(null); + schema.setMinProperties(null); + schema.setAdditionalProperties(null); + schema.setAdditionalProperties(null); + schema.set$ref(null); + schema.set$ref(null); + schema.setReadOnly(null); + schema.setWriteOnly(null); + schema.setExample(null); + schema.setExample(null); + schema.setExternalDocs(null); + schema.setDeprecated(null); + schema.setPrefixItems(null); + schema.setContentEncoding(null); + schema.setContentMediaType(null); + schema.setContentSchema(null); + schema.setPropertyNames(null); + schema.setUnevaluatedProperties(null); + schema.setMaxContains(null); + schema.setMinContains(null); + schema.setAdditionalItems(null); + schema.setUnevaluatedItems(null); + schema.setIf(null); + schema.setElse(null); + schema.setThen(null); + schema.setDependentSchemas(null); + schema.set$comment(null); + schema.setExamples(null); + schema.setExtensions(null); + schema.setConst(null); + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java new file mode 100644 index 0000000..b4fc750 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.compiler.api.symbols.AnnotationSymbol; +import io.ballerina.compiler.api.symbols.Documentable; +import io.ballerina.compiler.api.symbols.FunctionSymbol; +import io.ballerina.compiler.api.symbols.Symbol; + +/** + * Util class for the compiler plugin. + */ +public class Utils { + public static final String BALLERINA_ORG = "ballerina"; + private static final String TOOL_ANNOTATION_NAME = "McpTool"; + private static final String MCP_PACKAGE_NAME = "mcp"; + + private Utils() { + } + + public static boolean isMcpToolAnnotation(AnnotationSymbol annotationSymbol) { + return annotationSymbol.getModule().isPresent() + && isMcpModuleSymbol(annotationSymbol.getModule().get()) + && annotationSymbol.getName().isPresent() + && TOOL_ANNOTATION_NAME.equals(annotationSymbol.getName().get()); + } + + public static boolean isMcpModuleSymbol(Symbol symbol) { + return symbol.getModule().isPresent() + && MCP_PACKAGE_NAME.equals(symbol.getModule().get().id().moduleName()) + && BALLERINA_ORG.equals(symbol.getModule().get().id().orgName()); + } + + public static String getParameterDescription(FunctionSymbol functionSymbol, String parameterName) { + if (functionSymbol.documentation().isEmpty() + || functionSymbol.documentation().get().description().isEmpty()) { + return null; + } + return functionSymbol.documentation().get().parameterMap().getOrDefault(parameterName, null); + } + + public static String getDescription(Documentable documentable) { + if (documentable.documentation().isEmpty() + || documentable.documentation().get().description().isEmpty()) { + return null; + } + return documentable.documentation().get().description().get(); + } + + public static String addDoubleQuotes(String functionName) { + return "\"" + functionName + "\""; + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java new file mode 100644 index 0000000..5b172bd --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin.diagnostics; + +import io.ballerina.tools.diagnostics.Diagnostic; +import io.ballerina.tools.diagnostics.DiagnosticFactory; +import io.ballerina.tools.diagnostics.DiagnosticInfo; +import io.ballerina.tools.diagnostics.DiagnosticSeverity; +import io.ballerina.tools.diagnostics.Location; + +import static io.ballerina.tools.diagnostics.DiagnosticSeverity.ERROR; + +/** + * Compilation errors in the Ballerina AI package. + */ +public enum CompilationDiagnostic { + UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION(DiagnosticMessage.ERROR_101, DiagnosticCode.AI_101, ERROR); + + private final String diagnostic; + private final String diagnosticCode; + private final DiagnosticSeverity diagnosticSeverity; + + CompilationDiagnostic(DiagnosticMessage message, DiagnosticCode diagnosticCode, + DiagnosticSeverity diagnosticSeverity) { + this.diagnostic = message.getMessage(); + this.diagnosticCode = diagnosticCode.name(); + this.diagnosticSeverity = diagnosticSeverity; + } + + public static Diagnostic getDiagnostic(CompilationDiagnostic compilationDiagnostic, Location location, + Object... args) { + DiagnosticInfo diagnosticInfo = new DiagnosticInfo( + compilationDiagnostic.getDiagnosticCode(), + compilationDiagnostic.getDiagnostic(), + compilationDiagnostic.getDiagnosticSeverity()); + return DiagnosticFactory.createDiagnostic(diagnosticInfo, location, args); + } + + public String getDiagnostic() { + return diagnostic; + } + + public String getDiagnosticCode() { + return diagnosticCode; + } + + public DiagnosticSeverity getDiagnosticSeverity() { + return this.diagnosticSeverity; + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java new file mode 100644 index 0000000..a9b8cc5 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin.diagnostics; + +/** + * Compilation error codes used in Ballerina AI package compiler plugin. + */ +public enum DiagnosticCode { + AI_101 +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java new file mode 100644 index 0000000..61187f5 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin.diagnostics; + +/** + * Compilation error messages used in Ballerina AI package compiler plugin. + */ +public enum DiagnosticMessage { + ERROR_101("failed to generate the parameter schema definition for the function ''{0}''." + + " Specify the parameter schema manually using the `@ai:AgentTool` annotation's parameter field."); + + private final String message; + + DiagnosticMessage(String message) { + this.message = message; + } + + public String getMessage() { + return this.message; + } +} diff --git a/compiler-plugin/src/main/java/module-info.java b/compiler-plugin/src/main/java/module-info.java new file mode 100644 index 0000000..f94df11 --- /dev/null +++ b/compiler-plugin/src/main/java/module-info.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2025 WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +module io.ballerina.stdlib.ai.plugin { + requires io.ballerina.lang; + requires io.ballerina.parser; + requires io.ballerina.tools.api; + requires io.ballerina.openapi.service; + requires io.swagger.v3.core; + requires io.swagger.v3.oas.models; +} diff --git a/gradle.properties b/gradle.properties index 4f28b41..e707a3f 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,6 +4,7 @@ version=0.4.3-SNAPSHOT ballerinaLangVersion=2201.12.0 ballerinaGradlePluginVersion=2.3.0 +ballerinaToOpenApiVersion=2.3.0 checkstylePluginVersion=10.12.0 spotbugsPluginVersion=6.0.18 @@ -11,7 +12,7 @@ spotbugsPluginVersion=6.0.18 shadowJarPluginVersion=8.1.1 downloadPluginVersion=5.4.0 releasePluginVersion=2.8.0 - +swaggerVersion=2.2.9 # Ballerina Library Dependencies # Level 01 diff --git a/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java b/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java new file mode 100644 index 0000000..2e8e09d --- /dev/null +++ b/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp; + +import io.ballerina.runtime.api.Environment; +import io.ballerina.runtime.api.creators.ErrorCreator; +import io.ballerina.runtime.api.creators.ValueCreator; +import io.ballerina.runtime.api.types.ArrayType; +import io.ballerina.runtime.api.types.Parameter; +import io.ballerina.runtime.api.types.RecordType; +import io.ballerina.runtime.api.types.ReferenceType; +import io.ballerina.runtime.api.types.RemoteMethodType; +import io.ballerina.runtime.api.types.ServiceType; +import io.ballerina.runtime.api.types.Type; +import io.ballerina.runtime.api.types.UnionType; +import io.ballerina.runtime.api.values.BArray; +import io.ballerina.runtime.api.values.BMap; +import io.ballerina.runtime.api.values.BObject; +import io.ballerina.runtime.api.values.BString; +import io.ballerina.runtime.api.values.BTypedesc; + +import java.util.List; +import java.util.Optional; + +import static io.ballerina.runtime.api.utils.StringUtils.fromString; + +/** + * Utility class for invoking MCP service remote methods from Java via Ballerina interop. + *

+ * Not instantiable. + */ +public final class McpServiceMethodHelper { + + private static final String FIELD_TOOLS = "tools"; + private static final String FIELD_NAME = "name"; + private static final String FIELD_DESCRIPTION = "description"; + private static final String FIELD_SCHEMA = "schema"; + private static final String FIELD_INPUT_SCHEMA = "inputSchema"; + private static final String FIELD_ARGUMENTS = "arguments"; + private static final String FIELD_CONTENT = "content"; + private static final String FIELD_TYPE = "type"; + private static final String FIELD_TEXT = "text"; + + private static final String ANNOTATION_MCP_TOOL = "McpTool"; + private static final String TYPE_TEXT_CONTENT = "TextContent"; + private static final String VALUE_TEXT = "text"; + + private McpServiceMethodHelper() {} + + /** + * Invoke the 'onListTools' remote method on the given MCP service object. + * + * @param env The Ballerina runtime environment. + * @param mcpService The MCP service object. + * @return Result of remote method invocation. + */ + public static Object invokeOnListTools(Environment env, BObject mcpService) { + return env.getRuntime().callMethod(mcpService, "onListTools", null); + } + + /** + * Invoke the 'onCallTool' remote method on the given MCP service object with parameters. + * + * @param env The Ballerina runtime environment. + * @param mcpService The MCP service object. + * @param params Parameters for the tool invocation. + * @return Result of remote method invocation. + */ + public static Object invokeOnCallTool(Environment env, BObject mcpService, BMap params) { + return env.getRuntime().callMethod(mcpService, "onCallTool", null, params); + } + + /** + * Lists tool metadata for remote functions in the given MCP service. + * + * @param mcpService The MCP service object. + * @param typed The type descriptor for the result. + * @return Record containing the list of tools. + */ + public static Object listToolsForRemoteFunctions(BObject mcpService, BTypedesc typed) { + RecordType resultRecordType = (RecordType) typed.getDescribingType(); + BMap result = ValueCreator.createRecordValue(resultRecordType); + + ArrayType toolsArrayType = (ArrayType) resultRecordType.getFields().get(FIELD_TOOLS).getFieldType(); + BArray tools = ValueCreator.createArrayValue(toolsArrayType); + + for (RemoteMethodType remoteMethod : getRemoteMethods(mcpService)) { + remoteMethod.getAnnotations().entrySet().stream() + .filter(e -> e.getKey().getValue().contains(ANNOTATION_MCP_TOOL)) + .findFirst() + .ifPresent(annotation -> tools.append( + createToolRecord(toolsArrayType, remoteMethod, (BMap) annotation.getValue()) + )); + } + result.put(fromString(FIELD_TOOLS), tools); + return result; + } + + /** + * Invokes a remote function (tool) by name with arguments. + * + * @param env The Ballerina runtime environment. + * @param mcpService The MCP service object. + * @param params The parameters for the tool invocation. + * @param typed The type descriptor for the result. + * @return Record containing the invocation result or an error. + */ + public static Object callToolForRemoteFunctions(Environment env, BObject mcpService, BMap params, + BTypedesc typed) { + BString toolName = (BString) params.get(fromString(FIELD_NAME)); + + RemoteMethodType method = getRemoteMethods(mcpService).stream() + .filter(rmt -> rmt.getName().equals(toolName.getValue())) + .findFirst().orElse(null); + + if (method == null) { + BString errorMessage = + fromString("RemoteMethodType with name '" + toolName.getValue() + "' not found"); + return ErrorCreator.createError(errorMessage); + } + + Object[] args = buildArgsForMethod(method, (BMap) params.get(fromString(FIELD_ARGUMENTS))); + Object result = env.getRuntime().callMethod(mcpService, toolName.getValue(), null, args); + + return createCallToolResult(typed, result); + } + + private static List getRemoteMethods(BObject mcpService) { + ServiceType serviceType = (ServiceType) mcpService.getOriginalType(); + return List.of(serviceType.getRemoteMethods()); + } + + private static BMap createToolRecord(ArrayType toolsArrayType, RemoteMethodType remoteMethod, + BMap annotationValue) { + RecordType toolRecordType = (RecordType) ((ReferenceType) toolsArrayType.getElementType()).getReferredType(); + BMap tool = ValueCreator.createRecordValue(toolRecordType); + + tool.put(fromString(FIELD_NAME), fromString(remoteMethod.getName())); + tool.put(fromString(FIELD_DESCRIPTION), annotationValue.get(fromString(FIELD_DESCRIPTION))); + tool.put(fromString(FIELD_INPUT_SCHEMA), annotationValue.get(fromString(FIELD_SCHEMA))); + return tool; + } + + private static Object[] buildArgsForMethod(RemoteMethodType method, BMap arguments) { + List params = List.of(method.getParameters()); + Object[] args = new Object[params.size()]; + for (int i = 0; i < params.size(); i++) { + String paramName = params.get(i).name; + args[i] = arguments == null ? null : arguments.get(fromString(paramName)); + } + return args; + } + + private static Object createCallToolResult(BTypedesc typed, Object result) { + RecordType resultRecordType = (RecordType) typed.getDescribingType(); + BMap callToolResult = ValueCreator.createRecordValue(resultRecordType); + + ArrayType contentArrayType = (ArrayType) resultRecordType.getFields().get(FIELD_CONTENT).getFieldType(); + BArray contentArray = ValueCreator.createArrayValue(contentArrayType); + + UnionType contentUnionType = (UnionType) contentArrayType.getElementType(); + Optional textContentTypeOpt = contentUnionType.getMemberTypes().stream() + .filter(type -> TYPE_TEXT_CONTENT.equals(type.getName())) + .findFirst(); + if (textContentTypeOpt.isEmpty()) { + BString errorMessage = + fromString("No member type named 'TextContent' found in content union type."); + return ErrorCreator.createError(errorMessage); + } + RecordType textContentRecordType = (RecordType) ((ReferenceType) textContentTypeOpt.get()).getReferredType(); + BMap textContent = ValueCreator.createRecordValue(textContentRecordType); + textContent.put(fromString(FIELD_TYPE), fromString(VALUE_TEXT)); + textContent.put(fromString(FIELD_TEXT), fromString(result == null ? "" : result.toString())); + contentArray.append(textContent); + + callToolResult.put(fromString(FIELD_CONTENT), contentArray); + return callToolResult; + } +} diff --git a/settings.gradle b/settings.gradle index a00d7fb..c5cfee4 100644 --- a/settings.gradle +++ b/settings.gradle @@ -37,10 +37,12 @@ rootProject.name = 'mcp' include ':checkstyle' include ':mcp-ballerina' include ':mcp-native' +include ':mcp-compiler-plugin' project(':checkstyle').projectDir = file("build-config${File.separator}checkstyle") project(':mcp-ballerina').projectDir = file('ballerina') project(':mcp-native').projectDir = file('native') +project(':mcp-compiler-plugin').projectDir = file('compiler-plugin') gradleEnterprise { buildScan { diff --git a/spotbugs-exclude.xml b/spotbugs-exclude.xml index 58da414..1d22b46 100644 --- a/spotbugs-exclude.xml +++ b/spotbugs-exclude.xml @@ -16,4 +16,8 @@ ~ under the License. --> + + + + From ad45fc32a911214371fde047dc915fa1eedb9af4 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 24 Jun 2025 13:53:50 +0530 Subject: [PATCH 03/31] [Automated] Update the native jar versions --- ballerina/Dependencies.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index 18521e4..87e6a30 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -70,7 +70,7 @@ dependencies = [ [[package]] org = "ballerina" name = "http" -version = "2.14.2" +version = "2.14.1" dependencies = [ {org = "ballerina", name = "auth"}, {org = "ballerina", name = "cache"}, From ea3f627e2a69188b56abff3cc20da8b784128bd0 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 24 Jun 2025 14:10:07 +0530 Subject: [PATCH 04/31] Remove unnecessary code --- ballerina/basic_dispatcher_service.bal | 252 ------------------ ballerina/basic_listener.bal | 85 ------ ballerina/dispatcher_service.bal | 50 +--- ballerina/listener.bal | 44 ++- ballerina/main.bal | 52 ++-- ...istener.bal => native_listener_helper.bal} | 4 +- ballerina/types.bal | 24 +- 7 files changed, 90 insertions(+), 421 deletions(-) delete mode 100644 ballerina/basic_dispatcher_service.bal delete mode 100644 ballerina/basic_listener.bal rename ballerina/{native_listener.bal => native_listener_helper.bal} (64%) diff --git a/ballerina/basic_dispatcher_service.bal b/ballerina/basic_dispatcher_service.bal deleted file mode 100644 index ca5712b..0000000 --- a/ballerina/basic_dispatcher_service.bal +++ /dev/null @@ -1,252 +0,0 @@ -// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com). -// -// WSO2 LLC. licenses this file to you under the Apache License, -// Version 2.0 (the "License"); you may not use this file except -// in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -import ballerina/http; -import ballerina/io; -// import ballerina/io; -import ballerina/uuid; - -type BasicDispatcherService distinct service object { - *http:Service; - - isolated function addServiceRef(BasicMcpService basicMcpService); - isolated function removeServiceRef(); - isolated function setServerConfigs(ServerConfiguration serverConfigs); -}; - -BasicDispatcherService basicDispatcherService = isolated service object { - private ServerConfiguration? serverConfigs = (); - private BasicMcpService? basicMcpService = (); - private boolean isInitialized = false; - private string? sessionId = (); - - isolated function addServiceRef(BasicMcpService basicMcpService) { - lock { - self.basicMcpService = basicMcpService; - } - } - - isolated function removeServiceRef() { - lock { - self.basicMcpService = (); - } - } - - isolated function setServerConfigs(ServerConfiguration serverConfigs) { - lock { - self.serverConfigs = serverConfigs.cloneReadOnly(); - } - } - - isolated resource function get .() returns error? { - - } - - isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) returns http:BadRequest|http:Accepted|http:Ok|error { - lock { - io:println("Received request: ", request.cloneReadOnly()); - string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); - if acceptHeader is http:HeaderNotFoundError { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: 1, - 'error: { - code: -32000, - message: "Not Acceptable: Client must accept both application/json and text/event-stream" - } - } - }; - } - if !acceptHeader.includes(CONTENT_TYPE_JSON) || !acceptHeader.includes(CONTENT_TYPE_SSE) { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: 1, - 'error: { - code: -32000, - message: "Not Acceptable: Client must accept both application/json and text/event-stream" - } - } - }; - } - - string|http:HeaderNotFoundError contentTypeHeader = headers.getHeader(CONTENT_TYPE_HEADER); - if contentTypeHeader is http:HeaderNotFoundError { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: null, - 'error: { - code: -32000, - message: "Unsupported Media Type: Content-Type must be application/json" - } - } - }; - } - if !contentTypeHeader.includes(CONTENT_TYPE_JSON) { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: null, - 'error: { - code: -32000, - message: "Unsupported Media Type: Content-Type must be application/json" - } - } - }; - } - - if request is JsonRpcRequest { - if request.method == "initialize" { - if self.isInitialized && self.sessionId != () { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: null, - 'error: { - code: -32600, - message: "Invalid Request: Only one initialization request is allowed" - } - } - }; - } - self.isInitialized = true; - self.sessionId = uuid:createRandomUuid(); - - final string requestedVersion = check (request.params["protocolVersion"]).cloneWithType(); - final readonly & ServerCapabilities? capabilities = (self.serverConfigs?.options?.capabilities).cloneReadOnly(); - final readonly & Implementation? serverInfo = (self.serverConfigs?.serverInfo).cloneReadOnly(); - - if serverInfo is () { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: null, - 'error: { - code: -32000, - message: "Server Info not provided in configuration" - } - } - }; - } - - string protocolVersion = SUPPORTED_PROTOCOL_VERSIONS.some(v => v == requestedVersion) ? requestedVersion - : LATEST_PROTOCOL_VERSION; - - return { - headers: { - [SESSION_ID_HEADER]: self.sessionId ?: "" - }, - body: { - jsonrpc: JSONRPC_VERSION, - id: request.id, - result: { - protocolVersion: protocolVersion, - capabilities: capabilities ?: {}, - serverInfo: serverInfo - } - } - }; - } else if request.method == "tools/list" { - ListToolsResult listToolsResult = check self.executeListTools(); - io:println("ListToolsResult: ", listToolsResult.cloneReadOnly()); - return { - headers: { - [SESSION_ID_HEADER]: self.sessionId ?: "" - }, - body: { - jsonrpc: JSONRPC_VERSION, - id: request.id, - result: listToolsResult.cloneReadOnly() - } - }; - } else if request.method == "tools/call" { - CallToolParams params = check request.cloneReadOnly().params.ensureType(CallToolParams); - CallToolResult callToolResult = check self.executeCallTool(params); - return { - headers: { - [SESSION_ID_HEADER]: self.sessionId ?: "" - }, - body: { - jsonrpc: JSONRPC_VERSION, - id: request.id, - result: callToolResult.cloneReadOnly() - } - }; - } else { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: request.id, - 'error: { - code: -32601, - message: "Method not found" - } - } - }; - } - } - else if request is JsonRpcNotification { - if request.method == "notifications/initialized" { - return http:ACCEPTED; - } - } - } - // if request is InitializeRequest { - // lock { - // if self.isInitialized && self.sessionId != () { - // return { - // jsonrpc: JSONRPC_VERSION, - // id: null, - // 'error: { - // code: -32600, - // message: "Invalid Request: Only one initialization request is allowed" - // } - // }; - // } - // self.isInitialized = true; - // self.sessionId = uuid:createRandomUuid(); - // io:println("Session initialized with ID: ", self.sessionId); - // } - // } else if request is ListToolsRequest { - // io:println("Received ListToolsRequest"); - // } else if request is CallToolRequest { - // io:println("Received CallToolRequest"); - // } - return error("Unsupported request type"); - } - - private isolated function executeListTools() returns ListToolsResult|error { - lock { - BasicMcpService? chatService = self.basicMcpService; - if chatService is BasicMcpService { - return check listToolsForRemoteFunctions(chatService); - } - return error("MCP Service is not attached"); - } - } - - private isolated function executeCallTool(CallToolParams params) returns CallToolResult|error { - lock { - BasicMcpService? chatService = self.basicMcpService; - if chatService is BasicMcpService { - return check callToolForRemoteFunctions(chatService, params.cloneReadOnly()); - } - return error("MCP Service is not attached"); - } - } -}; diff --git a/ballerina/basic_listener.bal b/ballerina/basic_listener.bal deleted file mode 100644 index 42c1028..0000000 --- a/ballerina/basic_listener.bal +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com). -// -// WSO2 LLC. licenses this file to you under the Apache License, -// Version 2.0 (the "License"); you may not use this file except -// in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -import ballerina/http; - -public type McpToolConfig record {| - string description?; - json schema?; -|}; - -public annotation McpToolConfig McpTool on object function; - -# A server listener for handling MCP service requests. -public class BasicListener { - private http:Listener httpListener; - private BasicDispatcherService dispatcherService; - - # Initializes the Listener. - # - # + listenTo - Either a port number (int) or an existing http:Listener. - # + config - Optional listener configuration. - # + return - error? if listener initialization fails. - public function init(int|http:Listener listenTo, ServerConfiguration serverConfigs, *ListenerConfiguration config) returns error? { - if listenTo is http:Listener { - self.httpListener = listenTo; - } else { - self.httpListener = check new (listenTo, config); - } - self.dispatcherService = basicDispatcherService; - self.dispatcherService.setServerConfigs(serverConfigs); - } - - # Attaches an MCP service to the listener under the specified path(s). - # - # + mcpService - Service to attach. - # + name - Path(s) to mount the service on (string or string array). - # + return - error? if attachment fails. - public isolated function attach(BasicMcpService basicMcpService, string[]|string? name = ()) returns error? { - check self.httpListener.attach(self.dispatcherService, name); - self.dispatcherService.addServiceRef(basicMcpService); - } - - # Detaches the MCP service from the listener. - # - # + mcpService - Service to detach. - # + return - error? if detachment fails. - public isolated function detach(BasicMcpService basicMcpService) returns error? { - check self.httpListener.detach(self.dispatcherService); - self.dispatcherService.removeServiceRef(); - } - - # Starts the listener (begin accepting connections). - # - # + return - error? if starting fails. - public isolated function 'start() returns error? { - check self.httpListener.start(); - } - - # Gracefully stops the listener (completes active requests before shutting down). - # - # + return - error? if graceful stop fails. - public isolated function gracefulStop() returns error? { - check self.httpListener.gracefulStop(); - } - - # Immediately stops the listener (terminates all connections). - # - # + return - error? if immediate stop fails. - public isolated function immediateStop() returns error? { - check self.httpListener.immediateStop(); - } -} diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index 67dca26..72acde2 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -18,23 +18,21 @@ import ballerina/http; import ballerina/io; import ballerina/uuid; -// import ballerina/io; - type DispatcherService distinct service object { *http:Service; - isolated function addServiceRef(McpService mcpService); + isolated function addServiceRef(McpService|McpDeclarativeService mcpService); isolated function removeServiceRef(); isolated function setServerConfigs(ServerConfiguration serverConfigs); }; DispatcherService dispatcherService = isolated service object { private ServerConfiguration? serverConfigs = (); - private McpService? mcpService = (); + private McpService|McpDeclarativeService? mcpService = (); private boolean isInitialized = false; private string? sessionId = (); - isolated function addServiceRef(McpService mcpService) { + isolated function addServiceRef(McpService|McpDeclarativeService mcpService) { lock { self.mcpService = mcpService; } @@ -52,13 +50,8 @@ DispatcherService dispatcherService = isolated service object { } } - isolated resource function get .() returns error? { - - } - isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) returns http:BadRequest|http:Accepted|http:Ok|error { lock { - io:println("Received request: ", request.cloneReadOnly()); string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); if acceptHeader is http:HeaderNotFoundError { return { @@ -207,35 +200,16 @@ DispatcherService dispatcherService = isolated service object { } } } - // if request is InitializeRequest { - // lock { - // if self.isInitialized && self.sessionId != () { - // return { - // jsonrpc: JSONRPC_VERSION, - // id: null, - // 'error: { - // code: -32600, - // message: "Invalid Request: Only one initialization request is allowed" - // } - // }; - // } - // self.isInitialized = true; - // self.sessionId = uuid:createRandomUuid(); - // io:println("Session initialized with ID: ", self.sessionId); - // } - // } else if request is ListToolsRequest { - // io:println("Received ListToolsRequest"); - // } else if request is CallToolRequest { - // io:println("Received CallToolRequest"); - // } return error("Unsupported request type"); } private isolated function executeOnListTools() returns ListToolsResult|error { lock { - McpService? chatService = self.mcpService; - if chatService is McpService { - return check invokeOnListTools(chatService); + McpService|McpDeclarativeService? mcpService = self.mcpService; + if mcpService is McpService { + return check invokeOnListTools(mcpService); + } else if mcpService is McpDeclarativeService { + return check listToolsForRemoteFunctions(mcpService); } return error("MCP Service is not attached"); } @@ -243,9 +217,11 @@ DispatcherService dispatcherService = isolated service object { private isolated function executeOnCallTool(CallToolParams params) returns CallToolResult|error { lock { - McpService? chatService = self.mcpService; - if chatService is McpService { - return check invokeOnCallTool(chatService, params.cloneReadOnly()); + McpService|McpDeclarativeService? mcpService = self.mcpService; + if mcpService is McpService { + return check invokeOnCallTool(mcpService, params.cloneReadOnly()); + } else if mcpService is McpDeclarativeService { + return check callToolForRemoteFunctions(mcpService, params.cloneReadOnly()); } return error("MCP Service is not attached"); } diff --git a/ballerina/listener.bal b/ballerina/listener.bal index 9a065af..6c12f80 100644 --- a/ballerina/listener.bal +++ b/ballerina/listener.bal @@ -16,67 +16,85 @@ import ballerina/http; +# Represents the options for configuring an MCP server. public type ServerOptions record {| *ProtocolOptions; + # Capabilities to advertise as being supported by this server. ServerCapabilities capabilities?; + # Optional instructions describing how to use the server and its features. string instructions?; |}; +# Configuration options for initializing an MCP listener. +public type ListenerConfiguration record {| + *http:ListenerConfiguration; + *ServerConfiguration; +|}; + +type ServerConfiguration record {| + Implementation serverInfo; + ServerOptions options?; +|}; + # A server listener for handling MCP service requests. public class Listener { private http:Listener httpListener; private DispatcherService dispatcherService; # Initializes the Listener. - # + # # + listenTo - Either a port number (int) or an existing http:Listener. - # + config - Optional listener configuration. + # + config - Listener configuration. # + return - error? if listener initialization fails. - public function init(int|http:Listener listenTo, ServerConfiguration serverConfigs, *ListenerConfiguration config) returns error? { + public function init(int|http:Listener listenTo, *ListenerConfiguration config) returns error? { + ListenerConfiguration {serverInfo, options, ...listenerConfig} = config; if listenTo is http:Listener { self.httpListener = listenTo; - } else { - self.httpListener = check new (listenTo, config); + } else { + self.httpListener = check new (listenTo, listenerConfig); } self.dispatcherService = dispatcherService; - self.dispatcherService.setServerConfigs(serverConfigs); + self.dispatcherService.setServerConfigs({ + serverInfo, + options + }); } # Attaches an MCP service to the listener under the specified path(s). - # + # # + mcpService - Service to attach. # + name - Path(s) to mount the service on (string or string array). # + return - error? if attachment fails. - public isolated function attach(McpService mcpService, string[]|string? name = ()) returns error? { + public isolated function attach(McpService|McpDeclarativeService mcpService, string[]|string? name = ()) returns error? { check self.httpListener.attach(self.dispatcherService, name); self.dispatcherService.addServiceRef(mcpService); } # Detaches the MCP service from the listener. - # + # # + mcpService - Service to detach. # + return - error? if detachment fails. - public isolated function detach(McpService mcpService) returns error? { + public isolated function detach(McpService|McpDeclarativeService mcpService) returns error? { check self.httpListener.detach(self.dispatcherService); self.dispatcherService.removeServiceRef(); } # Starts the listener (begin accepting connections). - # + # # + return - error? if starting fails. public isolated function 'start() returns error? { check self.httpListener.start(); } # Gracefully stops the listener (completes active requests before shutting down). - # + # # + return - error? if graceful stop fails. public isolated function gracefulStop() returns error? { check self.httpListener.gracefulStop(); } # Immediately stops the listener (terminates all connections). - # + # # + return - error? if immediate stop fails. public isolated function immediateStop() returns error? { check self.httpListener.immediateStop(); diff --git a/ballerina/main.bal b/ballerina/main.bal index 863715f..5564257 100644 --- a/ballerina/main.bal +++ b/ballerina/main.bal @@ -63,27 +63,39 @@ // } // } -listener BasicListener basicListener = check new (9091, serverConfigs = { - serverInfo: { - name: "Basic MCP Server", - version: "1.0.0" - }, - options: {capabilities: {}} -}); +// listener BasicListener basicListener = check new (9091, serverConfigs = { +// serverInfo: { +// name: "Basic MCP Server", +// version: "1.0.0" +// }, +// options: {capabilities: {}} +// }); + +// service /mcp on basicListener { +// @McpTool { +// description: "Add two numbers", +// schema: { +// 'type: "object", +// properties: { +// "a": {"type": "integer", "description": "First number"}, +// "b": {"type": "integer", "description": "Second number"} +// }, +// required: ["a", "b"] +// } +// } +// remote function add(int a, int b) returns int { +// return a + b; +// } +// } + +listener Listener mcpListener = check new (9090, serverInfo = {name: "", version: ""}); -service /mcp on basicListener { - @McpTool { - description: "Add two numbers", - schema: { - 'type: "object", - properties: { - "a": {"type": "integer", "description": "First number"}, - "b": {"type": "integer", "description": "Second number"} - }, - required: ["a", "b"] - } +isolated service McpService /mcp on mcpListener { + remote isolated function onListTools() returns ListToolsResult|error { + return error("Not implemented"); } - remote function add(int a, int b) returns int { - return a + b; + + remote isolated function onCallTool(CallToolParams params) returns CallToolResult|error { + return error("Not implemented"); } } diff --git a/ballerina/native_listener.bal b/ballerina/native_listener_helper.bal similarity index 64% rename from ballerina/native_listener.bal rename to ballerina/native_listener_helper.bal index 06cc428..a948ff9 100644 --- a/ballerina/native_listener.bal +++ b/ballerina/native_listener_helper.bal @@ -8,10 +8,10 @@ isolated function invokeOnCallTool(McpService 'service, CallToolParams params) r 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; -isolated function listToolsForRemoteFunctions(BasicMcpService 'service, typedesc t = <>) returns t|error = @java:Method { +isolated function listToolsForRemoteFunctions(McpDeclarativeService 'service, typedesc t = <>) returns t|error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; -isolated function callToolForRemoteFunctions(BasicMcpService 'service, CallToolParams params, typedesc t = <>) returns t|error = @java:Method { +isolated function callToolForRemoteFunctions(McpDeclarativeService 'service, CallToolParams params, typedesc t = <>) returns t|error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; diff --git a/ballerina/types.bal b/ballerina/types.bal index 5332bfc..2d132e1 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -14,8 +14,6 @@ // specific language governing permissions and limitations // under the License. -import ballerina/http; - # Refers to any valid JSON-RPC object that can be decoded off the wire, or encoded to be sent. public type JsonRpcMessage JsonRpcRequest|JsonRpcNotification|JsonRpcResponse; @@ -414,21 +412,23 @@ public type AudioContent record { # Represents a result sent from the server to the client. public type ServerResult InitializeResult|CallToolResult|ListToolsResult|EmptyResult; +# Represents a tool configuration that can be used to define tools available in the MCP service. +public type McpToolConfig record {| + # The description of the tool. + string description?; + # The JSON schema for the tool's parameters. + json schema?; +|}; + +# Annotation to mark a function as an MCP tool configuration. +public annotation McpToolConfig McpTool on object function; + # Defines a mcp service interface that handles incoming mcp requests. public type McpService distinct isolated service object { remote isolated function onListTools() returns ListToolsResult|error; remote isolated function onCallTool(CallToolParams params) returns CallToolResult|error; }; -public type BasicMcpService distinct isolated service object { +public type McpDeclarativeService distinct isolated service object { }; - -public type ListenerConfiguration record {| - *http:ListenerConfiguration; -|}; - -public type ServerConfiguration record {| - Implementation serverInfo; - ServerOptions options?; -|}; From 91bc0aff6561ebdd9487004cc60359de1a027ad4 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 24 Jun 2025 16:30:11 +0530 Subject: [PATCH 05/31] [Automated] Update the native jar versions --- ballerina/Dependencies.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index 87e6a30..2e2263d 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -70,7 +70,7 @@ dependencies = [ [[package]] org = "ballerina" name = "http" -version = "2.14.1" +version = "2.14.2" dependencies = [ {org = "ballerina", name = "auth"}, {org = "ballerina", name = "cache"}, @@ -108,9 +108,6 @@ dependencies = [ {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "lang.value"} ] -modules = [ - {org = "ballerina", packageName = "io", moduleName = "io"} -] [[package]] org = "ballerina" From 2c9086e2f1fa4f3c3be265fcc4e0949276b4b54f Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 24 Jun 2025 16:30:40 +0530 Subject: [PATCH 06/31] Refactor dispatcher service --- ballerina/dispatcher_service.bal | 325 ++++++++++++++++++------------- ballerina/error.bal | 3 + ballerina/main.bal | 46 ++--- 3 files changed, 209 insertions(+), 165 deletions(-) diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index 72acde2..0984fae 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -15,9 +15,12 @@ // under the License. import ballerina/http; -import ballerina/io; import ballerina/uuid; +# Custom error type for dispatcher service operations. +type DispatcherError distinct error; + +# Represents the dispatcher service type definition. type DispatcherService distinct service object { *http:Service; @@ -50,157 +53,213 @@ DispatcherService dispatcherService = isolated service object { } } - isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) returns http:BadRequest|http:Accepted|http:Ok|error { + isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) + returns http:BadRequest|http:Accepted|http:Ok|error { + + http:BadRequest? headerValidationError = self.validateHeaders(headers); + if headerValidationError is http:BadRequest { + return headerValidationError; + } + lock { - string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); - if acceptHeader is http:HeaderNotFoundError { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: 1, - 'error: { - code: -32000, - message: "Not Acceptable: Client must accept both application/json and text/event-stream" - } - } - }; + if request is JsonRpcRequest { + return self.processJsonRpcRequest(request.cloneReadOnly()); } - if !acceptHeader.includes(CONTENT_TYPE_JSON) || !acceptHeader.includes(CONTENT_TYPE_SSE) { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: 1, - 'error: { - code: -32000, - message: "Not Acceptable: Client must accept both application/json and text/event-stream" - } - } - }; + else if request is JsonRpcNotification { + return self.processJsonRpcNotification(request.cloneReadOnly()); } - string|http:HeaderNotFoundError contentTypeHeader = headers.getHeader(CONTENT_TYPE_HEADER); - if contentTypeHeader is http:HeaderNotFoundError { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: null, - 'error: { - code: -32000, - message: "Unsupported Media Type: Content-Type must be application/json" - } - } - }; + return self.createErrorResponse(null, INVALID_REQUEST, "Unsupported request type"); + } + } + + private isolated function validateHeaders(http:Headers headers) returns http:BadRequest? { + // Validate Accept header + string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); + if acceptHeader is http:HeaderNotFoundError { + return self.createErrorResponse(1, -32000, + "Not Acceptable: Client must accept both application/json and text/event-stream"); + } + + if !acceptHeader.includes(CONTENT_TYPE_JSON) || !acceptHeader.includes(CONTENT_TYPE_SSE) { + return self.createErrorResponse(1, -32000, + "Not Acceptable: Client must accept both application/json and text/event-stream"); + } + + // Validate Content-Type header + string|http:HeaderNotFoundError contentTypeHeader = headers.getHeader(CONTENT_TYPE_HEADER); + if contentTypeHeader is http:HeaderNotFoundError { + return self.createErrorResponse(null, -32000, + "Unsupported Media Type: Content-Type must be application/json"); + } + + if !contentTypeHeader.includes(CONTENT_TYPE_JSON) { + return self.createErrorResponse(null, -32000, + "Unsupported Media Type: Content-Type must be application/json"); + } + + return (); + } + + private isolated function processJsonRpcRequest(JsonRpcRequest request) returns http:BadRequest|http:Ok|error { + match request.method { + "initialize" => { + return self.handleInitializeRequest(request); + } + "tools/list" => { + return self.handleListToolsRequest(request); + } + "tools/call" => { + return self.handleCallToolRequest(request); } - if !contentTypeHeader.includes(CONTENT_TYPE_JSON) { + _ => { + return self.createErrorResponse(request.id, METHOD_NOT_FOUND, "Method not found"); + } + } + } + + private isolated function processJsonRpcNotification(JsonRpcNotification notification) returns http:Accepted|http:BadRequest { + match notification.method { + "notifications/initialized" => { + return http:ACCEPTED; + } + _ => { return { body: { jsonrpc: JSONRPC_VERSION, - id: null, 'error: { - code: -32000, - message: "Unsupported Media Type: Content-Type must be application/json" + code: METHOD_NOT_FOUND, + message: "Unknown notification method" } } }; } + } + } - if request is JsonRpcRequest { - if request.method == "initialize" { - if self.isInitialized && self.sessionId != () { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: null, - 'error: { - code: -32600, - message: "Invalid Request: Only one initialization request is allowed" - } - } - }; - } - self.isInitialized = true; - self.sessionId = uuid:createRandomUuid(); - - final string requestedVersion = check (request.params["protocolVersion"]).cloneWithType(); - final readonly & ServerCapabilities? capabilities = (self.serverConfigs?.options?.capabilities).cloneReadOnly(); - final readonly & Implementation? serverInfo = (self.serverConfigs?.serverInfo).cloneReadOnly(); - - if serverInfo is () { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: null, - 'error: { - code: -32000, - message: "Server Info not provided in configuration" - } - } - }; + private isolated function handleInitializeRequest(JsonRpcRequest jsonRpcRequest) returns http:BadRequest|http:Ok { + JsonRpcRequest {jsonrpc, id, ...request} = jsonRpcRequest; + InitializeRequest|error initRequest = request.cloneWithType(InitializeRequest); + if initRequest is error { + return self.createErrorResponse(id, INVALID_REQUEST, + string `Invalid request: ${initRequest.message()}`); + } + + lock { + // If it's a server with session management and the session ID is already set we should reject the request + // to avoid re-initialization. + if self.isInitialized && self.sessionId != () { + return self.createErrorResponse(id, INVALID_REQUEST, + "Invalid Request: Only one initialization request is allowed"); + } + + self.isInitialized = true; + self.sessionId = uuid:createRandomUuid(); + + string requestedVersion = initRequest.params.protocolVersion; + string protocolVersion = self.selectProtocolVersion(requestedVersion); + + return { + headers: { + [SESSION_ID_HEADER]: self.sessionId ?: "" + }, + body: { + jsonrpc: JSONRPC_VERSION, + id: id, + result: { + protocolVersion: protocolVersion, + capabilities: (self.serverConfigs?.options?.capabilities ?: {}).cloneReadOnly(), + serverInfo: (self.serverConfigs?.serverInfo ?: { + name: "MCP Server", + version: "1.0.0" + }).cloneReadOnly() } + } + }; + } + } - string protocolVersion = SUPPORTED_PROTOCOL_VERSIONS.some(v => v == requestedVersion) ? requestedVersion - : LATEST_PROTOCOL_VERSION; - - return { - headers: { - [SESSION_ID_HEADER]: self.sessionId ?: "" - }, - body: { - jsonrpc: JSONRPC_VERSION, - id: request.id, - result: { - protocolVersion: protocolVersion, - capabilities: capabilities ?: {}, - serverInfo: serverInfo - } - } - }; - } else if request.method == "tools/list" { - ListToolsResult listToolsResult = check self.executeOnListTools(); - io:println("ListToolsResult: ", listToolsResult.cloneReadOnly()); - return { - headers: { - [SESSION_ID_HEADER]: self.sessionId ?: "" - }, - body: { - jsonrpc: JSONRPC_VERSION, - id: request.id, - result: listToolsResult.cloneReadOnly() - } - }; - } else if request.method == "tools/call" { - CallToolParams params = check request.cloneReadOnly().params.ensureType(CallToolParams); - CallToolResult callToolResult = check self.executeOnCallTool(params); - return { - headers: { - [SESSION_ID_HEADER]: self.sessionId ?: "" - }, - body: { - jsonrpc: JSONRPC_VERSION, - id: request.id, - result: callToolResult.cloneReadOnly() - } - }; - } else { - return { - body: { - jsonrpc: JSONRPC_VERSION, - id: request.id, - 'error: { - code: -32601, - message: "Method not found" - } - } - }; + private isolated function handleListToolsRequest(JsonRpcRequest request) returns http:BadRequest|http:Ok { + lock { + // Check if initialized + if !self.isInitialized { + return self.createErrorResponse(request.id, INVALID_REQUEST, + "Client must be initialized before making requests"); + } + + ListToolsResult|error listToolsResult = self.executeOnListTools(); + if listToolsResult is error { + return self.createErrorResponse(request.id, INTERNAL_ERROR, + string `Failed to list tools: ${listToolsResult.message()}`); + } + + return { + headers: { + [SESSION_ID_HEADER]: self.sessionId ?: "" + }, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: listToolsResult.cloneReadOnly() } + }; + } + } + + private isolated function handleCallToolRequest(JsonRpcRequest request) returns http:BadRequest|http:Ok { + lock { + // Check if initialized + if !self.isInitialized { + return self.createErrorResponse(request.id, INVALID_REQUEST, + "Client must be initialized before making requests"); } - else if request is JsonRpcNotification { - if request.method == "notifications/initialized" { - return http:ACCEPTED; + + // Extract and validate parameters + CallToolParams|error params = request.cloneReadOnly().params.ensureType(CallToolParams); + if params is error { + return self.createErrorResponse(request.id, INVALID_PARAMS, + string `Invalid parameters: ${params.message()}`); + } + + CallToolResult|error callToolResult = self.executeOnCallTool(params); + if callToolResult is error { + return self.createErrorResponse(request.id, INTERNAL_ERROR, + string `Failed to call tool '${params.name}': ${callToolResult.message()}`); + } + + return { + headers: { + [SESSION_ID_HEADER]: self.sessionId ?: "" + }, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: callToolResult.cloneReadOnly() } + }; + } + } + + private isolated function selectProtocolVersion(string requestedVersion) returns string { + foreach string supportedVersion in SUPPORTED_PROTOCOL_VERSIONS { + if supportedVersion == requestedVersion { + return requestedVersion; } } - return error("Unsupported request type"); + return LATEST_PROTOCOL_VERSION; + } + + private isolated function createErrorResponse(RequestId? id, int code, string message) returns http:BadRequest { + return { + body: { + jsonrpc: JSONRPC_VERSION, + id: id, + 'error: { + code: code, + message: message + } + } + }; } private isolated function executeOnListTools() returns ListToolsResult|error { @@ -211,7 +270,7 @@ DispatcherService dispatcherService = isolated service object { } else if mcpService is McpDeclarativeService { return check listToolsForRemoteFunctions(mcpService); } - return error("MCP Service is not attached"); + return error DispatcherError("MCP Service is not attached"); } } @@ -223,7 +282,7 @@ DispatcherService dispatcherService = isolated service object { } else if mcpService is McpDeclarativeService { return check callToolForRemoteFunctions(mcpService, params.cloneReadOnly()); } - return error("MCP Service is not attached"); + return error DispatcherError("MCP Service is not attached"); } } }; diff --git a/ballerina/error.bal b/ballerina/error.bal index 07374d2..bd75ca8 100644 --- a/ballerina/error.bal +++ b/ballerina/error.bal @@ -79,3 +79,6 @@ public type ListToolsError distinct ClientError; # Error for failures during tool execution operations. public type ToolCallError distinct ClientError; + +# Errors for failures occuring during server operations. +public type ServerError distinct Error; diff --git a/ballerina/main.bal b/ballerina/main.bal index 5564257..798c894 100644 --- a/ballerina/main.bal +++ b/ballerina/main.bal @@ -63,39 +63,21 @@ // } // } -// listener BasicListener basicListener = check new (9091, serverConfigs = { -// serverInfo: { -// name: "Basic MCP Server", -// version: "1.0.0" -// }, -// options: {capabilities: {}} -// }); - -// service /mcp on basicListener { -// @McpTool { -// description: "Add two numbers", -// schema: { -// 'type: "object", -// properties: { -// "a": {"type": "integer", "description": "First number"}, -// "b": {"type": "integer", "description": "Second number"} -// }, -// required: ["a", "b"] -// } -// } -// remote function add(int a, int b) returns int { -// return a + b; -// } -// } - -listener Listener mcpListener = check new (9090, serverInfo = {name: "", version: ""}); +listener Listener basicListener = check new (9091, serverInfo = {name: "Basic MCP Server", version: "1.0.0"}); -isolated service McpService /mcp on mcpListener { - remote isolated function onListTools() returns ListToolsResult|error { - return error("Not implemented"); +isolated service McpDeclarativeService /mcp on basicListener { + @McpTool { + description: "Add two numbers", + schema: { + 'type: "object", + properties: { + "a": {"type": "integer", "description": "First number"}, + "b": {"type": "integer", "description": "Second number"} + }, + required: ["a", "b"] + } } - - remote isolated function onCallTool(CallToolParams params) returns CallToolResult|error { - return error("Not implemented"); + remote function add(int a, int b) returns int { + return a + b; } } From 223172ae2d723838e79f3d2bd2a39f9ba5fbfebf Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 26 Jun 2025 12:30:19 +0530 Subject: [PATCH 07/31] Fix mcp service names --- ballerina/dispatcher_service.bal | 18 +-- ballerina/listener.bal | 4 +- ballerina/main.bal | 9 +- ballerina/native_listener_helper.bal | 8 +- ballerina/types.bal | 4 +- .../stdlib/mcp/plugin/McpSourceModifier.java | 140 +++++++++++------- .../stdlib/mcp/plugin/ModifierContext.java | 8 +- .../plugin/RemoteFunctionAnalysisTask.java | 51 +++---- .../io/ballerina/stdlib/mcp/plugin/Utils.java | 39 ++++- 9 files changed, 172 insertions(+), 109 deletions(-) diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index 0984fae..461988d 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -24,18 +24,18 @@ type DispatcherError distinct error; type DispatcherService distinct service object { *http:Service; - isolated function addServiceRef(McpService|McpDeclarativeService mcpService); + isolated function addServiceRef(Service|AdvancedService mcpService); isolated function removeServiceRef(); isolated function setServerConfigs(ServerConfiguration serverConfigs); }; DispatcherService dispatcherService = isolated service object { private ServerConfiguration? serverConfigs = (); - private McpService|McpDeclarativeService? mcpService = (); + private Service|AdvancedService? mcpService = (); private boolean isInitialized = false; private string? sessionId = (); - isolated function addServiceRef(McpService|McpDeclarativeService mcpService) { + isolated function addServiceRef(Service|AdvancedService mcpService) { lock { self.mcpService = mcpService; } @@ -264,10 +264,10 @@ DispatcherService dispatcherService = isolated service object { private isolated function executeOnListTools() returns ListToolsResult|error { lock { - McpService|McpDeclarativeService? mcpService = self.mcpService; - if mcpService is McpService { + Service|AdvancedService? mcpService = self.mcpService; + if mcpService is AdvancedService { return check invokeOnListTools(mcpService); - } else if mcpService is McpDeclarativeService { + } else if mcpService is Service { return check listToolsForRemoteFunctions(mcpService); } return error DispatcherError("MCP Service is not attached"); @@ -276,10 +276,10 @@ DispatcherService dispatcherService = isolated service object { private isolated function executeOnCallTool(CallToolParams params) returns CallToolResult|error { lock { - McpService|McpDeclarativeService? mcpService = self.mcpService; - if mcpService is McpService { + Service|AdvancedService? mcpService = self.mcpService; + if mcpService is AdvancedService { return check invokeOnCallTool(mcpService, params.cloneReadOnly()); - } else if mcpService is McpDeclarativeService { + } else if mcpService is Service { return check callToolForRemoteFunctions(mcpService, params.cloneReadOnly()); } return error DispatcherError("MCP Service is not attached"); diff --git a/ballerina/listener.bal b/ballerina/listener.bal index 6c12f80..f96c57a 100644 --- a/ballerina/listener.bal +++ b/ballerina/listener.bal @@ -65,7 +65,7 @@ public class Listener { # + mcpService - Service to attach. # + name - Path(s) to mount the service on (string or string array). # + return - error? if attachment fails. - public isolated function attach(McpService|McpDeclarativeService mcpService, string[]|string? name = ()) returns error? { + public isolated function attach(Service|AdvancedService mcpService, string[]|string? name = ()) returns error? { check self.httpListener.attach(self.dispatcherService, name); self.dispatcherService.addServiceRef(mcpService); } @@ -74,7 +74,7 @@ public class Listener { # # + mcpService - Service to detach. # + return - error? if detachment fails. - public isolated function detach(McpService|McpDeclarativeService mcpService) returns error? { + public isolated function detach(Service|AdvancedService mcpService) returns error? { check self.httpListener.detach(self.dispatcherService); self.dispatcherService.removeServiceRef(); } diff --git a/ballerina/main.bal b/ballerina/main.bal index 798c894..953154b 100644 --- a/ballerina/main.bal +++ b/ballerina/main.bal @@ -63,9 +63,9 @@ // } // } -listener Listener basicListener = check new (9091, serverInfo = {name: "Basic MCP Server", version: "1.0.0"}); +listener Listener basicListener = check new (9092, serverInfo = {name: "Basic MCP Server", version: "1.0.0"}); -isolated service McpDeclarativeService /mcp on basicListener { +isolated service Service /mcp on basicListener { @McpTool { description: "Add two numbers", schema: { @@ -80,4 +80,9 @@ isolated service McpDeclarativeService /mcp on basicListener { remote function add(int a, int b) returns int { return a + b; } + + @McpTool + remote function add1(int a, int b) returns int { + return a + b; + } } diff --git a/ballerina/native_listener_helper.bal b/ballerina/native_listener_helper.bal index a948ff9..614e1cf 100644 --- a/ballerina/native_listener_helper.bal +++ b/ballerina/native_listener_helper.bal @@ -1,17 +1,17 @@ import ballerina/jballerina.java; -isolated function invokeOnListTools(McpService 'service) returns ListToolsResult|error = @java:Method { +isolated function invokeOnListTools(AdvancedService 'service) returns ListToolsResult|error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; -isolated function invokeOnCallTool(McpService 'service, CallToolParams params) returns CallToolResult|error = @java:Method { +isolated function invokeOnCallTool(AdvancedService 'service, CallToolParams params) returns CallToolResult|error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; -isolated function listToolsForRemoteFunctions(McpDeclarativeService 'service, typedesc t = <>) returns t|error = @java:Method { +isolated function listToolsForRemoteFunctions(Service 'service, typedesc t = <>) returns t|error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; -isolated function callToolForRemoteFunctions(McpDeclarativeService 'service, CallToolParams params, typedesc t = <>) returns t|error = @java:Method { +isolated function callToolForRemoteFunctions(Service 'service, CallToolParams params, typedesc t = <>) returns t|error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; diff --git a/ballerina/types.bal b/ballerina/types.bal index 2d132e1..0ee5706 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -424,11 +424,11 @@ public type McpToolConfig record {| public annotation McpToolConfig McpTool on object function; # Defines a mcp service interface that handles incoming mcp requests. -public type McpService distinct isolated service object { +public type AdvancedService distinct isolated service object { remote isolated function onListTools() returns ListToolsResult|error; remote isolated function onCallTool(CallToolParams params) returns CallToolResult|error; }; -public type McpDeclarativeService distinct isolated service object { +public type Service distinct isolated service object { }; diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java index 3baef43..ed6ddb8 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java @@ -18,8 +18,10 @@ package io.ballerina.stdlib.mcp.plugin; +import io.ballerina.compiler.api.SemanticModel; import io.ballerina.compiler.syntax.tree.AnnotationNode; import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; +import io.ballerina.compiler.syntax.tree.IdentifierToken; import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode; import io.ballerina.compiler.syntax.tree.MetadataNode; import io.ballerina.compiler.syntax.tree.ModuleMemberDeclarationNode; @@ -42,14 +44,17 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import static io.ballerina.compiler.syntax.tree.SyntaxKind.AT_TOKEN; import static io.ballerina.compiler.syntax.tree.SyntaxKind.CLOSE_BRACE_TOKEN; import static io.ballerina.compiler.syntax.tree.SyntaxKind.COLON_TOKEN; import static io.ballerina.compiler.syntax.tree.SyntaxKind.OBJECT_METHOD_DEFINITION; import static io.ballerina.compiler.syntax.tree.SyntaxKind.OPEN_BRACE_TOKEN; -import static io.ballerina.compiler.syntax.tree.SyntaxKind.QUALIFIED_NAME_REFERENCE; import static io.ballerina.compiler.syntax.tree.SyntaxKind.SERVICE_DECLARATION; -import static io.ballerina.stdlib.mcp.plugin.RemoteFunctionAnalysisTask.EMPTY_STRING; +import static io.ballerina.stdlib.mcp.plugin.Utils.MCP_PACKAGE_NAME; +import static io.ballerina.stdlib.mcp.plugin.Utils.TOOL_ANNOTATION_NAME; +import static io.ballerina.stdlib.mcp.plugin.Utils.getToolAnnotationNode; public class McpSourceModifier implements ModifierTask { private final Map modifierContextMap; @@ -68,63 +73,55 @@ public void modify(SourceModifierContext context) { private void modifyDocumentWithTools(SourceModifierContext context, DocumentId documentId, ModifierContext modifierContext) { Module module = context.currentPackage().module(documentId.moduleId()); + SemanticModel semanticModel = context.compilation().getSemanticModel(documentId.moduleId()); ModulePartNode rootNode = module.document(documentId).syntaxTree().rootNode(); - ModulePartNode updatedRoot = modifyModulePartRoot(rootNode, modifierContext, documentId); + ModulePartNode updatedRoot = modifyModulePartRoot(semanticModel, rootNode, modifierContext, documentId); updateDocument(context, module, documentId, updatedRoot); } - private ModulePartNode modifyModulePartRoot(ModulePartNode modulePartNode, + private ModulePartNode modifyModulePartRoot(SemanticModel semanticModel, ModulePartNode modulePartNode, ModifierContext modifierContext, DocumentId documentId) { - List modifiedMembers = getModifiedModuleMembers(modulePartNode.members(), - modifierContext, documentId); + List modifiedMembers = getModifiedModuleMembers(semanticModel, + modulePartNode.members(), modifierContext); return modulePartNode.modify().withMembers(NodeFactory.createNodeList(modifiedMembers)).apply(); } - private List getModifiedModuleMembers(NodeList members, - ModifierContext modifierContext, - DocumentId documentId) { - Map modifiedAnnotations = getModifiedAnnotations(modifierContext); + private List getModifiedModuleMembers(SemanticModel semanticModel, + NodeList members, + ModifierContext modifierContext) { + Map modifiedAnnotations = getModifiedAnnotations(modifierContext); List modifiedMembers = new ArrayList<>(); for (ModuleMemberDeclarationNode member : members) { - modifiedMembers.add(getModifiedModuleMember(member, modifiedAnnotations)); + modifiedMembers.add(getModifiedModuleMember(semanticModel, member, modifiedAnnotations)); } return modifiedMembers; } - private Map getModifiedAnnotations(ModifierContext modifierContext) { - Map updatedAnnotationMap = new HashMap<>(); - for (Map.Entry entry : modifierContext + private Map getModifiedAnnotations(ModifierContext modifierContext) { + Map updatedAnnotationMap = new HashMap<>(); + for (Map.Entry entry : modifierContext .getAnnotationConfigMap().entrySet()) { - updatedAnnotationMap.put(entry.getKey(), getModifiedAnnotation(entry.getKey(), entry.getValue())); + updatedAnnotationMap.put(entry.getKey(), getModifiedAnnotation(entry.getValue())); } return updatedAnnotationMap; } - private AnnotationNode getModifiedAnnotation(AnnotationNode targetNode, ToolAnnotationConfig config) { + private AnnotationNode getModifiedAnnotation(ToolAnnotationConfig config) { + Token atToken = NodeFactory.createToken(AT_TOKEN); + + Token modulePrefix = NodeFactory.createIdentifierToken(MCP_PACKAGE_NAME); + Token colonToken = NodeFactory.createToken(COLON_TOKEN); + IdentifierToken identifier = NodeFactory.createIdentifierToken(TOOL_ANNOTATION_NAME); + QualifiedNameReferenceNode annotationReferenceNode = + NodeFactory.createQualifiedNameReferenceNode(modulePrefix, colonToken, identifier); + String mappingConstructorExpression = generateConfigMappingConstructor(config); MappingConstructorExpressionNode mappingConstructorNode = (MappingConstructorExpressionNode) NodeParser .parseExpression(mappingConstructorExpression); - Node annotationReference = targetNode.annotReference(); - if (annotationReference.kind() == QUALIFIED_NAME_REFERENCE) { - QualifiedNameReferenceNode qualifiedNameReferenceNode = (QualifiedNameReferenceNode) annotationReference; - String identifier = qualifiedNameReferenceNode.identifier().text().replaceAll("\\R", EMPTY_STRING); - String modulePrefix = qualifiedNameReferenceNode.modulePrefix().text(); - annotationReference = NodeFactory.createQualifiedNameReferenceNode( - NodeFactory.createIdentifierToken(modulePrefix), - NodeFactory.createToken(COLON_TOKEN), - NodeFactory.createIdentifierToken(identifier) - ); - Token closeBraceTokenWithNewLine = NodeFactory.createToken( - CLOSE_BRACE_TOKEN, - NodeFactory.createEmptyMinutiaeList(), - NodeFactory.createMinutiaeList( - NodeFactory.createEndOfLineMinutiae(System.lineSeparator()))); - mappingConstructorNode = mappingConstructorNode.modify().withCloseBrace(closeBraceTokenWithNewLine).apply(); - } - return NodeFactory.createAnnotationNode(targetNode.atToken(), annotationReference, mappingConstructorNode); + return NodeFactory.createAnnotationNode(atToken, annotationReferenceNode, mappingConstructorNode); } private String generateConfigMappingConstructor(ToolAnnotationConfig config) { @@ -134,35 +131,56 @@ private String generateConfigMappingConstructor(ToolAnnotationConfig config) { private String generateConfigMappingConstructor(ToolAnnotationConfig config, String openBraceSource, String closeBraceSource) { - return openBraceSource + String.format("description:%s,schema:%s", - config.description() != null ? config.description().replaceAll("\\R", " ") : "", - config.schema()) + closeBraceSource; + StringBuilder sb = new StringBuilder(); + sb.append(openBraceSource); + String desc = config.description().replaceAll("\\R", " "); + sb.append("description:").append(desc).append(","); + sb.append("schema:").append(config.schema()); + sb.append(closeBraceSource); + return sb.toString(); } - private ModuleMemberDeclarationNode getModifiedModuleMember(ModuleMemberDeclarationNode member, - Map modifiedAnnotations - ) { + private ModuleMemberDeclarationNode getModifiedModuleMember( + SemanticModel semanticModel, + ModuleMemberDeclarationNode member, + Map modifiedAnnotations) { if (member.kind() == SERVICE_DECLARATION) { - return modifyServiceDeclaration((ServiceDeclarationNode) member, modifiedAnnotations); + return modifyServiceDeclaration(semanticModel, (ServiceDeclarationNode) member, modifiedAnnotations); } return member; } - private ModuleMemberDeclarationNode modifyServiceDeclaration(ServiceDeclarationNode classDefinitionNode, - Map modifiedAnnotations) { + private ModuleMemberDeclarationNode modifyServiceDeclaration( + SemanticModel semanticModel, + ServiceDeclarationNode classDefinitionNode, + Map modifiedAnnotations) { + NodeList members = classDefinitionNode.members(); ArrayList modifiedMembers = new ArrayList<>(); for (Node member : members) { if (member.kind() == OBJECT_METHOD_DEFINITION) { - FunctionDefinitionNode methodDeclarationNode = (FunctionDefinitionNode) member; - if (methodDeclarationNode.metadata().isPresent()) { - MetadataNode modifiedMetadata = modifyMetadata(methodDeclarationNode.metadata().get(), - modifiedAnnotations); - methodDeclarationNode = methodDeclarationNode.modify().withMetadata(modifiedMetadata).apply(); + FunctionDefinitionNode functionDefinitionNode = (FunctionDefinitionNode) member; + AnnotationNode modifiedAnnotationNode = modifiedAnnotations.get(functionDefinitionNode); + if (functionDefinitionNode.metadata().isPresent()) { + MetadataNode functionMetadata = functionDefinitionNode.metadata().get(); + Optional toolAnnotationNode = + getToolAnnotationNode(semanticModel, functionDefinitionNode); + if (toolAnnotationNode.isPresent()) { + MetadataNode modifiedMetadata = modifyMetadata(functionMetadata, toolAnnotationNode.get(), + modifiedAnnotationNode); + functionDefinitionNode = functionDefinitionNode.modify().withMetadata(modifiedMetadata).apply(); + } else { + functionDefinitionNode = functionDefinitionNode.modify() + .withMetadata(modifyWithToolAnnotation(functionMetadata, modifiedAnnotationNode)) + .apply(); + } + } else { + functionDefinitionNode = functionDefinitionNode.modify() + .withMetadata(createMetadata(modifiedAnnotationNode)).apply(); } - modifiedMembers.add(methodDeclarationNode); + modifiedMembers.add(functionDefinitionNode); } else { modifiedMembers.add(member); } @@ -170,15 +188,33 @@ private ModuleMemberDeclarationNode modifyServiceDeclaration(ServiceDeclarationN return classDefinitionNode.modify().withMembers(NodeFactory.createNodeList(modifiedMembers)).apply(); } - private MetadataNode modifyMetadata(MetadataNode metadata, - Map modifiedAnnotations) { + private MetadataNode modifyWithToolAnnotation(MetadataNode metadata, AnnotationNode annotationNode) { + List updatedAnnotations = new ArrayList<>(); + metadata.annotations().forEach(updatedAnnotations::add); + updatedAnnotations.add(annotationNode); + return metadata.modify() + .withAnnotations(NodeFactory.createNodeList(updatedAnnotations)) + .apply(); + } + + private MetadataNode modifyMetadata(MetadataNode metadata, AnnotationNode toolAnnotationNode, + AnnotationNode modifiedAnnotationNode) { List updatedAnnotations = new ArrayList<>(); for (AnnotationNode annotation : metadata.annotations()) { - updatedAnnotations.add(modifiedAnnotations.getOrDefault(annotation, annotation)); + if (annotation.equals(toolAnnotationNode)) { + updatedAnnotations.add(modifiedAnnotationNode); + } else { + updatedAnnotations.add(annotation); + } } return metadata.modify().withAnnotations(NodeFactory.createNodeList(updatedAnnotations)).apply(); } + private MetadataNode createMetadata(AnnotationNode annotationNode) { + NodeList annotationNodeList = NodeFactory.createNodeList(annotationNode); + return NodeFactory.createMetadataNode(null, annotationNodeList); + } + private void updateDocument(SourceModifierContext context, Module module, DocumentId documentId, ModulePartNode updatedRoot) { SyntaxTree syntaxTree = module.document(documentId).syntaxTree().modifyWith(updatedRoot); diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java index 4b2b222..f06f306 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java @@ -18,19 +18,19 @@ package io.ballerina.stdlib.mcp.plugin; -import io.ballerina.compiler.syntax.tree.AnnotationNode; +import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; import java.util.HashMap; import java.util.Map; public class ModifierContext { - private final Map annotationConfigMap = new HashMap<>(); + private final Map annotationConfigMap = new HashMap<>(); - void add(AnnotationNode node, ToolAnnotationConfig config) { + void add(FunctionDefinitionNode node, ToolAnnotationConfig config) { annotationConfigMap.put(node, config); } - Map getAnnotationConfigMap() { + Map getAnnotationConfigMap() { return annotationConfigMap; } } diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java index 355b3a8..dc74c56 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java @@ -18,7 +18,6 @@ package io.ballerina.stdlib.mcp.plugin; -import io.ballerina.compiler.api.symbols.AnnotationSymbol; import io.ballerina.compiler.api.symbols.FunctionSymbol; import io.ballerina.compiler.api.symbols.Symbol; import io.ballerina.compiler.api.symbols.SymbolKind; @@ -26,9 +25,7 @@ import io.ballerina.compiler.syntax.tree.ExpressionNode; import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; import io.ballerina.compiler.syntax.tree.MappingFieldNode; -import io.ballerina.compiler.syntax.tree.MetadataNode; import io.ballerina.compiler.syntax.tree.NodeFactory; -import io.ballerina.compiler.syntax.tree.NodeList; import io.ballerina.compiler.syntax.tree.NodeParser; import io.ballerina.compiler.syntax.tree.SeparatedNodeList; import io.ballerina.compiler.syntax.tree.SpecificFieldNode; @@ -47,6 +44,7 @@ import static io.ballerina.stdlib.mcp.plugin.ToolAnnotationConfig.DESCRIPTION_FIELD_NAME; import static io.ballerina.stdlib.mcp.plugin.ToolAnnotationConfig.SCHEMA_FIELD_NAME; +import static io.ballerina.stdlib.mcp.plugin.Utils.getToolAnnotationNode; import static io.ballerina.stdlib.mcp.plugin.diagnostics.CompilationDiagnostic.UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION; public class RemoteFunctionAnalysisTask implements AnalysisTask { @@ -65,39 +63,32 @@ public void perform(SyntaxNodeAnalysisContext context) { this.context = context; FunctionDefinitionNode functionDefinitionNode = (FunctionDefinitionNode) context.node(); - Optional metadataNode = functionDefinitionNode.metadata(); - if (metadataNode.isEmpty()) { - return; - } - - NodeList annotationNodeList = metadataNode.get().annotations(); - Optional toolAnnotationNode = annotationNodeList.stream() - .filter(annotationNode -> - context.semanticModel().symbol(annotationNode) - .filter(symbol -> symbol.kind() == SymbolKind.ANNOTATION) - .filter(symbol -> Utils.isMcpToolAnnotation((AnnotationSymbol) symbol)) - .isPresent() - ) - .findFirst(); - if (toolAnnotationNode.isEmpty()) { - return; - } + AnnotationNode toolAnnotationNode = getToolAnnotationNode( + context.semanticModel(), functionDefinitionNode + ).orElse(null); - ToolAnnotationConfig config = createAnnotationConfig(toolAnnotationNode.get(), functionDefinitionNode); - addToModifierContext(context, toolAnnotationNode.get(), config); + ToolAnnotationConfig config = createAnnotationConfig(functionDefinitionNode, toolAnnotationNode); + addToModifierContext(context, functionDefinitionNode, config); } - private ToolAnnotationConfig createAnnotationConfig(AnnotationNode annotationNode, - FunctionDefinitionNode functionDefinitionNode) { + private ToolAnnotationConfig createAnnotationConfig(FunctionDefinitionNode functionDefinitionNode, + AnnotationNode annotationNode) { @SuppressWarnings("OptionalGetWithoutIsPresent") // is present already check in perform method FunctionSymbol functionSymbol = getFunctionSymbol(functionDefinitionNode).get(); String functionName = functionSymbol.getName().orElse("unknownFunction"); + String description = Utils.addDoubleQuotes( + Utils.escapeDoubleQuotes( + Objects.requireNonNullElse(Utils.getDescription(functionSymbol), functionName))); + if (annotationNode == null) { + String schema = getParameterSchema(functionSymbol, functionDefinitionNode.location()); + return new ToolAnnotationConfig(description, schema); + } SeparatedNodeList fields = annotationNode.annotValue().isEmpty() ? NodeFactory.createSeparatedNodeList() : annotationNode.annotValue().get().fields(); Map fieldValues = extractFieldValues(fields); - String description = fieldValues.containsKey(DESCRIPTION_FIELD_NAME) - ? fieldValues.get(DESCRIPTION_FIELD_NAME).toSourceCode() - : Utils.addDoubleQuotes(Objects.requireNonNullElse(Utils.getDescription(functionSymbol), functionName)); + if (fieldValues.containsKey(DESCRIPTION_FIELD_NAME)) { + description = fieldValues.get(DESCRIPTION_FIELD_NAME).toSourceCode(); + } String parameters = fieldValues.containsKey(SCHEMA_FIELD_NAME) ? fieldValues.get(SCHEMA_FIELD_NAME).toSourceCode() : getParameterSchema(functionSymbol, functionDefinitionNode.location()); @@ -137,9 +128,9 @@ private void reportDiagnostic(Diagnostic diagnostic) { this.context.reportDiagnostic(diagnostic); } - private void addToModifierContext(SyntaxNodeAnalysisContext context, AnnotationNode annotationNode, - ToolAnnotationConfig functionDefinitionNode) { + private void addToModifierContext(SyntaxNodeAnalysisContext context, FunctionDefinitionNode functionDefinitionNode, + ToolAnnotationConfig toolAnnotationConfig) { this.modifierContextMap.computeIfAbsent(context.documentId(), document -> new ModifierContext()) - .add(annotationNode, functionDefinitionNode); + .add(functionDefinitionNode, toolAnnotationConfig); } } diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java index b4fc750..b1ffc9d 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java @@ -18,18 +18,26 @@ package io.ballerina.stdlib.mcp.plugin; +import io.ballerina.compiler.api.SemanticModel; import io.ballerina.compiler.api.symbols.AnnotationSymbol; import io.ballerina.compiler.api.symbols.Documentable; import io.ballerina.compiler.api.symbols.FunctionSymbol; import io.ballerina.compiler.api.symbols.Symbol; +import io.ballerina.compiler.api.symbols.SymbolKind; +import io.ballerina.compiler.syntax.tree.AnnotationNode; +import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; +import io.ballerina.compiler.syntax.tree.MetadataNode; +import io.ballerina.compiler.syntax.tree.NodeList; + +import java.util.Optional; /** * Util class for the compiler plugin. */ public class Utils { public static final String BALLERINA_ORG = "ballerina"; - private static final String TOOL_ANNOTATION_NAME = "McpTool"; - private static final String MCP_PACKAGE_NAME = "mcp"; + public static final String TOOL_ANNOTATION_NAME = "McpTool"; + public static final String MCP_PACKAGE_NAME = "mcp"; private Utils() { } @@ -63,7 +71,30 @@ public static String getDescription(Documentable documentable) { return documentable.documentation().get().description().get(); } - public static String addDoubleQuotes(String functionName) { - return "\"" + functionName + "\""; + public static String escapeDoubleQuotes(String input) { + return input.replace("\"", "\\\""); + } + + + public static String addDoubleQuotes(String input) { + return "\"" + input + "\""; + } + + public static Optional getToolAnnotationNode(SemanticModel semanticModel, + FunctionDefinitionNode functionDefinitionNode) { + Optional metadataNode = functionDefinitionNode.metadata(); + if (metadataNode.isEmpty()) { + return Optional.empty(); + } + + NodeList annotationNodes = metadataNode.get().annotations(); + return annotationNodes.stream() + .filter(annotationNode -> + semanticModel.symbol(annotationNode) + .filter(symbol -> symbol.kind() == SymbolKind.ANNOTATION) + .filter(symbol -> Utils.isMcpToolAnnotation((AnnotationSymbol) symbol)) + .isPresent() + ) + .findFirst(); } } From 27f499abcda835f3fe536bd6f4f59840d5870db3 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 26 Jun 2025 12:32:20 +0530 Subject: [PATCH 08/31] [Automated] Update the native jar versions From f12d383a9bf3d069fa0cf18201496ad93762fd82 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 26 Jun 2025 12:34:58 +0530 Subject: [PATCH 09/31] Fix compilation error --- ballerina/CompilerPlugin.toml | 2 +- ballerina/listener.bal | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ballerina/CompilerPlugin.toml b/ballerina/CompilerPlugin.toml index f770d36..fd94c5c 100644 --- a/ballerina/CompilerPlugin.toml +++ b/ballerina/CompilerPlugin.toml @@ -3,7 +3,7 @@ id = "mcp-compiler-plugin" class = "io.ballerina.stdlib.mcp.plugin.McpCompilerPlugin" [[dependency]] -path = "../compiler-plugin/build/libs/mcp-compiler-plugin-0.4.1-SNAPSHOT.jar" +path = "../compiler-plugin/build/libs/mcp-compiler-plugin-0.4.2-SNAPSHOT.jar" [[dependency]] path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar" diff --git a/ballerina/listener.bal b/ballerina/listener.bal index f96c57a..11e3b10 100644 --- a/ballerina/listener.bal +++ b/ballerina/listener.bal @@ -18,11 +18,12 @@ import ballerina/http; # Represents the options for configuring an MCP server. public type ServerOptions record {| - *ProtocolOptions; # Capabilities to advertise as being supported by this server. ServerCapabilities capabilities?; # Optional instructions describing how to use the server and its features. string instructions?; + # Whether to enforce strict capabilities compliance. + boolean enforceStrictCapabilities?; |}; # Configuration options for initializing an MCP listener. From 877e7d1a5caec2844e36ca60cc440ec563922a55 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 26 Jun 2025 12:44:42 +0530 Subject: [PATCH 10/31] Fix naming issue --- ballerina/main.bal | 88 ------------------- ballerina/native_listener_helper.bal | 16 ++++ .../stdlib/mcp/plugin/McpCompilerPlugin.java | 1 - .../diagnostics/CompilationDiagnostic.java | 2 +- .../plugin/diagnostics/DiagnosticCode.java | 4 +- .../plugin/diagnostics/DiagnosticMessage.java | 4 +- 6 files changed, 21 insertions(+), 94 deletions(-) delete mode 100644 ballerina/main.bal diff --git a/ballerina/main.bal b/ballerina/main.bal deleted file mode 100644 index 953154b..0000000 --- a/ballerina/main.bal +++ /dev/null @@ -1,88 +0,0 @@ - -// listener Listener mcpListener = check new (9090, serverConfigs = { -// serverInfo: { -// name: "MCP Server", -// version: "1.0.0" -// }, -// options: {capabilities: {}} -// }); - -// service /mcp on mcpListener { -// remote isolated function onListTools() returns ListToolsResult|error { -// return { -// tools: [ -// { -// name: "single-greet", -// description: "Greet the user once", -// inputSchema: { -// 'type: "object", -// properties: { -// "name": {"type": "string", "description": "Name to greet"} -// }, -// required: ["name"] -// } -// }, -// { -// name: "multi-greet", -// description: "Greet the user multiple times with delay in between.", -// inputSchema: { -// 'type: "object", -// properties: { -// "name": {"type": "string", "description": "Name to greet"} -// }, -// required: ["name"] -// } -// } -// ] -// }; -// } - -// remote isolated function onCallTool(CallToolParams params) returns CallToolResult|error { -// string name = check (params.arguments["name"]).cloneWithType(); -// if params.name == "single-greet" { -// // Note: Can do any external function calls here, -// TextContent textContent = { -// 'type: "text", -// text: string `Hey ${name}! Welcome to itsuki's world!` -// }; -// return { -// content: [textContent] -// }; -// } else if params.name == "multi-greet" { -// // Note: Can do any external function calls here, -// TextContent textContent = { -// 'type: "text", -// text: string `Hey ${name}! Hope you enjoy your day!` -// }; -// return { -// content: [textContent] -// }; -// } else { -// return error("Unknown tool: " + params.name); -// } -// } -// } - -listener Listener basicListener = check new (9092, serverInfo = {name: "Basic MCP Server", version: "1.0.0"}); - -isolated service Service /mcp on basicListener { - @McpTool { - description: "Add two numbers", - schema: { - 'type: "object", - properties: { - "a": {"type": "integer", "description": "First number"}, - "b": {"type": "integer", "description": "Second number"} - }, - required: ["a", "b"] - } - } - remote function add(int a, int b) returns int { - return a + b; - } - - @McpTool - remote function add1(int a, int b) returns int { - return a + b; - } -} diff --git a/ballerina/native_listener_helper.bal b/ballerina/native_listener_helper.bal index 614e1cf..1ab9d78 100644 --- a/ballerina/native_listener_helper.bal +++ b/ballerina/native_listener_helper.bal @@ -1,3 +1,19 @@ +// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + import ballerina/jballerina.java; isolated function invokeOnListTools(AdvancedService 'service) returns ListToolsResult|error = @java:Method { diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java index 538c212..723215b 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java @@ -22,7 +22,6 @@ import io.ballerina.projects.plugins.CompilerPluginContext; public class McpCompilerPlugin extends CompilerPlugin { - @Override public void init(CompilerPluginContext compilerPluginContext) { compilerPluginContext.addCodeModifier(new McpCodeModifier()); diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java index 5b172bd..1fe6701 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java @@ -30,7 +30,7 @@ * Compilation errors in the Ballerina AI package. */ public enum CompilationDiagnostic { - UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION(DiagnosticMessage.ERROR_101, DiagnosticCode.AI_101, ERROR); + UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION(DiagnosticMessage.ERROR_101, DiagnosticCode.MCP_101, ERROR); private final String diagnostic; private final String diagnosticCode; diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java index a9b8cc5..9635a96 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java @@ -19,8 +19,8 @@ package io.ballerina.stdlib.mcp.plugin.diagnostics; /** - * Compilation error codes used in Ballerina AI package compiler plugin. + * Compilation error codes used in Ballerina mcp package compiler plugin. */ public enum DiagnosticCode { - AI_101 + MCP_101 } diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java index 61187f5..7a0fc5e 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java @@ -19,11 +19,11 @@ package io.ballerina.stdlib.mcp.plugin.diagnostics; /** - * Compilation error messages used in Ballerina AI package compiler plugin. + * Compilation error messages used in Ballerina mcp package compiler plugin. */ public enum DiagnosticMessage { ERROR_101("failed to generate the parameter schema definition for the function ''{0}''." + - " Specify the parameter schema manually using the `@ai:AgentTool` annotation's parameter field."); + " Specify the parameter schema manually using the `@mcp:McpTool` annotation's parameter field."); private final String message; From 42b72db62439e627eb9cd704452c1c9b5cb909d6 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 26 Jun 2025 12:52:56 +0530 Subject: [PATCH 11/31] Fix function symbol nullity --- .../mcp/plugin/RemoteFunctionAnalysisTask.java | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java index dc74c56..651add2 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java @@ -26,6 +26,7 @@ import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; import io.ballerina.compiler.syntax.tree.MappingFieldNode; import io.ballerina.compiler.syntax.tree.NodeFactory; +import io.ballerina.compiler.syntax.tree.NodeLocation; import io.ballerina.compiler.syntax.tree.NodeParser; import io.ballerina.compiler.syntax.tree.SeparatedNodeList; import io.ballerina.compiler.syntax.tree.SpecificFieldNode; @@ -67,20 +68,25 @@ public void perform(SyntaxNodeAnalysisContext context) { context.semanticModel(), functionDefinitionNode ).orElse(null); - ToolAnnotationConfig config = createAnnotationConfig(functionDefinitionNode, toolAnnotationNode); + NodeLocation functionNodeLocation = functionDefinitionNode.location(); + Optional functionSymbol = getFunctionSymbol(functionDefinitionNode); + if (functionSymbol.isEmpty()) { + return; + } + ToolAnnotationConfig config = createAnnotationConfig(functionSymbol.get(), functionNodeLocation, + toolAnnotationNode); addToModifierContext(context, functionDefinitionNode, config); } - private ToolAnnotationConfig createAnnotationConfig(FunctionDefinitionNode functionDefinitionNode, + private ToolAnnotationConfig createAnnotationConfig(FunctionSymbol functionSymbol, + NodeLocation functionNodeLocation, AnnotationNode annotationNode) { - @SuppressWarnings("OptionalGetWithoutIsPresent") // is present already check in perform method - FunctionSymbol functionSymbol = getFunctionSymbol(functionDefinitionNode).get(); String functionName = functionSymbol.getName().orElse("unknownFunction"); String description = Utils.addDoubleQuotes( Utils.escapeDoubleQuotes( Objects.requireNonNullElse(Utils.getDescription(functionSymbol), functionName))); if (annotationNode == null) { - String schema = getParameterSchema(functionSymbol, functionDefinitionNode.location()); + String schema = getParameterSchema(functionSymbol, functionNodeLocation); return new ToolAnnotationConfig(description, schema); } SeparatedNodeList fields = annotationNode.annotValue().isEmpty() ? @@ -91,7 +97,7 @@ private ToolAnnotationConfig createAnnotationConfig(FunctionDefinitionNode funct } String parameters = fieldValues.containsKey(SCHEMA_FIELD_NAME) ? fieldValues.get(SCHEMA_FIELD_NAME).toSourceCode() - : getParameterSchema(functionSymbol, functionDefinitionNode.location()); + : getParameterSchema(functionSymbol, functionNodeLocation); return new ToolAnnotationConfig(description, parameters); } From 575094901a16e490959ec9701877688dabc77cfc Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 26 Jun 2025 19:23:24 +0530 Subject: [PATCH 12/31] Fix error codes --- ballerina/dispatcher_service.bal | 106 +++++++++++++++++++++---------- ballerina/error.bal | 5 +- ballerina/types.bal | 4 ++ 3 files changed, 79 insertions(+), 36 deletions(-) diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index 461988d..88cd557 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -17,9 +17,6 @@ import ballerina/http; import ballerina/uuid; -# Custom error type for dispatcher service operations. -type DispatcherError distinct error; - # Represents the dispatcher service type definition. type DispatcherService distinct service object { *http:Service; @@ -54,10 +51,10 @@ DispatcherService dispatcherService = isolated service object { } isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) - returns http:BadRequest|http:Accepted|http:Ok|error { + returns http:BadRequest|http:NotAcceptable|http:UnsupportedMediaType|http:Accepted|http:Ok|ServerError { - http:BadRequest? headerValidationError = self.validateHeaders(headers); - if headerValidationError is http:BadRequest { + http:NotAcceptable|http:UnsupportedMediaType? headerValidationError = self.validateHeaders(headers); + if !(headerValidationError is ()) { return headerValidationError; } @@ -69,39 +66,56 @@ DispatcherService dispatcherService = isolated service object { return self.processJsonRpcNotification(request.cloneReadOnly()); } - return self.createErrorResponse(null, INVALID_REQUEST, "Unsupported request type"); + JsonRpcError jsonRpcError = self.createErrorResponse(null, INVALID_REQUEST, "Unsupported request type"); + return { + body: jsonRpcError.cloneReadOnly() + }; } } - private isolated function validateHeaders(http:Headers headers) returns http:BadRequest? { + private isolated function validateHeaders(http:Headers headers) + returns http:NotAcceptable|http:UnsupportedMediaType? { + // Validate Accept header string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); if acceptHeader is http:HeaderNotFoundError { - return self.createErrorResponse(1, -32000, + JsonRpcError jsonRpcError = self.createErrorResponse(1, NOT_ACCEPTABLE, "Not Acceptable: Client must accept both application/json and text/event-stream"); + return { + body: jsonRpcError + }; } if !acceptHeader.includes(CONTENT_TYPE_JSON) || !acceptHeader.includes(CONTENT_TYPE_SSE) { - return self.createErrorResponse(1, -32000, + JsonRpcError jsonRpcError = self.createErrorResponse(1, NOT_ACCEPTABLE, "Not Acceptable: Client must accept both application/json and text/event-stream"); + return { + body: jsonRpcError + }; } // Validate Content-Type header string|http:HeaderNotFoundError contentTypeHeader = headers.getHeader(CONTENT_TYPE_HEADER); if contentTypeHeader is http:HeaderNotFoundError { - return self.createErrorResponse(null, -32000, + JsonRpcError jsonRpcError = self.createErrorResponse(null, UNSUPPORTED_MEDIA_TYPE, "Unsupported Media Type: Content-Type must be application/json"); + return { + body: jsonRpcError + }; } if !contentTypeHeader.includes(CONTENT_TYPE_JSON) { - return self.createErrorResponse(null, -32000, + JsonRpcError jsonRpcError = self.createErrorResponse(null, UNSUPPORTED_MEDIA_TYPE, "Unsupported Media Type: Content-Type must be application/json"); + return { + body: jsonRpcError + }; } return (); } - private isolated function processJsonRpcRequest(JsonRpcRequest request) returns http:BadRequest|http:Ok|error { + private isolated function processJsonRpcRequest(JsonRpcRequest request) returns http:BadRequest|http:Ok { match request.method { "initialize" => { return self.handleInitializeRequest(request); @@ -113,7 +127,10 @@ DispatcherService dispatcherService = isolated service object { return self.handleCallToolRequest(request); } _ => { - return self.createErrorResponse(request.id, METHOD_NOT_FOUND, "Method not found"); + JsonRpcError jsonRpcError = self.createErrorResponse(request.id, METHOD_NOT_FOUND, "Method not found"); + return { + body: jsonRpcError + }; } } } @@ -141,16 +158,22 @@ DispatcherService dispatcherService = isolated service object { JsonRpcRequest {jsonrpc, id, ...request} = jsonRpcRequest; InitializeRequest|error initRequest = request.cloneWithType(InitializeRequest); if initRequest is error { - return self.createErrorResponse(id, INVALID_REQUEST, + JsonRpcError jsonRpcError = self.createErrorResponse(id, INVALID_REQUEST, string `Invalid request: ${initRequest.message()}`); + return { + body: jsonRpcError + }; } lock { // If it's a server with session management and the session ID is already set we should reject the request // to avoid re-initialization. if self.isInitialized && self.sessionId != () { - return self.createErrorResponse(id, INVALID_REQUEST, - "Invalid Request: Only one initialization request is allowed"); + JsonRpcError jsonRpcError = self.createErrorResponse(id, INVALID_REQUEST, + "Invalid Request: Only one initialization request is allowed"); + return { + body: jsonRpcError.cloneReadOnly() + }; } self.isInitialized = true; @@ -183,14 +206,20 @@ DispatcherService dispatcherService = isolated service object { lock { // Check if initialized if !self.isInitialized { - return self.createErrorResponse(request.id, INVALID_REQUEST, - "Client must be initialized before making requests"); + JsonRpcError jsonRpcError = self.createErrorResponse(request.id, INVALID_REQUEST, + "Client must be initialized before making requests"); + return { + body: jsonRpcError.cloneReadOnly() + }; } ListToolsResult|error listToolsResult = self.executeOnListTools(); if listToolsResult is error { - return self.createErrorResponse(request.id, INTERNAL_ERROR, - string `Failed to list tools: ${listToolsResult.message()}`); + JsonRpcError jsonRpcError = self.createErrorResponse(request.id, INTERNAL_ERROR, + string `Failed to list tools: ${listToolsResult.message()}`); + return { + body: jsonRpcError.cloneReadOnly() + }; } return { @@ -210,21 +239,30 @@ DispatcherService dispatcherService = isolated service object { lock { // Check if initialized if !self.isInitialized { - return self.createErrorResponse(request.id, INVALID_REQUEST, - "Client must be initialized before making requests"); + JsonRpcError jsonRpcError = self.createErrorResponse(request.id, INVALID_REQUEST, + "Client must be initialized before making requests"); + return { + body: jsonRpcError.cloneReadOnly() + }; } // Extract and validate parameters CallToolParams|error params = request.cloneReadOnly().params.ensureType(CallToolParams); if params is error { - return self.createErrorResponse(request.id, INVALID_PARAMS, - string `Invalid parameters: ${params.message()}`); + JsonRpcError jsonRpcError = self.createErrorResponse(request.id, INVALID_PARAMS, + string `Invalid parameters: ${params.message()}`); + return { + body: jsonRpcError.cloneReadOnly() + }; } CallToolResult|error callToolResult = self.executeOnCallTool(params); if callToolResult is error { - return self.createErrorResponse(request.id, INTERNAL_ERROR, - string `Failed to call tool '${params.name}': ${callToolResult.message()}`); + JsonRpcError jsonRpcError = self.createErrorResponse(request.id, INTERNAL_ERROR, + string `Failed to call tool '${params.name}': ${callToolResult.message()}`); + return { + body: jsonRpcError.cloneReadOnly() + }; } return { @@ -249,15 +287,13 @@ DispatcherService dispatcherService = isolated service object { return LATEST_PROTOCOL_VERSION; } - private isolated function createErrorResponse(RequestId? id, int code, string message) returns http:BadRequest { + private isolated function createErrorResponse(RequestId? id, int code, string message) returns JsonRpcError { return { - body: { - jsonrpc: JSONRPC_VERSION, - id: id, - 'error: { - code: code, - message: message - } + jsonrpc: JSONRPC_VERSION, + id: id, + 'error: { + code: code, + message: message } }; } diff --git a/ballerina/error.bal b/ballerina/error.bal index bd75ca8..6cb0290 100644 --- a/ballerina/error.bal +++ b/ballerina/error.bal @@ -80,5 +80,8 @@ public type ListToolsError distinct ClientError; # Error for failures during tool execution operations. public type ToolCallError distinct ClientError; -# Errors for failures occuring during server operations. +# Errors for failures occurring during server operations. public type ServerError distinct Error; + +# Custom error type for dispatcher service operations. +public type DispatcherError distinct ServerError; diff --git a/ballerina/types.bal b/ballerina/types.bal index 0ee5706..ce79376 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -106,6 +106,10 @@ public const METHOD_NOT_FOUND = -32601; public const INVALID_PARAMS = -32602; public const INTERNAL_ERROR = -32603; +// Library-defined error codes +public const NOT_ACCEPTABLE = -32001; +public const UNSUPPORTED_MEDIA_TYPE = -32002; + # A response to a request that indicates an error occurred. public type JsonRpcError record { # The JSON-RPC protocol version From 5aa1b267432150ac94f2371ce2da1aaf4ee22868 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Mon, 30 Jun 2025 11:07:14 +0530 Subject: [PATCH 13/31] Remove unnecessary records --- ballerina/types.bal | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/ballerina/types.bal b/ballerina/types.bal index ce79376..387dedc 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -127,24 +127,6 @@ public type JsonRpcError record { } 'error; }; -# A response that indicates success but carries no data. -public type EmptyResult Result; - -# This notification can be sent by either side to indicate that it is cancelling a previously-issued request. -public type CancelledNotification record {| - *Notification; - # The method name for this notification - NOTIFICATION_CANCELLED method; - # The parameters for the cancellation notification - record {| - # The ID of the request to cancel. - # This MUST correspond to the ID of a request previously issued in the same direction. - RequestId requestId; - # An optional string describing the reason for the cancellation. This MAY be logged or presented to the user. - string? reason = (); - |} params; -|}; - # This request is sent from the client to the server when it first connects, asking it to begin initialization. type InitializeRequest record {| *Request; @@ -414,7 +396,7 @@ public type AudioContent record { }; # Represents a result sent from the server to the client. -public type ServerResult InitializeResult|CallToolResult|ListToolsResult|EmptyResult; +public type ServerResult InitializeResult|CallToolResult|ListToolsResult; # Represents a tool configuration that can be used to define tools available in the MCP service. public type McpToolConfig record {| From 98f5de5c71142d3ba24173e28c12cb61def1f8eb Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Mon, 30 Jun 2025 11:54:50 +0530 Subject: [PATCH 14/31] Address review comments --- ballerina/dispatcher_service.bal | 183 +++++++++--------- ballerina/error.bal | 2 +- ballerina/listener.bal | 43 ++-- ballerina/native_listener_helper.bal | 11 +- ballerina/types.bal | 26 ++- .../diagnostics/CompilationDiagnostic.java | 2 +- .../stdlib/mcp/McpServiceMethodHelper.java | 59 +++--- .../io/ballerina/stdlib/mcp/ModuleUtils.java | 8 + .../stdlib/mcp/SseEventStreamHelper.java | 10 +- 9 files changed, 190 insertions(+), 154 deletions(-) diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index 88cd557..ffbbc57 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -52,34 +52,30 @@ DispatcherService dispatcherService = isolated service object { isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) returns http:BadRequest|http:NotAcceptable|http:UnsupportedMediaType|http:Accepted|http:Ok|ServerError { - http:NotAcceptable|http:UnsupportedMediaType? headerValidationError = self.validateHeaders(headers); - if !(headerValidationError is ()) { + if headerValidationError !is () { return headerValidationError; } - lock { - if request is JsonRpcRequest { - return self.processJsonRpcRequest(request.cloneReadOnly()); - } - else if request is JsonRpcNotification { - return self.processJsonRpcNotification(request.cloneReadOnly()); - } - - JsonRpcError jsonRpcError = self.createErrorResponse(null, INVALID_REQUEST, "Unsupported request type"); - return { - body: jsonRpcError.cloneReadOnly() - }; + if request is JsonRpcRequest { + return self.processJsonRpcRequest(request); } + if request is JsonRpcNotification { + return self.processJsonRpcNotification(request); + } + + JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INVALID_REQUEST, "Unsupported request type"); + return { + body: jsonRpcError + }; } private isolated function validateHeaders(http:Headers headers) returns http:NotAcceptable|http:UnsupportedMediaType? { - // Validate Accept header string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); if acceptHeader is http:HeaderNotFoundError { - JsonRpcError jsonRpcError = self.createErrorResponse(1, NOT_ACCEPTABLE, + JsonRpcError jsonRpcError = self.createJsonRpcError(NOT_ACCEPTABLE, "Not Acceptable: Client must accept both application/json and text/event-stream"); return { body: jsonRpcError @@ -87,7 +83,7 @@ DispatcherService dispatcherService = isolated service object { } if !acceptHeader.includes(CONTENT_TYPE_JSON) || !acceptHeader.includes(CONTENT_TYPE_SSE) { - JsonRpcError jsonRpcError = self.createErrorResponse(1, NOT_ACCEPTABLE, + JsonRpcError jsonRpcError = self.createJsonRpcError(NOT_ACCEPTABLE, "Not Acceptable: Client must accept both application/json and text/event-stream"); return { body: jsonRpcError @@ -97,7 +93,7 @@ DispatcherService dispatcherService = isolated service object { // Validate Content-Type header string|http:HeaderNotFoundError contentTypeHeader = headers.getHeader(CONTENT_TYPE_HEADER); if contentTypeHeader is http:HeaderNotFoundError { - JsonRpcError jsonRpcError = self.createErrorResponse(null, UNSUPPORTED_MEDIA_TYPE, + JsonRpcError jsonRpcError = self.createJsonRpcError(UNSUPPORTED_MEDIA_TYPE, "Unsupported Media Type: Content-Type must be application/json"); return { body: jsonRpcError @@ -105,29 +101,29 @@ DispatcherService dispatcherService = isolated service object { } if !contentTypeHeader.includes(CONTENT_TYPE_JSON) { - JsonRpcError jsonRpcError = self.createErrorResponse(null, UNSUPPORTED_MEDIA_TYPE, + JsonRpcError jsonRpcError = self.createJsonRpcError(UNSUPPORTED_MEDIA_TYPE, "Unsupported Media Type: Content-Type must be application/json"); return { body: jsonRpcError }; } - return (); + return; } private isolated function processJsonRpcRequest(JsonRpcRequest request) returns http:BadRequest|http:Ok { match request.method { - "initialize" => { + REQUEST_INITIALIZE => { return self.handleInitializeRequest(request); } - "tools/list" => { + REQUEST_LIST_TOOLS => { return self.handleListToolsRequest(request); } - "tools/call" => { + REQUEST_CALL_TOOL => { return self.handleCallToolRequest(request); } _ => { - JsonRpcError jsonRpcError = self.createErrorResponse(request.id, METHOD_NOT_FOUND, "Method not found"); + JsonRpcError jsonRpcError = self.createJsonRpcError(METHOD_NOT_FOUND, "Method not found", request.id); return { body: jsonRpcError }; @@ -136,30 +132,27 @@ DispatcherService dispatcherService = isolated service object { } private isolated function processJsonRpcNotification(JsonRpcNotification notification) returns http:Accepted|http:BadRequest { - match notification.method { - "notifications/initialized" => { - return http:ACCEPTED; - } - _ => { - return { - body: { - jsonrpc: JSONRPC_VERSION, - 'error: { - code: METHOD_NOT_FOUND, - message: "Unknown notification method" - } - } - }; - } + if notification.method == "notifications/initialized" { + return http:ACCEPTED; } + + return { + body: { + jsonrpc: JSONRPC_VERSION, + 'error: { + code: METHOD_NOT_FOUND, + message: "Unknown notification method" + } + } + }; } private isolated function handleInitializeRequest(JsonRpcRequest jsonRpcRequest) returns http:BadRequest|http:Ok { - JsonRpcRequest {jsonrpc, id, ...request} = jsonRpcRequest; - InitializeRequest|error initRequest = request.cloneWithType(InitializeRequest); + JsonRpcRequest {jsonrpc: _, id, ...request} = jsonRpcRequest; + InitializeRequest|error initRequest = request.cloneWithType(); if initRequest is error { - JsonRpcError jsonRpcError = self.createErrorResponse(id, INVALID_REQUEST, - string `Invalid request: ${initRequest.message()}`); + JsonRpcError jsonRpcError = self.createJsonRpcError(INVALID_REQUEST, + string `Invalid request: ${initRequest.message()}`, id); return { body: jsonRpcError }; @@ -169,10 +162,19 @@ DispatcherService dispatcherService = isolated service object { // If it's a server with session management and the session ID is already set we should reject the request // to avoid re-initialization. if self.isInitialized && self.sessionId != () { - JsonRpcError jsonRpcError = self.createErrorResponse(id, INVALID_REQUEST, - "Invalid Request: Only one initialization request is allowed"); + JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INVALID_REQUEST, + "Invalid Request: Only one initialization request is allowed", id); return { - body: jsonRpcError.cloneReadOnly() + body: jsonRpcError + }; + } + + ServerConfiguration? serverConfigs = self.serverConfigs; + if serverConfigs is () { + JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INTERNAL_ERROR, + "Internal Error: Server configuration is not set", id); + return { + body: jsonRpcError }; } @@ -191,11 +193,8 @@ DispatcherService dispatcherService = isolated service object { id: id, result: { protocolVersion: protocolVersion, - capabilities: (self.serverConfigs?.options?.capabilities ?: {}).cloneReadOnly(), - serverInfo: (self.serverConfigs?.serverInfo ?: { - name: "MCP Server", - version: "1.0.0" - }).cloneReadOnly() + capabilities: (serverConfigs.options?.capabilities ?: {}).cloneReadOnly(), + serverInfo: serverConfigs.serverInfo.cloneReadOnly() } } }; @@ -206,22 +205,24 @@ DispatcherService dispatcherService = isolated service object { lock { // Check if initialized if !self.isInitialized { - JsonRpcError jsonRpcError = self.createErrorResponse(request.id, INVALID_REQUEST, - "Client must be initialized before making requests"); + JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INVALID_REQUEST, + "Client must be initialized before making requests", request.id); return { - body: jsonRpcError.cloneReadOnly() + body: jsonRpcError }; } + } - ListToolsResult|error listToolsResult = self.executeOnListTools(); - if listToolsResult is error { - JsonRpcError jsonRpcError = self.createErrorResponse(request.id, INTERNAL_ERROR, - string `Failed to list tools: ${listToolsResult.message()}`); - return { - body: jsonRpcError.cloneReadOnly() - }; - } + ListToolsResult|error listToolsResult = self.executeOnListTools(); + if listToolsResult is error { + JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INTERNAL_ERROR, + string `Failed to list tools: ${listToolsResult.message()}`, request.id); + return { + body: jsonRpcError + }; + } + lock { return { headers: { [SESSION_ID_HEADER]: self.sessionId ?: "" @@ -239,32 +240,34 @@ DispatcherService dispatcherService = isolated service object { lock { // Check if initialized if !self.isInitialized { - JsonRpcError jsonRpcError = self.createErrorResponse(request.id, INVALID_REQUEST, - "Client must be initialized before making requests"); + JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INVALID_REQUEST, + "Client must be initialized before making requests", request.id); return { - body: jsonRpcError.cloneReadOnly() + body: jsonRpcError }; } + } - // Extract and validate parameters - CallToolParams|error params = request.cloneReadOnly().params.ensureType(CallToolParams); - if params is error { - JsonRpcError jsonRpcError = self.createErrorResponse(request.id, INVALID_PARAMS, - string `Invalid parameters: ${params.message()}`); - return { - body: jsonRpcError.cloneReadOnly() - }; - } + // Extract and validate parameters + CallToolParams|error params = request.params.ensureType(CallToolParams); + if params is error { + JsonRpcError jsonRpcError = self.createJsonRpcError(INVALID_PARAMS, + string `Invalid parameters: ${params.message()}`, request.id); + return { + body: jsonRpcError + }; + } - CallToolResult|error callToolResult = self.executeOnCallTool(params); - if callToolResult is error { - JsonRpcError jsonRpcError = self.createErrorResponse(request.id, INTERNAL_ERROR, - string `Failed to call tool '${params.name}': ${callToolResult.message()}`); - return { - body: jsonRpcError.cloneReadOnly() - }; - } + CallToolResult|error callToolResult = self.executeOnCallTool(params); + if callToolResult is error { + JsonRpcError jsonRpcError = self.createJsonRpcError(INTERNAL_ERROR, + string `Failed to call tool '${params.name}': ${callToolResult.message()}`, request.id); + return { + body: jsonRpcError + }; + } + lock { return { headers: { [SESSION_ID_HEADER]: self.sessionId ?: "" @@ -287,7 +290,7 @@ DispatcherService dispatcherService = isolated service object { return LATEST_PROTOCOL_VERSION; } - private isolated function createErrorResponse(RequestId? id, int code, string message) returns JsonRpcError { + private isolated function createJsonRpcError(int code, string message, RequestId? id = ()) returns JsonRpcError & readonly { return { jsonrpc: JSONRPC_VERSION, id: id, @@ -302,9 +305,10 @@ DispatcherService dispatcherService = isolated service object { lock { Service|AdvancedService? mcpService = self.mcpService; if mcpService is AdvancedService { - return check invokeOnListTools(mcpService); - } else if mcpService is Service { - return check listToolsForRemoteFunctions(mcpService); + return invokeOnListTools(mcpService); + } + if mcpService is Service { + return listToolsForRemoteFunctions(mcpService); } return error DispatcherError("MCP Service is not attached"); } @@ -314,9 +318,10 @@ DispatcherService dispatcherService = isolated service object { lock { Service|AdvancedService? mcpService = self.mcpService; if mcpService is AdvancedService { - return check invokeOnCallTool(mcpService, params.cloneReadOnly()); - } else if mcpService is Service { - return check callToolForRemoteFunctions(mcpService, params.cloneReadOnly()); + return invokeOnCallTool(mcpService, params.cloneReadOnly()); + } + if mcpService is Service { + return callToolForRemoteFunctions(mcpService, params.cloneReadOnly()); } return error DispatcherError("MCP Service is not attached"); } diff --git a/ballerina/error.bal b/ballerina/error.bal index 6cb0290..7845795 100644 --- a/ballerina/error.bal +++ b/ballerina/error.bal @@ -84,4 +84,4 @@ public type ToolCallError distinct ClientError; public type ServerError distinct Error; # Custom error type for dispatcher service operations. -public type DispatcherError distinct ServerError; +type DispatcherError distinct ServerError; diff --git a/ballerina/listener.bal b/ballerina/listener.bal index 11e3b10..d360414 100644 --- a/ballerina/listener.bal +++ b/ballerina/listener.bal @@ -47,12 +47,16 @@ public class Listener { # + listenTo - Either a port number (int) or an existing http:Listener. # + config - Listener configuration. # + return - error? if listener initialization fails. - public function init(int|http:Listener listenTo, *ListenerConfiguration config) returns error? { + public function init(int|http:Listener listenTo, *ListenerConfiguration config) returns Error? { ListenerConfiguration {serverInfo, options, ...listenerConfig} = config; if listenTo is http:Listener { self.httpListener = listenTo; } else { - self.httpListener = check new (listenTo, listenerConfig); + http:Listener|error httpListener = new (listenTo, listenerConfig); + if httpListener is error { + return error("Failed to initialize HTTP listener: " + httpListener.message()); + } + self.httpListener = httpListener; } self.dispatcherService = dispatcherService; self.dispatcherService.setServerConfigs({ @@ -66,8 +70,11 @@ public class Listener { # + mcpService - Service to attach. # + name - Path(s) to mount the service on (string or string array). # + return - error? if attachment fails. - public isolated function attach(Service|AdvancedService mcpService, string[]|string? name = ()) returns error? { - check self.httpListener.attach(self.dispatcherService, name); + public isolated function attach(Service|AdvancedService mcpService, string[]|string? name = ()) returns Error? { + error? result = self.httpListener.attach(self.dispatcherService, name); + if result is error { + return error("Failed to attach MCP service: " + result.message()); + } self.dispatcherService.addServiceRef(mcpService); } @@ -75,29 +82,41 @@ public class Listener { # # + mcpService - Service to detach. # + return - error? if detachment fails. - public isolated function detach(Service|AdvancedService mcpService) returns error? { - check self.httpListener.detach(self.dispatcherService); + public isolated function detach(Service|AdvancedService mcpService) returns Error? { + error? result = self.httpListener.detach(self.dispatcherService); + if result is error { + return error("Failed to detach MCP service: " + result.message()); + } self.dispatcherService.removeServiceRef(); } # Starts the listener (begin accepting connections). # # + return - error? if starting fails. - public isolated function 'start() returns error? { - check self.httpListener.start(); + public isolated function 'start() returns Error? { + error? result = self.httpListener.start(); + if result is error { + return error("Failed to start MCP listener: " + result.message()); + } } # Gracefully stops the listener (completes active requests before shutting down). # # + return - error? if graceful stop fails. - public isolated function gracefulStop() returns error? { - check self.httpListener.gracefulStop(); + public isolated function gracefulStop() returns Error? { + error? result = self.httpListener.gracefulStop(); + if result is error { + return error("Failed to gracefully stop MCP listener: " + result.message()); + } } # Immediately stops the listener (terminates all connections). # # + return - error? if immediate stop fails. - public isolated function immediateStop() returns error? { - check self.httpListener.immediateStop(); + public isolated function immediateStop() returns Error? { + error? result = self.httpListener.immediateStop(); + if result is error { + return error("Failed to immediately stop MCP listener: " + result.message()); + } } } diff --git a/ballerina/native_listener_helper.bal b/ballerina/native_listener_helper.bal index 1ab9d78..83553a0 100644 --- a/ballerina/native_listener_helper.bal +++ b/ballerina/native_listener_helper.bal @@ -16,18 +16,21 @@ import ballerina/jballerina.java; -isolated function invokeOnListTools(AdvancedService 'service) returns ListToolsResult|error = @java:Method { +isolated function invokeOnListTools(AdvancedService 'service) returns ListToolsResult|Error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; -isolated function invokeOnCallTool(AdvancedService 'service, CallToolParams params) returns CallToolResult|error = @java:Method { +isolated function invokeOnCallTool(AdvancedService 'service, CallToolParams params) + returns CallToolResult|Error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; -isolated function listToolsForRemoteFunctions(Service 'service, typedesc t = <>) returns t|error = @java:Method { +isolated function listToolsForRemoteFunctions(Service 'service, typedesc t = <>) + returns t|Error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; -isolated function callToolForRemoteFunctions(Service 'service, CallToolParams params, typedesc t = <>) returns t|error = @java:Method { +isolated function callToolForRemoteFunctions(Service 'service, CallToolParams params, typedesc t = <>) + returns t|Error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; diff --git a/ballerina/types.bal b/ballerina/types.bal index 387dedc..17cd46e 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -26,8 +26,18 @@ public const SUPPORTED_PROTOCOL_VERSIONS = [ public const JSONRPC_VERSION = "2.0"; -// # Notification methods -public const NOTIFICATION_INITIALIZED = "notifications/initialized"; +# Request methods +public enum RequestMethod { + REQUEST_INITIALIZE = "initialize", + REQUEST_LIST_TOOLS = "tools/list", + REQUEST_CALL_TOOL = "tools/call" +}; + +# Notification methods +public enum NotificationMethod { + NOTIFICATION_INITIALIZED = "notifications/initialized", + NOTIFICATION_PROGRESS = "notifications/progress" +}; # A progress token, used to associate progress notifications with the original request. public type ProgressToken string|int; @@ -131,7 +141,7 @@ public type JsonRpcError record { type InitializeRequest record {| *Request; # Method name for the request - "initialize" method; + REQUEST_INITIALIZE method; # Parameters for the initialize request record { *RequestParams; @@ -271,7 +281,7 @@ public type EmbeddedResource record { public type ListToolsRequest record {| *PaginatedRequest; # The method identifier for this request - "tools/list" method; + REQUEST_LIST_TOOLS method; |}; # The server's response to a tools/list request from the client. @@ -293,7 +303,7 @@ public type CallToolResult record { # Used by the client to invoke a tool provided by the server. public type CallToolRequest record {| # The JSON-RPC method name - "tools/call" method; + REQUEST_CALL_TOOL method; # The parameters for the tool call CallToolParams params; |}; @@ -403,7 +413,7 @@ public type McpToolConfig record {| # The description of the tool. string description?; # The JSON schema for the tool's parameters. - json schema?; + map schema?; |}; # Annotation to mark a function as an MCP tool configuration. @@ -411,8 +421,8 @@ public annotation McpToolConfig McpTool on object function; # Defines a mcp service interface that handles incoming mcp requests. public type AdvancedService distinct isolated service object { - remote isolated function onListTools() returns ListToolsResult|error; - remote isolated function onCallTool(CallToolParams params) returns CallToolResult|error; + remote isolated function onListTools() returns ListToolsResult|ServerError; + remote isolated function onCallTool(CallToolParams params) returns CallToolResult|ServerError; }; public type Service distinct isolated service object { diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java index 1fe6701..1ab65e6 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java @@ -27,7 +27,7 @@ import static io.ballerina.tools.diagnostics.DiagnosticSeverity.ERROR; /** - * Compilation errors in the Ballerina AI package. + * Compilation errors in the Ballerina mcp package. */ public enum CompilationDiagnostic { UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION(DiagnosticMessage.ERROR_101, DiagnosticCode.MCP_101, ERROR); diff --git a/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java b/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java index 2e8e09d..2ccb316 100644 --- a/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java +++ b/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java @@ -19,7 +19,6 @@ package io.ballerina.stdlib.mcp; import io.ballerina.runtime.api.Environment; -import io.ballerina.runtime.api.creators.ErrorCreator; import io.ballerina.runtime.api.creators.ValueCreator; import io.ballerina.runtime.api.types.ArrayType; import io.ballerina.runtime.api.types.Parameter; @@ -47,19 +46,19 @@ */ public final class McpServiceMethodHelper { - private static final String FIELD_TOOLS = "tools"; - private static final String FIELD_NAME = "name"; - private static final String FIELD_DESCRIPTION = "description"; - private static final String FIELD_SCHEMA = "schema"; - private static final String FIELD_INPUT_SCHEMA = "inputSchema"; - private static final String FIELD_ARGUMENTS = "arguments"; - private static final String FIELD_CONTENT = "content"; - private static final String FIELD_TYPE = "type"; - private static final String FIELD_TEXT = "text"; + private static final String TOOLS_FIELD_NAME = "tools"; + private static final String NAME_FIELD_NAME = "name"; + private static final String DESCRIPTION_FIELD_NAME = "description"; + private static final String SCHEMA_FIELD_NAME = "schema"; + private static final String INPUT_SCHEMA_FIELD_NAME = "inputSchema"; + private static final String ARGUMENTS_FIELD_NAME = "arguments"; + private static final String CONTENT_FIELD_NAME = "content"; + private static final String TYPE_FIELD_NAME = "type"; + private static final String TEXT_FIELD_NAME = "text"; private static final String ANNOTATION_MCP_TOOL = "McpTool"; private static final String TYPE_TEXT_CONTENT = "TextContent"; - private static final String VALUE_TEXT = "text"; + private static final String TEXT_VALUE_NAME = "text"; private McpServiceMethodHelper() {} @@ -97,7 +96,7 @@ public static Object listToolsForRemoteFunctions(BObject mcpService, BTypedesc t RecordType resultRecordType = (RecordType) typed.getDescribingType(); BMap result = ValueCreator.createRecordValue(resultRecordType); - ArrayType toolsArrayType = (ArrayType) resultRecordType.getFields().get(FIELD_TOOLS).getFieldType(); + ArrayType toolsArrayType = (ArrayType) resultRecordType.getFields().get(TOOLS_FIELD_NAME).getFieldType(); BArray tools = ValueCreator.createArrayValue(toolsArrayType); for (RemoteMethodType remoteMethod : getRemoteMethods(mcpService)) { @@ -108,7 +107,7 @@ public static Object listToolsForRemoteFunctions(BObject mcpService, BTypedesc t createToolRecord(toolsArrayType, remoteMethod, (BMap) annotation.getValue()) )); } - result.put(fromString(FIELD_TOOLS), tools); + result.put(fromString(TOOLS_FIELD_NAME), tools); return result; } @@ -123,19 +122,18 @@ public static Object listToolsForRemoteFunctions(BObject mcpService, BTypedesc t */ public static Object callToolForRemoteFunctions(Environment env, BObject mcpService, BMap params, BTypedesc typed) { - BString toolName = (BString) params.get(fromString(FIELD_NAME)); + BString toolName = (BString) params.get(fromString(NAME_FIELD_NAME)); - RemoteMethodType method = getRemoteMethods(mcpService).stream() + Optional method = getRemoteMethods(mcpService).stream() .filter(rmt -> rmt.getName().equals(toolName.getValue())) - .findFirst().orElse(null); + .findFirst(); - if (method == null) { - BString errorMessage = - fromString("RemoteMethodType with name '" + toolName.getValue() + "' not found"); - return ErrorCreator.createError(errorMessage); + if (method.isEmpty()) { + return ModuleUtils + .createError("RemoteMethodType with name '" + toolName.getValue() + "' not found"); } - Object[] args = buildArgsForMethod(method, (BMap) params.get(fromString(FIELD_ARGUMENTS))); + Object[] args = buildArgsForMethod(method.get(), (BMap) params.get(fromString(ARGUMENTS_FIELD_NAME))); Object result = env.getRuntime().callMethod(mcpService, toolName.getValue(), null, args); return createCallToolResult(typed, result); @@ -151,9 +149,9 @@ private static BMap createToolRecord(ArrayType toolsArrayType, RecordType toolRecordType = (RecordType) ((ReferenceType) toolsArrayType.getElementType()).getReferredType(); BMap tool = ValueCreator.createRecordValue(toolRecordType); - tool.put(fromString(FIELD_NAME), fromString(remoteMethod.getName())); - tool.put(fromString(FIELD_DESCRIPTION), annotationValue.get(fromString(FIELD_DESCRIPTION))); - tool.put(fromString(FIELD_INPUT_SCHEMA), annotationValue.get(fromString(FIELD_SCHEMA))); + tool.put(fromString(NAME_FIELD_NAME), fromString(remoteMethod.getName())); + tool.put(fromString(DESCRIPTION_FIELD_NAME), annotationValue.get(fromString(DESCRIPTION_FIELD_NAME))); + tool.put(fromString(INPUT_SCHEMA_FIELD_NAME), annotationValue.get(fromString(SCHEMA_FIELD_NAME))); return tool; } @@ -171,7 +169,7 @@ private static Object createCallToolResult(BTypedesc typed, Object result) { RecordType resultRecordType = (RecordType) typed.getDescribingType(); BMap callToolResult = ValueCreator.createRecordValue(resultRecordType); - ArrayType contentArrayType = (ArrayType) resultRecordType.getFields().get(FIELD_CONTENT).getFieldType(); + ArrayType contentArrayType = (ArrayType) resultRecordType.getFields().get(CONTENT_FIELD_NAME).getFieldType(); BArray contentArray = ValueCreator.createArrayValue(contentArrayType); UnionType contentUnionType = (UnionType) contentArrayType.getElementType(); @@ -179,17 +177,16 @@ private static Object createCallToolResult(BTypedesc typed, Object result) { .filter(type -> TYPE_TEXT_CONTENT.equals(type.getName())) .findFirst(); if (textContentTypeOpt.isEmpty()) { - BString errorMessage = - fromString("No member type named 'TextContent' found in content union type."); - return ErrorCreator.createError(errorMessage); + return ModuleUtils + .createError("No member type named 'TextContent' found in content union type."); } RecordType textContentRecordType = (RecordType) ((ReferenceType) textContentTypeOpt.get()).getReferredType(); BMap textContent = ValueCreator.createRecordValue(textContentRecordType); - textContent.put(fromString(FIELD_TYPE), fromString(VALUE_TEXT)); - textContent.put(fromString(FIELD_TEXT), fromString(result == null ? "" : result.toString())); + textContent.put(fromString(TYPE_FIELD_NAME), fromString(TEXT_VALUE_NAME)); + textContent.put(fromString(TEXT_FIELD_NAME), fromString(result == null ? "" : result.toString())); contentArray.append(textContent); - callToolResult.put(fromString(FIELD_CONTENT), contentArray); + callToolResult.put(fromString(CONTENT_FIELD_NAME), contentArray); return callToolResult; } } diff --git a/native/src/main/java/io/ballerina/stdlib/mcp/ModuleUtils.java b/native/src/main/java/io/ballerina/stdlib/mcp/ModuleUtils.java index ba09fe7..4c8ccda 100644 --- a/native/src/main/java/io/ballerina/stdlib/mcp/ModuleUtils.java +++ b/native/src/main/java/io/ballerina/stdlib/mcp/ModuleUtils.java @@ -20,6 +20,10 @@ import io.ballerina.runtime.api.Environment; import io.ballerina.runtime.api.Module; +import io.ballerina.runtime.api.creators.ErrorCreator; +import io.ballerina.runtime.api.values.BError; + +import static io.ballerina.runtime.api.utils.StringUtils.fromString; public final class ModuleUtils { private static Module module; @@ -36,4 +40,8 @@ public static Module getModule() { public static void setModule(Environment env) { module = env.getCurrentModule(); } + + public static BError createError(String errorMessage) { + return ErrorCreator.createError(fromString(errorMessage)); + } } diff --git a/native/src/main/java/io/ballerina/stdlib/mcp/SseEventStreamHelper.java b/native/src/main/java/io/ballerina/stdlib/mcp/SseEventStreamHelper.java index b733df1..bbfd98c 100644 --- a/native/src/main/java/io/ballerina/stdlib/mcp/SseEventStreamHelper.java +++ b/native/src/main/java/io/ballerina/stdlib/mcp/SseEventStreamHelper.java @@ -19,12 +19,8 @@ package io.ballerina.stdlib.mcp; import io.ballerina.runtime.api.Environment; -import io.ballerina.runtime.api.creators.ErrorCreator; import io.ballerina.runtime.api.values.BObject; import io.ballerina.runtime.api.values.BStream; -import io.ballerina.runtime.api.values.BString; - -import static io.ballerina.runtime.api.utils.StringUtils.fromString; /** * Utility class for handling Server-Sent Events (SSE) streams in Ballerina via Java interop. @@ -68,8 +64,7 @@ public static void attachSseStream(BObject object, BStream sseStream) { public static Object getNextSseEvent(Environment env, BObject object) { BStream sseStream = (BStream) object.getNativeData(SSE_STREAM_NATIVE_KEY); if (sseStream == null) { - BString errorMessage = fromString("Unable to obtain elements from stream. SSE stream not found."); - return ErrorCreator.createError(errorMessage); + return ModuleUtils.createError("Unable to obtain elements from stream. SSE stream not found."); } BObject iteratorObject = sseStream.getIteratorObj(); // Use the Ballerina runtime to call the "next" method on the iterator and fetch the next event. @@ -88,8 +83,7 @@ public static Object getNextSseEvent(Environment env, BObject object) { public static Object closeSseEventStream(Environment env, BObject object) { BStream sseStream = (BStream) object.getNativeData(SSE_STREAM_NATIVE_KEY); if (sseStream == null) { - BString errorMessage = fromString("Unable to obtain elements from stream. SSE stream not found."); - return ErrorCreator.createError(errorMessage); + return ModuleUtils.createError("Unable to obtain elements from stream. SSE stream not found."); } BObject iteratorObject = sseStream.getIteratorObj(); // Use the Ballerina runtime to call the "close" method on the iterator and release resources. From 7596adf8c52e7536ab0bd3d4c3924680104bcf24 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Mon, 30 Jun 2025 13:20:39 +0530 Subject: [PATCH 15/31] [Automated] Update the native jar versions --- ballerina/Dependencies.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index 2e2263d..18521e4 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -108,6 +108,9 @@ dependencies = [ {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "lang.value"} ] +modules = [ + {org = "ballerina", packageName = "io", moduleName = "io"} +] [[package]] org = "ballerina" From facbfcb699c47ca0232e0fc45bf01f418583bc31 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Mon, 30 Jun 2025 13:39:38 +0530 Subject: [PATCH 16/31] [Automated] Update the native jar versions --- ballerina/Dependencies.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index 18521e4..5c54e49 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -108,9 +108,6 @@ dependencies = [ {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "lang.value"} ] -modules = [ - {org = "ballerina", packageName = "io", moduleName = "io"} -] [[package]] org = "ballerina" @@ -229,7 +226,6 @@ name = "mcp" version = "0.4.3" dependencies = [ {org = "ballerina", name = "http"}, - {org = "ballerina", name = "io"}, {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "log"} ] From e2cfadfde3f85192d1363303e25434ae14b056fa Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Mon, 30 Jun 2025 13:44:07 +0530 Subject: [PATCH 17/31] Fix error with cloning type --- ballerina/dispatcher_service.bal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index ffbbc57..c9b097d 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -249,7 +249,7 @@ DispatcherService dispatcherService = isolated service object { } // Extract and validate parameters - CallToolParams|error params = request.params.ensureType(CallToolParams); + CallToolParams|error params = request.params.cloneWithType(); if params is error { JsonRpcError jsonRpcError = self.createJsonRpcError(INVALID_PARAMS, string `Invalid parameters: ${params.message()}`, request.id); From 565d2d024201501381a1bc41c00311f5f16c4a52 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 1 Jul 2025 14:11:12 +0530 Subject: [PATCH 18/31] [Automated] Update the native jar versions --- ballerina/Dependencies.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index 5c54e49..18521e4 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -108,6 +108,9 @@ dependencies = [ {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "lang.value"} ] +modules = [ + {org = "ballerina", packageName = "io", moduleName = "io"} +] [[package]] org = "ballerina" @@ -226,6 +229,7 @@ name = "mcp" version = "0.4.3" dependencies = [ {org = "ballerina", name = "http"}, + {org = "ballerina", name = "io"}, {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "log"} ] From a4e58549f0728d166955182324d3bef4da362f57 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 1 Jul 2025 14:17:42 +0530 Subject: [PATCH 19/31] Move SericeConfigs to annotations --- .gitignore | 3 +++ ballerina/dispatcher_service.bal | 21 ++++++++------------- ballerina/listener.bal | 23 +---------------------- ballerina/types.bal | 25 +++++++++++++++++++++++-- ballerina/utils.bal | 17 +++++++++++++++++ 5 files changed, 52 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 834a0a0..6034716 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,6 @@ target # Ballerina velocity.log* *Ballerina.lock + +# AI Tools +CLAUDE.md diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index c9b097d..63fb29e 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -23,11 +23,9 @@ type DispatcherService distinct service object { isolated function addServiceRef(Service|AdvancedService mcpService); isolated function removeServiceRef(); - isolated function setServerConfigs(ServerConfiguration serverConfigs); }; DispatcherService dispatcherService = isolated service object { - private ServerConfiguration? serverConfigs = (); private Service|AdvancedService? mcpService = (); private boolean isInitialized = false; private string? sessionId = (); @@ -44,11 +42,6 @@ DispatcherService dispatcherService = isolated service object { } } - isolated function setServerConfigs(ServerConfiguration serverConfigs) { - lock { - self.serverConfigs = serverConfigs.cloneReadOnly(); - } - } isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) returns http:BadRequest|http:NotAcceptable|http:UnsupportedMediaType|http:Accepted|http:Ok|ServerError { @@ -132,7 +125,7 @@ DispatcherService dispatcherService = isolated service object { } private isolated function processJsonRpcNotification(JsonRpcNotification notification) returns http:Accepted|http:BadRequest { - if notification.method == "notifications/initialized" { + if notification.method == NOTIFICATION_INITIALIZED { return http:ACCEPTED; } @@ -169,15 +162,17 @@ DispatcherService dispatcherService = isolated service object { }; } - ServerConfiguration? serverConfigs = self.serverConfigs; - if serverConfigs is () { + Service|AdvancedService? mcpService = self.mcpService; + if mcpService is () { JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INTERNAL_ERROR, - "Internal Error: Server configuration is not set", id); + "Internal Error: MCP Service is not attached", id); return { body: jsonRpcError }; } + ServiceConfiguration serviceConfig = getServiceConfiguration(mcpService); + self.isInitialized = true; self.sessionId = uuid:createRandomUuid(); @@ -193,8 +188,8 @@ DispatcherService dispatcherService = isolated service object { id: id, result: { protocolVersion: protocolVersion, - capabilities: (serverConfigs.options?.capabilities ?: {}).cloneReadOnly(), - serverInfo: serverConfigs.serverInfo.cloneReadOnly() + capabilities: (serviceConfig.options?.capabilities ?: {}).cloneReadOnly(), + serverInfo: serviceConfig.info.cloneReadOnly() } } }; diff --git a/ballerina/listener.bal b/ballerina/listener.bal index d360414..768f6a6 100644 --- a/ballerina/listener.bal +++ b/ballerina/listener.bal @@ -16,25 +16,9 @@ import ballerina/http; -# Represents the options for configuring an MCP server. -public type ServerOptions record {| - # Capabilities to advertise as being supported by this server. - ServerCapabilities capabilities?; - # Optional instructions describing how to use the server and its features. - string instructions?; - # Whether to enforce strict capabilities compliance. - boolean enforceStrictCapabilities?; -|}; - # Configuration options for initializing an MCP listener. public type ListenerConfiguration record {| *http:ListenerConfiguration; - *ServerConfiguration; -|}; - -type ServerConfiguration record {| - Implementation serverInfo; - ServerOptions options?; |}; # A server listener for handling MCP service requests. @@ -48,21 +32,16 @@ public class Listener { # + config - Listener configuration. # + return - error? if listener initialization fails. public function init(int|http:Listener listenTo, *ListenerConfiguration config) returns Error? { - ListenerConfiguration {serverInfo, options, ...listenerConfig} = config; if listenTo is http:Listener { self.httpListener = listenTo; } else { - http:Listener|error httpListener = new (listenTo, listenerConfig); + http:Listener|error httpListener = new (listenTo, config); if httpListener is error { return error("Failed to initialize HTTP listener: " + httpListener.message()); } self.httpListener = httpListener; } self.dispatcherService = dispatcherService; - self.dispatcherService.setServerConfigs({ - serverInfo, - options - }); } # Attaches an MCP service to the listener under the specified path(s). diff --git a/ballerina/types.bal b/ballerina/types.bal index 17cd46e..9d2a921 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -35,8 +35,7 @@ public enum RequestMethod { # Notification methods public enum NotificationMethod { - NOTIFICATION_INITIALIZED = "notifications/initialized", - NOTIFICATION_PROGRESS = "notifications/progress" + NOTIFICATION_INITIALIZED = "notifications/initialized" }; # A progress token, used to associate progress notifications with the original request. @@ -419,12 +418,34 @@ public type McpToolConfig record {| # Annotation to mark a function as an MCP tool configuration. public annotation McpToolConfig McpTool on object function; +# Represents the options for configuring an MCP server. +public type ServerOptions record {| + # Capabilities to advertise as being supported by this server. + ServerCapabilities capabilities?; + # Optional instructions describing how to use the server and its features. + string instructions?; + # Whether to enforce strict capabilities compliance. + boolean enforceStrictCapabilities?; +|}; + +# Configuration for MCP service that defines server capabilities and metadata. +public type ServiceConfiguration record {| + # Server implementation information + Implementation info; + # Optional server configuration options + ServerOptions options?; +|}; + +# Annotation to provide service configuration to MCP services. +public annotation ServiceConfiguration ServiceConfig on service; + # Defines a mcp service interface that handles incoming mcp requests. public type AdvancedService distinct isolated service object { remote isolated function onListTools() returns ListToolsResult|ServerError; remote isolated function onCallTool(CallToolParams params) returns CallToolResult|ServerError; }; +# Defines a basic mcp service interface that handles incoming mcp requests. public type Service distinct isolated service object { }; diff --git a/ballerina/utils.bal b/ballerina/utils.bal index b251534..9e5f2de 100644 --- a/ballerina/utils.bal +++ b/ballerina/utils.bal @@ -73,3 +73,20 @@ isolated function extractResultFromMessage(JsonRpcMessage message) returns Serve return error InvalidMessageTypeError("Received message from server is not a valid JsonRpcResponse."); } +# Retrieves the service configuration from an MCP service. +# +# + mcpService - The MCP service instance +# + return - The service configuration +isolated function getServiceConfiguration(Service|AdvancedService mcpService) returns ServiceConfiguration { + typedesc mcpServiceType = typeof mcpService; + ServiceConfiguration? serviceConfig = mcpServiceType.@ServiceConfig; + if serviceConfig is () { + return { + info: { + name: "MCP Service", + version: "1.0.0" + } + }; + } + return serviceConfig; +} From de0239a0a7bfe14a910683f2d2045c48534a5004 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 3 Jul 2025 13:46:20 +0530 Subject: [PATCH 20/31] [Automated] Update the native jar versions --- ballerina/Dependencies.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index 18521e4..5c54e49 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -108,9 +108,6 @@ dependencies = [ {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "lang.value"} ] -modules = [ - {org = "ballerina", packageName = "io", moduleName = "io"} -] [[package]] org = "ballerina" @@ -229,7 +226,6 @@ name = "mcp" version = "0.4.3" dependencies = [ {org = "ballerina", name = "http"}, - {org = "ballerina", name = "io"}, {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "log"} ] From 1cde6c3998dd92ee1599f2e72e1d60da81484e56 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 3 Jul 2025 14:20:25 +0530 Subject: [PATCH 21/31] Address review comments --- ballerina/dispatcher_service.bal | 111 ++++++++---------- ballerina/listener.bal | 48 ++++---- ballerina/types.bal | 15 ++- ballerina/utils.bal | 15 +-- .../stdlib/mcp/plugin/McpSourceModifier.java | 76 ++++++------ 5 files changed, 130 insertions(+), 135 deletions(-) diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index 63fb29e..24211a6 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -17,15 +17,14 @@ import ballerina/http; import ballerina/uuid; -# Represents the dispatcher service type definition. -type DispatcherService distinct service object { +type DispatcherService distinct isolated service object { *http:Service; isolated function addServiceRef(Service|AdvancedService mcpService); isolated function removeServiceRef(); }; -DispatcherService dispatcherService = isolated service object { +final DispatcherService dispatcherService = isolated service object { private Service|AdvancedService? mcpService = (); private boolean isInitialized = false; private string? sessionId = (); @@ -42,9 +41,8 @@ DispatcherService dispatcherService = isolated service object { } } - isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) - returns http:BadRequest|http:NotAcceptable|http:UnsupportedMediaType|http:Accepted|http:Ok|ServerError { + returns http:BadRequest|http:NotAcceptable|http:UnsupportedMediaType|http:Accepted|http:Ok { http:NotAcceptable|http:UnsupportedMediaType? headerValidationError = self.validateHeaders(headers); if headerValidationError !is () { return headerValidationError; @@ -57,9 +55,8 @@ DispatcherService dispatcherService = isolated service object { return self.processJsonRpcNotification(request); } - JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INVALID_REQUEST, "Unsupported request type"); return { - body: jsonRpcError + body: self.createJsonRpcError(INVALID_REQUEST, "Unsupported request type") }; } @@ -68,36 +65,32 @@ DispatcherService dispatcherService = isolated service object { // Validate Accept header string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); if acceptHeader is http:HeaderNotFoundError { - JsonRpcError jsonRpcError = self.createJsonRpcError(NOT_ACCEPTABLE, - "Not Acceptable: Client must accept both application/json and text/event-stream"); return { - body: jsonRpcError + body: self.createJsonRpcError(NOT_ACCEPTABLE, + "Not Acceptable: Client must accept both application/json and text/event-stream") }; } if !acceptHeader.includes(CONTENT_TYPE_JSON) || !acceptHeader.includes(CONTENT_TYPE_SSE) { - JsonRpcError jsonRpcError = self.createJsonRpcError(NOT_ACCEPTABLE, - "Not Acceptable: Client must accept both application/json and text/event-stream"); return { - body: jsonRpcError + body: self.createJsonRpcError(NOT_ACCEPTABLE, + "Not Acceptable: Client must accept both application/json and text/event-stream") }; } // Validate Content-Type header string|http:HeaderNotFoundError contentTypeHeader = headers.getHeader(CONTENT_TYPE_HEADER); if contentTypeHeader is http:HeaderNotFoundError { - JsonRpcError jsonRpcError = self.createJsonRpcError(UNSUPPORTED_MEDIA_TYPE, - "Unsupported Media Type: Content-Type must be application/json"); return { - body: jsonRpcError + body: self.createJsonRpcError(UNSUPPORTED_MEDIA_TYPE, + "Unsupported Media Type: Content-Type must be application/json") }; } if !contentTypeHeader.includes(CONTENT_TYPE_JSON) { - JsonRpcError jsonRpcError = self.createJsonRpcError(UNSUPPORTED_MEDIA_TYPE, - "Unsupported Media Type: Content-Type must be application/json"); return { - body: jsonRpcError + body: self.createJsonRpcError(UNSUPPORTED_MEDIA_TYPE, + "Unsupported Media Type: Content-Type must be application/json") }; } @@ -116,9 +109,8 @@ DispatcherService dispatcherService = isolated service object { return self.handleCallToolRequest(request); } _ => { - JsonRpcError jsonRpcError = self.createJsonRpcError(METHOD_NOT_FOUND, "Method not found", request.id); return { - body: jsonRpcError + body: self.createJsonRpcError(METHOD_NOT_FOUND, "Method not found", request.id) }; } } @@ -144,10 +136,9 @@ DispatcherService dispatcherService = isolated service object { JsonRpcRequest {jsonrpc: _, id, ...request} = jsonRpcRequest; InitializeRequest|error initRequest = request.cloneWithType(); if initRequest is error { - JsonRpcError jsonRpcError = self.createJsonRpcError(INVALID_REQUEST, - string `Invalid request: ${initRequest.message()}`, id); return { - body: jsonRpcError + body: self.createJsonRpcError(INVALID_REQUEST, + string `Invalid request: ${initRequest.message()}`, id) }; } @@ -155,19 +146,17 @@ DispatcherService dispatcherService = isolated service object { // If it's a server with session management and the session ID is already set we should reject the request // to avoid re-initialization. if self.isInitialized && self.sessionId != () { - JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INVALID_REQUEST, - "Invalid Request: Only one initialization request is allowed", id); return { - body: jsonRpcError + body: self.createJsonRpcError(INVALID_REQUEST, + "Invalid Request: Only one initialization request is allowed", id) }; } Service|AdvancedService? mcpService = self.mcpService; if mcpService is () { - JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INTERNAL_ERROR, - "Internal Error: MCP Service is not attached", id); return { - body: jsonRpcError + body: self.createJsonRpcError(INTERNAL_ERROR, + "Internal Error: MCP Service is not attached", id) }; } @@ -180,9 +169,7 @@ DispatcherService dispatcherService = isolated service object { string protocolVersion = self.selectProtocolVersion(requestedVersion); return { - headers: { - [SESSION_ID_HEADER]: self.sessionId ?: "" - }, + headers: self.prepareRequestHeaders(), body: { jsonrpc: JSONRPC_VERSION, id: id, @@ -200,28 +187,24 @@ DispatcherService dispatcherService = isolated service object { lock { // Check if initialized if !self.isInitialized { - JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INVALID_REQUEST, - "Client must be initialized before making requests", request.id); return { - body: jsonRpcError + body: self.createJsonRpcError(INVALID_REQUEST, + "Client must be initialized before making requests", request.id) }; } } ListToolsResult|error listToolsResult = self.executeOnListTools(); if listToolsResult is error { - JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INTERNAL_ERROR, - string `Failed to list tools: ${listToolsResult.message()}`, request.id); return { - body: jsonRpcError + body: self.createJsonRpcError(INTERNAL_ERROR, + string `Failed to list tools: ${listToolsResult.message()}`, request.id) }; } lock { return { - headers: { - [SESSION_ID_HEADER]: self.sessionId ?: "" - }, + headers: self.prepareRequestHeaders(), body: { jsonrpc: JSONRPC_VERSION, id: request.id, @@ -235,10 +218,9 @@ DispatcherService dispatcherService = isolated service object { lock { // Check if initialized if !self.isInitialized { - JsonRpcError & readonly jsonRpcError = self.createJsonRpcError(INVALID_REQUEST, - "Client must be initialized before making requests", request.id); return { - body: jsonRpcError + body: self.createJsonRpcError(INVALID_REQUEST, + "Client must be initialized before making requests", request.id) }; } } @@ -246,27 +228,23 @@ DispatcherService dispatcherService = isolated service object { // Extract and validate parameters CallToolParams|error params = request.params.cloneWithType(); if params is error { - JsonRpcError jsonRpcError = self.createJsonRpcError(INVALID_PARAMS, - string `Invalid parameters: ${params.message()}`, request.id); return { - body: jsonRpcError + body: self.createJsonRpcError(INVALID_PARAMS, + string `Invalid parameters: ${params.message()}`, request.id) }; } CallToolResult|error callToolResult = self.executeOnCallTool(params); if callToolResult is error { - JsonRpcError jsonRpcError = self.createJsonRpcError(INTERNAL_ERROR, - string `Failed to call tool '${params.name}': ${callToolResult.message()}`, request.id); return { - body: jsonRpcError + body: self.createJsonRpcError(INTERNAL_ERROR, + string `Failed to call tool '${params.name}': ${callToolResult.message()}`, request.id) }; } lock { return { - headers: { - [SESSION_ID_HEADER]: self.sessionId ?: "" - }, + headers: self.prepareRequestHeaders(), body: { jsonrpc: JSONRPC_VERSION, id: request.id, @@ -276,6 +254,13 @@ DispatcherService dispatcherService = isolated service object { } } + private isolated function prepareRequestHeaders() returns map { + lock { + string? currentSessionId = self.sessionId; + return currentSessionId is string ? {[SESSION_ID_HEADER]: currentSessionId} : {}; + } + } + private isolated function selectProtocolVersion(string requestedVersion) returns string { foreach string supportedVersion in SUPPORTED_PROTOCOL_VERSIONS { if supportedVersion == requestedVersion { @@ -285,16 +270,14 @@ DispatcherService dispatcherService = isolated service object { return LATEST_PROTOCOL_VERSION; } - private isolated function createJsonRpcError(int code, string message, RequestId? id = ()) returns JsonRpcError & readonly { - return { - jsonrpc: JSONRPC_VERSION, - id: id, - 'error: { - code: code, - message: message - } - }; - } + private isolated function createJsonRpcError(int code, string message, RequestId? id = ()) returns JsonRpcError & readonly => { + jsonrpc: JSONRPC_VERSION, + id: id, + 'error: { + code: code, + message: message + } + }; private isolated function executeOnListTools() returns ListToolsResult|error { lock { diff --git a/ballerina/listener.bal b/ballerina/listener.bal index 768f6a6..25ead4b 100644 --- a/ballerina/listener.bal +++ b/ballerina/listener.bal @@ -22,9 +22,8 @@ public type ListenerConfiguration record {| |}; # A server listener for handling MCP service requests. -public class Listener { +public isolated class Listener { private http:Listener httpListener; - private DispatcherService dispatcherService; # Initializes the Listener. # @@ -41,7 +40,6 @@ public class Listener { } self.httpListener = httpListener; } - self.dispatcherService = dispatcherService; } # Attaches an MCP service to the listener under the specified path(s). @@ -50,11 +48,13 @@ public class Listener { # + name - Path(s) to mount the service on (string or string array). # + return - error? if attachment fails. public isolated function attach(Service|AdvancedService mcpService, string[]|string? name = ()) returns Error? { - error? result = self.httpListener.attach(self.dispatcherService, name); - if result is error { - return error("Failed to attach MCP service: " + result.message()); + lock { + error? result = self.httpListener.attach(dispatcherService, name.cloneReadOnly()); + if result is error { + return error("Failed to attach MCP service: " + result.message()); + } + dispatcherService.addServiceRef(mcpService); } - self.dispatcherService.addServiceRef(mcpService); } # Detaches the MCP service from the listener. @@ -62,20 +62,24 @@ public class Listener { # + mcpService - Service to detach. # + return - error? if detachment fails. public isolated function detach(Service|AdvancedService mcpService) returns Error? { - error? result = self.httpListener.detach(self.dispatcherService); - if result is error { - return error("Failed to detach MCP service: " + result.message()); + lock { + error? result = self.httpListener.detach(dispatcherService); + if result is error { + return error("Failed to detach MCP service: " + result.message()); + } + dispatcherService.removeServiceRef(); } - self.dispatcherService.removeServiceRef(); } # Starts the listener (begin accepting connections). # # + return - error? if starting fails. public isolated function 'start() returns Error? { - error? result = self.httpListener.start(); - if result is error { - return error("Failed to start MCP listener: " + result.message()); + lock { + error? result = self.httpListener.start(); + if result is error { + return error("Failed to start MCP listener: " + result.message()); + } } } @@ -83,9 +87,11 @@ public class Listener { # # + return - error? if graceful stop fails. public isolated function gracefulStop() returns Error? { - error? result = self.httpListener.gracefulStop(); - if result is error { - return error("Failed to gracefully stop MCP listener: " + result.message()); + lock { + error? result = self.httpListener.gracefulStop(); + if result is error { + return error("Failed to gracefully stop MCP listener: " + result.message()); + } } } @@ -93,9 +99,11 @@ public class Listener { # # + return - error? if immediate stop fails. public isolated function immediateStop() returns Error? { - error? result = self.httpListener.immediateStop(); - if result is error { - return error("Failed to immediately stop MCP listener: " + result.message()); + lock { + error? result = self.httpListener.immediateStop(); + if result is error { + return error("Failed to immediately stop MCP listener: " + result.message()); + } } } } diff --git a/ballerina/types.bal b/ballerina/types.bal index 9d2a921..f993410 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -26,15 +26,19 @@ public const SUPPORTED_PROTOCOL_VERSIONS = [ public const JSONRPC_VERSION = "2.0"; -# Request methods +# Standard JSON-RPC request methods supported by the MCP protocol public enum RequestMethod { + # Initialize request to start a new MCP session REQUEST_INITIALIZE = "initialize", + # Request to list all available tools from the server REQUEST_LIST_TOOLS = "tools/list", + # Request to execute a specific tool with given parameters REQUEST_CALL_TOOL = "tools/call" }; -# Notification methods +# JSON-RPC notification methods used for one-way communication in MCP public enum NotificationMethod { + # Notification sent by client to indicate successful initialization NOTIFICATION_INITIALIZED = "notifications/initialized" }; @@ -109,14 +113,21 @@ public type JsonRpcResponse record {| |}; // Standard JSON-RPC error codes +# Standard JSON-RPC 2.0 error code indicating that the JSON sent is not a valid JSON object. public const PARSE_ERROR = -32700; +# Standard JSON-RPC 2.0 error code indicating that the JSON sent is not a valid request object. public const INVALID_REQUEST = -32600; +# Standard JSON-RPC 2.0 error code indicating that the method does not exist or is not available. public const METHOD_NOT_FOUND = -32601; +# Standard JSON-RPC 2.0 error code indicating that invalid method parameters were provided. public const INVALID_PARAMS = -32602; +# Standard JSON-RPC 2.0 error code indicating that an internal error occurred on the server. public const INTERNAL_ERROR = -32603; // Library-defined error codes +# MCP library-defined error code indicating that the request is not acceptable. public const NOT_ACCEPTABLE = -32001; +# MCP library-defined error code indicating that the media type is not supported. public const UNSUPPORTED_MEDIA_TYPE = -32002; # A response to a request that indicates an error occurred. diff --git a/ballerina/utils.bal b/ballerina/utils.bal index 9e5f2de..d9fcb45 100644 --- a/ballerina/utils.bal +++ b/ballerina/utils.bal @@ -80,13 +80,10 @@ isolated function extractResultFromMessage(JsonRpcMessage message) returns Serve isolated function getServiceConfiguration(Service|AdvancedService mcpService) returns ServiceConfiguration { typedesc mcpServiceType = typeof mcpService; ServiceConfiguration? serviceConfig = mcpServiceType.@ServiceConfig; - if serviceConfig is () { - return { - info: { - name: "MCP Service", - version: "1.0.0" - } - }; - } - return serviceConfig; + return serviceConfig ?: { + info: { + name: "MCP Service", + version: "1.0.0" + } + }; } diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java index ed6ddb8..1732272 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java @@ -21,8 +21,6 @@ import io.ballerina.compiler.api.SemanticModel; import io.ballerina.compiler.syntax.tree.AnnotationNode; import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; -import io.ballerina.compiler.syntax.tree.IdentifierToken; -import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode; import io.ballerina.compiler.syntax.tree.MetadataNode; import io.ballerina.compiler.syntax.tree.ModuleMemberDeclarationNode; import io.ballerina.compiler.syntax.tree.ModulePartNode; @@ -30,10 +28,8 @@ import io.ballerina.compiler.syntax.tree.NodeFactory; import io.ballerina.compiler.syntax.tree.NodeList; import io.ballerina.compiler.syntax.tree.NodeParser; -import io.ballerina.compiler.syntax.tree.QualifiedNameReferenceNode; import io.ballerina.compiler.syntax.tree.ServiceDeclarationNode; import io.ballerina.compiler.syntax.tree.SyntaxTree; -import io.ballerina.compiler.syntax.tree.Token; import io.ballerina.projects.DocumentId; import io.ballerina.projects.Module; import io.ballerina.projects.plugins.ModifierTask; @@ -46,9 +42,7 @@ import java.util.Map; import java.util.Optional; -import static io.ballerina.compiler.syntax.tree.SyntaxKind.AT_TOKEN; import static io.ballerina.compiler.syntax.tree.SyntaxKind.CLOSE_BRACE_TOKEN; -import static io.ballerina.compiler.syntax.tree.SyntaxKind.COLON_TOKEN; import static io.ballerina.compiler.syntax.tree.SyntaxKind.OBJECT_METHOD_DEFINITION; import static io.ballerina.compiler.syntax.tree.SyntaxKind.OPEN_BRACE_TOKEN; import static io.ballerina.compiler.syntax.tree.SyntaxKind.SERVICE_DECLARATION; @@ -109,19 +103,9 @@ private Map getModifiedAnnotations(Modif } private AnnotationNode getModifiedAnnotation(ToolAnnotationConfig config) { - Token atToken = NodeFactory.createToken(AT_TOKEN); - - Token modulePrefix = NodeFactory.createIdentifierToken(MCP_PACKAGE_NAME); - Token colonToken = NodeFactory.createToken(COLON_TOKEN); - IdentifierToken identifier = NodeFactory.createIdentifierToken(TOOL_ANNOTATION_NAME); - QualifiedNameReferenceNode annotationReferenceNode = - NodeFactory.createQualifiedNameReferenceNode(modulePrefix, colonToken, identifier); - String mappingConstructorExpression = generateConfigMappingConstructor(config); - MappingConstructorExpressionNode mappingConstructorNode = (MappingConstructorExpressionNode) NodeParser - .parseExpression(mappingConstructorExpression); - - return NodeFactory.createAnnotationNode(atToken, annotationReferenceNode, mappingConstructorNode); + String annotationString = "@" + MCP_PACKAGE_NAME + ":" + TOOL_ANNOTATION_NAME + mappingConstructorExpression; + return NodeParser.parseAnnotation(annotationString); } private String generateConfigMappingConstructor(ToolAnnotationConfig config) { @@ -160,34 +144,46 @@ private ModuleMemberDeclarationNode modifyServiceDeclaration( ArrayList modifiedMembers = new ArrayList<>(); for (Node member : members) { - if (member.kind() == OBJECT_METHOD_DEFINITION) { - FunctionDefinitionNode functionDefinitionNode = (FunctionDefinitionNode) member; - AnnotationNode modifiedAnnotationNode = modifiedAnnotations.get(functionDefinitionNode); - if (functionDefinitionNode.metadata().isPresent()) { - MetadataNode functionMetadata = functionDefinitionNode.metadata().get(); - Optional toolAnnotationNode = - getToolAnnotationNode(semanticModel, functionDefinitionNode); - if (toolAnnotationNode.isPresent()) { - MetadataNode modifiedMetadata = modifyMetadata(functionMetadata, toolAnnotationNode.get(), - modifiedAnnotationNode); - functionDefinitionNode = functionDefinitionNode.modify().withMetadata(modifiedMetadata).apply(); - } else { - functionDefinitionNode = functionDefinitionNode.modify() - .withMetadata(modifyWithToolAnnotation(functionMetadata, modifiedAnnotationNode)) - .apply(); - } - } else { - functionDefinitionNode = functionDefinitionNode.modify() - .withMetadata(createMetadata(modifiedAnnotationNode)).apply(); - } - modifiedMembers.add(functionDefinitionNode); - } else { + if (member.kind() != OBJECT_METHOD_DEFINITION) { modifiedMembers.add(member); + continue; } + + FunctionDefinitionNode functionDefinitionNode = (FunctionDefinitionNode) member; + AnnotationNode modifiedAnnotationNode = modifiedAnnotations.get(functionDefinitionNode); + + MetadataNode newMetadata = createOrUpdateMetadata( + semanticModel, functionDefinitionNode, modifiedAnnotationNode); + + FunctionDefinitionNode updatedFunction = functionDefinitionNode.modify() + .withMetadata(newMetadata) + .apply(); + + modifiedMembers.add(updatedFunction); } + return classDefinitionNode.modify().withMembers(NodeFactory.createNodeList(modifiedMembers)).apply(); } + private MetadataNode createOrUpdateMetadata( + SemanticModel semanticModel, + FunctionDefinitionNode functionDefinitionNode, + AnnotationNode modifiedAnnotationNode) { + + if (functionDefinitionNode.metadata().isEmpty()) { + return createMetadata(modifiedAnnotationNode); + } + + MetadataNode existingMetadata = functionDefinitionNode.metadata().get(); + Optional toolAnnotationNode = getToolAnnotationNode(semanticModel, functionDefinitionNode); + + if (toolAnnotationNode.isPresent()) { + return modifyMetadata(existingMetadata, toolAnnotationNode.get(), modifiedAnnotationNode); + } else { + return modifyWithToolAnnotation(existingMetadata, modifiedAnnotationNode); + } + } + private MetadataNode modifyWithToolAnnotation(MetadataNode metadata, AnnotationNode annotationNode) { List updatedAnnotations = new ArrayList<>(); metadata.annotations().forEach(updatedAnnotations::add); From bf8c9a1b8bc5dc9354a2a2c5ca0c1d224c349be0 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 3 Jul 2025 14:36:01 +0530 Subject: [PATCH 22/31] Fix the multi session --- ballerina/dispatcher_service.bal | 140 ++++++++++++++++++++----------- 1 file changed, 91 insertions(+), 49 deletions(-) diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index 24211a6..1fb5b60 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -26,8 +26,7 @@ type DispatcherService distinct isolated service object { final DispatcherService dispatcherService = isolated service object { private Service|AdvancedService? mcpService = (); - private boolean isInitialized = false; - private string? sessionId = (); + private map sessionMap = {}; isolated function addServiceRef(Service|AdvancedService mcpService) { lock { @@ -41,6 +40,34 @@ final DispatcherService dispatcherService = isolated service object { } } + isolated resource function delete .(http:Headers headers) returns http:BadRequest|http:Ok { + string? sessionId = self.getSessionIdFromHeaders(headers); + if sessionId is () { + return { + body: self.createJsonRpcError(INVALID_REQUEST, "Missing session ID header") + }; + } + + lock { + if !self.sessionMap.hasKey(sessionId) { + return { + body: self.createJsonRpcError(INVALID_REQUEST, string `Invalid session ID: ${sessionId}`) + }; + } + + _ = self.sessionMap.remove(sessionId); + } + + return { + body: { + jsonrpc: JSONRPC_VERSION, + result: { + message: string `Session ${sessionId} deleted successfully` + } + } + }; + } + isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) returns http:BadRequest|http:NotAcceptable|http:UnsupportedMediaType|http:Accepted|http:Ok { http:NotAcceptable|http:UnsupportedMediaType? headerValidationError = self.validateHeaders(headers); @@ -49,7 +76,7 @@ final DispatcherService dispatcherService = isolated service object { } if request is JsonRpcRequest { - return self.processJsonRpcRequest(request); + return self.processJsonRpcRequest(request, headers); } if request is JsonRpcNotification { return self.processJsonRpcNotification(request); @@ -97,16 +124,16 @@ final DispatcherService dispatcherService = isolated service object { return; } - private isolated function processJsonRpcRequest(JsonRpcRequest request) returns http:BadRequest|http:Ok { + private isolated function processJsonRpcRequest(JsonRpcRequest request, http:Headers headers) returns http:BadRequest|http:Ok { match request.method { REQUEST_INITIALIZE => { - return self.handleInitializeRequest(request); + return self.handleInitializeRequest(request, headers); } REQUEST_LIST_TOOLS => { - return self.handleListToolsRequest(request); + return self.handleListToolsRequest(request, headers); } REQUEST_CALL_TOOL => { - return self.handleCallToolRequest(request); + return self.handleCallToolRequest(request, headers); } _ => { return { @@ -132,7 +159,12 @@ final DispatcherService dispatcherService = isolated service object { }; } - private isolated function handleInitializeRequest(JsonRpcRequest jsonRpcRequest) returns http:BadRequest|http:Ok { + private isolated function getSessionIdFromHeaders(http:Headers headers) returns string? { + string|http:HeaderNotFoundError sessionHeader = headers.getHeader(SESSION_ID_HEADER); + return sessionHeader is string ? sessionHeader : (); + } + + private isolated function handleInitializeRequest(JsonRpcRequest jsonRpcRequest, http:Headers headers) returns http:BadRequest|http:Ok { JsonRpcRequest {jsonrpc: _, id, ...request} = jsonRpcRequest; InitializeRequest|error initRequest = request.cloneWithType(); if initRequest is error { @@ -142,13 +174,15 @@ final DispatcherService dispatcherService = isolated service object { }; } + // Check if there's a session ID in the headers + string? existingSessionId = self.getSessionIdFromHeaders(headers); + lock { - // If it's a server with session management and the session ID is already set we should reject the request - // to avoid re-initialization. - if self.isInitialized && self.sessionId != () { + // If there's an existing session ID and it's already in the map, return error + if existingSessionId is string && self.sessionMap.hasKey(existingSessionId) { return { body: self.createJsonRpcError(INVALID_REQUEST, - "Invalid Request: Only one initialization request is allowed", id) + string `Session already initialized: ${existingSessionId}`, id) }; } @@ -162,14 +196,15 @@ final DispatcherService dispatcherService = isolated service object { ServiceConfiguration serviceConfig = getServiceConfiguration(mcpService); - self.isInitialized = true; - self.sessionId = uuid:createRandomUuid(); + // Create new session ID + string newSessionId = uuid:createRandomUuid(); + self.sessionMap[newSessionId] = "initialized"; string requestedVersion = initRequest.params.protocolVersion; string protocolVersion = self.selectProtocolVersion(requestedVersion); return { - headers: self.prepareRequestHeaders(), + headers: {[SESSION_ID_HEADER]: newSessionId}, body: { jsonrpc: JSONRPC_VERSION, id: id, @@ -183,13 +218,22 @@ final DispatcherService dispatcherService = isolated service object { } } - private isolated function handleListToolsRequest(JsonRpcRequest request) returns http:BadRequest|http:Ok { + private isolated function handleListToolsRequest(JsonRpcRequest request, http:Headers headers) returns http:BadRequest|http:Ok { + // Validate session ID + string? sessionId = self.getSessionIdFromHeaders(headers); + if sessionId is () { + return { + body: self.createJsonRpcError(INVALID_REQUEST, + "Missing session ID header", request.id) + }; + } + lock { - // Check if initialized - if !self.isInitialized { + // Check if session exists + if !self.sessionMap.hasKey(sessionId) { return { body: self.createJsonRpcError(INVALID_REQUEST, - "Client must be initialized before making requests", request.id) + string `Invalid session ID: ${sessionId}`, request.id) }; } } @@ -202,25 +246,32 @@ final DispatcherService dispatcherService = isolated service object { }; } - lock { - return { - headers: self.prepareRequestHeaders(), - body: { - jsonrpc: JSONRPC_VERSION, - id: request.id, - result: listToolsResult.cloneReadOnly() - } + return { + headers: {[SESSION_ID_HEADER]: sessionId}, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: listToolsResult.cloneReadOnly() + } + }; + } + + private isolated function handleCallToolRequest(JsonRpcRequest request, http:Headers headers) returns http:BadRequest|http:Ok { + // Validate session ID + string? sessionId = self.getSessionIdFromHeaders(headers); + if sessionId is () { + return { + body: self.createJsonRpcError(INVALID_REQUEST, + "Missing session ID header", request.id) }; } - } - private isolated function handleCallToolRequest(JsonRpcRequest request) returns http:BadRequest|http:Ok { lock { - // Check if initialized - if !self.isInitialized { + // Check if session exists + if !self.sessionMap.hasKey(sessionId) { return { body: self.createJsonRpcError(INVALID_REQUEST, - "Client must be initialized before making requests", request.id) + string `Invalid session ID: ${sessionId}`, request.id) }; } } @@ -242,23 +293,14 @@ final DispatcherService dispatcherService = isolated service object { }; } - lock { - return { - headers: self.prepareRequestHeaders(), - body: { - jsonrpc: JSONRPC_VERSION, - id: request.id, - result: callToolResult.cloneReadOnly() - } - }; - } - } - - private isolated function prepareRequestHeaders() returns map { - lock { - string? currentSessionId = self.sessionId; - return currentSessionId is string ? {[SESSION_ID_HEADER]: currentSessionId} : {}; - } + return { + headers: {[SESSION_ID_HEADER]: sessionId}, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: callToolResult.cloneReadOnly() + } + }; } private isolated function selectProtocolVersion(string requestedVersion) returns string { From 35dfa0826f23a8652ac859ca2d7c3be9eba42b47 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 3 Jul 2025 17:15:05 +0530 Subject: [PATCH 23/31] Remove isolated from servie type --- ballerina/dispatcher_service.bal | 71 ++++++------------- ballerina/listener.bal | 14 ++-- ballerina/native_listener_helper.bal | 10 +++ ballerina/types.bal | 10 +-- .../io/ballerina/stdlib/mcp/plugin/Utils.java | 2 +- .../stdlib/mcp/McpServiceMethodHelper.java | 37 +++++++++- 6 files changed, 82 insertions(+), 62 deletions(-) diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index 1fb5b60..7cdd0e0 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -17,29 +17,11 @@ import ballerina/http; import ballerina/uuid; -type DispatcherService distinct isolated service object { +isolated service class DispatcherService { *http:Service; - isolated function addServiceRef(Service|AdvancedService mcpService); - isolated function removeServiceRef(); -}; - -final DispatcherService dispatcherService = isolated service object { - private Service|AdvancedService? mcpService = (); private map sessionMap = {}; - isolated function addServiceRef(Service|AdvancedService mcpService) { - lock { - self.mcpService = mcpService; - } - } - - isolated function removeServiceRef() { - lock { - self.mcpService = (); - } - } - isolated resource function delete .(http:Headers headers) returns http:BadRequest|http:Ok { string? sessionId = self.getSessionIdFromHeaders(headers); if sessionId is () { @@ -69,7 +51,7 @@ final DispatcherService dispatcherService = isolated service object { } isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) - returns http:BadRequest|http:NotAcceptable|http:UnsupportedMediaType|http:Accepted|http:Ok { + returns http:BadRequest|http:NotAcceptable|http:UnsupportedMediaType|http:Accepted|http:Ok|Error { http:NotAcceptable|http:UnsupportedMediaType? headerValidationError = self.validateHeaders(headers); if headerValidationError !is () { return headerValidationError; @@ -124,7 +106,7 @@ final DispatcherService dispatcherService = isolated service object { return; } - private isolated function processJsonRpcRequest(JsonRpcRequest request, http:Headers headers) returns http:BadRequest|http:Ok { + private isolated function processJsonRpcRequest(JsonRpcRequest request, http:Headers headers) returns http:BadRequest|http:Ok|Error { match request.method { REQUEST_INITIALIZE => { return self.handleInitializeRequest(request, headers); @@ -164,7 +146,7 @@ final DispatcherService dispatcherService = isolated service object { return sessionHeader is string ? sessionHeader : (); } - private isolated function handleInitializeRequest(JsonRpcRequest jsonRpcRequest, http:Headers headers) returns http:BadRequest|http:Ok { + private isolated function handleInitializeRequest(JsonRpcRequest jsonRpcRequest, http:Headers headers) returns http:BadRequest|http:Ok|Error { JsonRpcRequest {jsonrpc: _, id, ...request} = jsonRpcRequest; InitializeRequest|error initRequest = request.cloneWithType(); if initRequest is error { @@ -186,14 +168,7 @@ final DispatcherService dispatcherService = isolated service object { }; } - Service|AdvancedService? mcpService = self.mcpService; - if mcpService is () { - return { - body: self.createJsonRpcError(INTERNAL_ERROR, - "Internal Error: MCP Service is not attached", id) - }; - } - + Service|AdvancedService mcpService = check getMcpServiceFromDispatcher(self); ServiceConfiguration serviceConfig = getServiceConfiguration(mcpService); // Create new session ID @@ -321,29 +296,25 @@ final DispatcherService dispatcherService = isolated service object { } }; - private isolated function executeOnListTools() returns ListToolsResult|error { - lock { - Service|AdvancedService? mcpService = self.mcpService; - if mcpService is AdvancedService { - return invokeOnListTools(mcpService); - } - if mcpService is Service { - return listToolsForRemoteFunctions(mcpService); - } - return error DispatcherError("MCP Service is not attached"); + private isolated function executeOnListTools() returns ListToolsResult|Error { + Service|AdvancedService mcpService = check getMcpServiceFromDispatcher(self); + if mcpService is AdvancedService { + return invokeOnListTools(mcpService); } + if mcpService is Service { + return listToolsForRemoteFunctions(mcpService); + } + return error DispatcherError("MCP Service is not attached"); } - private isolated function executeOnCallTool(CallToolParams params) returns CallToolResult|error { - lock { - Service|AdvancedService? mcpService = self.mcpService; - if mcpService is AdvancedService { - return invokeOnCallTool(mcpService, params.cloneReadOnly()); - } - if mcpService is Service { - return callToolForRemoteFunctions(mcpService, params.cloneReadOnly()); - } - return error DispatcherError("MCP Service is not attached"); + private isolated function executeOnCallTool(CallToolParams params) returns CallToolResult|Error { + Service|AdvancedService mcpService = check getMcpServiceFromDispatcher(self); + if mcpService is AdvancedService { + return invokeOnCallTool(mcpService, params.cloneReadOnly()); + } + if mcpService is Service { + return callToolForRemoteFunctions(mcpService, params.cloneReadOnly()); } + return error DispatcherError("MCP Service is not attached"); } }; diff --git a/ballerina/listener.bal b/ballerina/listener.bal index 25ead4b..8d9c4b6 100644 --- a/ballerina/listener.bal +++ b/ballerina/listener.bal @@ -24,6 +24,7 @@ public type ListenerConfiguration record {| # A server listener for handling MCP service requests. public isolated class Listener { private http:Listener httpListener; + private DispatcherService[] dispatcherServices = []; # Initializes the Listener. # @@ -48,12 +49,13 @@ public isolated class Listener { # + name - Path(s) to mount the service on (string or string array). # + return - error? if attachment fails. public isolated function attach(Service|AdvancedService mcpService, string[]|string? name = ()) returns Error? { + DispatcherService dispatcherService = check new (); + check addMcpServiceToDispatcher(dispatcherService, mcpService); lock { error? result = self.httpListener.attach(dispatcherService, name.cloneReadOnly()); if result is error { return error("Failed to attach MCP service: " + result.message()); } - dispatcherService.addServiceRef(mcpService); } } @@ -63,11 +65,13 @@ public isolated class Listener { # + return - error? if detachment fails. public isolated function detach(Service|AdvancedService mcpService) returns Error? { lock { - error? result = self.httpListener.detach(dispatcherService); - if result is error { - return error("Failed to detach MCP service: " + result.message()); + foreach DispatcherService dispatcherService in self.dispatcherServices { + error? result = self.httpListener.detach(dispatcherService); + if result is error { + return error("Failed to detach MCP service: " + result.message()); + } } - dispatcherService.removeServiceRef(); + self.dispatcherServices = []; } } diff --git a/ballerina/native_listener_helper.bal b/ballerina/native_listener_helper.bal index 83553a0..80ccf5a 100644 --- a/ballerina/native_listener_helper.bal +++ b/ballerina/native_listener_helper.bal @@ -34,3 +34,13 @@ isolated function callToolForRemoteFunctions(Service 'service, CallToolParams pa returns t|Error = @java:Method { 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" } external; + +isolated function addMcpServiceToDispatcher(DispatcherService dispatcherService, Service|AdvancedService mcpService) + returns Error? = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; + +isolated function getMcpServiceFromDispatcher(DispatcherService dispatcherService) + returns Service|AdvancedService|Error = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; diff --git a/ballerina/types.bal b/ballerina/types.bal index f993410..747267c 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -298,7 +298,7 @@ public type ListToolsRequest record {| public type ListToolsResult record { *PaginatedResult; # A list of tools available on the server. - Tool[] tools; + McpTool[] tools; }; # The server's response to a tool call. @@ -354,7 +354,7 @@ public type ToolAnnotations record { }; # Definition for a tool the client can call. -public type Tool record { +public type McpTool record { # The name of the tool string name; # A human-readable description of the tool @@ -427,7 +427,7 @@ public type McpToolConfig record {| |}; # Annotation to mark a function as an MCP tool configuration. -public annotation McpToolConfig McpTool on object function; +public annotation McpToolConfig Tool on object function; # Represents the options for configuring an MCP server. public type ServerOptions record {| @@ -451,12 +451,12 @@ public type ServiceConfiguration record {| public annotation ServiceConfiguration ServiceConfig on service; # Defines a mcp service interface that handles incoming mcp requests. -public type AdvancedService distinct isolated service object { +public type AdvancedService distinct service object { remote isolated function onListTools() returns ListToolsResult|ServerError; remote isolated function onCallTool(CallToolParams params) returns CallToolResult|ServerError; }; # Defines a basic mcp service interface that handles incoming mcp requests. -public type Service distinct isolated service object { +public type Service distinct service object { }; diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java index b1ffc9d..d1542bb 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java @@ -36,7 +36,7 @@ */ public class Utils { public static final String BALLERINA_ORG = "ballerina"; - public static final String TOOL_ANNOTATION_NAME = "McpTool"; + public static final String TOOL_ANNOTATION_NAME = "Tool"; public static final String MCP_PACKAGE_NAME = "mcp"; private Utils() { diff --git a/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java b/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java index 2ccb316..8094af4 100644 --- a/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java +++ b/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java @@ -56,9 +56,10 @@ public final class McpServiceMethodHelper { private static final String TYPE_FIELD_NAME = "type"; private static final String TEXT_FIELD_NAME = "text"; - private static final String ANNOTATION_MCP_TOOL = "McpTool"; + private static final String ANNOTATION_MCP_TOOL = "Tool"; private static final String TYPE_TEXT_CONTENT = "TextContent"; private static final String TEXT_VALUE_NAME = "text"; + private static final String MCP_SERVICE_FIELD = "mcpService"; private McpServiceMethodHelper() {} @@ -139,6 +140,40 @@ public static Object callToolForRemoteFunctions(Environment env, BObject mcpServ return createCallToolResult(typed, result); } + /** + * Adds an MCP service to the dispatcher service by storing it in a private field. + * + * @param dispatcherService The dispatcher service object. + * @param mcpService The MCP service object to store. + * @return null if successful, error otherwise. + */ + public static Object addMcpServiceToDispatcher(BObject dispatcherService, BObject mcpService) { + try { + dispatcherService.addNativeData(MCP_SERVICE_FIELD, mcpService); + return null; + } catch (Exception e) { + return ModuleUtils.createError("Failed to add MCP service to dispatcher: " + e.getMessage()); + } + } + + /** + * Retrieves the MCP service from the dispatcher service. + * + * @param dispatcherService The dispatcher service object. + * @return The MCP service object or an error if not found. + */ + public static Object getMcpServiceFromDispatcher(BObject dispatcherService) { + try { + Object mcpService = dispatcherService.getNativeData(MCP_SERVICE_FIELD); + if (mcpService == null) { + return ModuleUtils.createError("MCP service not found in dispatcher"); + } + return mcpService; + } catch (Exception e) { + return ModuleUtils.createError("Failed to get MCP service from dispatcher: " + e.getMessage()); + } + } + private static List getRemoteMethods(BObject mcpService) { ServiceType serviceType = (ServiceType) mcpService.getOriginalType(); return List.of(serviceType.getRemoteMethods()); From cf3cc63126bdc38a8ebbc0c495e1e37de71ef7d2 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Thu, 3 Jul 2025 17:16:14 +0530 Subject: [PATCH 24/31] Fix warnings --- ballerina/listener.bal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballerina/listener.bal b/ballerina/listener.bal index 8d9c4b6..26e7321 100644 --- a/ballerina/listener.bal +++ b/ballerina/listener.bal @@ -49,7 +49,7 @@ public isolated class Listener { # + name - Path(s) to mount the service on (string or string array). # + return - error? if attachment fails. public isolated function attach(Service|AdvancedService mcpService, string[]|string? name = ()) returns Error? { - DispatcherService dispatcherService = check new (); + DispatcherService dispatcherService = new (); check addMcpServiceToDispatcher(dispatcherService, mcpService); lock { error? result = self.httpListener.attach(dispatcherService, name.cloneReadOnly()); From 759a7a7c99ecb69ccc38fb00da11a056f35c962f Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Fri, 4 Jul 2025 11:16:44 +0530 Subject: [PATCH 25/31] Move initialization into init --- ballerina/client.bal | 175 ++++++++++++++++--------------------------- 1 file changed, 63 insertions(+), 112 deletions(-) diff --git a/ballerina/client.bal b/ballerina/client.bal index 2cc17fe..a716ec7 100644 --- a/ballerina/client.bal +++ b/ballerina/client.bal @@ -33,15 +33,8 @@ public type ClientCapabilityConfiguration record {| # Represents an MCP client built on top of the Streamable HTTP transport. public distinct isolated client class Client { - # MCP server URL. - private final string serverUrl; - # Client implementation details (e.g., name and version). - private final Implementation clientInfo; - # Capabilities supported by the client. - private final ClientCapabilities clientCapabilities; - # Transport for communication with the MCP server. - private StreamableHttpClientTransport? transport = (); + private StreamableHttpClientTransport transport; # Server capabilities. private ServerCapabilities? serverCapabilities = (); # Server implementation information. @@ -49,71 +42,61 @@ public distinct isolated client class Client { # Request ID generator for tracking requests. private int requestId = 0; - # Initializes a new MCP client with the provided server URL and client details. - # - # + serverUrl - MCP server URL. - # + clientInfo - Client details, such as name and version. - # + config - Optional configuration containing client capabilities. - public isolated function init(string serverUrl, *ClientConfiguration config) { - self.serverUrl = serverUrl; - self.clientInfo = config.info.cloneReadOnly(); - self.clientCapabilities = (config.capabilityConfig?.capabilities).cloneReadOnly() ?: {}; - } - - # Establishes a connection to the MCP server and performs protocol initialization. + # Initializes a new MCP client and establishes connection to the server. + # Performs protocol handshake and capability exchange. Client is ready for use after construction. # - # + return - A ClientError if initialization fails, or nil on success. - isolated remote function initialize() returns ClientError? { - lock { - // Create and initialize transport. - StreamableHttpClientTransport newTransport = check new StreamableHttpClientTransport(self.serverUrl); - self.transport = newTransport; - - string? sessionId = newTransport.getSessionId(); + # + serverUrl - MCP server URL + # + config - Client configuration including info and capabilities + # + return - ClientError if initialization fails, nil on success + public isolated function init(string serverUrl, *ClientConfiguration config) returns ClientError? { + // Create and initialize transport. + StreamableHttpClientTransport newTransport = check new StreamableHttpClientTransport(serverUrl); + self.transport = newTransport; + + string? sessionId = newTransport.getSessionId(); + + // If a session ID exists, assume reconnection and skip initialization. + if sessionId is string { + return; + } - // If a session ID exists, assume reconnection and skip initialization. - if sessionId is string { - return; + // Prepare and send the initialization request. + InitializeRequest initRequest = { + method: "initialize", + params: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: config.capabilityConfig?.capabilities ?: {}, + clientInfo: config.info } + }; - // Prepare and send the initialization request. - InitializeRequest initRequest = { - method: "initialize", - params: { - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: self.clientCapabilities, - clientInfo: self.clientInfo - } - }; + ServerResult response = check self.sendRequestMessage(initRequest); - ServerResult response = check self.sendRequestMessage(initRequest); - - if response is InitializeResult { - final readonly & string protocolVersion = response.protocolVersion; - // Validate protocol compatibility. - if (!SUPPORTED_PROTOCOL_VERSIONS.some(v => v == protocolVersion)) { - return error ProtocolVersionError( - string `Server protocol version '${ - protocolVersion}' is not supported. Supported versions: ${ - SUPPORTED_PROTOCOL_VERSIONS.toString()}.` - ); - } - - // Store server capabilities and info. - self.serverCapabilities = response.capabilities; - self.serverInfo = response.serverInfo; - - // Send notification to complete initialization. - InitializedNotification initNotification = { - method: "notifications/initialized" - }; - check self.sendNotificationMessage(initNotification); - } else { - return error ClientInitializationError( - string `Initialization failed: unexpected response type '${ - (typeof response).toString()}' received from server.` + if response is InitializeResult { + final readonly & string protocolVersion = response.protocolVersion; + // Validate protocol compatibility. + if (!SUPPORTED_PROTOCOL_VERSIONS.some(v => v == protocolVersion)) { + return error ProtocolVersionError( + string `Server protocol version '${ + protocolVersion}' is not supported. Supported versions: ${ + SUPPORTED_PROTOCOL_VERSIONS.toString()}.` ); } + + // Store server capabilities and info. + self.serverCapabilities = response.capabilities.cloneReadOnly(); + self.serverInfo = response.serverInfo.cloneReadOnly(); + + // Send notification to complete initialization. + InitializedNotification initNotification = { + method: "notifications/initialized" + }; + check self.sendNotificationMessage(initNotification); + } else { + return error ClientInitializationError( + string `Initialization failed: unexpected response type '${ + (typeof response).toString()}' received from server.` + ); } } @@ -122,13 +105,7 @@ public distinct isolated client class Client { # + return - Stream of JsonRpcMessages or a ClientError. isolated remote function subscribeToServerMessages() returns stream|ClientError { lock { - StreamableHttpClientTransport? currentTransport = self.transport; - if currentTransport is () { - return error UninitializedTransportError( - "Subscription failed: client transport is not initialized." - ); - } - return currentTransport.establishEventStream(); + return self.transport.establishEventStream(); } } @@ -175,20 +152,10 @@ public distinct isolated client class Client { # + return - A ClientError if closure fails, or nil on success. isolated remote function close() returns ClientError? { lock { - StreamableHttpClientTransport? currentTransport = self.transport; - if currentTransport is () { - return error UninitializedTransportError( - "Closure failed: client transport is not initialized." - ); - } - do { - check currentTransport.terminateSession(); - lock { - self.transport = (); - self.serverCapabilities = (); - self.serverInfo = (); - } + check self.transport.terminateSession(); + self.serverCapabilities = (); + self.serverInfo = (); return; } on fail error e { return error ClientError(string `Failed to disconnect from server: ${e.message()}`, e); @@ -202,26 +169,17 @@ public distinct isolated client class Client { # + return - ServerResult, a stream of results, or a ClientError. private isolated function sendRequestMessage(Request request) returns ServerResult|ClientError { lock { - StreamableHttpClientTransport? currentTransport = self.transport; - if currentTransport is () { - return error UninitializedTransportError( - "Cannot send request: client transport is not initialized." - ); - } - - lock { - self.requestId += 1; + self.requestId += 1; - JsonRpcRequest jsonRpcRequest = { - ...request.cloneReadOnly(), - jsonrpc: JSONRPC_VERSION, - id: self.requestId - }; + JsonRpcRequest jsonRpcRequest = { + ...request.cloneReadOnly(), + jsonrpc: JSONRPC_VERSION, + id: self.requestId + }; - JsonRpcMessage|stream|StreamableHttpTransportError? response = - currentTransport.sendMessage(jsonRpcRequest); - return processServerResponse(response).cloneReadOnly(); - } + JsonRpcMessage|stream|StreamableHttpTransportError? response = + self.transport.sendMessage(jsonRpcRequest); + return processServerResponse(response).cloneReadOnly(); } } @@ -231,19 +189,12 @@ public distinct isolated client class Client { # + return - A ClientError if sending fails, or nil on success. private isolated function sendNotificationMessage(Notification notification) returns ClientError? { lock { - StreamableHttpClientTransport? currentTransport = self.transport; - if currentTransport is () { - return error UninitializedTransportError( - "Cannot send notification: client transport is not initialized." - ); - } - JsonRpcNotification jsonRpcNotification = { ...notification.cloneReadOnly(), jsonrpc: JSONRPC_VERSION }; - _ = check currentTransport.sendMessage(jsonRpcNotification); + _ = check self.transport.sendMessage(jsonRpcNotification); } } } From 1489a0896edc15fc81d23e0ce63c2d1585853a30 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Fri, 4 Jul 2025 11:22:28 +0530 Subject: [PATCH 26/31] Rename McpTool to ToolDefinition --- ballerina/types.bal | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ballerina/types.bal b/ballerina/types.bal index 747267c..66fa095 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -298,7 +298,7 @@ public type ListToolsRequest record {| public type ListToolsResult record { *PaginatedResult; # A list of tools available on the server. - McpTool[] tools; + ToolDefinition[] tools; }; # The server's response to a tool call. @@ -354,7 +354,7 @@ public type ToolAnnotations record { }; # Definition for a tool the client can call. -public type McpTool record { +public type ToolDefinition record { # The name of the tool string name; # A human-readable description of the tool From 50cd39c309092c1af80863c4bd5a35f91f6b56d5 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Fri, 4 Jul 2025 11:45:52 +0530 Subject: [PATCH 27/31] Fix attach detach logic --- ballerina/client.bal | 2 +- ballerina/listener.bal | 27 ++++++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/ballerina/client.bal b/ballerina/client.bal index a716ec7..9df6031 100644 --- a/ballerina/client.bal +++ b/ballerina/client.bal @@ -23,7 +23,7 @@ public type ClientConfiguration record {| ClientCapabilityConfiguration capabilityConfig?; |}; -# Configuration options for initializing an MCP client. +# Configuration for MCP client capabilities. public type ClientCapabilityConfiguration record {| # Capabilities to be advertised by this client. ClientCapabilities capabilities?; diff --git a/ballerina/listener.bal b/ballerina/listener.bal index 26e7321..6c6e3c5 100644 --- a/ballerina/listener.bal +++ b/ballerina/listener.bal @@ -30,7 +30,7 @@ public isolated class Listener { # # + listenTo - Either a port number (int) or an existing http:Listener. # + config - Listener configuration. - # + return - error? if listener initialization fails. + # + return - Error? if listener initialization fails. public function init(int|http:Listener listenTo, *ListenerConfiguration config) returns Error? { if listenTo is http:Listener { self.httpListener = listenTo; @@ -47,7 +47,7 @@ public isolated class Listener { # # + mcpService - Service to attach. # + name - Path(s) to mount the service on (string or string array). - # + return - error? if attachment fails. + # + return - Error? if attachment fails. public isolated function attach(Service|AdvancedService mcpService, string[]|string? name = ()) returns Error? { DispatcherService dispatcherService = new (); check addMcpServiceToDispatcher(dispatcherService, mcpService); @@ -56,28 +56,33 @@ public isolated class Listener { if result is error { return error("Failed to attach MCP service: " + result.message()); } + self.dispatcherServices.push(dispatcherService); } } # Detaches the MCP service from the listener. # # + mcpService - Service to detach. - # + return - error? if detachment fails. + # + return - Error? if detachment fails. public isolated function detach(Service|AdvancedService mcpService) returns Error? { lock { - foreach DispatcherService dispatcherService in self.dispatcherServices { - error? result = self.httpListener.detach(dispatcherService); - if result is error { - return error("Failed to detach MCP service: " + result.message()); + foreach [int, DispatcherService] dispatcherService in self.dispatcherServices.enumerate() { + Service|AdvancedService|Error attachedService = getMcpServiceFromDispatcher(dispatcherService[1]); + if attachedService === mcpService { + error? result = self.httpListener.detach(dispatcherService[1]); + if result is error { + return error("Failed to detach MCP service: " + result.message()); + } + _ = self.dispatcherServices.remove(dispatcherService[0]); + break; } } - self.dispatcherServices = []; } } # Starts the listener (begin accepting connections). # - # + return - error? if starting fails. + # + return - Error? if starting fails. public isolated function 'start() returns Error? { lock { error? result = self.httpListener.start(); @@ -89,7 +94,7 @@ public isolated class Listener { # Gracefully stops the listener (completes active requests before shutting down). # - # + return - error? if graceful stop fails. + # + return - Error? if graceful stop fails. public isolated function gracefulStop() returns Error? { lock { error? result = self.httpListener.gracefulStop(); @@ -101,7 +106,7 @@ public isolated class Listener { # Immediately stops the listener (terminates all connections). # - # + return - error? if immediate stop fails. + # + return - Error? if immediate stop fails. public isolated function immediateStop() returns Error? { lock { error? result = self.httpListener.immediateStop(); From 54b3ed18271c931b9b75c0feead89c2190b0c3d3 Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 8 Jul 2025 11:57:37 +0530 Subject: [PATCH 28/31] Address review comments --- ballerina/client.bal | 48 ++++++++----------- ballerina/dispatcher_service.bal | 6 +-- ballerina/listener.bal | 8 ++-- ballerina/types.bal | 10 ++-- .../stdlib/mcp/plugin/McpCodeModifier.java | 6 +++ .../stdlib/mcp/plugin/McpSourceModifier.java | 16 +++++++ .../plugin/RemoteFunctionAnalysisTask.java | 16 +++++++ 7 files changed, 71 insertions(+), 39 deletions(-) diff --git a/ballerina/client.bal b/ballerina/client.bal index 9df6031..a9e9574 100644 --- a/ballerina/client.bal +++ b/ballerina/client.bal @@ -50,7 +50,7 @@ public distinct isolated client class Client { # + return - ClientError if initialization fails, nil on success public isolated function init(string serverUrl, *ClientConfiguration config) returns ClientError? { // Create and initialize transport. - StreamableHttpClientTransport newTransport = check new StreamableHttpClientTransport(serverUrl); + StreamableHttpClientTransport newTransport = check new (serverUrl); self.transport = newTransport; string? sessionId = newTransport.getSessionId(); @@ -62,7 +62,6 @@ public distinct isolated client class Client { // Prepare and send the initialization request. InitializeRequest initRequest = { - method: "initialize", params: { protocolVersion: LATEST_PROTOCOL_VERSION, capabilities: config.capabilityConfig?.capabilities ?: {}, @@ -72,32 +71,30 @@ public distinct isolated client class Client { ServerResult response = check self.sendRequestMessage(initRequest); - if response is InitializeResult { - final readonly & string protocolVersion = response.protocolVersion; - // Validate protocol compatibility. - if (!SUPPORTED_PROTOCOL_VERSIONS.some(v => v == protocolVersion)) { - return error ProtocolVersionError( - string `Server protocol version '${ - protocolVersion}' is not supported. Supported versions: ${ - SUPPORTED_PROTOCOL_VERSIONS.toString()}.` - ); - } - - // Store server capabilities and info. - self.serverCapabilities = response.capabilities.cloneReadOnly(); - self.serverInfo = response.serverInfo.cloneReadOnly(); - - // Send notification to complete initialization. - InitializedNotification initNotification = { - method: "notifications/initialized" - }; - check self.sendNotificationMessage(initNotification); - } else { + if !(response is InitializeResult) { return error ClientInitializationError( string `Initialization failed: unexpected response type '${ (typeof response).toString()}' received from server.` ); } + + final string protocolVersion = response.protocolVersion; + // Validate protocol compatibility. + if (!SUPPORTED_PROTOCOL_VERSIONS.some(v => v == protocolVersion)) { + return error ProtocolVersionError( + string `Server protocol version '${ + protocolVersion}' is not supported. Supported versions: ${ + SUPPORTED_PROTOCOL_VERSIONS.toString()}.` + ); + } + + // Store server capabilities and info. + self.serverCapabilities = response.capabilities.cloneReadOnly(); + self.serverInfo = response.serverInfo.cloneReadOnly(); + + // Send notification to complete initialization. + InitializedNotification initNotification = {}; + check self.sendNotificationMessage(initNotification); } # Opens a server-sent events (SSE) stream for asynchronous server-to-client communication. @@ -113,9 +110,7 @@ public distinct isolated client class Client { # # + return - List of available tools or a ClientError. isolated remote function listTools() returns ListToolsResult|ClientError { - ListToolsRequest listToolsRequest = { - method: "tools/list" - }; + ListToolsRequest listToolsRequest = {}; ServerResult result = check self.sendRequestMessage(listToolsRequest); if result is ListToolsResult { @@ -133,7 +128,6 @@ public distinct isolated client class Client { # + return - Result of the tool execution or a ClientError. isolated remote function callTool(CallToolParams params) returns CallToolResult|ClientError { CallToolRequest toolCallRequest = { - method: "tools/call", params: params }; diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal index 7cdd0e0..b692cda 100644 --- a/ballerina/dispatcher_service.bal +++ b/ballerina/dispatcher_service.bal @@ -60,6 +60,7 @@ isolated service class DispatcherService { if request is JsonRpcRequest { return self.processJsonRpcRequest(request, headers); } + if request is JsonRpcNotification { return self.processJsonRpcNotification(request); } @@ -71,7 +72,6 @@ isolated service class DispatcherService { private isolated function validateHeaders(http:Headers headers) returns http:NotAcceptable|http:UnsupportedMediaType? { - // Validate Accept header string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); if acceptHeader is http:HeaderNotFoundError { return { @@ -87,7 +87,6 @@ isolated service class DispatcherService { }; } - // Validate Content-Type header string|http:HeaderNotFoundError contentTypeHeader = headers.getHeader(CONTENT_TYPE_HEADER); if contentTypeHeader is http:HeaderNotFoundError { return { @@ -106,7 +105,8 @@ isolated service class DispatcherService { return; } - private isolated function processJsonRpcRequest(JsonRpcRequest request, http:Headers headers) returns http:BadRequest|http:Ok|Error { + private isolated function processJsonRpcRequest(JsonRpcRequest request, http:Headers headers) + returns http:BadRequest|http:Ok|Error { match request.method { REQUEST_INITIALIZE => { return self.handleInitializeRequest(request, headers); diff --git a/ballerina/listener.bal b/ballerina/listener.bal index 6c6e3c5..de1b88a 100644 --- a/ballerina/listener.bal +++ b/ballerina/listener.bal @@ -66,14 +66,14 @@ public isolated class Listener { # + return - Error? if detachment fails. public isolated function detach(Service|AdvancedService mcpService) returns Error? { lock { - foreach [int, DispatcherService] dispatcherService in self.dispatcherServices.enumerate() { - Service|AdvancedService|Error attachedService = getMcpServiceFromDispatcher(dispatcherService[1]); + foreach [int, DispatcherService] [index, dispatcherService] in self.dispatcherServices.enumerate() { + Service|AdvancedService|Error attachedService = getMcpServiceFromDispatcher(dispatcherService); if attachedService === mcpService { - error? result = self.httpListener.detach(dispatcherService[1]); + error? result = self.httpListener.detach(dispatcherService); if result is error { return error("Failed to detach MCP service: " + result.message()); } - _ = self.dispatcherServices.remove(dispatcherService[0]); + _ = self.dispatcherServices.remove(index); break; } } diff --git a/ballerina/types.bal b/ballerina/types.bal index 66fa095..7d2c807 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -151,7 +151,7 @@ public type JsonRpcError record { type InitializeRequest record {| *Request; # Method name for the request - REQUEST_INITIALIZE method; + REQUEST_INITIALIZE method = REQUEST_INITIALIZE; # Parameters for the initialize request record { *RequestParams; @@ -187,7 +187,7 @@ public type InitializeResult record { public type InitializedNotification record {| *Notification; # The method identifier for the notification, must be "notifications/initialized" - NOTIFICATION_INITIALIZED method; + NOTIFICATION_INITIALIZED method = NOTIFICATION_INITIALIZED; |}; # Capabilities a client may support. Known capabilities are defined here, in this schema, @@ -291,7 +291,7 @@ public type EmbeddedResource record { public type ListToolsRequest record {| *PaginatedRequest; # The method identifier for this request - REQUEST_LIST_TOOLS method; + REQUEST_LIST_TOOLS method = REQUEST_LIST_TOOLS; |}; # The server's response to a tools/list request from the client. @@ -313,7 +313,7 @@ public type CallToolResult record { # Used by the client to invoke a tool provided by the server. public type CallToolRequest record {| # The JSON-RPC method name - REQUEST_CALL_TOOL method; + REQUEST_CALL_TOOL method = REQUEST_CALL_TOOL; # The parameters for the tool call CallToolParams params; |}; @@ -450,7 +450,7 @@ public type ServiceConfiguration record {| # Annotation to provide service configuration to MCP services. public annotation ServiceConfiguration ServiceConfig on service; -# Defines a mcp service interface that handles incoming mcp requests. +# Defines an MCP service interface that handles incoming MCP requests. public type AdvancedService distinct service object { remote isolated function onListTools() returns ListToolsResult|ServerError; remote isolated function onCallTool(CallToolParams params) returns CallToolResult|ServerError; diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java index bdce349..a6f022d 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java @@ -27,6 +27,12 @@ import static io.ballerina.compiler.syntax.tree.SyntaxKind.OBJECT_METHOD_DEFINITION; +/** + * Code modifier for processing MCP tool annotations on remote functions. + * + *

This modifier analyzes object method definitions and automatically generates + * or updates MCP tool annotations with schema information during compilation.

+ */ public class McpCodeModifier extends CodeModifier { private final Map modifierContextMap = new HashMap<>(); diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java index 1732272..d723e83 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java @@ -50,13 +50,29 @@ import static io.ballerina.stdlib.mcp.plugin.Utils.TOOL_ANNOTATION_NAME; import static io.ballerina.stdlib.mcp.plugin.Utils.getToolAnnotationNode; +/** + * Source modifier task that updates MCP tool annotations in Ballerina source files. + * + *

This modifier takes the analysis results from {@link RemoteFunctionAnalysisTask} + * and applies the generated tool annotations to the source code by modifying function metadata.

+ */ public class McpSourceModifier implements ModifierTask { private final Map modifierContextMap; + /** + * Creates a new source modifier with the given modifier context map. + * + * @param modifierContextMap map containing analysis results for each document + */ McpSourceModifier(Map modifierContextMap) { this.modifierContextMap = modifierContextMap; } + /** + * Modifies source files by updating MCP tool annotations based on analysis results. + * + * @param context the source modifier context + */ @Override public void modify(SourceModifierContext context) { for (Map.Entry entry : modifierContextMap.entrySet()) { diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java index 651add2..f9fc55b 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java @@ -48,6 +48,12 @@ import static io.ballerina.stdlib.mcp.plugin.Utils.getToolAnnotationNode; import static io.ballerina.stdlib.mcp.plugin.diagnostics.CompilationDiagnostic.UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION; +/** + * Analysis task that processes remote function definitions to generate MCP tool annotation configurations. + * + *

This task analyzes function signatures, extracts parameter schemas, and creates tool annotation + * configurations that will be used by {@link McpSourceModifier} to update source code.

+ */ public class RemoteFunctionAnalysisTask implements AnalysisTask { public static final String EMPTY_STRING = ""; public static final String NIL_EXPRESSION = "()"; @@ -55,10 +61,20 @@ public class RemoteFunctionAnalysisTask implements AnalysisTask modifierContextMap; private SyntaxNodeAnalysisContext context; + /** + * Creates a new analysis task with the given modifier context map. + * + * @param modifierContextMap map to store analysis results for each document + */ RemoteFunctionAnalysisTask(Map modifierContextMap) { this.modifierContextMap = modifierContextMap; } + /** + * Performs analysis on a function definition node to extract tool annotation information. + * + * @param context the syntax node analysis context containing the function definition + */ @Override public void perform(SyntaxNodeAnalysisContext context) { this.context = context; From 1040605ea3424862b5238f9dba177f0afca6124a Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 15 Jul 2025 12:08:05 +0530 Subject: [PATCH 29/31] [Automated] Update the native jar versions --- ballerina/Dependencies.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index 5c54e49..6ec1e24 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -216,9 +216,6 @@ dependencies = [ {org = "ballerina", name = "lang.value"}, {org = "ballerina", name = "observe"} ] -modules = [ - {org = "ballerina", packageName = "log", moduleName = "log"} -] [[package]] org = "ballerina" @@ -227,7 +224,7 @@ version = "0.4.3" dependencies = [ {org = "ballerina", name = "http"}, {org = "ballerina", name = "jballerina.java"}, - {org = "ballerina", name = "log"} + {org = "ballerina", name = "uuid"} ] modules = [ {org = "ballerina", packageName = "mcp", moduleName = "mcp"} From 690e3eda6e78e5aaf2bb6804cf98513743f6cc5f Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 15 Jul 2025 12:08:25 +0530 Subject: [PATCH 30/31] Fix merge conflicts --- ballerina/CompilerPlugin.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballerina/CompilerPlugin.toml b/ballerina/CompilerPlugin.toml index fd94c5c..677dafc 100644 --- a/ballerina/CompilerPlugin.toml +++ b/ballerina/CompilerPlugin.toml @@ -3,7 +3,7 @@ id = "mcp-compiler-plugin" class = "io.ballerina.stdlib.mcp.plugin.McpCompilerPlugin" [[dependency]] -path = "../compiler-plugin/build/libs/mcp-compiler-plugin-0.4.2-SNAPSHOT.jar" +path = "../compiler-plugin/build/libs/mcp-compiler-plugin-0.4.3-SNAPSHOT.jar" [[dependency]] path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar" From 2fcf652f2b6d6134693b7f5b1709b4c53d6462fd Mon Sep 17 00:00:00 2001 From: Azeem Muzammil Date: Tue, 15 Jul 2025 14:16:40 +0530 Subject: [PATCH 31/31] Address review comments --- ballerina/client.bal | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ballerina/client.bal b/ballerina/client.bal index a9e9574..0bbbdd3 100644 --- a/ballerina/client.bal +++ b/ballerina/client.bal @@ -71,7 +71,7 @@ public distinct isolated client class Client { ServerResult response = check self.sendRequestMessage(initRequest); - if !(response is InitializeResult) { + if response !is InitializeResult { return error ClientInitializationError( string `Initialization failed: unexpected response type '${ (typeof response).toString()}' received from server.` @@ -93,8 +93,7 @@ public distinct isolated client class Client { self.serverInfo = response.serverInfo.cloneReadOnly(); // Send notification to complete initialization. - InitializedNotification initNotification = {}; - check self.sendNotificationMessage(initNotification); + check self.sendNotificationMessage( {}); } # Opens a server-sent events (SSE) stream for asynchronous server-to-client communication.