Skip to content

Commit

Permalink
* add direct browser access to client headers
Browse files Browse the repository at this point in the history
* single shared anthropicJson instance for consistency
* tool test updated
* JsonSchema support
  • Loading branch information
morisil committed Oct 1, 2024
1 parent 5e31fd9 commit c44eaa2
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 69 deletions.
28 changes: 16 additions & 12 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonBuilder

const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/"

Expand All @@ -42,6 +41,12 @@ expect val envApiKey: String?

expect val missingApiKeyMessage: String

val anthropicJson: Json = Json {
allowSpecialFloatingPointValues = true
explicitNulls = false
encodeDefaults = true
}

fun Anthropic(
block: Anthropic.Config.() -> Unit = {}
): Anthropic {
Expand All @@ -54,7 +59,8 @@ fun Anthropic(
anthropicVersion = config.anthropicVersion,
anthropicBeta = config.anthropicBeta,
apiBase = config.apiBase,
defaultModel = defaultModel
defaultModel = defaultModel,
directBrowserAccess = config.directBrowserAccess
)
}

Expand All @@ -63,7 +69,8 @@ class Anthropic internal constructor(
val anthropicVersion: String,
val anthropicBeta: String?,
val apiBase: String,
val defaultModel: String
val defaultModel: String,
val directBrowserAccess: Boolean
) {

class Config {
Expand All @@ -72,12 +79,12 @@ class Anthropic internal constructor(
var anthropicBeta: String? = null
var apiBase: String = ANTHROPIC_API_BASE
var defaultModel: String? = null
var directBrowserAccess: Boolean = false
}

private val json = Json(builderAction = anthropicJsonConfigurer)
private val client = HttpClient {
install(ContentNegotiation) {
json(json)
json(anthropicJson)
}
install(SSE)
install(Logging) {
Expand All @@ -90,6 +97,9 @@ class Anthropic internal constructor(
if (anthropicBeta != null) {
header("anthropic-beta", anthropicBeta)
}
if (directBrowserAccess) {
header("anthropic-dangerous-direct-browser-access", true)
}
}
}

Expand Down Expand Up @@ -134,7 +144,7 @@ class Anthropic internal constructor(
) {
incoming
.filter { it.data != null }
.map { json.decodeFromString<Event>(it.data!!) }
.map { anthropicJson.decodeFromString<Event>(it.data!!) }
.collect {
emit(it)
}
Expand All @@ -146,9 +156,3 @@ class Anthropic internal constructor(
val messages = Messages()

}

val anthropicJsonConfigurer: JsonBuilder.() -> Unit = {
allowSpecialFloatingPointValues = true
explicitNulls = false
encodeDefaults = true
}
23 changes: 21 additions & 2 deletions src/commonMain/kotlin/message/Messages.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package com.xemantic.anthropic.message

import com.xemantic.anthropic.anthropicJson
import com.xemantic.anthropic.schema.JsonSchema
import com.xemantic.anthropic.schema.jsonSchemaOf
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.decodeFromJsonElement
import kotlin.collections.mutableListOf

enum class Role {
Expand Down Expand Up @@ -180,10 +184,20 @@ data class Tool(
val name: String,
val description: String,
@SerialName("input_schema")
val inputSchema: JsonObject, // soon it will be a generic type
val inputSchema: JsonSchema,
val cacheControl: CacheControl?
)

inline fun <reified T> Tool(
description: String,
cacheControl: CacheControl? = null
): Tool = Tool(
name = T::class.qualifiedName!!,
description = description,
inputSchema = jsonSchemaOf<T>(),
cacheControl = cacheControl
)

@Serializable
@JsonClassDiscriminator("type")
@OptIn(ExperimentalSerializationApi::class)
Expand Down Expand Up @@ -242,7 +256,12 @@ data class ToolUse(
val id: String,
val name: String,
val input: JsonObject
) : Content()
) : Content() {

inline fun <reified T> input(): T =
anthropicJson.decodeFromJsonElement<T>(input)

}

@Serializable
@SerialName("tool_result")
Expand Down
31 changes: 31 additions & 0 deletions src/commonMain/kotlin/schema/JsonSchema.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.xemantic.anthropic.schema

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class JsonSchema(
val type: String = "object",
val definitions: Map<String, JsonSchema>? = null,
val properties: Map<String, JsonSchemaProperty>? = null,
val required: List<String>? = null,
@SerialName("\$ref")
var ref: String? = null
)

@Serializable
data class JsonSchemaProperty(
val type: String? = null,
val items: JsonSchemaProperty? = null,
val enum: List<String>? = null,
val ref: String? = null
) {

companion object {
val STRING = JsonSchemaProperty("string")
val INTEGER = JsonSchemaProperty("integer")
val NUMBER = JsonSchemaProperty("number")
val BOOLEAN = JsonSchemaProperty("boolean")
}

}
75 changes: 75 additions & 0 deletions src/commonMain/kotlin/schema/JsonSchemaGenerator.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.xemantic.anthropic.schema

import kotlinx.serialization.*
import kotlinx.serialization.descriptors.*
import kotlin.collections.set

inline fun <reified T> jsonSchemaOf(): JsonSchema = generateSchema(
serializer<T>().descriptor
)

@OptIn(ExperimentalSerializationApi::class)
fun generateSchema(descriptor: SerialDescriptor): JsonSchema {
val properties = mutableMapOf<String, JsonSchemaProperty>()
val required = mutableListOf<String>()
val definitions = mutableMapOf<String, JsonSchema>()

for (i in 0 until descriptor.elementsCount) {
val name = descriptor.getElementName(i)
val elementDescriptor = descriptor.getElementDescriptor(i)
val property = generateSchemaProperty(elementDescriptor, definitions)
properties[name] = property
if (!descriptor.isElementOptional(i)) {
required.add(name)
}
}

return JsonSchema(
type = "object",
properties = properties,
required = required,
definitions = if (definitions.isNotEmpty()) definitions else null
)
}

@OptIn(ExperimentalSerializationApi::class)
private fun generateSchemaProperty(
descriptor: SerialDescriptor,
definitions: MutableMap<String, JsonSchema>
): JsonSchemaProperty {
return when (descriptor.kind) {
PrimitiveKind.STRING -> JsonSchemaProperty.STRING
PrimitiveKind.INT, PrimitiveKind.LONG -> JsonSchemaProperty.INTEGER
PrimitiveKind.FLOAT, PrimitiveKind.DOUBLE -> JsonSchemaProperty.NUMBER
PrimitiveKind.BOOLEAN -> JsonSchemaProperty.BOOLEAN
SerialKind.ENUM -> enumProperty(descriptor)
StructureKind.LIST -> JsonSchemaProperty(
type = "array",
items = generateSchemaProperty(
descriptor.getElementDescriptor(0),
definitions
)
)
StructureKind.MAP -> JsonSchemaProperty("object")
StructureKind.CLASS -> {
val refName = descriptor.serialName.trimEnd('?')
definitions[refName] = generateSchema(descriptor)
JsonSchemaProperty("\$ref", ref = "#/definitions/$refName")
}
else -> JsonSchemaProperty("object") // Default case
}
}

private fun enumProperty(
descriptor: SerialDescriptor
) = JsonSchemaProperty(
enum = descriptor.elementNames()
)

@OptIn(ExperimentalSerializationApi::class)
private fun SerialDescriptor.elementNames(): List<String> = buildList {
for (i in 0 until elementsCount) {
val name = getElementName(i)
add(name)
}
}
75 changes: 23 additions & 52 deletions src/commonTest/kotlin/AnthropicTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@ import com.xemantic.anthropic.message.Text
import com.xemantic.anthropic.message.Tool
import com.xemantic.anthropic.message.ToolChoice
import com.xemantic.anthropic.message.ToolUse
import com.xemantic.anthropic.schema.jsonSchemaOf
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.json.add
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.double
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlinx.serialization.json.put
import kotlinx.serialization.json.putJsonArray
import kotlinx.serialization.Serializable
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull
Expand Down Expand Up @@ -107,58 +102,35 @@ class AnthropicTest {
assertTrue(response == "The quick brown fox jumps over the lazy dog.")
}

@Test
fun shouldUseSimpleTool() = runTest {
// given
val client = Anthropic()


// when
val response = client.messages.stream {
+Message {
role = Role.USER
+"Say: 'The quick brown fox jumps over the lazy dog'"
}
// given
@Serializable
data class Calculator(
val operation: Operation,
val a: Double,
val b: Double
) {

enum class Operation(
val calculate: (a: Double, b: Double) -> Double
) {
ADD({ a, b -> a + b }),
SUBTRACT({ a, b -> a - b }),
MULTIPLY({ a, b -> a * b }),
DIVIDE({ a, b -> a / b })
}
.filterIsInstance<ContentBlockDelta>()
.map { (it.delta as TextDelta).text }
.toList()
.joinToString(separator = "")

// then
assertTrue(response == "The quick brown fox jumps over the lazy dog.")
fun calculate() = operation.calculate(a, b)

}

@Test
fun shouldUseCalculatorTool() = runTest {
// given
val client = Anthropic()
// soon the Tool will use generic serializable type and the schema
// will be generated automatically
val calculatorTool = Tool(
name = "calculator",
description = "Perform basic arithmetic operations",
inputSchema = buildJsonObject {
put("type", "object")
put("properties", buildJsonObject {
put("operation", buildJsonObject {
put("type", "string")
putJsonArray("enum") {
add("add")
add("subtract")
add("multiply")
add("divide")
}
})
put("a", buildJsonObject { put("type", "number") })
put("b", buildJsonObject { put("type", "number") })
})
putJsonArray("required") {
add("operation")
add("a")
add("b")
}
},
inputSchema = jsonSchemaOf<Calculator>(),
cacheControl = null
)

Expand All @@ -177,10 +149,9 @@ class AnthropicTest {
assertTrue(content[0] is ToolUse)
val toolUse = content[0] as ToolUse
assertTrue(toolUse.name == "calculator")
val input = toolUse.input.jsonObject
assertTrue(input["operation"]?.jsonPrimitive?.content == "multiply")
assertTrue(input["a"]?.jsonPrimitive?.double == 15.0)
assertTrue(input["b"]?.jsonPrimitive?.double == 7.0)
val calculator = toolUse.input<Calculator>()
val result = calculator.calculate()
assertTrue(result == 15.0 * 7.0)
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/commonTest/kotlin/message/MessagesTest.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package com.xemantic.anthropic.message

import com.xemantic.anthropic.anthropicJsonConfigurer
import com.xemantic.anthropic.anthropicJson
import io.kotest.assertions.json.shouldEqualJson
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlin.test.Test
Expand All @@ -14,9 +15,10 @@ class MessagesTest {
/**
* A pretty JSON printing for testing.
*/
private val json = Json {
anthropicJsonConfigurer()
private val json = Json(from = anthropicJson) {
prettyPrint = true
@OptIn(ExperimentalSerializationApi::class)
prettyPrintIndent = " "
}

@Test
Expand Down
Loading

0 comments on commit c44eaa2

Please sign in to comment.