diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 507455c17622a7..860ec5de0c8ece 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -30,6 +30,7 @@ ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder @@ -65,7 +66,7 @@ def __init__( prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance = None, + model_instance: ModelInstance | None = None, ) -> None: self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -508,24 +509,27 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() - if files: - file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) - - if file_extra_config: - file_objs = file_factory.build_from_message_files( - message_files=files, tenant_id=self.tenant_id, config=file_extra_config - ) - else: - file_objs = [] + if not files: + return UserPromptMessage(content=message.query) + file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + if not file_extra_config: + return UserPromptMessage(content=message.query) - if not file_objs: - return UserPromptMessage(content=message.query) - else: - prompt_message_contents: list[PromptMessageContent] = [] - prompt_message_contents.append(TextPromptMessageContent(data=message.query)) - for file_obj in file_objs: - prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) + image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - return UserPromptMessage(content=prompt_message_contents) - else: + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=self.tenant_id, config=file_extra_config + ) + if not file_objs: return UserPromptMessage(content=message.query) + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) + for file in file_objs: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + return UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 6261a9b12c6400..d8d047fe91cdbd 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -10,6 +10,7 @@ TextPromptMessageContent, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.utils.encoders import jsonable_encoder @@ -36,8 +37,24 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> l if self.files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) - for file_obj in self.files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) + + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 9083b4e85ff38a..cd546dee124147 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -22,6 +22,7 @@ ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine @@ -397,8 +398,24 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> l if self.files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) - for file_obj in self.files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) + + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 6c6e342a073aa2..9b72452d7a1a9e 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, field_validator -from core.file import FileExtraConfig, FileTransferMethod, FileType +from core.file import FileTransferMethod, FileType, FileUploadConfig from core.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode @@ -211,7 +211,7 @@ class TracingConfigEntity(BaseModel): class AppAdditionalFeatures(BaseModel): - file_upload: Optional[FileExtraConfig] = None + file_upload: Optional[FileUploadConfig] = None opening_statement: Optional[str] = None suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 42beec2535483b..2043ea0e41795f 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,8 +1,7 @@ from collections.abc import Mapping from typing import Any -from core.file.models import FileExtraConfig -from models import FileUploadConfig +from core.file import FileUploadConfig class FileUploadConfigManager: @@ -30,15 +29,14 @@ def convert(cls, config: Mapping[str, Any], is_vision: bool = True): if is_vision: data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low") - return FileExtraConfig.model_validate(data) + return FileUploadConfig.model_validate(data) @classmethod - def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for file upload feature :param config: app model config args - :param is_vision: if True, the feature is vision feature """ if not config.get("file_upload"): config["file_upload"] = {} diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index b52f235849f665..cb606953cd7967 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -52,9 +52,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: related_config_keys = [] # file upload validation - config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, is_vision=False - ) + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) related_config_keys.extend(current_related_config_keys) # opening_statement diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 39ab87c9142b0a..0dd0ad1fd84c21 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -26,7 +26,6 @@ from extensions.ext_database import db from factories import file_factory from models.account import Account -from models.enums import CreatedByRole from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow @@ -98,13 +97,10 @@ def generate( # parse files files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER if file_extra_config: file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -127,10 +123,11 @@ def generate( application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index de12f5a441d7d1..5faaf04fbfaa4e 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -23,7 +23,6 @@ from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser -from models.enums import CreatedByRole logger = logging.getLogger(__name__) @@ -103,8 +102,6 @@ def generate( # always enable retriever resource in debugger mode override_model_config_dict["retriever_resource"] = {"enabled": True} - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files files = args.get("files") or [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) @@ -112,8 +109,6 @@ def generate( file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -135,10 +130,11 @@ def generate( task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 2707ada6cb9118..bd751e25e5407d 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -2,12 +2,11 @@ from typing import TYPE_CHECKING, Any, Optional from core.app.app_config.entities import VariableEntityType -from core.file import File, FileExtraConfig +from core.file import File, FileUploadConfig from factories import file_factory if TYPE_CHECKING: from core.app.app_config.entities import AppConfig, VariableEntity - from models.enums import CreatedByRole class BaseAppGenerator: @@ -16,8 +15,6 @@ def _prepare_user_inputs( *, user_inputs: Optional[Mapping[str, Any]], app_config: "AppConfig", - user_id: str, - role: "CreatedByRole", ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values @@ -31,9 +28,7 @@ def _prepare_user_inputs( k: file_factory.build_from_mapping( mapping=v, tenant_id=app_config.tenant_id, - user_id=user_id, - role=role, - config=FileExtraConfig( + config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_extensions=entity_dictionary[k].allowed_file_extensions, allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, @@ -47,9 +42,7 @@ def _prepare_user_inputs( k: file_factory.build_from_mappings( mappings=v, tenant_id=app_config.tenant_id, - user_id=user_id, - role=role, - config=FileExtraConfig( + config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_extensions=entity_dictionary[k].allowed_file_extensions, allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 5c074f5306e4c9..0e71f380f7c94c 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -23,7 +23,6 @@ from extensions.ext_database import db from factories import file_factory from models.account import Account -from models.enums import CreatedByRole from models.model import App, EndUser logger = logging.getLogger(__name__) @@ -101,8 +100,6 @@ def generate( # always enable retriever resource in debugger mode override_model_config_dict["retriever_resource"] = {"enabled": True} - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) @@ -110,8 +107,6 @@ def generate( file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -133,10 +128,11 @@ def generate( task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 46450d39c0d6e1..9b4db3902c8285 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -22,7 +22,6 @@ from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Message -from models.enums import CreatedByRole from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError @@ -88,8 +87,6 @@ def generate( tenant_id=app_model.tenant_id, config=args.get("model_config") ) - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) @@ -97,8 +94,6 @@ def generate( file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -110,7 +105,6 @@ def generate( ) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id trace_manager = TraceQueueManager(app_model.id) # init application generate entity @@ -118,7 +112,8 @@ def generate( task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, user_id=user.id, @@ -259,14 +254,11 @@ def generate_more_like_this( override_model_config_dict["model"] = model_dict # parse files - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) if file_extra_config: file_objs = file_factory.build_from_mappings( mappings=message.message_files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 8b98e74b85969b..b0aa21c7317b65 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -46,9 +46,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: related_config_keys = [] # file upload validation - config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, is_vision=False - ) + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) related_config_keys.extend(current_related_config_keys) # text_to_speech diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index a865c8a68b3aa9..b68afdf2125d03 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -25,7 +25,6 @@ from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Workflow -from models.enums import CreatedByRole logger = logging.getLogger(__name__) @@ -70,15 +69,11 @@ def generate( ): files: Sequence[Mapping[str, Any]] = args.get("files") or [] - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) system_files = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) @@ -100,7 +95,8 @@ def generate( application_generate_entity = WorkflowAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), files=system_files, user_id=user.id, stream=stream, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index f2eba293236466..31c3a996e19286 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -7,7 +7,7 @@ from constants import UUID_NIL from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from core.file.models import File +from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity from core.ops.ops_trace_manager import TraceQueueManager @@ -80,6 +80,7 @@ class AppGenerateEntity(BaseModel): # app config app_config: AppConfig + file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] files: Sequence[File] diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py index bdaf8793fa10ff..fe9e52258ac046 100644 --- a/api/core/file/__init__.py +++ b/api/core/file/__init__.py @@ -2,13 +2,13 @@ from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType from .models import ( File, - FileExtraConfig, + FileUploadConfig, ImageConfig, ) __all__ = [ "FileType", - "FileExtraConfig", + "FileUploadConfig", "FileTransferMethod", "FileBelongsTo", "File", diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index b69d7a74c098a3..f0aae6fa5dd8c9 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -33,25 +33,28 @@ def get_attr(*, file: File, attr: FileAttribute): raise ValueError(f"Invalid file attribute: {attr}") -def to_prompt_message_content(f: File, /): +def to_prompt_message_content( + f: File, + /, + *, + image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, +): """ - Convert a File object to an ImagePromptMessageContent object. + Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object. - This function takes a File object and converts it to an ImagePromptMessageContent - object, which can be used as a prompt for image-based AI models. + This function takes a File object and converts it to an appropriate PromptMessageContent + object, which can be used as a prompt for image or audio-based AI models. Args: - file (File): The File object to convert. Must be of type FileType.IMAGE. + f (File): The File object to convert. + detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts. + If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW. Returns: - ImagePromptMessageContent: An object containing the image data and detail level. + Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level Raises: - ValueError: If the file is not an image or if the file data is missing. - - Note: - The detail level of the image prompt is determined by the file's extra_config. - If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW. + ValueError: If the file type is not supported or if required data is missing. """ match f.type: case FileType.IMAGE: @@ -60,19 +63,14 @@ def to_prompt_message_content(f: File, /): else: data = _to_base64_data_string(f) - if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail: - detail = f._extra_config.image_config.detail - else: - detail = ImagePromptMessageContent.DETAIL.LOW - - return ImagePromptMessageContent(data=data, detail=detail) + return ImagePromptMessageContent(data=data, detail=image_detail_config) case FileType.AUDIO: encoded_string = _file_to_encoded_string(f) if f.extension is None: raise ValueError("Missing file extension") return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) case _: - raise ValueError(f"file type {f.type} is not supported") + raise ValueError("file type f.type is not supported") def download(f: File, /): diff --git a/api/core/file/models.py b/api/core/file/models.py index 866ff3155b7df5..0142893787e073 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -21,7 +21,7 @@ class ImageConfig(BaseModel): detail: ImagePromptMessageContent.DETAIL | None = None -class FileExtraConfig(BaseModel): +class FileUploadConfig(BaseModel): """ File Upload Entity. """ @@ -46,7 +46,6 @@ class File(BaseModel): extension: Optional[str] = Field(default=None, description="File extension, should contains dot") mime_type: Optional[str] = None size: int = -1 - _extra_config: FileExtraConfig | None = None def to_dict(self) -> Mapping[str, str | int | None]: data = self.model_dump(mode="json") @@ -107,34 +106,4 @@ def validate_after(self): case FileTransferMethod.TOOL_FILE: if not self.related_id: raise ValueError("Missing file related_id") - - # Validate the extra config. - if not self._extra_config: - return self - - if self._extra_config.allowed_file_types: - if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM: - raise ValueError(f"Invalid file type: {self.type}") - - if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions: - raise ValueError(f"Invalid file extension: {self.extension}") - - if ( - self._extra_config.allowed_upload_methods - and self.transfer_method not in self._extra_config.allowed_upload_methods - ): - raise ValueError(f"Invalid transfer method: {self.transfer_method}") - - match self.type: - case FileType.IMAGE: - # NOTE: This part of validation is deprecated, but still used in app features "Image Upload". - if not self._extra_config.image_config: - return self - # TODO: skip check if transfer_methods is empty, because many test cases are not setting this field - if ( - self._extra_config.image_config.transfer_methods - and self.transfer_method not in self._extra_config.image_config.transfer_methods - ): - raise ValueError(f"Invalid transfer method: {self.transfer_method}") - return self diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index d92c36a2df9024..688fb4776a86e1 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -81,15 +81,18 @@ def get_history_prompt_messages( db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() ) - if workflow_run: + if workflow_run and workflow_run.workflow: file_extra_config = FileUploadConfigManager.convert( workflow_run.workflow.features_dict, is_vision=False ) + detail = ImagePromptMessageContent.DETAIL.LOW if file_extra_config and app_record: file_objs = file_factory.build_from_message_files( message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config ) + if file_extra_config.image_config and file_extra_config.image_config.detail: + detail = file_extra_config.image_config.detail else: file_objs = [] @@ -98,12 +101,16 @@ def get_history_prompt_messages( else: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=message.query)) - for file_obj in file_objs: - if file_obj.type in {FileType.IMAGE, FileType.AUDIO}: - prompt_message = file_manager.to_prompt_message_content(file_obj) + for file in file_objs: + if file.type in {FileType.IMAGE, FileType.AUDIO}: + prompt_message = file_manager.to_prompt_message_content( + file, + image_detail_config=detail, + ) prompt_message_contents.append(prompt_message) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: prompt_messages.append(UserPromptMessage(content=message.query)) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index bbd9531b192596..0f3f8249661bf0 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -15,6 +15,7 @@ TextPromptMessageContent, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -26,8 +27,13 @@ class AdvancedPromptTransform(PromptTransform): Advanced Prompt Transform for Workflow LLM Node. """ - def __init__(self, with_variable_tmpl: bool = False) -> None: + def __init__( + self, + with_variable_tmpl: bool = False, + image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, + ) -> None: self.with_variable_tmpl = with_variable_tmpl + self.image_detail_config = image_detail_config def get_prompt( self, diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 81403487235e34..211ec78f4d6a58 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -1,19 +1,23 @@ from typing import Any -from core.file import File -from core.file.enums import FileTransferMethod, FileType +from core.file import FileTransferMethod, FileType from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from factories import file_factory class VectorizerProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - test_img = File( + mapping = { + "transfer_method": FileTransferMethod.TOOL_FILE, + "type": FileType.IMAGE, + "id": "test_id", + "url": "https://cloud.dify.ai/logo/logo-site.png", + } + test_img = file_factory.build_from_mapping( + mapping=mapping, tenant_id="__test_123", - remote_url="https://cloud.dify.ai/logo/logo-site.png", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.REMOTE_URL, ) try: VectorizerTool().fork_tool_runtime( diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index a037bee665b019..c51151e7ada15a 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -13,6 +13,7 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.http_request.executor import Executor from core.workflow.utils import variable_template_parser +from factories import file_factory from models.workflow import WorkflowNodeExecutionStatus from .entities import ( @@ -160,16 +161,15 @@ def extract_files(self, url: str, response: Response) -> list[File]: mimetype=content_type, ) - files.append( - File( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file.id, - filename=filename, - extension=extension, - mime_type=content_type, - ) + mapping = { + "tool_file_id": tool_file.id, + "type": FileType.IMAGE.value, + "transfer_method": FileTransferMethod.TOOL_FILE.value, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, ) + files.append(file) return files diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index df22130d6955e0..a9acc63f433f32 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -18,6 +18,7 @@ from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db +from factories import file_factory from models import ToolFile from models.workflow import WorkflowNodeExecutionStatus @@ -183,19 +184,17 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) if tool_file is None: raise ValueError(f"tool file {tool_file_id} not exists") - result.append( - File( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - remote_url=url, - related_id=tool_file.id, - filename=tool_file.name, - extension=ext, - mime_type=tool_file.mimetype, - size=tool_file.size, - ) + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, ) + result.append(file) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id tool_file_id = str(response.message).split("/")[-1].split(".")[0] @@ -204,18 +203,16 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) tool_file = session.scalar(stmt) if tool_file is None: raise ValueError(f"tool file {tool_file_id} not exists") - result.append( - File( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file.id, - filename=tool_file.name, - extension=path.splitext(response.save_as)[1], - mime_type=tool_file.mimetype, - size=tool_file.size, - ) + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, ) + result.append(file) elif response.type == ToolInvokeMessage.MessageType.LINK: url = str(response.message) transfer_method = FileTransferMethod.TOOL_FILE @@ -229,16 +226,15 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) extension = "." + url.split("/")[-1].split(".")[1] else: extension = ".bin" - file = File( + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, tenant_id=self.tenant_id, - type=FileType(response.save_as), - transfer_method=transfer_method, - remote_url=url, - filename=tool_file.name, - related_id=tool_file.id, - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, ) result.append(file) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index eb812bad21b6a1..84b251223f96f1 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,10 +5,10 @@ from typing import Any, Optional, cast from configs import dify_config -from core.app.app_config.entities import FileExtraConfig +from core.app.app_config.entities import FileUploadConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.models import File, FileTransferMethod, FileType, ImageConfig +from core.file.models import File, FileTransferMethod, ImageConfig from core.workflow.callbacks import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError @@ -22,6 +22,7 @@ from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.llm import LLMNodeData from core.workflow.nodes.node_mapping import node_type_classes_mapping +from factories import file_factory from models.enums import UserFrom from models.workflow import ( Workflow, @@ -271,19 +272,17 @@ def mapping_user_inputs_to_variable_pool( for item in input_value: if isinstance(item, dict) and "type" in item and item["type"] == "image": transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) - file = File( + mapping = { + "id": item.get("id"), + "transfer_method": transfer_method, + "upload_file_id": item.get("upload_file_id"), + "url": item.get("url"), + } + config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None) + file = file_factory.build_from_mapping( + mapping=mapping, tenant_id=tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - remote_url=item.get("url") - if transfer_method == FileTransferMethod.REMOTE_URL - else None, - related_id=item.get("upload_file_id") - if transfer_method == FileTransferMethod.LOCAL_FILE - else None, - _extra_config=FileExtraConfig( - image_config=ImageConfig(detail=detail) if detail else None - ), + config=config, ) new_value.append(file) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 1066dc8862baa6..738b2b3478f46a 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,23 +1,21 @@ import mimetypes -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from typing import Any import httpx from sqlalchemy import select -from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType +from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig from core.helper import ssrf_proxy from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile -from models.enums import CreatedByRole def build_from_message_files( *, message_files: Sequence["MessageFile"], tenant_id: str, - config: FileExtraConfig, + config: FileUploadConfig, ) -> Sequence[File]: results = [ build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) @@ -31,7 +29,7 @@ def build_from_message_file( *, message_file: "MessageFile", tenant_id: str, - config: FileExtraConfig, + config: FileUploadConfig, ): mapping = { "transfer_method": message_file.transfer_method, @@ -43,8 +41,6 @@ def build_from_message_file( return build_from_mapping( mapping=mapping, tenant_id=tenant_id, - user_id=message_file.created_by, - role=CreatedByRole(message_file.created_by_role), config=config, ) @@ -53,38 +49,30 @@ def build_from_mapping( *, mapping: Mapping[str, Any], tenant_id: str, - user_id: str, - role: "CreatedByRole", - config: FileExtraConfig, -): + config: FileUploadConfig | None = None, +) -> File: + config = config or FileUploadConfig() + transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) - match transfer_method: - case FileTransferMethod.REMOTE_URL: - file = _build_from_remote_url( - mapping=mapping, - tenant_id=tenant_id, - config=config, - transfer_method=transfer_method, - ) - case FileTransferMethod.LOCAL_FILE: - file = _build_from_local_file( - mapping=mapping, - tenant_id=tenant_id, - user_id=user_id, - role=role, - config=config, - transfer_method=transfer_method, - ) - case FileTransferMethod.TOOL_FILE: - file = _build_from_tool_file( - mapping=mapping, - tenant_id=tenant_id, - user_id=user_id, - config=config, - transfer_method=transfer_method, - ) - case _: - raise ValueError(f"Invalid file transfer method: {transfer_method}") + + build_functions: dict[FileTransferMethod, Callable] = { + FileTransferMethod.LOCAL_FILE: _build_from_local_file, + FileTransferMethod.REMOTE_URL: _build_from_remote_url, + FileTransferMethod.TOOL_FILE: _build_from_tool_file, + } + + build_func = build_functions.get(transfer_method) + if not build_func: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + + file = build_func( + mapping=mapping, + tenant_id=tenant_id, + transfer_method=transfer_method, + ) + + if not _is_file_valid_with_config(file=file, config=config): + raise ValueError(f"File validation failed for file: {file.filename}") return file @@ -92,10 +80,8 @@ def build_from_mapping( def build_from_mappings( *, mappings: Sequence[Mapping[str, Any]], - config: FileExtraConfig | None, + config: FileUploadConfig | None, tenant_id: str, - user_id: str, - role: "CreatedByRole", ) -> Sequence[File]: if not config: return [] @@ -104,8 +90,6 @@ def build_from_mappings( build_from_mapping( mapping=mapping, tenant_id=tenant_id, - user_id=user_id, - role=role, config=config, ) for mapping in mappings @@ -128,31 +112,20 @@ def _build_from_local_file( *, mapping: Mapping[str, Any], tenant_id: str, - user_id: str, - role: "CreatedByRole", - config: FileExtraConfig, transfer_method: FileTransferMethod, -): - # check if the upload file exists. +) -> File: file_type = FileType.value_of(mapping.get("type")) stmt = select(UploadFile).where( UploadFile.id == mapping.get("upload_file_id"), UploadFile.tenant_id == tenant_id, - UploadFile.created_by == user_id, - UploadFile.created_by_role == role, ) - if file_type == FileType.IMAGE: - stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS)) - elif file_type == FileType.VIDEO: - stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS)) - elif file_type == FileType.AUDIO: - stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS)) - elif file_type == FileType.DOCUMENT: - stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS)) + row = db.session.scalar(stmt) + if row is None: raise ValueError("Invalid upload file") - file = File( + + return File( id=mapping.get("id"), filename=row.name, extension="." + row.extension, @@ -162,80 +135,72 @@ def _build_from_local_file( transfer_method=transfer_method, remote_url=row.source_url, related_id=mapping.get("upload_file_id"), - _extra_config=config, size=row.size, ) - return file def _build_from_remote_url( *, mapping: Mapping[str, Any], tenant_id: str, - config: FileExtraConfig, transfer_method: FileTransferMethod, -): +) -> File: url = mapping.get("url") if not url: raise ValueError("Invalid file url") - mime_type = mimetypes.guess_type(url)[0] or "" - file_size = -1 - filename = url.split("/")[-1].split("?")[0] or "unknown_file" - - resp = ssrf_proxy.head(url, follow_redirects=True) - if resp.status_code == httpx.codes.OK: - if content_disposition := resp.headers.get("Content-Disposition"): - filename = content_disposition.split("filename=")[-1].strip('"') - file_size = int(resp.headers.get("Content-Length", file_size)) - mime_type = mime_type or str(resp.headers.get("Content-Type", "")) - - # Determine file extension + mime_type, filename, file_size = _get_remote_file_info(url) extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" - if not mime_type: - mime_type, _ = mimetypes.guess_type(url) - file = File( + return File( id=mapping.get("id"), filename=filename, tenant_id=tenant_id, type=FileType.value_of(mapping.get("type")), transfer_method=transfer_method, remote_url=url, - _extra_config=config, mime_type=mime_type, extension=extension, size=file_size, ) - return file + + +def _get_remote_file_info(url: str): + mime_type = mimetypes.guess_type(url)[0] or "" + file_size = -1 + filename = url.split("/")[-1].split("?")[0] or "unknown_file" + + resp = ssrf_proxy.head(url, follow_redirects=True) + if resp.status_code == httpx.codes.OK: + if content_disposition := resp.headers.get("Content-Disposition"): + filename = str(content_disposition.split("filename=")[-1].strip('"')) + file_size = int(resp.headers.get("Content-Length", file_size)) + mime_type = mime_type or str(resp.headers.get("Content-Type", "")) + + return mime_type, filename, file_size def _build_from_tool_file( *, mapping: Mapping[str, Any], tenant_id: str, - user_id: str, - config: FileExtraConfig, transfer_method: FileTransferMethod, -): +) -> File: tool_file = ( db.session.query(ToolFile) .filter( ToolFile.id == mapping.get("tool_file_id"), ToolFile.tenant_id == tenant_id, - ToolFile.user_id == user_id, ) .first() ) + if tool_file is None: raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") - path = tool_file.file_key - if "." in path: - extension = "." + path.split("/")[-1].split(".")[-1] - else: - extension = ".bin" - file = File( + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + + return File( id=mapping.get("id"), tenant_id=tenant_id, filename=tool_file.name, @@ -246,6 +211,21 @@ def _build_from_tool_file( extension=extension, mime_type=tool_file.mimetype, size=tool_file.size, - _extra_config=config, ) - return file + + +def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool: + if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM: + return False + + if config.allowed_extensions and file.extension not in config.allowed_extensions: + return False + + if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods: + return False + + if file.type == FileType.IMAGE and config.image_config: + if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods: + return False + + return True diff --git a/api/models/__init__.py b/api/models/__init__.py index 1d8bae6cfaaed3..cd6c7674da0847 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -6,7 +6,6 @@ AppMode, Conversation, EndUser, - FileUploadConfig, InstalledApp, Message, MessageAnnotation, @@ -50,6 +49,5 @@ "Tenant", "Conversation", "MessageAnnotation", - "FileUploadConfig", "ToolFile", ] diff --git a/api/models/model.py b/api/models/model.py index e9c6b6732fe165..7ee5b0c5797475 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,7 +1,7 @@ import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from datetime import datetime from enum import Enum from typing import Any, Literal, Optional @@ -9,12 +9,11 @@ import sqlalchemy as sa from flask import request from flask_login import UserMixin -from pydantic import BaseModel, Field from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config -from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers from core.file.tool_file_parser import ToolFileParser from extensions.ext_database import db @@ -25,14 +24,6 @@ from .types import StringUUID -class FileUploadConfig(BaseModel): - enabled: bool = Field(default=False) - allowed_file_types: Sequence[FileType] = Field(default_factory=list) - allowed_extensions: Sequence[str] = Field(default_factory=list) - allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - number_limits: int = Field(default=0, gt=0, le=10) - - class DifySetup(db.Model): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -958,9 +949,6 @@ def message_files(self): "type": message_file.type, }, tenant_id=current_app.tenant_id, - user_id=self.from_account_id or self.from_end_user_id or "", - role=CreatedByRole(message_file.created_by_role), - config=FileExtraConfig(), ) elif message_file.transfer_method == "remote_url": if message_file.url is None: @@ -973,9 +961,6 @@ def message_files(self): "url": message_file.url, }, tenant_id=current_app.tenant_id, - user_id=self.from_account_id or self.from_end_user_id or "", - role=CreatedByRole(message_file.created_by_role), - config=FileExtraConfig(), ) elif message_file.transfer_method == "tool_file": if message_file.upload_file_id is None: @@ -990,9 +975,6 @@ def message_files(self): file = file_factory.build_from_mapping( mapping=mapping, tenant_id=current_app.tenant_id, - user_id=self.from_account_id or self.from_end_user_id or "", - role=CreatedByRole(message_file.created_by_role), - config=FileExtraConfig(), ) else: raise ValueError( diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 75c11afa945db1..90b5cc48362f3b 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -13,7 +13,7 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from core.file.models import FileExtraConfig +from core.file.models import FileUploadConfig from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder @@ -381,7 +381,7 @@ def _convert_to_llm_node( graph: dict, model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, - file_upload: Optional[FileExtraConfig] = None, + file_upload: Optional[FileUploadConfig] = None, external_data_variable_node_mapping: dict[str, str] | None = None, ) -> dict: """ diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 0da6622658e844..9eea63f722e51f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -430,37 +430,3 @@ def test_multi_colons_parse(setup_http_mock): assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "") assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "") # assert "http://example3.com" == resp.get("headers", {}).get("referer") - - -def test_image_file(monkeypatch): - from types import SimpleNamespace - - monkeypatch.setattr( - "core.tools.tool_file_manager.ToolFileManager.create_file_by_raw", - lambda *args, **kwargs: SimpleNamespace(id="1"), - ) - - node = init_http_node( - config={ - "id": "1", - "data": { - "title": "http", - "desc": "", - "method": "get", - "url": "https://cloud.dify.ai/logo/logo-site.png", - "authorization": { - "type": "no-auth", - "config": None, - }, - "params": "", - "headers": "", - "body": None, - }, - } - ) - - result = node._run() - assert result.process_data is not None - assert result.outputs is not None - resp = result.outputs - assert len(resp.get("files", [])) == 1 diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index ece2173090233f..7d19cff3e8ece6 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -3,7 +3,7 @@ import pytest from core.app.app_config.entities import ModelConfigEntity -from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig +from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -134,7 +134,6 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", - _extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)), ) ]