Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Genai speech to form #3666

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions android/gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ fhir-sdk-engine = "1.1.0-preview2-SNAPSHOT"
fhir-sdk-knowledge = "0.1.0-alpha03-preview5-rc2-SNAPSHOT"
fhir-sdk-workflow = "0.1.0-alpha04-preview10-rc1-SNAPSHOT"
fragment-ktx = "1.8.3"
generativeai = "0.9.0"
glide = "4.16.0"
googleCloudSpeech = "2.5.2"
gradle = "8.3.2"
gson = "2.10.1"
hilt = "1.2.0"
Expand Down Expand Up @@ -128,8 +130,10 @@ fhir-sdk-common = { group = "org.smartregister", name = "common", version.ref =
foundation = { group = "androidx.compose.foundation", name = "foundation", version.ref = "compose-ui" }
fragment-ktx = { group = "androidx.fragment", name = "fragment-ktx", version.ref = "fragment-ktx" }
fragment-testing = { group = "androidx.fragment", name = "fragment-testing", version.ref = "fragment-ktx" }
generativeai = { module = "com.google.ai.client.generativeai:generativeai", version.ref = "generativeai" }
glide = { group = "com.github.bumptech.glide", name = "glide", version.ref = "glide" }
gms-play-services-location = { group = "com.google.android.gms", name = "play-services-location", version.ref = "playServicesLocation" }
google-cloud-speech = { module = "com.google.cloud:google-cloud-speech", version.ref = "googleCloudSpeech" }
gradle = { module = "com.android.tools.build:gradle", version.ref = "gradle" }
gson = { group = "com.google.code.gson", name = "gson", version.ref = "gson" }
hilt-compiler = { group = "androidx.hilt", name = "hilt-compiler", version.ref = "hilt" }
Expand Down
4 changes: 4 additions & 0 deletions android/quest/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ dependencies {
implementation(libs.bundles.cameraX)
implementation(libs.log4j)

// AI dependencies
implementation(libs.google.cloud.speech)
implementation(libs.generativeai)

// Annotation processors
kapt(libs.hilt.compiler)
kapt(libs.dagger.hilt.compiler)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2021-2024 Ona Systems, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.smartregister.fhircore.quest.ui.speechtoform

import com.google.ai.client.generativeai.GenerativeModel
import com.google.ai.client.generativeai.type.BlockThreshold
import com.google.ai.client.generativeai.type.HarmCategory
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.generationConfig

class GeminiModel(private val apiKey: String) {
// model usage
// https://developer.android.com/ai/google-ai-client-sdk
val model =
GenerativeModel(
modelName = "gemini-1.5-flash-001",
// todo actually add the API key
apiKey = "BuildConfig.apikey",
generationConfig =
generationConfig {
temperature = 0.15f
topK = 32
topP = 1f
maxOutputTokens = 4096
},
safetySettings =
listOf(
SafetySetting(HarmCategory.HARASSMENT, BlockThreshold.MEDIUM_AND_ABOVE),
SafetySetting(HarmCategory.HATE_SPEECH, BlockThreshold.MEDIUM_AND_ABOVE),
SafetySetting(HarmCategory.SEXUALLY_EXPLICIT, BlockThreshold.MEDIUM_AND_ABOVE),
SafetySetting(HarmCategory.DANGEROUS_CONTENT, BlockThreshold.MEDIUM_AND_ABOVE),
),
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2021-2024 Ona Systems, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.smartregister.fhircore.quest.ui.speechtoform

import java.io.File
import java.util.logging.Logger
import org.hl7.fhir.r4.model.Questionnaire
import org.hl7.fhir.r4.model.QuestionnaireResponse

class SpeechToForm(
private val speechToText: SpeechToText,
private val textToForm: TextToForm,
) {

private val logger = Logger.getLogger(SpeechToForm::class.java.name)

/**
* Reads an audio file, transcribes it, and generates a FHIR QuestionnaireResponse.
*
* @param audioFile The input audio file to process.
* @param questionnaire The FHIR Questionnaire used to generate the response.
* @return The generated QuestionnaireResponse, or null if the process fails.
*/
suspend fun processAudioToQuestionnaireResponse(
audioFile: File,
questionnaire: Questionnaire,
): QuestionnaireResponse? {
logger.info("Starting audio transcription process...")

// Step 1: Transcribe audio to text
val tempTextFile = speechToText.transcribeAudioToText(audioFile)
if (tempTextFile == null) {
logger.severe("Failed to transcribe audio.")
return null
}
logger.info("Transcription successful. File path: ${tempTextFile.absolutePath}")

// Step 2: Generate QuestionnaireResponse from the transcript
val questionnaireResponse =
textToForm.generateQuestionnaireResponse(tempTextFile, questionnaire)
if (questionnaireResponse == null) {
logger.severe("Failed to generate QuestionnaireResponse.")
return null
}

logger.info("QuestionnaireResponse generated successfully.")
return questionnaireResponse
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2021-2024 Ona Systems, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.smartregister.fhircore.quest.ui.speechtoform

import com.google.cloud.speech.v1.RecognitionAudio
import com.google.cloud.speech.v1.RecognitionConfig
import com.google.cloud.speech.v1.RecognitionConfig.AudioEncoding
import com.google.cloud.speech.v1.SpeechClient
import com.google.cloud.speech.v1.SpeechRecognitionResult
import java.io.File
import java.util.logging.Logger

class SpeechToText {

private val logger = Logger.getLogger(SpeechToText::class.java.name)

/**
* Transcribes an audio file to text using Google Cloud Speech-to-Text API and writes it to a
* temporary file.
*
* @param audioFile The audio file to be transcribed.
* @return The temporary file containing the transcribed text.
*/
fun transcribeAudioToText(audioFile: File): File? {
var tempFile: File? = null

SpeechClient.create().use { speechClient ->
val audioBytes = audioFile.readBytes()

// Build the recognition audio
val recognitionAudio =
RecognitionAudio.newBuilder()
.setContent(com.google.protobuf.ByteString.copyFrom(audioBytes))
.build()

// Configure recognition settings
val config =
RecognitionConfig.newBuilder()
.setEncoding(AudioEncoding.LINEAR16)
.setSampleRateHertz(16000)
.setLanguageCode("en-US")
.build()

// Perform transcription
val response = speechClient.recognize(config, recognitionAudio)
val transcription =
response.resultsList.joinToString(" ") { result: SpeechRecognitionResult ->
result.alternativesList[0].transcript
}

logger.info("Transcription: $transcription")

// Write transcription to a temporary file
tempFile = File.createTempFile("transcription", ".txt")
tempFile?.writeText(transcription)

logger.info("Transcription written to temporary file. ")
}
return tempFile
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright 2021-2024 Ona Systems, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.smartregister.fhircore.quest.ui.speechtoform

import ca.uhn.fhir.interceptor.model.RequestPartitionId.fromJson
import com.google.ai.client.generativeai.GenerativeModel
import java.io.File
import java.util.logging.Logger
import org.hl7.fhir.r4.model.Questionnaire
import org.hl7.fhir.r4.model.QuestionnaireResponse
import org.json.JSONObject

class TextToForm(private val generativeModel: GenerativeModel) {

private val logger = Logger.getLogger(TextToForm::class.java.name)

/**
* Generates an HL7 FHIR QuestionnaireResponse from a transcript using the provided Questionnaire.
*
* @param transcriptFile The temporary file containing the transcript text.
* @param questionnaire The FHIR Questionnaire to base the response on.
* @return The generated and validated QuestionnaireResponse or null if generation fails.
*/
suspend fun generateQuestionnaireResponse(
transcriptFile: File,
questionnaire: Questionnaire,
): QuestionnaireResponse? {
val transcript = transcriptFile.readText()
val prompt = promptTemplate(transcript, questionnaire)

logger.info("Sending request to Gemini...")
val generatedText = generativeModel.generateContent(prompt).text

val questionnaireResponseJson = extractJsonBlock(generatedText) ?: return null

return try {
val questionnaireResponse = parseQuestionnaireResponse(questionnaireResponseJson)
if (validateQuestionnaireResponse(questionnaireResponse)) {
logger.info("QuestionnaireResponse validated successfully.")
questionnaireResponse
} else {
logger.warning("QuestionnaireResponse validation failed.")
null
}
} catch (e: Exception) {
logger.severe("Error generating QuestionnaireResponse: ${e.message}")
null
}
}

/** Builds the prompt for the Gemini model. */
private fun promptTemplate(transcript: String, questionnaire: Questionnaire): String {
return """
You are a scribe created to turn conversational text into structure HL7 FHIR output. Below
you will see the text Transcript of a conversation between a nurse and a patient within
<transcript> XML tags and an HL7 FHIR Questionnaire within <questionnaire> XML tags. Your job
is to convert the text in Transcript into a new HL7 FHIR QuestionnaireResponse as if the
information in Transcript had been entered directly into the FHIR Questionniare. Only output
the FHIR QuestionnaireResponse as JSON and nothing else.
<transcript>$transcript</transcript>
<questionnaire>$questionnaire</questionnaire>
"""
.trimIndent()
}

/** Extracts the JSON block from the generated text. */
private fun extractJsonBlock(responseText: String?): String? {
if (responseText == null) return null
val start = responseText.indexOf("```json")
if (start == -1) return null
val end = responseText.indexOf("```", start + 7)
return if (end == -1) null else responseText.substring(start + 7, end).trim()
}

/** Parses the JSON string into a QuestionnaireResponse object. */
private fun parseQuestionnaireResponse(json: String): QuestionnaireResponse {
return QuestionnaireResponse().apply { fromJson(JSONObject(json).toString()) }
}

/** Validates the QuestionnaireResponse structure. */
private fun validateQuestionnaireResponse(qr: QuestionnaireResponse): Boolean {
// todo use SDC validation

return true
}
}
Loading