Skip to content

Commit

Permalink
added article summary with OpenAI integration (#399)
Browse files Browse the repository at this point in the history
* OpenAI Integration

* Make OpenAI section dialog

* Apply suggestions from code review

Co-authored-by: Jonas Kalderstam <[email protected]>

* Fix refresh models uses stored settings, fix ui jumping

* ktlint format

---------

Co-authored-by: Jonas Kalderstam <[email protected]>
  • Loading branch information
anod and spacecowboy authored Nov 16, 2024
1 parent 63288ad commit ca13fd4
Show file tree
Hide file tree
Showing 13 changed files with 842 additions and 3 deletions.
3 changes: 3 additions & 0 deletions app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,16 @@ dependencies {
implementation(platform(libs.okhttp.bom))
implementation(platform(libs.coil.bom))
implementation(platform(libs.compose.bom))
implementation(platform(libs.openai.client.bom))

// Dependencies
implementation(libs.bundles.android)
implementation(libs.bundles.compose)
implementation(libs.bundles.jvm)
implementation(libs.bundles.okhttp.android)
implementation(libs.bundles.kotlin)
implementation(libs.openai.client)
implementation(libs.ktor.client.okhttp)

// Only for debug
debugImplementation("com.squareup.leakcanary:leakcanary-android:3.0-alpha-1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ class Repository(override val di: DI) : DIAware {
sessionStore.setResumeTime(value)
}

val openAISettings = settingsStore.openAiSettings

fun setOpenAiSettings(value: OpenAISettings) = settingsStore.setOpenAiSettings(value)

val showTitleUnreadCount = settingsStore.showTitleUnreadCount

fun setShowTitleUnreadCount(value: Boolean) = settingsStore.setShowTitleUnreadCount(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,29 @@ class SettingsStore(override val di: DI) : DIAware {
}
}

private val _openAiSettings =
MutableStateFlow(
OpenAISettings(
key = sp.getStringNonNull(PREF_OPENAI_KEY, ""),
modelId = sp.getStringNonNull(PREF_OPENAI_MODEL_ID, "gpt-4o-mini"),
baseUrl = sp.getStringNonNull(PREF_OPENAI_URL, ""),
azureApiVersion = sp.getStringNonNull(PREF_OPENAI_AZURE_VERSION, ""),
azureDeploymentId = sp.getStringNonNull(PREF_OPENAI_AZURE_DEPLOYMENT_ID, ""),
),
)
val openAiSettings = _openAiSettings.asStateFlow()

fun setOpenAiSettings(value: OpenAISettings) {
_openAiSettings.value = value
sp.edit()
.putString(PREF_OPENAI_KEY, value.key)
.putString(PREF_OPENAI_MODEL_ID, value.modelId)
.putString(PREF_OPENAI_URL, value.baseUrl)
.putString(PREF_OPENAI_AZURE_VERSION, value.azureApiVersion)
.putString(PREF_OPENAI_AZURE_DEPLOYMENT_ID, value.azureDeploymentId)
.apply()
}

private val _showTitleUnreadCount = MutableStateFlow(sp.getBoolean(PREF_SHOW_TITLE_UNREAD_COUNT, false))
val showTitleUnreadCount = _showTitleUnreadCount.asStateFlow()

Expand Down Expand Up @@ -586,6 +609,15 @@ const val PREF_LIST_SHOW_READING_TIME = "pref_show_reading_time"
*/
const val PREF_READALOUD_USE_DETECT_LANGUAGE = "pref_readaloud_detect_lang"

/**
* OpenAI integration
*/
const val PREF_OPENAI_KEY = "pref_openai_key"
const val PREF_OPENAI_MODEL_ID = "pref_openai_model_id"
const val PREF_OPENAI_URL = "pref_openai_url"
const val PREF_OPENAI_AZURE_VERSION = "pref_openai_azure_version"
const val PREF_OPENAI_AZURE_DEPLOYMENT_ID = "pref_openai_azure_deployment_id"

/**
* Appearance settings
*/
Expand Down Expand Up @@ -702,6 +734,14 @@ enum class SwipeAsRead(
FROM_ANYWHERE(R.string.from_anywhere),
}

data class OpenAISettings(
val modelId: String = "",
val baseUrl: String = "",
val azureApiVersion: String = "",
val azureDeploymentId: String = "",
val key: String = "",
)

fun String.dropEnds(
starting: Int,
ending: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.nononsenseapps.feeder.base.bindWithActivityViewModelScope
import com.nononsenseapps.feeder.base.bindWithComposableViewModelScope
import com.nononsenseapps.feeder.model.OPMLParserHandler
import com.nononsenseapps.feeder.model.opml.OPMLImporter
import com.nononsenseapps.feeder.openai.OpenAIApi
import com.nononsenseapps.feeder.ui.CommonActivityViewModel
import com.nononsenseapps.feeder.ui.MainActivityViewModel
import com.nononsenseapps.feeder.ui.NavigationDeepLinkViewModel
Expand All @@ -22,7 +23,10 @@ import com.nononsenseapps.feeder.ui.compose.searchfeed.SearchFeedViewModel
import com.nononsenseapps.feeder.ui.compose.settings.SettingsViewModel
import org.kodein.di.DI
import org.kodein.di.bind
import org.kodein.di.compose.instance
import org.kodein.di.instance
import org.kodein.di.singleton
import java.util.Locale

val archModelModule =
DI.Module(name = "arch models") {
Expand All @@ -33,6 +37,7 @@ val archModelModule =
bind<FeedItemStore>() with singleton { FeedItemStore(di) }
bind<SyncRemoteStore>() with singleton { SyncRemoteStore(di) }
bind<OPMLParserHandler>() with singleton { OPMLImporter(di) }
bind<OpenAIApi>() with singleton { OpenAIApi(instance(), appLang = Locale.getDefault().getISO3Language()) }

bindWithActivityViewModelScope<MainActivityViewModel>()
bindWithActivityViewModelScope<OpenLinkInDefaultActivityViewModel>()
Expand Down
209 changes: 209 additions & 0 deletions app/src/main/java/com/nononsenseapps/feeder/openai/OpenAIApi.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
package com.nononsenseapps.feeder.openai

import com.aallam.openai.api.chat.ChatCompletionRequest
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatResponseFormat
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.chat.TextContent
import com.aallam.openai.api.logging.LogLevel
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.LoggingConfig
import com.aallam.openai.client.OpenAI
import com.aallam.openai.client.OpenAIConfig
import com.aallam.openai.client.OpenAIHost
import com.nononsenseapps.feeder.BuildConfig
import com.nononsenseapps.feeder.archmodel.OpenAISettings
import com.nononsenseapps.feeder.archmodel.Repository
import io.ktor.client.plugins.HttpSend
import io.ktor.client.plugins.plugin
import io.ktor.client.request.url
import io.ktor.http.URLBuilder
import io.ktor.http.appendPathSegments
import io.ktor.http.takeFrom
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json

private fun OpenAISettings.toOpenAIConfig(): OpenAIConfig =
OpenAIConfig(
token = key,
logging = LoggingConfig(logLevel = LogLevel.Headers, sanitize = !BuildConfig.DEBUG),
host = toOpenAIHost(withAzureDeploymentId = false),
httpClientConfig = {
if (isAzure) {
install(HttpSend)
install("azure-interceptor") {
plugin(HttpSend).intercept { request ->
request.headers.remove("Authorization")
request.headers.append("api-key", key)
// models path doesn't include azureDeploymentId
val path = request.url.pathSegments.takeLastWhile { it != "openai" || it.isEmpty() }
val url =
toOpenAIHost(withAzureDeploymentId = path.last() != "models")
.toUrl()
.appendPathSegments(path)
.build()
request.url(url)
execute(request)
}
}
}
},
)

class OpenAIApi(
private val repository: Repository,
private val appLang: String,
) {
@Serializable
data class SummaryResponse(val lang: String, val content: String)

sealed interface SummaryResult {
val content: String

data class Success(
val id: String,
val created: Long,
val model: String,
override val content: String,
val promptTokens: Int,
val completeTokens: Int,
val totalTokens: Int,
val detectedLanguage: String,
) : SummaryResult

data class Error(override val content: String) : SummaryResult
}

sealed interface ModelsResult {
data object MissingToken : ModelsResult

data object AzureApiVersionRequired : ModelsResult

data object AzureDeploymentIdRequired : ModelsResult

data class Success(val ids: List<String>) : ModelsResult

data class Error(val message: String?) : ModelsResult
}

private val openAISettings: OpenAISettings
get() = repository.openAISettings.value

private val openAI: OpenAI
get() = OpenAI(config = openAISettings.toOpenAIConfig())

suspend fun listModelIds(settings: OpenAISettings): ModelsResult {
if (settings.key.isEmpty()) {
return ModelsResult.MissingToken
}
if (settings.isAzure) {
if (settings.azureApiVersion.isBlank()) {
return ModelsResult.AzureApiVersionRequired
}
if (settings.azureDeploymentId.isBlank()) {
return ModelsResult.AzureDeploymentIdRequired
}
}
return try {
OpenAI(config = settings.toOpenAIConfig()).models()
.sortedByDescending { it.created }
.map { it.id.id }.let { ModelsResult.Success(it) }
} catch (e: Exception) {
ModelsResult.Error(message = e.message ?: e.cause?.message)
}
}

suspend fun summarize(content: String): SummaryResult {
try {
val response =
openAI.chatCompletion(
request = summaryRequest(content),
requestOptions = null,
)
val summaryResponse: SummaryResponse =
response.choices.firstOrNull()?.message?.content?.let { text ->
Json.decodeFromString(text)
} ?: throw IllegalStateException("Response content is null")

return SummaryResult.Success(
id = response.id,
model = response.model.id,
content = summaryResponse.content,
created = response.created,
promptTokens = response.usage?.promptTokens ?: 0,
completeTokens = response.usage?.completionTokens ?: 0,
totalTokens = response.usage?.completionTokens ?: 0,
detectedLanguage = summaryResponse.lang,
)
} catch (e: Exception) {
return SummaryResult.Error(content = e.message ?: e.cause?.message ?: "")
}
}

private fun summaryRequest(content: String): ChatCompletionRequest {
return ChatCompletionRequest(
model = ModelId(id = openAISettings.modelId),
messages =
listOf(
ChatMessage(
role = ChatRole.System,
messageContent =
TextContent(
listOf(
"You are an assistant in an RSS reader app, summarizing article content.",
"The app language is '$appLang'.",
"Provide summaries in the article's language if 99% recognizable; otherwise, use the app language.",
"Format response as JSON: { \"lang\": \"ISO code\", \"content\": \"summary\" }.",
"Keep summaries up to 100 words, 3 paragraphs, with up to 3 bullet points per paragraph.",
"For readability use bullet points, titles, quotes and new lines using plain text only.",
"Use only single language.",
"Keep full quotes if any.",
).joinToString(separator = " "),
),
),
ChatMessage(
role = ChatRole.User,
messageContent = TextContent("Summarize:\n\n$content"),
),
),
responseFormat = ChatResponseFormat.JsonObject,
)
}
}

val OpenAISettings.isAzure: Boolean
get() = baseUrl.contains("openai.azure.com", ignoreCase = true)

val OpenAISettings.isValid: Boolean
get() =
modelId.isNotEmpty() &&
key.isNotEmpty() &&
if (isAzure) azureApiVersion.isNotBlank() && azureDeploymentId.isNotBlank() else true

fun OpenAISettings.toOpenAIHost(withAzureDeploymentId: Boolean): OpenAIHost =
baseUrl.let { baseUrl ->
if (baseUrl.isEmpty()) {
OpenAIHost.OpenAI
} else {
OpenAIHost(
baseUrl =
URLBuilder()
.takeFrom(baseUrl).also {
it.appendPathSegments("openai")
if (withAzureDeploymentId && azureDeploymentId.isNotBlank()) {
it.appendPathSegments("deployments", azureDeploymentId)
}
}.buildString(),
queryParams =
azureApiVersion.let { apiVersion ->
if (apiVersion.isEmpty()) emptyMap() else mapOf("api-version" to apiVersion)
},
)
}
}

fun OpenAIHost.toUrl(): URLBuilder =
URLBuilder()
.takeFrom(baseUrl).also {
queryParams.forEach { (k, v) -> it.parameters.append(k, v) }
}
Loading

0 comments on commit ca13fd4

Please sign in to comment.