From c44eaa2bd0da3fa328c23dc713cd4764b82b52e8 Mon Sep 17 00:00:00 2001 From: Kazik Pogoda Date: Tue, 1 Oct 2024 23:05:51 +0200 Subject: [PATCH] * add direct browser access to client headers * single shared anthropicJson instance for consistency * tool test updated * JsonSchema support --- src/commonMain/kotlin/Anthropic.kt | 28 ++-- src/commonMain/kotlin/message/Messages.kt | 23 +++- src/commonMain/kotlin/schema/JsonSchema.kt | 31 +++++ .../kotlin/schema/JsonSchemaGenerator.kt | 75 ++++++++++ src/commonTest/kotlin/AnthropicTest.kt | 75 ++++------ src/commonTest/kotlin/message/MessagesTest.kt | 8 +- .../kotlin/schema/JsonSchemaGeneratorTest.kt | 130 ++++++++++++++++++ 7 files changed, 301 insertions(+), 69 deletions(-) create mode 100644 src/commonMain/kotlin/schema/JsonSchema.kt create mode 100644 src/commonMain/kotlin/schema/JsonSchemaGenerator.kt create mode 100644 src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt diff --git a/src/commonMain/kotlin/Anthropic.kt b/src/commonMain/kotlin/Anthropic.kt index 7abeb39..e31fd76 100644 --- a/src/commonMain/kotlin/Anthropic.kt +++ b/src/commonMain/kotlin/Anthropic.kt @@ -27,7 +27,6 @@ import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonBuilder const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/" @@ -42,6 +41,12 @@ expect val envApiKey: String? expect val missingApiKeyMessage: String +val anthropicJson: Json = Json { + allowSpecialFloatingPointValues = true + explicitNulls = false + encodeDefaults = true +} + fun Anthropic( block: Anthropic.Config.() -> Unit = {} ): Anthropic { @@ -54,7 +59,8 @@ fun Anthropic( anthropicVersion = config.anthropicVersion, anthropicBeta = config.anthropicBeta, apiBase = config.apiBase, - defaultModel = defaultModel + defaultModel = defaultModel, + directBrowserAccess = config.directBrowserAccess ) } @@ -63,7 +69,8 @@ class Anthropic internal constructor( val anthropicVersion: String, val anthropicBeta: String?, val apiBase: String, - val defaultModel: String + val defaultModel: String, + val directBrowserAccess: Boolean ) { class Config { @@ -72,12 +79,12 @@ class Anthropic internal constructor( var anthropicBeta: String? = null var apiBase: String = ANTHROPIC_API_BASE var defaultModel: String? = null + var directBrowserAccess: Boolean = false } - private val json = Json(builderAction = anthropicJsonConfigurer) private val client = HttpClient { install(ContentNegotiation) { - json(json) + json(anthropicJson) } install(SSE) install(Logging) { @@ -90,6 +97,9 @@ class Anthropic internal constructor( if (anthropicBeta != null) { header("anthropic-beta", anthropicBeta) } + if (directBrowserAccess) { + header("anthropic-dangerous-direct-browser-access", true) + } } } @@ -134,7 +144,7 @@ class Anthropic internal constructor( ) { incoming .filter { it.data != null } - .map { json.decodeFromString(it.data!!) } + .map { anthropicJson.decodeFromString(it.data!!) } .collect { emit(it) } @@ -146,9 +156,3 @@ class Anthropic internal constructor( val messages = Messages() } - -val anthropicJsonConfigurer: JsonBuilder.() -> Unit = { - allowSpecialFloatingPointValues = true - explicitNulls = false - encodeDefaults = true -} diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt index e8a2b64..42477a3 100644 --- a/src/commonMain/kotlin/message/Messages.kt +++ b/src/commonMain/kotlin/message/Messages.kt @@ -1,10 +1,14 @@ package com.xemantic.anthropic.message +import com.xemantic.anthropic.anthropicJson +import com.xemantic.anthropic.schema.JsonSchema +import com.xemantic.anthropic.schema.jsonSchemaOf import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonClassDiscriminator import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.decodeFromJsonElement import kotlin.collections.mutableListOf enum class Role { @@ -180,10 +184,20 @@ data class Tool( val name: String, val description: String, @SerialName("input_schema") - val inputSchema: JsonObject, // soon it will be a generic type + val inputSchema: JsonSchema, val cacheControl: CacheControl? ) +inline fun Tool( + description: String, + cacheControl: CacheControl? = null +): Tool = Tool( + name = T::class.qualifiedName!!, + description = description, + inputSchema = jsonSchemaOf(), + cacheControl = cacheControl +) + @Serializable @JsonClassDiscriminator("type") @OptIn(ExperimentalSerializationApi::class) @@ -242,7 +256,12 @@ data class ToolUse( val id: String, val name: String, val input: JsonObject -) : Content() +) : Content() { + + inline fun input(): T = + anthropicJson.decodeFromJsonElement(input) + +} @Serializable @SerialName("tool_result") diff --git a/src/commonMain/kotlin/schema/JsonSchema.kt b/src/commonMain/kotlin/schema/JsonSchema.kt new file mode 100644 index 0000000..36ee19b --- /dev/null +++ b/src/commonMain/kotlin/schema/JsonSchema.kt @@ -0,0 +1,31 @@ +package com.xemantic.anthropic.schema + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class JsonSchema( + val type: String = "object", + val definitions: Map? = null, + val properties: Map? = null, + val required: List? = null, + @SerialName("\$ref") + var ref: String? = null +) + +@Serializable +data class JsonSchemaProperty( + val type: String? = null, + val items: JsonSchemaProperty? = null, + val enum: List? = null, + val ref: String? = null +) { + + companion object { + val STRING = JsonSchemaProperty("string") + val INTEGER = JsonSchemaProperty("integer") + val NUMBER = JsonSchemaProperty("number") + val BOOLEAN = JsonSchemaProperty("boolean") + } + +} \ No newline at end of file diff --git a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt new file mode 100644 index 0000000..a66b098 --- /dev/null +++ b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt @@ -0,0 +1,75 @@ +package com.xemantic.anthropic.schema + +import kotlinx.serialization.* +import kotlinx.serialization.descriptors.* +import kotlin.collections.set + +inline fun jsonSchemaOf(): JsonSchema = generateSchema( + serializer().descriptor +) + +@OptIn(ExperimentalSerializationApi::class) +fun generateSchema(descriptor: SerialDescriptor): JsonSchema { + val properties = mutableMapOf() + val required = mutableListOf() + val definitions = mutableMapOf() + + for (i in 0 until descriptor.elementsCount) { + val name = descriptor.getElementName(i) + val elementDescriptor = descriptor.getElementDescriptor(i) + val property = generateSchemaProperty(elementDescriptor, definitions) + properties[name] = property + if (!descriptor.isElementOptional(i)) { + required.add(name) + } + } + + return JsonSchema( + type = "object", + properties = properties, + required = required, + definitions = if (definitions.isNotEmpty()) definitions else null + ) +} + +@OptIn(ExperimentalSerializationApi::class) +private fun generateSchemaProperty( + descriptor: SerialDescriptor, + definitions: MutableMap +): JsonSchemaProperty { + return when (descriptor.kind) { + PrimitiveKind.STRING -> JsonSchemaProperty.STRING + PrimitiveKind.INT, PrimitiveKind.LONG -> JsonSchemaProperty.INTEGER + PrimitiveKind.FLOAT, PrimitiveKind.DOUBLE -> JsonSchemaProperty.NUMBER + PrimitiveKind.BOOLEAN -> JsonSchemaProperty.BOOLEAN + SerialKind.ENUM -> enumProperty(descriptor) + StructureKind.LIST -> JsonSchemaProperty( + type = "array", + items = generateSchemaProperty( + descriptor.getElementDescriptor(0), + definitions + ) + ) + StructureKind.MAP -> JsonSchemaProperty("object") + StructureKind.CLASS -> { + val refName = descriptor.serialName.trimEnd('?') + definitions[refName] = generateSchema(descriptor) + JsonSchemaProperty("\$ref", ref = "#/definitions/$refName") + } + else -> JsonSchemaProperty("object") // Default case + } +} + +private fun enumProperty( + descriptor: SerialDescriptor +) = JsonSchemaProperty( + enum = descriptor.elementNames() +) + +@OptIn(ExperimentalSerializationApi::class) +private fun SerialDescriptor.elementNames(): List = buildList { + for (i in 0 until elementsCount) { + val name = getElementName(i) + add(name) + } +} diff --git a/src/commonTest/kotlin/AnthropicTest.kt b/src/commonTest/kotlin/AnthropicTest.kt index f0738e1..c65ea0e 100644 --- a/src/commonTest/kotlin/AnthropicTest.kt +++ b/src/commonTest/kotlin/AnthropicTest.kt @@ -11,17 +11,12 @@ import com.xemantic.anthropic.message.Text import com.xemantic.anthropic.message.Tool import com.xemantic.anthropic.message.ToolChoice import com.xemantic.anthropic.message.ToolUse +import com.xemantic.anthropic.schema.jsonSchemaOf import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import kotlinx.coroutines.test.runTest -import kotlinx.serialization.json.add -import kotlinx.serialization.json.buildJsonObject -import kotlinx.serialization.json.double -import kotlinx.serialization.json.jsonObject -import kotlinx.serialization.json.jsonPrimitive -import kotlinx.serialization.json.put -import kotlinx.serialization.json.putJsonArray +import kotlinx.serialization.Serializable import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNull @@ -107,58 +102,35 @@ class AnthropicTest { assertTrue(response == "The quick brown fox jumps over the lazy dog.") } - @Test - fun shouldUseSimpleTool() = runTest { - // given - val client = Anthropic() - - - // when - val response = client.messages.stream { - +Message { - role = Role.USER - +"Say: 'The quick brown fox jumps over the lazy dog'" - } + // given + @Serializable + data class Calculator( + val operation: Operation, + val a: Double, + val b: Double + ) { + + enum class Operation( + val calculate: (a: Double, b: Double) -> Double + ) { + ADD({ a, b -> a + b }), + SUBTRACT({ a, b -> a - b }), + MULTIPLY({ a, b -> a * b }), + DIVIDE({ a, b -> a / b }) } - .filterIsInstance() - .map { (it.delta as TextDelta).text } - .toList() - .joinToString(separator = "") - // then - assertTrue(response == "The quick brown fox jumps over the lazy dog.") + fun calculate() = operation.calculate(a, b) + } @Test fun shouldUseCalculatorTool() = runTest { // given val client = Anthropic() - // soon the Tool will use generic serializable type and the schema - // will be generated automatically val calculatorTool = Tool( name = "calculator", description = "Perform basic arithmetic operations", - inputSchema = buildJsonObject { - put("type", "object") - put("properties", buildJsonObject { - put("operation", buildJsonObject { - put("type", "string") - putJsonArray("enum") { - add("add") - add("subtract") - add("multiply") - add("divide") - } - }) - put("a", buildJsonObject { put("type", "number") }) - put("b", buildJsonObject { put("type", "number") }) - }) - putJsonArray("required") { - add("operation") - add("a") - add("b") - } - }, + inputSchema = jsonSchemaOf(), cacheControl = null ) @@ -177,10 +149,9 @@ class AnthropicTest { assertTrue(content[0] is ToolUse) val toolUse = content[0] as ToolUse assertTrue(toolUse.name == "calculator") - val input = toolUse.input.jsonObject - assertTrue(input["operation"]?.jsonPrimitive?.content == "multiply") - assertTrue(input["a"]?.jsonPrimitive?.double == 15.0) - assertTrue(input["b"]?.jsonPrimitive?.double == 7.0) + val calculator = toolUse.input() + val result = calculator.calculate() + assertTrue(result == 15.0 * 7.0) } } diff --git a/src/commonTest/kotlin/message/MessagesTest.kt b/src/commonTest/kotlin/message/MessagesTest.kt index 3da077c..e5e9444 100644 --- a/src/commonTest/kotlin/message/MessagesTest.kt +++ b/src/commonTest/kotlin/message/MessagesTest.kt @@ -1,7 +1,8 @@ package com.xemantic.anthropic.message -import com.xemantic.anthropic.anthropicJsonConfigurer +import com.xemantic.anthropic.anthropicJson import io.kotest.assertions.json.shouldEqualJson +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import kotlin.test.Test @@ -14,9 +15,10 @@ class MessagesTest { /** * A pretty JSON printing for testing. */ - private val json = Json { - anthropicJsonConfigurer() + private val json = Json(from = anthropicJson) { prettyPrint = true + @OptIn(ExperimentalSerializationApi::class) + prettyPrintIndent = " " } @Test diff --git a/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt b/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt new file mode 100644 index 0000000..bb5c8e7 --- /dev/null +++ b/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt @@ -0,0 +1,130 @@ +package com.xemantic.anthropic.schema + +import com.xemantic.anthropic.anthropicJson +import io.kotest.assertions.json.shouldEqualJson +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.Serializable +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlin.test.Test + +@Serializable +data class Address( + val street: String? = null, + val city: String? = null, + val zipCode: String, + val country: String +) + +@Serializable +data class Person( + val name: String, + val age: Int, + val email: String?, + val hobbies: List = emptyList(), + val address: Address? = null +) + +class JsonSchemaGeneratorTest { + + private val json = Json(from = anthropicJson) { + prettyPrint = true + @OptIn(ExperimentalSerializationApi::class) + prettyPrintIndent = " " + } + + @Test + fun generateJsonSchemaForAddress() { + + // when + val schema = jsonSchemaOf
() + val schemaJson = json.encodeToString(schema) + + // then + schemaJson shouldEqualJson """ + { + "properties": { + "street": { + "type": "string" + }, + "city": { + "type": "string" + }, + "zipCode": { + "type": "string" + }, + "country": { + "type": "string" + } + }, + "required": [ + "zipCode", + "country" + ] + } + """.trimIndent() + } + + @Test + fun generateSchemaForJson() { + // when + val schema = jsonSchemaOf() + val schemaJson = json.encodeToString(schema) + + // then + print(schemaJson) + schemaJson shouldEqualJson """ + { + "definitions": { + "com.xemantic.anthropic.schema.Address": { + "properties": { + "street": { + "type": "string" + }, + "city": { + "type": "string" + }, + "zipCode": { + "type": "string" + }, + "country": { + "type": "string" + } + }, + "required": [ + "zipCode", + "country" + ] + } + }, + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "email": { + "type": "string" + }, + "hobbies": { + "type": "array", + "items": { + "type": "string" + } + }, + "address": { + "type": "${'$'}ref", + "ref": "#/definitions/com.xemantic.anthropic.schema.Address" + } + }, + "required": [ + "name", + "age", + "email" + ] + } + """.trimIndent() + } + +}