Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embed Katago #173

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
/.idea
.DS_Store
/build
.cxx
CMakeCache.txt
CMakeFiles
CMakeLists.txt
ThirdPartyTlsLibrary
/app/app.iml
/app/.externalNativeBuild
/app/.cxx
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "katago"]
path = engine/katago/src
url = https://github.com/lightvector/KataGo
5 changes: 5 additions & 0 deletions app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ android {
version "3.10.2"
}
}
sourceSets {
main {
assets.srcDirs += file("${project(":engine:katago").buildDir}/bin")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, unfortunately

}
}
namespace 'io.zenandroid.onlinego'
packagingOptions {
resources.excludes.add("META-INF/*")
Expand Down
Binary file modified app/src/main/assets/katago.net
Binary file not shown.
118 changes: 80 additions & 38 deletions app/src/main/java/io/zenandroid/onlinego/ai/KataGoAnalysisEngine.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import com.google.firebase.crashlytics.FirebaseCrashlytics
import com.squareup.moshi.Moshi
import com.squareup.moshi.kotlin.reflect.KotlinJsonAdapterFactory
import io.reactivex.Single
import io.reactivex.subjects.PublishSubject
import io.zenandroid.onlinego.OnlineGoApplication
import io.zenandroid.onlinego.data.model.Position
import io.zenandroid.onlinego.data.model.StoneType
Expand All @@ -17,23 +16,35 @@ import io.zenandroid.onlinego.utils.recordException
import java.io.*
import java.util.*
import java.util.concurrent.atomic.AtomicLong
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onSubscription
import kotlinx.coroutines.flow.single
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.rx2.asFlowable

object KataGoAnalysisEngine {
var started = false
private set
var shouldShutDown = false
private set
var version = ""
private set
private var process: Process? = null
private var writer: OutputStreamWriter? = null
private var reader: BufferedReader? = null
private var requestIDX: AtomicLong = AtomicLong(0)
private val queryAdapter =
Moshi.Builder().add(KotlinJsonAdapterFactory()).build().adapter(Query::class.java)
private val responseAdapter =
Moshi.Builder().add(KotlinJsonAdapterFactory()).build().adapter(Response::class.java)
private val errorAdapter =
Moshi.Builder().add(KotlinJsonAdapterFactory()).build().adapter(ErrorResponse::class.java)
private val responseSubject = PublishSubject.create<KataGoResponse>()
private val moshi get() = Moshi.Builder().add(KotlinJsonAdapterFactory()).build()
private val queryAdapter = moshi.adapter(Query::class.java)
private val responseAdapter = moshi.adapter(Response::class.java)
private val errorAdapter = moshi.adapter(ErrorResponse::class.java)
private val responseSubject = MutableSharedFlow<KataGoResponse>()
private val filesDir = OnlineGoApplication.instance.filesDir
private val netFile = File(filesDir, "katagonet.gz")
private val cfgFile = File(filesDir, "katago.cfg")
Expand Down Expand Up @@ -65,6 +76,7 @@ object KataGoAnalysisEngine {
while (true) {
val line = errorReader.readLine() ?: break
if (line.startsWith("KataGo v")) {
version = line
continue
} else if (line == "Started, ready to begin handling requests") {
requestIDX = AtomicLong(0)
Expand All @@ -79,30 +91,32 @@ object KataGoAnalysisEngine {

if (started) {
Thread {
while (true) {
val line = reader?.readLine() ?: break
if (line.startsWith("{\"error\"") || line.startsWith("{\"warning\":\"WARNING_MESSAGE\"")) {
Log.e("KataGoAnalysisEngine", line)
recordException(Exception("Katago: $line"))
errorAdapter.fromJson(line)?.let {
responseSubject.onNext(it)
}
} else {
Log.d("KataGoAnalysisEngine", line)
FirebaseCrashlytics.getInstance().log("KATAGO < $line")
responseAdapter.fromJson(line)?.let {
responseSubject.onNext(it)
runBlocking {
while (true) {
val line = reader?.readLine() ?: break
if (line.startsWith("{\"error\"") || line.startsWith("{\"warning\":\"WARNING_MESSAGE\"")) {
Log.e("KataGoAnalysisEngine", line)
recordException(Exception("Katago: $line"))
errorAdapter.fromJson(line)?.let {
responseSubject.emit(it)
}
} else {
Log.d("KataGoAnalysisEngine", line)
FirebaseCrashlytics.getInstance().log("KATAGO < $line")
responseAdapter.fromJson(line)?.let {
responseSubject.emit(it)
}
}
}
Log.d("KataGoAnalysisEngine", "End of input, killing reader thread")
FirebaseCrashlytics.getInstance().log("KATAGO < End of input, killing reader thread")
started = false
}
Log.d("KataGoAnalysisEngine", "End of input, killing reader thread")
FirebaseCrashlytics.getInstance().log("KATAGO < End of input, killing reader thread")
started = false
}.start()
} else {
Log.e("KataGoAnalysisEngine", "Could not start KataGo")
recordException(Exception("Could not start KataGo $errors"))
throw RuntimeException("Could not start KataGo")
throw RuntimeException("$errors")
}
}
}
Expand Down Expand Up @@ -131,26 +145,44 @@ object KataGoAnalysisEngine {
}.start()
}

fun analyzeMoveSequence(
@Deprecated("rxjava")
fun analyzeMoveSequenceSingle(
sequence: List<Position>,
komi: Float? = null,
maxVisits: Int? = null,
maxVisits: Int? = settingsRepository.maxVisits,
rules: String = "japanese",
includeOwnership: Boolean? = null,
includeMovesOwnership: Boolean? = null,
includePolicy: Boolean? = null
): Single<Response> {
return analyzeMoveSequence(
sequence = sequence,
komi = komi,
maxVisits = maxVisits,
rules = rules,
includeOwnership = includeOwnership,
includeMovesOwnership = includeMovesOwnership,
includePolicy = includePolicy,
)
.filter { !it.isDuringSearch }
.asFlowable()
.firstOrError()
}

fun analyzeMoveSequence(
sequence: List<Position>,
komi: Float? = null,
maxVisits: Int? = settingsRepository.maxVisits,
rules: String = "japanese",
includeOwnership: Boolean? = null,
includeMovesOwnership: Boolean? = null,
includePolicy: Boolean? = null
): Flow<Response> {

val id = generateId()
return responseSubject
.filter { it.id == id }
.firstOrError()
.map {
if (it is ErrorResponse) {
throw RuntimeException(it.error)
} else {
it as Response
}
}.doOnSubscribe {
.asSharedFlow()
.onSubscription {
val initialPosition = mutableSetOf<List<String>>()
val history = Stack<List<String>>()
sequence.map { pos ->
Expand Down Expand Up @@ -185,7 +217,8 @@ object KataGoAnalysisEngine {
komi = komi,
maxVisits = maxVisits,
moves = history,
rules = "japanese"
rules = rules,
reportDuringSearchEvery = 0.5f
)

val stringQuery = queryAdapter.toJson(query)
Expand All @@ -197,6 +230,15 @@ object KataGoAnalysisEngine {
flush()
}
}
.filterNotNull()
.filter { it.id == id }
.map {
if (it is ErrorResponse) {
throw RuntimeException(it.error)
} else {
it as Response
}
}
}

private fun generateId() = requestIDX.incrementAndGet().toString()
Expand All @@ -221,4 +263,4 @@ object KataGoAnalysisEngine {
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ data class Query (
val avoidMoves: List<List<String>>? = null,
val allowMoves: List<List<String>>? = null,
val overrideSettings: String? = null,
val priority: Int? = null
)
val reportDuringSearchEvery: Float? = null,
val priority: Int? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ sealed interface KataGoResponse {

data class Response(
override val id: String,
val isDuringSearch: Boolean = false,
val turnNumber: Int,
val moveInfos: List<MoveInfo>,
val rootInfo: RootInfo,
Expand Down Expand Up @@ -49,4 +50,4 @@ data class RootInfo(
data class ResponseAbreviatedJSON(
val rootInfo: RootInfo,
val ownership: List<Float>? = null
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package io.zenandroid.onlinego.ui.screens.localai.middlewares
import io.reactivex.Observable
import io.reactivex.rxkotlin.withLatestFrom
import io.reactivex.schedulers.Schedulers
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.take
import kotlinx.coroutines.rx2.asObservable
import io.zenandroid.onlinego.ai.KataGoAnalysisEngine
import io.zenandroid.onlinego.data.model.StoneType
import io.zenandroid.onlinego.data.model.katago.KataGoResponse.Response
Expand All @@ -22,14 +26,16 @@ class AIMoveMiddleware : Middleware<AiGameState, AiGameAction> {
actions.ofType(GenerateAiMove::class.java)
.withLatestFrom(state)
.filter { (_, state) -> state.engineStarted && !state.stateRestorePending && state.position != null }
.flatMapSingle { (_, state) ->
.flatMap { (_, state) ->
KataGoAnalysisEngine.analyzeMoveSequence(
sequence = state.history,
maxVisits = 20,
//maxVisits = 20,
komi = state.position?.komi ?: 0f,
includeOwnership = false,
includeMovesOwnership = false
)
.filter { !it.isDuringSearch }
.take(1)
.map {
val selectedMove = selectMove(it)
val move = Util.getCoordinatesFromGTP(selectedMove.move, state.position!!.boardHeight)
Expand All @@ -42,11 +48,12 @@ class AIMoveMiddleware : Middleware<AiGameState, AiGameAction> {
AIMove(newPos, it, selectedMove)
}
}
.onErrorReturn { AIError }
.asObservable()
.subscribeOn(Schedulers.io())
.onErrorReturn { AIError }
}

private fun selectMove(analysis: Response): MoveInfo {
return analysis.moveInfos[0]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@ import android.util.Log
import io.reactivex.Observable
import io.reactivex.rxkotlin.withLatestFrom
import io.reactivex.schedulers.Schedulers
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.take
import kotlinx.coroutines.rx2.asObservable
import org.koin.core.context.GlobalContext
import io.zenandroid.onlinego.ai.KataGoAnalysisEngine
import io.zenandroid.onlinego.data.model.Cell
import io.zenandroid.onlinego.data.model.StoneType
import io.zenandroid.onlinego.data.repositories.SettingsRepository
import io.zenandroid.onlinego.gamelogic.RulesManager
import io.zenandroid.onlinego.gamelogic.RulesManager.isGameOver
import io.zenandroid.onlinego.mvi.Middleware
Expand All @@ -16,6 +22,8 @@ import io.zenandroid.onlinego.ui.screens.localai.AiGameState
import io.zenandroid.onlinego.utils.recordException

class GameTurnMiddleware : Middleware<AiGameState, AiGameAction> {
val settingsRepository: SettingsRepository = GlobalContext.get().get()

override fun bind(actions: Observable<AiGameAction>, state: Observable<AiGameState>): Observable<AiGameAction> =
Observable.merge(
engineStarted(actions, state),
Expand Down Expand Up @@ -51,13 +59,19 @@ class GameTurnMiddleware : Middleware<AiGameState, AiGameAction> {
actions.filter { it is NewPosition || it is AIMove }
.withLatestFrom(state)
.filter { (_, state) -> state.history.isGameOver() }
.flatMap { (_, state) ->
.switchMap { (_, state) ->
KataGoAnalysisEngine.analyzeMoveSequence(
sequence = state.history,
maxVisits = 10,
//maxVisits = 10,
komi = state.position!!.komi,
includeOwnership = true
)
.let {
if (!settingsRepository.detailedAnalysis) it
.filter { !it.isDuringSearch }
.take(1)
else it
}
.map {
val blackTerritory = mutableSetOf<Cell>()
val whiteTerritory = mutableSetOf<Cell>()
Expand Down Expand Up @@ -93,8 +107,8 @@ class GameTurnMiddleware : Middleware<AiGameState, AiGameAction> {
val aiWon = state.enginePlaysBlack == (blackScore > whiteScore)
ScoreComputed(newPos, whiteScore, blackScore, aiWon, it)
}
.asObservable()
.subscribeOn(Schedulers.io())
.toObservable()
.doOnError(this::onError)
.onErrorResumeNext(Observable.empty())
}
Expand All @@ -103,4 +117,4 @@ class GameTurnMiddleware : Middleware<AiGameState, AiGameAction> {
Log.e("GameTurnMiddleware", throwable.message, throwable)
recordException(throwable)
}
}
}
Loading