Skip to content

Commit ab87add

Browse files
committed
Add replay test in ci
1 parent 899397d commit ab87add

File tree

3 files changed

+991
-24
lines changed

3 files changed

+991
-24
lines changed

engine/build.gradle.kts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,7 @@ dependencies {
99
implementation(libs.ktor.client.java)
1010
implementation(libs.ktor.client.content.negotiation)
1111
implementation(libs.ktor.serialization.json)
12+
testImplementation(kotlin("test"))
13+
testImplementation(libs.kotlinx.serialization.json)
14+
testImplementation(libs.kotlinx.coroutines.core)
1215
}

engine/src/test/kotlin/de/tuda/stg/securecoder/engine/workflow/EngineLlmReplayTests.kt

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,25 @@ import de.tuda.stg.securecoder.engine.stream.StreamEvent
66
import de.tuda.stg.securecoder.enricher.PromptEnricher
77
import de.tuda.stg.securecoder.filesystem.InMemoryFileSystem
88
import kotlinx.coroutines.runBlocking
9+
import kotlinx.serialization.ExperimentalSerializationApi
910
import kotlinx.serialization.encodeToString
1011
import kotlinx.serialization.json.Json
12+
import kotlinx.serialization.json.decodeFromStream
1113
import kotlin.test.Test
1214
import java.nio.file.Files
1315
import java.nio.file.Path
16+
import kotlin.test.Ignore
17+
import kotlin.test.assertIs
18+
import kotlin.test.assertTrue
1419

1520
class EngineLlmReplayTests {
1621
private val json = Json { prettyPrint = true; ignoreUnknownKeys = true; encodeDefaults = true }
17-
private val logsPath: Path = Path.of("build", "llm_logs", "log.json")
22+
private val resourceName = "llm_output.json"
1823

1924
@Test
25+
@Ignore
2026
fun generator_collects_real_llm_responses() = runBlocking {
27+
val logsPath: Path = Path.of("src", "test", "resources", resourceName)
2128
Files.createDirectories(logsPath.parent)
2229

2330
val prompts = listOf(
@@ -27,11 +34,11 @@ class EngineLlmReplayTests {
2734

2835
val models = buildList {
2936
val apiKey = System.getenv("API_KEy") ?: "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"
30-
add("OR:llama3.2:latest" to OpenRouterClient(apiKey, "meta-llama/llama-3.2-3b-instruct"))
31-
add("OR:gpt-oss:20b" to OpenRouterClient(apiKey, "openai/gpt-oss-20b"))
32-
//val olBase = System.getenv("OLLAMA_URL") ?: "http://127.0.0.1:11434"
33-
//add("ollama:llama3.2:latest" to OllamaClient("llama3.2:latest", baseUrl = olBase))
34-
//add("ollama:gpt-oss:20b" to OllamaClient("gpt-oss:20b", baseUrl = olBase))
37+
//add("OR:llama3.2:latest" to OpenRouterClient(apiKey, "meta-llama/llama-3.2-3b-instruct"))
38+
//add("OR:gpt-oss:20b" to OpenRouterClient(apiKey, "openai/gpt-oss-20b"))
39+
val olBase = System.getenv("OLLAMA_URL") ?: "http://127.0.0.1:11434"
40+
add("ollama:llama3.2:latest" to OllamaClient("llama3.2:latest", baseUrl = olBase))
41+
add("ollama:gpt-oss:20b" to OllamaClient("gpt-oss:20b", baseUrl = olBase))
3542
}
3643

3744
val runs = mutableListOf<LoggedRun>()
@@ -81,12 +88,11 @@ class EngineLlmReplayTests {
8188
}
8289

8390
@Test
84-
fun replay_test_uses_recorded_responses_and_counts_success() = runBlocking {
85-
if (!Files.exists(logsPath)) {
86-
println("No log file at $logsPath; nothing to replay. Test will be a no-op.")
87-
return@runBlocking
88-
}
89-
val suite = json.decodeFromString<LoggedSuite>(Files.readString(logsPath))
91+
@OptIn(ExperimentalSerializationApi::class)
92+
fun test_replay() = runBlocking {
93+
val resourceStream = this@EngineLlmReplayTests::class.java.classLoader.getResourceAsStream(resourceName)
94+
?: throw IllegalStateException("No $resourceName on classpath")
95+
val suite: LoggedSuite = resourceStream.use { json.decodeFromStream(it) }
9096

9197
data class Group(
9298
val modelName: String,
@@ -95,7 +101,7 @@ class EngineLlmReplayTests {
95101
var successes: Int = 0,
96102
var parseFails: Int = 0
97103
)
98-
104+
assertTrue(suite.runs.isNotEmpty())
99105
suite.runs
100106
.groupBy { it.modelName to it.promptKind }
101107
.map { (key, runs) ->
@@ -112,28 +118,22 @@ class EngineLlmReplayTests {
112118
guardians = emptyList(),
113119
)
114120
group.total++
115-
var l = 0
116121
val result = engine.run(run.enginePrompt, fs, onEvent = {
117122
if (it !is StreamEvent.InvalidLlmOutputWarning) return@run
118123
group.parseFails++
119-
if (l++ >= 2) {
120-
println("=======ERROR=======")
121-
println("=======ERROR=======")
122-
println("=======ERROR=======")
123-
println(it.parseErrors.joinToString("\n"))
124-
println()
125-
println()
126-
println(it.chatExchange.output)
127-
}
124+
//println(it.parseErrors.joinToString("\n"))
125+
//println(it.chatExchange.output)
128126
})
129127
if (result is EngineResult.Success) {
130128
group.successes++
131129
}
130+
assertIs<EngineResult.Success>(result)
132131
}
133132
group
134133
}
135134
.forEach {
136-
println("Group ${it.modelName} / ${it.promptKind}: replayed runs: ${it.total}, successes: ${it.successes} (${it.parseFails} parse failures)")
135+
assertTrue(it.parseFails <= (1.5 * suite.runs.size))
136+
//println("Group ${it.modelName} / ${it.promptKind}: replayed runs: ${it.total}, successes: ${it.successes} (${it.parseFails} parse failures)")
137137
}
138138
}
139139
}

0 commit comments

Comments
 (0)