diff --git a/ballerina/Ballerina.toml b/ballerina/Ballerina.toml index 3a9267535..826b7eee2 100644 --- a/ballerina/Ballerina.toml +++ b/ballerina/Ballerina.toml @@ -1,7 +1,7 @@ [package] org = "ballerina" name = "websocket" -version = "2.14.1" +version = "2.15.0" authors = ["Ballerina"] keywords = ["ws", "network", "bi-directional", "streaming", "service", "client"] repository = "https://github.com/ballerina-platform/module-ballerina-websocket" @@ -15,8 +15,8 @@ graalvmCompatible = true [[platform.java21.dependency]] groupId = "io.ballerina.stdlib" artifactId = "websocket-native" -version = "2.14.1" -path = "../native/build/libs/websocket-native-2.14.1-SNAPSHOT.jar" +version = "2.15.0" +path = "../native/build/libs/websocket-native-2.15.0-SNAPSHOT.jar" [[platform.java21.dependency]] groupId = "io.ballerina.stdlib" @@ -85,5 +85,5 @@ version = "4.1.118.Final" path = "./lib/netty-handler-proxy-4.1.118.Final.jar" [[platform.java21.dependency]] -path = "../test-utils/build/libs/websocket-test-utils-2.14.1-SNAPSHOT.jar" +path = "../test-utils/build/libs/websocket-test-utils-2.15.0-SNAPSHOT.jar" scope = "testOnly" diff --git a/ballerina/CompilerPlugin.toml b/ballerina/CompilerPlugin.toml index f0c2e2e19..35c0f6193 100644 --- a/ballerina/CompilerPlugin.toml +++ b/ballerina/CompilerPlugin.toml @@ -3,4 +3,4 @@ id = "websocket-compiler-plugin" class = "io.ballerina.stdlib.websocket.plugin.WebSocketCompilerPlugin" [[dependency]] -path = "../compiler-plugin/build/libs/websocket-compiler-plugin-2.14.1-SNAPSHOT.jar" +path = "../compiler-plugin/build/libs/websocket-compiler-plugin-2.15.0-SNAPSHOT.jar" diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index bf4b4e276..729ac223f 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -342,7 +342,7 @@ dependencies = [ [[package]] org = "ballerina" name = "websocket" -version = "2.14.1" +version = "2.15.0" dependencies = [ {org = "ballerina", name = "auth"}, {org = "ballerina", name = "constraint"}, diff --git a/ballerina/annotation.bal b/ballerina/annotation.bal index 3286fcc18..1d11832de 100644 --- a/ballerina/annotation.bal +++ b/ballerina/annotation.bal @@ -22,14 +22,19 @@ # # + subProtocols - Negotiable sub protocol by the service # + idleTimeout - Idle timeout for the client connection. Upon timeout, `onIdleTimeout` resource (if defined) -# in the server service will be triggered. Note that this overrides the `timeout` config -# in the `websocket:Listener`, which is applicable only for the initial HTTP upgrade request +# in the server service will be triggered. Note that this overrides the `timeout` config +# in the `websocket:Listener`, which is applicable only for the initial HTTP upgrade request # + maxFrameSize - The maximum payload size of a WebSocket frame in bytes. -# If this is not set or is negative or zero, the default frame size, which is 65536 will be used +# If this is not set or is negative or zero, the default frame size, which is 65536 will be used # + auth - Listener authentication configurations # + validation - Enable/disable constraint validation # + dispatcherKey - The key which is going to be used for dispatching to custom remote functions. # + dispatcherStreamId - The identifier used to distinguish between requests and their corresponding responses in a multiplexing scenario. +# + connectionClosureTimeout - Time to wait (in seconds) for the close frame to be received from the remote endpoint +# before closing the connection. If the timeout exceeds, then the connection is terminated even though a close frame is +# not received from the remote endpoint. If the value is -1, then the connection waits until a close frame is +# received, and any other negative value results in an error. If the WebSocket frame is received from the remote endpoint +# within the waiting period, the connection is terminated immediately. public type WSServiceConfig record {| string[] subProtocols = []; decimal idleTimeout = 0; @@ -38,6 +43,7 @@ public type WSServiceConfig record {| boolean validation = true; string dispatcherKey?; string dispatcherStreamId?; + decimal connectionClosureTimeout = 60; |}; # The annotation which is used to configure a WebSocket service. diff --git a/ballerina/tests/connection_closure_timeout.bal b/ballerina/tests/connection_closure_timeout.bal new file mode 100644 index 000000000..ec3fe4dde --- /dev/null +++ b/ballerina/tests/connection_closure_timeout.bal @@ -0,0 +1,133 @@ +// Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/lang.runtime; +import ballerina/test; + +map callers = {}; +error negativeTimeoutErrorMessage = error(""); + +@ServiceConfig { + dispatcherKey: "event", + connectionClosureTimeout: 5 +} +service / on new Listener(22100) { + resource function get .() returns Service|UpgradeError { + return new ConnectionClosureTimeoutService(); + } +} + +service class ConnectionClosureTimeoutService { + *Service; + + remote function onSubscribe(Caller caller) returns error? { + callers["onSubscribe"] = caller; + check caller->close(); + } + + remote function onChat(Caller caller) returns NormalClosure? { + callers["onChat"] = caller; + return NORMAL_CLOSURE; + } + + remote function onNegativeTimeout(Caller caller) returns error? { + Error? close = caller->close(timeout = -10); + if close is Error { + negativeTimeoutErrorMessage = close; + } + } + + remote function onIsClosed(record {string event; string name;} data) returns boolean|error { + Caller? caller = callers[data.name]; + if caller is Caller { + return !caller.isOpen(); + } + return error("Caller not found"); + } + + remote function onNegativeTimeoutErrorMessage(string data) returns string { + return negativeTimeoutErrorMessage.message(); + } +} + +@test:Config { + groups: ["connectionClosureTimeout"] +} +public function testConnectionClosureTimeoutCaller() returns error? { + Client wsClient1 = check new ("ws://localhost:22100/"); + check wsClient1->writeMessage({event: "subscribe"}); + runtime:sleep(8); + + // Check if the connection is closed using another client + Client wsClient2 = check new ("ws://localhost:22100/"); + check wsClient2->writeMessage({event: "is_closed", name: "onSubscribe"}); + boolean isClosed = check wsClient2->readMessage(); + test:assertTrue(isClosed); +} + +@test:Config { + groups: ["connectionClosureTimeout"] +} +public function testConnectionClosureTimeoutCloseFrames() returns error? { + Client wsClient1 = check new ("ws://localhost:22100/"); + check wsClient1->writeMessage({event: "chat"}); + runtime:sleep(8); + + // Check if the connection is closed using another client + Client wsClient2 = check new ("ws://localhost:22100/"); + check wsClient2->writeMessage({event: "is_closed", name: "onChat"}); + boolean isClosed = check wsClient2->readMessage(); + test:assertTrue(isClosed); +} + +@test:Config { + groups: ["connectionClosureTimeout"] +} +public function testConnectionClosureTimeoutCallerNegativeTimeout() returns error? { + Client wsClient1 = check new ("ws://localhost:22100/"); + check wsClient1->writeMessage({event: "negative_timeout"}); + runtime:sleep(1); + + // Check error in the service using another client + Client wsClient2 = check new ("ws://localhost:22100/"); + check wsClient2->writeMessage({event: "negative_timeout_error_message"}); + string errorMessage = check wsClient2->readMessage(); + test:assertEquals(errorMessage, "Invalid timeout value: -10"); +} + +@test:Config { + groups: ["connectionClosureTimeout"] +} +public function testConnectionClosureTimeoutNegativeValueInClient() returns error? { + Client wsClient1 = check new ("ws://localhost:22100/"); + Error? close = wsClient1->close(timeout = -20); + test:assertTrue(close is error); + if close is error { + test:assertEquals(close.message(), "Invalid timeout value: -20"); + } +} + +@test:Config { + groups: ["connectionClosureTimeout","test"] +} +public function testInvalidConnectionClosureTimeoutValue() returns error? { + Client wsClient1 = check new ("ws://localhost:22100/"); + Error? close = wsClient1->close(timeout = 200000000000000000); + test:assertTrue(close is error); + if close is error { + test:assertEquals(close.message(), "Error: Invalid timeout value: 200000000000000000"); + } +} diff --git a/ballerina/websocket_caller.bal b/ballerina/websocket_caller.bal index f8cfa5dfb..32bb59ed3 100644 --- a/ballerina/websocket_caller.bal +++ b/ballerina/websocket_caller.bal @@ -83,12 +83,12 @@ public isolated client class Caller { # + reason - Reason for closing the connection # + timeout - Time to wait (in seconds) for the close frame to be received from the remote endpoint before closing the # connection. If the timeout exceeds, then the connection is terminated even though a close frame - # is not received from the remote endpoint. If the value < 0 (e.g., -1), then the connection waits - # until a close frame is received. If the WebSocket frame is received from the remote endpoint - # within the waiting period, the connection is terminated immediately + # is not received from the remote endpoint. If the value is -1, then the connection waits + # until a close frame is received, and any other negative value results in an error. If the WebSocket frame is received + # from the remote endpoint within the waiting period, the connection is terminated immediately # + return - A `websocket:Error` if an error occurs when sending remote isolated function close(int? statusCode = 1000, string? reason = (), - decimal timeout = 60) returns Error? { + decimal? timeout = ()) returns Error? { int code = 1000; if (statusCode is int) { if (statusCode <= 999 || statusCode >= 1004 && statusCode <= 1006 || statusCode >= 1012 && @@ -98,10 +98,14 @@ public isolated client class Caller { } code = statusCode; } + if timeout is decimal && timeout < 0d && timeout != -1d { + string errorMessage = "Invalid timeout value: " + timeout.toString(); + return error Error(errorMessage); + } return self.externClose(code, reason is () ? "" : reason, timeout); } - isolated function externClose(int statusCode, string reason, decimal timeoutInSecs) returns Error? = @java:Method { + isolated function externClose(int statusCode, string reason, decimal? timeoutInSecs = ()) returns Error? = @java:Method { 'class: "io.ballerina.stdlib.websocket.actions.websocketconnector.Close" } external; diff --git a/ballerina/websocket_sync_client.bal b/ballerina/websocket_sync_client.bal index bfec57e71..6e9c40672 100644 --- a/ballerina/websocket_sync_client.bal +++ b/ballerina/websocket_sync_client.bal @@ -105,9 +105,10 @@ public isolated client class Client { # + reason - Reason for closing the connection # + timeout - Time to wait (in seconds) for the close frame to be received from the remote endpoint before closing the # connection. If the timeout exceeds, then the connection is terminated even though a close frame - # is not received from the remote endpoint. If the value is < 0 (e.g., -1), then the connection - # waits until a close frame is received. If the WebSocket frame is received from the remote - # endpoint within the waiting period, the connection is terminated immediately + # is not received from the remote endpoint. If the value is -1, then the connection + # waits until a close frame is received, and any other negative value results in an error. + # If the WebSocket frame is received from the remote endpoint within the waiting period, + # the connection is terminated immediately # + return - A `websocket:Error` if an error occurs while closing the WebSocket connection remote isolated function close(int? statusCode = 1000, string? reason = (), decimal timeout = 60) returns Error? { int code = 1000; @@ -119,6 +120,10 @@ public isolated client class Client { } code = statusCode; } + if timeout < 0d && timeout != -1d { + string errorMessage = "Invalid timeout value: " + timeout.toString(); + return error Error(errorMessage); + } return self.externClose(code, reason is () ? "" : reason, timeout); } @@ -226,7 +231,7 @@ public isolated client class Client { } } - isolated function externClose(int statusCode, string reason, decimal timeoutInSecs) + isolated function externClose(int statusCode, string reason, decimal? timeoutInSecs) returns Error? = @java:Method { 'class: "io.ballerina.stdlib.websocket.actions.websocketconnector.Close" } external; diff --git a/changelog.md b/changelog.md index 22ce616e3..83e095be5 100644 --- a/changelog.md +++ b/changelog.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added +- [Introduce Service Config Annotation for connectionClosureTimeout in Websocket module](https://github.com/ballerina-platform/ballerina-library/issues/7697) - [Implement websocket close frame support](https://github.com/ballerina-platform/ballerina-library/issues/7578) ### Fixed diff --git a/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/websocket/compiler/WebSocketServiceValidationTest.java b/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/websocket/compiler/WebSocketServiceValidationTest.java index e77ff657a..1c50f48e7 100644 --- a/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/websocket/compiler/WebSocketServiceValidationTest.java +++ b/compiler-plugin-tests/src/test/java/io/ballerina/stdlib/websocket/compiler/WebSocketServiceValidationTest.java @@ -572,6 +572,16 @@ public void testRemoteFunctionWithStreamAndCloseFrameReturnTypes() { Assert.assertEquals(diagnosticResult.errorCount(), 0); } + @Test + public void testConnectionClosureTimeoutInTheServiceConfig() { + Package currentPackage = loadPackage("sample_package_64"); + PackageCompilation compilation = currentPackage.getCompilation(); + DiagnosticResult diagnosticResult = compilation.diagnosticResult(); + Assert.assertEquals(diagnosticResult.errorCount(), 1); + Diagnostic diagnostic = (Diagnostic) diagnosticResult.errors().toArray()[0]; + assertDiagnostic(diagnostic, PluginConstants.CompilationErrors.INVALID_CONNECTION_CLOSURE_TIMEOUT); + } + private void assertDiagnostic(Diagnostic diagnostic, PluginConstants.CompilationErrors error) { Assert.assertEquals(diagnostic.diagnosticInfo().code(), error.getErrorCode()); Assert.assertEquals(diagnostic.diagnosticInfo().messageFormat(), diff --git a/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_64/Ballerina.toml b/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_64/Ballerina.toml new file mode 100644 index 000000000..eb158bba7 --- /dev/null +++ b/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_64/Ballerina.toml @@ -0,0 +1,4 @@ +[package] +org = "websocket_test" +name = "sample_64" +version = "0.1.0" diff --git a/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_64/service.bal b/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_64/service.bal new file mode 100644 index 000000000..51d98e5b7 --- /dev/null +++ b/compiler-plugin-tests/src/test/resources/ballerina_sources/sample_package_64/service.bal @@ -0,0 +1,51 @@ +// Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). +// +// WSO2 LLC. licenses this file to you under the Apache License, +// Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import ballerina/websocket; + +@websocket:ServiceConfig { + connectionClosureTimeout: -5 +} +service / on new websocket:Listener(9090) { + resource isolated function get .() returns websocket:Service|websocket:UpgradeError { + return new WsService(); + } +} + +service isolated class WsService { + *websocket:Service; +} + +// We ignore the compiler validation for below services since value is provided via a variable +decimal connectionClosureTimeout = 5.0; + +@websocket:ServiceConfig { + connectionClosureTimeout: connectionClosureTimeout +} +service / on new websocket:Listener(9090) { + resource isolated function get .() returns websocket:Service|websocket:UpgradeError { + return new WsService(); + } +} + +@websocket:ServiceConfig { + connectionClosureTimeout +} +service / on new websocket:Listener(9090) { + resource isolated function get .() returns websocket:Service|websocket:UpgradeError { + return new WsService(); + } +} diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/PluginConstants.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/PluginConstants.java index 933bef014..35306477c 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/PluginConstants.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/PluginConstants.java @@ -105,7 +105,9 @@ public enum CompilationErrors { CONTRADICTING_RETURN_TYPES("Contradicting return types provided for `{0}` remote function, cannot contain" + " stream type with other types", "WEBSOCKET_119"), DISPATCHER_STREAM_ID_WITHOUT_KEY("The `dispatcherStreamId` annotation is used without `dispatcherKey` " + - "annotation", "WEBSOCKET_120"); + "annotation", "WEBSOCKET_120"), + INVALID_CONNECTION_CLOSURE_TIMEOUT("Invalid connection closure timeout provided for the service", + "WEBSOCKET_121"); private final String error; private final String errorCode; diff --git a/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/WebSocketUpgradeServiceValidatorTask.java b/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/WebSocketUpgradeServiceValidatorTask.java index eb1bfc49c..285d436b1 100644 --- a/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/WebSocketUpgradeServiceValidatorTask.java +++ b/compiler-plugin/src/main/java/io/ballerina/stdlib/websocket/plugin/WebSocketUpgradeServiceValidatorTask.java @@ -19,7 +19,6 @@ package io.ballerina.stdlib.websocket.plugin; import io.ballerina.compiler.api.SemanticModel; -import io.ballerina.compiler.api.symbols.AnnotationSymbol; import io.ballerina.compiler.api.symbols.ModuleSymbol; import io.ballerina.compiler.api.symbols.ServiceDeclarationSymbol; import io.ballerina.compiler.api.symbols.Symbol; @@ -29,14 +28,17 @@ import io.ballerina.compiler.api.symbols.UnionTypeSymbol; import io.ballerina.compiler.syntax.tree.AnnotationNode; import io.ballerina.compiler.syntax.tree.FunctionArgumentNode; +import io.ballerina.compiler.syntax.tree.MappingFieldNode; import io.ballerina.compiler.syntax.tree.MetadataNode; import io.ballerina.compiler.syntax.tree.NamedArgumentNode; +import io.ballerina.compiler.syntax.tree.Node; import io.ballerina.compiler.syntax.tree.NodeList; import io.ballerina.compiler.syntax.tree.NodeLocation; import io.ballerina.compiler.syntax.tree.ParenthesizedArgList; import io.ballerina.compiler.syntax.tree.PositionalArgumentNode; import io.ballerina.compiler.syntax.tree.SeparatedNodeList; import io.ballerina.compiler.syntax.tree.ServiceDeclarationNode; +import io.ballerina.compiler.syntax.tree.SpecificFieldNode; import io.ballerina.compiler.syntax.tree.SyntaxKind; import io.ballerina.projects.plugins.AnalysisTask; import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext; @@ -47,6 +49,7 @@ import java.util.List; import java.util.Optional; +import static io.ballerina.stdlib.websocket.WebSocketConstants.ANNOTATION_ATTR_CONNECTION_CLOSURE_TIMEOUT; import static io.ballerina.stdlib.websocket.plugin.PluginConstants.DISPATCHER_ANNOTATION; import static io.ballerina.stdlib.websocket.plugin.PluginConstants.DISPATCHER_STREAM_ID_ANNOTATION; import static io.ballerina.stdlib.websocket.plugin.PluginConstants.ORG_NAME; @@ -74,6 +77,11 @@ public void perform(SyntaxNodeAnalysisContext ctx) { reportDiagnostics(ctx, PluginConstants.CompilationErrors.DISPATCHER_STREAM_ID_WITHOUT_KEY, serviceDeclarationNode.location()); } + Optional timeoutValue = getConnectionClosureTimeoutValue(serviceDeclarationNode, ctx.semanticModel()); + if (timeoutValue.isPresent() && timeoutValue.get() < 0 && timeoutValue.get().intValue() != -1) { + reportDiagnostics(ctx, PluginConstants.CompilationErrors.INVALID_CONNECTION_CLOSURE_TIMEOUT, + serviceDeclarationNode.metadata().get().location()); + } String modulePrefix = Utils.getPrefix(ctx); Optional serviceDeclarationSymbol = ctx.semanticModel().symbol(serviceDeclarationNode); @@ -92,31 +100,82 @@ public void perform(SyntaxNodeAnalysisContext ctx) { } } - private boolean isAnnotationPresent(AnnotationNode annotation, SemanticModel semanticModel, - String annotationName) { - Optional symbolOpt = semanticModel.symbol(annotation); - if (symbolOpt.isEmpty()) { + private boolean isAnnotationFieldPresent(AnnotationNode annotation, SemanticModel semanticModel, + String annotationName) { + if (annotation.annotValue().isEmpty()) { return false; } - - Symbol symbol = symbolOpt.get(); - if (!(symbol instanceof AnnotationSymbol)) { - return false; + for (MappingFieldNode field : annotation.annotValue().get().fields()) { + if (field.kind() != SyntaxKind.SPECIFIC_FIELD) { + continue; + } + Node fieldNameNode = ((SpecificFieldNode) field).fieldName(); + if (fieldNameNode.kind() != SyntaxKind.IDENTIFIER_TOKEN) { + continue; + } + Optional symbol = semanticModel.symbol(fieldNameNode); + if (symbol.isEmpty()) { + continue; + } + if (symbol.get().getName().isEmpty() || !annotationName.equals(symbol.get().getName().get())) { + continue; + } + return true; } - - return annotation.annotValue().toString().contains(annotationName); + return false; } private boolean getDispatcherConfigAnnotation(ServiceDeclarationNode serviceNode, SemanticModel semanticModel, String annotationName) { - Optional metadata = serviceNode.metadata(); - if (metadata.isEmpty()) { + if (serviceNode.metadata().isEmpty()) { return false; } - MetadataNode metaData = metadata.get(); + MetadataNode metaData = serviceNode.metadata().get(); NodeList annotations = metaData.annotations(); return annotations.stream() - .anyMatch(ann -> isAnnotationPresent(ann, semanticModel, annotationName)); + .anyMatch(ann -> isAnnotationFieldPresent(ann, semanticModel, annotationName)); + } + + private Optional getConnectionClosureTimeoutValue(ServiceDeclarationNode serviceNode, + SemanticModel semanticModel) { + if (serviceNode.metadata().isEmpty()) { + return Optional.empty(); + } + NodeList annotations = serviceNode.metadata().get().annotations(); + return annotations.stream() + .filter(ann -> isAnnotationFieldPresent(ann, semanticModel, ANNOTATION_ATTR_CONNECTION_CLOSURE_TIMEOUT)) + .map(ann -> getAnnotationValue(ann, semanticModel, ANNOTATION_ATTR_CONNECTION_CLOSURE_TIMEOUT)) + .filter(Optional::isPresent) + .map(Optional::get) + .findFirst() + .map(Double::parseDouble); + } + + private Optional getAnnotationValue(AnnotationNode annotation, SemanticModel semanticModel, + String annotationName) { + if (annotation.annotValue().isEmpty()) { + return Optional.empty(); + } + for (MappingFieldNode field : annotation.annotValue().get().fields()) { + if (field.kind() == SyntaxKind.SPECIFIC_FIELD) { + SpecificFieldNode specificFieldNode = (SpecificFieldNode) field; + Optional symbol = semanticModel.symbol(specificFieldNode); + if (symbol.isEmpty()) { + continue; + } + if (symbol.get().getName().isEmpty() || !annotationName.equals(symbol.get().getName().get())) { + continue; + } + if (specificFieldNode.valueExpr().isEmpty()) { + return Optional.empty(); + } + if (specificFieldNode.valueExpr().get().kind() != SyntaxKind.UNARY_EXPRESSION) { + return Optional.empty(); + } + return Optional.of(specificFieldNode.valueExpr().get().toString().strip()); + } + } + return Optional.empty(); } private boolean isListenerBelongsToWebSocketModule(TypeSymbol listenerType) { diff --git a/docs/spec/spec.md b/docs/spec/spec.md index 53fe9f17b..f2a76f10e 100644 --- a/docs/spec/spec.md +++ b/docs/spec/spec.md @@ -217,6 +217,11 @@ When writing the service, following configurations can be provided, # + auth - Listener authenticaton configurations # + dispatcherKey - The key which is going to be used for dispatching to custom remote functions. # + dispatcherStreamId - The identifier used to distinguish between requests and their corresponding responses in a multiplexing scenario. +# + connectionClosureTimeout - Time to wait (in seconds) for the close frame to be received from the remote endpoint +# before closing the connection. If the timeout exceeds, then the connection is terminated even though a close frame is +# not received from the remote endpoint. If the value is -1, then the connection waits until a close frame is +# received, and any other negative value results in an error. If the WebSocket frame is received from the remote endpoint within the waiting period, the connection is +# terminated immediately. public type WSServiceConfig record {| string[] subProtocols = []; decimal idleTimeout = 0; @@ -225,9 +230,12 @@ public type WSServiceConfig record {| boolean validation = true; string dispatcherKey?; string dispatcherStreamId?; + decimal connectionClosureTimeout = 60; |}; ``` +> **Note:** The `connectionClosureTimeout` is validated at compile-time for literal values and at runtime for non-literal values such as variables. + ### 3.2. [WebSocket Service](#32-websocket-service) Once the WebSocket upgrade is accepted by the UpgradeService, it returns a `websocket:Service`. This service has a fixed set of remote methods that do not have any configs. Receiving messages will get dispatched to the relevant remote method. Each remote method is explained below. diff --git a/gradle.properties b/gradle.properties index b6314f160..28023d5be 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,6 +1,6 @@ org.gradle.caching=true group=io.ballerina.stdlib -version=2.14.1-SNAPSHOT +version=2.15.0-SNAPSHOT ballerinaLangVersion=2201.12.0 ballerinaTomlParserVersion=1.2.2 nettyVersion=4.1.118.Final diff --git a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketConstants.java b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketConstants.java index 12d37be18..b06555fa7 100644 --- a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketConstants.java +++ b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketConstants.java @@ -45,6 +45,7 @@ public class WebSocketConstants { public static final String WEBSOCKET_ANNOTATION_CONFIGURATION = "ServiceConfig"; public static final BString ANNOTATION_ATTR_SUB_PROTOCOLS = StringUtils.fromString("subProtocols"); public static final BString ANNOTATION_ATTR_IDLE_TIMEOUT = StringUtils.fromString("idleTimeout"); + public static final String ANNOTATION_ATTR_CONNECTION_CLOSURE_TIMEOUT = "connectionClosureTimeout"; public static final BString ANNOTATION_ATTR_READ_IDLE_TIMEOUT = StringUtils.fromString("readTimeout"); public static final BString ANNOTATION_ATTR_TIMEOUT = StringUtils.fromString("timeout"); public static final BString ANNOTATION_ATTR_MAX_FRAME_SIZE = StringUtils.fromString("maxFrameSize"); @@ -145,7 +146,6 @@ public class WebSocketConstants { public static final BString CLOSE_FRAME_REASON = StringUtils.fromString("reason"); public static final String PREDEFINED_CLOSE_FRAME_TYPE = "PredefinedCloseFrameType"; public static final String CUSTOM_CLOSE_FRAME_TYPE = "CustomCloseFrameType"; - public static final int CLOSE_FRAME_DEFAULT_TIMEOUT = 60; private WebSocketConstants() { } diff --git a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketResourceCallback.java b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketResourceCallback.java index 14aa45e37..0333708d9 100644 --- a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketResourceCallback.java +++ b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketResourceCallback.java @@ -48,11 +48,11 @@ import java.util.concurrent.CountDownLatch; import static io.ballerina.runtime.api.utils.StringUtils.fromString; -import static io.ballerina.stdlib.websocket.WebSocketConstants.CLOSE_FRAME_DEFAULT_TIMEOUT; import static io.ballerina.stdlib.websocket.WebSocketConstants.CLOSE_FRAME_TYPE; import static io.ballerina.stdlib.websocket.WebSocketConstants.PACKAGE_WEBSOCKET; import static io.ballerina.stdlib.websocket.WebSocketConstants.STREAMING_NEXT_FUNCTION; import static io.ballerina.stdlib.websocket.WebSocketResourceDispatcher.dispatchOnError; +import static io.ballerina.stdlib.websocket.actions.websocketconnector.Close.getConnectionClosureTimeout; import static io.ballerina.stdlib.websocket.actions.websocketconnector.Close.initiateConnectionClosure; import static io.ballerina.stdlib.websocket.actions.websocketconnector.Close.waitForTimeout; import static io.ballerina.stdlib.websocket.actions.websocketconnector.WebSocketConnector.fromByteArray; @@ -152,7 +152,8 @@ public static void sendCloseFrame(Object result, WebSocketConnectionInfo connect ChannelFuture closeFuture = initiateConnectionClosure(errors, statusCode, reason, connectionInfo, countDownLatch); connectionInfo.getWebSocketConnection().readNextFrame(); - waitForTimeout(errors, CLOSE_FRAME_DEFAULT_TIMEOUT, countDownLatch, connectionInfo); + int timeoutInSecs = getConnectionClosureTimeout(null, connectionInfo); + waitForTimeout(errors, timeoutInSecs, countDownLatch, connectionInfo); closeFuture.channel().close().addListener(future -> { WebSocketUtil.setListenerOpenField(connectionInfo); }); diff --git a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketUtil.java b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketUtil.java index e6d54aa20..393b99777 100644 --- a/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketUtil.java +++ b/native/src/main/java/io/ballerina/stdlib/websocket/WebSocketUtil.java @@ -200,6 +200,20 @@ public static int findMaxFrameSize(BMap configs) { } + public static int findTimeoutInSeconds(BMap config, BString key) { + String value = config.get(key).toString(); + int timeout; + try { + timeout = Integer.parseInt(value); + } catch (NumberFormatException e) { + throw WebSocketUtil.createErrorByType(new Exception("Invalid timeout value: " + value)); + } + if (timeout < 0 && timeout != -1) { + throw WebSocketUtil.createErrorByType(new Exception("Invalid timeout value: " + value)); + } + return timeout; + } + public static int findTimeoutInSeconds(BMap config, BString key, int defaultValue) { try { int timeout = (int) ((BDecimal) config.get(key)).floatValue(); diff --git a/native/src/main/java/io/ballerina/stdlib/websocket/actions/websocketconnector/Close.java b/native/src/main/java/io/ballerina/stdlib/websocket/actions/websocketconnector/Close.java index b70efc651..d4624b0bd 100644 --- a/native/src/main/java/io/ballerina/stdlib/websocket/actions/websocketconnector/Close.java +++ b/native/src/main/java/io/ballerina/stdlib/websocket/actions/websocketconnector/Close.java @@ -28,6 +28,7 @@ import io.ballerina.stdlib.websocket.observability.WebSocketObservabilityConstants; import io.ballerina.stdlib.websocket.observability.WebSocketObservabilityUtil; import io.ballerina.stdlib.websocket.server.WebSocketConnectionInfo; +import io.ballerina.stdlib.websocket.server.WebSocketServerService; import io.netty.channel.ChannelFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,7 +46,7 @@ public class Close { private static final Logger log = LoggerFactory.getLogger(Close.class); public static Object externClose(Environment env, BObject wsConnection, long statusCode, BString reason, - BDecimal timeoutInSecs) { + Object bTimeoutInSecs) { return env.yieldAndRun(() -> { CompletableFuture balFuture = new CompletableFuture<>(); WebSocketConnectionInfo connectionInfo = (WebSocketConnectionInfo) wsConnection @@ -53,12 +54,13 @@ public static Object externClose(Environment env, BObject wsConnection, long sta WebSocketObservabilityUtil.observeResourceInvocation(env, connectionInfo, WebSocketConstants.RESOURCE_NAME_CLOSE); try { + int timeoutInSecs = getConnectionClosureTimeout(bTimeoutInSecs, connectionInfo); CountDownLatch countDownLatch = new CountDownLatch(1); List errors = new ArrayList<>(1); ChannelFuture closeFuture = initiateConnectionClosure(errors, (int) statusCode, reason.getValue(), connectionInfo, countDownLatch); connectionInfo.getWebSocketConnection().readNextFrame(); - waitForTimeout(errors, (int) timeoutInSecs.floatValue(), countDownLatch, connectionInfo); + waitForTimeout(errors, timeoutInSecs, countDownLatch, connectionInfo); closeFuture.channel().close().addListener(future -> { WebSocketUtil.setListenerOpenField(connectionInfo); if (errors.isEmpty()) { @@ -81,8 +83,22 @@ public static Object externClose(Environment env, BObject wsConnection, long sta }); } - public static ChannelFuture initiateConnectionClosure(List errors, int statusCode, - String reason, WebSocketConnectionInfo connectionInfo, CountDownLatch latch) throws IllegalAccessException { + public static int getConnectionClosureTimeout(Object bTimeoutInSecs, WebSocketConnectionInfo connectionInfo) { + try { + int timeoutInSecs = 0; + if (bTimeoutInSecs instanceof BDecimal) { + timeoutInSecs = Integer.parseInt(bTimeoutInSecs.toString()); + } else if (connectionInfo.getService() instanceof WebSocketServerService webSocketServerService) { + timeoutInSecs = webSocketServerService.getConnectionClosureTimeout(); + } + return timeoutInSecs; + } catch (Exception e) { + throw new RuntimeException("Invalid timeout value: " + bTimeoutInSecs, e); + } + } + + public static ChannelFuture initiateConnectionClosure(List errors, int statusCode, String reason, + WebSocketConnectionInfo connectionInfo, CountDownLatch latch) throws IllegalAccessException { WebSocketConnection webSocketConnection = connectionInfo.getWebSocketConnection(); ChannelFuture closeFuture; closeFuture = webSocketConnection.initiateConnectionClosure(statusCode, reason); diff --git a/native/src/main/java/io/ballerina/stdlib/websocket/server/WebSocketServerService.java b/native/src/main/java/io/ballerina/stdlib/websocket/server/WebSocketServerService.java index 41a12b787..d6f2d531f 100644 --- a/native/src/main/java/io/ballerina/stdlib/websocket/server/WebSocketServerService.java +++ b/native/src/main/java/io/ballerina/stdlib/websocket/server/WebSocketServerService.java @@ -20,7 +20,6 @@ import io.ballerina.runtime.api.Runtime; import io.ballerina.runtime.api.types.ObjectType; -import io.ballerina.runtime.api.utils.StringUtils; import io.ballerina.runtime.api.utils.TypeUtils; import io.ballerina.runtime.api.values.BMap; import io.ballerina.runtime.api.values.BObject; @@ -30,6 +29,8 @@ import io.ballerina.stdlib.websocket.WebSocketService; import io.ballerina.stdlib.websocket.WebSocketUtil; +import static io.ballerina.runtime.api.utils.StringUtils.fromString; +import static io.ballerina.stdlib.websocket.WebSocketConstants.ANNOTATION_ATTR_CONNECTION_CLOSURE_TIMEOUT; import static io.ballerina.stdlib.websocket.WebSocketConstants.ANNOTATION_ATTR_DISPATCHER_KEY; import static io.ballerina.stdlib.websocket.WebSocketConstants.ANNOTATION_ATTR_VALIDATION_ENABLED; @@ -44,6 +45,7 @@ public class WebSocketServerService extends WebSocketService { private int idleTimeoutInSeconds = 0; private boolean enableValidation = true; private String dispatchingKey = null; + private int connectionClosureTimeout; public WebSocketServerService(BObject service, Runtime runtime, String basePath) { super(service, runtime); @@ -56,6 +58,8 @@ private void populateConfigs(String basePath) { negotiableSubProtocols = WebSocketUtil.findNegotiableSubProtocols(configAnnotation); idleTimeoutInSeconds = WebSocketUtil.findTimeoutInSeconds(configAnnotation, WebSocketConstants.ANNOTATION_ATTR_IDLE_TIMEOUT, 0); + connectionClosureTimeout = WebSocketUtil.findTimeoutInSeconds(configAnnotation, + fromString(ANNOTATION_ATTR_CONNECTION_CLOSURE_TIMEOUT)); maxFrameSize = WebSocketUtil.findMaxFrameSize(configAnnotation); enableValidation = configAnnotation.getBooleanValue(ANNOTATION_ATTR_VALIDATION_ENABLED); if (configAnnotation.getStringValue(ANNOTATION_ATTR_DISPATCHER_KEY) != null) { @@ -70,7 +74,7 @@ private void populateConfigs(String basePath) { @SuppressWarnings(WebSocketConstants.UNCHECKED) private BMap getServiceConfigAnnotation() { ObjectType serviceType = (ObjectType) TypeUtils.getReferredType(TypeUtils.getType(service)); - return (BMap) serviceType.getAnnotation(StringUtils.fromString( + return (BMap) serviceType.getAnnotation(fromString( ModuleUtils.getPackageIdentifier() + ":" + WebSocketConstants.WEBSOCKET_ANNOTATION_CONFIGURATION)); } @@ -101,4 +105,8 @@ public String getDispatchingKey() { public String getBasePath() { return basePath; } + + public int getConnectionClosureTimeout() { + return connectionClosureTimeout; + } }