Skip to content

Commit 4c634c7

Browse files
committed
Support nullable types in KxJsonSchemaFormat
1 parent d8195eb commit 4c634c7

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

engine/src/main/kotlin/de/tuda/stg/securecoder/engine/llm/KxJsonSchemaFormat.kt

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import kotlinx.serialization.json.JsonObjectBuilder
1313
import kotlinx.serialization.json.JsonPrimitive
1414
import kotlinx.serialization.json.buildJsonArray
1515
import kotlinx.serialization.json.buildJsonObject
16+
import kotlinx.serialization.json.putJsonArray
1617

1718
@OptIn(ExperimentalSerializationApi::class)
1819
class KxJsonSchemaFormat {
@@ -24,7 +25,7 @@ class KxJsonSchemaFormat {
2425
if (!seen.add(key)) {
2526
throw IllegalStateException("Recursive type detected: $key")
2627
}
27-
val jsonType = when (desc.kind) {
28+
var jsonType = when (desc.kind) {
2829
PrimitiveKind.BOOLEAN -> type("boolean")
2930
PrimitiveKind.BYTE, PrimitiveKind.SHORT, PrimitiveKind.INT, PrimitiveKind.LONG -> type("integer")
3031
PrimitiveKind.FLOAT, PrimitiveKind.DOUBLE -> type("number")
@@ -65,12 +66,31 @@ class KxJsonSchemaFormat {
6566
}
6667
seen.remove(key)
6768
if (desc.isNullable) {
68-
throw IllegalStateException("Nullable types are not supported")
69+
jsonType = makeNullable(jsonType)
6970
}
7071
val selfDesc = getDescription(desc.annotations)
7172
return if (selfDesc != null) addDescription(jsonType, selfDesc) else jsonType
7273
}
7374

75+
private fun makeNullable(schema: JsonObject): JsonObject {
76+
val type = schema["type"]
77+
if (type is JsonPrimitive && type.isString) {
78+
return buildJsonObject {
79+
schema.forEach(::put)
80+
putJsonArray("type") {
81+
add(type)
82+
add(JsonPrimitive("null"))
83+
}
84+
}
85+
}
86+
return buildJsonObject {
87+
put("anyOf", buildJsonArray {
88+
add(schema)
89+
add(type("null"))
90+
})
91+
}
92+
}
93+
7494
private fun type(name: String, builderAction: JsonObjectBuilder.() -> Unit = {}): JsonObject =
7595
buildJsonObject {
7696
put("type", JsonPrimitive(name))
@@ -88,7 +108,7 @@ class KxJsonSchemaFormat {
88108

89109
private fun addDescription(obj: JsonObject, text: String): JsonObject =
90110
buildJsonObject {
91-
obj.forEach { (k, v) -> put(k, v) }
111+
obj.forEach(::put)
92112
put("description", JsonPrimitive(text))
93113
}
94114
}

0 commit comments

Comments
 (0)