diff --git a/.gitignore b/.gitignore index 8582ee9..d93c72d 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,9 @@ src/main/resources/messages/secrets.properties # these file/folders are to be tracked in docs related branches docs .nojekyll + +# ignore other extension related +competitive-companion + +# ignore MacOS related files +*.DS_Store \ No newline at end of file diff --git a/src/main/kotlin/com/github/pushpavel/autocp/gather/base/ProblemGatheringBridge.kt b/src/main/kotlin/com/github/pushpavel/autocp/gather/base/ProblemGatheringBridge.kt index 5b24bc1..567ea8f 100644 --- a/src/main/kotlin/com/github/pushpavel/autocp/gather/base/ProblemGatheringBridge.kt +++ b/src/main/kotlin/com/github/pushpavel/autocp/gather/base/ProblemGatheringBridge.kt @@ -11,10 +11,13 @@ import com.intellij.openapi.project.Project import com.intellij.openapi.startup.ProjectActivity import kotlinx.coroutines.CancellationException import kotlinx.coroutines.Job +import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.delay import kotlinx.coroutines.isActive import kotlinx.coroutines.launch +import kotlinx.coroutines.withTimeout import kotlinx.serialization.SerializationException import kotlinx.serialization.json.Json import java.net.SocketTimeoutException @@ -50,22 +53,37 @@ class ProblemGatheringBridge : Disposable { // initialize server serverJob = scope.launch { try { - val serverSocket = openServerSocketAsync( + val serverSockets = openServerSocketsAsync( R.others.competitiveCompanionPorts ).await() ?: throw ProblemGatheringErr.AllPortsTakenErr(R.others.competitiveCompanionPorts) - serverSocket.use { + val messages = Channel(Channel.RENDEZVOUS) + try { + serverSockets.forEach { socket -> + // keep accept responsive to cancellation + socket.soTimeout = 1000 + launch { + while (isActive) { + try { + val message = listenForMessageAsync(socket, 1000).await() ?: continue + messages.send(message) + } catch (_: SocketTimeoutException) { + // keep looping + } + } + } + } + while (isActive) { try { coroutineScope { - val message = listenForMessageAsync( - serverSocket, R.others.problemGatheringTimeoutMillis - ).await() ?: return@coroutineScope - + val message = withTimeout(R.others.problemGatheringTimeoutMillis.toLong()) { + messages.receive() + } val json = serializer.decodeFromString(message) BatchProcessor.onJsonReceived(json) } - } catch (e: SocketTimeoutException) { + } catch (_: TimeoutCancellationException) { BatchProcessor.interruptBatch(ProblemGatheringErr.TimeoutErr) } catch (e: SerializationException) { BatchProcessor.interruptBatch(ProblemGatheringErr.JsonErr) @@ -73,6 +91,9 @@ class ProblemGatheringBridge : Disposable { while (BatchProcessor.isCurrentBatchBlocking()) delay(100) } + } finally { + messages.close() + serverSockets.forEach { kotlin.runCatching { it.close() } } } } catch (e: ProblemGatheringErr) { R.notify.problemGatheringErr(e) diff --git a/src/main/kotlin/com/github/pushpavel/autocp/gather/base/localServer.kt b/src/main/kotlin/com/github/pushpavel/autocp/gather/base/localServer.kt index 3716381..f4ec8e5 100644 --- a/src/main/kotlin/com/github/pushpavel/autocp/gather/base/localServer.kt +++ b/src/main/kotlin/com/github/pushpavel/autocp/gather/base/localServer.kt @@ -6,14 +6,24 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import java.io.BufferedReader +import java.io.BufferedInputStream +import java.io.ByteArrayOutputStream import java.io.InputStream import java.io.InputStreamReader import java.net.InetAddress import java.net.ServerSocket import java.net.SocketException +import java.net.SocketTimeoutException -fun CoroutineScope.openServerSocketAsync(ports: List) = async(Dispatchers.IO) { +/** + * Try to bind server sockets on loopback addresses for a port from [ports]. + * + * Competitive Companion always sends to `http://localhost:/`. + * On some systems `localhost` can resolve to IPv6 first (`::1`), while AutoCp may only be bound to IPv4 (`127.0.0.1`) + * (or vice versa). To make this robust, we try to bind both loopback addresses on the same port (if the OS allows it). + */ +fun CoroutineScope.openServerSocketsAsync(ports: List) = async(Dispatchers.IO) { val log = Logger.getInstance("${R.keys.pluginId} openServerSocketAsync") var portIndex = 0 @@ -22,13 +32,32 @@ fun CoroutineScope.openServerSocketAsync(ports: List) = async(Dispatchers.I if (portIndex != 0) log.info("Port ${ports[portIndex - 1]} taken. retrying with Port ${ports[portIndex]}") + val port = ports[portIndex] + val sockets = mutableListOf() try { - // successfully starting the server - return@async ServerSocket(ports[portIndex], 50, InetAddress.getByName("localhost")) - } catch (e: SocketException) { - // failed retrying with next port - portIndex++ + // Prefer IPv4 loopback to keep current behavior, then also try IPv6 loopback. + runCatching { sockets.add(ServerSocket(port, 50, InetAddress.getByName("127.0.0.1"))) } + runCatching { sockets.add(ServerSocket(port, 50, InetAddress.getByName("::1"))) } + + if (sockets.isNotEmpty()) { + runCatching { + val bound = sockets.joinToString(", ") { s -> s.inetAddress.hostAddress + ":" + s.localPort } + log.info("Listening for Competitive Companion on $bound") + } + return@async sockets + } + } catch (_: SocketException) { + // ignore and retry next port + } catch (_: Exception) { + // ignore and retry next port + } finally { + if (sockets.isEmpty()) { + // Ensure we don't leak partially opened sockets when retrying. + sockets.forEach { runCatching { it.close() } } + } } + + portIndex++ } if (portIndex != 0) @@ -42,15 +71,89 @@ fun CoroutineScope.openServerSocketAsync(ports: List) = async(Dispatchers.I */ fun CoroutineScope.listenForMessageAsync(serverSocket: ServerSocket, timeout: Int) = async(Dispatchers.IO) { serverSocket.soTimeout = timeout - serverSocket.accept().use { - val inputStream = it.getInputStream() - val request = readFromStream(inputStream) - val strings = request.split("\n\n".toPattern(), 2).toTypedArray() + try { + serverSocket.accept().use { + // ServerSocket.soTimeout only affects accept(). Ensure reads don't block forever. + it.soTimeout = timeout + + val inputStream = BufferedInputStream(it.getInputStream()) + val body = readHttpBody(inputStream) - if (strings.size > 1) - return@async strings[1] + if (!body.isNullOrEmpty()) return@async body + } + return@async null + } catch (_: SocketTimeoutException) { + // Normal: no incoming connection within timeout. + return@async null + } catch (_: SocketException) { + // Can happen on shutdown / dynamic plugin unload. + return@async null } - return@async null +} + +/** + * Reads an HTTP request body without waiting for EOF. + * Supports both CRLF and LF separators, and uses Content-Length when available. + */ +private fun readHttpBody(inputStream: BufferedInputStream): String? { + val headerBytes = ByteArrayOutputStream() + var prev = -1 + var curr: Int + var seenLfLf = false + var seenCrlfCrlf = false + + // Read headers up to a sane limit + val maxHeaderBytes = 64 * 1024 + while (headerBytes.size() < maxHeaderBytes) { + curr = inputStream.read() + if (curr == -1) break + headerBytes.write(curr) + + // detect \n\n + if (prev == '\n'.code && curr == '\n'.code) { + seenLfLf = true + break + } + // detect \r\n\r\n + val hb = headerBytes.toByteArray() + val n = hb.size + if (n >= 4 && + hb[n - 4] == '\r'.code.toByte() && + hb[n - 3] == '\n'.code.toByte() && + hb[n - 2] == '\r'.code.toByte() && + hb[n - 1] == '\n'.code.toByte() + ) { + seenCrlfCrlf = true + break + } + + prev = curr + } + + if (!seenLfLf && !seenCrlfCrlf) { + // couldn't find header terminator; fall back to old behavior + return null + } + + val headers = headerBytes.toString(Charsets.ISO_8859_1) + val contentLength = Regex("(?im)^Content-Length:\\s*(\\d+)\\s*$") + .find(headers) + ?.groupValues + ?.getOrNull(1) + ?.toIntOrNull() + ?: return null + + if (contentLength <= 0) return null + + val bodyBytes = ByteArray(contentLength) + var off = 0 + while (off < contentLength) { + val read = inputStream.read(bodyBytes, off, contentLength - off) + if (read <= 0) break + off += read + } + if (off <= 0) return null + return bodyBytes.copyOf(off).toString(Charsets.UTF_8) } /**