Skip to content

Commit

Permalink
Feature: Cost calculation of API usage (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
morisil authored Dec 13, 2024
1 parent 17747e4 commit 8ba0cb4
Show file tree
Hide file tree
Showing 28 changed files with 1,012 additions and 366 deletions.
2 changes: 1 addition & 1 deletion .github/FUNDING.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# These are supported funding model platforms

github:xemantic # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
github: xemantic
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ build/

### IntelliJ IDEA ###
/.idea/
!/.idea/copyright/
*.iws
*.iml
*.ipr
Expand Down
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,11 @@ fun main() = runBlocking {
println(initialResponse)

conversation += initialResponse
val tool = initialResponse.content.filterIsInstance<ToolUse>().first()
val toolResult = tool.use()
conversation += Message { +toolResult }
conversation += initialResonse.useTools()

val finalResponse = client.messages.create {
messages = conversation
useTools()
allTools()
}
println("Final response:")
println(finalResponse)
Expand Down
20 changes: 15 additions & 5 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.jetbrains.kotlin.gradle.dsl.KotlinVersion
plugins {
alias(libs.plugins.kotlin.multiplatform)
alias(libs.plugins.kotlin.plugin.serialization)
alias(libs.plugins.kotlinx.atomicfu)
alias(libs.plugins.kotlin.plugin.power.assert)
alias(libs.plugins.dokka)
alias(libs.plugins.versions)
Expand Down Expand Up @@ -139,6 +140,14 @@ kotlin {

sourceSets {

all {
languageSettings {
languageVersion = kotlinTarget.version
apiVersion = kotlinTarget.version
progressiveMode = true
}
}

commonMain {
dependencies {
implementation(libs.kotlinx.datetime)
Expand All @@ -147,14 +156,15 @@ kotlin {
implementation(libs.ktor.client.logging)
implementation(libs.ktor.serialization.kotlinx.json)
implementation(libs.xemantic.ai.tool.schema)
api(libs.xemantic.ai.money)
}
}

commonTest {
dependencies {
implementation(libs.kotlin.test)
implementation(libs.kotlinx.coroutines.test)
implementation(libs.kotest.assertions.core)
implementation(libs.xemantic.kotlin.test)
implementation(libs.kotest.assertions.json)
}
}
Expand Down Expand Up @@ -233,10 +243,10 @@ tasks.withType<Test> {
}

powerAssert {
// functions = listOf(
// "io.kotest.matchers.shouldBe"
// )
// includedSourceSets = listOf("commonTest", "jvmTest", "nativeTest")
functions = listOf(
"com.xemantic.kotlin.test.assert",
"com.xemantic.kotlin.test.have"
)
}

// maybe this one is not necessary?
Expand Down
7 changes: 6 additions & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ javaTarget = "17"
kotlin = "2.1.0"
kotlinxCoroutines = "1.9.0"
kotlinxDatetime = "0.6.1"
kotlinxAtomicFu = "0.26.0"
ktor = "3.0.1"
kotest = "6.0.0.M1"

xemanticKotlinTest = "1.0"
xemanticAiToolSchema = "0.1.1"
xemanticAiMoney = "0.2"

# logging is not used at the moment, might be enabled later
log4j = "2.24.2"
Expand All @@ -24,7 +27,9 @@ kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-t
kotlinx-datetime = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.ref = "kotlinxDatetime" }

# xemantic
xemantic-kotlin-test = { module = "com.xemantic.kotlin:xemantic-kotlin-test", version.ref = "xemanticKotlinTest"}
xemantic-ai-tool-schema = { module = "com.xemantic.ai:xemantic-ai-tool-schema", version.ref = "xemanticAiToolSchema"}
xemantic-ai-money = { module = "com.xemantic.ai:xemantic-ai-money", version.ref = "xemanticAiMoney"}

# logging libs
log4j-slf4j2 = { module = "org.apache.logging.log4j:log4j-slf4j2-impl", version.ref = "log4j" }
Expand All @@ -40,13 +45,13 @@ ktor-client-java = { module = "io.ktor:ktor-client-java", version.ref = "ktor" }
ktor-client-curl = { module = "io.ktor:ktor-client-curl", version.ref = "ktor" }
ktor-client-darwin = { module = "io.ktor:ktor-client-darwin", version.ref = "ktor" }

kotest-assertions-core = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" }
kotest-assertions-json = { module = "io.kotest:kotest-assertions-json", version.ref = "kotest" }

[plugins]
kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
kotlin-plugin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
kotlin-plugin-power-assert = { id = "org.jetbrains.kotlin.plugin.power-assert", version.ref = "kotlin" }
kotlinx-atomicfu = { id = "org.jetbrains.kotlinx.atomicfu", version.ref = "kotlinxAtomicFu" }
dokka = { id = "org.jetbrains.dokka", version.ref = "dokkaPlugin" }
versions = { id = "com.github.ben-manes.versions", version.ref = "versionsPlugin" }
publish = { id = "io.github.gradle-nexus.publish-plugin", version.ref = "publishPlugin" }
47 changes: 41 additions & 6 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.tool.BuiltInTool
import com.xemantic.anthropic.tool.Tool
import com.xemantic.anthropic.tool.ToolInput
import com.xemantic.anthropic.usage.Cost
import com.xemantic.anthropic.usage.Usage
import com.xemantic.anthropic.usage.UsageCollector
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.*
Expand Down Expand Up @@ -66,7 +69,8 @@ fun Anthropic(
defaultMaxTokens = config.defaultMaxTokens,
directBrowserAccess = config.directBrowserAccess,
logLevel = if (config.logHttp) LogLevel.ALL else LogLevel.NONE,
toolMap = config.tools.associateBy { it.name }
modelMap = config.modelMap,
toolMap = config.tools.associateBy { it.name },
)
} // TODO this can be a second constructor, then toolMap can be private

Expand All @@ -79,6 +83,7 @@ class Anthropic internal constructor(
val defaultMaxTokens: Int,
val directBrowserAccess: Boolean,
val logLevel: LogLevel,
private val modelMap: Map<String, AnthropicModel>,
private val toolMap: Map<String, Tool>
) {

Expand All @@ -87,14 +92,16 @@ class Anthropic internal constructor(
var anthropicVersion: String = DEFAULT_ANTHROPIC_VERSION
var anthropicBeta: String? = null
var apiBase: String = ANTHROPIC_API_BASE
var defaultModel: Model = Model.DEFAULT
var defaultModel: AnthropicModel = Model.DEFAULT
var defaultMaxTokens: Int = defaultModel.maxOutput

var directBrowserAccess: Boolean = false
var logHttp: Boolean = false

var tools: List<Tool> = emptyList()

var modelMap: Map<String, AnthropicModel> = Model.entries.associateBy { it.id }

// TODO in the future this should be rather Tool builder
inline fun <reified T : ToolInput> tool(
cacheControl: CacheControl? = null,
Expand Down Expand Up @@ -176,6 +183,7 @@ class Anthropic internal constructor(
val response = apiResponse.body<Response>()
when (response) {
is MessageResponse -> response.apply {
updateUsage(response)
content.filterIsInstance<ToolUse>()
.forEach { toolUse ->
val tool = toolMap[toolUse.name]
Expand All @@ -192,7 +200,9 @@ class Anthropic internal constructor(
error = response.error,
httpStatusCode = apiResponse.status
)
else -> throw RuntimeException("Unsupported response: $response") // should never happen
else -> throw RuntimeException(
"Unsupported response: $response"
) // should never happen
}
return response
}
Expand Down Expand Up @@ -222,8 +232,13 @@ class Anthropic internal constructor(
.map { it.data }
.filterNotNull()
.map { anthropicJson.decodeFromString<Event>(it) }
.collect {
emit(it)
.collect { event ->
// TODO we need better way of handling subsequent deltas with usage
if (event is Event.MessageStart) {
// TODO more rules are needed here
updateUsage(event.message)
}
emit(event)
}
}
}
Expand All @@ -232,5 +247,25 @@ class Anthropic internal constructor(

val messages = Messages()

}
private val usageCollector = UsageCollector()

val usage: Usage get() = usageCollector.usage

val cost: Cost get() = usageCollector.cost

override fun toString(): String = "Anthropic($usage, $cost)"

private val MessageResponse.anthropicModel: AnthropicModel get() = requireNotNull(
modelMap[model]
) {
"The model returned in the response is not known to Anthropic API client: $id"
}

private fun updateUsage(response: MessageResponse) {
usageCollector.update(
modelCost = response.anthropicModel.cost,
usage = response.usage
)
}

}
10 changes: 10 additions & 0 deletions src/commonMain/kotlin/AnthropicJson.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.descriptors.buildSerialDescriptor
import kotlinx.serialization.encodeToString
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -65,6 +66,15 @@ val anthropicJson: Json = Json {
encodeDefaults = true
}

@OptIn(ExperimentalSerializationApi::class)
@PublishedApi
internal val prettyAnthropicJson: Json = Json(from = anthropicJson) {
prettyPrint = true
prettyPrintIndent = " "
}

inline fun <reified T> T.toPrettyJson(): String = prettyAnthropicJson.encodeToString<T>(this)

private object ResponseSerializer : JsonContentPolymorphicSerializer<Response>(
baseClass = Response::class
) {
Expand Down
Loading

0 comments on commit 8ba0cb4

Please sign in to comment.