diff --git a/.gitignore b/.gitignore index 834a0a0..6034716 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,6 @@ target # Ballerina velocity.log* *Ballerina.lock + +# AI Tools +CLAUDE.md diff --git a/ballerina/CompilerPlugin.toml b/ballerina/CompilerPlugin.toml new file mode 100644 index 0000000..677dafc --- /dev/null +++ b/ballerina/CompilerPlugin.toml @@ -0,0 +1,9 @@ +[plugin] +id = "mcp-compiler-plugin" +class = "io.ballerina.stdlib.mcp.plugin.McpCompilerPlugin" + +[[dependency]] +path = "../compiler-plugin/build/libs/mcp-compiler-plugin-0.4.3-SNAPSHOT.jar" + +[[dependency]] +path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar" diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index a8c757c..6ec1e24 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -5,7 +5,7 @@ [ballerina] dependencies-toml-version = "2" -distribution-version = "2201.12.7" +distribution-version = "2201.12.0" [[package]] org = "ballerina" @@ -108,9 +108,6 @@ dependencies = [ {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "lang.value"} ] -modules = [ - {org = "ballerina", packageName = "io", moduleName = "io"} -] [[package]] org = "ballerina" @@ -219,9 +216,6 @@ dependencies = [ {org = "ballerina", name = "lang.value"}, {org = "ballerina", name = "observe"} ] -modules = [ - {org = "ballerina", packageName = "log", moduleName = "log"} -] [[package]] org = "ballerina" @@ -229,9 +223,8 @@ name = "mcp" version = "0.4.3" dependencies = [ {org = "ballerina", name = "http"}, - {org = "ballerina", name = "io"}, {org = "ballerina", name = "jballerina.java"}, - {org = "ballerina", name = "log"} + {org = "ballerina", name = "uuid"} ] modules = [ {org = "ballerina", packageName = "mcp", moduleName = "mcp"} @@ -314,4 +307,7 @@ dependencies = [ {org = "ballerina", name = "lang.int"}, {org = "ballerina", name = "time"} ] +modules = [ + {org = "ballerina", packageName = "uuid", moduleName = "uuid"} +] diff --git a/ballerina/build.gradle b/ballerina/build.gradle index 37951cf..4e1ef4e 100644 --- a/ballerina/build.gradle +++ b/ballerina/build.gradle @@ -27,7 +27,9 @@ def packageName = "mcp" def packageOrg = "ballerina" def tomlVersion = stripBallerinaExtensionVersion("${project.version}") def ballerinaTomlFilePlaceHolder = new File("${project.rootDir}/build-config/resources/Ballerina.toml") +def compilerPluginTomlFilePlaceHolder = new File("${project.rootDir}/build-config/resources/CompilerPlugin.toml") def ballerinaTomlFile = new File("$project.projectDir/Ballerina.toml") +def compilerPluginTomlFile = new File("$project.projectDir/CompilerPlugin.toml") def stripBallerinaExtensionVersion(String extVersion) { if (extVersion.matches(project.ext.timestampedVersionRegex)) { @@ -55,6 +57,11 @@ task updateTomlFiles { def newConfig = ballerinaTomlFilePlaceHolder.text.replace("@project.version@", project.version) newConfig = newConfig.replace("@toml.version@", tomlVersion) ballerinaTomlFile.text = newConfig + + def ballerinaToOpenApiVersion = project.ballerinaToOpenApiVersion + def newCompilerPluginToml = compilerPluginTomlFilePlaceHolder.text.replace("@project.version@", project.version) + newCompilerPluginToml = newCompilerPluginToml.replace("@ballerinaToOpenApiVersion.version@", ballerinaToOpenApiVersion) + compilerPluginTomlFile.text = newCompilerPluginToml } } @@ -94,7 +101,9 @@ updateTomlFiles.dependsOn copyStdlibs build.dependsOn "generatePomFileForMavenPublication" build.dependsOn ":${packageName}-native:build" +build.dependsOn ":${packageName}-compiler-plugin:build" test.dependsOn ":${packageName}-native:build" +test.dependsOn ":${packageName}-compiler-plugin:build" publishToMavenLocal.dependsOn build publish.dependsOn build diff --git a/ballerina/client.bal b/ballerina/client.bal index 2cc17fe..0bbbdd3 100644 --- a/ballerina/client.bal +++ b/ballerina/client.bal @@ -23,7 +23,7 @@ public type ClientConfiguration record {| ClientCapabilityConfiguration capabilityConfig?; |}; -# Configuration options for initializing an MCP client. +# Configuration for MCP client capabilities. public type ClientCapabilityConfiguration record {| # Capabilities to be advertised by this client. ClientCapabilities capabilities?; @@ -33,15 +33,8 @@ public type ClientCapabilityConfiguration record {| # Represents an MCP client built on top of the Streamable HTTP transport. public distinct isolated client class Client { - # MCP server URL. - private final string serverUrl; - # Client implementation details (e.g., name and version). - private final Implementation clientInfo; - # Capabilities supported by the client. - private final ClientCapabilities clientCapabilities; - # Transport for communication with the MCP server. - private StreamableHttpClientTransport? transport = (); + private StreamableHttpClientTransport transport; # Server capabilities. private ServerCapabilities? serverCapabilities = (); # Server implementation information. @@ -49,72 +42,58 @@ public distinct isolated client class Client { # Request ID generator for tracking requests. private int requestId = 0; - # Initializes a new MCP client with the provided server URL and client details. + # Initializes a new MCP client and establishes connection to the server. + # Performs protocol handshake and capability exchange. Client is ready for use after construction. # - # + serverUrl - MCP server URL. - # + clientInfo - Client details, such as name and version. - # + config - Optional configuration containing client capabilities. - public isolated function init(string serverUrl, *ClientConfiguration config) { - self.serverUrl = serverUrl; - self.clientInfo = config.info.cloneReadOnly(); - self.clientCapabilities = (config.capabilityConfig?.capabilities).cloneReadOnly() ?: {}; - } - - # Establishes a connection to the MCP server and performs protocol initialization. - # - # + return - A ClientError if initialization fails, or nil on success. - isolated remote function initialize() returns ClientError? { - lock { - // Create and initialize transport. - StreamableHttpClientTransport newTransport = check new StreamableHttpClientTransport(self.serverUrl); - self.transport = newTransport; - - string? sessionId = newTransport.getSessionId(); + # + serverUrl - MCP server URL + # + config - Client configuration including info and capabilities + # + return - ClientError if initialization fails, nil on success + public isolated function init(string serverUrl, *ClientConfiguration config) returns ClientError? { + // Create and initialize transport. + StreamableHttpClientTransport newTransport = check new (serverUrl); + self.transport = newTransport; + + string? sessionId = newTransport.getSessionId(); + + // If a session ID exists, assume reconnection and skip initialization. + if sessionId is string { + return; + } - // If a session ID exists, assume reconnection and skip initialization. - if sessionId is string { - return; + // Prepare and send the initialization request. + InitializeRequest initRequest = { + params: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: config.capabilityConfig?.capabilities ?: {}, + clientInfo: config.info } + }; - // Prepare and send the initialization request. - InitializeRequest initRequest = { - method: "initialize", - params: { - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: self.clientCapabilities, - clientInfo: self.clientInfo - } - }; + ServerResult response = check self.sendRequestMessage(initRequest); - ServerResult response = check self.sendRequestMessage(initRequest); - - if response is InitializeResult { - final readonly & string protocolVersion = response.protocolVersion; - // Validate protocol compatibility. - if (!SUPPORTED_PROTOCOL_VERSIONS.some(v => v == protocolVersion)) { - return error ProtocolVersionError( - string `Server protocol version '${ - protocolVersion}' is not supported. Supported versions: ${ - SUPPORTED_PROTOCOL_VERSIONS.toString()}.` - ); - } - - // Store server capabilities and info. - self.serverCapabilities = response.capabilities; - self.serverInfo = response.serverInfo; - - // Send notification to complete initialization. - InitializedNotification initNotification = { - method: "notifications/initialized" - }; - check self.sendNotificationMessage(initNotification); - } else { - return error ClientInitializationError( - string `Initialization failed: unexpected response type '${ - (typeof response).toString()}' received from server.` - ); - } + if response !is InitializeResult { + return error ClientInitializationError( + string `Initialization failed: unexpected response type '${ + (typeof response).toString()}' received from server.` + ); + } + + final string protocolVersion = response.protocolVersion; + // Validate protocol compatibility. + if (!SUPPORTED_PROTOCOL_VERSIONS.some(v => v == protocolVersion)) { + return error ProtocolVersionError( + string `Server protocol version '${ + protocolVersion}' is not supported. Supported versions: ${ + SUPPORTED_PROTOCOL_VERSIONS.toString()}.` + ); } + + // Store server capabilities and info. + self.serverCapabilities = response.capabilities.cloneReadOnly(); + self.serverInfo = response.serverInfo.cloneReadOnly(); + + // Send notification to complete initialization. + check self.sendNotificationMessage( {}); } # Opens a server-sent events (SSE) stream for asynchronous server-to-client communication. @@ -122,13 +101,7 @@ public distinct isolated client class Client { # + return - Stream of JsonRpcMessages or a ClientError. isolated remote function subscribeToServerMessages() returns stream|ClientError { lock { - StreamableHttpClientTransport? currentTransport = self.transport; - if currentTransport is () { - return error UninitializedTransportError( - "Subscription failed: client transport is not initialized." - ); - } - return currentTransport.establishEventStream(); + return self.transport.establishEventStream(); } } @@ -136,9 +109,7 @@ public distinct isolated client class Client { # # + return - List of available tools or a ClientError. isolated remote function listTools() returns ListToolsResult|ClientError { - ListToolsRequest listToolsRequest = { - method: "tools/list" - }; + ListToolsRequest listToolsRequest = {}; ServerResult result = check self.sendRequestMessage(listToolsRequest); if result is ListToolsResult { @@ -156,7 +127,6 @@ public distinct isolated client class Client { # + return - Result of the tool execution or a ClientError. isolated remote function callTool(CallToolParams params) returns CallToolResult|ClientError { CallToolRequest toolCallRequest = { - method: "tools/call", params: params }; @@ -175,20 +145,10 @@ public distinct isolated client class Client { # + return - A ClientError if closure fails, or nil on success. isolated remote function close() returns ClientError? { lock { - StreamableHttpClientTransport? currentTransport = self.transport; - if currentTransport is () { - return error UninitializedTransportError( - "Closure failed: client transport is not initialized." - ); - } - do { - check currentTransport.terminateSession(); - lock { - self.transport = (); - self.serverCapabilities = (); - self.serverInfo = (); - } + check self.transport.terminateSession(); + self.serverCapabilities = (); + self.serverInfo = (); return; } on fail error e { return error ClientError(string `Failed to disconnect from server: ${e.message()}`, e); @@ -202,26 +162,17 @@ public distinct isolated client class Client { # + return - ServerResult, a stream of results, or a ClientError. private isolated function sendRequestMessage(Request request) returns ServerResult|ClientError { lock { - StreamableHttpClientTransport? currentTransport = self.transport; - if currentTransport is () { - return error UninitializedTransportError( - "Cannot send request: client transport is not initialized." - ); - } + self.requestId += 1; - lock { - self.requestId += 1; - - JsonRpcRequest jsonRpcRequest = { - ...request.cloneReadOnly(), - jsonrpc: JSONRPC_VERSION, - id: self.requestId - }; + JsonRpcRequest jsonRpcRequest = { + ...request.cloneReadOnly(), + jsonrpc: JSONRPC_VERSION, + id: self.requestId + }; - JsonRpcMessage|stream|StreamableHttpTransportError? response = - currentTransport.sendMessage(jsonRpcRequest); - return processServerResponse(response).cloneReadOnly(); - } + JsonRpcMessage|stream|StreamableHttpTransportError? response = + self.transport.sendMessage(jsonRpcRequest); + return processServerResponse(response).cloneReadOnly(); } } @@ -231,19 +182,12 @@ public distinct isolated client class Client { # + return - A ClientError if sending fails, or nil on success. private isolated function sendNotificationMessage(Notification notification) returns ClientError? { lock { - StreamableHttpClientTransport? currentTransport = self.transport; - if currentTransport is () { - return error UninitializedTransportError( - "Cannot send notification: client transport is not initialized." - ); - } - JsonRpcNotification jsonRpcNotification = { ...notification.cloneReadOnly(), jsonrpc: JSONRPC_VERSION }; - _ = check currentTransport.sendMessage(jsonRpcNotification); + _ = check self.transport.sendMessage(jsonRpcNotification); } } } diff --git a/ballerina/dispatcher_service.bal b/ballerina/dispatcher_service.bal new file mode 100644 index 0000000..b692cda --- /dev/null +++ b/ballerina/dispatcher_service.bal @@ -0,0 +1,320 @@ +// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/http; +import ballerina/uuid; + +isolated service class DispatcherService { + *http:Service; + + private map sessionMap = {}; + + isolated resource function delete .(http:Headers headers) returns http:BadRequest|http:Ok { + string? sessionId = self.getSessionIdFromHeaders(headers); + if sessionId is () { + return { + body: self.createJsonRpcError(INVALID_REQUEST, "Missing session ID header") + }; + } + + lock { + if !self.sessionMap.hasKey(sessionId) { + return { + body: self.createJsonRpcError(INVALID_REQUEST, string `Invalid session ID: ${sessionId}`) + }; + } + + _ = self.sessionMap.remove(sessionId); + } + + return { + body: { + jsonrpc: JSONRPC_VERSION, + result: { + message: string `Session ${sessionId} deleted successfully` + } + } + }; + } + + isolated resource function post .(@http:Payload JsonRpcMessage request, http:Headers headers) + returns http:BadRequest|http:NotAcceptable|http:UnsupportedMediaType|http:Accepted|http:Ok|Error { + http:NotAcceptable|http:UnsupportedMediaType? headerValidationError = self.validateHeaders(headers); + if headerValidationError !is () { + return headerValidationError; + } + + if request is JsonRpcRequest { + return self.processJsonRpcRequest(request, headers); + } + + if request is JsonRpcNotification { + return self.processJsonRpcNotification(request); + } + + return { + body: self.createJsonRpcError(INVALID_REQUEST, "Unsupported request type") + }; + } + + private isolated function validateHeaders(http:Headers headers) + returns http:NotAcceptable|http:UnsupportedMediaType? { + string|http:HeaderNotFoundError acceptHeader = headers.getHeader(ACCEPT_HEADER); + if acceptHeader is http:HeaderNotFoundError { + return { + body: self.createJsonRpcError(NOT_ACCEPTABLE, + "Not Acceptable: Client must accept both application/json and text/event-stream") + }; + } + + if !acceptHeader.includes(CONTENT_TYPE_JSON) || !acceptHeader.includes(CONTENT_TYPE_SSE) { + return { + body: self.createJsonRpcError(NOT_ACCEPTABLE, + "Not Acceptable: Client must accept both application/json and text/event-stream") + }; + } + + string|http:HeaderNotFoundError contentTypeHeader = headers.getHeader(CONTENT_TYPE_HEADER); + if contentTypeHeader is http:HeaderNotFoundError { + return { + body: self.createJsonRpcError(UNSUPPORTED_MEDIA_TYPE, + "Unsupported Media Type: Content-Type must be application/json") + }; + } + + if !contentTypeHeader.includes(CONTENT_TYPE_JSON) { + return { + body: self.createJsonRpcError(UNSUPPORTED_MEDIA_TYPE, + "Unsupported Media Type: Content-Type must be application/json") + }; + } + + return; + } + + private isolated function processJsonRpcRequest(JsonRpcRequest request, http:Headers headers) + returns http:BadRequest|http:Ok|Error { + match request.method { + REQUEST_INITIALIZE => { + return self.handleInitializeRequest(request, headers); + } + REQUEST_LIST_TOOLS => { + return self.handleListToolsRequest(request, headers); + } + REQUEST_CALL_TOOL => { + return self.handleCallToolRequest(request, headers); + } + _ => { + return { + body: self.createJsonRpcError(METHOD_NOT_FOUND, "Method not found", request.id) + }; + } + } + } + + private isolated function processJsonRpcNotification(JsonRpcNotification notification) returns http:Accepted|http:BadRequest { + if notification.method == NOTIFICATION_INITIALIZED { + return http:ACCEPTED; + } + + return { + body: { + jsonrpc: JSONRPC_VERSION, + 'error: { + code: METHOD_NOT_FOUND, + message: "Unknown notification method" + } + } + }; + } + + private isolated function getSessionIdFromHeaders(http:Headers headers) returns string? { + string|http:HeaderNotFoundError sessionHeader = headers.getHeader(SESSION_ID_HEADER); + return sessionHeader is string ? sessionHeader : (); + } + + private isolated function handleInitializeRequest(JsonRpcRequest jsonRpcRequest, http:Headers headers) returns http:BadRequest|http:Ok|Error { + JsonRpcRequest {jsonrpc: _, id, ...request} = jsonRpcRequest; + InitializeRequest|error initRequest = request.cloneWithType(); + if initRequest is error { + return { + body: self.createJsonRpcError(INVALID_REQUEST, + string `Invalid request: ${initRequest.message()}`, id) + }; + } + + // Check if there's a session ID in the headers + string? existingSessionId = self.getSessionIdFromHeaders(headers); + + lock { + // If there's an existing session ID and it's already in the map, return error + if existingSessionId is string && self.sessionMap.hasKey(existingSessionId) { + return { + body: self.createJsonRpcError(INVALID_REQUEST, + string `Session already initialized: ${existingSessionId}`, id) + }; + } + + Service|AdvancedService mcpService = check getMcpServiceFromDispatcher(self); + ServiceConfiguration serviceConfig = getServiceConfiguration(mcpService); + + // Create new session ID + string newSessionId = uuid:createRandomUuid(); + self.sessionMap[newSessionId] = "initialized"; + + string requestedVersion = initRequest.params.protocolVersion; + string protocolVersion = self.selectProtocolVersion(requestedVersion); + + return { + headers: {[SESSION_ID_HEADER]: newSessionId}, + body: { + jsonrpc: JSONRPC_VERSION, + id: id, + result: { + protocolVersion: protocolVersion, + capabilities: (serviceConfig.options?.capabilities ?: {}).cloneReadOnly(), + serverInfo: serviceConfig.info.cloneReadOnly() + } + } + }; + } + } + + private isolated function handleListToolsRequest(JsonRpcRequest request, http:Headers headers) returns http:BadRequest|http:Ok { + // Validate session ID + string? sessionId = self.getSessionIdFromHeaders(headers); + if sessionId is () { + return { + body: self.createJsonRpcError(INVALID_REQUEST, + "Missing session ID header", request.id) + }; + } + + lock { + // Check if session exists + if !self.sessionMap.hasKey(sessionId) { + return { + body: self.createJsonRpcError(INVALID_REQUEST, + string `Invalid session ID: ${sessionId}`, request.id) + }; + } + } + + ListToolsResult|error listToolsResult = self.executeOnListTools(); + if listToolsResult is error { + return { + body: self.createJsonRpcError(INTERNAL_ERROR, + string `Failed to list tools: ${listToolsResult.message()}`, request.id) + }; + } + + return { + headers: {[SESSION_ID_HEADER]: sessionId}, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: listToolsResult.cloneReadOnly() + } + }; + } + + private isolated function handleCallToolRequest(JsonRpcRequest request, http:Headers headers) returns http:BadRequest|http:Ok { + // Validate session ID + string? sessionId = self.getSessionIdFromHeaders(headers); + if sessionId is () { + return { + body: self.createJsonRpcError(INVALID_REQUEST, + "Missing session ID header", request.id) + }; + } + + lock { + // Check if session exists + if !self.sessionMap.hasKey(sessionId) { + return { + body: self.createJsonRpcError(INVALID_REQUEST, + string `Invalid session ID: ${sessionId}`, request.id) + }; + } + } + + // Extract and validate parameters + CallToolParams|error params = request.params.cloneWithType(); + if params is error { + return { + body: self.createJsonRpcError(INVALID_PARAMS, + string `Invalid parameters: ${params.message()}`, request.id) + }; + } + + CallToolResult|error callToolResult = self.executeOnCallTool(params); + if callToolResult is error { + return { + body: self.createJsonRpcError(INTERNAL_ERROR, + string `Failed to call tool '${params.name}': ${callToolResult.message()}`, request.id) + }; + } + + return { + headers: {[SESSION_ID_HEADER]: sessionId}, + body: { + jsonrpc: JSONRPC_VERSION, + id: request.id, + result: callToolResult.cloneReadOnly() + } + }; + } + + private isolated function selectProtocolVersion(string requestedVersion) returns string { + foreach string supportedVersion in SUPPORTED_PROTOCOL_VERSIONS { + if supportedVersion == requestedVersion { + return requestedVersion; + } + } + return LATEST_PROTOCOL_VERSION; + } + + private isolated function createJsonRpcError(int code, string message, RequestId? id = ()) returns JsonRpcError & readonly => { + jsonrpc: JSONRPC_VERSION, + id: id, + 'error: { + code: code, + message: message + } + }; + + private isolated function executeOnListTools() returns ListToolsResult|Error { + Service|AdvancedService mcpService = check getMcpServiceFromDispatcher(self); + if mcpService is AdvancedService { + return invokeOnListTools(mcpService); + } + if mcpService is Service { + return listToolsForRemoteFunctions(mcpService); + } + return error DispatcherError("MCP Service is not attached"); + } + + private isolated function executeOnCallTool(CallToolParams params) returns CallToolResult|Error { + Service|AdvancedService mcpService = check getMcpServiceFromDispatcher(self); + if mcpService is AdvancedService { + return invokeOnCallTool(mcpService, params.cloneReadOnly()); + } + if mcpService is Service { + return callToolForRemoteFunctions(mcpService, params.cloneReadOnly()); + } + return error DispatcherError("MCP Service is not attached"); + } +}; diff --git a/ballerina/error.bal b/ballerina/error.bal index 07374d2..7845795 100644 --- a/ballerina/error.bal +++ b/ballerina/error.bal @@ -79,3 +79,9 @@ public type ListToolsError distinct ClientError; # Error for failures during tool execution operations. public type ToolCallError distinct ClientError; + +# Errors for failures occurring during server operations. +public type ServerError distinct Error; + +# Custom error type for dispatcher service operations. +type DispatcherError distinct ServerError; diff --git a/ballerina/listener.bal b/ballerina/listener.bal new file mode 100644 index 0000000..de1b88a --- /dev/null +++ b/ballerina/listener.bal @@ -0,0 +1,118 @@ +// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/http; + +# Configuration options for initializing an MCP listener. +public type ListenerConfiguration record {| + *http:ListenerConfiguration; +|}; + +# A server listener for handling MCP service requests. +public isolated class Listener { + private http:Listener httpListener; + private DispatcherService[] dispatcherServices = []; + + # Initializes the Listener. + # + # + listenTo - Either a port number (int) or an existing http:Listener. + # + config - Listener configuration. + # + return - Error? if listener initialization fails. + public function init(int|http:Listener listenTo, *ListenerConfiguration config) returns Error? { + if listenTo is http:Listener { + self.httpListener = listenTo; + } else { + http:Listener|error httpListener = new (listenTo, config); + if httpListener is error { + return error("Failed to initialize HTTP listener: " + httpListener.message()); + } + self.httpListener = httpListener; + } + } + + # Attaches an MCP service to the listener under the specified path(s). + # + # + mcpService - Service to attach. + # + name - Path(s) to mount the service on (string or string array). + # + return - Error? if attachment fails. + public isolated function attach(Service|AdvancedService mcpService, string[]|string? name = ()) returns Error? { + DispatcherService dispatcherService = new (); + check addMcpServiceToDispatcher(dispatcherService, mcpService); + lock { + error? result = self.httpListener.attach(dispatcherService, name.cloneReadOnly()); + if result is error { + return error("Failed to attach MCP service: " + result.message()); + } + self.dispatcherServices.push(dispatcherService); + } + } + + # Detaches the MCP service from the listener. + # + # + mcpService - Service to detach. + # + return - Error? if detachment fails. + public isolated function detach(Service|AdvancedService mcpService) returns Error? { + lock { + foreach [int, DispatcherService] [index, dispatcherService] in self.dispatcherServices.enumerate() { + Service|AdvancedService|Error attachedService = getMcpServiceFromDispatcher(dispatcherService); + if attachedService === mcpService { + error? result = self.httpListener.detach(dispatcherService); + if result is error { + return error("Failed to detach MCP service: " + result.message()); + } + _ = self.dispatcherServices.remove(index); + break; + } + } + } + } + + # Starts the listener (begin accepting connections). + # + # + return - Error? if starting fails. + public isolated function 'start() returns Error? { + lock { + error? result = self.httpListener.start(); + if result is error { + return error("Failed to start MCP listener: " + result.message()); + } + } + } + + # Gracefully stops the listener (completes active requests before shutting down). + # + # + return - Error? if graceful stop fails. + public isolated function gracefulStop() returns Error? { + lock { + error? result = self.httpListener.gracefulStop(); + if result is error { + return error("Failed to gracefully stop MCP listener: " + result.message()); + } + } + } + + # Immediately stops the listener (terminates all connections). + # + # + return - Error? if immediate stop fails. + public isolated function immediateStop() returns Error? { + lock { + error? result = self.httpListener.immediateStop(); + if result is error { + return error("Failed to immediately stop MCP listener: " + result.message()); + } + } + } +} diff --git a/ballerina/native_listener_helper.bal b/ballerina/native_listener_helper.bal new file mode 100644 index 0000000..80ccf5a --- /dev/null +++ b/ballerina/native_listener_helper.bal @@ -0,0 +1,46 @@ +// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/jballerina.java; + +isolated function invokeOnListTools(AdvancedService 'service) returns ListToolsResult|Error = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; + +isolated function invokeOnCallTool(AdvancedService 'service, CallToolParams params) + returns CallToolResult|Error = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; + +isolated function listToolsForRemoteFunctions(Service 'service, typedesc t = <>) + returns t|Error = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; + +isolated function callToolForRemoteFunctions(Service 'service, CallToolParams params, typedesc t = <>) + returns t|Error = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; + +isolated function addMcpServiceToDispatcher(DispatcherService dispatcherService, Service|AdvancedService mcpService) + returns Error? = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; + +isolated function getMcpServiceFromDispatcher(DispatcherService dispatcherService) + returns Service|AdvancedService|Error = @java:Method { + 'class: "io.ballerina.stdlib.mcp.McpServiceMethodHelper" +} external; diff --git a/ballerina/types.bal b/ballerina/types.bal index ca83cda..7d2c807 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -26,8 +26,21 @@ public const SUPPORTED_PROTOCOL_VERSIONS = [ public const JSONRPC_VERSION = "2.0"; -// # Notification methods -public const NOTIFICATION_INITIALIZED = "notifications/initialized"; +# Standard JSON-RPC request methods supported by the MCP protocol +public enum RequestMethod { + # Initialize request to start a new MCP session + REQUEST_INITIALIZE = "initialize", + # Request to list all available tools from the server + REQUEST_LIST_TOOLS = "tools/list", + # Request to execute a specific tool with given parameters + REQUEST_CALL_TOOL = "tools/call" +}; + +# JSON-RPC notification methods used for one-way communication in MCP +public enum NotificationMethod { + # Notification sent by client to indicate successful initialization + NOTIFICATION_INITIALIZED = "notifications/initialized" +}; # A progress token, used to associate progress notifications with the original request. public type ProgressToken string|int; @@ -99,11 +112,46 @@ public type JsonRpcResponse record {| ServerResult result; |}; +// Standard JSON-RPC error codes +# Standard JSON-RPC 2.0 error code indicating that the JSON sent is not a valid JSON object. +public const PARSE_ERROR = -32700; +# Standard JSON-RPC 2.0 error code indicating that the JSON sent is not a valid request object. +public const INVALID_REQUEST = -32600; +# Standard JSON-RPC 2.0 error code indicating that the method does not exist or is not available. +public const METHOD_NOT_FOUND = -32601; +# Standard JSON-RPC 2.0 error code indicating that invalid method parameters were provided. +public const INVALID_PARAMS = -32602; +# Standard JSON-RPC 2.0 error code indicating that an internal error occurred on the server. +public const INTERNAL_ERROR = -32603; + +// Library-defined error codes +# MCP library-defined error code indicating that the request is not acceptable. +public const NOT_ACCEPTABLE = -32001; +# MCP library-defined error code indicating that the media type is not supported. +public const UNSUPPORTED_MEDIA_TYPE = -32002; + +# A response to a request that indicates an error occurred. +public type JsonRpcError record { + # The JSON-RPC protocol version + JSONRPC_VERSION jsonrpc; + # Identifier of the request + RequestId? id; + # The error information + record { + # The error type that occurred + int code; + # A short description of the error. The message SHOULD be limited to a concise single sentence. + string message; + # Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.). + anydata data?; + } 'error; +}; + # This request is sent from the client to the server when it first connects, asking it to begin initialization. type InitializeRequest record {| *Request; # Method name for the request - "initialize" method; + REQUEST_INITIALIZE method = REQUEST_INITIALIZE; # Parameters for the initialize request record { *RequestParams; @@ -139,7 +187,7 @@ public type InitializeResult record { public type InitializedNotification record {| *Notification; # The method identifier for the notification, must be "notifications/initialized" - NOTIFICATION_INITIALIZED method; + NOTIFICATION_INITIALIZED method = NOTIFICATION_INITIALIZED; |}; # Capabilities a client may support. Known capabilities are defined here, in this schema, @@ -243,14 +291,14 @@ public type EmbeddedResource record { public type ListToolsRequest record {| *PaginatedRequest; # The method identifier for this request - "tools/list" method; + REQUEST_LIST_TOOLS method = REQUEST_LIST_TOOLS; |}; # The server's response to a tools/list request from the client. public type ListToolsResult record { *PaginatedResult; # A list of tools available on the server. - Tool[] tools; + ToolDefinition[] tools; }; # The server's response to a tool call. @@ -265,7 +313,7 @@ public type CallToolResult record { # Used by the client to invoke a tool provided by the server. public type CallToolRequest record {| # The JSON-RPC method name - "tools/call" method; + REQUEST_CALL_TOOL method = REQUEST_CALL_TOOL; # The parameters for the tool call CallToolParams params; |}; @@ -306,7 +354,7 @@ public type ToolAnnotations record { }; # Definition for a tool the client can call. -public type Tool record { +public type ToolDefinition record { # The name of the tool string name; # A human-readable description of the tool @@ -369,3 +417,46 @@ public type AudioContent record { # Represents a result sent from the server to the client. public type ServerResult InitializeResult|CallToolResult|ListToolsResult; + +# Represents a tool configuration that can be used to define tools available in the MCP service. +public type McpToolConfig record {| + # The description of the tool. + string description?; + # The JSON schema for the tool's parameters. + map schema?; +|}; + +# Annotation to mark a function as an MCP tool configuration. +public annotation McpToolConfig Tool on object function; + +# Represents the options for configuring an MCP server. +public type ServerOptions record {| + # Capabilities to advertise as being supported by this server. + ServerCapabilities capabilities?; + # Optional instructions describing how to use the server and its features. + string instructions?; + # Whether to enforce strict capabilities compliance. + boolean enforceStrictCapabilities?; +|}; + +# Configuration for MCP service that defines server capabilities and metadata. +public type ServiceConfiguration record {| + # Server implementation information + Implementation info; + # Optional server configuration options + ServerOptions options?; +|}; + +# Annotation to provide service configuration to MCP services. +public annotation ServiceConfiguration ServiceConfig on service; + +# Defines an MCP service interface that handles incoming MCP requests. +public type AdvancedService distinct service object { + remote isolated function onListTools() returns ListToolsResult|ServerError; + remote isolated function onCallTool(CallToolParams params) returns CallToolResult|ServerError; +}; + +# Defines a basic mcp service interface that handles incoming mcp requests. +public type Service distinct service object { + +}; diff --git a/ballerina/utils.bal b/ballerina/utils.bal index b251534..d9fcb45 100644 --- a/ballerina/utils.bal +++ b/ballerina/utils.bal @@ -73,3 +73,17 @@ isolated function extractResultFromMessage(JsonRpcMessage message) returns Serve return error InvalidMessageTypeError("Received message from server is not a valid JsonRpcResponse."); } +# Retrieves the service configuration from an MCP service. +# +# + mcpService - The MCP service instance +# + return - The service configuration +isolated function getServiceConfiguration(Service|AdvancedService mcpService) returns ServiceConfiguration { + typedesc mcpServiceType = typeof mcpService; + ServiceConfiguration? serviceConfig = mcpServiceType.@ServiceConfig; + return serviceConfig ?: { + info: { + name: "MCP Service", + version: "1.0.0" + } + }; +} diff --git a/build-config/resources/CompilerPlugin.toml b/build-config/resources/CompilerPlugin.toml new file mode 100644 index 0000000..beecd8a --- /dev/null +++ b/build-config/resources/CompilerPlugin.toml @@ -0,0 +1,9 @@ +[plugin] +id = "mcp-compiler-plugin" +class = "io.ballerina.stdlib.mcp.plugin.McpCompilerPlugin" + +[[dependency]] +path = "../compiler-plugin/build/libs/mcp-compiler-plugin-@project.version@.jar" + +[[dependency]] +path = "../compiler-plugin/build/libs/ballerina-to-openapi-@ballerinaToOpenApiVersion.version@.jar" diff --git a/compiler-plugin/build.gradle b/compiler-plugin/build.gradle new file mode 100644 index 0000000..e67dd9f --- /dev/null +++ b/compiler-plugin/build.gradle @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +plugins { + id 'java' + id 'checkstyle' + id 'com.github.spotbugs' +} + +description = 'Ballerina - MCP Package Compiler Plugin' + +configurations { + externalJars +} + +dependencies { + checkstyle project(':checkstyle') + checkstyle "com.puppycrawl.tools:checkstyle:${checkstylePluginVersion}" + + implementation group: 'org.ballerinalang', name: 'ballerina-lang', version: "${ballerinaLangVersion}" + implementation group: 'org.ballerinalang', name: 'ballerina-tools-api', version: "${ballerinaLangVersion}" + implementation group: 'org.ballerinalang', name: 'ballerina-parser', version: "${ballerinaLangVersion}" + implementation group: 'io.ballerina.openapi', name: 'ballerina-to-openapi', version: "${ballerinaToOpenApiVersion}" + implementation group: 'io.swagger.core.v3', name: 'swagger-core', version: "${swaggerVersion}" + implementation group: 'io.swagger.core.v3', name: 'swagger-models', version: "${swaggerVersion}" + + implementation project(":mcp-native") + externalJars group: 'io.ballerina.openapi', name: 'ballerina-to-openapi', version: "${ballerinaToOpenApiVersion}" +} + +def excludePattern = '**/module-info.java' +tasks.withType(Checkstyle) { + exclude excludePattern +} + +checkstyle { + toolVersion "${project.checkstylePluginVersion}" + configFile rootProject.file("build-config/checkstyle/build/checkstyle.xml") + configProperties = ["suppressionFile": file("${rootDir}/build-config/checkstyle/build/suppressions.xml")] +} + +checkstyleMain.dependsOn(":checkstyle:downloadCheckstyleRuleFiles") + +spotbugsMain { + def classLoader = plugins["com.github.spotbugs"].class.classLoader + def SpotBugsConfidence = classLoader.findLoadedClass("com.github.spotbugs.snom.Confidence") + def SpotBugsEffort = classLoader.findLoadedClass("com.github.spotbugs.snom.Effort") + effort = SpotBugsEffort.MAX + reportLevel = SpotBugsConfidence.LOW + reportsDir = file("$project.buildDir/reports/spotbugs") + reports { + html.enabled true + text.enabled = true + } + def excludeFile = file("${rootDir}/spotbugs-exclude.xml") + if (excludeFile.exists()) { + excludeFilter = excludeFile + } +} + +compileJava { + doFirst { + options.compilerArgs = [ + '--module-path', classpath.asPath, + ] + classpath = files() + } +} + +build.dependsOn ":mcp-native:build" + +task copyOpenApiJar(type: Copy) { + from { + configurations.externalJars.collect { it } + } + into "${buildDir}/libs" +} + +build.dependsOn copyOpenApiJar diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java new file mode 100644 index 0000000..a6f022d --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCodeModifier.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.projects.DocumentId; +import io.ballerina.projects.plugins.CodeModifier; +import io.ballerina.projects.plugins.CodeModifierContext; + +import java.util.HashMap; +import java.util.Map; + +import static io.ballerina.compiler.syntax.tree.SyntaxKind.OBJECT_METHOD_DEFINITION; + +/** + * Code modifier for processing MCP tool annotations on remote functions. + * + *

