diff --git a/ballerina/CompilerPlugin.toml b/ballerina/CompilerPlugin.toml index 35c0f6193..fc5ed938f 100644 --- a/ballerina/CompilerPlugin.toml +++ b/ballerina/CompilerPlugin.toml @@ -4,3 +4,9 @@ class = "io.ballerina.stdlib.websocket.plugin.WebSocketCompilerPlugin" [[dependency]] path = "../compiler-plugin/build/libs/websocket-compiler-plugin-2.15.0-SNAPSHOT.jar" + +[[dependency]] +path = "../native/build/libs/websocket-native-2.15.0-SNAPSHOT.jar" + +[[dependency]] +path = "./lib/http-native-2.14.0.jar" diff --git a/ballerina/annotation.bal b/ballerina/annotation.bal index 1d11832de..c1406f8b0 100644 --- a/ballerina/annotation.bal +++ b/ballerina/annotation.bal @@ -14,10 +14,6 @@ // specific language governing permissions and limitations // under the License. -/////////////////////////// -/// Service Annotations /// -/////////////////////////// - # Configurations for a WebSocket service. # # + subProtocols - Negotiable sub protocol by the service @@ -48,3 +44,13 @@ public type WSServiceConfig record {| # The annotation which is used to configure a WebSocket service. public annotation WSServiceConfig ServiceConfig on service; + +# Configurations used to define dispatching rules for remote functions. +# +# + value - The value which is going to be used for dispatching to custom remote functions. +public type WsDispatcherMapping record {| + string value; +|}; + +# The annotation which is used to configure the dispatching rules for WebSocket remote functions. +public const annotation WsDispatcherMapping DispatcherMapping on function; diff --git a/ballerina/build.gradle b/ballerina/build.gradle index fac0dc139..6f3c8ed43 100644 --- a/ballerina/build.gradle +++ b/ballerina/build.gradle @@ -122,6 +122,7 @@ task updateTomlFiles { ballerinaTomlFile.text = newConfig def newPluginConfig = compilerPluginTomlFilePlaceHolder.text.replace("@project.version@", project.version) + newPluginConfig = newPluginConfig.replace("@stdlib.httpnative.version@", stdlibDependentHttpNativeVersion) compilerPluginTomlFile.text = newPluginConfig } } diff --git a/ballerina/tests/test_dispatcher_mapping_annotation.bal b/ballerina/tests/test_dispatcher_mapping_annotation.bal new file mode 100644 index 000000000..f31b37496 --- /dev/null +++ b/ballerina/tests/test_dispatcher_mapping_annotation.bal @@ -0,0 +1,66 @@ +// Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/test; + +type Subscribe record {| + string event = "subscribe"; + string data; +|}; + +@ServiceConfig { + dispatcherKey: "event" +} +service / on new Listener(22103) { + resource function get .() returns Service|UpgradeError { + return new WsService22103(); + } +} + +service class WsService22103 { + *Service; + + @DispatcherMapping { + value: "subscribe" + } + remote function onSubscribeMessage(Subscribe message) returns string { + return "onSubscribeMessage"; + } + + remote function onSubscribeMessageError(Caller caller, error message) returns error? { + check caller->writeMessage("onSubscribeMessageError"); + } +} + +@test:Config { + groups: ["dispatcherMappingAnnotation"] +} +public function testDispatcherMappingAnnotation() returns error? { + Client wsClient = check new ("ws://localhost:22103/"); + check wsClient->writeMessage({event: "subscribe", data: "test"}); + string res = check wsClient->readMessage(); + test:assertEquals(res, "onSubscribeMessage"); +} + +@test:Config { + groups: ["dispatcherMappingAnnotation"] +} +public function testDispatcherMappingAnnotationWithCustomOnError() returns error? { + Client wsClient = check new ("ws://localhost:22103/"); + check wsClient->writeMessage({event: "subscribe", invalidField: "test"}); + string res = check wsClient->readMessage(); + test:assertEquals(res, "onSubscribeMessageError"); +} diff --git a/build-config/resources/CompilerPlugin.toml b/build-config/resources/CompilerPlugin.toml index b5a86c280..4b06f7c63 100644 --- a/build-config/resources/CompilerPlugin.toml +++ b/build-config/resources/CompilerPlugin.toml @@ -4,3 +4,9 @@ class = "io.ballerina.stdlib.websocket.plugin.WebSocketCompilerPlugin" [[dependency]] path = "../compiler-plugin/build/libs/websocket-compiler-plugin-@project.version@.jar" + +[[dependency]] +path = "../native/build/libs/websocket-native-@project.version@.jar" + +[[dependency]] +path = "./lib/http-native-@stdlib.httpnative.version@.jar" diff --git a/changelog.md b/changelog.md index 83e095be5..18926aecb 100644 --- a/changelog.md +++ b/changelog.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added +- [Support Custom Remote Function Mapping via Annotation](https://github.com/ballerina-platform/ballerina-library/issues/7733) - [Introduce Service Config Annotation for connectionClosureTimeout in Websocket module](https://github.com/ballerina-platform/ballerina-library/issues/7697) - [Implement websocket close frame support](https://github.com/ballerina-platform/ballerina-library/issues/7578) diff --git a/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/websocket/compiler/WebSocketServiceValidationTest.java b/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/websocket/compiler/WebSocketServiceValidationTest.java index 1c50f48e7..ce1a6c964 100644 --- a/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/websocket/compiler/WebSocketServiceValidationTest.java +++ b/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/websocket/compiler/WebSocketServiceValidationTest.java @@ -572,6 +572,20 @@ public void testRemoteFunctionWithStreamAndCloseFrameReturnTypes() { Assert.assertEquals(diagnosticResult.errorCount(), 0); } + @Test + public void testDispatcherMappingAnnotation() { + Package currentPackage = loadPackage("sample_package_63"); + PackageCompilation compilation = currentPackage.getCompilation(); + DiagnosticResult diagnosticResult = compilation.diagnosticResult(); + Assert.assertEquals(diagnosticResult.errorCount(), 3); + Diagnostic firstDiagnostic = (Diagnostic) diagnosticResult.errors().toArray()[0]; + assertDiagnostic(firstDiagnostic, PluginConstants.CompilationErrors.RE_DECLARED_REMOTE_FUNCTIONS); + Diagnostic secondDiagnostic = (Diagnostic) diagnosticResult.errors().toArray()[1]; + assertDiagnostic(secondDiagnostic, PluginConstants.CompilationErrors.DUPLICATED_DISPATCHER_MAPPING_VALUE); + Diagnostic thirdDiagnostic = (Diagnostic) diagnosticResult.errors().toArray()[2]; + assertDiagnostic(thirdDiagnostic, PluginConstants.CompilationErrors.INVALID_FUNCTION_ANNOTATION); + } + @Test public void testConnectionClosureTimeoutInTheServiceConfig() { Package currentPackage = loadPackage("sample_package_64"); diff --git a/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_63/Ballerina.toml b/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_63/Ballerina.toml new file mode 100644 index 000000000..5cd60fa22 --- /dev/null +++ b/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_63/Ballerina.toml @@ -0,0 +1,4 @@ +[package] +org = "websocket_test" +name = "sample_63" +version = "0.1.0" diff --git a/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_63/server.bal b/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_63/server.bal new file mode 100644 index 000000000..3f528a2de --- /dev/null +++ b/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_63/server.bal @@ -0,0 +1,60 @@ +// Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/websocket as ws; + +type Subscribe record {| + string event = "subscribe"; + string data; +|}; + +@ws:ServiceConfig { + dispatcherKey: "event" +} +service / on new ws:Listener(9090) { + resource function get .() returns ws:Service|ws:UpgradeError { + return new WsService(); + } +} + +service class WsService { + *ws:Service; + + remote function onSubscribe(Subscribe message) returns string { + return "onSubscribe"; + } + + @ws:DispatcherMapping { + value: "subscribe" + } + remote function onSubscribeMessage(Subscribe message) returns string { + return "onSubscribeMessage"; + } + + @ws:DispatcherMapping { + value: "subscribe" + } + remote function onSubscribeText(Subscribe message) returns string { + return "onSubscribeText"; + } + + @ws:DispatcherMapping { + value: "ping" + } + remote function onPing(Subscribe message) returns string { + return "onPing"; + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/PluginConstants.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/PluginConstants.java index 35306477c..8e85f25a3 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/PluginConstants.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/PluginConstants.java @@ -79,6 +79,12 @@ public enum CompilationErrors { "WEBSOCKET_215"), INVALID_REMOTE_FUNCTIONS("Cannot have `{0}` with `onMessage` remote function", "WEBSOCKET_216"), + RE_DECLARED_REMOTE_FUNCTIONS("Cannot have `{0}` because the message type `{1}` is already " + + "associated with `{2}` remote function", "WEBSOCKET_217"), + DUPLICATED_DISPATCHER_MAPPING_VALUE("DispatcherMapping annotation value `{0}` is already " + + "exists", "WEBSOCKET_218"), + INVALID_FUNCTION_ANNOTATION("Invalid annotation provided for `{0}` remote function. " + + "This annotation can only be used with the custom dispatcher functions", "WEBSOCKET_219"), INVALID_RESOURCE_ERROR("There should be only one `get` resource for the service", "WEBSOCKET_101"), MORE_THAN_ONE_RESOURCE_PARAM_ERROR("There should be only http:Request as a parameter", diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/WebSocketServiceValidator.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/WebSocketServiceValidator.java index a4fc6792e..e283061f1 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/WebSocketServiceValidator.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/WebSocketServiceValidator.java @@ -20,30 +20,85 @@ import io.ballerina.compiler.api.symbols.FunctionTypeSymbol; import io.ballerina.compiler.api.symbols.MethodSymbol; import io.ballerina.compiler.api.symbols.Qualifier; +import io.ballerina.compiler.api.symbols.Symbol; +import io.ballerina.compiler.syntax.tree.AnnotationNode; import io.ballerina.compiler.syntax.tree.ClassDefinitionNode; +import io.ballerina.compiler.syntax.tree.ExpressionNode; import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode; +import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode; +import io.ballerina.compiler.syntax.tree.Node; import io.ballerina.compiler.syntax.tree.NodeList; +import io.ballerina.compiler.syntax.tree.SpecificFieldNode; import io.ballerina.compiler.syntax.tree.SyntaxKind; import io.ballerina.compiler.syntax.tree.Token; import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext; +import io.ballerina.stdlib.websocket.WebSocketConstants; import io.ballerina.tools.diagnostics.DiagnosticFactory; import io.ballerina.tools.diagnostics.DiagnosticInfo; import io.ballerina.tools.diagnostics.DiagnosticSeverity; +import java.util.HashSet; import java.util.Map; +import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; +import static io.ballerina.stdlib.websocket.WebSocketConstants.ANNOTATION_ATTR_DISPATCHER_VALUE; +import static io.ballerina.stdlib.websocket.WebSocketResourceDispatcher.createCustomRemoteFunction; +import static io.ballerina.stdlib.websocket.plugin.PluginConstants.CompilationErrors.DUPLICATED_DISPATCHER_MAPPING_VALUE; +import static io.ballerina.stdlib.websocket.plugin.PluginConstants.CompilationErrors.INVALID_FUNCTION_ANNOTATION; +import static io.ballerina.stdlib.websocket.plugin.PluginConstants.CompilationErrors.RE_DECLARED_REMOTE_FUNCTIONS; + /** * A class for validating websocket service. */ public class WebSocketServiceValidator { public static final String GENERIC_FUNCTION = "generic function"; + private final Set specialRemoteMethods = Set.of(PluginConstants.ON_OPEN, PluginConstants.ON_CLOSE, + PluginConstants.ON_ERROR, PluginConstants.ON_IDLE_TIMEOUT, PluginConstants.ON_TEXT_MESSAGE, + PluginConstants.ON_BINARY_MESSAGE, PluginConstants.ON_MESSAGE, PluginConstants.ON_PING_MESSAGE, + PluginConstants.ON_PONG_MESSAGE); private SyntaxNodeAnalysisContext ctx; WebSocketServiceValidator(SyntaxNodeAnalysisContext syntaxNodeAnalysisContext) { this.ctx = syntaxNodeAnalysisContext; } + private static Optional getDispatcherMappingAnnotatedFunctionName(FunctionDefinitionNode node, + SyntaxNodeAnalysisContext ctx) { + if (node.metadata().isEmpty()) { + return Optional.empty(); + } + for (AnnotationNode annotationNode : node.metadata().get().annotations()) { + Optional annotationType = ctx.semanticModel().symbol(annotationNode); + if (annotationType.isEmpty()) { + continue; + } + if (!annotationType.get().getModule().flatMap(Symbol::getName) + .orElse("").equals(WebSocketConstants.PACKAGE_WEBSOCKET) || + !annotationType.get().getName().orElse("") + .equals(WebSocketConstants.WEBSOCKET_DISPATCHER_MAPPING_ANNOTATION)) { + continue; + } + if (annotationNode.annotValue().isEmpty()) { + return Optional.empty(); + } + MappingConstructorExpressionNode annotationValue = annotationNode.annotValue().get(); + for (Node field : annotationValue.fields()) { + if (!field.kind().equals(SyntaxKind.SPECIFIC_FIELD)) { + continue; + } + String fieldName = ((SpecificFieldNode) field).fieldName().toString().strip(); + Optional filedValue = ((SpecificFieldNode) field).valueExpr(); + if (!fieldName.equals(ANNOTATION_ATTR_DISPATCHER_VALUE) || filedValue.isEmpty()) { + continue; + } + return Optional.of(filedValue.get().toString().replaceAll("\"", "").strip()); + } + } + return Optional.empty(); + } + public void validate() { ClassDefinitionNode classDefNode = (ClassDefinitionNode) ctx.node(); Map functionSet = classDefNode.members().stream().filter(child -> @@ -70,7 +125,7 @@ public void validate() { classDefNode.location(), PluginConstants.ON_TEXT_MESSAGE); } if (functionSet.containsKey(PluginConstants.ON_MESSAGE) && - functionSet.containsKey(PluginConstants.ON_BINARY_MESSAGE)) { + functionSet.containsKey(PluginConstants.ON_BINARY_MESSAGE)) { Utils.reportDiagnostics(ctx, PluginConstants.CompilationErrors.INVALID_REMOTE_FUNCTIONS, classDefNode.location(), PluginConstants.ON_BINARY_MESSAGE); } @@ -105,6 +160,45 @@ public void validate() { !functionSet.containsKey(PluginConstants.ON_BINARY_MESSAGE)) { reportDiagnostic(classDefNode, PluginConstants.CompilationErrors.ON_MESSAGE_GENERATION_HINT); } + validateDispatcherMappingAnnotations(classDefNode, functionSet); + } + + private void validateDispatcherMappingAnnotations(ClassDefinitionNode classDefNode, + Map functionSet) { + Set seenAnnotationValues = new HashSet<>(); + for (Node node : classDefNode.members()) { + if (!node.kind().equals(SyntaxKind.OBJECT_METHOD_DEFINITION)) { + continue; + } + FunctionDefinitionNode funcDefinitionNode = (FunctionDefinitionNode) node; + if (funcDefinitionNode.qualifierList().stream() + .noneMatch(token -> token.text().equals(Qualifier.REMOTE.getValue()))) { + continue; + } + Optional funcName = ctx.semanticModel().symbol(funcDefinitionNode).flatMap(Symbol::getName); + Optional annoDispatchingValue = getDispatcherMappingAnnotatedFunctionName(funcDefinitionNode, ctx); + if (funcName.isEmpty() || annoDispatchingValue.isEmpty()) { + continue; + } + if (seenAnnotationValues.contains(annoDispatchingValue.get())) { + Utils.reportDiagnostics(ctx, DUPLICATED_DISPATCHER_MAPPING_VALUE, + funcDefinitionNode.location(), annoDispatchingValue.get()); + continue; + } + seenAnnotationValues.add(annoDispatchingValue.get()); + String customRemoteFunctionName = createCustomRemoteFunction(annoDispatchingValue.get()); + if (this.specialRemoteMethods.contains(funcName.get())) { + Utils.reportDiagnostics(ctx, INVALID_FUNCTION_ANNOTATION, funcDefinitionNode.location(), + funcName.get()); + continue; + } + if (functionSet.containsKey(customRemoteFunctionName) && + !customRemoteFunctionName.equals(funcName.get()) && + !this.specialRemoteMethods.contains(customRemoteFunctionName)) { + Utils.reportDiagnostics(ctx, RE_DECLARED_REMOTE_FUNCTIONS, classDefNode.location(), + customRemoteFunctionName, annoDispatchingValue.get(), funcName.get()); + } + } } private void filterRemoteFunctions(FunctionDefinitionNode functionDefinitionNode) { diff --git a/docs/spec/spec.md b/docs/spec/spec.md index f2a76f10e..3c28abd65 100644 --- a/docs/spec/spec.md +++ b/docs/spec/spec.md @@ -35,7 +35,7 @@ The conforming implementation of the specification is released and included in t * [onClose](#onclose) * [onError](#onerror) * 3.2.2. [Dispatching custom remote methods](#322-dispatching-custom-remote-methods) - * [Dispatching custom error remote methods](#Dispatching custom error remote methods) + * [Dispatching custom error remote methods](#dispatching-custom-error-remote-methods) * 3.2.3. [Return types](#323-return-types) 4. [Client](#4-client) * 4.1. [Client Configurations](#41-client-configurations) @@ -366,45 +366,80 @@ For example, if the message is `{"event": "heartbeat"}` it will get dispatched t 1. The user can configure the field name(key) to identify the messages and the allowed values as message types. -The `dispatcherKey` is used to identify the event type of the incoming message by its value. -The `dispatcherStreamId` is used to distinguish between requests and their corresponding responses in a multiplexing scenario. + The `dispatcherKey` is used to identify the event type of the incoming message by its value. The `dispatcherStreamId` is used to distinguish between requests and their corresponding responses in a multiplexing scenario. -```ballerina + ```ballerina + Ex: + incoming message = ` {"event": "heartbeat", "id": "1"}` + dispatcherKey = "event" + dispatcherStreamId = "id" + event/message type = "heartbeat" + dispatching to remote function = "onHeartbeat" -Ex: -incoming message = ` {"event": "heartbeat", "id": "1"}` -dispatcherKey = "event" -dispatcherStreamId = "id" -event/message type = "heartbeat" -dispatching to remote function = "onHeartbeat" + ```ballerina + @websocket:ServiceConfig { + dispatcherKey: "event", + dispatcherStreamId: "id" + } + service / on new websocket:Listener(9090) {} + ``` -```ballerina -@websocket:ServiceConfig { - dispatcherKey: "event", - dispatcherStreamId: "id" -} -service / on new websocket:Listener(9090) {} -``` +2. Naming of the remote function. + + * If there are spaces and underscores between message types, those will be removed and made camel case("un subscribe" -> "onUnSubscribe"). + * The 'on' word is added as the predecessor and the remote function name is in the camel case("heartbeat" -> "onHeartbeat"). + +3. Custom Dispatching with `@DispatcherMapping` annotation + + The `@DispatcherMapping` annotation allows users to explicitly define the dispatching behavior for remote functions. If an incoming message type matches the value in the annotation, the respective remote function will be invoked. + + ```ballerina + @DispatcherMapping { + value: "subscribe" + } + remote function onSubscribeMessage(Subscribe message) returns string { + return "onSubscribeMessage"; + } + ``` + + In this case, when a message of type "subscribe" is received, the `onSubscribeMessage` remote function is invoked. -##### [Dispatching custom error remote methods](#Dispatching custom error remote methods) +4. If an unmatching message type receives where a matching remote function is not implemented in the WebSocket service by the user, it gets dispatched to the default `onMessage` remote function if it is implemented. Or else it will get ignored. + +##### [Dispatching custom error remote methods](#dispatching-custom-error-remote-methods) If the user has defined a remote function with the name `customRemoteFunction` + `Error` in the WebSocket service, the error messages will get dispatched to that remote function when there is a data binding error. If that is not defined, the generic `onError` remote function gets dispatched. +* Example 1 + ```ballerina -Ex: +remote function onHeartbeat(Heartbeat message) returns error? { +} + +remote function onHeartbeatError(error message) returns error? { +} + incoming message = ` {"event": "heartbeat"}` -dispatcherKey = "event" -event/message type = "heartbeat" dispatching remote function = "onHeartbeat" dispatching error remote function = "onHeartbeatError" ``` -2. Naming of the remote function. +* Example 2 -- If there are spaces and underscores between message types, those will be removed and made camel case("un subscribe" -> "onUnSubscribe"). -- The 'on' word is added as the predecessor and the remote function name is in the camel case("heartbeat" -> "onHeartbeat"). +```ballerina +@websocket:DispatcherMapping { + value: "subscribe" +} +remote function onSubscribeMessage(Subscribe message) returns error? { +} -3. If an unmatching message type receives where a matching remote function is not implemented in the WebSocket service by the user, it gets dispatched to the default `onMessage` remote function if it is implemented. Or else it will get ignored. +remote function onSubscribeMessageError(error message) returns error? { +} + +incoming message = ` {"event": "subscribe"}` +dispatching remote function = "onSubscribeMessage" +dispatching error remote function = "onSubscribeMessageError" +``` #### 3.2.3. Return types diff --git a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketConstants.java b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketConstants.java index c9c475610..c51c9c24a 100644 --- a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketConstants.java +++ b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketConstants.java @@ -53,6 +53,9 @@ public class WebSocketConstants { public static final BString ANNOTATION_ATTR_VALIDATION_ENABLED = StringUtils.fromString("validation"); public static final BString ANNOTATION_ATTR_DISPATCHER_KEY = StringUtils.fromString("dispatcherKey"); + public static final String WEBSOCKET_DISPATCHER_MAPPING_ANNOTATION = "DispatcherMapping"; + public static final String ANNOTATION_ATTR_DISPATCHER_VALUE = "value"; + public static final BString RETRY_CONFIG = StringUtils.fromString("retryConfig"); public static final String LOG_MESSAGE = "{} {}"; public static final int STATUS_CODE_ABNORMAL_CLOSURE = 1006; diff --git a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketResourceDispatcher.java b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketResourceDispatcher.java index 6d07500c0..419db04e3 100644 --- a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketResourceDispatcher.java +++ b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketResourceDispatcher.java @@ -28,6 +28,7 @@ import io.ballerina.runtime.api.types.ObjectType; import io.ballerina.runtime.api.types.Parameter; import io.ballerina.runtime.api.types.PredefinedTypes; +import io.ballerina.runtime.api.types.RemoteMethodType; import io.ballerina.runtime.api.types.ResourceMethodType; import io.ballerina.runtime.api.types.ServiceType; import io.ballerina.runtime.api.types.Type; @@ -408,29 +409,33 @@ public static void dispatchOnText(WebSocketConnectionInfo connectionInfo, WebSoc return; } String dispatchingKey = ((WebSocketServerService) wsService).getDispatchingKey(); - Optional customRemoteMethodName = getCustomRemoteMethodName(dispatchingKey, stringAggregator); + Optional dispatchingValue = getDispatchingValue(dispatchingKey, stringAggregator); + Optional customRemoteMethodName = dispatchingValue + .map(WebSocketResourceDispatcher::createCustomRemoteFunction); MethodType onTextMessageResource = null; BObject wsEndpoint = connectionInfo.getWebSocketEndpoint(); Object dispatchingService = wsService.getWsService(connectionInfo.getWebSocketConnection().getChannelId()); - MethodType[] remoteFunctions = ((ServiceType) (((BValue) dispatchingService).getType())).getMethods(); - for (MethodType remoteFunc : remoteFunctions) { - String funcName = remoteFunc.getName(); - if (customRemoteMethodName.isPresent() && funcName.equals(customRemoteMethodName.get())) { - onTextMessageResource = remoteFunc; - break; - } - if (funcName.equals(WebSocketConstants.RESOURCE_NAME_ON_TEXT_MESSAGE) || - funcName.equals(WebSocketConstants.RESOURCE_NAME_ON_MESSAGE)) { - onTextMessageResource = remoteFunc; - } + Map dispatchingFunctions = wsService + .getDispatchingFunctions(connectionInfo.getWebSocketConnection().getChannelId()); + if (dispatchingValue.isPresent() && dispatchingFunctions.containsKey(dispatchingValue.get())) { + onTextMessageResource = dispatchingFunctions.get(dispatchingValue.get()); + } else if (customRemoteMethodName.isPresent() + && dispatchingFunctions.containsKey(customRemoteMethodName.get())) { + onTextMessageResource = dispatchingFunctions.get(customRemoteMethodName.get()); + } else if (dispatchingFunctions.containsKey(WebSocketConstants.RESOURCE_NAME_ON_TEXT_MESSAGE)) { + onTextMessageResource = dispatchingFunctions.get(WebSocketConstants.RESOURCE_NAME_ON_TEXT_MESSAGE); + } else if (dispatchingFunctions.containsKey(WebSocketConstants.RESOURCE_NAME_ON_MESSAGE)) { + onTextMessageResource = dispatchingFunctions.get(WebSocketConstants.RESOURCE_NAME_ON_MESSAGE); } - boolean hasOnError = Arrays.stream(remoteFunctions).anyMatch(remoteFunc -> remoteFunc.getName() - .equals(WebSocketConstants.RESOURCE_NAME_ON_ERROR)); + boolean hasOnError = dispatchingFunctions.containsKey(WebSocketConstants.RESOURCE_NAME_ON_ERROR); String errorMethodName = null; boolean hasOnCustomError = false; - if (customRemoteMethodName.isPresent()) { + if (onTextMessageResource != null) { + errorMethodName = onTextMessageResource.getName() + "Error"; + hasOnCustomError = dispatchingFunctions.containsKey(errorMethodName); + } else if (customRemoteMethodName.isPresent()) { errorMethodName = customRemoteMethodName.get() + "Error"; - hasOnCustomError = hasCustomErrorRemoteFunction(remoteFunctions, errorMethodName); + hasOnCustomError = dispatchingFunctions.containsKey(errorMethodName); } if (onTextMessageResource == null) { stringAggregator.resetAggregateString(); @@ -532,15 +537,6 @@ private static Object getBvaluesForTextMessage(Type param, int typeTag, BObject return bValue; } - private static boolean hasCustomErrorRemoteFunction(MethodType[] remoteFunctions, String errorMethodName) { - for (MethodType remoteFunc : remoteFunctions) { - if (remoteFunc.getName().equals(errorMethodName)) { - return true; - } - } - return false; - } - private static void handleError(WebSocketConnectionInfo connectionInfo, BError error, boolean hasOnCustomError, String errorMethodName, boolean hasOnError) throws IllegalAccessException { if (hasOnCustomError) { @@ -559,23 +555,23 @@ private static void handleDataBindingError(WebSocketConnectionInfo connectionInf } } - private static Optional getCustomRemoteMethodName(String dispatchingKey, + private static Optional getDispatchingValue(String dispatchingKey, WebSocketConnectionInfo.StringAggregator stringAggregator) { return Optional.ofNullable(dispatchingKey) .flatMap(key -> { try { - BString dispatchingValue = ((BMap) FromJsonString.fromJsonString( + String dispatchingValue = ((BMap) FromJsonString.fromJsonString( StringUtils.fromString(stringAggregator.getAggregateString()))) - .getStringValue(StringUtils.fromString(dispatchingKey)); - return Optional.of(createCustomRemoteFunction(dispatchingValue.getValue())); + .getStringValue(StringUtils.fromString(dispatchingKey)).getValue(); + return Optional.of(dispatchingValue); } catch (RuntimeException e) { return Optional.empty(); } }); } - private static String createCustomRemoteFunction(String dispatchingValue) { + public static String createCustomRemoteFunction(String dispatchingValue) { dispatchingValue = "on " + dispatchingValue; StringBuilder builder = new StringBuilder(); String[] words = dispatchingValue.split("[\\W_]+"); diff --git a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketService.java b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketService.java index e158e029e..f6ae9a08a 100644 --- a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketService.java +++ b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketService.java @@ -19,14 +19,25 @@ package io.ballerina.stdlib.websocket; import io.ballerina.runtime.api.Runtime; +import io.ballerina.runtime.api.flags.SymbolFlags; import io.ballerina.runtime.api.types.MethodType; import io.ballerina.runtime.api.types.ObjectType; +import io.ballerina.runtime.api.types.RemoteMethodType; +import io.ballerina.runtime.api.types.ServiceType; import io.ballerina.runtime.api.utils.TypeUtils; +import io.ballerina.runtime.api.values.BMap; import io.ballerina.runtime.api.values.BObject; +import io.ballerina.runtime.api.values.BString; +import io.ballerina.runtime.api.values.BValue; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; +import static io.ballerina.runtime.api.utils.StringUtils.fromString; +import static io.ballerina.stdlib.websocket.WebSocketConstants.ANNOTATION_ATTR_DISPATCHER_VALUE; +import static io.ballerina.stdlib.websocket.WebSocketConstants.UNCHECKED; + /** * WebSocket service for service dispatching. */ @@ -36,6 +47,7 @@ public class WebSocketService { protected Runtime runtime; private final Map resourcesMap = new ConcurrentHashMap<>(); private Map wsServices = new ConcurrentHashMap<>(); + private Map> wsServicesDispatchingFunctions = new ConcurrentHashMap<>(); public WebSocketService(Runtime runtime) { this.runtime = runtime; @@ -55,6 +67,35 @@ private void populateResourcesMap(BObject service) { } } + private Map getDispatchingFunctionMap(ServiceType dispatchingService) { + Map dispatchingFunctions = new ConcurrentHashMap<>(); + for (MethodType method : dispatchingService.getMethods()) { + if (!(SymbolFlags.isFlagOn(method.getFlags(), SymbolFlags.REMOTE))) { + continue; + } + RemoteMethodType remoteMethodType = (RemoteMethodType) method; + Optional dispatchingValue = getAnnotationDispatchingValue(remoteMethodType); + if (dispatchingValue.isPresent()) { + dispatchingFunctions.put(dispatchingValue.get(), remoteMethodType); + } else { + dispatchingFunctions.put(remoteMethodType.getName(), remoteMethodType); + } + } + return dispatchingFunctions; + } + + @SuppressWarnings(UNCHECKED) + public static Optional getAnnotationDispatchingValue(RemoteMethodType remoteFunc) { + BMap annotations = (BMap) remoteFunc.getAnnotation(fromString( + ModuleUtils.getPackageIdentifier() + ":" + WebSocketConstants.WEBSOCKET_DISPATCHER_MAPPING_ANNOTATION)); + if (annotations != null && annotations.containsKey(fromString(ANNOTATION_ATTR_DISPATCHER_VALUE))) { + String dispatchingValue = annotations. + getStringValue(fromString(ANNOTATION_ATTR_DISPATCHER_VALUE)).getValue(); + return Optional.of(dispatchingValue); + } + return Optional.empty(); + } + public MethodType getResourceByName(String resourceName) { return resourcesMap.get(resourceName); } @@ -69,9 +110,15 @@ public Runtime getRuntime() { public void addWsService(String channelId, Object dispatchingService) { this.wsServices.put(channelId, dispatchingService); + this.wsServicesDispatchingFunctions.put(channelId, + getDispatchingFunctionMap(((ServiceType) (((BValue) dispatchingService).getType())))); } public Object getWsService(String key) { return this.wsServices.get(key); } + + public Map getDispatchingFunctions(String key) { + return this.wsServicesDispatchingFunctions.get(key); + } }