Skip to content

Commit

Permalink
feat(api/workflow): Add Conversation.dialogue_count
Browse files Browse the repository at this point in the history
  • Loading branch information
laipz8200 committed Aug 14, 2024
1 parent ba79088 commit e19ca0b
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 86 deletions.
6 changes: 5 additions & 1 deletion api/contexts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from contextvars import ContextVar

tenant_id: ContextVar[str] = ContextVar('tenant_id')
from core.workflow.entities.variable_pool import VariablePool

tenant_id: ContextVar[str] = ContextVar('tenant_id')

workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool')
108 changes: 84 additions & 24 deletions api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session

import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
Expand All @@ -18,15 +20,20 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from models.workflow import ConversationVariable, Workflow

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -120,7 +127,7 @@ def generate(
conversation=conversation,
stream=stream
)

def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
Expand All @@ -140,10 +147,10 @@ def single_iteration_generate(self, app_model: App,
"""
if not node_id:
raise ValueError('node_id is required')

if args.get('inputs') is None:
raise ValueError('inputs is required')

extras = {
"auto_generate_conversation_name": False
}
Expand Down Expand Up @@ -209,7 +216,7 @@ def _generate(self, *,
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
# db.session.refresh(conversation)

# init queue manager
queue_manager = MessageBasedAppQueueManager(
Expand All @@ -221,15 +228,69 @@ def _generate(self, *,
message_id=message.id
)

# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]

session.commit()

# Increment dialogue count.
conversation.dialogue_count += 1

conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)

inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files

user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id

# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation_id,
SystemVariable.USER_ID: user_id,
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)

# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'user': user,
'context': contextvars.copy_context()
'context': contextvars.copy_context(),
})

worker_thread.start()
Expand All @@ -242,7 +303,7 @@ def _generate(self, *,
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)

return AdvancedChatAppGenerateResponseConverter.convert(
Expand All @@ -253,9 +314,7 @@ def _generate(self, *,
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
user: Account,
context: contextvars.Context) -> None:
"""
Generate worker in a new thread.
Expand All @@ -282,16 +341,14 @@ def _generate_worker(self, flask_app: Flask,
user_id=application_generate_entity.user_id
)
else:
# get conversation and message
conversation = self._get_conversation(conversation_id)
# get message
message = self._get_message(message_id)

# chatbot app
runner = AdvancedChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except GenerateTaskStoppedException:
Expand All @@ -314,14 +371,17 @@ def _generate_worker(self, flask_app: Flask,
finally:
db.session.close()

def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False) \
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def _handle_advanced_chat_response(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
Expand All @@ -341,7 +401,7 @@ def _handle_advanced_chat_response(self, application_generate_entity: AdvancedCh
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)

try:
Expand Down
53 changes: 2 additions & 51 deletions api/core/app/apps/advanced_chat/app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from collections.abc import Mapping
from typing import Any, Optional, cast

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
Expand All @@ -19,13 +16,10 @@
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, Workflow
from models import App, Message, Workflow

logger = logging.getLogger(__name__)

Expand All @@ -39,7 +33,6 @@ def run(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
Expand All @@ -63,15 +56,6 @@ def run(

inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files

user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id

# moderation
if self.handle_input_moderation(
Expand Down Expand Up @@ -103,38 +87,6 @@ def run(
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())

# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.commit()
# Convert database entities to variables
conversation_variables = [item.to_variable() for item in conversation_variables]

# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)

# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
Expand All @@ -146,7 +98,6 @@ def run(
invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
)

def single_iteration_run(
Expand All @@ -155,7 +106,7 @@ def single_iteration_run(
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')

Expand Down
11 changes: 8 additions & 3 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Generator
from typing import Any, Optional, Union, cast

import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
Expand Down Expand Up @@ -71,6 +72,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]]

Expand All @@ -81,7 +83,7 @@ def __init__(
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
stream: bool,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
Expand All @@ -103,11 +105,12 @@ def __init__(
self._workflow = workflow
self._conversation = conversation
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
SystemVariable.USER_ID: user_id,
}

self._task_state = AdvancedChatTaskState(
Expand Down Expand Up @@ -613,7 +616,9 @@ def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:

if route_chunk_node_id == 'sys':
# system variable
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
Expand Down
5 changes: 4 additions & 1 deletion api/core/app/apps/message_based_app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _get_conversation_introduction(self, application_generate_entity: AppGenerat

return introduction

def _get_conversation(self, conversation_id: str) -> Conversation:
def _get_conversation(self, conversation_id: str):
"""
Get conversation by conversation id
:param conversation_id: conversation id
Expand All @@ -270,6 +270,9 @@ def _get_conversation(self, conversation_id: str) -> Conversation:
.first()
)

if not conversation:
raise ConversationNotExistsError()

return conversation

def _get_message(self, message_id: str) -> Message:
Expand Down
1 change: 1 addition & 0 deletions api/core/workflow/entities/node_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class SystemVariable(Enum):
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'

@classmethod
def value_of(cls, value: str) -> 'SystemVariable':
Expand Down
Loading

0 comments on commit e19ca0b

Please sign in to comment.