diff --git a/README.md b/README.md index 0568979..6c7340a 100644 --- a/README.md +++ b/README.md @@ -3,42 +3,44 @@ Unofficial Kotlin multiplatform variant of the [Antropic SDK](https://docs.anthropic.com/en/api/client-sdks). -[Maven Central Version](https://central.sonatype.com/namespace/com.xemantic.anthropic) -[GitHub Release Date](https://github.com/xemantic/anthropic-sdk-kotlin/releases) -[license](https://github.com/xemantic/anthropic-sdk-kotlin/blob/main/LICENSE) - -[GitHub Actions Workflow Status](https://github.com/xemantic/anthropic-sdk-kotlin/actions/workflows/build-main.yml) -[GitHub branch check runs](https://github.com/xemantic/anthropic-sdk-kotlin/actions/workflows/build-main.yml) -[GitHub commits since latest release](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/) -[GitHub last commit](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/) - -[GitHub contributors](https://github.com/xemantic/anthropic-sdk-kotlin/graphs/contributors) -[GitHub commit activity](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/) -[GitHub code size in bytes]() -[GitHub Created At](https://github.com/xemantic/anthropic-sdk-kotlin/commit/39c1fa4c138d4c671868c973e2ad37b262ae03c2) -[kotlin version](https://kotlinlang.org/docs/releases.html) -[ktor version](https://ktor.io/) - -[discord server](https://discord.gg/vQktqqN2Vn) -[discord users online](https://discord.gg/vQktqqN2Vn) -[X (formerly Twitter) Follow](https://x.com/KazikPogoda) +[Maven Central Version](https://central.sonatype.com/namespace/com.xemantic.anthropic) +[GitHub Release Date](https://github.com/xemantic/anthropic-sdk-kotlin/releases) +[license](https://github.com/xemantic/anthropic-sdk-kotlin/blob/main/LICENSE) + +[GitHub Actions Workflow Status](https://github.com/xemantic/anthropic-sdk-kotlin/actions/workflows/build-main.yml) +[GitHub branch check runs](https://github.com/xemantic/anthropic-sdk-kotlin/actions/workflows/build-main.yml) +[GitHub commits since latest release](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/) +[GitHub last commit](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/) + +[GitHub contributors](https://github.com/xemantic/anthropic-sdk-kotlin/graphs/contributors) +[GitHub commit activity](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/) +[GitHub code size in bytes]() +[GitHub Created At](https://github.com/xemantic/anthropic-sdk-kotlin/commit/39c1fa4c138d4c671868c973e2ad37b262ae03c2) +[kotlin version](https://kotlinlang.org/docs/releases.html) +[ktor version](https://ktor.io/) + +[discord server](https://discord.gg/vQktqqN2Vn) +[discord users online](https://discord.gg/vQktqqN2Vn) +[X (formerly Twitter) Follow](https://x.com/KazikPogoda) ## Why? -I like Kotlin. I like even more the multiplatform aspect of pure Kotlin - that a library code written once -can be utilized as a: - -* regular Java library to be used on backend, desktop, Android, etc. -* Kotlin library to be used on backend, desktop, Android. -* executable native binary (e.g. a command line tool) -* Kotlin app transpiled to JavaScript -* Kotlin app compiled to WebAssembly -* JavaScript library -* TypeScript library -* native library, working also with Swift/iOS - -Having Kotlin multiplatform library for the Anthropic APIs allows -me to write AI code once, and target all the platforms automatically. +Because I believe that coding Agentic AI should be as easy as possible. I am coming from the +[creative coding community](https://creativecode.berlin/), where +we are teaching artists, without prior programming experience, how to express their creations through +code as a medium. I want to give creators of all kinds this extremely powerful tool, so that +**you can turn your own machine into an outside window, through which, the AI system can perceive +your world and your needs, and act upon this information.** + +There is no official Anthropic SDK for Kotlin, a de facto standard for Android development. The one for Java +is also lacking. Even if they will appear one day, we can expect them to be autogenerated by the +[Stainless API bot](https://www.stainlessapi.com/), which is used by both, Anthropic and OpenAI, to automate +their SDK development based on evolving API. While such an approach seem to work with dynamically typed languages, +it might fail short with statically typed languages like Kotlin, sacrificing typical language idioms in favor +of [over-verbose constructs](https://github.com/anthropics/anthropic-sdk-go/blob/main/examples/tools/main.go). +This library is a [Kotlin multiplatform](https://kotlinlang.org/docs/multiplatform.html) +therefore your AI agents developed with it can be seamlessly used in Android, JVM, JavaScript, iOS, WebAssembly, +and many other environments. ## Usage @@ -50,9 +52,7 @@ Add to your `build.gradle.kts`: ```kotlin dependencies { implementation("com.xemantic.anthropic:anthropic-sdk-kotlin:.0.2.2") - // for a JVM project, the client engine will differ per platform - // check ktor doucmentation for details - implementation("io.ktor:ktor-client-java:3.0.0-rc-1") + implementation("io.ktor:ktor-client-java:3.0.0") // or the latest ktor version } ``` @@ -60,11 +60,11 @@ The simplest code look like: ```kotlin fun main() { - val client = Anthropic() + val anthropic = Anthropic() val response = runBlocking { - client.messages.create { + anthropic.messages.create { +Message { - +"Hello World!" + +"Hello, Claude" } } } @@ -72,77 +72,136 @@ fun main() { } ``` +### Response streaming + Streaming is also possible: ```kotlin fun main() { val client = Anthropic() - val pong = runBlocking { + runBlocking { client.messages.stream { - +Message { - role = Role.USER - +"ping!" - } + +Message { +"Write me a poem." } } - .filterIsInstance() - .map { (it.delta as TextDelta).text } - .toList() - .joinToString(separator = "") + .filterIsInstance() + .map { (it.delta as Delta.TextDelta).text } + .collect { delta -> println(delta) } } - println(pong) } ``` -It can also use tools: +### Using tools + +If you want to write AI agents, you need tools, and this is where this library shines: ```kotlin -@Serializable -data class Calculator( - val operation: Operation, - val a: Double, - val b: Double -) { - - @Suppress("unused") // will be used by Anthropic :) - enum class Operation( - val calculate: (a: Double, b: Double) -> Double - ) { - ADD({ a, b -> a + b }), - SUBTRACT({ a, b -> a - b }), - MULTIPLY({ a, b -> a * b }), - DIVIDE({ a, b -> a / b }) +@AnthropicTool( + name = "get_weather", + description = "Get the weather for a specific location" +) +data class WeatherTool(val location: String): UsableTool { + override fun use( + toolUseId: String + ) = ToolResult( + toolUseId, + "The weather is 73f" // it should use some external service + ) +} + +fun main() = runBlocking { + + val client = Anthropic { + tool() + } + + val conversation = mutableListOf() + conversation += Message { +"What is the weather in SF?" } + + val initialResponse = client.messages.create { + messages = conversation + useTools() } + println("Initial response:") + println(initialResponse) - fun calculate() = operation.calculate(a, b) + conversation += initialResponse.asMessage() + val tool = initialResponse.content.filterIsInstance().first() + val toolResult = tool.use() + conversation += Message { +toolResult } + val finalResponse = client.messages.create { + messages = conversation + useTools() + } + println("Final response:") + println(finalResponse) } +``` -fun main() { - val client = Anthropic() +The advantage comes no only from reduced verbosity, but also the class annotated with +the `@AnthropicTool` will have its JSON schema automatically sent to the Anthropic API when +defining the tool to use. For the reference check equivalent examples in the official +Anthropic SDKs: - val calculatorTool = Tool( - description = "Perform basic arithmetic operations" - ) +* [TypeScript](https://github.com/anthropics/anthropic-sdk-typescript/blob/main/examples/tools.ts) +* [Python](https://github.com/anthropics/anthropic-sdk-python/blob/main/examples/tools.py) +* [Go](https://github.com/anthropics/anthropic-sdk-go/blob/main/examples/tools/main.go) - val response = runBlocking { - client.messages.create { - +Message { - +"What's 15 multiplied by 7?" +None of them is taking the advantage of automatic schema generation, which becomes crucial +for maintaining agents expecting more complex and structured input from the LLM. + +### Injecting dependencies to tools + +Tools can be provided with dependencies, for example singleton +services providing some facilities, like HTTP client to connect to the +internet or DB connection pool to access the database. + +```kotlin +@AnthropicTool( + name = "query_database", + description = "Executes SQL on the database" +) +data class DatabaseQueryTool(val sql: String): UsableTool { + + internal lateinit var connection: Connection + + override fun use( + toolUseId: String + ) = ToolResult( + toolUseId, + text = connection.prepareStatement(sql).use { statement -> + statement.resultSet.use { resultSet -> + resultSet.toString() } - tools = listOf(calculatorTool) - toolChoice = ToolChoice.Any() } + ) + +} + +fun main() = runBlocking { + + val client = Anthropic { + tool { + connection = DriverManager.getConnection("jdbc:...") + } + } + + val response = client.messages.create { + +Message { +"Select all the users who never logged in to the the system" } + useTools() } - val toolUse = response.content[0] as ToolUse - val calculator = toolUse.input() - val result = calculator.calculate() // we are doing the job for LLM here - println(result) + val tool = response.content.filterIsInstance().first() + val toolResult = tool.use() + println(toolResult) } ``` -More sophisticated code examples targeting various -platforms will follow in the +After the `DatabaseQueryTool` is decoded from the API response, it can be processed +by the lambda function passed to the tool definition. In case of the example above, +the lambda will inject a JDBC connection to the tool. + +More sophisticated code examples targeting various Kotlin platforms can be found in the [anthropic-sdk-kotlin-demo](https://github.com/xemantic/anthropic-sdk-kotlin-demo) project. diff --git a/build.gradle.kts b/build.gradle.kts index 2d87906..bb2aaf5 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -6,7 +6,7 @@ import org.gradle.api.tasks.testing.logging.TestLogEvent import org.jetbrains.kotlin.gradle.ExperimentalKotlinGradlePluginApi import org.jetbrains.kotlin.gradle.dsl.JvmTarget import org.jetbrains.kotlin.gradle.dsl.KotlinVersion -import org.jetbrains.kotlin.gradle.tasks.KotlinJvmCompile +import org.jetbrains.kotlin.gradle.targets.native.tasks.KotlinNativeTest plugins { alias(libs.plugins.kotlin.multiplatform) @@ -32,6 +32,10 @@ val signingPassword: String? by project val sonatypeUser: String? by project val sonatypePassword: String? by project +// we don't want to risk that a flaky test will crash the release build +// and everything should be tested anyway after merging to the main branch +val skipTests = isReleaseBuild + println(""" Project: ${project.name} Version: ${project.version} @@ -40,12 +44,25 @@ println(""" ) repositories { - mavenCentral() + mavenCentral() } kotlin { - jvm {} + //explicitApi() // check with serialization? + jvm { + testRuns["test"].executionTask.configure { + useJUnitPlatform() + } + // set up according to https://jakewharton.com/gradle-toolchains-are-rarely-a-good-idea/ + compilerOptions { + apiVersion = kotlinTarget + languageVersion = kotlinTarget + jvmTarget = JvmTarget.fromTarget(javaTarget) + freeCompilerArgs.add("-Xjdk-release=$javaTarget") + progressiveMode = true + } + } linuxX64() @@ -64,6 +81,7 @@ kotlin { dependencies { implementation(libs.kotlin.test) implementation(libs.kotlinx.coroutines.test) + implementation(libs.kotest.assertions.core) implementation(libs.kotest.assertions.json) } } @@ -89,17 +107,6 @@ kotlin { } -// set up according to https://jakewharton.com/gradle-toolchains-are-rarely-a-good-idea/ -tasks.withType { - compilerOptions { - apiVersion = kotlinTarget - languageVersion = kotlinTarget - jvmTarget = JvmTarget.fromTarget(javaTarget) - freeCompilerArgs.add("-Xjdk-release=$javaTarget") - progressiveMode = true - } -} - fun isNonStable(version: String): Boolean { val stableKeyword = listOf("RELEASE", "FINAL", "GA").any { version.uppercase().contains(it) } val regex = "^[0-9,.v-]+(-r)?$".toRegex() @@ -113,7 +120,7 @@ tasks.withType { } } -tasks.withType() { +tasks.withType { testLogging { events( TestLogEvent.PASSED, @@ -123,15 +130,16 @@ tasks.withType() { showStackTraces = true exceptionFormat = TestExceptionFormat.FULL } + enabled = !skipTests +} + +tasks.withType { + enabled = !skipTests } -@Suppress("OPT_IN_USAGE") powerAssert { functions = listOf( - "kotlin.assert", - "kotlin.test.assertTrue", - "kotlin.test.assertEquals", - "kotlin.test.assertNull" + "io.kotest.matchers.shouldBe" ) includedSourceSets = listOf("commonTest", "jvmTest", "nativeTest") } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 173847e..015165c 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -2,11 +2,13 @@ kotlinTarget = "2.0" javaTarget = "17" -kotlin = "2.0.20" +kotlin = "2.0.21" kotlinxCoroutines = "1.9.0" -ktor = "3.0.0-rc-2" -kotest = "5.9.1" +ktor = "3.0.0" +kotest = "6.0.0.M1" +# logging is not used at the moment, might be enabled later +#kotlinLogging = "7.0.0" log4j = "2.24.1" jackson = "2.18.0" @@ -18,10 +20,12 @@ publishPlugin = "2.0.0" kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" } kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version.ref = "kotlinxCoroutines" } -log4j-slf4j2 = { group = "org.apache.logging.log4j", name = "log4j-slf4j2-impl", version.ref = "log4j" } -log4j-core = { group = "org.apache.logging.log4j", name = "log4j-core", version.ref = "log4j" } -jackson-databind = { group = "com.fasterxml.jackson.core", name = "jackson-databind", version.ref = "jackson" } -jackson-dataformat-yaml = { group = "com.fasterxml.jackson.dataformat", name = "jackson-dataformat-yaml", version.ref = "jackson" } +# logging libs +#kotlin-logging = { module = "io.github.oshai:kotlin-logging", version.ref = "kotlinLogging" } +log4j-slf4j2 = { module = "org.apache.logging.log4j:log4j-slf4j2-impl", version.ref = "log4j" } +log4j-core = { module = "org.apache.logging.log4j:log4j-core", version.ref = "log4j" } +jackson-databind = { module = "com.fasterxml.jackson.core:jackson-databind", version.ref = "jackson" } +jackson-dataformat-yaml = { module = "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml", version.ref = "jackson" } ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" } ktor-client-content-negotiation = { module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" } @@ -30,6 +34,7 @@ ktor-serialization-kotlinx-json = { module = "io.ktor:ktor-serialization-kotlinx ktor-client-java = { module = "io.ktor:ktor-client-java", version.ref = "ktor" } ktor-client-curl = { module = "io.ktor:ktor-client-curl", version.ref = "ktor" } +kotest-assertions-core = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" } kotest-assertions-json = { module = "io.kotest:kotest-assertions-json", version.ref = "kotest" } [plugins] diff --git a/src/commonMain/kotlin/Anthropic.kt b/src/commonMain/kotlin/Anthropic.kt index b38f0a4..72d0834 100644 --- a/src/commonMain/kotlin/Anthropic.kt +++ b/src/commonMain/kotlin/Anthropic.kt @@ -5,10 +5,14 @@ import com.xemantic.anthropic.message.Error import com.xemantic.anthropic.message.ErrorResponse import com.xemantic.anthropic.message.MessageRequest import com.xemantic.anthropic.message.MessageResponse +import com.xemantic.anthropic.message.Tool +import com.xemantic.anthropic.message.ToolUse +import com.xemantic.anthropic.tool.UsableTool +import com.xemantic.anthropic.tool.toolOf import io.ktor.client.HttpClient import io.ktor.client.call.body +import io.ktor.client.plugins.* import io.ktor.client.plugins.contentnegotiation.ContentNegotiation -import io.ktor.client.plugins.defaultRequest import io.ktor.client.plugins.logging.LogLevel import io.ktor.client.plugins.logging.Logging import io.ktor.client.plugins.sse.SSE @@ -26,12 +30,25 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map +import kotlinx.serialization.KSerializer import kotlinx.serialization.json.Json +import kotlinx.serialization.serializer +import kotlin.reflect.KType +import kotlin.reflect.typeOf +/** + * The default Anthropic API base. + */ const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/" +/** + * The default version to be passed to the `anthropic-version` HTTP header of each API request. + */ const val DEFAULT_ANTHROPIC_VERSION: String = "2023-06-01" +/** + * An exception thrown when API requests returns error. + */ class AnthropicException( error: Error, httpStatusCode: HttpStatusCode @@ -41,19 +58,27 @@ expect val envApiKey: String? expect val missingApiKeyMessage: String +/** + * A JSON format suitable for communication with Anthropic API. + */ val anthropicJson: Json = Json { allowSpecialFloatingPointValues = true explicitNulls = false encodeDefaults = true } +/** + * The public constructor function which for the Anthropic API client. + * + * @param block the config block to set up the API access. + */ fun Anthropic( block: Anthropic.Config.() -> Unit = {} ): Anthropic { val config = Anthropic.Config().apply(block) val apiKey = if (config.apiKey != null) config.apiKey else envApiKey requireNotNull(apiKey) { missingApiKeyMessage } - val defaultModel = if (config.defaultModel != null) config.defaultModel!! else "claude-3-opus-20240229" + val defaultModel = if (config.defaultModel != null) config.defaultModel!! else "claude-3-5-sonnet-20240620" return Anthropic( apiKey = apiKey, anthropicVersion = config.anthropicVersion, @@ -61,8 +86,10 @@ fun Anthropic( apiBase = config.apiBase, defaultModel = defaultModel, directBrowserAccess = config.directBrowserAccess - ) -} + ).apply { + toolEntryMap = (config.usableTools as List>).associateBy { it.tool.name } + } +} // TODO this can be a second constructor, then toolMap can be private class Anthropic internal constructor( val apiKey: String, @@ -80,8 +107,28 @@ class Anthropic internal constructor( var apiBase: String = ANTHROPIC_API_BASE var defaultModel: String? = null var directBrowserAccess: Boolean = false + @PublishedApi + internal var usableTools: List> = emptyList() + + inline fun tool( + noinline block: T.() -> Unit = {} + ) { + val entry = ToolEntry(typeOf(), toolOf(), serializer(), block) + usableTools += entry + } + } + @PublishedApi + internal class ToolEntry( + val type: KType, + val tool: Tool, // TODO, no cache control + val serializer: KSerializer, + val initialize: T.() -> Unit = {} + ) + + internal var toolEntryMap = mapOf>() + private val client = HttpClient { install(ContentNegotiation) { json(anthropicJson) @@ -90,6 +137,14 @@ class Anthropic internal constructor( install(Logging) { level = LogLevel.BODY } + install(HttpRequestRetry) { + retryOnServerErrors(maxRetries = 5) + exponentialDelay() + maxRetries = 5 + retryIf { _, response -> + response.status == HttpStatusCode.TooManyRequests + } + } defaultRequest { url(apiBase) header("x-api-key", apiKey) @@ -103,20 +158,29 @@ class Anthropic internal constructor( } } - inner class Messages() { + inner class Messages { suspend fun create( block: MessageRequest.Builder.() -> Unit ): MessageResponse { + val request = MessageRequest.Builder( - defaultModel + defaultModel, + toolEntryMap = toolEntryMap ).apply(block).build() + val response = client.post("/v1/messages") { contentType(ContentType.Application.Json) setBody(request) } if (response.status.isSuccess()) { - return response.body() + return response.body().apply { + content.filterIsInstance() + .forEach { toolUse -> + val entry = toolEntryMap[toolUse.name]!! + toolUse.toolEntry = entry + } + } } else { throw AnthropicException( error = response.body().error, @@ -129,7 +193,10 @@ class Anthropic internal constructor( block: MessageRequest.Builder.() -> Unit ): Flow = flow { - val request = MessageRequest.Builder(defaultModel).apply { + val request = MessageRequest.Builder( + defaultModel, + toolEntryMap = toolEntryMap + ).apply { block(this) stream = true }.build() @@ -157,5 +224,3 @@ class Anthropic internal constructor( } -inline fun anthropicTypeOf(): String = - T::class.qualifiedName!!.replace('.', '_') diff --git a/src/commonMain/kotlin/event/Events.kt b/src/commonMain/kotlin/event/Events.kt index b9d3da8..7c5fa32 100644 --- a/src/commonMain/kotlin/event/Events.kt +++ b/src/commonMain/kotlin/event/Events.kt @@ -7,6 +7,8 @@ import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonClassDiscriminator +// reference https://docs.spring.io/spring-ai/reference/_images/anthropic-claude3-events-model.jpg + @Serializable @JsonClassDiscriminator("type") @OptIn(ExperimentalSerializationApi::class) @@ -14,14 +16,14 @@ sealed class Event @Serializable @SerialName("message_start") -data class MessageStart( +data class MessageStartEvent( val message: MessageResponse ) : Event() @Serializable @SerialName("message_delta") -data class MessageDelta( - val delta: MessageDelta.Delta, +data class MessageDeltaEvent( + val delta: Delta, val usage: Usage ) : Event() { @@ -37,13 +39,15 @@ data class MessageDelta( @Serializable @SerialName("message_stop") -class MessageStop : Event() { +class MessageStopEvent : Event() { override fun toString(): String = "MessageStop" } +// TODO error event is missing, should we rename all of these to events? + @Serializable @SerialName("content_block_start") -data class ContentBlockStart( +data class ContentBlockStartEvent( val index: Int, @SerialName("content_block") val contentBlock: ContentBlock @@ -51,7 +55,7 @@ data class ContentBlockStart( @Serializable @SerialName("content_block_stop") -data class ContentBlockStop( +data class ContentBlockStopEvent( val index: Int ) : Event() @@ -66,17 +70,23 @@ sealed class ContentBlock { val text: String ) : ContentBlock() + @Serializable + @SerialName("tool_use") + class ToolUse( + val text: String // TODO tool_id + ) : ContentBlock() + // TODO missing tool_use } @Serializable @SerialName("ping") -class Ping: Event() { +class PingEvent: Event() { override fun toString(): String = "Ping" } @Serializable @SerialName("content_block_delta") -data class ContentBlockDelta( +data class ContentBlockDeltaEvent( val index: Int, val delta: Delta ) : Event() diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt index 6044cfb..defa70c 100644 --- a/src/commonMain/kotlin/message/Messages.kt +++ b/src/commonMain/kotlin/message/Messages.kt @@ -1,16 +1,14 @@ package com.xemantic.anthropic.message +import com.xemantic.anthropic.Anthropic import com.xemantic.anthropic.anthropicJson -import com.xemantic.anthropic.anthropicTypeOf import com.xemantic.anthropic.schema.JsonSchema -import com.xemantic.anthropic.schema.jsonSchemaOf -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable +import com.xemantic.anthropic.tool.UsableTool +import kotlinx.serialization.* import kotlinx.serialization.json.JsonClassDiscriminator import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.decodeFromJsonElement import kotlin.collections.mutableListOf +import kotlin.reflect.typeOf enum class Role { @SerialName("user") @@ -35,32 +33,54 @@ data class MessageRequest( @SerialName("stop_sequences") val stopSequences: List?, val stream: Boolean?, - val system: List?, + val system: List?, val temperature: Double?, @SerialName("tool_choice") val toolChoice: ToolChoice?, val tools: List?, + @SerialName("top_k") val topK: Int?, + @SerialName("top_p") val topP: Int? ) { - class Builder( - val defaultApiModel: String + class Builder internal constructor( + private val defaultModel: String, + @PublishedApi + internal val toolEntryMap: Map> ) { var model: String? = null var maxTokens = 1024 - val messages = mutableListOf() + var messages: List = mutableListOf() var metadata = null val stopSequences = mutableListOf() var stream: Boolean? = null internal set - val systemTexts = mutableListOf() + var system: List? = null var temperature: Double? = null var toolChoice: ToolChoice? = null var tools: List? = null val topK: Int? = null val topP: Int? = null + fun useTools() { + tools = toolEntryMap.values.map { it.tool } + } + + /** + * Sets both, the [tools] list and the [toolChoice] with + * just one tool to use, forcing the API to respond with the [ToolUse]. + */ + inline fun useTool() { + val type = typeOf() + val toolEntry = toolEntryMap.values.find { it.type == type } + requireNotNull(toolEntry) { + "No such tool defined in Anthropic client: ${T::class.qualifiedName}" + } + tools = listOf(toolEntry.tool) + toolChoice = ToolChoice.Tool(name = toolEntry.tool.name) + } + fun messages(vararg messages: Message) { this.messages += messages.toList() } @@ -77,23 +97,20 @@ data class MessageRequest( this.stopSequences += stopSequences.toList() } - var system: String? - get() = if (systemTexts.isEmpty()) null else systemTexts[0].text - set(value) { - systemTexts.clear() - if (value != null) { - systemTexts.add(Text(text = value)) - } - } + fun system( + text: String + ) { + system = listOf(System(text = text)) + } fun build(): MessageRequest = MessageRequest( - model = if (model != null) model!! else defaultApiModel, + model = if (model != null) model!! else defaultModel, maxTokens = maxTokens, messages = messages, metadata = metadata, stopSequences = stopSequences.toNullIfEmpty(), stream = if (stream != null) stream else null, - system = systemTexts.toNullIfEmpty(), + system = system, temperature = temperature, toolChoice = toolChoice, tools = tools, @@ -108,7 +125,9 @@ fun MessageRequest( defaultModel: String, block: MessageRequest.Builder.() -> Unit ): MessageRequest { - val builder = MessageRequest.Builder(defaultModel) + val builder = MessageRequest.Builder( + defaultModel, emptyMap() + ) block(builder) return builder.build() } @@ -131,8 +150,15 @@ data class MessageResponse( @SerialName("message") MESSAGE } + + fun asMessage(): Message = Message { + role = Role.ASSISTANT + content += this@MessageResponse.content + } + } + @Serializable data class ErrorResponse( val type: String, @@ -180,30 +206,37 @@ fun Message(block: Message.Builder.() -> Unit): Message { return builder.build() } +@Serializable +data class System( + @SerialName("cache_control") + val cacheControl: CacheControl? = null, + val type: Type = Type.TEXT, + val text: String? = null, +) { + + enum class Type { + @SerialName("text") + TEXT + } + +} + @Serializable data class Tool( val name: String, val description: String, @SerialName("input_schema") val inputSchema: JsonSchema, + @SerialName("cache_control") val cacheControl: CacheControl? ) -inline fun Tool( - description: String, - cacheControl: CacheControl? = null -): Tool = Tool( - name = anthropicTypeOf(), - description = description, - inputSchema = jsonSchemaOf(), - cacheControl = cacheControl -) - @Serializable @JsonClassDiscriminator("type") @OptIn(ExperimentalSerializationApi::class) sealed class Content { + @SerialName("cache_control") abstract val cacheControl: CacheControl? } @@ -212,6 +245,7 @@ sealed class Content { @SerialName("text") data class Text( val text: String, + @SerialName("cache_control") override val cacheControl: CacheControl? = null, ) : Content() @@ -219,6 +253,7 @@ data class Text( @SerialName("image") data class Image( val source: Source, + @SerialName("cache_control") override val cacheControl: CacheControl? = null ) : Content() { @@ -250,38 +285,70 @@ data class Image( } -@Serializable @SerialName("tool_use") +@Serializable data class ToolUse( + @SerialName("cache_control") override val cacheControl: CacheControl? = null, val id: String, val name: String, val input: JsonObject ) : Content() { - inline fun input(): T = - anthropicJson.decodeFromJsonElement(input) + @Transient + internal lateinit var toolEntry: Anthropic.ToolEntry + + fun use(): ToolResult { + val tool = anthropicJson.decodeFromJsonElement( + deserializer = toolEntry.serializer, + element = input + ) + val result = try { + toolEntry.initialize(tool) + tool.use(toolUseId = id) + } catch (e: Exception) { + ToolResult( + toolUseId = id, + isError = true, + content = listOf( + Text( + text = e.message ?: "Unknown error occurred" + ) + ) + ) + } + return result + } } @Serializable @SerialName("tool_result") data class ToolResult( - override val cacheControl: CacheControl? = null, @SerialName("tool_use_id") val toolUseId: String, + val content: List, // TODO only Text, Image allowed here, should be accessible in gthe builder @SerialName("is_error") val isError: Boolean = false, - val content: List + @SerialName("cache_control") + override val cacheControl: CacheControl? = null ) : Content() +fun ToolResult( + toolUseId: String, + text: String +): ToolResult = ToolResult( + toolUseId, + content = listOf(Text(text)) +) + @Serializable data class CacheControl( val type: Type ) { - @SerialName("ephemeral") enum class Type { + @SerialName("ephemeral") EPHEMERAL } @@ -324,9 +391,9 @@ data class Usage( @SerialName("input_tokens") val inputTokens: Int, @SerialName("cache_creation_input_tokens") - val cacheCreationInputTokens: Int?, + val cacheCreationInputTokens: Int? = null, @SerialName("cache_read_input_tokens") - val cacheReadInputTokens: Int?, + val cacheReadInputTokens: Int? = null, @SerialName("output_tokens") val outputTokens: Int ) diff --git a/src/commonMain/kotlin/schema/JsonSchema.kt b/src/commonMain/kotlin/schema/JsonSchema.kt index 36ee19b..2c15f6b 100644 --- a/src/commonMain/kotlin/schema/JsonSchema.kt +++ b/src/commonMain/kotlin/schema/JsonSchema.kt @@ -18,6 +18,7 @@ data class JsonSchemaProperty( val type: String? = null, val items: JsonSchemaProperty? = null, val enum: List? = null, + @SerialName("\$ref") val ref: String? = null ) { diff --git a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt index ef71a22..ad418cd 100644 --- a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt +++ b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt @@ -52,9 +52,12 @@ private fun generateSchemaProperty( ) StructureKind.MAP -> JsonSchemaProperty("object") StructureKind.CLASS -> { + // dots are not allowed in JSON Schema name, if the @SerialName was not + // specified, then fully qualified class name will be used, and we need + // to translate it val refName = descriptor.serialName.replace('.', '_').trimEnd('?') definitions[refName] = generateSchema(descriptor) - JsonSchemaProperty("\$ref", ref = "#/definitions/$refName") + JsonSchemaProperty(ref = "#/definitions/$refName") } else -> JsonSchemaProperty("object") // Default case } diff --git a/src/commonMain/kotlin/tool/Tools.kt b/src/commonMain/kotlin/tool/Tools.kt new file mode 100644 index 0000000..1ab4bed --- /dev/null +++ b/src/commonMain/kotlin/tool/Tools.kt @@ -0,0 +1,86 @@ +package com.xemantic.anthropic.tool + +import com.xemantic.anthropic.message.CacheControl +import com.xemantic.anthropic.message.Tool +import com.xemantic.anthropic.message.ToolResult +import com.xemantic.anthropic.schema.jsonSchemaOf +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.MetaSerializable +import kotlinx.serialization.SerializationException +import kotlinx.serialization.serializer + +/** + * Annotation used to mark a class extending the [UsableTool]. + * + * This annotation provides metadata for tools that can be serialized and used in the context + * of the Anthropic API. It includes a name and description for the tool. + * + * @property name The name of the tool. This name is used during serialization and should be a unique identifier for the tool. + * @property description A comprehensive description of what the tool does and how it should be used. + */ +@OptIn(ExperimentalSerializationApi::class) +@MetaSerializable +@Target(AnnotationTarget.CLASS) +annotation class AnthropicTool( + val name: String, + val description: String +) + +/** + * Interface for tools that can be used in the context of the Anthropic API. + * + * Classes implementing this interface represent tools that can be executed + * with a given tool use ID. The implementation of the [use] method should + * contain the logic for executing the tool and returning the [ToolResult]. + */ +interface UsableTool { + + /** + * Executes the tool and returns the result. + * + * @param toolUseId A unique identifier for this particular use of the tool. + * @return A [ToolResult] containing the outcome of executing the tool. + */ + fun use( + toolUseId: String + ): ToolResult + +} + +fun Tool.cacheControl( + cacheControl: CacheControl? = null +): Tool = if (cacheControl == null) this else Tool( + name, + description, + inputSchema, + cacheControl +) + +@OptIn(ExperimentalSerializationApi::class) +inline fun toolOf( + cacheControl: CacheControl? = null // TODO should it be here? +): Tool { + + val serializer = try { + serializer() + } catch (e :SerializationException) { + throw SerializationException( + "The class ${T::class.qualifiedName} must be annotated with @SerializableTool", e + ) + } + + val anthropicTool = serializer + .descriptor + .annotations + .filterIsInstance() + .firstOrNull() ?: throw SerializationException( + "The class ${T::class.qualifiedName} must be annotated with @SerializableTool", + ) + + return Tool( + name = anthropicTool.name, + description = anthropicTool.description, + inputSchema = jsonSchemaOf(), + cacheControl = cacheControl + ) +} diff --git a/src/commonTest/kotlin/AnthropicTest.kt b/src/commonTest/kotlin/AnthropicTest.kt index 605b231..b7f8f24 100644 --- a/src/commonTest/kotlin/AnthropicTest.kt +++ b/src/commonTest/kotlin/AnthropicTest.kt @@ -1,6 +1,6 @@ package com.xemantic.anthropic -import com.xemantic.anthropic.event.ContentBlockDelta +import com.xemantic.anthropic.event.ContentBlockDeltaEvent import com.xemantic.anthropic.event.Delta.TextDelta import com.xemantic.anthropic.message.Image import com.xemantic.anthropic.message.Message @@ -8,18 +8,23 @@ import com.xemantic.anthropic.message.MessageResponse import com.xemantic.anthropic.message.Role import com.xemantic.anthropic.message.StopReason import com.xemantic.anthropic.message.Text -import com.xemantic.anthropic.message.Tool -import com.xemantic.anthropic.message.ToolChoice import com.xemantic.anthropic.message.ToolUse +import com.xemantic.anthropic.test.Calculator +import com.xemantic.anthropic.test.DatabaseQueryTool +import com.xemantic.anthropic.test.FibonacciTool +import com.xemantic.anthropic.test.TestDatabase +import io.kotest.assertions.assertSoftly +import io.kotest.matchers.ints.shouldBeGreaterThan +import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe +import io.kotest.matchers.string.shouldContain +import io.kotest.matchers.string.shouldStartWith +import io.kotest.matchers.types.instanceOf import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import kotlinx.coroutines.test.runTest -import kotlinx.serialization.Serializable import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertNull -import kotlin.test.assertTrue class AnthropicTest { @@ -33,23 +38,22 @@ class AnthropicTest { +Message { +"Hello World! What's your name?" } - model = "claude-3-opus-20240229" maxTokens = 1024 } // then - response.apply { - assertTrue(type == MessageResponse.Type.MESSAGE) - assertTrue(role == Role.ASSISTANT) - assertTrue(model == "claude-3-opus-20240229") - assertTrue(content.size == 1) - assertTrue(content[0] is Text) + assertSoftly(response) { + type shouldBe MessageResponse.Type.MESSAGE + role shouldBe Role.ASSISTANT + model shouldBe "claude-3-5-sonnet-20240620" + stopReason shouldBe StopReason.END_TURN + content.size shouldBe 1 + content[0] shouldBe instanceOf() val text = content[0] as Text - assertTrue(text.text.contains("Claude")) - assertTrue(stopReason == StopReason.END_TURN) - assertNull(stopSequence) - assertEquals(usage.inputTokens, 15) - assertTrue(usage.outputTokens > 0) + text.text shouldContain "Claude" + stopSequence shouldBe null + usage.inputTokens shouldBe 15 + usage.outputTokens shouldBeGreaterThan 0 } } @@ -72,11 +76,12 @@ class AnthropicTest { } // then - response.apply { - assertTrue(1 == content.size) - assertTrue(content[0] is Text) + assertSoftly(response) { + stopReason shouldBe StopReason.END_TURN + content.size shouldBe 1 + content[0] shouldBe instanceOf() val text = content[0] as Text - assertTrue(text.text.lowercase().contains("foo")) + text.text.uppercase() shouldContain "FOO" } } @@ -87,106 +92,152 @@ class AnthropicTest { // when val response = client.messages.stream { - +Message { - role = Role.USER - +"Say: 'The quick brown fox jumps over the lazy dog'" - } + +Message { +"Say: 'The sun slowly dipped below the horizon, painting the sky in a breathtaking array of oranges, pinks, and purples.'" } } - .filterIsInstance() + .filterIsInstance() .map { (it.delta as TextDelta).text } .toList() .joinToString(separator = "") // then - assertTrue(response == "The quick brown fox jumps over the lazy dog.") + response shouldBe "The sun slowly dipped below the horizon, painting the sky in a breathtaking array of oranges, pinks, and purples." } - // given - @Serializable - data class Calculator( - val operation: Operation, - val a: Double, - val b: Double - ) { + @Test + fun shouldUseCalculatorTool() = runTest { + // given + val client = Anthropic { + tool() + } + val conversation = mutableListOf() + conversation += Message { +"What's 15 multiplied by 7?" } + + // when + val initialResponse = client.messages.create { + messages = conversation + useTools() + } + conversation += initialResponse.asMessage() - @Suppress("unused") // it is used, but by Anthropic, so we skip the warning - enum class Operation( - val calculate: (a: Double, b: Double) -> Double - ) { - ADD({ a, b -> a + b }), - SUBTRACT({ a, b -> a - b }), - MULTIPLY({ a, b -> a * b }), - DIVIDE({ a, b -> a / b }) + // then + assertSoftly(initialResponse) { + stopReason shouldBe StopReason.TOOL_USE + content.size shouldBe 2 + content[0] shouldBe instanceOf() + content[1] shouldBe instanceOf() + (content[1] as ToolUse).name shouldBe "Calculator" } - fun calculate() = operation.calculate(a, b) + val toolUse = initialResponse.content[1] as ToolUse + val result = toolUse.use() // here we execute the tool + conversation += Message { +result } + + // when + val resultResponse = client.messages.create { + messages = conversation + useTools() + } + + // then + assertSoftly(resultResponse) { + stopReason shouldBe StopReason.END_TURN + content.size shouldBe 1 + content[0] shouldBe instanceOf() + (content[0] as Text).text shouldContain "105" + } } @Test - fun shouldUseCalculatorTool() = runTest { + fun shouldUseFibonacciTool() = runTest { // given - val client = Anthropic() - val calculatorTool = Tool( - description = "Perform basic arithmetic operations", - ) + val client = Anthropic { + tool() + } // when val response = client.messages.create { - +Message { - +"What's 15 multiplied by 7?" - } - tools = listOf(calculatorTool) - toolChoice = ToolChoice.Any() + +Message { +"What's fibonacci number 42" } + useTools() } // then - response.apply { - assertTrue(content.size == 1) - assertTrue(content[0] is ToolUse) - val toolUse = content[0] as ToolUse - assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Calculator") - val calculator = toolUse.input() - val result = calculator.calculate() - assertTrue(result == 15.0 * 7.0) + val toolUse = response.content.filterIsInstance().first() + toolUse.name shouldBe "FibonacciTool" + + val result = toolUse.use() + assertSoftly(result) { + toolUseId shouldBe toolUse.id + isError shouldBe false + content shouldBe listOf(Text(text = "267914296")) } } - @Serializable - data class Fibonacci(val n: Int) + @Test + fun shouldUse2ToolsInSequence() = runTest { + // given + val client = Anthropic { + tool() + tool() + } + + // when + val conversation = mutableListOf() + conversation += Message { +"Calculate Fibonacci number 42 and then divide it by 42" } + + val fibonacciResponse = client.messages.create { + messages = conversation + useTools() + } + conversation += fibonacciResponse.asMessage() - tailrec fun fibonacci( - n: Int, a: Int = 0, b: Int = 1 - ): Int = when (n) { - 0 -> a; 1 -> b; else -> fibonacci(n - 1, b, a + b) + val fibonacciToolUse = fibonacciResponse.content.filterIsInstance().first() + fibonacciToolUse.name shouldBe "FibonacciTool" + val fibonacciResult = fibonacciToolUse.use() + conversation += Message { +fibonacciResult } + + val calculatorResponse = client.messages.create { + messages = conversation + useTools() + } + conversation += calculatorResponse.asMessage() + + val calculatorToolUse = calculatorResponse.content.filterIsInstance().first() + calculatorToolUse.name shouldBe "Calculator" + val calculatorResult = calculatorToolUse.use() + conversation += Message { +calculatorResult } + + val finalResponse = client.messages.create { + messages = conversation + useTools() + } + + finalResponse.content[0] shouldBe instanceOf() + (finalResponse.content[0] as Text).text shouldContain "6,378,911.8" } @Test - fun shouldUseCalculatorToolForFibonacci() = runTest { + fun shouldUseToolWithDependencies() = runTest { // given - val client = Anthropic() - val fibonacciTool = Tool( - description = "Calculates fibonacci number of a given n", - ) + val testDatabase = TestDatabase() + val client = Anthropic { + tool { + database = testDatabase + } + } // when val response = client.messages.create { - +Message { +"What's fibonacci number 42" } - tools = listOf(fibonacciTool) - toolChoice = ToolChoice.Any() + +Message { +"List data in CUSTOMER table" } + useTool() } + val toolUse = response.content.filterIsInstance().first() + toolUse.use() // then - response.apply { - assertTrue(content.size == 1) - assertTrue(content[0] is ToolUse) - val toolUse = content[0] as ToolUse - assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Fibonacci") - val n = toolUse.input().n - assertTrue(n == 42) - val fibonacciNumber = fibonacci(n) // doing the job for Anthropic - assertTrue(fibonacciNumber == 267914296) - } + testDatabase.executedQuery shouldNotBe null + testDatabase.executedQuery!!.uppercase() shouldStartWith "SELECT * FROM CUSTOMER" + // depending on the response the statement might end up with semicolon, which we discard } } diff --git a/src/commonTest/kotlin/message/MessagesTest.kt b/src/commonTest/kotlin/message/MessageRequestTest.kt similarity index 59% rename from src/commonTest/kotlin/message/MessagesTest.kt rename to src/commonTest/kotlin/message/MessageRequestTest.kt index e5e9444..a9244f7 100644 --- a/src/commonTest/kotlin/message/MessagesTest.kt +++ b/src/commonTest/kotlin/message/MessageRequestTest.kt @@ -1,31 +1,29 @@ package com.xemantic.anthropic.message -import com.xemantic.anthropic.anthropicJson +import com.xemantic.anthropic.test.testJson import io.kotest.assertions.json.shouldEqualJson -import kotlinx.serialization.ExperimentalSerializationApi +import io.kotest.matchers.shouldBe import kotlinx.serialization.encodeToString -import kotlinx.serialization.json.Json import kotlin.test.Test /** - * Tests the JSON serialization format of created Anthropic API messages. + * Tests the JSON serialization format of created Anthropic API message requests. */ -class MessagesTest { +class MessageRequestTest { - /** - * A pretty JSON printing for testing. - */ - private val json = Json(from = anthropicJson) { - prettyPrint = true - @OptIn(ExperimentalSerializationApi::class) - prettyPrintIndent = " " + @Test + fun defaultMessageShouldHaveRoleUser() { + // given + val message = Message {} + // then + message.role shouldBe Role.USER } @Test fun shouldCreateTheSimplestMessageRequest() { // given val request = MessageRequest( - defaultModel = "claude-3-opus-20240229" + defaultModel = "claude-3-5-sonnet-20240620" ) { +Message { +"Hey Claude!?" @@ -33,12 +31,12 @@ class MessagesTest { } // when - val json = json.encodeToString(request) + val json = testJson.encodeToString(request) // then json shouldEqualJson """ { - "model": "claude-3-opus-20240229", + "model": "claude-3-5-sonnet-20240620", "messages": [ { "role": "user", diff --git a/src/commonTest/kotlin/message/MessageResponseTest.kt b/src/commonTest/kotlin/message/MessageResponseTest.kt new file mode 100644 index 0000000..d26db8d --- /dev/null +++ b/src/commonTest/kotlin/message/MessageResponseTest.kt @@ -0,0 +1,67 @@ +package com.xemantic.anthropic.message + +import com.xemantic.anthropic.test.testJson +import io.kotest.assertions.assertSoftly +import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.instanceOf +import kotlin.test.Test + +/** + * Tests the JSON format of deserialized Anthropic API message responses. + */ +class MessageResponseTest { + + @Test + fun shouldDeserializeToolUseMessageResponse() { + // given + val jsonResponse = """ + { + "id": "msg_01PspkNzNG3nrf5upeTsmWLF", + "type": "message", + "role": "assistant", + "model": "claude-3-5-sonnet-20240620", + "content": [ + { + "type": "tool_use", + "id": "toolu_01YHJK38TBKCRPn7zfjxcKHx", + "name": "Calculator", + "input": { + "operation": "MULTIPLY", + "a": 15, + "b": 7 + } + } + ], + "stop_reason": "tool_use", + "stop_sequence": null, + "usage": { + "input_tokens": 419, + "output_tokens": 86 + } + } + """.trimIndent() + + val response = testJson.decodeFromString(jsonResponse) + assertSoftly(response) { + id shouldBe "msg_01PspkNzNG3nrf5upeTsmWLF" + type shouldBe MessageResponse.Type.MESSAGE + role shouldBe Role.ASSISTANT + model shouldBe "claude-3-5-sonnet-20240620" + content.size shouldBe 1 + content[0] shouldBe instanceOf() + stopReason shouldBe StopReason.TOOL_USE + stopSequence shouldBe null + usage shouldBe Usage( + inputTokens = 419, + outputTokens = 86 + ) + } + val toolUse = response.content[0] as ToolUse + assertSoftly(toolUse) { + id shouldBe "toolu_01YHJK38TBKCRPn7zfjxcKHx" + name shouldBe "Calculator" + // TODO generate JsonObject to assert input + } + } + +} diff --git a/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt b/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt index 4c824c4..411e56e 100644 --- a/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt +++ b/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt @@ -3,12 +3,14 @@ package com.xemantic.anthropic.schema import com.xemantic.anthropic.anthropicJson import io.kotest.assertions.json.shouldEqualJson import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import kotlin.test.Test @Serializable +@SerialName("address") data class Address( val street: String? = null, val city: String? = null, @@ -35,7 +37,6 @@ class JsonSchemaGeneratorTest { @Test fun generateJsonSchemaForAddress() { - // when val schema = jsonSchemaOf
() val schemaJson = json.encodeToString(schema) @@ -77,7 +78,7 @@ class JsonSchemaGeneratorTest { { "type": "object", "definitions": { - "com_xemantic_anthropic_schema_Address": { + "address": { "type": "object", "properties": { "street": { @@ -116,8 +117,7 @@ class JsonSchemaGeneratorTest { } }, "address": { - "type": "${'$'}ref", - "ref": "#/definitions/com_xemantic_anthropic_schema_Address" + "${'$'}ref": "#/definitions/address" } }, "required": [ diff --git a/src/commonTest/kotlin/test/AnthropicTestSupport.kt b/src/commonTest/kotlin/test/AnthropicTestSupport.kt new file mode 100644 index 0000000..0b693a8 --- /dev/null +++ b/src/commonTest/kotlin/test/AnthropicTestSupport.kt @@ -0,0 +1,16 @@ +package com.xemantic.anthropic.test + +import com.xemantic.anthropic.anthropicJson +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.json.Json + +/** + * A pretty JSON printing for testing. It's derived from [anthropicJson], + * therefore should use the same rules for serialization/deserialization, but + * it has `prettyPrint` and 2 space tab enabled in addition. + */ +val testJson = Json(from = anthropicJson) { + prettyPrint = true + @OptIn(ExperimentalSerializationApi::class) + prettyPrintIndent = " " +} diff --git a/src/commonTest/kotlin/test/AnthropicTestTools.kt b/src/commonTest/kotlin/test/AnthropicTestTools.kt new file mode 100644 index 0000000..c7b8b78 --- /dev/null +++ b/src/commonTest/kotlin/test/AnthropicTestTools.kt @@ -0,0 +1,85 @@ +package com.xemantic.anthropic.test + +import com.xemantic.anthropic.message.ToolResult +import com.xemantic.anthropic.tool.AnthropicTool +import com.xemantic.anthropic.tool.UsableTool +import kotlinx.serialization.Transient + +@AnthropicTool( + name = "FibonacciTool", + description = "Calculate Fibonacci number n" +) +data class FibonacciTool(val n: Int): UsableTool { + + tailrec fun fibonacci( + n: Int, a: Int = 0, b: Int = 1 + ): Int = when (n) { + 0 -> a; 1 -> b; else -> fibonacci(n - 1, b, a + b) + } + + override fun use( + toolUseId: String, + ) = ToolResult(toolUseId, "${fibonacci(n)}") + +} + +@AnthropicTool( + name = "Calculator", + description = "Calculates the arithmetic outcome of an operation when given the arguments a and b" +) +data class Calculator( + val operation: Operation, + val a: Double, + val b: Double +): UsableTool { + + @Suppress("unused") // it is used, but by Anthropic, so we skip the warning + enum class Operation( + val calculate: (a: Double, b: Double) -> Double + ) { + ADD({ a, b -> a + b }), + SUBTRACT({ a, b -> a - b }), + MULTIPLY({ a, b -> a * b }), + DIVIDE({ a, b -> a / b }) + } + + override fun use(toolUseId: String) = ToolResult( + toolUseId, + operation.calculate(a, b).toString() + ) + +} + +interface Database { + fun execute(query: String): List +} + +class TestDatabase : Database { + var executedQuery: String? = null + override fun execute( + query: String + ): List { + executedQuery = query + return listOf("foo", "bar", "buzz") + } +} + +@AnthropicTool( + name = "DatabaseQuery", + description = "Executes database query" +) +data class DatabaseQueryTool( + val query: String +) : UsableTool { + + @Transient + lateinit var database: Database + + override fun use( + toolUseId: String + ) = ToolResult( + toolUseId, + text = database.execute(query).joinToString() + ) + +} diff --git a/src/commonTest/kotlin/tool/UsableToolTest.kt b/src/commonTest/kotlin/tool/UsableToolTest.kt new file mode 100644 index 0000000..516e77f --- /dev/null +++ b/src/commonTest/kotlin/tool/UsableToolTest.kt @@ -0,0 +1,87 @@ +package com.xemantic.anthropic.tool + +import com.xemantic.anthropic.message.CacheControl +import com.xemantic.anthropic.message.ToolResult +import com.xemantic.anthropic.schema.JsonSchema +import com.xemantic.anthropic.schema.JsonSchemaProperty +import io.kotest.assertions.assertSoftly +import io.kotest.assertions.throwables.shouldThrowWithMessage +import io.kotest.matchers.shouldBe +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlin.test.Test + +class UsableToolTest { + + @AnthropicTool( + name = "TestTool", + description = "Test tool receiving a message and outputting it back" + ) + class TestTool( + val message: String + ) : UsableTool { + override fun use(toolUseId: String) = ToolResult(toolUseId, message) + } + + @Test + fun shouldCreateToolFromUsableToolAnnotatedWithAnthropicTool() { + // when + val tool = toolOf() + + assertSoftly(tool) { + name shouldBe "TestTool" + description shouldBe "Test tool receiving a message and outputting it back" + inputSchema shouldBe JsonSchema( + properties = mapOf("message" to JsonSchemaProperty.STRING), + required = listOf("message") + ) + cacheControl shouldBe null + } + } + + @Test + fun shouldCreateToolWithCacheControlFromUsableTool() { + // when + val tool = toolOf( + cacheControl = CacheControl(type = CacheControl.Type.EPHEMERAL) + ) + + assertSoftly(tool) { + name shouldBe "TestTool" + description shouldBe "Test tool receiving a message and outputting it back" + inputSchema shouldBe JsonSchema( + properties = mapOf("message" to JsonSchemaProperty.STRING), + required = listOf("message") + ) + cacheControl shouldBe CacheControl(type = CacheControl.Type.EPHEMERAL) + } + } + + class NoAnnotationTool : UsableTool { + override fun use(toolUseId: String) = ToolResult(toolUseId, "nothing") + } + + @Test + fun shouldFailToCreateToolWithoutAnthropicToolAnnotation() { + shouldThrowWithMessage( + "The class com.xemantic.anthropic.tool.UsableToolTest.NoAnnotationTool must be annotated with @SerializableTool" + ) { + toolOf() + } + } + + @Serializable + class OnlySerializableAnnotationTool : UsableTool { + override fun use(toolUseId: String) = ToolResult(toolUseId, "nothing") + } + + @Test + fun shouldFailToCreateToolWithOnlySerializableAnnotation() { + shouldThrowWithMessage( + "The class com.xemantic.anthropic.tool.UsableToolTest.OnlySerializableAnnotationTool must be annotated with @SerializableTool" + ) { + toolOf() + } + } + +} diff --git a/src/jvmMain/kotlin/JvmAnthropic.kt b/src/jvmMain/kotlin/JvmAnthropic.kt index 84f8a11..036434c 100644 --- a/src/jvmMain/kotlin/JvmAnthropic.kt +++ b/src/jvmMain/kotlin/JvmAnthropic.kt @@ -1,7 +1,40 @@ package com.xemantic.anthropic +import com.xemantic.anthropic.message.MessageRequest +import com.xemantic.anthropic.message.MessageResponse +import kotlinx.coroutines.runBlocking +import java.util.function.Consumer + actual val envApiKey: String? get() = System.getenv("ANTHROPIC_API_KEY") actual val missingApiKeyMessage: String get() = "apiKey is missing, it has to be provided as a parameter or as an ANTHROPIC_API_KEY environment variable." + +// a very early version of Java only SDK, adapting Kotlin idioms and coroutines +// it might change a lot in the future +class JavaAnthropic { + + companion object { + + @JvmStatic + fun create(): Anthropic = Anthropic() + + @JvmStatic + fun create( + configurer: Consumer + ): Anthropic { + return Anthropic { configurer.accept(this) } + } + + } + +} + +fun Anthropic.createMessage( + builder: Consumer +): MessageResponse = runBlocking { + messages.create { + builder.accept(this) + } +}