Skip to content

Commit

Permalink
the next working version of tools, with simplified API
Browse files Browse the repository at this point in the history
  • Loading branch information
morisil committed Oct 10, 2024
1 parent 6872a5a commit a0b5a9a
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 214 deletions.
53 changes: 32 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,17 @@ fun main() {
It can also use tools:

```kotlin
@Serializable
@SerializableTool(
name = "Calculator",
description = "Calculates the arithmetic outcome of an operation when given the arguments a and b"
)
data class Calculator(
val operation: Operation,
val a: Double,
val b: Double
) {
): UsableTool {

@Suppress("unused") // will be used by Anthropic :)
@Suppress("unused") // it is used, but by Anthropic, so we skip the warning
enum class Operation(
val calculate: (a: Double, b: Double) -> Double
) {
Expand All @@ -113,31 +116,39 @@ data class Calculator(
DIVIDE({ a, b -> a / b })
}

fun calculate() = operation.calculate(a, b)
override fun use(toolUseId: String) = ToolResult(
toolUseId,
operation.calculate(a, b).toString()
)

}

fun main() {
val client = Anthropic()
fun main() = runBlocking {

val calculatorTool = Tool<Calculator>(
description = "Perform basic arithmetic operations"
)
val client = Anthropic {
tool<FibonacciTool>()
}

val response = runBlocking {
client.messages.create {
+Message {
+"What's 15 multiplied by 7?"
}
tools = listOf(calculatorTool)
toolChoice = ToolChoice.Any()
}
val conversation = mutableListOf<Message>()
conversation += Message { +"What's 15 multiplied by 7?" }

val response1 = client.messages.create {
messages = conversation
useTools()
}
conversation += response1.asMessage()

val toolUse = response.content[0] as ToolUse
val calculator = toolUse.input<Calculator>()
val result = calculator.calculate() // we are doing the job for LLM here
println(result)
println((response1.content[0] as Text).text)
val toolUse = response1.content[1] as ToolUse
val result = toolUse.use() // we are doing the calculation job for Claude here

conversation += Message { +result }

val response2 = client.messages.create {
messages = conversation
useTools()
}
println((response2.content[0] as Text).text)
}
```

Expand Down
5 changes: 2 additions & 3 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import org.gradle.api.tasks.testing.logging.TestLogEvent
import org.jetbrains.kotlin.gradle.ExperimentalKotlinGradlePluginApi
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
import org.jetbrains.kotlin.gradle.dsl.KotlinVersion
import org.jetbrains.kotlin.gradle.tasks.KotlinJvmCompile

plugins {
alias(libs.plugins.kotlin.multiplatform)
Expand Down Expand Up @@ -47,6 +46,7 @@ kotlin {
//explicitApi() // check with serialization?
jvm {
testRuns["test"].executionTask.configure {
enabled = false
useJUnitPlatform()
}
// set up according to https://jakewharton.com/gradle-toolchains-are-rarely-a-good-idea/
Expand Down Expand Up @@ -129,10 +129,9 @@ tasks.withType<Test> {
enabled = true
}

@Suppress("OPT_IN_USAGE")
powerAssert {
functions = listOf(
"com.xemantic.anthropic.test.shouldBe"
"io.kotest.matchers.shouldBe"
)
includedSourceSets = listOf("commonTest", "jvmTest", "nativeTest")
}
Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ javaTarget = "17"
kotlin = "2.0.21"
kotlinxCoroutines = "1.9.0"
ktor = "3.0.0"
kotest = "5.9.1"
kotest = "6.0.0.M1"

log4j = "2.24.1"
jackson = "2.18.0"
Expand Down
59 changes: 34 additions & 25 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.message.Tool
import com.xemantic.anthropic.message.ToolUse
import com.xemantic.anthropic.tool.UsableTool
import com.xemantic.anthropic.tool.toolOf
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
Expand Down Expand Up @@ -73,9 +74,9 @@ fun Anthropic(
defaultModel = defaultModel,
directBrowserAccess = config.directBrowserAccess
).apply {
usableTools = config.usableTools
toolEntryMap = (config.usableTools as List<Anthropic.ToolEntry<UsableTool>>).associateBy { it.tool.name }
}
}
} // TODO this can be a second constructor, then toolMap can be private

class Anthropic internal constructor(
val apiKey: String,
Expand All @@ -93,36 +94,35 @@ class Anthropic internal constructor(
var apiBase: String = ANTHROPIC_API_BASE
var defaultModel: String? = null
var directBrowserAccess: Boolean = false
var usableTools: List<KClass<out UsableTool>> = emptyList()
@PublishedApi
internal var usableTools: List<ToolEntry<out UsableTool>> = emptyList()

inline fun <reified T : UsableTool> tool(
block: T.() -> Unit = {}
noinline block: T.() -> Unit = {}
) {
usableTools += T::class
val entry = ToolEntry(toolOf<T>(), serializer<T>(), block)
usableTools += entry
}

}

private class ToolEntry(
val tool: Tool,
@PublishedApi
internal class ToolEntry<T : UsableTool>(
val tool: Tool, // TODO, no cache control
val serializer: KSerializer<T>,
val initializer: T.() -> Unit = {}
)

private var toolSerializerMap = mapOf<String, KSerializer<out UsableTool>>()
internal var toolEntryMap = mapOf<String, ToolEntry<UsableTool>>()

var usableTools: List<KClass<out UsableTool>> = emptyList()
get() = field
set(value) {
value.validate()
field = value
}
// var usableTools: List<KClass<out Tool>> = emptyList()
// set(value) {
// toolMap += mapOf(value)
// field = value
// }

inline fun <reified T : UsableTool> tool() {
usableTools += T::class
}

fun List<KClass<out UsableTool>>.validate() {
forEach { tool ->
//tool.serializer()
}
//usableTools += T::class
}

private val client = HttpClient {
Expand All @@ -146,22 +146,28 @@ class Anthropic internal constructor(
}
}

inner class Messages() {
inner class Messages {

suspend fun create(
block: MessageRequest.Builder.() -> Unit
): MessageResponse {

val request = MessageRequest.Builder(
defaultModel
defaultModel,
toolEntryMap = toolEntryMap
).apply(block).build()

val response = client.post("/v1/messages") {
contentType(ContentType.Application.Json)
setBody(request)
}
if (response.status.isSuccess()) {
return response.body<MessageResponse>().apply {
content.filterIsInstance<ToolUse>()
.forEach { it.toolSerializerMap = toolSerializerMap }
.forEach { toolUse ->
val entry = toolEntryMap[toolUse.name]!!
toolUse.toolEntry = entry
}
}
} else {
throw AnthropicException(
Expand All @@ -175,7 +181,10 @@ class Anthropic internal constructor(
block: MessageRequest.Builder.() -> Unit
): Flow<Event> = flow {

val request = MessageRequest.Builder(defaultModel).apply {
val request = MessageRequest.Builder(
defaultModel,
toolEntryMap = toolEntryMap
).apply {
block(this)
stream = true
}.build()
Expand Down
27 changes: 15 additions & 12 deletions src/commonMain/kotlin/message/Messages.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.xemantic.anthropic.message

import com.xemantic.anthropic.Anthropic
import com.xemantic.anthropic.anthropicJson
import com.xemantic.anthropic.schema.JsonSchema
import com.xemantic.anthropic.tool.UsableTool
Expand Down Expand Up @@ -48,8 +49,9 @@ data class MessageRequest(
val topP: Int?
) {

class Builder(
val defaultApiModel: String
class Builder internal constructor(
private val defaultModel: String,
private val toolEntryMap: Map<String, Anthropic.ToolEntry<out UsableTool>>
) {
var model: String? = null
var maxTokens = 1024
Expand All @@ -66,11 +68,7 @@ data class MessageRequest(
val topP: Int? = null

fun useTools() {
//too
}

fun tools(vararg classes: KClass<out UsableTool>) {
// TODO it needs access to Anthropic, therefore either needs a constructor parameter, or needs to be inner class
tools = toolEntryMap.values.map { it.tool }
}

fun messages(vararg messages: Message) {
Expand All @@ -96,7 +94,7 @@ data class MessageRequest(
}

fun build(): MessageRequest = MessageRequest(
model = if (model != null) model!! else defaultApiModel,
model = if (model != null) model!! else defaultModel,
maxTokens = maxTokens,
messages = messages,
metadata = metadata,
Expand All @@ -117,7 +115,9 @@ fun MessageRequest(
defaultModel: String,
block: MessageRequest.Builder.() -> Unit
): MessageRequest {
val builder = MessageRequest.Builder(defaultModel)
val builder = MessageRequest.Builder(
defaultModel, emptyMap()
)
block(builder)
return builder.build()
}
Expand Down Expand Up @@ -286,11 +286,14 @@ data class ToolUse(
) : Content() {

@Transient
internal lateinit var toolSerializerMap: Map<String, KSerializer<out UsableTool>>
internal lateinit var toolEntry: Anthropic.ToolEntry<UsableTool>

fun use(): ToolResult {
val serializer = toolSerializerMap[name]!!
val tool = anthropicJson.decodeFromJsonElement(serializer, input)
val tool = anthropicJson.decodeFromJsonElement(
deserializer = toolEntry.serializer,
element = input
)
toolEntry.initializer(tool)
return tool.use(toolUseId = id)
}

Expand Down
Loading

0 comments on commit a0b5a9a

Please sign in to comment.