Skip to content

Commit

Permalink
Merge branch 'refactor/remove-extra-config-from-file' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
laipz8200 committed Nov 2, 2024
2 parents fe8d850 + cb5ffd8 commit c1147f8
Show file tree
Hide file tree
Showing 29 changed files with 259 additions and 352 deletions.
42 changes: 23 additions & 19 deletions api/core/agent/base_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
model_instance: ModelInstance | None = None,
) -> None:
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
Expand Down Expand Up @@ -508,24 +509,27 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P

def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())

if file_extra_config:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
else:
file_objs = []
if not files:
return UserPromptMessage(content=message.query)
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if not file_extra_config:
return UserPromptMessage(content=message.query)

if not file_objs:
return UserPromptMessage(content=message.query)
else:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file_obj in file_objs:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW

return UserPromptMessage(content=prompt_message_contents)
else:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
if not file_objs:
return UserPromptMessage(content=message.query)
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
return UserPromptMessage(content=prompt_message_contents)
21 changes: 19 additions & 2 deletions api/core/agent/cot_chat_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.utils.encoders import jsonable_encoder


Expand All @@ -36,8 +37,24 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> l
if self.files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file_obj in self.files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))

# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)

prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
Expand Down
21 changes: 19 additions & 2 deletions api/core/agent/fc_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
Expand Down Expand Up @@ -397,8 +398,24 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> l
if self.files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file_obj in self.files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))

# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)

prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
Expand Down
4 changes: 2 additions & 2 deletions api/core/app/app_config/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel, Field, field_validator

from core.file import FileExtraConfig, FileTransferMethod, FileType
from core.file import FileTransferMethod, FileType, FileUploadConfig
from core.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode

Expand Down Expand Up @@ -211,7 +211,7 @@ class TracingConfigEntity(BaseModel):


class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileExtraConfig] = None
file_upload: Optional[FileUploadConfig] = None
opening_statement: Optional[str] = None
suggested_questions: list[str] = []
suggested_questions_after_answer: bool = False
Expand Down
8 changes: 3 additions & 5 deletions api/core/app/app_config/features/file_upload/manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from collections.abc import Mapping
from typing import Any

from core.file.models import FileExtraConfig
from models import FileUploadConfig
from core.file import FileUploadConfig


class FileUploadConfigManager:
Expand Down Expand Up @@ -30,15 +29,14 @@ def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
if is_vision:
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")

return FileExtraConfig.model_validate(data)
return FileUploadConfig.model_validate(data)

@classmethod
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for file upload feature
:param config: app model config args
:param is_vision: if True, the feature is vision feature
"""
if not config.get("file_upload"):
config["file_upload"] = {}
Expand Down
4 changes: 1 addition & 3 deletions api/core/app/apps/advanced_chat/app_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate:
related_config_keys = []

# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config, is_vision=False
)
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)

# opening_statement
Expand Down
7 changes: 2 additions & 5 deletions api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow

Expand Down Expand Up @@ -98,13 +97,10 @@ def generate(
# parse files
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
Expand All @@ -127,10 +123,11 @@ def generate(
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
Expand Down
8 changes: 2 additions & 6 deletions api/core/app/apps/agent_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser
from models.enums import CreatedByRole

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -103,17 +102,13 @@ def generate(
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}

role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER

# parse files
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
Expand All @@ -135,10 +130,11 @@ def generate(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
Expand Down
13 changes: 3 additions & 10 deletions api/core/app/apps/base_app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from typing import TYPE_CHECKING, Any, Optional

from core.app.app_config.entities import VariableEntityType
from core.file import File, FileExtraConfig
from core.file import File, FileUploadConfig
from factories import file_factory

if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity
from models.enums import CreatedByRole


class BaseAppGenerator:
Expand All @@ -16,8 +15,6 @@ def _prepare_user_inputs(
*,
user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig",
user_id: str,
role: "CreatedByRole",
) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
Expand All @@ -31,9 +28,7 @@ def _prepare_user_inputs(
k: file_factory.build_from_mapping(
mapping=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
Expand All @@ -47,9 +42,7 @@ def _prepare_user_inputs(
k: file_factory.build_from_mappings(
mappings=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
Expand Down
8 changes: 2 additions & 6 deletions api/core/app/apps/chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, EndUser

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,17 +100,13 @@ def generate(
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}

role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER

# parse files
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
Expand All @@ -133,10 +128,11 @@ def generate(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
Expand Down
Loading

0 comments on commit c1147f8

Please sign in to comment.