From 8ba0cb4054e1fb4fcd1f3e6340dc4cd795f5171b Mon Sep 17 00:00:00 2001 From: Kazik Pogoda Date: Fri, 13 Dec 2024 21:11:49 +0100 Subject: [PATCH] Feature: Cost calculation of API usage (#23) --- .github/FUNDING.yml | 2 +- .gitignore | 1 + README.md | 6 +- build.gradle.kts | 20 +- gradle/libs.versions.toml | 7 +- src/commonMain/kotlin/Anthropic.kt | 47 +++- src/commonMain/kotlin/AnthropicJson.kt | 10 + src/commonMain/kotlin/Models.kt | 86 ++++--- src/commonMain/kotlin/content/Tool.kt | 75 ++++-- src/commonMain/kotlin/message/Messages.kt | 14 +- src/commonMain/kotlin/tool/Tools.kt | 27 +- src/commonMain/kotlin/usage/Usage.kt | 87 +++++-- src/commonMain/kotlin/usage/UsageCollector.kt | 54 ++++ src/commonTest/kotlin/AnthropicTest.kt | 238 ++++++++++++------ src/commonTest/kotlin/content/DocumentTest.kt | 22 +- src/commonTest/kotlin/content/ImageTest.kt | 22 +- .../kotlin/content/ToolResultTest.kt | 64 +++++ .../kotlin/error/ErrorResponseTest.kt | 24 +- .../kotlin/message/MessageRequestTest.kt | 101 ++++---- .../kotlin/message/MessageResponseTest.kt | 53 ++-- src/commonTest/kotlin/message/MessageTest.kt | 16 ++ .../kotlin/message/ToolResultTest.kt | 20 -- .../kotlin/test/AnthropicTestSupport.kt | 17 -- src/commonTest/kotlin/tool/ToolChoiceTest.kt | 22 +- src/commonTest/kotlin/tool/ToolInputTest.kt | 45 ++-- src/commonTest/kotlin/usage/CostTest.kt | 120 +++++++++ .../kotlin/usage/UsageCollectorTest.kt | 136 ++++++++++ src/jvmTest/kotlin/content/MagicNumberTest.kt | 42 ++-- 28 files changed, 1012 insertions(+), 366 deletions(-) create mode 100644 src/commonMain/kotlin/usage/UsageCollector.kt create mode 100644 src/commonTest/kotlin/content/ToolResultTest.kt create mode 100644 src/commonTest/kotlin/message/MessageTest.kt delete mode 100644 src/commonTest/kotlin/message/ToolResultTest.kt create mode 100644 src/commonTest/kotlin/usage/CostTest.kt create mode 100644 src/commonTest/kotlin/usage/UsageCollectorTest.kt diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index bcb35c2..a2edaa0 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,3 +1,3 @@ # These are supported funding model platforms -github:xemantic # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +github: xemantic diff --git a/.gitignore b/.gitignore index c52b98f..2e6fa73 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ build/ ### IntelliJ IDEA ### /.idea/ +!/.idea/copyright/ *.iws *.iml *.ipr diff --git a/README.md b/README.md index 4fb2af0..844568e 100644 --- a/README.md +++ b/README.md @@ -158,13 +158,11 @@ fun main() = runBlocking { println(initialResponse) conversation += initialResponse - val tool = initialResponse.content.filterIsInstance().first() - val toolResult = tool.use() - conversation += Message { +toolResult } + conversation += initialResonse.useTools() val finalResponse = client.messages.create { messages = conversation - useTools() + allTools() } println("Final response:") println(finalResponse) diff --git a/build.gradle.kts b/build.gradle.kts index 879593c..89cbd0f 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -11,6 +11,7 @@ import org.jetbrains.kotlin.gradle.dsl.KotlinVersion plugins { alias(libs.plugins.kotlin.multiplatform) alias(libs.plugins.kotlin.plugin.serialization) + alias(libs.plugins.kotlinx.atomicfu) alias(libs.plugins.kotlin.plugin.power.assert) alias(libs.plugins.dokka) alias(libs.plugins.versions) @@ -139,6 +140,14 @@ kotlin { sourceSets { + all { + languageSettings { + languageVersion = kotlinTarget.version + apiVersion = kotlinTarget.version + progressiveMode = true + } + } + commonMain { dependencies { implementation(libs.kotlinx.datetime) @@ -147,6 +156,7 @@ kotlin { implementation(libs.ktor.client.logging) implementation(libs.ktor.serialization.kotlinx.json) implementation(libs.xemantic.ai.tool.schema) + api(libs.xemantic.ai.money) } } @@ -154,7 +164,7 @@ kotlin { dependencies { implementation(libs.kotlin.test) implementation(libs.kotlinx.coroutines.test) - implementation(libs.kotest.assertions.core) + implementation(libs.xemantic.kotlin.test) implementation(libs.kotest.assertions.json) } } @@ -233,10 +243,10 @@ tasks.withType { } powerAssert { -// functions = listOf( -// "io.kotest.matchers.shouldBe" -// ) -// includedSourceSets = listOf("commonTest", "jvmTest", "nativeTest") + functions = listOf( + "com.xemantic.kotlin.test.assert", + "com.xemantic.kotlin.test.have" + ) } // maybe this one is not necessary? diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 639822e..ff99ef2 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -5,10 +5,13 @@ javaTarget = "17" kotlin = "2.1.0" kotlinxCoroutines = "1.9.0" kotlinxDatetime = "0.6.1" +kotlinxAtomicFu = "0.26.0" ktor = "3.0.1" kotest = "6.0.0.M1" +xemanticKotlinTest = "1.0" xemanticAiToolSchema = "0.1.1" +xemanticAiMoney = "0.2" # logging is not used at the moment, might be enabled later log4j = "2.24.2" @@ -24,7 +27,9 @@ kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-t kotlinx-datetime = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.ref = "kotlinxDatetime" } # xemantic +xemantic-kotlin-test = { module = "com.xemantic.kotlin:xemantic-kotlin-test", version.ref = "xemanticKotlinTest"} xemantic-ai-tool-schema = { module = "com.xemantic.ai:xemantic-ai-tool-schema", version.ref = "xemanticAiToolSchema"} +xemantic-ai-money = { module = "com.xemantic.ai:xemantic-ai-money", version.ref = "xemanticAiMoney"} # logging libs log4j-slf4j2 = { module = "org.apache.logging.log4j:log4j-slf4j2-impl", version.ref = "log4j" } @@ -40,13 +45,13 @@ ktor-client-java = { module = "io.ktor:ktor-client-java", version.ref = "ktor" } ktor-client-curl = { module = "io.ktor:ktor-client-curl", version.ref = "ktor" } ktor-client-darwin = { module = "io.ktor:ktor-client-darwin", version.ref = "ktor" } -kotest-assertions-core = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" } kotest-assertions-json = { module = "io.kotest:kotest-assertions-json", version.ref = "kotest" } [plugins] kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" } kotlin-plugin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } kotlin-plugin-power-assert = { id = "org.jetbrains.kotlin.plugin.power-assert", version.ref = "kotlin" } +kotlinx-atomicfu = { id = "org.jetbrains.kotlinx.atomicfu", version.ref = "kotlinxAtomicFu" } dokka = { id = "org.jetbrains.dokka", version.ref = "dokkaPlugin" } versions = { id = "com.github.ben-manes.versions", version.ref = "versionsPlugin" } publish = { id = "io.github.gradle-nexus.publish-plugin", version.ref = "publishPlugin" } diff --git a/src/commonMain/kotlin/Anthropic.kt b/src/commonMain/kotlin/Anthropic.kt index c77561a..f58e9c9 100644 --- a/src/commonMain/kotlin/Anthropic.kt +++ b/src/commonMain/kotlin/Anthropic.kt @@ -10,6 +10,9 @@ import com.xemantic.anthropic.message.MessageResponse import com.xemantic.anthropic.tool.BuiltInTool import com.xemantic.anthropic.tool.Tool import com.xemantic.anthropic.tool.ToolInput +import com.xemantic.anthropic.usage.Cost +import com.xemantic.anthropic.usage.Usage +import com.xemantic.anthropic.usage.UsageCollector import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.plugins.* @@ -66,7 +69,8 @@ fun Anthropic( defaultMaxTokens = config.defaultMaxTokens, directBrowserAccess = config.directBrowserAccess, logLevel = if (config.logHttp) LogLevel.ALL else LogLevel.NONE, - toolMap = config.tools.associateBy { it.name } + modelMap = config.modelMap, + toolMap = config.tools.associateBy { it.name }, ) } // TODO this can be a second constructor, then toolMap can be private @@ -79,6 +83,7 @@ class Anthropic internal constructor( val defaultMaxTokens: Int, val directBrowserAccess: Boolean, val logLevel: LogLevel, + private val modelMap: Map, private val toolMap: Map ) { @@ -87,7 +92,7 @@ class Anthropic internal constructor( var anthropicVersion: String = DEFAULT_ANTHROPIC_VERSION var anthropicBeta: String? = null var apiBase: String = ANTHROPIC_API_BASE - var defaultModel: Model = Model.DEFAULT + var defaultModel: AnthropicModel = Model.DEFAULT var defaultMaxTokens: Int = defaultModel.maxOutput var directBrowserAccess: Boolean = false @@ -95,6 +100,8 @@ class Anthropic internal constructor( var tools: List = emptyList() + var modelMap: Map = Model.entries.associateBy { it.id } + // TODO in the future this should be rather Tool builder inline fun tool( cacheControl: CacheControl? = null, @@ -176,6 +183,7 @@ class Anthropic internal constructor( val response = apiResponse.body() when (response) { is MessageResponse -> response.apply { + updateUsage(response) content.filterIsInstance() .forEach { toolUse -> val tool = toolMap[toolUse.name] @@ -192,7 +200,9 @@ class Anthropic internal constructor( error = response.error, httpStatusCode = apiResponse.status ) - else -> throw RuntimeException("Unsupported response: $response") // should never happen + else -> throw RuntimeException( + "Unsupported response: $response" + ) // should never happen } return response } @@ -222,8 +232,13 @@ class Anthropic internal constructor( .map { it.data } .filterNotNull() .map { anthropicJson.decodeFromString(it) } - .collect { - emit(it) + .collect { event -> + // TODO we need better way of handling subsequent deltas with usage + if (event is Event.MessageStart) { + // TODO more rules are needed here + updateUsage(event.message) + } + emit(event) } } } @@ -232,5 +247,25 @@ class Anthropic internal constructor( val messages = Messages() -} + private val usageCollector = UsageCollector() + + val usage: Usage get() = usageCollector.usage + + val cost: Cost get() = usageCollector.cost + + override fun toString(): String = "Anthropic($usage, $cost)" + private val MessageResponse.anthropicModel: AnthropicModel get() = requireNotNull( + modelMap[model] + ) { + "The model returned in the response is not known to Anthropic API client: $id" + } + + private fun updateUsage(response: MessageResponse) { + usageCollector.update( + modelCost = response.anthropicModel.cost, + usage = response.usage + ) + } + +} diff --git a/src/commonMain/kotlin/AnthropicJson.kt b/src/commonMain/kotlin/AnthropicJson.kt index e32bc73..651cff9 100644 --- a/src/commonMain/kotlin/AnthropicJson.kt +++ b/src/commonMain/kotlin/AnthropicJson.kt @@ -22,6 +22,7 @@ import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.descriptors.SerialKind import kotlinx.serialization.descriptors.buildSerialDescriptor +import kotlinx.serialization.encodeToString import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.encoding.Encoder import kotlinx.serialization.json.Json @@ -65,6 +66,15 @@ val anthropicJson: Json = Json { encodeDefaults = true } +@OptIn(ExperimentalSerializationApi::class) +@PublishedApi +internal val prettyAnthropicJson: Json = Json(from = anthropicJson) { + prettyPrint = true + prettyPrintIndent = " " +} + +inline fun T.toPrettyJson(): String = prettyAnthropicJson.encodeToString(this) + private object ResponseSerializer : JsonContentPolymorphicSerializer( baseClass = Response::class ) { diff --git a/src/commonMain/kotlin/Models.kt b/src/commonMain/kotlin/Models.kt index ee0b794..11d412f 100644 --- a/src/commonMain/kotlin/Models.kt +++ b/src/commonMain/kotlin/Models.kt @@ -1,12 +1,40 @@ package com.xemantic.anthropic -enum class Model( - val id: String, - val contextWindow: Int, - val maxOutput: Int, - val messageBatchesApi: Boolean, +import com.xemantic.ai.money.Money +import com.xemantic.ai.money.Ratio +import com.xemantic.anthropic.usage.Cost + +/** + * The model used by the API. + * E.g., Claude LLM `sonnet`, `opus`, `haiku` family. + */ +interface AnthropicModel { + + val id: String + val contextWindow: Int + val maxOutput: Int + val messageBatchesApi: Boolean val cost: Cost -) { + +} + +val ANTHROPIC_TOKEN_COST_RATIO = Money.Ratio("0.000001") + +val String.dollarsPerMillion: Money get() = Money(this) * ANTHROPIC_TOKEN_COST_RATIO + +/** + * Predefined models supported by Anthropic API. + * + * It could include Vertex AI (Google Cloud), or Bedrock (AWS) models in the future. + */ +// TODO model should be interface AnthropicApi models should be enum +enum class Model( + override val id: String, + override val contextWindow: Int, + override val maxOutput: Int, + override val messageBatchesApi: Boolean, + override val cost: Cost +) : AnthropicModel { CLAUDE_3_5_SONNET( id = "claude-3-5-sonnet-latest", @@ -14,8 +42,8 @@ enum class Model( maxOutput = 8182, messageBatchesApi = true, cost = Cost( - inputTokens = 3.0, - outputTokens = 15.0 + inputTokens = "3".dollarsPerMillion, + outputTokens = "15".dollarsPerMillion ) ), @@ -25,8 +53,8 @@ enum class Model( maxOutput = 8182, messageBatchesApi = true, cost = Cost( - inputTokens = 3.0, - outputTokens = 15.0 + inputTokens = "3".dollarsPerMillion, + outputTokens = "15".dollarsPerMillion ) ), @@ -36,8 +64,8 @@ enum class Model( maxOutput = 8182, messageBatchesApi = true, cost = Cost( - inputTokens = 1.0, - outputTokens = 5.0 + inputTokens = "0.80".dollarsPerMillion, + outputTokens = "4".dollarsPerMillion ) ), @@ -47,8 +75,8 @@ enum class Model( maxOutput = 8182, messageBatchesApi = true, cost = Cost( - inputTokens = 1.0, - outputTokens = 5.0 + inputTokens = "0.80".dollarsPerMillion, + outputTokens = "4".dollarsPerMillion ) ), @@ -58,8 +86,8 @@ enum class Model( maxOutput = 8182, messageBatchesApi = true, cost = Cost( - inputTokens = 3.0, - outputTokens = 15.0 + inputTokens = "3".dollarsPerMillion, + outputTokens = "15".dollarsPerMillion ) ), @@ -69,8 +97,8 @@ enum class Model( maxOutput = 4096, messageBatchesApi = true, cost = Cost( - inputTokens = 15.0, - outputTokens = 75.0 + inputTokens = "15".dollarsPerMillion, + outputTokens = "75".dollarsPerMillion ) ), @@ -80,8 +108,8 @@ enum class Model( maxOutput = 4096, messageBatchesApi = true, cost = Cost( - inputTokens = 15.0, - outputTokens = 75.0 + inputTokens = "15".dollarsPerMillion, + outputTokens = "75".dollarsPerMillion ) ), @@ -91,8 +119,8 @@ enum class Model( maxOutput = 4096, messageBatchesApi = true, cost = Cost( - inputTokens = 3.0, - outputTokens = 15.0 + inputTokens = "3".dollarsPerMillion, + outputTokens = "15".dollarsPerMillion ) ), @@ -102,21 +130,15 @@ enum class Model( maxOutput = 4096, messageBatchesApi = true, cost = Cost( - inputTokens = .25, - outputTokens = 1.25 + inputTokens = "0.25".dollarsPerMillion, + outputTokens = "1.25".dollarsPerMillion ) ); - /** - * Cost per MTok - */ - data class Cost( - val inputTokens: Double, - val outputTokens: Double - ) - companion object { + val DEFAULT: Model = CLAUDE_3_5_SONNET + } } diff --git a/src/commonMain/kotlin/content/Tool.kt b/src/commonMain/kotlin/content/Tool.kt index 519a2a7..0acc720 100644 --- a/src/commonMain/kotlin/content/Tool.kt +++ b/src/commonMain/kotlin/content/Tool.kt @@ -2,7 +2,6 @@ package com.xemantic.anthropic.content import com.xemantic.anthropic.anthropicJson import com.xemantic.anthropic.cache.CacheControl -import com.xemantic.anthropic.message.toNullIfEmpty import com.xemantic.anthropic.tool.Tool import com.xemantic.anthropic.tool.ToolInput import kotlinx.serialization.SerialName @@ -41,13 +40,16 @@ data class ToolUse( val toolInput = decodeInput() toolInput.use(toolUseId = id) } else { - ToolResult(toolUseId = id) { + ToolResult { + toolUseId = id error("Cannot use unknown tool: $name") } } } catch (e: Exception) { + // TODO a better way to log this exception e.printStackTrace() - ToolResult(toolUseId = id) { + ToolResult { + toolUseId = id error(e.message ?: "Unknown error occurred") } } @@ -55,9 +57,10 @@ data class ToolUse( } +@ConsistentCopyVisibility @Serializable @SerialName("tool_result") -data class ToolResult( +data class ToolResult private constructor( @SerialName("tool_use_id") val toolUseId: String, val content: List? = null, @@ -69,7 +72,42 @@ data class ToolResult( class Builder : ContentBuilder { - override val content: MutableList = mutableListOf() + private class ToolResultList( + private val list: MutableList = mutableListOf() + ) : MutableList by list { + + override fun add(element: Content): Boolean { + require(element is Image || element is Text) { + "Only Image and Text content element is allowed" + } + return list.add(element) + } + + override fun add(index: Int, element: Content) { + require(element is Image || element is Text) { + "Only Image and Text content element is allowed" + } + return list.add(index, element) + } + + override fun addAll(elements: Collection): Boolean { + require(elements.all { it is Image || it is Text}) { + "Only Image and Text content elements are allowed" + } + return list.addAll(elements) + } + + override fun set(index: Int, element: Content): Content { + require(element is Image || element is Text) { + "Only Image and Text content element is allowed" + } + return list.set(index, element) + } + } + + var toolUseId: String? = null + + override val content: MutableList = ToolResultList() var isError: Boolean? = null var cacheControl: CacheControl? = null @@ -79,24 +117,33 @@ data class ToolResult( isError = true } + operator fun plus(text: Text) { + content += text + } + + operator fun plus(image: Image) { + content += image + } + + fun build(): ToolResult = ToolResult( + toolUseId = requireNotNull(toolUseId) { + "toolUseId cannot be null" + }, + content = buildList { addAll(content) }, + isError = isError, + cacheControl = cacheControl + ) + } } @OptIn(ExperimentalContracts::class) inline fun ToolResult( - toolUseId: String, block: ToolResult.Builder.() -> Unit = {} ): ToolResult { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - val builder = ToolResult.Builder() - block(builder) - return ToolResult( - toolUseId = toolUseId, - content = builder.content.toNullIfEmpty(), - isError = if (builder.isError == null) false else null, - cacheControl = builder.cacheControl - ) + return ToolResult.Builder().also(block).build() } diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt index 4fd46ca..93fe7e0 100644 --- a/src/commonMain/kotlin/message/Messages.kt +++ b/src/commonMain/kotlin/message/Messages.kt @@ -5,13 +5,14 @@ import com.xemantic.anthropic.Response import com.xemantic.anthropic.cache.CacheControl import com.xemantic.anthropic.content.Content import com.xemantic.anthropic.content.ContentBuilder +import com.xemantic.anthropic.content.ToolUse +import com.xemantic.anthropic.toPrettyJson import com.xemantic.anthropic.tool.Tool import com.xemantic.anthropic.tool.ToolChoice import com.xemantic.anthropic.tool.ToolInput import com.xemantic.anthropic.tool.toolName import com.xemantic.anthropic.usage.Usage import kotlinx.serialization.* -import kotlinx.serialization.json.JsonClassDiscriminator import kotlin.collections.mutableListOf /** @@ -152,6 +153,8 @@ data class MessageRequest( ) } + override fun toString(): String = toPrettyJson() + } /** @@ -248,4 +251,13 @@ data class MessageResponse( content += this@MessageResponse.content } + suspend fun useTools(): Message { + val toolResults = content.filterIsInstance().map { + it.use() + } + return Message { + this@Message.content += toolResults + } + } + } diff --git a/src/commonMain/kotlin/tool/Tools.kt b/src/commonMain/kotlin/tool/Tools.kt index ab4687d..f7ec54c 100644 --- a/src/commonMain/kotlin/tool/Tools.kt +++ b/src/commonMain/kotlin/tool/Tools.kt @@ -3,10 +3,12 @@ package com.xemantic.anthropic.tool import com.xemantic.ai.tool.schema.JsonSchema import com.xemantic.ai.tool.schema.generator.jsonSchemaOf import com.xemantic.ai.tool.schema.meta.Description +import com.xemantic.anthropic.anthropicJson import com.xemantic.anthropic.cache.CacheControl import com.xemantic.anthropic.content.Content import com.xemantic.anthropic.content.ToolResult import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.KSerializer import kotlinx.serialization.MetaSerializable import kotlinx.serialization.SerialName @@ -15,6 +17,7 @@ import kotlinx.serialization.SerializationException import kotlinx.serialization.Transient import kotlinx.serialization.json.JsonClassDiscriminator import kotlinx.serialization.serializer +import kotlinx.serialization.serializerOrNull @Serializable @JsonClassDiscriminator("name") @@ -65,8 +68,10 @@ abstract class BuiltInTool( * with a given tool use ID. The implementation of the [use] method should * contain the logic for executing the tool and returning the [ToolResult]. */ +@Serializable abstract class ToolInput { + @Transient private var block: suspend ToolResult.Builder.() -> Any? = {} fun use(block: suspend ToolResult.Builder.() -> Any?) { @@ -79,15 +84,23 @@ abstract class ToolInput { * @param toolUseId A unique identifier for this particular use of the tool. * @return A [ToolResult] containing the outcome of executing the tool. */ + // TODO this needs a big test coverage suspend fun use(toolUseId: String): ToolResult { - return ToolResult(toolUseId) { + return ToolResult { + this.toolUseId = toolUseId val result = block(this) - if (result != null) { - when (result) { - is Content -> +result - is Unit -> {} // nothing to do - !is Unit -> +result.toString() - else -> throw IllegalStateException("Tool use {} returned not supported: $this") + if ((result != null) && (result !is Unit)) { + if (result is Content) { + +result + } else { + @OptIn(InternalSerializationApi::class) + val serializer = result::class.serializerOrNull() as KSerializer? + val value = if (serializer != null) { + anthropicJson.encodeToString(serializer, result) + } else { + result.toString() + } + +value } } } diff --git a/src/commonMain/kotlin/usage/Usage.kt b/src/commonMain/kotlin/usage/Usage.kt index 6a2d6cc..1f18892 100644 --- a/src/commonMain/kotlin/usage/Usage.kt +++ b/src/commonMain/kotlin/usage/Usage.kt @@ -1,6 +1,10 @@ package com.xemantic.anthropic.usage -import com.xemantic.anthropic.Model +import com.xemantic.ai.money.Money +import com.xemantic.ai.money.ONE +import com.xemantic.ai.money.Ratio +import com.xemantic.ai.money.times +import com.xemantic.ai.money.ZERO import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -14,37 +18,72 @@ data class Usage( val cacheCreationInputTokens: Int? = null, @SerialName("cache_read_input_tokens") val cacheReadInputTokens: Int? = null, -) - -fun Usage.add(usage: Usage): Usage = Usage( - inputTokens = inputTokens + usage.inputTokens, - outputTokens = outputTokens + usage.outputTokens, - cacheReadInputTokens = (cacheReadInputTokens ?: 0) + (usage.cacheReadInputTokens ?: 0), - cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) + (usage.cacheCreationInputTokens ?: 0), -) - -fun Usage.cost( - model: Model, - isBatch: Boolean = false -): Cost = Cost( - inputTokens = inputTokens * model.cost.inputTokens / 1000000.0 * (if (isBatch) .5 else 1.0), - outputTokens = outputTokens * model.cost.outputTokens / 1000000.0 * (if (isBatch) .5 else 1.0), - cacheReadInputTokens = (cacheReadInputTokens ?: 0) * model.cost.inputTokens * .1 / 1000000.0 * (if (isBatch) .5 else 1.0), - cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) * model.cost.inputTokens * .25 / 1000000.0 * (if (isBatch) .5 else 1.0) -) +) { + + companion object { + + val ZERO = Usage( + inputTokens = 0, + outputTokens = 0, + cacheCreationInputTokens = 0, + cacheReadInputTokens = 0 + ) + + } + + operator fun plus(usage: Usage): Usage = Usage( + inputTokens = inputTokens + usage.inputTokens, + outputTokens = outputTokens + usage.outputTokens, + cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) + (usage.cacheCreationInputTokens ?: 0), + cacheReadInputTokens = (cacheReadInputTokens ?: 0) + (usage.cacheReadInputTokens ?: 0) + ) + + fun cost( + modelCost: Cost, + costRatio: Money.Ratio = Money.Ratio.ONE + ): Cost = Cost( + inputTokens = inputTokens * modelCost.inputTokens * costRatio, + outputTokens = outputTokens * modelCost.outputTokens * costRatio, + // how cacheCreation and batch are playing together? + cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) * modelCost.cacheCreationInputTokens * costRatio, + cacheReadInputTokens = (cacheReadInputTokens ?: 0) * modelCost.cacheReadInputTokens * costRatio + ) +} + +@Serializable data class Cost( - val inputTokens: Double, - val outputTokens: Double, - val cacheCreationInputTokens: Double, - val cacheReadInputTokens: Double + val inputTokens: Money, + val outputTokens: Money, + val cacheCreationInputTokens: Money = inputTokens * Money.Ratio("1.25"), + val cacheReadInputTokens: Money = inputTokens * Money.Ratio("0.1"), ) { - fun add(cost: Cost): Cost = Cost( + operator fun plus(cost: Cost): Cost = Cost( inputTokens = inputTokens + cost.inputTokens, outputTokens = outputTokens + cost.outputTokens, cacheCreationInputTokens = cacheCreationInputTokens + cost.cacheCreationInputTokens, cacheReadInputTokens = cacheReadInputTokens + cost.cacheReadInputTokens ) + operator fun times(amount: Money): Cost = Cost( + inputTokens = inputTokens * amount, + outputTokens = outputTokens * amount, + cacheCreationInputTokens = cacheCreationInputTokens * amount, + cacheReadInputTokens = cacheReadInputTokens * amount + ) + + val total: Money get() = + inputTokens + + outputTokens + + cacheCreationInputTokens + + cacheReadInputTokens + + companion object { + val ZERO = Cost( + inputTokens = Money.ZERO, + outputTokens = Money.ZERO + ) + } + } diff --git a/src/commonMain/kotlin/usage/UsageCollector.kt b/src/commonMain/kotlin/usage/UsageCollector.kt new file mode 100644 index 0000000..805ac54 --- /dev/null +++ b/src/commonMain/kotlin/usage/UsageCollector.kt @@ -0,0 +1,54 @@ +package com.xemantic.anthropic.usage + +import com.xemantic.ai.money.Money +import com.xemantic.ai.money.ONE +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update + +/** + * Collects overall [Usage] and calculates [Cost] information + * based on [com.xemantic.anthropic.message.MessageResponse]s returned + * by API calls. + */ +class UsageCollector { + + // Atomic in case of several threads updating this data concurrently + private val _usage = atomic(Usage.ZERO) + + /** + * The current accumulated usage. + */ + val usage: Usage get() = _usage.value + + // Atomic in case of several threads updating this data concurrently + private val _cost = atomic(Cost.ZERO) + + /** + * The current accumulated cost. + */ + val cost: Cost get() = _cost.value + + /** + * Updates the usage and cost based on the provided parameters. + * + * @param usage The usage to add. + * @param modelCost The cost of the used model. + * @param costRatio The cost ratio to apply, defaults to 1, but might be different for batch requests, etc. + */ + fun update( + usage: Usage, + modelCost: Cost, + costRatio: Money.Ratio = Money.Ratio.ONE, + ) { + _usage.update { it + usage } + _cost.update { it + usage.cost(modelCost, costRatio) } + } + + /** + * Returns a string representation of the UsageCollector. + * + * @return A string containing the current usage and cost. + */ + override fun toString(): String = "UsageCollector(usage=$usage, cost=$cost)" + +} diff --git a/src/commonTest/kotlin/AnthropicTest.kt b/src/commonTest/kotlin/AnthropicTest.kt index 94a3bd8..392df28 100644 --- a/src/commonTest/kotlin/AnthropicTest.kt +++ b/src/commonTest/kotlin/AnthropicTest.kt @@ -1,5 +1,7 @@ package com.xemantic.anthropic +import com.xemantic.ai.money.Money +import com.xemantic.ai.money.ZERO import com.xemantic.anthropic.event.Delta.TextDelta import com.xemantic.anthropic.event.Event import com.xemantic.anthropic.message.Message @@ -12,29 +14,110 @@ import com.xemantic.anthropic.tool.FibonacciTool import com.xemantic.anthropic.tool.TestDatabase import com.xemantic.anthropic.content.Text import com.xemantic.anthropic.content.ToolUse -import com.xemantic.anthropic.test.assert -import io.kotest.matchers.ints.shouldBeGreaterThan -import io.kotest.matchers.shouldBe -import io.kotest.matchers.shouldNotBe -import io.kotest.matchers.string.shouldContain -import io.kotest.matchers.string.shouldStartWith -import io.kotest.matchers.types.instanceOf +import com.xemantic.anthropic.usage.Cost +import com.xemantic.anthropic.usage.Usage +import com.xemantic.kotlin.test.be +import com.xemantic.kotlin.test.have +import com.xemantic.kotlin.test.should import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import kotlinx.coroutines.test.runTest -import kotlin.test.Ignore +import kotlin.collections.any import kotlin.test.Test class AnthropicTest { + @Test + fun `Should create Anthropic instance with 0 Usage and Cost`() { + Anthropic() should { + have(usage == Usage.ZERO) + have(cost == Cost.ZERO) + } + } + @Test fun `Should receive an introduction from Claude`() = runTest { // given - val client = Anthropic() + val anthropic = Anthropic() // when - val response = client.messages.create { + val response = anthropic.messages.create { + +Message { + +"Hello World! What's your name?" + } + } + + // then + response should { + have(role == Role.ASSISTANT) + have("claude" in model) + have(stopReason == StopReason.END_TURN) + have(content.size == 1) + content[0] should { + be() + have("Claude" in text) + } + have(stopSequence == null) + usage should { + have(inputTokens == 15) + have(outputTokens > 0) + } + } + } + + @Test + fun `Should receive Usage and update Cost calculation`() = runTest { + // given + val anthropic = Anthropic() + + // when + val response = anthropic.messages.create { + +Message { + +"Hello Claude! I am testing the amount of input and output tokens." + } + } + + // then + response should { + have(role == Role.ASSISTANT) + have("claude" in model) + have(stopReason == StopReason.END_TURN) + have(content.size == 1) + have(stopSequence == null) + usage should { + have(inputTokens == 21) + have(outputTokens > 0) + have(cacheCreationInputTokens == null) + have(cacheReadInputTokens == null) + } + } + + anthropic should { + usage should { + have(inputTokens == 21) + have(inputTokens > 0) + have(cacheCreationInputTokens == 0) + have(cacheReadInputTokens == 0) + } + cost should { + have(inputTokens >= Money.ZERO && inputTokens == Money("0.000063")) + have(outputTokens >= Money.ZERO && inputTokens <= Money("0.0005")) + have(cacheCreationInputTokens == Money.ZERO) + have(cacheReadInputTokens == Money.ZERO) + } + } + + } + + @Test + fun `Should use system prompt`() = runTest { + // given + val anthropic = Anthropic() + + // when + val response = anthropic.messages.create { + system("Whatever the human says, answer \"HAHAHA\"") +Message { +"Hello World! What's your name?" } @@ -42,17 +125,12 @@ class AnthropicTest { } // then - response.assert { - role shouldBe Role.ASSISTANT - model shouldBe "claude-3-5-sonnet-20241022" - stopReason shouldBe StopReason.END_TURN - content.size shouldBe 1 - content[0] shouldBe instanceOf() - val text = content[0] as Text - text.text shouldContain "Claude" - stopSequence shouldBe null - usage.inputTokens shouldBe 15 - usage.outputTokens shouldBeGreaterThan 0 + response should { + have(content.size == 1) + content[0] should { + be() + have(text == "HAHAHA") + } } } @@ -71,7 +149,7 @@ class AnthropicTest { .joinToString(separator = "") // then - response shouldBe "The sun slowly dipped below the horizon, painting the sky in a breathtaking array of oranges, pinks, and purples." + assert(response == "The sun slowly dipped below the horizon, painting the sky in a breathtaking array of oranges, pinks, and purples.") } @Test @@ -91,17 +169,16 @@ class AnthropicTest { conversation += initialResponse // then - initialResponse.assert { - stopReason shouldBe StopReason.TOOL_USE - content.size shouldBe 1 // and therefore there is only ToolUse without commentary - content[0] shouldBe instanceOf() - (content[0] as ToolUse).name shouldBe "Calculator" + initialResponse should { + have(stopReason == StopReason.TOOL_USE) + have(content.size == 1) // and therefore there is only ToolUse without commentary + content[0] should { + be() + have(name == "Calculator") + } } - val toolUse = initialResponse.content[0] as ToolUse - val result = toolUse.use() // here we execute the tool - - conversation += Message { +result } + conversation += initialResponse.useTools() // when val resultResponse = client.messages.create { @@ -110,11 +187,13 @@ class AnthropicTest { } // then - resultResponse.assert { - stopReason shouldBe StopReason.END_TURN - content.size shouldBe 1 - content[0] shouldBe instanceOf() - (content[0] as Text).text shouldContain "105" + resultResponse should { + have(stopReason == StopReason.END_TURN) + have(content.size == 1) + content[0] should { + be() + have("105" in text) + } } } @@ -133,59 +212,71 @@ class AnthropicTest { // then val toolUse = response.content.filterIsInstance().first() - toolUse.name shouldBe "FibonacciTool" + toolUse should { + have(name == "FibonacciTool") + } val result = toolUse.use() - result.assert { - toolUseId shouldBe toolUse.id - isError shouldBe false - content shouldBe listOf(Text(text = "267914296")) + result should { + have(toolUseId == toolUse.id) + have(content == listOf(Text(text = "267914296"))) } } @Test - @Ignore // this test is flaky because it has wrong sometimes claude will decide to use both tools at once. fun `Should use 2 tools in sequence`() = runTest { // given val client = Anthropic { tool() tool() } - - // when + val systemPrompt = "Always use tools to perform calculations. Never calculate on your own, even if you know the answer." + val prompt = "Calculate Fibonacci number 42 and then divide it by 42" val conversation = mutableListOf() - conversation += Message { +"Calculate Fibonacci number 42 and then divide it by 42" } + conversation += Message { +prompt } + // when val fibonacciResponse = client.messages.create { + system(systemPrompt) messages = conversation singleTool() } conversation += fibonacciResponse - val fibonacciToolUse = fibonacciResponse.content.filterIsInstance().first() - fibonacciToolUse.name shouldBe "FibonacciTool" - val fibonacciResult = fibonacciToolUse.use() - conversation += Message { +fibonacciResult } + // then + fibonacciResponse should { + have(stopReason == StopReason.TOOL_USE) + have(content.any { it is ToolUse && it.name == "FibonacciTool" }) + } + conversation += fibonacciResponse.useTools() + // when val calculatorResponse = client.messages.create { messages = conversation singleTool() } conversation += calculatorResponse - val calculatorToolUse = calculatorResponse.content.filterIsInstance().first() - calculatorToolUse.name shouldBe "Calculator" - val calculatorResult = calculatorToolUse.use() - conversation += Message { +calculatorResult } + // then + calculatorResponse should { + have(stopReason == StopReason.TOOL_USE) + have(content.any { it is ToolUse && it.name == "Calculator" }) + } + conversation += calculatorResponse.useTools() + // when val finalResponse = client.messages.create { messages = conversation allTools() } - - finalResponse.content[0] shouldBe instanceOf() - // the result might be in the format: 6,378,911.8.... - (finalResponse.content[0] as Text).text.replace(",", "") shouldContain "6378911.8" + finalResponse should { + have(content.isNotEmpty()) + content[0] should { + be() + // the result might be in the format: 6,378,911.8.... + have(text.replace(",", "").contains("6378911.8")) + } + } } @Test @@ -204,36 +295,23 @@ class AnthropicTest { singleTool() // we are forcing the use of this tool // could be also just tool() if we are confident that LLM will use this one } - val toolUse = response.content.filterIsInstance().first() - toolUse.use() // then - testDatabase.executedQuery shouldNotBe null - testDatabase.executedQuery!!.uppercase() shouldStartWith "SELECT * FROM CUSTOMER" - // depending on the response the statement might end up with semicolon, which we discard - } - - @Test - fun `Should use system prompt`() = runTest { - // given - val anthropic = Anthropic() + response should { + have(stopReason == StopReason.TOOL_USE) + have(content.any { it is ToolUse && it.name == "DatabaseQuery" }) + } // when - val response = anthropic.messages.create { - system("Whatever the human says, answer \"HAHAHA\"") - +Message { - +"Hello World! What's your name?" - } - maxTokens = 1024 - } + response.useTools() // then - response.assert { - content.size shouldBe 1 - content[0] shouldBe instanceOf() - val text = content[0] as Text - text.text shouldBe "HAHAHA" + testDatabase should { + have(executedQuery != null) + have(executedQuery!!.uppercase().startsWith("SELECT * FROM CUSTOMER")) } + + // depending on the response the statement might end up with semicolon, which we discard } } diff --git a/src/commonTest/kotlin/content/DocumentTest.kt b/src/commonTest/kotlin/content/DocumentTest.kt index f60944f..b99a14d 100644 --- a/src/commonTest/kotlin/content/DocumentTest.kt +++ b/src/commonTest/kotlin/content/DocumentTest.kt @@ -3,10 +3,9 @@ package com.xemantic.anthropic.content import com.xemantic.anthropic.Anthropic import com.xemantic.anthropic.message.Message import com.xemantic.anthropic.message.StopReason -import com.xemantic.anthropic.test.assert -import io.kotest.matchers.shouldBe -import io.kotest.matchers.string.shouldContain -import io.kotest.matchers.types.instanceOf +import com.xemantic.kotlin.test.be +import com.xemantic.kotlin.test.have +import com.xemantic.kotlin.test.should import kotlinx.coroutines.test.runTest import kotlin.test.Test @@ -24,7 +23,7 @@ const val testPdf = "JVBERi0xLjEKJcKlwrHDqwoKMSAwIG9iagogIDw8IC9UeXBlIC9DYXRhbG9 class DocumentTest { @Test - fun shouldReadTextFooFromTestImage() = runTest { + fun `Should read text FOO from test PDF`() = runTest { // given val client = Anthropic { anthropicBeta = "pdfs-2024-09-25" @@ -44,12 +43,13 @@ class DocumentTest { } // then - response.assert { - stopReason shouldBe StopReason.END_TURN - content.size shouldBe 1 - content[0] shouldBe instanceOf() - val text = content[0] as Text - text.text.uppercase() shouldContain "HELLO WORLD" + response should { + have(stopReason == StopReason.END_TURN) + have(content.size == 1) + content[0] should { + be() + assert("HELLO WORLD" in text.uppercase()) + } } } diff --git a/src/commonTest/kotlin/content/ImageTest.kt b/src/commonTest/kotlin/content/ImageTest.kt index 3673579..a4a60d2 100644 --- a/src/commonTest/kotlin/content/ImageTest.kt +++ b/src/commonTest/kotlin/content/ImageTest.kt @@ -3,10 +3,9 @@ package com.xemantic.anthropic.content import com.xemantic.anthropic.Anthropic import com.xemantic.anthropic.message.Message import com.xemantic.anthropic.message.StopReason -import com.xemantic.anthropic.test.assert -import io.kotest.matchers.shouldBe -import io.kotest.matchers.string.shouldContain -import io.kotest.matchers.types.instanceOf +import com.xemantic.kotlin.test.be +import com.xemantic.kotlin.test.have +import com.xemantic.kotlin.test.should import kotlinx.coroutines.test.runTest import kotlin.test.Test @@ -27,7 +26,7 @@ const val testImage = "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAYAAACqaXHeAAAACXBIWXMAA class ImageTest { @Test - fun shouldReadTextFooFromTestImage() = runTest { + fun `Should read text FOO from test image`() = runTest { // given val client = Anthropic() @@ -45,12 +44,13 @@ class ImageTest { } // then - response.assert { - stopReason shouldBe StopReason.END_TURN - content.size shouldBe 1 - content[0] shouldBe instanceOf() - val text = content[0] as Text - text.text.uppercase() shouldContain "FOO" + response should { + have(stopReason == StopReason.END_TURN) + have(content.size == 1) + content[0] should { + be() + have("FOO" in text.uppercase()) + } } } diff --git a/src/commonTest/kotlin/content/ToolResultTest.kt b/src/commonTest/kotlin/content/ToolResultTest.kt new file mode 100644 index 0000000..9dbbcb4 --- /dev/null +++ b/src/commonTest/kotlin/content/ToolResultTest.kt @@ -0,0 +1,64 @@ +package com.xemantic.anthropic.content + +import com.xemantic.kotlin.test.be +import com.xemantic.kotlin.test.have +import com.xemantic.kotlin.test.should +import org.junit.Test + +class ToolResultTest { + + @Test + fun `Should create ToolResult for a single String representing Text content`() { + ToolResult { + toolUseId = "42" + +"foo" + } should { + be() + have(toolUseId == "42") + have(content!!.size == 1) + content[0] should { + be() + have(text == "foo") + } + have(isError == null) + have(cacheControl == null) + } + } + + @Test + fun `Should create ToolResult for Text element representing content`() { + ToolResult { + toolUseId = "42" + +Text(text = "foo") + } should { + be() + have(toolUseId == "42") + have(content!!.size == 1) + content[0] should { + be() + have(text == "foo") + } + have(isError == null) + have(cacheControl == null) + } + } + + @Test + fun `Should create error ToolResult`() { + ToolResult { + toolUseId = "42" + error("Error message") + } should { + be() + have(toolUseId == "42") + have(content!!.size == 1) + content[0] should { + be() + have(text == "Error message") + } + have(isError == true) + have(cacheControl == null) + } + } + +} \ No newline at end of file diff --git a/src/commonTest/kotlin/error/ErrorResponseTest.kt b/src/commonTest/kotlin/error/ErrorResponseTest.kt index 199c689..ded653a 100644 --- a/src/commonTest/kotlin/error/ErrorResponseTest.kt +++ b/src/commonTest/kotlin/error/ErrorResponseTest.kt @@ -1,10 +1,10 @@ package com.xemantic.anthropic.error import com.xemantic.anthropic.Response -import com.xemantic.anthropic.test.assert import com.xemantic.anthropic.test.testJson -import io.kotest.matchers.shouldBe -import io.kotest.matchers.types.instanceOf +import com.xemantic.kotlin.test.be +import com.xemantic.kotlin.test.have +import com.xemantic.kotlin.test.should import kotlin.test.Test /** @@ -13,9 +13,8 @@ import kotlin.test.Test class ErrorResponseTest { @Test - fun shouldDeserializeToolUseMessageResponse() { - // given - val jsonResponse = """ + fun `Should deserialize ErrorResponse`() { + testJson.decodeFromString(/* language=json */ """ { "type": "error", "error": { @@ -23,16 +22,13 @@ class ErrorResponseTest { "message": "The requested resource could not be found." } } - """.trimIndent() - - val response = testJson.decodeFromString(jsonResponse) - response shouldBe instanceOf() - (response as ErrorResponse).assert { - error shouldBe Error( + """) should { + be() + have(error == Error( type = "not_found_error", message = "The requested resource could not be found." - ) + )) } } -} \ No newline at end of file +} diff --git a/src/commonTest/kotlin/message/MessageRequestTest.kt b/src/commonTest/kotlin/message/MessageRequestTest.kt index e6db1b1..333ed5a 100644 --- a/src/commonTest/kotlin/message/MessageRequestTest.kt +++ b/src/commonTest/kotlin/message/MessageRequestTest.kt @@ -10,50 +10,18 @@ import com.xemantic.anthropic.tool.bash.Bash import com.xemantic.anthropic.tool.computer.Computer import com.xemantic.anthropic.tool.editor.TextEditor import io.kotest.assertions.json.shouldEqualJson -import io.kotest.matchers.shouldBe import kotlinx.serialization.SerialName import kotlinx.serialization.encodeToString +import kotlin.test.Ignore import kotlin.test.Test -// given -@AnthropicTool("get_weather") -@Description("Get the current weather in a given location") -data class GetWeather( - @Description("The city and state, e.g. San Francisco, CA") - val location: String, - val unit: TemperatureUnit? = null -) : ToolInput() { - init { - use { - "42" - } - } -} - -@Description("The unit of temperature, either 'celsius' or 'fahrenheit'") -@Suppress("unused") // it is used by the serializer -enum class TemperatureUnit { - @SerialName("celsius") - CELSIUS, - @SerialName("fahrenheit") - FAHRENHEIT -} - /** * Tests the JSON serialization format of created Anthropic API message requests. */ class MessageRequestTest { @Test - fun defaultMessageShouldHaveRoleUser() { - // given - val message = Message {} - // then - message.role shouldBe Role.USER - } - - @Test - fun shouldCreateTheSimplestMessageRequest() { + fun `Should create simple MessageRequest`() { // given val request = MessageRequest { +Message { @@ -65,7 +33,7 @@ class MessageRequestTest { val json = testJson.encodeToString(request) // then - json shouldEqualJson """ + json shouldEqualJson /* language=json */ """ { "model": "claude-3-5-sonnet-latest", "messages": [ @@ -81,17 +49,44 @@ class MessageRequestTest { ], "max_tokens": 8182 } - """.trimIndent() + """ + // Note: max_tokens value will default to the max for a given model + // claude-3-5-sonnet-latest ist the default model + } + + // now we need some test tool + @AnthropicTool("get_weather") + @Description("Get the weather for a specific location") + data class GetWeather( + @Description("The city and state, e.g. San Francisco, CA") + val location: String, + val unit: TemperatureUnit? = null + ) : ToolInput() { + init { + use { + "42" + } + } + } + + @Description("The unit of temperature, either 'celsius' or 'fahrenheit'") + @Suppress("unused") // it is used by the serializer + enum class TemperatureUnit { + @SerialName("celsius") + CELSIUS, + @SerialName("fahrenheit") + FAHRENHEIT } @Test - fun shouldCreateMessageRequestWithMultipleTools() { + fun `Should create MessageRequest with multiple tools`() { // given val request = MessageRequest { +Message { +"Hey Claude!?" } tools = listOf( + // built in tools Computer( displayWidthPx = 1024, displayHeightPx = 768, @@ -99,6 +94,7 @@ class MessageRequestTest { ), TextEditor(), Bash(), + // custom tool Tool() ) } @@ -106,7 +102,8 @@ class MessageRequestTest { // when val json = testJson.encodeToString(request) - json shouldEqualJson """ + // then + json shouldEqualJson /* language=json */ """ { "model": "claude-3-5-sonnet-latest", "messages": [ @@ -139,7 +136,7 @@ class MessageRequestTest { }, { "name": "get_weather", - "description": "Get the current weather in a given location", + "description": "Get the weather for a specific location", "input_schema": { "type": "object", "properties": { @@ -158,12 +155,11 @@ class MessageRequestTest { } ] } - """.trimIndent() - // then + """ } @Test - fun shouldCreateMessageRequestWithExplicitToolChoice() { + fun `Should create MessageRequest with explicit ToolChoice`() { // given val request = MessageRequest { +Message { @@ -182,7 +178,7 @@ class MessageRequestTest { val json = testJson.encodeToString(request) // then - json shouldEqualJson """ + json shouldEqualJson /* language=json */ """ { "model": "claude-3-5-sonnet-latest", "messages": [ @@ -205,7 +201,7 @@ class MessageRequestTest { "tools": [ { "name": "get_weather", - "description": "Get the current weather in a given location", + "description": "Get the weather for a specific location", "input_schema": { "type": "object", "properties": { @@ -224,13 +220,13 @@ class MessageRequestTest { } ] } - """.trimIndent() + """ } @Test - fun shouldDeserializeMessageRequestForExampleStoredOnDisk() { + fun `Should deserialize MessageRequest - for example a JSON stored on disk`() { // given - val request = """ + val request = /* language=json */ """ { "model": "claude-3-5-sonnet-latest", "messages": [ @@ -282,14 +278,19 @@ class MessageRequestTest { } ] } - """.trimIndent() + """ // when val messageRequest = testJson.decodeFromString(request) // then - // TODO assertions - println(messageRequest) + messageRequest.toString() shouldEqualJson request + } + + @Test + @Ignore // TODO this test can be fixed only when the model is refactored to be configurable + fun `Should fail to create a MessageRequest instance for unknown model`() { + //val messageRequest = MessageRequest { } } } diff --git a/src/commonTest/kotlin/message/MessageResponseTest.kt b/src/commonTest/kotlin/message/MessageResponseTest.kt index 6aa51f5..d3b87fd 100644 --- a/src/commonTest/kotlin/message/MessageResponseTest.kt +++ b/src/commonTest/kotlin/message/MessageResponseTest.kt @@ -2,11 +2,13 @@ package com.xemantic.anthropic.message import com.xemantic.anthropic.Response import com.xemantic.anthropic.content.ToolUse -import com.xemantic.anthropic.test.assert import com.xemantic.anthropic.test.testJson import com.xemantic.anthropic.usage.Usage -import io.kotest.matchers.shouldBe -import io.kotest.matchers.types.instanceOf +import com.xemantic.kotlin.test.be +import com.xemantic.kotlin.test.have +import com.xemantic.kotlin.test.should +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject import kotlin.test.Test /** @@ -15,8 +17,9 @@ import kotlin.test.Test class MessageResponseTest { @Test - fun shouldDeserializeToolUseMessageResponse() { + fun `Should deserialize ToolUse message response`() { // given + /* language=json */ val jsonResponse = """ { "id": "msg_01PspkNzNG3nrf5upeTsmWLF", @@ -42,28 +45,34 @@ class MessageResponseTest { "output_tokens": 86 } } - """.trimIndent() + """ + // when val response = testJson.decodeFromString(jsonResponse) - response shouldBe instanceOf() - (response as MessageResponse).assert { - id shouldBe "msg_01PspkNzNG3nrf5upeTsmWLF" - role shouldBe Role.ASSISTANT - model shouldBe "claude-3-5-sonnet-20241022" - content.size shouldBe 1 - content[0] shouldBe instanceOf() - stopReason shouldBe StopReason.TOOL_USE - stopSequence shouldBe null - usage shouldBe Usage( + + // then + response should { + be() + have(id == "msg_01PspkNzNG3nrf5upeTsmWLF") + have(role == Role.ASSISTANT) + have(model == "claude-3-5-sonnet-20241022") + have(content.size == 1) + have(stopReason == StopReason.TOOL_USE) + have(stopSequence == null) + have(usage == Usage( inputTokens = 419, outputTokens = 86 - ) - } - val toolUse = response.content[0] as ToolUse - toolUse.assert { - id shouldBe "toolu_01YHJK38TBKCRPn7zfjxcKHx" - name shouldBe "Calculator" - // TODO generate JsonObject to assert input + )) + content[0] should { + be() + have(id == "toolu_01YHJK38TBKCRPn7zfjxcKHx") + have(name == "Calculator") + have(input == buildJsonObject { + put("operation", JsonPrimitive("MULTIPLY")) + put("a", JsonPrimitive(15)) + put("b", JsonPrimitive(7)) + }) + } } } diff --git a/src/commonTest/kotlin/message/MessageTest.kt b/src/commonTest/kotlin/message/MessageTest.kt new file mode 100644 index 0000000..46fb761 --- /dev/null +++ b/src/commonTest/kotlin/message/MessageTest.kt @@ -0,0 +1,16 @@ +package com.xemantic.anthropic.message + +import com.xemantic.kotlin.test.have +import com.xemantic.kotlin.test.should +import org.junit.Test + +class MessageTest { + + @Test + fun `Default Message should have Role USER`() { + Message {} should { + have(role == Role.USER) + } + } + +} diff --git a/src/commonTest/kotlin/message/ToolResultTest.kt b/src/commonTest/kotlin/message/ToolResultTest.kt deleted file mode 100644 index cb24162..0000000 --- a/src/commonTest/kotlin/message/ToolResultTest.kt +++ /dev/null @@ -1,20 +0,0 @@ -package com.xemantic.anthropic.message - -import com.xemantic.anthropic.content.Text -import com.xemantic.anthropic.content.ToolResult -import io.kotest.matchers.shouldBe -import kotlin.test.Test - -class ToolResultTest { - - @Test - fun shouldCreateToolResultForSingleContentString() { - ToolResult(toolUseId = "42") { - +"foo" - } shouldBe ToolResult( - toolUseId = "42", - content = listOf(Text(text = "foo")) - ) - } - -} diff --git a/src/commonTest/kotlin/test/AnthropicTestSupport.kt b/src/commonTest/kotlin/test/AnthropicTestSupport.kt index e1eb244..0b693a8 100644 --- a/src/commonTest/kotlin/test/AnthropicTestSupport.kt +++ b/src/commonTest/kotlin/test/AnthropicTestSupport.kt @@ -3,9 +3,6 @@ package com.xemantic.anthropic.test import com.xemantic.anthropic.anthropicJson import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.json.Json -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract /** * A pretty JSON printing for testing. It's derived from [anthropicJson], @@ -17,17 +14,3 @@ val testJson = Json(from = anthropicJson) { @OptIn(ExperimentalSerializationApi::class) prettyPrintIndent = " " } - -/** - * Asserts certain conditions on an object of type [T]. - * - * @param T The type of the object to assert against. - * @param block A lambda with receiver that defines the assertions to be performed on the object. - */ -@OptIn(ExperimentalContracts::class) -inline fun T.assert(block: T.() -> Unit) { - contract { - callsInPlace(block, InvocationKind.EXACTLY_ONCE) - } - block(this) -} diff --git a/src/commonTest/kotlin/tool/ToolChoiceTest.kt b/src/commonTest/kotlin/tool/ToolChoiceTest.kt index 9817e37..f66dea7 100644 --- a/src/commonTest/kotlin/tool/ToolChoiceTest.kt +++ b/src/commonTest/kotlin/tool/ToolChoiceTest.kt @@ -16,11 +16,11 @@ class ToolChoiceTest { @Test fun shouldSerializeToolChoiceAuto() { - ToolChoice.Auto().json shouldEqualJson """ + ToolChoice.Auto().json shouldEqualJson /* language=json */ """ { "type": "auto" } - """.trimIndent() + """ ToolChoice.Auto( disableParallelToolUse = true @@ -29,27 +29,27 @@ class ToolChoiceTest { "type": "auto", "disable_parallel_tool_use": true } - """.trimIndent() + """ } @Test fun shouldSerializeToolChoiceAny() { - ToolChoice.Any().json shouldEqualJson """ + ToolChoice.Any().json shouldEqualJson /* language=json */ """ { "type": "any" } - """.trimIndent() + """ ToolChoice.Any( disableParallelToolUse = true - ).json shouldEqualJson """ + ).json shouldEqualJson /* language=json */ """ { "type": "any", "disable_parallel_tool_use": true } - """.trimIndent() + """ } @@ -58,23 +58,23 @@ class ToolChoiceTest { ToolChoice.Tool( name = "foo" - ).json shouldEqualJson """ + ).json shouldEqualJson /* language=json */ """ { "type": "tool", "name": "foo" } - """.trimIndent() + """ ToolChoice.Tool( name = "foo", disableParallelToolUse = true - ).json shouldEqualJson """ + ).json shouldEqualJson /* language=json */ """ { "type": "tool", "name": "foo", "disable_parallel_tool_use": true } - """.trimIndent() + """ } diff --git a/src/commonTest/kotlin/tool/ToolInputTest.kt b/src/commonTest/kotlin/tool/ToolInputTest.kt index dc35828..64c7692 100644 --- a/src/commonTest/kotlin/tool/ToolInputTest.kt +++ b/src/commonTest/kotlin/tool/ToolInputTest.kt @@ -2,14 +2,13 @@ package com.xemantic.anthropic.tool import com.xemantic.ai.tool.schema.meta.Description import com.xemantic.anthropic.cache.CacheControl -import com.xemantic.anthropic.test.assert +import com.xemantic.kotlin.test.have +import com.xemantic.kotlin.test.should import io.kotest.assertions.json.shouldEqualJson -import io.kotest.assertions.throwables.shouldThrow -import io.kotest.matchers.shouldBe -import io.kotest.matchers.string.shouldMatch import kotlinx.serialization.Serializable import kotlinx.serialization.SerializationException import kotlin.test.Test +import kotlin.test.assertFailsWith class ToolInputTest { @@ -34,10 +33,11 @@ class ToolInputTest { // when val tool = Tool() - tool.assert { - name shouldBe "TestTool" - description shouldBe "A test tool receiving a message and outputting it back" - inputSchema.toString() shouldEqualJson """ + tool should { + have(name == "TestTool") + have(description == "A test tool receiving a message and outputting it back") + have(cacheControl == null) + inputSchema.toString() shouldEqualJson /* language=json */ """ { "type": "object", "properties": { @@ -51,7 +51,6 @@ class ToolInputTest { ] } """ - cacheControl shouldBe null } } @@ -63,10 +62,11 @@ class ToolInputTest { cacheControl = CacheControl(type = CacheControl.Type.EPHEMERAL) ) - tool.assert { - name shouldBe "TestTool" - description shouldBe "A test tool receiving a message and outputting it back" - inputSchema.toString() shouldEqualJson """ + tool should { + have(name == "TestTool") + have(description == "A test tool receiving a message and outputting it back") + have(cacheControl == CacheControl(type = CacheControl.Type.EPHEMERAL)) + inputSchema.toString() shouldEqualJson /* language=json */ """ { "type": "object", "properties": { @@ -80,7 +80,6 @@ class ToolInputTest { ] } """ - cacheControl shouldBe CacheControl(type = CacheControl.Type.EPHEMERAL) } } @@ -88,10 +87,14 @@ class ToolInputTest { @Test fun `Should fail to create a Tool without AnthropicTool annotation`() { - shouldThrow { + assertFailsWith { Tool() - }.message shouldMatch "Cannot find serializer for class .*NoAnnotationTool, " + - "make sure that it is annotated with @AnthropicTool and kotlin.serialization plugin is enabled for the project" + } should { + have(message!!.matches(Regex( + "Cannot find serializer for class .*NoAnnotationTool, " + + "make sure that it is annotated with @AnthropicTool and kotlin.serialization plugin is enabled for the project" + ))) + } } @Serializable @@ -99,9 +102,13 @@ class ToolInputTest { @Test fun `Should fail to create a Tool with only Serializable annotation`() { - shouldThrow { + assertFailsWith { Tool() - }.message shouldMatch "The class .*OnlySerializableAnnotationTool must be annotated with @AnthropicTool" + } should { + have(message!!.matches(Regex( + "The class .*OnlySerializableAnnotationTool must be annotated with @AnthropicTool" + ))) + } } } diff --git a/src/commonTest/kotlin/usage/CostTest.kt b/src/commonTest/kotlin/usage/CostTest.kt new file mode 100644 index 0000000..253cea4 --- /dev/null +++ b/src/commonTest/kotlin/usage/CostTest.kt @@ -0,0 +1,120 @@ +package com.xemantic.anthropic.usage + +import com.xemantic.ai.money.Money +import com.xemantic.ai.money.ZERO +import com.xemantic.kotlin.test.have +import com.xemantic.kotlin.test.should +import kotlin.test.Test + +class CostTest { + + @Test + fun `Should create a Cost instance with correct values`() { + Cost( + inputTokens = Money("0.001"), + outputTokens = Money("0.002"), + cacheCreationInputTokens = Money("0.00025"), + cacheReadInputTokens = Money("0.0005") + ) should { + have(inputTokens == Money("0.001")) + have(outputTokens == Money("0.002")) + have(cacheCreationInputTokens == Money("0.00025")) + have(cacheReadInputTokens == Money("0.0005")) + } + } + + /** + * This case is used when costs per model are being defined. + */ + @Test + fun `Should create a Cost instance with correct values when cache costs are not specified`() { + Cost( + inputTokens = Money("0.001"), + outputTokens = Money("0.002") + ) should { + have(inputTokens == Money("0.001")) + have(outputTokens == Money("0.002")) + have(cacheCreationInputTokens == Money("0.00125")) + have(cacheReadInputTokens == Money("0.0001")) + } + } + + @Test + fun `Should add two Cost instances without cache`() { + // given + val cost1 = Cost( + inputTokens = Money("0.001"), + outputTokens = Money("0.002"), + cacheCreationInputTokens = Money.ZERO, + cacheReadInputTokens = Money.ZERO + ) + val cost2 = Cost( + inputTokens = Money("0.003"), + outputTokens = Money("0.004"), + cacheCreationInputTokens = Money.ZERO, + cacheReadInputTokens = Money.ZERO + ) + + // when + val result = cost1 + cost2 + + // then + result should { + have(inputTokens == Money("0.004")) + have(outputTokens == Money("0.006")) + have(cacheCreationInputTokens == Money.ZERO) + have(cacheReadInputTokens == Money.ZERO) + } + } + + @Test + fun `Should add two Cost instances with cache`() { + // given + val cost1 = Cost( + inputTokens = Money("0.001"), + outputTokens = Money("0.002"), + cacheCreationInputTokens = Money("0.0001"), + cacheReadInputTokens = Money("0.0002"), + ) + val cost2 = Cost( + inputTokens = Money("0.003"), + outputTokens = Money("0.004"), + cacheCreationInputTokens = Money("0.0003"), + cacheReadInputTokens = Money("0.0004") + ) + + // when + val result = cost1 + cost2 + + // then + result should { + have(inputTokens == Money("0.004")) + have(outputTokens == Money("0.006")) + have(cacheCreationInputTokens == Money("0.0004")) + have(cacheReadInputTokens == Money("0.0006")) + } + } + + @Test + fun `Should calculate total cost`() { + Cost( + inputTokens = Money("0.001"), + outputTokens = Money("0.002"), + cacheCreationInputTokens = Money("0.0005"), + cacheReadInputTokens = Money("0.0007") + ) should { + have(total == Money("0.0042")) + } + } + + @Test + fun `Should create ZERO Cost instance`() { + Cost.ZERO should { + have(inputTokens == Money.ZERO) + have(outputTokens == Money.ZERO) + have(cacheCreationInputTokens == Money.ZERO) + have(cacheReadInputTokens == Money.ZERO) + } + } + +} diff --git a/src/commonTest/kotlin/usage/UsageCollectorTest.kt b/src/commonTest/kotlin/usage/UsageCollectorTest.kt new file mode 100644 index 0000000..0611b53 --- /dev/null +++ b/src/commonTest/kotlin/usage/UsageCollectorTest.kt @@ -0,0 +1,136 @@ +package com.xemantic.anthropic.usage + +import com.xemantic.ai.money.Money +import com.xemantic.ai.money.Ratio +import com.xemantic.ai.money.ZERO +import com.xemantic.anthropic.Model +import com.xemantic.kotlin.test.assert +import com.xemantic.kotlin.test.have +import com.xemantic.kotlin.test.should +import kotlin.test.Test + +class UsageCollectorTest { + + @Test + fun `Should initialize UsageCollector with zero usage`() { + UsageCollector() should { + have(usage == Usage.ZERO) + have(cost == Cost.ZERO) + } + } + + @Test + fun `toString should return String representation of UsageCollector`() { + assert(UsageCollector().toString() == + "UsageCollector(usage=" + + "Usage(inputTokens=0, outputTokens=0, cacheCreationInputTokens=0, cacheReadInputTokens=0), cost=" + + "Cost(inputTokens=0, outputTokens=0, cacheCreationInputTokens=0, cacheReadInputTokens=0))" + ) + } + + @Test + fun `Should update usage and cost`() { + // given + val collector = UsageCollector() + + // when + collector.update( + modelCost = Model.DEFAULT.cost, + usage = Usage( + inputTokens = 1000, + outputTokens = 1000 + ) + ) + + // then + collector should { + have(usage == Usage( + inputTokens = 1000, + outputTokens = 1000, + cacheCreationInputTokens = 0, + cacheReadInputTokens = 0 + )) + have(cost == Cost( + inputTokens = Money(".003"), + outputTokens = Money(".015"), + cacheCreationInputTokens = Money.ZERO, + cacheReadInputTokens = Money.ZERO + )) + } + } + + @Test + fun `Should update usage and cost for batch`() { + // given + val collector = UsageCollector() + + // when + collector.update( + modelCost = Model.DEFAULT.cost, + usage = Usage( + inputTokens = 1000, + outputTokens = 1000 + ), + costRatio = Money.Ratio("0.5") + ) + + // then + collector should { + have(usage == Usage( + inputTokens = 1000, + outputTokens = 1000, + cacheCreationInputTokens = 0, + cacheReadInputTokens = 0 + )) + have(cost == Cost( + inputTokens = Money(".0015"), + outputTokens = Money(".0075"), + cacheCreationInputTokens = Money.ZERO, + cacheReadInputTokens = Money.ZERO + )) + } + } + + @Test + fun `Should accumulate multiple usage updates`() { + // given + val collector = UsageCollector() + val testUsage = Usage( + inputTokens = 1000, + outputTokens = 1000, + cacheCreationInputTokens = 1000, + cacheReadInputTokens = 1000 + ) + + // when + collector.update( + modelCost = Model.CLAUDE_3_5_SONNET.cost, + usage = testUsage + ) + collector.update( + modelCost = Model.CLAUDE_3_5_HAIKU.cost, + usage = testUsage + ) + collector.update( + modelCost = Model.CLAUDE_3_OPUS.cost, + usage = testUsage + ) + + // then + collector should { + usage should { + have(inputTokens == 3000) + have(outputTokens == 3000) + have(cacheCreationInputTokens == 3000) + have(cacheReadInputTokens == 3000) + } + cost should { + have(inputTokens == Money("0.0188")) // 0.003 + 0.0008 + 0.015 + have(outputTokens == Money("0.094")) // 0.015 + 0.004 + 0.075 + have(cacheCreationInputTokens == Money("0.0235")) // 0.00375 + 0.001 + 0.01875 + have(cacheReadInputTokens == Money("0.00188")) // 0.0003 + 0.00008 + 0.0015 + } + } + } + +} diff --git a/src/jvmTest/kotlin/content/MagicNumberTest.kt b/src/jvmTest/kotlin/content/MagicNumberTest.kt index c493ed2..7783e33 100644 --- a/src/jvmTest/kotlin/content/MagicNumberTest.kt +++ b/src/jvmTest/kotlin/content/MagicNumberTest.kt @@ -1,28 +1,38 @@ package com.xemantic.anthropic.content -import io.kotest.matchers.shouldBe import java.io.File import kotlin.test.Test +import com.xemantic.kotlin.test.assert class MagicNumberTest { @Test fun `Should detect file Magic Number`() { - File( - "test-data/minimal.pdf" - ).readBytes().findMagicNumber() shouldBe MagicNumber.PDF - File( - "test-data/minimal.jpg" - ).readBytes().findMagicNumber() shouldBe MagicNumber.JPEG - File( - "test-data/minimal.png" - ).readBytes().findMagicNumber() shouldBe MagicNumber.PNG - File( - "test-data/minimal.gif" - ).readBytes().findMagicNumber() shouldBe MagicNumber.GIF - File( - "test-data/minimal.webp" - ).readBytes().findMagicNumber() shouldBe MagicNumber.WEBP + assert( + File( + "test-data/minimal.pdf" + ).readBytes().findMagicNumber() == MagicNumber.PDF + ) + assert( + File( + "test-data/minimal.jpg" + ).readBytes().findMagicNumber() == MagicNumber.JPEG + ) + assert( + File( + "test-data/minimal.png" + ).readBytes().findMagicNumber() == MagicNumber.PNG + ) + assert( + File( + "test-data/minimal.gif" + ).readBytes().findMagicNumber() == MagicNumber.GIF + ) + assert( + File( + "test-data/minimal.webp" + ).readBytes().findMagicNumber() == MagicNumber.WEBP + ) } } \ No newline at end of file