This modifier analyzes object method definitions and automatically generates + * or updates MCP tool annotations with schema information during compilation.

+ */ +public class McpCodeModifier extends CodeModifier { + private final Map modifierContextMap = new HashMap<>(); + + @Override + public void init(CodeModifierContext codeModifierContext) { + codeModifierContext.addSyntaxNodeAnalysisTask(new RemoteFunctionAnalysisTask(modifierContextMap), + OBJECT_METHOD_DEFINITION); + codeModifierContext.addSourceModifierTask(new McpSourceModifier(modifierContextMap)); + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java new file mode 100644 index 0000000..723215b --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpCompilerPlugin.java @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.projects.plugins.CompilerPlugin; +import io.ballerina.projects.plugins.CompilerPluginContext; + +public class McpCompilerPlugin extends CompilerPlugin { + @Override + public void init(CompilerPluginContext compilerPluginContext) { + compilerPluginContext.addCodeModifier(new McpCodeModifier()); + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java new file mode 100644 index 0000000..d723e83 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/McpSourceModifier.java @@ -0,0 +1,240 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.compiler.api.SemanticModel; +import io.ballerina.compiler.syntax.tree.AnnotationNode; +import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; +import io.ballerina.compiler.syntax.tree.MetadataNode; +import io.ballerina.compiler.syntax.tree.ModuleMemberDeclarationNode; +import io.ballerina.compiler.syntax.tree.ModulePartNode; +import io.ballerina.compiler.syntax.tree.Node; +import io.ballerina.compiler.syntax.tree.NodeFactory; +import io.ballerina.compiler.syntax.tree.NodeList; +import io.ballerina.compiler.syntax.tree.NodeParser; +import io.ballerina.compiler.syntax.tree.ServiceDeclarationNode; +import io.ballerina.compiler.syntax.tree.SyntaxTree; +import io.ballerina.projects.DocumentId; +import io.ballerina.projects.Module; +import io.ballerina.projects.plugins.ModifierTask; +import io.ballerina.projects.plugins.SourceModifierContext; +import io.ballerina.tools.text.TextDocument; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static io.ballerina.compiler.syntax.tree.SyntaxKind.CLOSE_BRACE_TOKEN; +import static io.ballerina.compiler.syntax.tree.SyntaxKind.OBJECT_METHOD_DEFINITION; +import static io.ballerina.compiler.syntax.tree.SyntaxKind.OPEN_BRACE_TOKEN; +import static io.ballerina.compiler.syntax.tree.SyntaxKind.SERVICE_DECLARATION; +import static io.ballerina.stdlib.mcp.plugin.Utils.MCP_PACKAGE_NAME; +import static io.ballerina.stdlib.mcp.plugin.Utils.TOOL_ANNOTATION_NAME; +import static io.ballerina.stdlib.mcp.plugin.Utils.getToolAnnotationNode; + +/** + * Source modifier task that updates MCP tool annotations in Ballerina source files. + * + *

This modifier takes the analysis results from {@link RemoteFunctionAnalysisTask} + * and applies the generated tool annotations to the source code by modifying function metadata.

+ */ +public class McpSourceModifier implements ModifierTask { + private final Map modifierContextMap; + + /** + * Creates a new source modifier with the given modifier context map. + * + * @param modifierContextMap map containing analysis results for each document + */ + McpSourceModifier(Map modifierContextMap) { + this.modifierContextMap = modifierContextMap; + } + + /** + * Modifies source files by updating MCP tool annotations based on analysis results. + * + * @param context the source modifier context + */ + @Override + public void modify(SourceModifierContext context) { + for (Map.Entry entry : modifierContextMap.entrySet()) { + modifyDocumentWithTools(context, entry.getKey(), entry.getValue()); + } + } + + private void modifyDocumentWithTools(SourceModifierContext context, DocumentId documentId, + ModifierContext modifierContext) { + Module module = context.currentPackage().module(documentId.moduleId()); + SemanticModel semanticModel = context.compilation().getSemanticModel(documentId.moduleId()); + ModulePartNode rootNode = module.document(documentId).syntaxTree().rootNode(); + ModulePartNode updatedRoot = modifyModulePartRoot(semanticModel, rootNode, modifierContext, documentId); + updateDocument(context, module, documentId, updatedRoot); + } + + private ModulePartNode modifyModulePartRoot(SemanticModel semanticModel, ModulePartNode modulePartNode, + ModifierContext modifierContext, DocumentId documentId) { + List modifiedMembers = getModifiedModuleMembers(semanticModel, + modulePartNode.members(), modifierContext); + return modulePartNode.modify().withMembers(NodeFactory.createNodeList(modifiedMembers)).apply(); + } + + private List getModifiedModuleMembers(SemanticModel semanticModel, + NodeList members, + ModifierContext modifierContext) { + Map modifiedAnnotations = getModifiedAnnotations(modifierContext); + List modifiedMembers = new ArrayList<>(); + + for (ModuleMemberDeclarationNode member : members) { + modifiedMembers.add(getModifiedModuleMember(semanticModel, member, modifiedAnnotations)); + } + + return modifiedMembers; + } + + private Map getModifiedAnnotations(ModifierContext modifierContext) { + Map updatedAnnotationMap = new HashMap<>(); + for (Map.Entry entry : modifierContext + .getAnnotationConfigMap().entrySet()) { + updatedAnnotationMap.put(entry.getKey(), getModifiedAnnotation(entry.getValue())); + } + return updatedAnnotationMap; + } + + private AnnotationNode getModifiedAnnotation(ToolAnnotationConfig config) { + String mappingConstructorExpression = generateConfigMappingConstructor(config); + String annotationString = "@" + MCP_PACKAGE_NAME + ":" + TOOL_ANNOTATION_NAME + mappingConstructorExpression; + return NodeParser.parseAnnotation(annotationString); + } + + private String generateConfigMappingConstructor(ToolAnnotationConfig config) { + return generateConfigMappingConstructor(config, OPEN_BRACE_TOKEN.stringValue(), + CLOSE_BRACE_TOKEN.stringValue()); + } + + private String generateConfigMappingConstructor(ToolAnnotationConfig config, String openBraceSource, + String closeBraceSource) { + StringBuilder sb = new StringBuilder(); + sb.append(openBraceSource); + String desc = config.description().replaceAll("\\R", " "); + sb.append("description:").append(desc).append(","); + sb.append("schema:").append(config.schema()); + sb.append(closeBraceSource); + return sb.toString(); + } + + private ModuleMemberDeclarationNode getModifiedModuleMember( + SemanticModel semanticModel, + ModuleMemberDeclarationNode member, + Map modifiedAnnotations) { + + if (member.kind() == SERVICE_DECLARATION) { + return modifyServiceDeclaration(semanticModel, (ServiceDeclarationNode) member, modifiedAnnotations); + } + return member; + } + + private ModuleMemberDeclarationNode modifyServiceDeclaration( + SemanticModel semanticModel, + ServiceDeclarationNode classDefinitionNode, + Map modifiedAnnotations) { + + NodeList members = classDefinitionNode.members(); + ArrayList modifiedMembers = new ArrayList<>(); + + for (Node member : members) { + if (member.kind() != OBJECT_METHOD_DEFINITION) { + modifiedMembers.add(member); + continue; + } + + FunctionDefinitionNode functionDefinitionNode = (FunctionDefinitionNode) member; + AnnotationNode modifiedAnnotationNode = modifiedAnnotations.get(functionDefinitionNode); + + MetadataNode newMetadata = createOrUpdateMetadata( + semanticModel, functionDefinitionNode, modifiedAnnotationNode); + + FunctionDefinitionNode updatedFunction = functionDefinitionNode.modify() + .withMetadata(newMetadata) + .apply(); + + modifiedMembers.add(updatedFunction); + } + + return classDefinitionNode.modify().withMembers(NodeFactory.createNodeList(modifiedMembers)).apply(); + } + + private MetadataNode createOrUpdateMetadata( + SemanticModel semanticModel, + FunctionDefinitionNode functionDefinitionNode, + AnnotationNode modifiedAnnotationNode) { + + if (functionDefinitionNode.metadata().isEmpty()) { + return createMetadata(modifiedAnnotationNode); + } + + MetadataNode existingMetadata = functionDefinitionNode.metadata().get(); + Optional toolAnnotationNode = getToolAnnotationNode(semanticModel, functionDefinitionNode); + + if (toolAnnotationNode.isPresent()) { + return modifyMetadata(existingMetadata, toolAnnotationNode.get(), modifiedAnnotationNode); + } else { + return modifyWithToolAnnotation(existingMetadata, modifiedAnnotationNode); + } + } + + private MetadataNode modifyWithToolAnnotation(MetadataNode metadata, AnnotationNode annotationNode) { + List updatedAnnotations = new ArrayList<>(); + metadata.annotations().forEach(updatedAnnotations::add); + updatedAnnotations.add(annotationNode); + return metadata.modify() + .withAnnotations(NodeFactory.createNodeList(updatedAnnotations)) + .apply(); + } + + private MetadataNode modifyMetadata(MetadataNode metadata, AnnotationNode toolAnnotationNode, + AnnotationNode modifiedAnnotationNode) { + List updatedAnnotations = new ArrayList<>(); + for (AnnotationNode annotation : metadata.annotations()) { + if (annotation.equals(toolAnnotationNode)) { + updatedAnnotations.add(modifiedAnnotationNode); + } else { + updatedAnnotations.add(annotation); + } + } + return metadata.modify().withAnnotations(NodeFactory.createNodeList(updatedAnnotations)).apply(); + } + + private MetadataNode createMetadata(AnnotationNode annotationNode) { + NodeList annotationNodeList = NodeFactory.createNodeList(annotationNode); + return NodeFactory.createMetadataNode(null, annotationNodeList); + } + + private void updateDocument(SourceModifierContext context, Module module, DocumentId documentId, + ModulePartNode updatedRoot) { + SyntaxTree syntaxTree = module.document(documentId).syntaxTree().modifyWith(updatedRoot); + TextDocument textDocument = syntaxTree.textDocument(); + if (module.documentIds().contains(documentId)) { + context.modifySourceFile(textDocument, documentId); + } else { + context.modifyTestSourceFile(textDocument, documentId); + } + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java new file mode 100644 index 0000000..f06f306 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/ModifierContext.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; + +import java.util.HashMap; +import java.util.Map; + +public class ModifierContext { + private final Map annotationConfigMap = new HashMap<>(); + + void add(FunctionDefinitionNode node, ToolAnnotationConfig config) { + annotationConfigMap.put(node, config); + } + + Map getAnnotationConfigMap() { + return annotationConfigMap; + } +} + +record ToolAnnotationConfig( + String description, + String schema) { + + public static final String DESCRIPTION_FIELD_NAME = "description"; + public static final String SCHEMA_FIELD_NAME = "schema"; + + public String get(String field) { + return switch (field) { + case DESCRIPTION_FIELD_NAME -> description(); + case SCHEMA_FIELD_NAME -> schema(); + default -> null; + }; + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java new file mode 100644 index 0000000..f9fc55b --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/RemoteFunctionAnalysisTask.java @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.compiler.api.symbols.FunctionSymbol; +import io.ballerina.compiler.api.symbols.Symbol; +import io.ballerina.compiler.api.symbols.SymbolKind; +import io.ballerina.compiler.syntax.tree.AnnotationNode; +import io.ballerina.compiler.syntax.tree.ExpressionNode; +import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; +import io.ballerina.compiler.syntax.tree.MappingFieldNode; +import io.ballerina.compiler.syntax.tree.NodeFactory; +import io.ballerina.compiler.syntax.tree.NodeLocation; +import io.ballerina.compiler.syntax.tree.NodeParser; +import io.ballerina.compiler.syntax.tree.SeparatedNodeList; +import io.ballerina.compiler.syntax.tree.SpecificFieldNode; +import io.ballerina.compiler.syntax.tree.SyntaxKind; +import io.ballerina.projects.DocumentId; +import io.ballerina.projects.plugins.AnalysisTask; +import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext; +import io.ballerina.stdlib.mcp.plugin.diagnostics.CompilationDiagnostic; +import io.ballerina.tools.diagnostics.Diagnostic; +import io.ballerina.tools.diagnostics.Location; + +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static io.ballerina.stdlib.mcp.plugin.ToolAnnotationConfig.DESCRIPTION_FIELD_NAME; +import static io.ballerina.stdlib.mcp.plugin.ToolAnnotationConfig.SCHEMA_FIELD_NAME; +import static io.ballerina.stdlib.mcp.plugin.Utils.getToolAnnotationNode; +import static io.ballerina.stdlib.mcp.plugin.diagnostics.CompilationDiagnostic.UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION; + +/** + * Analysis task that processes remote function definitions to generate MCP tool annotation configurations. + * + *

This task analyzes function signatures, extracts parameter schemas, and creates tool annotation + * configurations that will be used by {@link McpSourceModifier} to update source code.

+ */ +public class RemoteFunctionAnalysisTask implements AnalysisTask { + public static final String EMPTY_STRING = ""; + public static final String NIL_EXPRESSION = "()"; + + private final Map modifierContextMap; + private SyntaxNodeAnalysisContext context; + + /** + * Creates a new analysis task with the given modifier context map. + * + * @param modifierContextMap map to store analysis results for each document + */ + RemoteFunctionAnalysisTask(Map modifierContextMap) { + this.modifierContextMap = modifierContextMap; + } + + /** + * Performs analysis on a function definition node to extract tool annotation information. + * + * @param context the syntax node analysis context containing the function definition + */ + @Override + public void perform(SyntaxNodeAnalysisContext context) { + this.context = context; + + FunctionDefinitionNode functionDefinitionNode = (FunctionDefinitionNode) context.node(); + AnnotationNode toolAnnotationNode = getToolAnnotationNode( + context.semanticModel(), functionDefinitionNode + ).orElse(null); + + NodeLocation functionNodeLocation = functionDefinitionNode.location(); + Optional functionSymbol = getFunctionSymbol(functionDefinitionNode); + if (functionSymbol.isEmpty()) { + return; + } + ToolAnnotationConfig config = createAnnotationConfig(functionSymbol.get(), functionNodeLocation, + toolAnnotationNode); + addToModifierContext(context, functionDefinitionNode, config); + } + + private ToolAnnotationConfig createAnnotationConfig(FunctionSymbol functionSymbol, + NodeLocation functionNodeLocation, + AnnotationNode annotationNode) { + String functionName = functionSymbol.getName().orElse("unknownFunction"); + String description = Utils.addDoubleQuotes( + Utils.escapeDoubleQuotes( + Objects.requireNonNullElse(Utils.getDescription(functionSymbol), functionName))); + if (annotationNode == null) { + String schema = getParameterSchema(functionSymbol, functionNodeLocation); + return new ToolAnnotationConfig(description, schema); + } + SeparatedNodeList fields = annotationNode.annotValue().isEmpty() ? + NodeFactory.createSeparatedNodeList() : annotationNode.annotValue().get().fields(); + Map fieldValues = extractFieldValues(fields); + if (fieldValues.containsKey(DESCRIPTION_FIELD_NAME)) { + description = fieldValues.get(DESCRIPTION_FIELD_NAME).toSourceCode(); + } + String parameters = fieldValues.containsKey(SCHEMA_FIELD_NAME) + ? fieldValues.get(SCHEMA_FIELD_NAME).toSourceCode() + : getParameterSchema(functionSymbol, functionNodeLocation); + return new ToolAnnotationConfig(description, parameters); + } + + private Optional getFunctionSymbol(FunctionDefinitionNode functionDefinitionNode) { + Optional functionSymbol = context.semanticModel().symbol(functionDefinitionNode); + return functionSymbol.filter(symbol -> symbol.kind() == SymbolKind.FUNCTION + || symbol.kind() == SymbolKind.METHOD).map(FunctionSymbol.class::cast); + } + + private Map extractFieldValues(SeparatedNodeList fields) { + return fields.stream() + .filter(field -> field.kind() == SyntaxKind.SPECIFIC_FIELD) + .map(field -> (SpecificFieldNode) field) + .filter(field -> field.valueExpr().isPresent()) + .collect(Collectors.toMap( + field -> field.fieldName().toSourceCode().trim(), + field -> field.valueExpr().orElse(NodeParser.parseExpression(NIL_EXPRESSION)) + )); + } + + private String getParameterSchema(FunctionSymbol functionSymbol, Location alternativeFunctionLocation) { + try { + return SchemaUtils.getParameterSchema(functionSymbol, this.context); + } catch (Exception e) { + Diagnostic diagnostic = CompilationDiagnostic.getDiagnostic(UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION, + functionSymbol.getLocation().orElse(alternativeFunctionLocation), + functionSymbol.getName().orElse("unknownFunction")); + reportDiagnostic(diagnostic); + return NIL_EXPRESSION; + } + } + + private void reportDiagnostic(Diagnostic diagnostic) { + this.context.reportDiagnostic(diagnostic); + } + + private void addToModifierContext(SyntaxNodeAnalysisContext context, FunctionDefinitionNode functionDefinitionNode, + ToolAnnotationConfig toolAnnotationConfig) { + this.modifierContextMap.computeIfAbsent(context.documentId(), document -> new ModifierContext()) + .add(functionDefinitionNode, toolAnnotationConfig); + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/SchemaUtils.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/SchemaUtils.java new file mode 100644 index 0000000..9e40cae --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/SchemaUtils.java @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.compiler.api.symbols.FunctionSymbol; +import io.ballerina.compiler.api.symbols.FunctionTypeSymbol; +import io.ballerina.compiler.api.symbols.ParameterKind; +import io.ballerina.compiler.api.symbols.ParameterSymbol; +import io.ballerina.openapi.service.mapper.type.TypeMapper; +import io.ballerina.openapi.service.mapper.type.TypeMapperImpl; +import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext; +import io.swagger.v3.core.util.Json; +import io.swagger.v3.core.util.OpenAPISchema2JsonSchema; +import io.swagger.v3.oas.models.media.Schema; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static io.ballerina.stdlib.mcp.plugin.RemoteFunctionAnalysisTask.EMPTY_STRING; +import static io.ballerina.stdlib.mcp.plugin.RemoteFunctionAnalysisTask.NIL_EXPRESSION; + +/** + * Utility class for generating and manipulating function tool parameter schemas. + */ +public class SchemaUtils { + private static final String STRING = "string"; + private static final String BYTE = "byte"; + private static final String NUMBER = "number"; + + private SchemaUtils() { + } + + public static String getParameterSchema(FunctionSymbol functionSymbol, SyntaxNodeAnalysisContext context) + throws Exception { + FunctionTypeSymbol functionTypeSymbol = functionSymbol.typeDescriptor(); + List parameterSymbolList = functionTypeSymbol.params().get(); + if (functionTypeSymbol.params().isEmpty() || parameterSymbolList.isEmpty()) { + return NIL_EXPRESSION; + } + + Map individualParamSchema = new HashMap<>(); + List requiredParams = new ArrayList<>(); + TypeMapper typeMapper = new TypeMapperImpl(context); + for (ParameterSymbol parameterSymbol : parameterSymbolList) { + try { + String parameterName = parameterSymbol.getName().orElseThrow(); + if (parameterSymbol.paramKind() != ParameterKind.DEFAULTABLE) { + requiredParams.add(parameterName); + } + @SuppressWarnings("rawtypes") + Schema schema = typeMapper.getSchema(parameterSymbol.typeDescriptor()); + String parameterDescription = Utils.getParameterDescription(functionSymbol, parameterName); + schema.setDescription(parameterDescription); + String jsonSchema = SchemaUtils.getJsonSchema(schema); + individualParamSchema.put(parameterName, jsonSchema); + } catch (RuntimeException e) { + throw new Exception(e); + } + } + String properties = individualParamSchema.entrySet().stream() + .map(entry -> String.format("\"%s\": %s", entry.getKey(), entry.getValue())) + .collect(Collectors.joining(", ", "{", "}")); + + String required = requiredParams.stream() + .map(paramName -> String.format("\"%s\"", paramName)) + .collect(Collectors.joining(", ", "[", "]")); + return String.format("{\"type\":\"object\",\"required\":%s,\"properties\":%s}", + required, properties); + } + + @SuppressWarnings("rawtypes") + private static String getJsonSchema(Schema schema) { + modifySchema(schema); + OpenAPISchema2JsonSchema openAPISchema2JsonSchema = new OpenAPISchema2JsonSchema(); + openAPISchema2JsonSchema.process(schema); + String newLineRegex = "\\R"; + String jsonCompressionRegex = "\\s*([{}\\[\\]:,])\\s*"; + return Json.pretty(schema.getJsonSchema()) + .replaceAll(newLineRegex, EMPTY_STRING) + .replaceAll(jsonCompressionRegex, "$1"); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static void modifySchema(Schema schema) { + if (schema == null) { + return; + } + modifySchema(schema.getItems()); + modifySchema(schema.getNot()); + + Map properties = schema.getProperties(); + if (properties != null) { + properties.values().forEach(SchemaUtils::modifySchema); + } + + List allOf = schema.getAllOf(); + if (allOf != null) { + schema.setType(null); + allOf.forEach(SchemaUtils::modifySchema); + } + + List anyOf = schema.getAnyOf(); + if (anyOf != null) { + schema.setType(null); + anyOf.forEach(SchemaUtils::modifySchema); + } + + List oneOf = schema.getOneOf(); + if (oneOf != null) { + schema.setType(null); + oneOf.forEach(SchemaUtils::modifySchema); + } + + // Override default ballerina byte to json schema mapping + if (BYTE.equals(schema.getFormat()) && STRING.equals(schema.getType())) { + schema.setFormat(null); + schema.setType(NUMBER); + } + removeUnwantedFields(schema); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private static void removeUnwantedFields(Schema schema) { + schema.setSpecVersion(null); + schema.setSpecVersion(null); + schema.setContains(null); + schema.set$id(null); + schema.set$schema(null); + schema.set$anchor(null); + schema.setExclusiveMaximumValue(null); + schema.setExclusiveMinimumValue(null); + schema.setDiscriminator(null); + schema.setTitle(null); + schema.setMaximum(null); + schema.setExclusiveMaximum(null); + schema.setMinimum(null); + schema.setExclusiveMinimum(null); + schema.setMaxLength(null); + schema.setMinLength(null); + schema.setMaxItems(null); + schema.setMinItems(null); + schema.setMaxProperties(null); + schema.setMinProperties(null); + schema.setAdditionalProperties(null); + schema.setAdditionalProperties(null); + schema.set$ref(null); + schema.set$ref(null); + schema.setReadOnly(null); + schema.setWriteOnly(null); + schema.setExample(null); + schema.setExample(null); + schema.setExternalDocs(null); + schema.setDeprecated(null); + schema.setPrefixItems(null); + schema.setContentEncoding(null); + schema.setContentMediaType(null); + schema.setContentSchema(null); + schema.setPropertyNames(null); + schema.setUnevaluatedProperties(null); + schema.setMaxContains(null); + schema.setMinContains(null); + schema.setAdditionalItems(null); + schema.setUnevaluatedItems(null); + schema.setIf(null); + schema.setElse(null); + schema.setThen(null); + schema.setDependentSchemas(null); + schema.set$comment(null); + schema.setExamples(null); + schema.setExtensions(null); + schema.setConst(null); + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java new file mode 100644 index 0000000..d1542bb --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/Utils.java @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin; + +import io.ballerina.compiler.api.SemanticModel; +import io.ballerina.compiler.api.symbols.AnnotationSymbol; +import io.ballerina.compiler.api.symbols.Documentable; +import io.ballerina.compiler.api.symbols.FunctionSymbol; +import io.ballerina.compiler.api.symbols.Symbol; +import io.ballerina.compiler.api.symbols.SymbolKind; +import io.ballerina.compiler.syntax.tree.AnnotationNode; +import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; +import io.ballerina.compiler.syntax.tree.MetadataNode; +import io.ballerina.compiler.syntax.tree.NodeList; + +import java.util.Optional; + +/** + * Util class for the compiler plugin. + */ +public class Utils { + public static final String BALLERINA_ORG = "ballerina"; + public static final String TOOL_ANNOTATION_NAME = "Tool"; + public static final String MCP_PACKAGE_NAME = "mcp"; + + private Utils() { + } + + public static boolean isMcpToolAnnotation(AnnotationSymbol annotationSymbol) { + return annotationSymbol.getModule().isPresent() + && isMcpModuleSymbol(annotationSymbol.getModule().get()) + && annotationSymbol.getName().isPresent() + && TOOL_ANNOTATION_NAME.equals(annotationSymbol.getName().get()); + } + + public static boolean isMcpModuleSymbol(Symbol symbol) { + return symbol.getModule().isPresent() + && MCP_PACKAGE_NAME.equals(symbol.getModule().get().id().moduleName()) + && BALLERINA_ORG.equals(symbol.getModule().get().id().orgName()); + } + + public static String getParameterDescription(FunctionSymbol functionSymbol, String parameterName) { + if (functionSymbol.documentation().isEmpty() + || functionSymbol.documentation().get().description().isEmpty()) { + return null; + } + return functionSymbol.documentation().get().parameterMap().getOrDefault(parameterName, null); + } + + public static String getDescription(Documentable documentable) { + if (documentable.documentation().isEmpty() + || documentable.documentation().get().description().isEmpty()) { + return null; + } + return documentable.documentation().get().description().get(); + } + + public static String escapeDoubleQuotes(String input) { + return input.replace("\"", "\\\""); + } + + + public static String addDoubleQuotes(String input) { + return "\"" + input + "\""; + } + + public static Optional getToolAnnotationNode(SemanticModel semanticModel, + FunctionDefinitionNode functionDefinitionNode) { + Optional metadataNode = functionDefinitionNode.metadata(); + if (metadataNode.isEmpty()) { + return Optional.empty(); + } + + NodeList annotationNodes = metadataNode.get().annotations(); + return annotationNodes.stream() + .filter(annotationNode -> + semanticModel.symbol(annotationNode) + .filter(symbol -> symbol.kind() == SymbolKind.ANNOTATION) + .filter(symbol -> Utils.isMcpToolAnnotation((AnnotationSymbol) symbol)) + .isPresent() + ) + .findFirst(); + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java new file mode 100644 index 0000000..1ab65e6 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/CompilationDiagnostic.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin.diagnostics; + +import io.ballerina.tools.diagnostics.Diagnostic; +import io.ballerina.tools.diagnostics.DiagnosticFactory; +import io.ballerina.tools.diagnostics.DiagnosticInfo; +import io.ballerina.tools.diagnostics.DiagnosticSeverity; +import io.ballerina.tools.diagnostics.Location; + +import static io.ballerina.tools.diagnostics.DiagnosticSeverity.ERROR; + +/** + * Compilation errors in the Ballerina mcp package. + */ +public enum CompilationDiagnostic { + UNABLE_TO_GENERATE_SCHEMA_FOR_FUNCTION(DiagnosticMessage.ERROR_101, DiagnosticCode.MCP_101, ERROR); + + private final String diagnostic; + private final String diagnosticCode; + private final DiagnosticSeverity diagnosticSeverity; + + CompilationDiagnostic(DiagnosticMessage message, DiagnosticCode diagnosticCode, + DiagnosticSeverity diagnosticSeverity) { + this.diagnostic = message.getMessage(); + this.diagnosticCode = diagnosticCode.name(); + this.diagnosticSeverity = diagnosticSeverity; + } + + public static Diagnostic getDiagnostic(CompilationDiagnostic compilationDiagnostic, Location location, + Object... args) { + DiagnosticInfo diagnosticInfo = new DiagnosticInfo( + compilationDiagnostic.getDiagnosticCode(), + compilationDiagnostic.getDiagnostic(), + compilationDiagnostic.getDiagnosticSeverity()); + return DiagnosticFactory.createDiagnostic(diagnosticInfo, location, args); + } + + public String getDiagnostic() { + return diagnostic; + } + + public String getDiagnosticCode() { + return diagnosticCode; + } + + public DiagnosticSeverity getDiagnosticSeverity() { + return this.diagnosticSeverity; + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java new file mode 100644 index 0000000..9635a96 --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticCode.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin.diagnostics; + +/** + * Compilation error codes used in Ballerina mcp package compiler plugin. + */ +public enum DiagnosticCode { + MCP_101 +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java new file mode 100644 index 0000000..7a0fc5e --- /dev/null +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/mcp/plugin/diagnostics/DiagnosticMessage.java @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp.plugin.diagnostics; + +/** + * Compilation error messages used in Ballerina mcp package compiler plugin. + */ +public enum DiagnosticMessage { + ERROR_101("failed to generate the parameter schema definition for the function ''{0}''." + + " Specify the parameter schema manually using the `@mcp:McpTool` annotation's parameter field."); + + private final String message; + + DiagnosticMessage(String message) { + this.message = message; + } + + public String getMessage() { + return this.message; + } +} diff --git a/compiler-plugin/src/main/java/module-info.java b/compiler-plugin/src/main/java/module-info.java new file mode 100644 index 0000000..f94df11 --- /dev/null +++ b/compiler-plugin/src/main/java/module-info.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2025 WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +module io.ballerina.stdlib.ai.plugin { + requires io.ballerina.lang; + requires io.ballerina.parser; + requires io.ballerina.tools.api; + requires io.ballerina.openapi.service; + requires io.swagger.v3.core; + requires io.swagger.v3.oas.models; +} diff --git a/gradle.properties b/gradle.properties index 4f28b41..e707a3f 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,6 +4,7 @@ version=0.4.3-SNAPSHOT ballerinaLangVersion=2201.12.0 ballerinaGradlePluginVersion=2.3.0 +ballerinaToOpenApiVersion=2.3.0 checkstylePluginVersion=10.12.0 spotbugsPluginVersion=6.0.18 @@ -11,7 +12,7 @@ spotbugsPluginVersion=6.0.18 shadowJarPluginVersion=8.1.1 downloadPluginVersion=5.4.0 releasePluginVersion=2.8.0 - +swaggerVersion=2.2.9 # Ballerina Library Dependencies # Level 01 diff --git a/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java b/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java new file mode 100644 index 0000000..8094af4 --- /dev/null +++ b/native/src/main/java/io/ballerina/stdlib/mcp/McpServiceMethodHelper.java @@ -0,0 +1,227 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.ballerina.stdlib.mcp; + +import io.ballerina.runtime.api.Environment; +import io.ballerina.runtime.api.creators.ValueCreator; +import io.ballerina.runtime.api.types.ArrayType; +import io.ballerina.runtime.api.types.Parameter; +import io.ballerina.runtime.api.types.RecordType; +import io.ballerina.runtime.api.types.ReferenceType; +import io.ballerina.runtime.api.types.RemoteMethodType; +import io.ballerina.runtime.api.types.ServiceType; +import io.ballerina.runtime.api.types.Type; +import io.ballerina.runtime.api.types.UnionType; +import io.ballerina.runtime.api.values.BArray; +import io.ballerina.runtime.api.values.BMap; +import io.ballerina.runtime.api.values.BObject; +import io.ballerina.runtime.api.values.BString; +import io.ballerina.runtime.api.values.BTypedesc; + +import java.util.List; +import java.util.Optional; + +import static io.ballerina.runtime.api.utils.StringUtils.fromString; + +/** + * Utility class for invoking MCP service remote methods from Java via Ballerina interop. + *

+ * Not instantiable. + */ +public final class McpServiceMethodHelper { + + private static final String TOOLS_FIELD_NAME = "tools"; + private static final String NAME_FIELD_NAME = "name"; + private static final String DESCRIPTION_FIELD_NAME = "description"; + private static final String SCHEMA_FIELD_NAME = "schema"; + private static final String INPUT_SCHEMA_FIELD_NAME = "inputSchema"; + private static final String ARGUMENTS_FIELD_NAME = "arguments"; + private static final String CONTENT_FIELD_NAME = "content"; + private static final String TYPE_FIELD_NAME = "type"; + private static final String TEXT_FIELD_NAME = "text"; + + private static final String ANNOTATION_MCP_TOOL = "Tool"; + private static final String TYPE_TEXT_CONTENT = "TextContent"; + private static final String TEXT_VALUE_NAME = "text"; + private static final String MCP_SERVICE_FIELD = "mcpService"; + + private McpServiceMethodHelper() {} + + /** + * Invoke the 'onListTools' remote method on the given MCP service object. + * + * @param env The Ballerina runtime environment. + * @param mcpService The MCP service object. + * @return Result of remote method invocation. + */ + public static Object invokeOnListTools(Environment env, BObject mcpService) { + return env.getRuntime().callMethod(mcpService, "onListTools", null); + } + + /** + * Invoke the 'onCallTool' remote method on the given MCP service object with parameters. + * + * @param env The Ballerina runtime environment. + * @param mcpService The MCP service object. + * @param params Parameters for the tool invocation. + * @return Result of remote method invocation. + */ + public static Object invokeOnCallTool(Environment env, BObject mcpService, BMap params) { + return env.getRuntime().callMethod(mcpService, "onCallTool", null, params); + } + + /** + * Lists tool metadata for remote functions in the given MCP service. + * + * @param mcpService The MCP service object. + * @param typed The type descriptor for the result. + * @return Record containing the list of tools. + */ + public static Object listToolsForRemoteFunctions(BObject mcpService, BTypedesc typed) { + RecordType resultRecordType = (RecordType) typed.getDescribingType(); + BMap result = ValueCreator.createRecordValue(resultRecordType); + + ArrayType toolsArrayType = (ArrayType) resultRecordType.getFields().get(TOOLS_FIELD_NAME).getFieldType(); + BArray tools = ValueCreator.createArrayValue(toolsArrayType); + + for (RemoteMethodType remoteMethod : getRemoteMethods(mcpService)) { + remoteMethod.getAnnotations().entrySet().stream() + .filter(e -> e.getKey().getValue().contains(ANNOTATION_MCP_TOOL)) + .findFirst() + .ifPresent(annotation -> tools.append( + createToolRecord(toolsArrayType, remoteMethod, (BMap) annotation.getValue()) + )); + } + result.put(fromString(TOOLS_FIELD_NAME), tools); + return result; + } + + /** + * Invokes a remote function (tool) by name with arguments. + * + * @param env The Ballerina runtime environment. + * @param mcpService The MCP service object. + * @param params The parameters for the tool invocation. + * @param typed The type descriptor for the result. + * @return Record containing the invocation result or an error. + */ + public static Object callToolForRemoteFunctions(Environment env, BObject mcpService, BMap params, + BTypedesc typed) { + BString toolName = (BString) params.get(fromString(NAME_FIELD_NAME)); + + Optional method = getRemoteMethods(mcpService).stream() + .filter(rmt -> rmt.getName().equals(toolName.getValue())) + .findFirst(); + + if (method.isEmpty()) { + return ModuleUtils + .createError("RemoteMethodType with name '" + toolName.getValue() + "' not found"); + } + + Object[] args = buildArgsForMethod(method.get(), (BMap) params.get(fromString(ARGUMENTS_FIELD_NAME))); + Object result = env.getRuntime().callMethod(mcpService, toolName.getValue(), null, args); + + return createCallToolResult(typed, result); + } + + /** + * Adds an MCP service to the dispatcher service by storing it in a private field. + * + * @param dispatcherService The dispatcher service object. + * @param mcpService The MCP service object to store. + * @return null if successful, error otherwise. + */ + public static Object addMcpServiceToDispatcher(BObject dispatcherService, BObject mcpService) { + try { + dispatcherService.addNativeData(MCP_SERVICE_FIELD, mcpService); + return null; + } catch (Exception e) { + return ModuleUtils.createError("Failed to add MCP service to dispatcher: " + e.getMessage()); + } + } + + /** + * Retrieves the MCP service from the dispatcher service. + * + * @param dispatcherService The dispatcher service object. + * @return The MCP service object or an error if not found. + */ + public static Object getMcpServiceFromDispatcher(BObject dispatcherService) { + try { + Object mcpService = dispatcherService.getNativeData(MCP_SERVICE_FIELD); + if (mcpService == null) { + return ModuleUtils.createError("MCP service not found in dispatcher"); + } + return mcpService; + } catch (Exception e) { + return ModuleUtils.createError("Failed to get MCP service from dispatcher: " + e.getMessage()); + } + } + + private static List getRemoteMethods(BObject mcpService) { + ServiceType serviceType = (ServiceType) mcpService.getOriginalType(); + return List.of(serviceType.getRemoteMethods()); + } + + private static BMap createToolRecord(ArrayType toolsArrayType, RemoteMethodType remoteMethod, + BMap annotationValue) { + RecordType toolRecordType = (RecordType) ((ReferenceType) toolsArrayType.getElementType()).getReferredType(); + BMap tool = ValueCreator.createRecordValue(toolRecordType); + + tool.put(fromString(NAME_FIELD_NAME), fromString(remoteMethod.getName())); + tool.put(fromString(DESCRIPTION_FIELD_NAME), annotationValue.get(fromString(DESCRIPTION_FIELD_NAME))); + tool.put(fromString(INPUT_SCHEMA_FIELD_NAME), annotationValue.get(fromString(SCHEMA_FIELD_NAME))); + return tool; + } + + private static Object[] buildArgsForMethod(RemoteMethodType method, BMap arguments) { + List params = List.of(method.getParameters()); + Object[] args = new Object[params.size()]; + for (int i = 0; i < params.size(); i++) { + String paramName = params.get(i).name; + args[i] = arguments == null ? null : arguments.get(fromString(paramName)); + } + return args; + } + + private static Object createCallToolResult(BTypedesc typed, Object result) { + RecordType resultRecordType = (RecordType) typed.getDescribingType(); + BMap callToolResult = ValueCreator.createRecordValue(resultRecordType); + + ArrayType contentArrayType = (ArrayType) resultRecordType.getFields().get(CONTENT_FIELD_NAME).getFieldType(); + BArray contentArray = ValueCreator.createArrayValue(contentArrayType); + + UnionType contentUnionType = (UnionType) contentArrayType.getElementType(); + Optional textContentTypeOpt = contentUnionType.getMemberTypes().stream() + .filter(type -> TYPE_TEXT_CONTENT.equals(type.getName())) + .findFirst(); + if (textContentTypeOpt.isEmpty()) { + return ModuleUtils + .createError("No member type named 'TextContent' found in content union type."); + } + RecordType textContentRecordType = (RecordType) ((ReferenceType) textContentTypeOpt.get()).getReferredType(); + BMap textContent = ValueCreator.createRecordValue(textContentRecordType); + textContent.put(fromString(TYPE_FIELD_NAME), fromString(TEXT_VALUE_NAME)); + textContent.put(fromString(TEXT_FIELD_NAME), fromString(result == null ? "" : result.toString())); + contentArray.append(textContent); + + callToolResult.put(fromString(CONTENT_FIELD_NAME), contentArray); + return callToolResult; + } +} diff --git a/native/src/main/java/io/ballerina/stdlib/mcp/ModuleUtils.java b/native/src/main/java/io/ballerina/stdlib/mcp/ModuleUtils.java index ba09fe7..4c8ccda 100644 --- a/native/src/main/java/io/ballerina/stdlib/mcp/ModuleUtils.java +++ b/native/src/main/java/io/ballerina/stdlib/mcp/ModuleUtils.java @@ -20,6 +20,10 @@ import io.ballerina.runtime.api.Environment; import io.ballerina.runtime.api.Module; +import io.ballerina.runtime.api.creators.ErrorCreator; +import io.ballerina.runtime.api.values.BError; + +import static io.ballerina.runtime.api.utils.StringUtils.fromString; public final class ModuleUtils { private static Module module; @@ -36,4 +40,8 @@ public static Module getModule() { public static void setModule(Environment env) { module = env.getCurrentModule(); } + + public static BError createError(String errorMessage) { + return ErrorCreator.createError(fromString(errorMessage)); + } } diff --git a/native/src/main/java/io/ballerina/stdlib/mcp/SseEventStreamHelper.java b/native/src/main/java/io/ballerina/stdlib/mcp/SseEventStreamHelper.java index b733df1..bbfd98c 100644 --- a/native/src/main/java/io/ballerina/stdlib/mcp/SseEventStreamHelper.java +++ b/native/src/main/java/io/ballerina/stdlib/mcp/SseEventStreamHelper.java @@ -19,12 +19,8 @@ package io.ballerina.stdlib.mcp; import io.ballerina.runtime.api.Environment; -import io.ballerina.runtime.api.creators.ErrorCreator; import io.ballerina.runtime.api.values.BObject; import io.ballerina.runtime.api.values.BStream; -import io.ballerina.runtime.api.values.BString; - -import static io.ballerina.runtime.api.utils.StringUtils.fromString; /** * Utility class for handling Server-Sent Events (SSE) streams in Ballerina via Java interop. @@ -68,8 +64,7 @@ public static void attachSseStream(BObject object, BStream sseStream) { public static Object getNextSseEvent(Environment env, BObject object) { BStream sseStream = (BStream) object.getNativeData(SSE_STREAM_NATIVE_KEY); if (sseStream == null) { - BString errorMessage = fromString("Unable to obtain elements from stream. SSE stream not found."); - return ErrorCreator.createError(errorMessage); + return ModuleUtils.createError("Unable to obtain elements from stream. SSE stream not found."); } BObject iteratorObject = sseStream.getIteratorObj(); // Use the Ballerina runtime to call the "next" method on the iterator and fetch the next event. @@ -88,8 +83,7 @@ public static Object getNextSseEvent(Environment env, BObject object) { public static Object closeSseEventStream(Environment env, BObject object) { BStream sseStream = (BStream) object.getNativeData(SSE_STREAM_NATIVE_KEY); if (sseStream == null) { - BString errorMessage = fromString("Unable to obtain elements from stream. SSE stream not found."); - return ErrorCreator.createError(errorMessage); + return ModuleUtils.createError("Unable to obtain elements from stream. SSE stream not found."); } BObject iteratorObject = sseStream.getIteratorObj(); // Use the Ballerina runtime to call the "close" method on the iterator and release resources. diff --git a/settings.gradle b/settings.gradle index a00d7fb..c5cfee4 100644 --- a/settings.gradle +++ b/settings.gradle @@ -37,10 +37,12 @@ rootProject.name = 'mcp' include ':checkstyle' include ':mcp-ballerina' include ':mcp-native' +include ':mcp-compiler-plugin' project(':checkstyle').projectDir = file("build-config${File.separator}checkstyle") project(':mcp-ballerina').projectDir = file('ballerina') project(':mcp-native').projectDir = file('native') +project(':mcp-compiler-plugin').projectDir = file('compiler-plugin') gradleEnterprise { buildScan { diff --git a/spotbugs-exclude.xml b/spotbugs-exclude.xml index 58da414..1d22b46 100644 --- a/spotbugs-exclude.xml +++ b/spotbugs-exclude.xml @@ -16,4 +16,8 @@ ~ under the License. --> + + + +