Skip to content

Commit

Permalink
logging, safe tool usage, retry on failed requests
Browse files Browse the repository at this point in the history
  • Loading branch information
morisil committed Oct 11, 2024
1 parent a0b5a9a commit d0c60d7
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 128 deletions.
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ kotlin {
implementation(libs.ktor.client.content.negotiation)
implementation(libs.ktor.client.logging)
implementation(libs.ktor.serialization.kotlinx.json)
implementation(libs.kotlin.logging)
}
}

Expand Down
11 changes: 7 additions & 4 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ kotlinxCoroutines = "1.9.0"
ktor = "3.0.0"
kotest = "6.0.0.M1"

kotlinLogging = "7.0.0"
log4j = "2.24.1"
jackson = "2.18.0"

Expand All @@ -18,10 +19,12 @@ publishPlugin = "2.0.0"
kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" }
kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version.ref = "kotlinxCoroutines" }

log4j-slf4j2 = { group = "org.apache.logging.log4j", name = "log4j-slf4j2-impl", version.ref = "log4j" }
log4j-core = { group = "org.apache.logging.log4j", name = "log4j-core", version.ref = "log4j" }
jackson-databind = { group = "com.fasterxml.jackson.core", name = "jackson-databind", version.ref = "jackson" }
jackson-dataformat-yaml = { group = "com.fasterxml.jackson.dataformat", name = "jackson-dataformat-yaml", version.ref = "jackson" }
# logging libs
kotlin-logging = { module = "io.github.oshai:kotlin-logging-jvm", version.ref = "kotlinLogging" }
log4j-slf4j2 = { module = "org.apache.logging.log4j:log4j-slf4j2-impl", version.ref = "log4j" }
log4j-core = { module = "org.apache.logging.log4j:log4j-core", version.ref = "log4j" }
jackson-databind = { module = "com.fasterxml.jackson.core:jackson-databind", version.ref = "jackson" }
jackson-dataformat-yaml = { module = "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml", version.ref = "jackson" }

ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" }
ktor-client-content-negotiation = { module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" }
Expand Down
45 changes: 27 additions & 18 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import com.xemantic.anthropic.tool.UsableTool
import com.xemantic.anthropic.tool.toolOf
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.*
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.plugins.defaultRequest
import io.ktor.client.plugins.logging.LogLevel
import io.ktor.client.plugins.logging.Logging
import io.ktor.client.plugins.sse.SSE
Expand All @@ -33,12 +33,20 @@ import kotlinx.coroutines.flow.map
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.Json
import kotlinx.serialization.serializer
import kotlin.reflect.KClass

/**
* The default Anthropic API base.
*/
const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/"

/**
* The default version to be passed to the `anthropic-version` HTTP header of each API request.
*/
const val DEFAULT_ANTHROPIC_VERSION: String = "2023-06-01"

/**
* An exception thrown when API requests returns error.
*/
class AnthropicException(
error: Error,
httpStatusCode: HttpStatusCode
Expand All @@ -48,17 +56,20 @@ expect val envApiKey: String?

expect val missingApiKeyMessage: String

/**
* A JSON format suitable for communication with Anthropic API.
*/
val anthropicJson: Json = Json {
allowSpecialFloatingPointValues = true
explicitNulls = false
encodeDefaults = true
// serializersModule = SerializersModule {
// //contextual(UsableTool::class, UsableToolSerializer::class)
//// polymorphic(UsableTool::class) {
//// }
// }
}

/**
* The public constructor function which for the Anthropic API client.
*
* @param block the config block to set up the API access.
*/
fun Anthropic(
block: Anthropic.Config.() -> Unit = {}
): Anthropic {
Expand Down Expand Up @@ -110,21 +121,11 @@ class Anthropic internal constructor(
internal class ToolEntry<T : UsableTool>(
val tool: Tool, // TODO, no cache control
val serializer: KSerializer<T>,
val initializer: T.() -> Unit = {}
val initialize: T.() -> Unit = {}
)

internal var toolEntryMap = mapOf<String, ToolEntry<UsableTool>>()

// var usableTools: List<KClass<out Tool>> = emptyList()
// set(value) {
// toolMap += mapOf(value)
// field = value
// }

inline fun <reified T : UsableTool> tool() {
//usableTools += T::class
}

private val client = HttpClient {
install(ContentNegotiation) {
json(anthropicJson)
Expand All @@ -133,6 +134,14 @@ class Anthropic internal constructor(
install(Logging) {
level = LogLevel.BODY
}
install(HttpRequestRetry) {
retryOnServerErrors(maxRetries = 5)
exponentialDelay()
maxRetries = 5
retryIf { _, response ->
response.status == HttpStatusCode.TooManyRequests
}
}
defaultRequest {
url(apiBase)
header("x-api-key", apiKey)
Expand Down
128 changes: 22 additions & 106 deletions src/commonMain/kotlin/message/Messages.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,11 @@ import com.xemantic.anthropic.Anthropic
import com.xemantic.anthropic.anthropicJson
import com.xemantic.anthropic.schema.JsonSchema
import com.xemantic.anthropic.tool.UsableTool
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.serialization.*
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.PolymorphicKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlinx.serialization.json.JsonDecoder
import kotlinx.serialization.json.JsonObject
import kotlin.collections.mutableListOf
import kotlin.reflect.KClass

enum class Role {
@SerialName("user")
Expand Down Expand Up @@ -45,7 +38,9 @@ data class MessageRequest(
@SerialName("tool_choice")
val toolChoice: ToolChoice?,
val tools: List<Tool>?,
@SerialName("top_k")
val topK: Int?,
@SerialName("top_p")
val topP: Int?
) {

Expand Down Expand Up @@ -285,6 +280,8 @@ data class ToolUse(
val input: JsonObject
) : Content() {

private val logger = KotlinLogging.logger {}

@Transient
internal lateinit var toolEntry: Anthropic.ToolEntry<UsableTool>

Expand All @@ -293,8 +290,23 @@ data class ToolUse(
deserializer = toolEntry.serializer,
element = input
)
toolEntry.initializer(tool)
return tool.use(toolUseId = id)
val result = try {
toolEntry.initialize(tool)
logger.debug { "[$name:$id] Using tool" }
tool.use(toolUseId = id)
} catch (e: Exception) {
logger.error(e) { "[$name:$id] Tool use error: ${e.message}" }
ToolResult(
toolUseId = id,
isError = true,
content = listOf(
Text(
text = e.message ?: "Unknown error occurred"
)
)
)
}
return result
}

}
Expand Down Expand Up @@ -374,99 +386,3 @@ data class Usage(
@SerialName("output_tokens")
val outputTokens: Int
)


interface CacheableBuilder {

var cacheControl: CacheControl?

var cache: Boolean
get() = cacheControl != null
set(value) {
if (value) {
cacheControl = CacheControl(type = CacheControl.Type.EPHEMERAL)
} else {
cacheControl = null
}
}

}

//class UsableToolSerializer : JsonContentPolymorphicSerializer2<UsableTool>(UsableTool::class) {
//
//// override val descriptor: SerialDescriptor = buildClassSerialDescriptor("UsableTool") {
//// element<String>("type")
//// element<JsonElement>("data")
//// }
////
//// override fun serialize(
//// encoder: Encoder,
//// value: UsableTool
//// ) {
////// val polymorphic: SerializationStrategy<String> = serializersModule.getPolymorphic(UsableTool::class, "foo")
//// PolymorphicSerializer(UsableTool::class)
////// encoder.encodeString(value)
////// encoder.encodeString(value.name)
////// polymorphic.seri
////// encoder.encodeSerializableValue(polymorphic)
//// }
////
//// override fun deserialize(decoder: Decoder): UsableTool {
//// require(decoder is JsonDecoder) { "This serializer can be used only with Json format" }
//// val name = decoder.decodeString()
//// val polymorphic = decoder.serializersModule.getPolymorphic(UsableTool::class, name)
//// val id = decoder.decodeString()
//// return DummyUsableTool()
//// }
//
// override fun selectDeserializer(element: JsonElement): DeserializationStrategy<UsableTool> {
// println(element)
// TODO("dupa dupa Not yet implemented")
// }
//
//}


//class UsableToolSerializer : JsonContentPolymorphicSerializer2<UsableTool>(
// UsableTool::class
//)

@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
open class JsonContentPolymorphicSerializer2<T : Any>(private val baseClass: KClass<T>) : KSerializer<T> {
/**
* A descriptor for this set of content-based serializers.
* By default, it uses the name composed of [baseClass] simple name,
* kind is set to [PolymorphicKind.SEALED] and contains 0 elements.
*
* However, this descriptor can be overridden to achieve better representation of custom transformed JSON shape
* for schema generating/introspection purposes.
*/
override val descriptor: SerialDescriptor =
buildSerialDescriptor("JsonContentPolymorphicSerializer<${baseClass.simpleName}>", PolymorphicKind.SEALED)

final override fun serialize(encoder: Encoder, value: T) {
val actualSerializer =
encoder.serializersModule.getPolymorphic(baseClass, value)
?: value::class.serializerOrNull()
?: throw SerializationException("fiu fiu")
@Suppress("UNCHECKED_CAST")
(actualSerializer as KSerializer<T>).serialize(encoder, value)
}

final override fun deserialize(decoder: Decoder): T {
val input = decoder.asJsonDecoder()
input.json.serializersModule.getPolymorphic(UsableTool::class, "foo")
val tree = input.decodeJsonElement()

@Suppress("UNCHECKED_CAST")
val actualSerializer = String.serializer() as KSerializer<T>
return input.json.decodeFromJsonElement(actualSerializer, tree)
}

}

internal fun Decoder.asJsonDecoder(): JsonDecoder = this as? JsonDecoder
?: throw IllegalStateException(
"This serializer can be used only with Json format." +
"Expected Decoder to be JsonDecoder, got ${this::class}"
)

0 comments on commit d0c60d7

Please sign in to comment.