Skip to content

Commit

Permalink
Improves polymorphic serialization + tests (#561)
Browse files Browse the repository at this point in the history
* Improves polymorphic serialization + tests

* Completes the tests and models

* Uses kotlin test runtime

* Fixes compilation errors
  • Loading branch information
fedefernandez authored Dec 5, 2023
1 parent 6a88611 commit 5d4f9e9
Show file tree
Hide file tree
Showing 32 changed files with 1,100 additions and 311 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package com.xebia.functional.xef.prompt

import ai.xef.openai.OpenAIModel
import com.xebia.functional.openai.models.ChatCompletionRole
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestMessage
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestUserMessageContentText
import com.xebia.functional.openai.models.ext.chat.*
import com.xebia.functional.xef.prompt.templates.assistant
import com.xebia.functional.xef.prompt.templates.system
import com.xebia.functional.xef.prompt.templates.user
Expand Down Expand Up @@ -69,25 +68,16 @@ interface PromptBuilder<T> {

fun String.message(role: ChatCompletionRole): ChatCompletionRequestMessage =
when (role) {
ChatCompletionRole.system ->
ChatCompletionRequestMessage.ChatCompletionRequestSystemMessage(this)
ChatCompletionRole.system -> ChatCompletionRequestSystemMessage(this)
ChatCompletionRole.user ->
ChatCompletionRequestMessage.ChatCompletionRequestUserMessage(
listOf(
ChatCompletionRequestUserMessageContentText(
ChatCompletionRequestUserMessageContentText.Type.text,
this
)
)
)
ChatCompletionRole.assistant ->
ChatCompletionRequestMessage.ChatCompletionRequestAssistantMessage(this)
ChatCompletionRequestUserMessage(listOf(ChatCompletionRequestUserMessageContentText(this)))
ChatCompletionRole.assistant -> ChatCompletionRequestAssistantMessage(this)
ChatCompletionRole.tool ->
// TODO - Tool Id?
ChatCompletionRequestMessage.ChatCompletionRequestToolMessage(this, "toolId")
ChatCompletionRequestToolMessage(this, "toolId")
ChatCompletionRole.function ->
// TODO - Function name?
ChatCompletionRequestMessage.ChatCompletionRequestFunctionMessage(this, "functionName")
ChatCompletionRequestFunctionMessage(this, "functionName")
}

// TODO this fails because of the ChatCompletionRequestMessage role fixed to function in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package com.xebia.functional.xef.store

import com.xebia.functional.openai.models.ChatCompletionResponseMessage
import com.xebia.functional.openai.models.ChatCompletionRole
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestMessage
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestUserMessageContentText
import com.xebia.functional.openai.models.ext.chat.*

sealed class MemorizedMessage {
val role: ChatCompletionRole
Expand All @@ -16,8 +15,7 @@ sealed class MemorizedMessage {
fun asRequestMessage(): ChatCompletionRequestMessage =
when (this) {
is Request -> message
is Response ->
ChatCompletionRequestMessage.ChatCompletionRequestAssistantMessage(message.content)
is Response -> ChatCompletionRequestAssistantMessage(message.content)
}

data class Request(val message: ChatCompletionRequestMessage) : MemorizedMessage()
Expand All @@ -28,18 +26,11 @@ sealed class MemorizedMessage {
fun memorizedMessage(role: ChatCompletionRole, content: String): MemorizedMessage =
when (role) {
ChatCompletionRole.system ->
MemorizedMessage.Request(
ChatCompletionRequestMessage.ChatCompletionRequestSystemMessage(content)
)
MemorizedMessage.Request(ChatCompletionRequestSystemMessage(content))
ChatCompletionRole.user ->
MemorizedMessage.Request(
ChatCompletionRequestMessage.ChatCompletionRequestUserMessage(
listOf(
ChatCompletionRequestUserMessageContentText(
ChatCompletionRequestUserMessageContentText.Type.text,
content
)
)
ChatCompletionRequestUserMessage(
listOf(ChatCompletionRequestUserMessageContentText(content))
)
)
ChatCompletionRole.assistant ->
Expand All @@ -51,14 +42,14 @@ fun memorizedMessage(role: ChatCompletionRole, content: String): MemorizedMessag
)
ChatCompletionRole.tool ->
MemorizedMessage.Request(
ChatCompletionRequestMessage.ChatCompletionRequestToolMessage(
ChatCompletionRequestToolMessage(
content = content,
toolCallId = "fake-tool-call-id" // TODO we are not storing the tool id with the content
)
)
ChatCompletionRole.function ->
MemorizedMessage.Request(
ChatCompletionRequestMessage.ChatCompletionRequestToolMessage(
ChatCompletionRequestToolMessage(
content = content,
toolCallId = "fake-tool-call-id" // TODO we are not storing the tool id with the content
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package com.xebia.functional.xef.conversation
import ai.xef.openai.StandardModel
import com.xebia.functional.openai.models.ChatCompletionRole
import com.xebia.functional.openai.models.CreateChatCompletionRequestModel
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestAssistantMessage
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestMessage.*
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestUserMessage
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestUserMessageContentText
import com.xebia.functional.xef.data.*
import com.xebia.functional.xef.llm.models.modelType
Expand Down Expand Up @@ -81,15 +83,9 @@ class ConversationSpec :
messages.flatMap {
listOf(
ChatCompletionRequestUserMessage(
content =
listOf(
ChatCompletionRequestUserMessageContentText(
ChatCompletionRequestUserMessageContentText.Type.text,
it.key
)
)
listOf(ChatCompletionRequestUserMessageContentText(it.key))
),
ChatCompletionRequestAssistantMessage(content = it.value),
ChatCompletionRequestAssistantMessage(it.value),
)
}
)
Expand Down Expand Up @@ -134,15 +130,9 @@ class ConversationSpec :
messages.flatMap {
listOf(
ChatCompletionRequestUserMessage(
content =
listOf(
ChatCompletionRequestUserMessageContentText(
ChatCompletionRequestUserMessageContentText.Type.text,
it.key
)
)
listOf(ChatCompletionRequestUserMessageContentText(it.key))
),
ChatCompletionRequestAssistantMessage(content = it.value)
ChatCompletionRequestAssistantMessage(it.value)
)
}
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package com.xebia.functional.xef.store

import arrow.atomic.AtomicInt
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestMessage
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestAssistantMessage
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestUserMessage
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestUserMessageContentText

class MemoryData {
Expand All @@ -16,18 +17,14 @@ class MemoryData {
): List<Memory> =
(0 until n).flatMap {
val m1 =
ChatCompletionRequestMessage.ChatCompletionRequestUserMessage(
ChatCompletionRequestUserMessage(
listOf(
ChatCompletionRequestUserMessageContentText(
ChatCompletionRequestUserMessageContentText.Type.text,
"Question $it${append?.let { ": $it" } ?: ""}"
)
)
)
val m2 =
ChatCompletionRequestMessage.ChatCompletionRequestAssistantMessage(
"Response $it${append?.let { ": $it" } ?: ""}"
)
val m2 = ChatCompletionRequestAssistantMessage("Response $it${append?.let { ": $it" } ?: ""}")
listOf(
Memory(conversationId, MemorizedMessage.Request(m1), atomicInt.addAndGet(1)),
Memory(conversationId, MemorizedMessage.Request(m2), atomicInt.addAndGet(1)),
Expand Down
9 changes: 3 additions & 6 deletions integrations/postgresql/src/test/kotlin/xef/MemoryData.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package xef

import arrow.atomic.AtomicInt
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestMessage
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestUserMessageContent
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestUserMessageContentText
import com.xebia.functional.openai.models.ext.chat.*
import com.xebia.functional.xef.store.ConversationId
import com.xebia.functional.xef.store.MemorizedMessage
import com.xebia.functional.xef.store.Memory
Expand All @@ -19,15 +17,14 @@ class MemoryData {
conversationId: ConversationId = defaultConversationId
): List<Memory> =
(0 until n).flatMap {
val m1 = ChatCompletionRequestMessage.ChatCompletionRequestUserMessage(
val m1 = ChatCompletionRequestUserMessage(
listOf(
ChatCompletionRequestUserMessageContentText(
ChatCompletionRequestUserMessageContentText.Type.text,
"Question $it${append?.let { ": $it" } ?: ""}"
)
)
)
val m2 = ChatCompletionRequestMessage.ChatCompletionRequestAssistantMessage("Response $it${append?.let { ": $it" } ?: ""}")
val m2 = ChatCompletionRequestAssistantMessage("Response $it${append?.let { ": $it" } ?: ""}")
listOf(
Memory(conversationId, MemorizedMessage.Request(m1), atomicInt.addAndGet(1)),
Memory(conversationId, MemorizedMessage.Request(m2), atomicInt.addAndGet(1)),
Expand Down
6 changes: 3 additions & 3 deletions openai-client/client/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ kotlin {
}
val commonTest by getting {
dependencies {
implementation(kotlin("test"))
implementation(libs.kotest.property)
implementation(libs.kotest.framework)
implementation(libs.kotest.assertions)
}
}
Expand All @@ -82,8 +82,8 @@ kotlin {
api(libs.ktor.client.cio)
}
}
val jvmTest by getting { dependencies { implementation(libs.kotest.junit5) } }
val jsMain by getting { dependencies { api(libs.ktor.client.js) } }
val jvmTest by getting { dependencies { implementation(libs.kotest.junit5) } }
val linuxX64Main by getting { dependencies { api(libs.ktor.client.cio) } }
val macosX64Main by getting { dependencies { api(libs.ktor.client.cio) } }
val macosArm64Main by getting { dependencies { api(libs.ktor.client.cio) } }
Expand Down Expand Up @@ -112,7 +112,7 @@ kotlin {
spotless {
kotlin {
target("**/*.kt")
ktfmt().googleStyle()
ktfmt().googleStyle().configure { it.setRemoveUnusedImport(true) }
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package com.xebia.functional.openai.models.ext.assistant

import com.xebia.functional.openai.models.FunctionObject
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.*
import kotlinx.serialization.json.*

@Serializable(with = AssistantTools.MyTypeSerializer::class)
Expand All @@ -19,28 +16,3 @@ sealed interface AssistantTools {
}
}
}

@Serializable
data class AssistantToolsCode(val type: Type = Type.code_interpreter) : AssistantTools {
@Serializable
enum class Type(val value: String) {
@SerialName(value = "code_interpreter") code_interpreter("code_interpreter")
}
}

@Serializable
data class AssistantToolsRetrieval(val type: Type = Type.retrieval) : AssistantTools {
@Serializable
enum class Type(val value: String) {
@SerialName(value = "retrieval") retrieval("retrieval")
}
}

@Serializable
data class AssistantToolsFunction(val type: Type = Type.function, val function: FunctionObject) :
AssistantTools {
@Serializable
enum class Type(val value: String) {
@SerialName(value = "function") function("function")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.xebia.functional.openai.models.ext.assistant

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class AssistantToolsCode(val type: Type) : AssistantTools {
constructor() : this(Type.code_interpreter)

@Serializable
enum class Type(val value: String) {
@SerialName(value = "code_interpreter") code_interpreter("code_interpreter");

override fun toString(): String {
return value
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xebia.functional.openai.models.ext.assistant

import com.xebia.functional.openai.models.FunctionObject
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class AssistantToolsFunction(val function: FunctionObject, val type: Type) : AssistantTools {
constructor(function: FunctionObject) : this(function, Type.function)

@Serializable
enum class Type(val value: String) {
@SerialName(value = "function") function("function")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.xebia.functional.openai.models.ext.assistant

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class AssistantToolsRetrieval(val type: Type) : AssistantTools {
constructor() : this(Type.retrieval)

@Serializable
enum class Type(val value: String) {
@SerialName(value = "retrieval") retrieval("retrieval")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.xebia.functional.openai.models.ext.assistant

import com.xebia.functional.openai.models.RunStepDetailsMessageCreationObjectMessageCreation
import kotlinx.serialization.Required
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class RunStepDetailsMessageCreationObject(
@SerialName(value = "message_creation")
@Required
val messageCreation: RunStepDetailsMessageCreationObjectMessageCreation,
/* Always `message_creation``. */
@SerialName(value = "type") @Required val type: Type
) : RunStepObjectStepDetails {

constructor(
messageCreation: RunStepDetailsMessageCreationObjectMessageCreation
) : this(messageCreation, Type.message_creation)

/**
* Always `message_creation``.
*
* Values: message_creation
*/
@Serializable
enum class Type(val value: String) {
@SerialName(value = "message_creation") message_creation("message_creation")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.xebia.functional.openai.models.ext.assistant

import com.xebia.functional.openai.models.RunStepDetailsToolCallsObjectToolCallsInner
import kotlinx.serialization.Required
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class RunStepDetailsToolCallsObject(

/* An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `retrieval`, or `function`. */
@SerialName(value = "tool_calls")
@Required
val toolCalls: List<RunStepDetailsToolCallsObjectToolCallsInner>,

/* Always `tool_calls`. */
@SerialName(value = "type") @Required val type: Type = Type.tool_calls
) : RunStepObjectStepDetails {

constructor(
toolCalls: List<RunStepDetailsToolCallsObjectToolCallsInner>
) : this(toolCalls, Type.tool_calls)

/**
* Always `tool_calls`.
*
* Values: tool_calls
*/
@Serializable
enum class Type(val value: String) {
@SerialName(value = "tool_calls") tool_calls("tool_calls")
}
}
Loading

0 comments on commit 5d4f9e9

Please sign in to comment.