Skip to content

Commit

Permalink
Cost usage aggregator. Initial version.
Browse files Browse the repository at this point in the history
  • Loading branch information
morisil committed Nov 10, 2024
1 parent 7132763 commit 2e7157f
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 47 deletions.
9 changes: 9 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.jetbrains.kotlin.gradle.targets.native.tasks.KotlinNativeTest
plugins {
alias(libs.plugins.kotlin.multiplatform)
alias(libs.plugins.kotlin.plugin.serialization)
alias(libs.plugins.kotlinx.atomicfu)
alias(libs.plugins.kotlin.plugin.power.assert)
alias(libs.plugins.dokka)
alias(libs.plugins.versions)
Expand Down Expand Up @@ -93,6 +94,14 @@ kotlin {

sourceSets {

all {
languageSettings {
languageVersion = kotlinTarget.version
apiVersion = kotlinTarget.version
progressiveMode = true
}
}

commonMain {
dependencies {
implementation(libs.kotlinx.datetime)
Expand Down
2 changes: 2 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ javaTarget = "17"
kotlin = "2.0.21"
kotlinxCoroutines = "1.9.0"
kotlinxDatetime = "0.6.1"
kotlinxAtomicFu = "0.26.0"
ktor = "3.0.1"
kotest = "6.0.0.M1"

Expand Down Expand Up @@ -42,6 +43,7 @@ kotest-assertions-json = { module = "io.kotest:kotest-assertions-json", version.
kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
kotlin-plugin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
kotlin-plugin-power-assert = { id = "org.jetbrains.kotlin.plugin.power-assert", version.ref = "kotlin" }
kotlinx-atomicfu = { id = "org.jetbrains.kotlinx.atomicfu", version.ref = "kotlinxAtomicFu" }
dokka = { id = "org.jetbrains.dokka", version.ref = "dokkaPlugin" }
versions = { id = "com.github.ben-manes.versions", version.ref = "versionsPlugin" }
publish = { id = "io.github.gradle-nexus.publish-plugin", version.ref = "publishPlugin" }
37 changes: 33 additions & 4 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.tool.BuiltInTool
import com.xemantic.anthropic.tool.Tool
import com.xemantic.anthropic.tool.ToolInput
import com.xemantic.anthropic.usage.Cost
import com.xemantic.anthropic.usage.Usage
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.*
Expand All @@ -26,6 +28,8 @@ import io.ktor.http.HttpMethod
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.update
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.flow
Expand Down Expand Up @@ -66,7 +70,8 @@ fun Anthropic(
defaultMaxTokens = config.defaultMaxTokens,
directBrowserAccess = config.directBrowserAccess,
logLevel = if (config.logHttp) LogLevel.ALL else LogLevel.NONE,
toolMap = config.tools.associateBy { it.name }
modelMap = config.modelMap,
toolMap = config.tools.associateBy { it.name },
)
} // TODO this can be a second constructor, then toolMap can be private

Expand All @@ -79,6 +84,7 @@ class Anthropic internal constructor(
val defaultMaxTokens: Int,
val directBrowserAccess: Boolean,
val logLevel: LogLevel,
private val modelMap: Map<String, AnthropicModel>,
private val toolMap: Map<String, Tool>
) {

Expand All @@ -95,6 +101,8 @@ class Anthropic internal constructor(

var tools: List<Tool> = emptyList()

var modelMap: Map<String, AnthropicModel> = Model.entries.associateBy { it.id }

// TODO in the future this should be rather Tool builder
inline fun <reified T : ToolInput> tool(
cacheControl: CacheControl? = null,
Expand Down Expand Up @@ -182,6 +190,7 @@ class Anthropic internal constructor(
println("Error!!! Unexpected tool use: ${toolUse.name}")
}
}
updateTotals()
}
is ErrorResponse -> throw AnthropicException(
error = response.error,
Expand Down Expand Up @@ -217,8 +226,12 @@ class Anthropic internal constructor(
.map { it.data }
.filterNotNull()
.map { anthropicJson.decodeFromString<Event>(it) }
.collect {
emit(it)
.collect { event ->
// TODO we need better way of handling subsequent deltas with usage
if (event is Event.MessageStart) {
event.message.updateTotals()
}
emit(event)
}
}
}
Expand All @@ -227,5 +240,21 @@ class Anthropic internal constructor(

val messages = Messages()

}
private val _totalUsage = atomic(Usage.ZERO)
val totalUsage: Usage get() = _totalUsage.value

private val _totalCost = atomic(Cost.ZERO)
val totalCost: Cost get() = _totalCost.value

private val MessageResponse.anthropicModel: AnthropicModel get() = requireNotNull(
modelMap[model]
) {
"The model returned in the response is not known to Anthropic API client: $id"
}

private fun MessageResponse.updateTotals() {
_totalUsage.update { it + usage }
_totalCost.update { it + (usage.cost(anthropicModel) / Model.PRICE_UNIT) }
}

}
45 changes: 31 additions & 14 deletions src/commonMain/kotlin/Models.kt
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
package com.xemantic.anthropic

enum class Model(
val id: String,
val contextWindow: Int,
val maxOutput: Int,
val messageBatchesApi: Boolean,
import com.xemantic.anthropic.usage.Cost

/**
* The model used by the API.
* E.g., Claude LLM `sonnet`, `opus`, `haiku` family.
*/
interface AnthropicModel {

val id: String
val contextWindow: Int
val maxOutput: Int
val messageBatchesApi: Boolean
val cost: Cost
) {

}

/**
* Predefined models supported by Anthropic API.
*
* It could include Vertex AI (Google Cloud), or Bedrock (AWS) models in the future.
*/
enum class Model(
override val id: String,
override val contextWindow: Int,
override val maxOutput: Int,
override val messageBatchesApi: Boolean,
override val cost: Cost
) : AnthropicModel {

CLAUDE_3_5_SONNET(
id = "claude-3-5-sonnet-latest",
Expand Down Expand Up @@ -107,16 +128,12 @@ enum class Model(
)
);

/**
* Cost per MTok
*/
data class Cost(
val inputTokens: Double,
val outputTokens: Double
)

companion object {

val DEFAULT: Model = CLAUDE_3_5_SONNET

const val PRICE_UNIT: Double = 1000000.0

}

}
1 change: 0 additions & 1 deletion src/commonMain/kotlin/message/Messages.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import com.xemantic.anthropic.tool.ToolInput
import com.xemantic.anthropic.tool.toolName
import com.xemantic.anthropic.usage.Usage
import kotlinx.serialization.*
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlin.collections.mutableListOf

/**
Expand Down
10 changes: 6 additions & 4 deletions src/commonMain/kotlin/tool/Tools.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.xemantic.anthropic.tool

import com.xemantic.anthropic.anthropicJson
import com.xemantic.anthropic.cache.CacheControl
import com.xemantic.anthropic.content.Content
import com.xemantic.anthropic.schema.Description
Expand All @@ -13,6 +14,7 @@ import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.Transient
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlinx.serialization.serializer

Expand Down Expand Up @@ -65,8 +67,10 @@ abstract class BuiltInTool(
* with a given tool use ID. The implementation of the [use] method should
* contain the logic for executing the tool and returning the [ToolResult].
*/
@Serializable
abstract class ToolInput {

@Transient
private var block: suspend ToolResult.Builder.() -> Any? = {}

fun use(block: suspend ToolResult.Builder.() -> Any?) {
Expand All @@ -82,12 +86,10 @@ abstract class ToolInput {
suspend fun use(toolUseId: String): ToolResult {
return ToolResult(toolUseId) {
val result = block(this)
if (result != null) {
if ((result != null) && (result !is Unit)) {
when (result) {
is Content -> +result
is Unit -> {} // nothing to do
!is Unit -> +result.toString()
else -> throw IllegalStateException("Tool use {} returned not supported: $this")
else -> +anthropicJson.encodeToString(result)
}
}
}
Expand Down
80 changes: 59 additions & 21 deletions src/commonMain/kotlin/usage/Usage.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.xemantic.anthropic.usage

import com.xemantic.anthropic.AnthropicModel
import com.xemantic.anthropic.Model
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
Expand All @@ -14,37 +15,74 @@ data class Usage(
val cacheCreationInputTokens: Int? = null,
@SerialName("cache_read_input_tokens")
val cacheReadInputTokens: Int? = null,
)

fun Usage.add(usage: Usage): Usage = Usage(
inputTokens = inputTokens + usage.inputTokens,
outputTokens = outputTokens + usage.outputTokens,
cacheReadInputTokens = (cacheReadInputTokens ?: 0) + (usage.cacheReadInputTokens ?: 0),
cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) + (usage.cacheCreationInputTokens ?: 0),
)

fun Usage.cost(
model: Model,
isBatch: Boolean = false
): Cost = Cost(
inputTokens = inputTokens * model.cost.inputTokens / 1000000.0 * (if (isBatch) .5 else 1.0),
outputTokens = outputTokens * model.cost.outputTokens / 1000000.0 * (if (isBatch) .5 else 1.0),
cacheReadInputTokens = (cacheReadInputTokens ?: 0) * model.cost.inputTokens * .1 / 1000000.0 * (if (isBatch) .5 else 1.0),
cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) * model.cost.inputTokens * .25 / 1000000.0 * (if (isBatch) .5 else 1.0)
)
) {

companion object {

val ZERO = Usage(
inputTokens = 0,
outputTokens = 0,
cacheCreationInputTokens = 0,
cacheReadInputTokens = 0
)

}

operator fun plus(usage: Usage): Usage = Usage(
inputTokens = inputTokens + usage.inputTokens,
outputTokens = outputTokens + usage.outputTokens,
cacheReadInputTokens = (cacheReadInputTokens ?: 0) + (usage.cacheReadInputTokens ?: 0),
cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) + (usage.cacheCreationInputTokens ?: 0),
)

fun cost(
model: AnthropicModel,
isBatch: Boolean = false
): Cost = Cost(
inputTokens = inputTokens * model.cost.inputTokens / Model.PRICE_UNIT,
outputTokens = outputTokens * model.cost.outputTokens / Model.PRICE_UNIT,
cacheReadInputTokens = (cacheReadInputTokens ?: 0) / Model.PRICE_UNIT,
cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) / Model.PRICE_UNIT
).let { if (isBatch) it * .5 else it }

}

@Serializable
data class Cost(
val inputTokens: Double,
val outputTokens: Double,
val cacheCreationInputTokens: Double,
val cacheReadInputTokens: Double
val cacheCreationInputTokens: Double = inputTokens * .25,
val cacheReadInputTokens: Double = inputTokens * .25
) {

fun add(cost: Cost): Cost = Cost(
operator fun plus(cost: Cost): Cost = Cost(
inputTokens = inputTokens + cost.inputTokens,
outputTokens = outputTokens + cost.outputTokens,
cacheCreationInputTokens = cacheCreationInputTokens + cost.cacheCreationInputTokens,
cacheReadInputTokens = cacheReadInputTokens + cost.cacheReadInputTokens
)

operator fun times(value: Double): Cost = Cost(
inputTokens = inputTokens * value,
outputTokens = outputTokens * value,
cacheCreationInputTokens = cacheCreationInputTokens * value,
cacheReadInputTokens = cacheReadInputTokens * value
)

operator fun div(value: Double): Cost = Cost(
inputTokens = inputTokens / value,
outputTokens = outputTokens / value,
cacheCreationInputTokens = cacheCreationInputTokens / value,
cacheReadInputTokens = cacheReadInputTokens / value
)

val total: Double get() = inputTokens + outputTokens + cacheCreationInputTokens + cacheReadInputTokens

companion object {
val ZERO = Cost(
inputTokens = 0.0,
outputTokens = 0.0
)
}

}
23 changes: 20 additions & 3 deletions src/commonTest/kotlin/AnthropicTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import com.xemantic.anthropic.tool.TestDatabase
import com.xemantic.anthropic.content.Text
import com.xemantic.anthropic.content.ToolUse
import io.kotest.assertions.assertSoftly
import io.kotest.matchers.doubles.shouldBeLessThan
import io.kotest.matchers.ints.shouldBeGreaterThan
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
Expand All @@ -30,18 +31,18 @@ class AnthropicTest {
@Test
fun shouldReceiveAnIntroductionFromClaude() = runTest {
// given
val client = Anthropic()
val anthropic = Anthropic()

// when
val response = client.messages.create {
val response = anthropic.messages.create {
+Message {
+"Hello World! What's your name?"
}
maxTokens = 1024
}

// then
assertSoftly(response) {
response.apply {
role shouldBe Role.ASSISTANT
model shouldBe "claude-3-5-sonnet-20241022"
stopReason shouldBe StopReason.END_TURN
Expand All @@ -53,6 +54,22 @@ class AnthropicTest {
usage.inputTokens shouldBe 15
usage.outputTokens shouldBeGreaterThan 0
}

anthropic.totalUsage.apply {
inputTokens shouldBe 15
outputTokens shouldBeGreaterThan 0
cacheReadInputTokens shouldBe 0
cacheCreationInputTokens shouldBe 0
}

anthropic.totalCost.apply {
inputTokens shouldBeLessThan 0.0000000001
outputTokens shouldBeLessThan 0.000000001
cacheReadInputTokens shouldBe 0.0
cacheCreationInputTokens shouldBe 0.0
total shouldBeLessThan 0.000000001
}

}

@Test
Expand Down

0 comments on commit 2e7157f

Please sign in to comment.