diff --git a/README.md b/README.md
index 0568979..6c7340a 100644
--- a/README.md
+++ b/README.md
@@ -3,42 +3,44 @@
Unofficial Kotlin multiplatform variant of the
[Antropic SDK](https://docs.anthropic.com/en/api/client-sdks).
-[](https://central.sonatype.com/namespace/com.xemantic.anthropic)
-[](https://github.com/xemantic/anthropic-sdk-kotlin/releases)
-[](https://github.com/xemantic/anthropic-sdk-kotlin/blob/main/LICENSE)
-
-[](https://github.com/xemantic/anthropic-sdk-kotlin/actions/workflows/build-main.yml)
-[](https://github.com/xemantic/anthropic-sdk-kotlin/actions/workflows/build-main.yml)
-[](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/)
-[](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/)
-
-[](https://github.com/xemantic/anthropic-sdk-kotlin/graphs/contributors)
-[](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/)
-[]()
-[](https://github.com/xemantic/anthropic-sdk-kotlin/commit/39c1fa4c138d4c671868c973e2ad37b262ae03c2)
-[](https://kotlinlang.org/docs/releases.html)
-[](https://ktor.io/)
-
-[](https://discord.gg/vQktqqN2Vn)
-[](https://discord.gg/vQktqqN2Vn)
-[](https://x.com/KazikPogoda)
+[](https://central.sonatype.com/namespace/com.xemantic.anthropic)
+[](https://github.com/xemantic/anthropic-sdk-kotlin/releases)
+[](https://github.com/xemantic/anthropic-sdk-kotlin/blob/main/LICENSE)
+
+[](https://github.com/xemantic/anthropic-sdk-kotlin/actions/workflows/build-main.yml)
+[](https://github.com/xemantic/anthropic-sdk-kotlin/actions/workflows/build-main.yml)
+[](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/)
+[](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/)
+
+[](https://github.com/xemantic/anthropic-sdk-kotlin/graphs/contributors)
+[](https://github.com/xemantic/anthropic-sdk-kotlin/commits/main/)
+[]()
+[](https://github.com/xemantic/anthropic-sdk-kotlin/commit/39c1fa4c138d4c671868c973e2ad37b262ae03c2)
+[](https://kotlinlang.org/docs/releases.html)
+[](https://ktor.io/)
+
+[](https://discord.gg/vQktqqN2Vn)
+[](https://discord.gg/vQktqqN2Vn)
+[](https://x.com/KazikPogoda)
## Why?
-I like Kotlin. I like even more the multiplatform aspect of pure Kotlin - that a library code written once
-can be utilized as a:
-
-* regular Java library to be used on backend, desktop, Android, etc.
-* Kotlin library to be used on backend, desktop, Android.
-* executable native binary (e.g. a command line tool)
-* Kotlin app transpiled to JavaScript
-* Kotlin app compiled to WebAssembly
-* JavaScript library
-* TypeScript library
-* native library, working also with Swift/iOS
-
-Having Kotlin multiplatform library for the Anthropic APIs allows
-me to write AI code once, and target all the platforms automatically.
+Because I believe that coding Agentic AI should be as easy as possible. I am coming from the
+[creative coding community](https://creativecode.berlin/), where
+we are teaching artists, without prior programming experience, how to express their creations through
+code as a medium. I want to give creators of all kinds this extremely powerful tool, so that
+**you can turn your own machine into an outside window, through which, the AI system can perceive
+your world and your needs, and act upon this information.**
+
+There is no official Anthropic SDK for Kotlin, a de facto standard for Android development. The one for Java
+is also lacking. Even if they will appear one day, we can expect them to be autogenerated by the
+[Stainless API bot](https://www.stainlessapi.com/), which is used by both, Anthropic and OpenAI, to automate
+their SDK development based on evolving API. While such an approach seem to work with dynamically typed languages,
+it might fail short with statically typed languages like Kotlin, sacrificing typical language idioms in favor
+of [over-verbose constructs](https://github.com/anthropics/anthropic-sdk-go/blob/main/examples/tools/main.go).
+This library is a [Kotlin multiplatform](https://kotlinlang.org/docs/multiplatform.html)
+therefore your AI agents developed with it can be seamlessly used in Android, JVM, JavaScript, iOS, WebAssembly,
+and many other environments.
## Usage
@@ -50,9 +52,7 @@ Add to your `build.gradle.kts`:
```kotlin
dependencies {
implementation("com.xemantic.anthropic:anthropic-sdk-kotlin:.0.2.2")
- // for a JVM project, the client engine will differ per platform
- // check ktor doucmentation for details
- implementation("io.ktor:ktor-client-java:3.0.0-rc-1")
+ implementation("io.ktor:ktor-client-java:3.0.0") // or the latest ktor version
}
```
@@ -60,11 +60,11 @@ The simplest code look like:
```kotlin
fun main() {
- val client = Anthropic()
+ val anthropic = Anthropic()
val response = runBlocking {
- client.messages.create {
+ anthropic.messages.create {
+Message {
- +"Hello World!"
+ +"Hello, Claude"
}
}
}
@@ -72,77 +72,136 @@ fun main() {
}
```
+### Response streaming
+
Streaming is also possible:
```kotlin
fun main() {
val client = Anthropic()
- val pong = runBlocking {
+ runBlocking {
client.messages.stream {
- +Message {
- role = Role.USER
- +"ping!"
- }
+ +Message { +"Write me a poem." }
}
- .filterIsInstance()
- .map { (it.delta as TextDelta).text }
- .toList()
- .joinToString(separator = "")
+ .filterIsInstance()
+ .map { (it.delta as Delta.TextDelta).text }
+ .collect { delta -> println(delta) }
}
- println(pong)
}
```
-It can also use tools:
+### Using tools
+
+If you want to write AI agents, you need tools, and this is where this library shines:
```kotlin
-@Serializable
-data class Calculator(
- val operation: Operation,
- val a: Double,
- val b: Double
-) {
-
- @Suppress("unused") // will be used by Anthropic :)
- enum class Operation(
- val calculate: (a: Double, b: Double) -> Double
- ) {
- ADD({ a, b -> a + b }),
- SUBTRACT({ a, b -> a - b }),
- MULTIPLY({ a, b -> a * b }),
- DIVIDE({ a, b -> a / b })
+@AnthropicTool(
+ name = "get_weather",
+ description = "Get the weather for a specific location"
+)
+data class WeatherTool(val location: String): UsableTool {
+ override fun use(
+ toolUseId: String
+ ) = ToolResult(
+ toolUseId,
+ "The weather is 73f" // it should use some external service
+ )
+}
+
+fun main() = runBlocking {
+
+ val client = Anthropic {
+ tool()
+ }
+
+ val conversation = mutableListOf()
+ conversation += Message { +"What is the weather in SF?" }
+
+ val initialResponse = client.messages.create {
+ messages = conversation
+ useTools()
}
+ println("Initial response:")
+ println(initialResponse)
- fun calculate() = operation.calculate(a, b)
+ conversation += initialResponse.asMessage()
+ val tool = initialResponse.content.filterIsInstance().first()
+ val toolResult = tool.use()
+ conversation += Message { +toolResult }
+ val finalResponse = client.messages.create {
+ messages = conversation
+ useTools()
+ }
+ println("Final response:")
+ println(finalResponse)
}
+```
-fun main() {
- val client = Anthropic()
+The advantage comes no only from reduced verbosity, but also the class annotated with
+the `@AnthropicTool` will have its JSON schema automatically sent to the Anthropic API when
+defining the tool to use. For the reference check equivalent examples in the official
+Anthropic SDKs:
- val calculatorTool = Tool(
- description = "Perform basic arithmetic operations"
- )
+* [TypeScript](https://github.com/anthropics/anthropic-sdk-typescript/blob/main/examples/tools.ts)
+* [Python](https://github.com/anthropics/anthropic-sdk-python/blob/main/examples/tools.py)
+* [Go](https://github.com/anthropics/anthropic-sdk-go/blob/main/examples/tools/main.go)
- val response = runBlocking {
- client.messages.create {
- +Message {
- +"What's 15 multiplied by 7?"
+None of them is taking the advantage of automatic schema generation, which becomes crucial
+for maintaining agents expecting more complex and structured input from the LLM.
+
+### Injecting dependencies to tools
+
+Tools can be provided with dependencies, for example singleton
+services providing some facilities, like HTTP client to connect to the
+internet or DB connection pool to access the database.
+
+```kotlin
+@AnthropicTool(
+ name = "query_database",
+ description = "Executes SQL on the database"
+)
+data class DatabaseQueryTool(val sql: String): UsableTool {
+
+ internal lateinit var connection: Connection
+
+ override fun use(
+ toolUseId: String
+ ) = ToolResult(
+ toolUseId,
+ text = connection.prepareStatement(sql).use { statement ->
+ statement.resultSet.use { resultSet ->
+ resultSet.toString()
}
- tools = listOf(calculatorTool)
- toolChoice = ToolChoice.Any()
}
+ )
+
+}
+
+fun main() = runBlocking {
+
+ val client = Anthropic {
+ tool {
+ connection = DriverManager.getConnection("jdbc:...")
+ }
+ }
+
+ val response = client.messages.create {
+ +Message { +"Select all the users who never logged in to the the system" }
+ useTools()
}
- val toolUse = response.content[0] as ToolUse
- val calculator = toolUse.input()
- val result = calculator.calculate() // we are doing the job for LLM here
- println(result)
+ val tool = response.content.filterIsInstance().first()
+ val toolResult = tool.use()
+ println(toolResult)
}
```
-More sophisticated code examples targeting various
-platforms will follow in the
+After the `DatabaseQueryTool` is decoded from the API response, it can be processed
+by the lambda function passed to the tool definition. In case of the example above,
+the lambda will inject a JDBC connection to the tool.
+
+More sophisticated code examples targeting various Kotlin platforms can be found in the
[anthropic-sdk-kotlin-demo](https://github.com/xemantic/anthropic-sdk-kotlin-demo)
project.
diff --git a/build.gradle.kts b/build.gradle.kts
index 2d87906..bb2aaf5 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -6,7 +6,7 @@ import org.gradle.api.tasks.testing.logging.TestLogEvent
import org.jetbrains.kotlin.gradle.ExperimentalKotlinGradlePluginApi
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
import org.jetbrains.kotlin.gradle.dsl.KotlinVersion
-import org.jetbrains.kotlin.gradle.tasks.KotlinJvmCompile
+import org.jetbrains.kotlin.gradle.targets.native.tasks.KotlinNativeTest
plugins {
alias(libs.plugins.kotlin.multiplatform)
@@ -32,6 +32,10 @@ val signingPassword: String? by project
val sonatypeUser: String? by project
val sonatypePassword: String? by project
+// we don't want to risk that a flaky test will crash the release build
+// and everything should be tested anyway after merging to the main branch
+val skipTests = isReleaseBuild
+
println("""
Project: ${project.name}
Version: ${project.version}
@@ -40,12 +44,25 @@ println("""
)
repositories {
- mavenCentral()
+ mavenCentral()
}
kotlin {
- jvm {}
+ //explicitApi() // check with serialization?
+ jvm {
+ testRuns["test"].executionTask.configure {
+ useJUnitPlatform()
+ }
+ // set up according to https://jakewharton.com/gradle-toolchains-are-rarely-a-good-idea/
+ compilerOptions {
+ apiVersion = kotlinTarget
+ languageVersion = kotlinTarget
+ jvmTarget = JvmTarget.fromTarget(javaTarget)
+ freeCompilerArgs.add("-Xjdk-release=$javaTarget")
+ progressiveMode = true
+ }
+ }
linuxX64()
@@ -64,6 +81,7 @@ kotlin {
dependencies {
implementation(libs.kotlin.test)
implementation(libs.kotlinx.coroutines.test)
+ implementation(libs.kotest.assertions.core)
implementation(libs.kotest.assertions.json)
}
}
@@ -89,17 +107,6 @@ kotlin {
}
-// set up according to https://jakewharton.com/gradle-toolchains-are-rarely-a-good-idea/
-tasks.withType {
- compilerOptions {
- apiVersion = kotlinTarget
- languageVersion = kotlinTarget
- jvmTarget = JvmTarget.fromTarget(javaTarget)
- freeCompilerArgs.add("-Xjdk-release=$javaTarget")
- progressiveMode = true
- }
-}
-
fun isNonStable(version: String): Boolean {
val stableKeyword = listOf("RELEASE", "FINAL", "GA").any { version.uppercase().contains(it) }
val regex = "^[0-9,.v-]+(-r)?$".toRegex()
@@ -113,7 +120,7 @@ tasks.withType {
}
}
-tasks.withType() {
+tasks.withType {
testLogging {
events(
TestLogEvent.PASSED,
@@ -123,15 +130,16 @@ tasks.withType() {
showStackTraces = true
exceptionFormat = TestExceptionFormat.FULL
}
+ enabled = !skipTests
+}
+
+tasks.withType {
+ enabled = !skipTests
}
-@Suppress("OPT_IN_USAGE")
powerAssert {
functions = listOf(
- "kotlin.assert",
- "kotlin.test.assertTrue",
- "kotlin.test.assertEquals",
- "kotlin.test.assertNull"
+ "io.kotest.matchers.shouldBe"
)
includedSourceSets = listOf("commonTest", "jvmTest", "nativeTest")
}
diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml
index 173847e..015165c 100644
--- a/gradle/libs.versions.toml
+++ b/gradle/libs.versions.toml
@@ -2,11 +2,13 @@
kotlinTarget = "2.0"
javaTarget = "17"
-kotlin = "2.0.20"
+kotlin = "2.0.21"
kotlinxCoroutines = "1.9.0"
-ktor = "3.0.0-rc-2"
-kotest = "5.9.1"
+ktor = "3.0.0"
+kotest = "6.0.0.M1"
+# logging is not used at the moment, might be enabled later
+#kotlinLogging = "7.0.0"
log4j = "2.24.1"
jackson = "2.18.0"
@@ -18,10 +20,12 @@ publishPlugin = "2.0.0"
kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" }
kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version.ref = "kotlinxCoroutines" }
-log4j-slf4j2 = { group = "org.apache.logging.log4j", name = "log4j-slf4j2-impl", version.ref = "log4j" }
-log4j-core = { group = "org.apache.logging.log4j", name = "log4j-core", version.ref = "log4j" }
-jackson-databind = { group = "com.fasterxml.jackson.core", name = "jackson-databind", version.ref = "jackson" }
-jackson-dataformat-yaml = { group = "com.fasterxml.jackson.dataformat", name = "jackson-dataformat-yaml", version.ref = "jackson" }
+# logging libs
+#kotlin-logging = { module = "io.github.oshai:kotlin-logging", version.ref = "kotlinLogging" }
+log4j-slf4j2 = { module = "org.apache.logging.log4j:log4j-slf4j2-impl", version.ref = "log4j" }
+log4j-core = { module = "org.apache.logging.log4j:log4j-core", version.ref = "log4j" }
+jackson-databind = { module = "com.fasterxml.jackson.core:jackson-databind", version.ref = "jackson" }
+jackson-dataformat-yaml = { module = "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml", version.ref = "jackson" }
ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" }
ktor-client-content-negotiation = { module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" }
@@ -30,6 +34,7 @@ ktor-serialization-kotlinx-json = { module = "io.ktor:ktor-serialization-kotlinx
ktor-client-java = { module = "io.ktor:ktor-client-java", version.ref = "ktor" }
ktor-client-curl = { module = "io.ktor:ktor-client-curl", version.ref = "ktor" }
+kotest-assertions-core = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" }
kotest-assertions-json = { module = "io.kotest:kotest-assertions-json", version.ref = "kotest" }
[plugins]
diff --git a/src/commonMain/kotlin/Anthropic.kt b/src/commonMain/kotlin/Anthropic.kt
index b38f0a4..72d0834 100644
--- a/src/commonMain/kotlin/Anthropic.kt
+++ b/src/commonMain/kotlin/Anthropic.kt
@@ -5,10 +5,14 @@ import com.xemantic.anthropic.message.Error
import com.xemantic.anthropic.message.ErrorResponse
import com.xemantic.anthropic.message.MessageRequest
import com.xemantic.anthropic.message.MessageResponse
+import com.xemantic.anthropic.message.Tool
+import com.xemantic.anthropic.message.ToolUse
+import com.xemantic.anthropic.tool.UsableTool
+import com.xemantic.anthropic.tool.toolOf
import io.ktor.client.HttpClient
import io.ktor.client.call.body
+import io.ktor.client.plugins.*
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
-import io.ktor.client.plugins.defaultRequest
import io.ktor.client.plugins.logging.LogLevel
import io.ktor.client.plugins.logging.Logging
import io.ktor.client.plugins.sse.SSE
@@ -26,12 +30,25 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
+import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.Json
+import kotlinx.serialization.serializer
+import kotlin.reflect.KType
+import kotlin.reflect.typeOf
+/**
+ * The default Anthropic API base.
+ */
const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/"
+/**
+ * The default version to be passed to the `anthropic-version` HTTP header of each API request.
+ */
const val DEFAULT_ANTHROPIC_VERSION: String = "2023-06-01"
+/**
+ * An exception thrown when API requests returns error.
+ */
class AnthropicException(
error: Error,
httpStatusCode: HttpStatusCode
@@ -41,19 +58,27 @@ expect val envApiKey: String?
expect val missingApiKeyMessage: String
+/**
+ * A JSON format suitable for communication with Anthropic API.
+ */
val anthropicJson: Json = Json {
allowSpecialFloatingPointValues = true
explicitNulls = false
encodeDefaults = true
}
+/**
+ * The public constructor function which for the Anthropic API client.
+ *
+ * @param block the config block to set up the API access.
+ */
fun Anthropic(
block: Anthropic.Config.() -> Unit = {}
): Anthropic {
val config = Anthropic.Config().apply(block)
val apiKey = if (config.apiKey != null) config.apiKey else envApiKey
requireNotNull(apiKey) { missingApiKeyMessage }
- val defaultModel = if (config.defaultModel != null) config.defaultModel!! else "claude-3-opus-20240229"
+ val defaultModel = if (config.defaultModel != null) config.defaultModel!! else "claude-3-5-sonnet-20240620"
return Anthropic(
apiKey = apiKey,
anthropicVersion = config.anthropicVersion,
@@ -61,8 +86,10 @@ fun Anthropic(
apiBase = config.apiBase,
defaultModel = defaultModel,
directBrowserAccess = config.directBrowserAccess
- )
-}
+ ).apply {
+ toolEntryMap = (config.usableTools as List>).associateBy { it.tool.name }
+ }
+} // TODO this can be a second constructor, then toolMap can be private
class Anthropic internal constructor(
val apiKey: String,
@@ -80,8 +107,28 @@ class Anthropic internal constructor(
var apiBase: String = ANTHROPIC_API_BASE
var defaultModel: String? = null
var directBrowserAccess: Boolean = false
+ @PublishedApi
+ internal var usableTools: List> = emptyList()
+
+ inline fun tool(
+ noinline block: T.() -> Unit = {}
+ ) {
+ val entry = ToolEntry(typeOf(), toolOf(), serializer(), block)
+ usableTools += entry
+ }
+
}
+ @PublishedApi
+ internal class ToolEntry(
+ val type: KType,
+ val tool: Tool, // TODO, no cache control
+ val serializer: KSerializer,
+ val initialize: T.() -> Unit = {}
+ )
+
+ internal var toolEntryMap = mapOf>()
+
private val client = HttpClient {
install(ContentNegotiation) {
json(anthropicJson)
@@ -90,6 +137,14 @@ class Anthropic internal constructor(
install(Logging) {
level = LogLevel.BODY
}
+ install(HttpRequestRetry) {
+ retryOnServerErrors(maxRetries = 5)
+ exponentialDelay()
+ maxRetries = 5
+ retryIf { _, response ->
+ response.status == HttpStatusCode.TooManyRequests
+ }
+ }
defaultRequest {
url(apiBase)
header("x-api-key", apiKey)
@@ -103,20 +158,29 @@ class Anthropic internal constructor(
}
}
- inner class Messages() {
+ inner class Messages {
suspend fun create(
block: MessageRequest.Builder.() -> Unit
): MessageResponse {
+
val request = MessageRequest.Builder(
- defaultModel
+ defaultModel,
+ toolEntryMap = toolEntryMap
).apply(block).build()
+
val response = client.post("/v1/messages") {
contentType(ContentType.Application.Json)
setBody(request)
}
if (response.status.isSuccess()) {
- return response.body()
+ return response.body().apply {
+ content.filterIsInstance()
+ .forEach { toolUse ->
+ val entry = toolEntryMap[toolUse.name]!!
+ toolUse.toolEntry = entry
+ }
+ }
} else {
throw AnthropicException(
error = response.body().error,
@@ -129,7 +193,10 @@ class Anthropic internal constructor(
block: MessageRequest.Builder.() -> Unit
): Flow = flow {
- val request = MessageRequest.Builder(defaultModel).apply {
+ val request = MessageRequest.Builder(
+ defaultModel,
+ toolEntryMap = toolEntryMap
+ ).apply {
block(this)
stream = true
}.build()
@@ -157,5 +224,3 @@ class Anthropic internal constructor(
}
-inline fun anthropicTypeOf(): String =
- T::class.qualifiedName!!.replace('.', '_')
diff --git a/src/commonMain/kotlin/event/Events.kt b/src/commonMain/kotlin/event/Events.kt
index b9d3da8..7c5fa32 100644
--- a/src/commonMain/kotlin/event/Events.kt
+++ b/src/commonMain/kotlin/event/Events.kt
@@ -7,6 +7,8 @@ import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonClassDiscriminator
+// reference https://docs.spring.io/spring-ai/reference/_images/anthropic-claude3-events-model.jpg
+
@Serializable
@JsonClassDiscriminator("type")
@OptIn(ExperimentalSerializationApi::class)
@@ -14,14 +16,14 @@ sealed class Event
@Serializable
@SerialName("message_start")
-data class MessageStart(
+data class MessageStartEvent(
val message: MessageResponse
) : Event()
@Serializable
@SerialName("message_delta")
-data class MessageDelta(
- val delta: MessageDelta.Delta,
+data class MessageDeltaEvent(
+ val delta: Delta,
val usage: Usage
) : Event() {
@@ -37,13 +39,15 @@ data class MessageDelta(
@Serializable
@SerialName("message_stop")
-class MessageStop : Event() {
+class MessageStopEvent : Event() {
override fun toString(): String = "MessageStop"
}
+// TODO error event is missing, should we rename all of these to events?
+
@Serializable
@SerialName("content_block_start")
-data class ContentBlockStart(
+data class ContentBlockStartEvent(
val index: Int,
@SerialName("content_block")
val contentBlock: ContentBlock
@@ -51,7 +55,7 @@ data class ContentBlockStart(
@Serializable
@SerialName("content_block_stop")
-data class ContentBlockStop(
+data class ContentBlockStopEvent(
val index: Int
) : Event()
@@ -66,17 +70,23 @@ sealed class ContentBlock {
val text: String
) : ContentBlock()
+ @Serializable
+ @SerialName("tool_use")
+ class ToolUse(
+ val text: String // TODO tool_id
+ ) : ContentBlock()
+ // TODO missing tool_use
}
@Serializable
@SerialName("ping")
-class Ping: Event() {
+class PingEvent: Event() {
override fun toString(): String = "Ping"
}
@Serializable
@SerialName("content_block_delta")
-data class ContentBlockDelta(
+data class ContentBlockDeltaEvent(
val index: Int,
val delta: Delta
) : Event()
diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt
index 6044cfb..defa70c 100644
--- a/src/commonMain/kotlin/message/Messages.kt
+++ b/src/commonMain/kotlin/message/Messages.kt
@@ -1,16 +1,14 @@
package com.xemantic.anthropic.message
+import com.xemantic.anthropic.Anthropic
import com.xemantic.anthropic.anthropicJson
-import com.xemantic.anthropic.anthropicTypeOf
import com.xemantic.anthropic.schema.JsonSchema
-import com.xemantic.anthropic.schema.jsonSchemaOf
-import kotlinx.serialization.ExperimentalSerializationApi
-import kotlinx.serialization.SerialName
-import kotlinx.serialization.Serializable
+import com.xemantic.anthropic.tool.UsableTool
+import kotlinx.serialization.*
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlinx.serialization.json.JsonObject
-import kotlinx.serialization.json.decodeFromJsonElement
import kotlin.collections.mutableListOf
+import kotlin.reflect.typeOf
enum class Role {
@SerialName("user")
@@ -35,32 +33,54 @@ data class MessageRequest(
@SerialName("stop_sequences")
val stopSequences: List?,
val stream: Boolean?,
- val system: List?,
+ val system: List?,
val temperature: Double?,
@SerialName("tool_choice")
val toolChoice: ToolChoice?,
val tools: List?,
+ @SerialName("top_k")
val topK: Int?,
+ @SerialName("top_p")
val topP: Int?
) {
- class Builder(
- val defaultApiModel: String
+ class Builder internal constructor(
+ private val defaultModel: String,
+ @PublishedApi
+ internal val toolEntryMap: Map>
) {
var model: String? = null
var maxTokens = 1024
- val messages = mutableListOf()
+ var messages: List = mutableListOf()
var metadata = null
val stopSequences = mutableListOf()
var stream: Boolean? = null
internal set
- val systemTexts = mutableListOf()
+ var system: List? = null
var temperature: Double? = null
var toolChoice: ToolChoice? = null
var tools: List? = null
val topK: Int? = null
val topP: Int? = null
+ fun useTools() {
+ tools = toolEntryMap.values.map { it.tool }
+ }
+
+ /**
+ * Sets both, the [tools] list and the [toolChoice] with
+ * just one tool to use, forcing the API to respond with the [ToolUse].
+ */
+ inline fun useTool() {
+ val type = typeOf()
+ val toolEntry = toolEntryMap.values.find { it.type == type }
+ requireNotNull(toolEntry) {
+ "No such tool defined in Anthropic client: ${T::class.qualifiedName}"
+ }
+ tools = listOf(toolEntry.tool)
+ toolChoice = ToolChoice.Tool(name = toolEntry.tool.name)
+ }
+
fun messages(vararg messages: Message) {
this.messages += messages.toList()
}
@@ -77,23 +97,20 @@ data class MessageRequest(
this.stopSequences += stopSequences.toList()
}
- var system: String?
- get() = if (systemTexts.isEmpty()) null else systemTexts[0].text
- set(value) {
- systemTexts.clear()
- if (value != null) {
- systemTexts.add(Text(text = value))
- }
- }
+ fun system(
+ text: String
+ ) {
+ system = listOf(System(text = text))
+ }
fun build(): MessageRequest = MessageRequest(
- model = if (model != null) model!! else defaultApiModel,
+ model = if (model != null) model!! else defaultModel,
maxTokens = maxTokens,
messages = messages,
metadata = metadata,
stopSequences = stopSequences.toNullIfEmpty(),
stream = if (stream != null) stream else null,
- system = systemTexts.toNullIfEmpty(),
+ system = system,
temperature = temperature,
toolChoice = toolChoice,
tools = tools,
@@ -108,7 +125,9 @@ fun MessageRequest(
defaultModel: String,
block: MessageRequest.Builder.() -> Unit
): MessageRequest {
- val builder = MessageRequest.Builder(defaultModel)
+ val builder = MessageRequest.Builder(
+ defaultModel, emptyMap()
+ )
block(builder)
return builder.build()
}
@@ -131,8 +150,15 @@ data class MessageResponse(
@SerialName("message")
MESSAGE
}
+
+ fun asMessage(): Message = Message {
+ role = Role.ASSISTANT
+ content += this@MessageResponse.content
+ }
+
}
+
@Serializable
data class ErrorResponse(
val type: String,
@@ -180,30 +206,37 @@ fun Message(block: Message.Builder.() -> Unit): Message {
return builder.build()
}
+@Serializable
+data class System(
+ @SerialName("cache_control")
+ val cacheControl: CacheControl? = null,
+ val type: Type = Type.TEXT,
+ val text: String? = null,
+) {
+
+ enum class Type {
+ @SerialName("text")
+ TEXT
+ }
+
+}
+
@Serializable
data class Tool(
val name: String,
val description: String,
@SerialName("input_schema")
val inputSchema: JsonSchema,
+ @SerialName("cache_control")
val cacheControl: CacheControl?
)
-inline fun Tool(
- description: String,
- cacheControl: CacheControl? = null
-): Tool = Tool(
- name = anthropicTypeOf(),
- description = description,
- inputSchema = jsonSchemaOf(),
- cacheControl = cacheControl
-)
-
@Serializable
@JsonClassDiscriminator("type")
@OptIn(ExperimentalSerializationApi::class)
sealed class Content {
+ @SerialName("cache_control")
abstract val cacheControl: CacheControl?
}
@@ -212,6 +245,7 @@ sealed class Content {
@SerialName("text")
data class Text(
val text: String,
+ @SerialName("cache_control")
override val cacheControl: CacheControl? = null,
) : Content()
@@ -219,6 +253,7 @@ data class Text(
@SerialName("image")
data class Image(
val source: Source,
+ @SerialName("cache_control")
override val cacheControl: CacheControl? = null
) : Content() {
@@ -250,38 +285,70 @@ data class Image(
}
-@Serializable
@SerialName("tool_use")
+@Serializable
data class ToolUse(
+ @SerialName("cache_control")
override val cacheControl: CacheControl? = null,
val id: String,
val name: String,
val input: JsonObject
) : Content() {
- inline fun input(): T =
- anthropicJson.decodeFromJsonElement(input)
+ @Transient
+ internal lateinit var toolEntry: Anthropic.ToolEntry
+
+ fun use(): ToolResult {
+ val tool = anthropicJson.decodeFromJsonElement(
+ deserializer = toolEntry.serializer,
+ element = input
+ )
+ val result = try {
+ toolEntry.initialize(tool)
+ tool.use(toolUseId = id)
+ } catch (e: Exception) {
+ ToolResult(
+ toolUseId = id,
+ isError = true,
+ content = listOf(
+ Text(
+ text = e.message ?: "Unknown error occurred"
+ )
+ )
+ )
+ }
+ return result
+ }
}
@Serializable
@SerialName("tool_result")
data class ToolResult(
- override val cacheControl: CacheControl? = null,
@SerialName("tool_use_id")
val toolUseId: String,
+ val content: List, // TODO only Text, Image allowed here, should be accessible in gthe builder
@SerialName("is_error")
val isError: Boolean = false,
- val content: List
+ @SerialName("cache_control")
+ override val cacheControl: CacheControl? = null
) : Content()
+fun ToolResult(
+ toolUseId: String,
+ text: String
+): ToolResult = ToolResult(
+ toolUseId,
+ content = listOf(Text(text))
+)
+
@Serializable
data class CacheControl(
val type: Type
) {
- @SerialName("ephemeral")
enum class Type {
+ @SerialName("ephemeral")
EPHEMERAL
}
@@ -324,9 +391,9 @@ data class Usage(
@SerialName("input_tokens")
val inputTokens: Int,
@SerialName("cache_creation_input_tokens")
- val cacheCreationInputTokens: Int?,
+ val cacheCreationInputTokens: Int? = null,
@SerialName("cache_read_input_tokens")
- val cacheReadInputTokens: Int?,
+ val cacheReadInputTokens: Int? = null,
@SerialName("output_tokens")
val outputTokens: Int
)
diff --git a/src/commonMain/kotlin/schema/JsonSchema.kt b/src/commonMain/kotlin/schema/JsonSchema.kt
index 36ee19b..2c15f6b 100644
--- a/src/commonMain/kotlin/schema/JsonSchema.kt
+++ b/src/commonMain/kotlin/schema/JsonSchema.kt
@@ -18,6 +18,7 @@ data class JsonSchemaProperty(
val type: String? = null,
val items: JsonSchemaProperty? = null,
val enum: List? = null,
+ @SerialName("\$ref")
val ref: String? = null
) {
diff --git a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt
index ef71a22..ad418cd 100644
--- a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt
+++ b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt
@@ -52,9 +52,12 @@ private fun generateSchemaProperty(
)
StructureKind.MAP -> JsonSchemaProperty("object")
StructureKind.CLASS -> {
+ // dots are not allowed in JSON Schema name, if the @SerialName was not
+ // specified, then fully qualified class name will be used, and we need
+ // to translate it
val refName = descriptor.serialName.replace('.', '_').trimEnd('?')
definitions[refName] = generateSchema(descriptor)
- JsonSchemaProperty("\$ref", ref = "#/definitions/$refName")
+ JsonSchemaProperty(ref = "#/definitions/$refName")
}
else -> JsonSchemaProperty("object") // Default case
}
diff --git a/src/commonMain/kotlin/tool/Tools.kt b/src/commonMain/kotlin/tool/Tools.kt
new file mode 100644
index 0000000..1ab4bed
--- /dev/null
+++ b/src/commonMain/kotlin/tool/Tools.kt
@@ -0,0 +1,86 @@
+package com.xemantic.anthropic.tool
+
+import com.xemantic.anthropic.message.CacheControl
+import com.xemantic.anthropic.message.Tool
+import com.xemantic.anthropic.message.ToolResult
+import com.xemantic.anthropic.schema.jsonSchemaOf
+import kotlinx.serialization.ExperimentalSerializationApi
+import kotlinx.serialization.MetaSerializable
+import kotlinx.serialization.SerializationException
+import kotlinx.serialization.serializer
+
+/**
+ * Annotation used to mark a class extending the [UsableTool].
+ *
+ * This annotation provides metadata for tools that can be serialized and used in the context
+ * of the Anthropic API. It includes a name and description for the tool.
+ *
+ * @property name The name of the tool. This name is used during serialization and should be a unique identifier for the tool.
+ * @property description A comprehensive description of what the tool does and how it should be used.
+ */
+@OptIn(ExperimentalSerializationApi::class)
+@MetaSerializable
+@Target(AnnotationTarget.CLASS)
+annotation class AnthropicTool(
+ val name: String,
+ val description: String
+)
+
+/**
+ * Interface for tools that can be used in the context of the Anthropic API.
+ *
+ * Classes implementing this interface represent tools that can be executed
+ * with a given tool use ID. The implementation of the [use] method should
+ * contain the logic for executing the tool and returning the [ToolResult].
+ */
+interface UsableTool {
+
+ /**
+ * Executes the tool and returns the result.
+ *
+ * @param toolUseId A unique identifier for this particular use of the tool.
+ * @return A [ToolResult] containing the outcome of executing the tool.
+ */
+ fun use(
+ toolUseId: String
+ ): ToolResult
+
+}
+
+fun Tool.cacheControl(
+ cacheControl: CacheControl? = null
+): Tool = if (cacheControl == null) this else Tool(
+ name,
+ description,
+ inputSchema,
+ cacheControl
+)
+
+@OptIn(ExperimentalSerializationApi::class)
+inline fun toolOf(
+ cacheControl: CacheControl? = null // TODO should it be here?
+): Tool {
+
+ val serializer = try {
+ serializer()
+ } catch (e :SerializationException) {
+ throw SerializationException(
+ "The class ${T::class.qualifiedName} must be annotated with @SerializableTool", e
+ )
+ }
+
+ val anthropicTool = serializer
+ .descriptor
+ .annotations
+ .filterIsInstance()
+ .firstOrNull() ?: throw SerializationException(
+ "The class ${T::class.qualifiedName} must be annotated with @SerializableTool",
+ )
+
+ return Tool(
+ name = anthropicTool.name,
+ description = anthropicTool.description,
+ inputSchema = jsonSchemaOf(),
+ cacheControl = cacheControl
+ )
+}
diff --git a/src/commonTest/kotlin/AnthropicTest.kt b/src/commonTest/kotlin/AnthropicTest.kt
index 605b231..b7f8f24 100644
--- a/src/commonTest/kotlin/AnthropicTest.kt
+++ b/src/commonTest/kotlin/AnthropicTest.kt
@@ -1,6 +1,6 @@
package com.xemantic.anthropic
-import com.xemantic.anthropic.event.ContentBlockDelta
+import com.xemantic.anthropic.event.ContentBlockDeltaEvent
import com.xemantic.anthropic.event.Delta.TextDelta
import com.xemantic.anthropic.message.Image
import com.xemantic.anthropic.message.Message
@@ -8,18 +8,23 @@ import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.message.Role
import com.xemantic.anthropic.message.StopReason
import com.xemantic.anthropic.message.Text
-import com.xemantic.anthropic.message.Tool
-import com.xemantic.anthropic.message.ToolChoice
import com.xemantic.anthropic.message.ToolUse
+import com.xemantic.anthropic.test.Calculator
+import com.xemantic.anthropic.test.DatabaseQueryTool
+import com.xemantic.anthropic.test.FibonacciTool
+import com.xemantic.anthropic.test.TestDatabase
+import io.kotest.assertions.assertSoftly
+import io.kotest.matchers.ints.shouldBeGreaterThan
+import io.kotest.matchers.shouldBe
+import io.kotest.matchers.shouldNotBe
+import io.kotest.matchers.string.shouldContain
+import io.kotest.matchers.string.shouldStartWith
+import io.kotest.matchers.types.instanceOf
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.test.runTest
-import kotlinx.serialization.Serializable
import kotlin.test.Test
-import kotlin.test.assertEquals
-import kotlin.test.assertNull
-import kotlin.test.assertTrue
class AnthropicTest {
@@ -33,23 +38,22 @@ class AnthropicTest {
+Message {
+"Hello World! What's your name?"
}
- model = "claude-3-opus-20240229"
maxTokens = 1024
}
// then
- response.apply {
- assertTrue(type == MessageResponse.Type.MESSAGE)
- assertTrue(role == Role.ASSISTANT)
- assertTrue(model == "claude-3-opus-20240229")
- assertTrue(content.size == 1)
- assertTrue(content[0] is Text)
+ assertSoftly(response) {
+ type shouldBe MessageResponse.Type.MESSAGE
+ role shouldBe Role.ASSISTANT
+ model shouldBe "claude-3-5-sonnet-20240620"
+ stopReason shouldBe StopReason.END_TURN
+ content.size shouldBe 1
+ content[0] shouldBe instanceOf()
val text = content[0] as Text
- assertTrue(text.text.contains("Claude"))
- assertTrue(stopReason == StopReason.END_TURN)
- assertNull(stopSequence)
- assertEquals(usage.inputTokens, 15)
- assertTrue(usage.outputTokens > 0)
+ text.text shouldContain "Claude"
+ stopSequence shouldBe null
+ usage.inputTokens shouldBe 15
+ usage.outputTokens shouldBeGreaterThan 0
}
}
@@ -72,11 +76,12 @@ class AnthropicTest {
}
// then
- response.apply {
- assertTrue(1 == content.size)
- assertTrue(content[0] is Text)
+ assertSoftly(response) {
+ stopReason shouldBe StopReason.END_TURN
+ content.size shouldBe 1
+ content[0] shouldBe instanceOf()
val text = content[0] as Text
- assertTrue(text.text.lowercase().contains("foo"))
+ text.text.uppercase() shouldContain "FOO"
}
}
@@ -87,106 +92,152 @@ class AnthropicTest {
// when
val response = client.messages.stream {
- +Message {
- role = Role.USER
- +"Say: 'The quick brown fox jumps over the lazy dog'"
- }
+ +Message { +"Say: 'The sun slowly dipped below the horizon, painting the sky in a breathtaking array of oranges, pinks, and purples.'" }
}
- .filterIsInstance()
+ .filterIsInstance()
.map { (it.delta as TextDelta).text }
.toList()
.joinToString(separator = "")
// then
- assertTrue(response == "The quick brown fox jumps over the lazy dog.")
+ response shouldBe "The sun slowly dipped below the horizon, painting the sky in a breathtaking array of oranges, pinks, and purples."
}
- // given
- @Serializable
- data class Calculator(
- val operation: Operation,
- val a: Double,
- val b: Double
- ) {
+ @Test
+ fun shouldUseCalculatorTool() = runTest {
+ // given
+ val client = Anthropic {
+ tool()
+ }
+ val conversation = mutableListOf()
+ conversation += Message { +"What's 15 multiplied by 7?" }
+
+ // when
+ val initialResponse = client.messages.create {
+ messages = conversation
+ useTools()
+ }
+ conversation += initialResponse.asMessage()
- @Suppress("unused") // it is used, but by Anthropic, so we skip the warning
- enum class Operation(
- val calculate: (a: Double, b: Double) -> Double
- ) {
- ADD({ a, b -> a + b }),
- SUBTRACT({ a, b -> a - b }),
- MULTIPLY({ a, b -> a * b }),
- DIVIDE({ a, b -> a / b })
+ // then
+ assertSoftly(initialResponse) {
+ stopReason shouldBe StopReason.TOOL_USE
+ content.size shouldBe 2
+ content[0] shouldBe instanceOf()
+ content[1] shouldBe instanceOf()
+ (content[1] as ToolUse).name shouldBe "Calculator"
}
- fun calculate() = operation.calculate(a, b)
+ val toolUse = initialResponse.content[1] as ToolUse
+ val result = toolUse.use() // here we execute the tool
+ conversation += Message { +result }
+
+ // when
+ val resultResponse = client.messages.create {
+ messages = conversation
+ useTools()
+ }
+
+ // then
+ assertSoftly(resultResponse) {
+ stopReason shouldBe StopReason.END_TURN
+ content.size shouldBe 1
+ content[0] shouldBe instanceOf()
+ (content[0] as Text).text shouldContain "105"
+ }
}
@Test
- fun shouldUseCalculatorTool() = runTest {
+ fun shouldUseFibonacciTool() = runTest {
// given
- val client = Anthropic()
- val calculatorTool = Tool(
- description = "Perform basic arithmetic operations",
- )
+ val client = Anthropic {
+ tool()
+ }
// when
val response = client.messages.create {
- +Message {
- +"What's 15 multiplied by 7?"
- }
- tools = listOf(calculatorTool)
- toolChoice = ToolChoice.Any()
+ +Message { +"What's fibonacci number 42" }
+ useTools()
}
// then
- response.apply {
- assertTrue(content.size == 1)
- assertTrue(content[0] is ToolUse)
- val toolUse = content[0] as ToolUse
- assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Calculator")
- val calculator = toolUse.input()
- val result = calculator.calculate()
- assertTrue(result == 15.0 * 7.0)
+ val toolUse = response.content.filterIsInstance().first()
+ toolUse.name shouldBe "FibonacciTool"
+
+ val result = toolUse.use()
+ assertSoftly(result) {
+ toolUseId shouldBe toolUse.id
+ isError shouldBe false
+ content shouldBe listOf(Text(text = "267914296"))
}
}
- @Serializable
- data class Fibonacci(val n: Int)
+ @Test
+ fun shouldUse2ToolsInSequence() = runTest {
+ // given
+ val client = Anthropic {
+ tool()
+ tool()
+ }
+
+ // when
+ val conversation = mutableListOf()
+ conversation += Message { +"Calculate Fibonacci number 42 and then divide it by 42" }
+
+ val fibonacciResponse = client.messages.create {
+ messages = conversation
+ useTools()
+ }
+ conversation += fibonacciResponse.asMessage()
- tailrec fun fibonacci(
- n: Int, a: Int = 0, b: Int = 1
- ): Int = when (n) {
- 0 -> a; 1 -> b; else -> fibonacci(n - 1, b, a + b)
+ val fibonacciToolUse = fibonacciResponse.content.filterIsInstance().first()
+ fibonacciToolUse.name shouldBe "FibonacciTool"
+ val fibonacciResult = fibonacciToolUse.use()
+ conversation += Message { +fibonacciResult }
+
+ val calculatorResponse = client.messages.create {
+ messages = conversation
+ useTools()
+ }
+ conversation += calculatorResponse.asMessage()
+
+ val calculatorToolUse = calculatorResponse.content.filterIsInstance().first()
+ calculatorToolUse.name shouldBe "Calculator"
+ val calculatorResult = calculatorToolUse.use()
+ conversation += Message { +calculatorResult }
+
+ val finalResponse = client.messages.create {
+ messages = conversation
+ useTools()
+ }
+
+ finalResponse.content[0] shouldBe instanceOf()
+ (finalResponse.content[0] as Text).text shouldContain "6,378,911.8"
}
@Test
- fun shouldUseCalculatorToolForFibonacci() = runTest {
+ fun shouldUseToolWithDependencies() = runTest {
// given
- val client = Anthropic()
- val fibonacciTool = Tool(
- description = "Calculates fibonacci number of a given n",
- )
+ val testDatabase = TestDatabase()
+ val client = Anthropic {
+ tool {
+ database = testDatabase
+ }
+ }
// when
val response = client.messages.create {
- +Message { +"What's fibonacci number 42" }
- tools = listOf(fibonacciTool)
- toolChoice = ToolChoice.Any()
+ +Message { +"List data in CUSTOMER table" }
+ useTool()
}
+ val toolUse = response.content.filterIsInstance().first()
+ toolUse.use()
// then
- response.apply {
- assertTrue(content.size == 1)
- assertTrue(content[0] is ToolUse)
- val toolUse = content[0] as ToolUse
- assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Fibonacci")
- val n = toolUse.input().n
- assertTrue(n == 42)
- val fibonacciNumber = fibonacci(n) // doing the job for Anthropic
- assertTrue(fibonacciNumber == 267914296)
- }
+ testDatabase.executedQuery shouldNotBe null
+ testDatabase.executedQuery!!.uppercase() shouldStartWith "SELECT * FROM CUSTOMER"
+ // depending on the response the statement might end up with semicolon, which we discard
}
}
diff --git a/src/commonTest/kotlin/message/MessagesTest.kt b/src/commonTest/kotlin/message/MessageRequestTest.kt
similarity index 59%
rename from src/commonTest/kotlin/message/MessagesTest.kt
rename to src/commonTest/kotlin/message/MessageRequestTest.kt
index e5e9444..a9244f7 100644
--- a/src/commonTest/kotlin/message/MessagesTest.kt
+++ b/src/commonTest/kotlin/message/MessageRequestTest.kt
@@ -1,31 +1,29 @@
package com.xemantic.anthropic.message
-import com.xemantic.anthropic.anthropicJson
+import com.xemantic.anthropic.test.testJson
import io.kotest.assertions.json.shouldEqualJson
-import kotlinx.serialization.ExperimentalSerializationApi
+import io.kotest.matchers.shouldBe
import kotlinx.serialization.encodeToString
-import kotlinx.serialization.json.Json
import kotlin.test.Test
/**
- * Tests the JSON serialization format of created Anthropic API messages.
+ * Tests the JSON serialization format of created Anthropic API message requests.
*/
-class MessagesTest {
+class MessageRequestTest {
- /**
- * A pretty JSON printing for testing.
- */
- private val json = Json(from = anthropicJson) {
- prettyPrint = true
- @OptIn(ExperimentalSerializationApi::class)
- prettyPrintIndent = " "
+ @Test
+ fun defaultMessageShouldHaveRoleUser() {
+ // given
+ val message = Message {}
+ // then
+ message.role shouldBe Role.USER
}
@Test
fun shouldCreateTheSimplestMessageRequest() {
// given
val request = MessageRequest(
- defaultModel = "claude-3-opus-20240229"
+ defaultModel = "claude-3-5-sonnet-20240620"
) {
+Message {
+"Hey Claude!?"
@@ -33,12 +31,12 @@ class MessagesTest {
}
// when
- val json = json.encodeToString(request)
+ val json = testJson.encodeToString(request)
// then
json shouldEqualJson """
{
- "model": "claude-3-opus-20240229",
+ "model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
diff --git a/src/commonTest/kotlin/message/MessageResponseTest.kt b/src/commonTest/kotlin/message/MessageResponseTest.kt
new file mode 100644
index 0000000..d26db8d
--- /dev/null
+++ b/src/commonTest/kotlin/message/MessageResponseTest.kt
@@ -0,0 +1,67 @@
+package com.xemantic.anthropic.message
+
+import com.xemantic.anthropic.test.testJson
+import io.kotest.assertions.assertSoftly
+import io.kotest.matchers.shouldBe
+import io.kotest.matchers.types.instanceOf
+import kotlin.test.Test
+
+/**
+ * Tests the JSON format of deserialized Anthropic API message responses.
+ */
+class MessageResponseTest {
+
+ @Test
+ fun shouldDeserializeToolUseMessageResponse() {
+ // given
+ val jsonResponse = """
+ {
+ "id": "msg_01PspkNzNG3nrf5upeTsmWLF",
+ "type": "message",
+ "role": "assistant",
+ "model": "claude-3-5-sonnet-20240620",
+ "content": [
+ {
+ "type": "tool_use",
+ "id": "toolu_01YHJK38TBKCRPn7zfjxcKHx",
+ "name": "Calculator",
+ "input": {
+ "operation": "MULTIPLY",
+ "a": 15,
+ "b": 7
+ }
+ }
+ ],
+ "stop_reason": "tool_use",
+ "stop_sequence": null,
+ "usage": {
+ "input_tokens": 419,
+ "output_tokens": 86
+ }
+ }
+ """.trimIndent()
+
+ val response = testJson.decodeFromString(jsonResponse)
+ assertSoftly(response) {
+ id shouldBe "msg_01PspkNzNG3nrf5upeTsmWLF"
+ type shouldBe MessageResponse.Type.MESSAGE
+ role shouldBe Role.ASSISTANT
+ model shouldBe "claude-3-5-sonnet-20240620"
+ content.size shouldBe 1
+ content[0] shouldBe instanceOf()
+ stopReason shouldBe StopReason.TOOL_USE
+ stopSequence shouldBe null
+ usage shouldBe Usage(
+ inputTokens = 419,
+ outputTokens = 86
+ )
+ }
+ val toolUse = response.content[0] as ToolUse
+ assertSoftly(toolUse) {
+ id shouldBe "toolu_01YHJK38TBKCRPn7zfjxcKHx"
+ name shouldBe "Calculator"
+ // TODO generate JsonObject to assert input
+ }
+ }
+
+}
diff --git a/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt b/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt
index 4c824c4..411e56e 100644
--- a/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt
+++ b/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt
@@ -3,12 +3,14 @@ package com.xemantic.anthropic.schema
import com.xemantic.anthropic.anthropicJson
import io.kotest.assertions.json.shouldEqualJson
import kotlinx.serialization.ExperimentalSerializationApi
+import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlin.test.Test
@Serializable
+@SerialName("address")
data class Address(
val street: String? = null,
val city: String? = null,
@@ -35,7 +37,6 @@ class JsonSchemaGeneratorTest {
@Test
fun generateJsonSchemaForAddress() {
-
// when
val schema = jsonSchemaOf()
val schemaJson = json.encodeToString(schema)
@@ -77,7 +78,7 @@ class JsonSchemaGeneratorTest {
{
"type": "object",
"definitions": {
- "com_xemantic_anthropic_schema_Address": {
+ "address": {
"type": "object",
"properties": {
"street": {
@@ -116,8 +117,7 @@ class JsonSchemaGeneratorTest {
}
},
"address": {
- "type": "${'$'}ref",
- "ref": "#/definitions/com_xemantic_anthropic_schema_Address"
+ "${'$'}ref": "#/definitions/address"
}
},
"required": [
diff --git a/src/commonTest/kotlin/test/AnthropicTestSupport.kt b/src/commonTest/kotlin/test/AnthropicTestSupport.kt
new file mode 100644
index 0000000..0b693a8
--- /dev/null
+++ b/src/commonTest/kotlin/test/AnthropicTestSupport.kt
@@ -0,0 +1,16 @@
+package com.xemantic.anthropic.test
+
+import com.xemantic.anthropic.anthropicJson
+import kotlinx.serialization.ExperimentalSerializationApi
+import kotlinx.serialization.json.Json
+
+/**
+ * A pretty JSON printing for testing. It's derived from [anthropicJson],
+ * therefore should use the same rules for serialization/deserialization, but
+ * it has `prettyPrint` and 2 space tab enabled in addition.
+ */
+val testJson = Json(from = anthropicJson) {
+ prettyPrint = true
+ @OptIn(ExperimentalSerializationApi::class)
+ prettyPrintIndent = " "
+}
diff --git a/src/commonTest/kotlin/test/AnthropicTestTools.kt b/src/commonTest/kotlin/test/AnthropicTestTools.kt
new file mode 100644
index 0000000..c7b8b78
--- /dev/null
+++ b/src/commonTest/kotlin/test/AnthropicTestTools.kt
@@ -0,0 +1,85 @@
+package com.xemantic.anthropic.test
+
+import com.xemantic.anthropic.message.ToolResult
+import com.xemantic.anthropic.tool.AnthropicTool
+import com.xemantic.anthropic.tool.UsableTool
+import kotlinx.serialization.Transient
+
+@AnthropicTool(
+ name = "FibonacciTool",
+ description = "Calculate Fibonacci number n"
+)
+data class FibonacciTool(val n: Int): UsableTool {
+
+ tailrec fun fibonacci(
+ n: Int, a: Int = 0, b: Int = 1
+ ): Int = when (n) {
+ 0 -> a; 1 -> b; else -> fibonacci(n - 1, b, a + b)
+ }
+
+ override fun use(
+ toolUseId: String,
+ ) = ToolResult(toolUseId, "${fibonacci(n)}")
+
+}
+
+@AnthropicTool(
+ name = "Calculator",
+ description = "Calculates the arithmetic outcome of an operation when given the arguments a and b"
+)
+data class Calculator(
+ val operation: Operation,
+ val a: Double,
+ val b: Double
+): UsableTool {
+
+ @Suppress("unused") // it is used, but by Anthropic, so we skip the warning
+ enum class Operation(
+ val calculate: (a: Double, b: Double) -> Double
+ ) {
+ ADD({ a, b -> a + b }),
+ SUBTRACT({ a, b -> a - b }),
+ MULTIPLY({ a, b -> a * b }),
+ DIVIDE({ a, b -> a / b })
+ }
+
+ override fun use(toolUseId: String) = ToolResult(
+ toolUseId,
+ operation.calculate(a, b).toString()
+ )
+
+}
+
+interface Database {
+ fun execute(query: String): List
+}
+
+class TestDatabase : Database {
+ var executedQuery: String? = null
+ override fun execute(
+ query: String
+ ): List {
+ executedQuery = query
+ return listOf("foo", "bar", "buzz")
+ }
+}
+
+@AnthropicTool(
+ name = "DatabaseQuery",
+ description = "Executes database query"
+)
+data class DatabaseQueryTool(
+ val query: String
+) : UsableTool {
+
+ @Transient
+ lateinit var database: Database
+
+ override fun use(
+ toolUseId: String
+ ) = ToolResult(
+ toolUseId,
+ text = database.execute(query).joinToString()
+ )
+
+}
diff --git a/src/commonTest/kotlin/tool/UsableToolTest.kt b/src/commonTest/kotlin/tool/UsableToolTest.kt
new file mode 100644
index 0000000..516e77f
--- /dev/null
+++ b/src/commonTest/kotlin/tool/UsableToolTest.kt
@@ -0,0 +1,87 @@
+package com.xemantic.anthropic.tool
+
+import com.xemantic.anthropic.message.CacheControl
+import com.xemantic.anthropic.message.ToolResult
+import com.xemantic.anthropic.schema.JsonSchema
+import com.xemantic.anthropic.schema.JsonSchemaProperty
+import io.kotest.assertions.assertSoftly
+import io.kotest.assertions.throwables.shouldThrowWithMessage
+import io.kotest.matchers.shouldBe
+import kotlinx.serialization.Serializable
+import kotlinx.serialization.SerializationException
+import kotlin.test.Test
+
+class UsableToolTest {
+
+ @AnthropicTool(
+ name = "TestTool",
+ description = "Test tool receiving a message and outputting it back"
+ )
+ class TestTool(
+ val message: String
+ ) : UsableTool {
+ override fun use(toolUseId: String) = ToolResult(toolUseId, message)
+ }
+
+ @Test
+ fun shouldCreateToolFromUsableToolAnnotatedWithAnthropicTool() {
+ // when
+ val tool = toolOf()
+
+ assertSoftly(tool) {
+ name shouldBe "TestTool"
+ description shouldBe "Test tool receiving a message and outputting it back"
+ inputSchema shouldBe JsonSchema(
+ properties = mapOf("message" to JsonSchemaProperty.STRING),
+ required = listOf("message")
+ )
+ cacheControl shouldBe null
+ }
+ }
+
+ @Test
+ fun shouldCreateToolWithCacheControlFromUsableTool() {
+ // when
+ val tool = toolOf(
+ cacheControl = CacheControl(type = CacheControl.Type.EPHEMERAL)
+ )
+
+ assertSoftly(tool) {
+ name shouldBe "TestTool"
+ description shouldBe "Test tool receiving a message and outputting it back"
+ inputSchema shouldBe JsonSchema(
+ properties = mapOf("message" to JsonSchemaProperty.STRING),
+ required = listOf("message")
+ )
+ cacheControl shouldBe CacheControl(type = CacheControl.Type.EPHEMERAL)
+ }
+ }
+
+ class NoAnnotationTool : UsableTool {
+ override fun use(toolUseId: String) = ToolResult(toolUseId, "nothing")
+ }
+
+ @Test
+ fun shouldFailToCreateToolWithoutAnthropicToolAnnotation() {
+ shouldThrowWithMessage(
+ "The class com.xemantic.anthropic.tool.UsableToolTest.NoAnnotationTool must be annotated with @SerializableTool"
+ ) {
+ toolOf()
+ }
+ }
+
+ @Serializable
+ class OnlySerializableAnnotationTool : UsableTool {
+ override fun use(toolUseId: String) = ToolResult(toolUseId, "nothing")
+ }
+
+ @Test
+ fun shouldFailToCreateToolWithOnlySerializableAnnotation() {
+ shouldThrowWithMessage(
+ "The class com.xemantic.anthropic.tool.UsableToolTest.OnlySerializableAnnotationTool must be annotated with @SerializableTool"
+ ) {
+ toolOf()
+ }
+ }
+
+}
diff --git a/src/jvmMain/kotlin/JvmAnthropic.kt b/src/jvmMain/kotlin/JvmAnthropic.kt
index 84f8a11..036434c 100644
--- a/src/jvmMain/kotlin/JvmAnthropic.kt
+++ b/src/jvmMain/kotlin/JvmAnthropic.kt
@@ -1,7 +1,40 @@
package com.xemantic.anthropic
+import com.xemantic.anthropic.message.MessageRequest
+import com.xemantic.anthropic.message.MessageResponse
+import kotlinx.coroutines.runBlocking
+import java.util.function.Consumer
+
actual val envApiKey: String?
get() = System.getenv("ANTHROPIC_API_KEY")
actual val missingApiKeyMessage: String
get() = "apiKey is missing, it has to be provided as a parameter or as an ANTHROPIC_API_KEY environment variable."
+
+// a very early version of Java only SDK, adapting Kotlin idioms and coroutines
+// it might change a lot in the future
+class JavaAnthropic {
+
+ companion object {
+
+ @JvmStatic
+ fun create(): Anthropic = Anthropic()
+
+ @JvmStatic
+ fun create(
+ configurer: Consumer
+ ): Anthropic {
+ return Anthropic { configurer.accept(this) }
+ }
+
+ }
+
+}
+
+fun Anthropic.createMessage(
+ builder: Consumer
+): MessageResponse = runBlocking {
+ messages.create {
+ builder.accept(this)
+ }
+}