Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@ src/main/resources/messages/secrets.properties
# these file/folders are to be tracked in docs related branches
docs
.nojekyll

# ignore other extension related
competitive-companion

# ignore MacOS related files
*.DS_Store
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ import com.intellij.openapi.project.Project
import com.intellij.openapi.startup.ProjectActivity
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.Job
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.Json
import java.net.SocketTimeoutException
Expand Down Expand Up @@ -50,29 +53,47 @@ class ProblemGatheringBridge : Disposable {
// initialize server
serverJob = scope.launch {
try {
val serverSocket = openServerSocketAsync(
val serverSockets = openServerSocketsAsync(
R.others.competitiveCompanionPorts
).await() ?: throw ProblemGatheringErr.AllPortsTakenErr(R.others.competitiveCompanionPorts)

serverSocket.use {
val messages = Channel<String>(Channel.RENDEZVOUS)
try {
serverSockets.forEach { socket ->
// keep accept responsive to cancellation
socket.soTimeout = 1000
launch {
while (isActive) {
try {
val message = listenForMessageAsync(socket, 1000).await() ?: continue
messages.send(message)
} catch (_: SocketTimeoutException) {
// keep looping
}
}
}
}

while (isActive) {
try {
coroutineScope {
val message = listenForMessageAsync(
serverSocket, R.others.problemGatheringTimeoutMillis
).await() ?: return@coroutineScope

val message = withTimeout(R.others.problemGatheringTimeoutMillis.toLong()) {
messages.receive()
}
val json = serializer.decodeFromString<ProblemJson>(message)
BatchProcessor.onJsonReceived(json)
}
} catch (e: SocketTimeoutException) {
} catch (_: TimeoutCancellationException) {
BatchProcessor.interruptBatch(ProblemGatheringErr.TimeoutErr)
} catch (e: SerializationException) {
BatchProcessor.interruptBatch(ProblemGatheringErr.JsonErr)
}

while (BatchProcessor.isCurrentBatchBlocking()) delay(100)
}
} finally {
messages.close()
serverSockets.forEach { kotlin.runCatching { it.close() } }
}
} catch (e: ProblemGatheringErr) {
R.notify.problemGatheringErr(e)
Expand Down
129 changes: 116 additions & 13 deletions src/main/kotlin/com/github/pushpavel/autocp/gather/base/localServer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,24 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import java.io.BufferedReader
import java.io.BufferedInputStream
import java.io.ByteArrayOutputStream
import java.io.InputStream
import java.io.InputStreamReader
import java.net.InetAddress
import java.net.ServerSocket
import java.net.SocketException
import java.net.SocketTimeoutException


fun CoroutineScope.openServerSocketAsync(ports: List<Int>) = async(Dispatchers.IO) {
/**
* Try to bind server sockets on loopback addresses for a port from [ports].
*
* Competitive Companion always sends to `http://localhost:<port>/`.
* On some systems `localhost` can resolve to IPv6 first (`::1`), while AutoCp may only be bound to IPv4 (`127.0.0.1`)
* (or vice versa). To make this robust, we try to bind both loopback addresses on the same port (if the OS allows it).
*/
fun CoroutineScope.openServerSocketsAsync(ports: List<Int>) = async(Dispatchers.IO) {
val log = Logger.getInstance("${R.keys.pluginId} openServerSocketAsync")
var portIndex = 0

Expand All @@ -22,13 +32,32 @@ fun CoroutineScope.openServerSocketAsync(ports: List<Int>) = async(Dispatchers.I
if (portIndex != 0)
log.info("Port ${ports[portIndex - 1]} taken. retrying with Port ${ports[portIndex]}")

val port = ports[portIndex]
val sockets = mutableListOf<ServerSocket>()
try {
// successfully starting the server
return@async ServerSocket(ports[portIndex], 50, InetAddress.getByName("localhost"))
} catch (e: SocketException) {
// failed retrying with next port
portIndex++
// Prefer IPv4 loopback to keep current behavior, then also try IPv6 loopback.
runCatching { sockets.add(ServerSocket(port, 50, InetAddress.getByName("127.0.0.1"))) }
runCatching { sockets.add(ServerSocket(port, 50, InetAddress.getByName("::1"))) }

if (sockets.isNotEmpty()) {
runCatching {
val bound = sockets.joinToString(", ") { s -> s.inetAddress.hostAddress + ":" + s.localPort }
log.info("Listening for Competitive Companion on $bound")
}
return@async sockets
}
} catch (_: SocketException) {
// ignore and retry next port
} catch (_: Exception) {
// ignore and retry next port
} finally {
if (sockets.isEmpty()) {
// Ensure we don't leak partially opened sockets when retrying.
sockets.forEach { runCatching { it.close() } }
}
}

portIndex++
}

if (portIndex != 0)
Expand All @@ -42,15 +71,89 @@ fun CoroutineScope.openServerSocketAsync(ports: List<Int>) = async(Dispatchers.I
*/
fun CoroutineScope.listenForMessageAsync(serverSocket: ServerSocket, timeout: Int) = async(Dispatchers.IO) {
serverSocket.soTimeout = timeout
serverSocket.accept().use {
val inputStream = it.getInputStream()
val request = readFromStream(inputStream)
val strings = request.split("\n\n".toPattern(), 2).toTypedArray()
try {
serverSocket.accept().use {
// ServerSocket.soTimeout only affects accept(). Ensure reads don't block forever.
it.soTimeout = timeout

val inputStream = BufferedInputStream(it.getInputStream())
val body = readHttpBody(inputStream)

if (strings.size > 1)
return@async strings[1]
if (!body.isNullOrEmpty()) return@async body
}
return@async null
} catch (_: SocketTimeoutException) {
// Normal: no incoming connection within timeout.
return@async null
} catch (_: SocketException) {
// Can happen on shutdown / dynamic plugin unload.
return@async null
}
return@async null
}

/**
* Reads an HTTP request body without waiting for EOF.
* Supports both CRLF and LF separators, and uses Content-Length when available.
*/
private fun readHttpBody(inputStream: BufferedInputStream): String? {
val headerBytes = ByteArrayOutputStream()
var prev = -1
var curr: Int
var seenLfLf = false
var seenCrlfCrlf = false

// Read headers up to a sane limit
val maxHeaderBytes = 64 * 1024
while (headerBytes.size() < maxHeaderBytes) {
curr = inputStream.read()
if (curr == -1) break
headerBytes.write(curr)

// detect \n\n
if (prev == '\n'.code && curr == '\n'.code) {
seenLfLf = true
break
}
// detect \r\n\r\n
val hb = headerBytes.toByteArray()
val n = hb.size
if (n >= 4 &&
hb[n - 4] == '\r'.code.toByte() &&
hb[n - 3] == '\n'.code.toByte() &&
hb[n - 2] == '\r'.code.toByte() &&
hb[n - 1] == '\n'.code.toByte()
) {
seenCrlfCrlf = true
break
}

prev = curr
}

if (!seenLfLf && !seenCrlfCrlf) {
// couldn't find header terminator; fall back to old behavior
return null
}

val headers = headerBytes.toString(Charsets.ISO_8859_1)
val contentLength = Regex("(?im)^Content-Length:\\s*(\\d+)\\s*$")
.find(headers)
?.groupValues
?.getOrNull(1)
?.toIntOrNull()
?: return null

if (contentLength <= 0) return null

val bodyBytes = ByteArray(contentLength)
var off = 0
while (off < contentLength) {
val read = inputStream.read(bodyBytes, off, contentLength - off)
if (read <= 0) break
off += read
}
if (off <= 0) return null
return bodyBytes.copyOf(off).toString(Charsets.UTF_8)
}

/**
Expand Down
Loading