Skip to content

Commit

Permalink
add @description annotation to tool properties (#10)
Browse files Browse the repository at this point in the history
It is possible to use overloaded plusAssign operator to add response messages directly to a conversation (MutableList<Message>)
  • Loading branch information
morisil authored Oct 17, 2024
1 parent 61c4a31 commit 1d29673
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 33 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ fun main() = runBlocking {
println("Initial response:")
println(initialResponse)

conversation += initialResponse.asMessage()
conversation += initialResponse
val tool = initialResponse.content.filterIsInstance<ToolUse>().first()
val toolResult = tool.use()
conversation += Message { +toolResult }
Expand Down
23 changes: 22 additions & 1 deletion src/commonMain/kotlin/message/Messages.kt
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ data class System(
@Serializable
data class Tool(
val name: String,
val description: String,
val description: String?,
@SerialName("input_schema")
val inputSchema: JsonSchema,
@SerialName("cache_control")
Expand Down Expand Up @@ -354,6 +354,21 @@ fun ToolResult(
content = listOf(Text(text))
)

inline fun <reified T> ToolResult(
toolUseId: String,
value: T
): ToolResult = ToolResult(
toolUseId,
content = listOf(
Text(
anthropicJson.encodeToString(
serializer = serializer<T>(),
value = value
)
)
)
)

@Serializable
data class CacheControl(
val type: Type
Expand Down Expand Up @@ -409,3 +424,9 @@ data class Usage(
@SerialName("output_tokens")
val outputTokens: Int
)

operator fun MutableCollection<in Message>.plusAssign(
response: MessageResponse
) {
this += response.asMessage()
}
21 changes: 11 additions & 10 deletions src/commonMain/kotlin/schema/JsonSchema.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
package com.xemantic.anthropic.schema

import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.MetaSerializable
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@OptIn(ExperimentalSerializationApi::class)
@Target(AnnotationTarget.PROPERTY)
@MetaSerializable
annotation class Description(
val value: String
)

@Serializable
data class JsonSchema(
val type: String = "object",
Expand All @@ -16,17 +25,9 @@ data class JsonSchema(
@Serializable
data class JsonSchemaProperty(
val type: String? = null,
val description: String? = null,
val items: JsonSchemaProperty? = null,
val enum: List<String>? = null,
@SerialName("\$ref")
val ref: String? = null
) {

companion object {
val STRING = JsonSchemaProperty("string")
val INTEGER = JsonSchemaProperty("integer")
val NUMBER = JsonSchemaProperty("number")
val BOOLEAN = JsonSchemaProperty("boolean")
}

}
)
39 changes: 27 additions & 12 deletions src/commonMain/kotlin/schema/JsonSchemaGenerator.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@ fun generateSchema(descriptor: SerialDescriptor): JsonSchema {
for (i in 0 until descriptor.elementsCount) {
val name = descriptor.getElementName(i)
val elementDescriptor = descriptor.getElementDescriptor(i)
val property = generateSchemaProperty(elementDescriptor, definitions)
val elementAnnotations = descriptor.getElementAnnotations(i)
val property = generateSchemaProperty(
elementDescriptor,
description = elementAnnotations
.filterIsInstance<Description>()
.firstOrNull()
?.value,
definitions
)
properties[name] = property
if (!descriptor.isElementOptional(i)) {
required.add(name)
Expand All @@ -35,38 +43,45 @@ fun generateSchema(descriptor: SerialDescriptor): JsonSchema {
@OptIn(ExperimentalSerializationApi::class)
private fun generateSchemaProperty(
descriptor: SerialDescriptor,
description: String?,
definitions: MutableMap<String, JsonSchema>
): JsonSchemaProperty {
return when (descriptor.kind) {
PrimitiveKind.STRING -> JsonSchemaProperty.STRING
PrimitiveKind.INT, PrimitiveKind.LONG -> JsonSchemaProperty.INTEGER
PrimitiveKind.FLOAT, PrimitiveKind.DOUBLE -> JsonSchemaProperty.NUMBER
PrimitiveKind.BOOLEAN -> JsonSchemaProperty.BOOLEAN
SerialKind.ENUM -> enumProperty(descriptor)
PrimitiveKind.STRING -> JsonSchemaProperty("string", description)
PrimitiveKind.INT, PrimitiveKind.LONG -> JsonSchemaProperty("integer", description)
PrimitiveKind.FLOAT, PrimitiveKind.DOUBLE -> JsonSchemaProperty("number", description)
PrimitiveKind.BOOLEAN -> JsonSchemaProperty("boolean", description)
SerialKind.ENUM -> enumProperty(descriptor, description)
StructureKind.LIST -> JsonSchemaProperty(
type = "array",
items = generateSchemaProperty(
descriptor.getElementDescriptor(0),
description,
definitions
)
)
StructureKind.MAP -> JsonSchemaProperty("object")
StructureKind.MAP -> JsonSchemaProperty("object", description)
StructureKind.CLASS -> {
// dots are not allowed in JSON Schema name, if the @SerialName was not
// specified, then fully qualified class name will be used, and we need
// to translate it
val refName = descriptor.serialName.replace('.', '_').trimEnd('?')
definitions[refName] = generateSchema(descriptor)
JsonSchemaProperty(ref = "#/definitions/$refName")
JsonSchemaProperty(
ref = "#/definitions/$refName",
description = description
)
}
else -> JsonSchemaProperty("object") // Default case
else -> JsonSchemaProperty("object", description) // Default case
}
}

private fun enumProperty(
descriptor: SerialDescriptor
) = JsonSchemaProperty(
enum = descriptor.elementNames()
descriptor: SerialDescriptor,
description: String?
) = JsonSchemaProperty( // TODO should it return type enum?
enum = descriptor.elementNames(),
description = description,
)

@OptIn(ExperimentalSerializationApi::class)
Expand Down
5 changes: 3 additions & 2 deletions src/commonMain/kotlin/tool/Tools.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import kotlinx.serialization.serializer
@Target(AnnotationTarget.CLASS)
annotation class AnthropicTool(
val name: String,
val description: String
val description: String = ""
)

/**
Expand Down Expand Up @@ -80,7 +80,8 @@ inline fun <reified T : UsableTool> toolOf(

return Tool(
name = anthropicTool.name,
description = anthropicTool.description,
// annotation description cannot be null, so we allow empty and detect it here
description = if (anthropicTool.description.isNotBlank()) anthropicTool.description else null,
inputSchema = jsonSchemaOf<T>(),
cacheControl = cacheControl
)
Expand Down
7 changes: 4 additions & 3 deletions src/commonTest/kotlin/AnthropicTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.xemantic.anthropic.message.Role
import com.xemantic.anthropic.message.StopReason
import com.xemantic.anthropic.message.Text
import com.xemantic.anthropic.message.ToolUse
import com.xemantic.anthropic.message.plusAssign
import com.xemantic.anthropic.test.Calculator
import com.xemantic.anthropic.test.DatabaseQueryTool
import com.xemantic.anthropic.test.FibonacciTool
Expand Down Expand Up @@ -117,7 +118,7 @@ class AnthropicTest {
messages = conversation
useTools()
}
conversation += initialResponse.asMessage()
conversation += initialResponse

// then
assertSoftly(initialResponse) {
Expand Down Expand Up @@ -189,7 +190,7 @@ class AnthropicTest {
messages = conversation
useTools()
}
conversation += fibonacciResponse.asMessage()
conversation += fibonacciResponse

val fibonacciToolUse = fibonacciResponse.content.filterIsInstance<ToolUse>().first()
fibonacciToolUse.name shouldBe "FibonacciTool"
Expand All @@ -200,7 +201,7 @@ class AnthropicTest {
messages = conversation
useTools()
}
conversation += calculatorResponse.asMessage()
conversation += calculatorResponse

val calculatorToolUse = calculatorResponse.content.filterIsInstance<ToolUse>().first()
calculatorToolUse.name shouldBe "Calculator"
Expand Down
34 changes: 34 additions & 0 deletions src/commonTest/kotlin/message/ToolResultTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.xemantic.anthropic.message

import io.kotest.matchers.shouldBe
import kotlinx.serialization.Serializable
import kotlin.test.Test

class ToolResultTest {

@Test
fun shouldCreateToolResultForSingleString() {
ToolResult(
toolUseId = "42",
"foo"
) shouldBe ToolResult(
toolUseId = "42",
content = listOf(Text(text = "foo"))
)
}

@Serializable
data class Foo(val bar: String)

@Test
fun shouldCreateToolResultForSerializableInstance() {
ToolResult(
toolUseId = "42",
Foo("buzz")
) shouldBe ToolResult(
toolUseId = "42",
content = listOf(Text(text = "{\"bar\":\"buzz\"}"))
)
}

}
4 changes: 3 additions & 1 deletion src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ data class Address(

@Serializable
data class Person(
@Description("The official name")
val name: String,
val age: Int,
val email: String?,
Expand Down Expand Up @@ -102,7 +103,8 @@ class JsonSchemaGeneratorTest {
},
"properties": {
"name": {
"type": "string"
"type": "string",
"description": "The official name"
},
"age": {
"type": "integer"
Expand Down
15 changes: 12 additions & 3 deletions src/commonTest/kotlin/tool/UsableToolTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.xemantic.anthropic.tool

import com.xemantic.anthropic.message.CacheControl
import com.xemantic.anthropic.message.ToolResult
import com.xemantic.anthropic.schema.Description
import com.xemantic.anthropic.schema.JsonSchema
import com.xemantic.anthropic.schema.JsonSchemaProperty
import io.kotest.assertions.assertSoftly
Expand All @@ -18,6 +19,7 @@ class UsableToolTest {
description = "Test tool receiving a message and outputting it back"
)
class TestTool(
@Description("the message")
val message: String
) : UsableTool {
override suspend fun use(
Expand All @@ -34,15 +36,19 @@ class UsableToolTest {
name shouldBe "TestTool"
description shouldBe "Test tool receiving a message and outputting it back"
inputSchema shouldBe JsonSchema(
properties = mapOf("message" to JsonSchemaProperty.STRING),
properties = mapOf("message" to JsonSchemaProperty(
type = "string",
description = "the message"
)),
required = listOf("message")
)
cacheControl shouldBe null
}
}

// TODO maybe we need a builder here?
@Test
fun shouldCreateToolWithCacheControlFromUsableTool() {
fun shouldCreateToolWithCacheControlFromUsableToolSuppliedWithCacheControl() {
// when
val tool = toolOf<TestTool>(
cacheControl = CacheControl(type = CacheControl.Type.EPHEMERAL)
Expand All @@ -52,7 +58,10 @@ class UsableToolTest {
name shouldBe "TestTool"
description shouldBe "Test tool receiving a message and outputting it back"
inputSchema shouldBe JsonSchema(
properties = mapOf("message" to JsonSchemaProperty.STRING),
properties = mapOf("message" to JsonSchemaProperty(
type = "string",
description = "the message"
)),
required = listOf("message")
)
cacheControl shouldBe CacheControl(type = CacheControl.Type.EPHEMERAL)
Expand Down

0 comments on commit 1d29673

Please sign in to comment.