From 9270e5a7ef62e9ae2f722b8be944175167b1f402 Mon Sep 17 00:00:00 2001 From: Kazik Pogoda Date: Sat, 2 Nov 2024 20:59:30 +0100 Subject: [PATCH] another refactoring greatly simplifying the use of tools --- README.md | 31 +++++++------ src/commonMain/kotlin/Anthropic.kt | 1 + .../kotlin/content/ContentBuilder.kt | 25 +++++++++++ src/commonMain/kotlin/content/Tool.kt | 24 +++------- src/commonMain/kotlin/message/Messages.kt | 18 ++------ src/commonMain/kotlin/tool/Tools.kt | 21 +++++---- src/commonMain/kotlin/tool/bash/Bash.kt | 9 ++-- .../kotlin/tool/computer/Computer.kt | 25 +++++------ .../kotlin/tool/editor/TextEditor.kt | 24 +++++----- .../kotlin/message/MessageRequestTest.kt | 13 +++--- .../kotlin/tool/AnthropicTestTools.kt | 44 +++++++++---------- src/commonTest/kotlin/tool/ToolInputTest.kt | 23 ++++------ .../kotlin/tool/computer/JvmComputer.kt | 20 ++------- .../kotlin/tool/editor/JvmTextEditor.kt | 19 +------- 14 files changed, 128 insertions(+), 169 deletions(-) create mode 100644 src/commonMain/kotlin/content/ContentBuilder.kt diff --git a/README.md b/README.md index 552e071..e8e60cf 100644 --- a/README.md +++ b/README.md @@ -132,13 +132,13 @@ If you want to write AI agents, you need tools, and this is where this library s ```kotlin @AnthropicTool("get_weather") @Description("Get the weather for a specific location") -data class WeatherTool(val location: String): ToolInput { - override fun use( - toolUseId: String - ) = ToolResult( - toolUseId, - "The weather is 73f" // it should use some external service - ) +data class WeatherTool(val location: String): ToolInput() { + init { + use { + // in the real world it should use some external service + +"The weather is 73f" + } + } } fun main() = runBlocking { @@ -192,21 +192,20 @@ internet or DB connection pool to access the database. ```kotlin @AnthropicTool("query_database") @Description("Executes SQL on the database") -data class QueryDatabase(val sql: String): ToolInput { +data class QueryDatabase(val sql: String): ToolInput() { @Transient internal lateinit var connection: Connection - override fun use( - toolUseId: String - ) = ToolResult( - toolUseId, - text = connection.prepareStatement(sql).use { statement -> - statement.resultSet.use { resultSet -> - resultSet.toString() + init { + use { + +connection.prepareStatement(sql).use { statement -> + statement.executeQuery().use { resultSet -> + resultSet.toString() + } } } - ) + } } diff --git a/src/commonMain/kotlin/Anthropic.kt b/src/commonMain/kotlin/Anthropic.kt index 347f619..65fef1d 100644 --- a/src/commonMain/kotlin/Anthropic.kt +++ b/src/commonMain/kotlin/Anthropic.kt @@ -95,6 +95,7 @@ class Anthropic internal constructor( var tools: List = emptyList() + // TODO in the future this should be rather Tool builder inline fun tool( cacheControl: CacheControl? = null, noinline inputInitializer: T.() -> Unit = {} diff --git a/src/commonMain/kotlin/content/ContentBuilder.kt b/src/commonMain/kotlin/content/ContentBuilder.kt new file mode 100644 index 0000000..777e532 --- /dev/null +++ b/src/commonMain/kotlin/content/ContentBuilder.kt @@ -0,0 +1,25 @@ +package com.xemantic.anthropic.content + +import com.xemantic.anthropic.message.Content + +interface ContentBuilder { + + val content: MutableList + + operator fun Content.unaryPlus() { + content += this + } + + operator fun String.unaryPlus() { + content += Text(this) + } + + operator fun Number.unaryPlus() { + content += Text(this.toString()) + } + + operator fun Collection.unaryPlus() { + content += this + } + +} diff --git a/src/commonMain/kotlin/content/Tool.kt b/src/commonMain/kotlin/content/Tool.kt index a5f823e..5c2e7de 100644 --- a/src/commonMain/kotlin/content/Tool.kt +++ b/src/commonMain/kotlin/content/Tool.kt @@ -10,7 +10,6 @@ import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.Transient import kotlinx.serialization.json.JsonObject -import kotlin.collections.plus import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -69,35 +68,24 @@ data class ToolResult( override val cacheControl: CacheControl? = null ) : Content() { - class Builder { + class Builder : ContentBuilder { + + override val content: MutableList = mutableListOf() - var content: List = emptyList() var isError: Boolean? = null var cacheControl: CacheControl? = null - fun content(message: String) { - content = listOf(Text(text = message)) - } - fun error(message: String) { - content(message) + +message isError = true } - operator fun List.unaryPlus() { - content += this - } - - operator fun String.unaryPlus() { - content(this) - } - } } @OptIn(ExperimentalContracts::class) -fun ToolResult( +inline fun ToolResult( toolUseId: String, block: ToolResult.Builder.() -> Unit = {} ): ToolResult { @@ -105,7 +93,7 @@ fun ToolResult( callsInPlace(block, InvocationKind.EXACTLY_ONCE) } val builder = ToolResult.Builder() - builder.apply(block) + block(builder) return ToolResult( toolUseId = toolUseId, content = builder.content.toNullIfEmpty(), diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt index f37343b..d463600 100644 --- a/src/commonMain/kotlin/message/Messages.kt +++ b/src/commonMain/kotlin/message/Messages.kt @@ -3,7 +3,7 @@ package com.xemantic.anthropic.message import com.xemantic.anthropic.Model import com.xemantic.anthropic.Response import com.xemantic.anthropic.cache.CacheControl -import com.xemantic.anthropic.content.Text +import com.xemantic.anthropic.content.ContentBuilder import com.xemantic.anthropic.tool.Tool import com.xemantic.anthropic.tool.ToolChoice import com.xemantic.anthropic.tool.ToolInput @@ -176,21 +176,11 @@ data class Message( val content: List ) { - class Builder { - var role = Role.USER - val content = mutableListOf() - - operator fun Content.unaryPlus() { - content += this - } + class Builder : ContentBuilder { - operator fun List.unaryPlus() { - content += this - } + override val content = mutableListOf() - operator fun String.unaryPlus() { - content += Text(this) - } + var role = Role.USER fun build() = Message( role = role, diff --git a/src/commonMain/kotlin/tool/Tools.kt b/src/commonMain/kotlin/tool/Tools.kt index 2742e86..166fee0 100644 --- a/src/commonMain/kotlin/tool/Tools.kt +++ b/src/commonMain/kotlin/tool/Tools.kt @@ -33,13 +33,6 @@ abstract class Tool { @PublishedApi internal lateinit var inputInitializer: ToolInput.() -> Unit - inline fun initialize( - noinline block: T.() -> Unit - ) { - @Suppress("UNCHECKED_CAST") - inputInitializer = block as ToolInput.() -> Unit - } - } @Serializable @@ -73,7 +66,13 @@ abstract class BuiltInTool( * with a given tool use ID. The implementation of the [use] method should * contain the logic for executing the tool and returning the [ToolResult]. */ -interface ToolInput { +abstract class ToolInput() { + + private lateinit var block: suspend ToolResult.Builder.() -> Unit + + fun use(block: suspend ToolResult.Builder.() -> Unit) { + this.block = block + } /** * Executes the tool and returns the result. @@ -81,7 +80,11 @@ interface ToolInput { * @param toolUseId A unique identifier for this particular use of the tool. * @return A [ToolResult] containing the outcome of executing the tool. */ - suspend fun use(toolUseId: String): ToolResult + suspend fun use(toolUseId: String): ToolResult { + return ToolResult(toolUseId) { + block(this) + } + } } diff --git a/src/commonMain/kotlin/tool/bash/Bash.kt b/src/commonMain/kotlin/tool/bash/Bash.kt index 203840b..45f8387 100644 --- a/src/commonMain/kotlin/tool/bash/Bash.kt +++ b/src/commonMain/kotlin/tool/bash/Bash.kt @@ -1,7 +1,6 @@ package com.xemantic.anthropic.tool.bash import com.xemantic.anthropic.cache.CacheControl -import com.xemantic.anthropic.content.ToolResult import com.xemantic.anthropic.tool.BuiltInTool import com.xemantic.anthropic.tool.ToolInput import kotlinx.serialization.ExperimentalSerializationApi @@ -22,10 +21,12 @@ data class Bash( data class Input( val command: String, val restart: Boolean? = false, - ) : ToolInput { + ) : ToolInput() { - override suspend fun use(toolUseId: String): ToolResult { - TODO("Not yet implemented") + init { + use { + TODO("Not yet implemented") + } } } diff --git a/src/commonMain/kotlin/tool/computer/Computer.kt b/src/commonMain/kotlin/tool/computer/Computer.kt index 9f113a0..8270e6c 100644 --- a/src/commonMain/kotlin/tool/computer/Computer.kt +++ b/src/commonMain/kotlin/tool/computer/Computer.kt @@ -1,7 +1,7 @@ package com.xemantic.anthropic.tool.computer import com.xemantic.anthropic.cache.CacheControl -import com.xemantic.anthropic.content.ToolResult +import com.xemantic.anthropic.content.Image import com.xemantic.anthropic.tool.BuiltInTool import com.xemantic.anthropic.tool.ToolInput import kotlinx.serialization.ExperimentalSerializationApi @@ -27,9 +27,6 @@ data class Computer( init { inputSerializer = Input.serializer() - initialize { - service = computerService - } } @Serializable @@ -37,14 +34,19 @@ data class Computer( val action: Action, val coordinate: Coordinate?, val text: String - ) : ToolInput { + ) : ToolInput() { @Transient lateinit var service: ComputerService - override suspend fun use( - toolUseId: String - ) = service.use(toolUseId, this) + init { + use { + when (action) { + Action.SCREENSHOT -> +service.screenshot() + else -> TODO("Not implemented yet") + } + } + } } @@ -81,11 +83,6 @@ data class Coordinate(val x: Int, val y: Int) interface ComputerService { - suspend fun use( - toolUseId: String, - input: Computer.Input - ): ToolResult + fun screenshot(): Image } - -expect val computerService: ComputerService diff --git a/src/commonMain/kotlin/tool/editor/TextEditor.kt b/src/commonMain/kotlin/tool/editor/TextEditor.kt index f0c616d..dac8d91 100644 --- a/src/commonMain/kotlin/tool/editor/TextEditor.kt +++ b/src/commonMain/kotlin/tool/editor/TextEditor.kt @@ -1,7 +1,6 @@ package com.xemantic.anthropic.tool.editor import com.xemantic.anthropic.cache.CacheControl -import com.xemantic.anthropic.content.ToolResult import com.xemantic.anthropic.tool.BuiltInTool import com.xemantic.anthropic.tool.ToolInput import kotlinx.serialization.ExperimentalSerializationApi @@ -21,9 +20,6 @@ data class TextEditor( init { inputSerializer = Input.serializer() - inputInitializer = { - //service = computerService - } } @Serializable @@ -41,14 +37,19 @@ data class TextEditor( val path: String, @SerialName("view_range") val viewRange: Int? = 0 - ) : ToolInput { + ) : ToolInput() { @Transient lateinit var service: TextEditorService - override suspend fun use( - toolUseId: String - ) = service.use(toolUseId, this) + init { + use { + when (command) { + Command.VIEW -> service.view(path) + else -> TODO("not implemented yet") + } + } + } } @@ -70,11 +71,6 @@ enum class Command { interface TextEditorService { - suspend fun use( - toolUseId: String, - input: TextEditor.Input - ): ToolResult + fun view(path: String) } - -expect val textEditorService: TextEditorService diff --git a/src/commonTest/kotlin/message/MessageRequestTest.kt b/src/commonTest/kotlin/message/MessageRequestTest.kt index dafbff4..3525067 100644 --- a/src/commonTest/kotlin/message/MessageRequestTest.kt +++ b/src/commonTest/kotlin/message/MessageRequestTest.kt @@ -1,6 +1,5 @@ package com.xemantic.anthropic.message -import com.xemantic.anthropic.content.ToolResult import com.xemantic.anthropic.message.MessageRequestTest.TemperatureUnit import com.xemantic.anthropic.schema.Description import com.xemantic.anthropic.test.testJson @@ -25,12 +24,12 @@ data class GetWeather( val location: String, @Description("The unit of temperature, either 'celsius' or 'fahrenheit'") val unit: TemperatureUnit? = null -) : ToolInput { - - override suspend fun use( - toolUseId: String - ) = ToolResult(toolUseId) { +"42" } - +) : ToolInput() { + init { + use { + +"42" + } + } } /** diff --git a/src/commonTest/kotlin/tool/AnthropicTestTools.kt b/src/commonTest/kotlin/tool/AnthropicTestTools.kt index 74d32d1..7dc3c60 100644 --- a/src/commonTest/kotlin/tool/AnthropicTestTools.kt +++ b/src/commonTest/kotlin/tool/AnthropicTestTools.kt @@ -1,25 +1,22 @@ package com.xemantic.anthropic.tool -import com.xemantic.anthropic.content.ToolResult import com.xemantic.anthropic.schema.Description import kotlinx.serialization.Transient +tailrec fun fibonacci( + n: Int, a: Int = 0, b: Int = 1 +): Int = when (n) { + 0 -> a; 1 -> b; else -> fibonacci(n - 1, b, a + b) +} + @AnthropicTool("FibonacciTool") @Description("Calculate Fibonacci number n") -data class FibonacciTool(val n: Int): ToolInput { - - tailrec fun fibonacci( - n: Int, a: Int = 0, b: Int = 1 - ): Int = when (n) { - 0 -> a; 1 -> b; else -> fibonacci(n - 1, b, a + b) - } - - override suspend fun use( - toolUseId: String, - ) = ToolResult(toolUseId) { - +"${fibonacci(n)}" +data class FibonacciTool(val n: Int) : ToolInput() { + init { + use { + +fibonacci(n) + } } - } @AnthropicTool("Calculator") @@ -28,7 +25,7 @@ data class Calculator( val operation: Operation, val a: Double, val b: Double -): ToolInput { +): ToolInput() { @Suppress("unused") // it is used, but by Anthropic, so we skip the warning enum class Operation( @@ -40,8 +37,10 @@ data class Calculator( DIVIDE({ a, b -> a / b }) } - override suspend fun use(toolUseId: String) = ToolResult(toolUseId) { - +operation.calculate(a, b).toString() + init { + use { + +operation.calculate(a, b) + } } } @@ -64,17 +63,14 @@ class TestDatabase : Database { @Description("Executes database query") data class DatabaseQuery( val query: String -) : ToolInput { +) : ToolInput() { @Transient internal lateinit var database: Database - override suspend fun use( - toolUseId: String - ): ToolResult { - val result = database.execute(query).joinToString() - return ToolResult(toolUseId) { - +result + init { + use { + +database.execute(query).joinToString() } } diff --git a/src/commonTest/kotlin/tool/ToolInputTest.kt b/src/commonTest/kotlin/tool/ToolInputTest.kt index c6b290f..3f3a3e9 100644 --- a/src/commonTest/kotlin/tool/ToolInputTest.kt +++ b/src/commonTest/kotlin/tool/ToolInputTest.kt @@ -1,7 +1,6 @@ package com.xemantic.anthropic.tool import com.xemantic.anthropic.cache.CacheControl -import com.xemantic.anthropic.content.ToolResult import com.xemantic.anthropic.schema.Description import com.xemantic.anthropic.schema.JsonSchema import com.xemantic.anthropic.schema.JsonSchemaProperty @@ -19,10 +18,12 @@ class ToolInputTest { class TestToolInput( @Description("the message") val message: String - ) : ToolInput { - override suspend fun use( - toolUseId: String - ) = ToolResult(toolUseId) { +message } + ) : ToolInput() { + init { + use { + +message + } + } } @Test @@ -66,11 +67,7 @@ class ToolInputTest { } } - class NoAnnotationTool : ToolInput { - override suspend fun use( - toolUseId: String - ) = ToolResult(toolUseId) - } + class NoAnnotationTool : ToolInput() @Test fun shouldFailToCreateToolWithoutAnthropicToolAnnotation() { @@ -83,11 +80,7 @@ class ToolInputTest { } @Serializable - class OnlySerializableAnnotationTool : ToolInput { - override suspend fun use( - toolUseId: String - ) = ToolResult(toolUseId) - } + class OnlySerializableAnnotationTool : ToolInput() @Test fun shouldFailToCreateToolWithOnlySerializableAnnotation() { diff --git a/src/jvmMain/kotlin/tool/computer/JvmComputer.kt b/src/jvmMain/kotlin/tool/computer/JvmComputer.kt index e1338cd..60df2c1 100644 --- a/src/jvmMain/kotlin/tool/computer/JvmComputer.kt +++ b/src/jvmMain/kotlin/tool/computer/JvmComputer.kt @@ -1,7 +1,6 @@ package com.xemantic.anthropic.tool.computer import com.xemantic.anthropic.content.Image -import com.xemantic.anthropic.content.ToolResult import java.awt.Rectangle import java.awt.Robot import java.awt.Toolkit @@ -10,20 +9,9 @@ import javax.imageio.ImageIO object JvmComputerService : ComputerService { - override suspend fun use( - toolUseId: String, - input: Computer.Input - ) = when (input.action) { - Action.SCREENSHOT -> ToolResult( - toolUseId = toolUseId, - content = listOf( - Image { - data = takeScreenshot() - mediaType = Image.MediaType.IMAGE_JPEG - } - ) - ) - else -> TODO() + override fun screenshot() = Image { + data = takeScreenshot() + mediaType = Image.MediaType.IMAGE_JPEG } } @@ -38,5 +26,3 @@ fun takeScreenshot(): ByteArray { } return output.toByteArray() } - -actual val computerService: ComputerService get() = JvmComputerService diff --git a/src/jvmMain/kotlin/tool/editor/JvmTextEditor.kt b/src/jvmMain/kotlin/tool/editor/JvmTextEditor.kt index 6690b25..003a65b 100644 --- a/src/jvmMain/kotlin/tool/editor/JvmTextEditor.kt +++ b/src/jvmMain/kotlin/tool/editor/JvmTextEditor.kt @@ -1,24 +1,9 @@ package com.xemantic.anthropic.tool.editor -import com.xemantic.anthropic.content.ToolResult -import java.io.File - object JvmTextEditorService : TextEditorService { - override suspend fun use( - toolUseId: String, - input: TextEditor.Input - ): ToolResult { - val content = if (input.command == Command.VIEW) { - File(input.path).readText() - } else { - TODO("Not implemented yet") - } - return ToolResult(toolUseId) { - content(content) - } + override fun view(path: String) { + TODO("Not yet implemented") } } - -actual val textEditorService: TextEditorService get() = JvmTextEditorService