diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 8173bee58e8e24..7c632f8a34d56a 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -76,7 +76,7 @@ jobs: - name: Run Workflow run: poetry run -C api bash dev/pytest/pytest_workflow.sh - - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale) + - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch) uses: hoverkraft-tech/compose-action@v2.0.0 with: compose-file: | @@ -90,5 +90,6 @@ jobs: pgvecto-rs pgvector chroma + elasticsearch - name: Test Vector Stores run: poetry run -C api bash dev/pytest/pytest_vdb.sh diff --git a/.github/workflows/expose_service_ports.sh b/.github/workflows/expose_service_ports.sh index 3418bf0c6f6688..ae3e0ee69d8cfb 100755 --- a/.github/workflows/expose_service_ports.sh +++ b/.github/workflows/expose_service_ports.sh @@ -6,5 +6,6 @@ yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml +yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml -echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs." \ No newline at end of file +echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch" \ No newline at end of file diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index f6092c86337d85..d681dc66276dd1 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -45,6 +45,10 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example + - name: Ruff formatter check + if: steps.changed-files.outputs.any_changed == 'true' + run: poetry run -C api ruff format --check ./api + - name: Lint hints if: failure() run: echo "Please run 'dev/reformat' to fix the fixable linting errors." diff --git a/api/.env.example b/api/.env.example index cf3a0f302d60cc..f81675fd53810c 100644 --- a/api/.env.example +++ b/api/.env.example @@ -130,6 +130,12 @@ TENCENT_VECTOR_DB_DATABASE=dify TENCENT_VECTOR_DB_SHARD=1 TENCENT_VECTOR_DB_REPLICAS=2 +# ElasticSearch configuration +ELASTICSEARCH_HOST=127.0.0.1 +ELASTICSEARCH_PORT=9200 +ELASTICSEARCH_USERNAME=elastic +ELASTICSEARCH_PASSWORD=elastic + # PGVECTO_RS configuration PGVECTO_RS_HOST=localhost PGVECTO_RS_PORT=5431 @@ -261,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0 # Celery beat configuration -CELERY_BEAT_SCHEDULER_TIME=1 \ No newline at end of file +CELERY_BEAT_SCHEDULER_TIME=1 + +# Position configuration +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= diff --git a/api/app.py b/api/app.py index 50441cb81da1f4..ad219ca0d67459 100644 --- a/api/app.py +++ b/api/app.py @@ -1,6 +1,6 @@ import os -if os.environ.get("DEBUG", "false").lower() != 'true': +if os.environ.get("DEBUG", "false").lower() != "true": from gevent import monkey monkey.patch_all() @@ -57,7 +57,7 @@ if os.name == "nt": os.system('tzutil /s "UTC"') else: - os.environ['TZ'] = 'UTC' + os.environ["TZ"] = "UTC" time.tzset() @@ -70,13 +70,14 @@ class DifyApp(Flask): # ------------- -config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first +config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first # ---------------------------- # Application Factory Function # ---------------------------- + def create_flask_app_with_configs() -> Flask: """ create a raw flask app @@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask: elif isinstance(value, int | float | bool): os.environ[key] = str(value) elif value is None: - os.environ[key] = '' + os.environ[key] = "" return dify_app @@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask: def create_app() -> Flask: app = create_flask_app_with_configs() - app.secret_key = app.config['SECRET_KEY'] + app.secret_key = app.config["SECRET_KEY"] log_handlers = None - log_file = app.config.get('LOG_FILE') + log_file = app.config.get("LOG_FILE") if log_file: log_dir = os.path.dirname(log_file) os.makedirs(log_dir, exist_ok=True) @@ -111,23 +112,24 @@ def create_app() -> Flask: RotatingFileHandler( filename=log_file, maxBytes=1024 * 1024 * 1024, - backupCount=5 + backupCount=5, ), - logging.StreamHandler(sys.stdout) + logging.StreamHandler(sys.stdout), ] logging.basicConfig( - level=app.config.get('LOG_LEVEL'), - format=app.config.get('LOG_FORMAT'), - datefmt=app.config.get('LOG_DATEFORMAT'), + level=app.config.get("LOG_LEVEL"), + format=app.config.get("LOG_FORMAT"), + datefmt=app.config.get("LOG_DATEFORMAT"), handlers=log_handlers, - force=True + force=True, ) - log_tz = app.config.get('LOG_TZ') + log_tz = app.config.get("LOG_TZ") if log_tz: from datetime import datetime import pytz + timezone = pytz.timezone(log_tz) def time_converter(seconds): @@ -162,24 +164,24 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint not in ['console', 'inner_api']: + if request.blueprint not in ["console", "inner_api"]: return None # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get('Authorization', '') + auth_header = request.headers.get("Authorization", "") if not auth_header: - auth_token = request.args.get('_token') + auth_token = request.args.get("_token") if not auth_token: - raise Unauthorized('Invalid Authorization token.') + raise Unauthorized("Invalid Authorization token.") else: - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(auth_token) - user_id = decoded.get('user_id') + user_id = decoded.get("user_id") account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) if account: @@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login): @login_manager.unauthorized_handler def unauthorized_handler(): """Handle unauthorized requests.""" - return Response(json.dumps({ - 'code': 'unauthorized', - 'message': "Unauthorized." - }), status=401, content_type="application/json") + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) # register blueprint routers @@ -204,38 +207,36 @@ def register_blueprints(app): from controllers.service_api import bp as service_api_bp from controllers.web import bp as web_bp - CORS(service_api_bp, - allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'] - ) + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) app.register_blueprint(service_api_bp) - CORS(web_bp, - resources={ - r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}}, - supports_credentials=True, - allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], - expose_headers=['X-Version', 'X-Env'] - ) + CORS( + web_bp, + resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) app.register_blueprint(web_bp) - CORS(console_app_bp, - resources={ - r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}}, - supports_credentials=True, - allow_headers=['Content-Type', 'Authorization'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], - expose_headers=['X-Version', 'X-Env'] - ) + CORS( + console_app_bp, + resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) app.register_blueprint(console_app_bp) - CORS(files_bp, - allow_headers=['Content-Type'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'] - ) + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) app.register_blueprint(files_bp) app.register_blueprint(inner_api_bp) @@ -245,29 +246,29 @@ def register_blueprints(app): app = create_app() celery = app.extensions["celery"] -if app.config.get('TESTING'): +if app.config.get("TESTING"): print("App is running in TESTING mode") @app.after_request def after_request(response): """Add Version headers to the response.""" - response.set_cookie('remember_token', '', expires=0) - response.headers.add('X-Version', app.config['CURRENT_VERSION']) - response.headers.add('X-Env', app.config['DEPLOY_ENV']) + response.set_cookie("remember_token", "", expires=0) + response.headers.add("X-Version", app.config["CURRENT_VERSION"]) + response.headers.add("X-Env", app.config["DEPLOY_ENV"]) return response -@app.route('/health') +@app.route("/health") def health(): - return Response(json.dumps({ - 'pid': os.getpid(), - 'status': 'ok', - 'version': app.config['CURRENT_VERSION'] - }), status=200, content_type="application/json") + return Response( + json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}), + status=200, + content_type="application/json", + ) -@app.route('/threads') +@app.route("/threads") def threads(): num_threads = threading.active_count() threads = threading.enumerate() @@ -278,32 +279,34 @@ def threads(): thread_id = thread.ident is_alive = thread.is_alive() - thread_list.append({ - 'name': thread_name, - 'id': thread_id, - 'is_alive': is_alive - }) + thread_list.append( + { + "name": thread_name, + "id": thread_id, + "is_alive": is_alive, + } + ) return { - 'pid': os.getpid(), - 'thread_num': num_threads, - 'threads': thread_list + "pid": os.getpid(), + "thread_num": num_threads, + "threads": thread_list, } -@app.route('/db-pool-stat') +@app.route("/db-pool-stat") def pool_stat(): engine = db.engine return { - 'pid': os.getpid(), - 'pool_size': engine.pool.size(), - 'checked_in_connections': engine.pool.checkedin(), - 'checked_out_connections': engine.pool.checkedout(), - 'overflow_connections': engine.pool.overflow(), - 'connection_timeout': engine.pool.timeout(), - 'recycle_time': db.engine.pool._recycle + "pid": os.getpid(), + "pool_size": engine.pool.size(), + "checked_in_connections": engine.pool.checkedin(), + "checked_out_connections": engine.pool.checkedout(), + "overflow_connections": engine.pool.overflow(), + "connection_timeout": engine.pool.timeout(), + "recycle_time": db.engine.pool._recycle, } -if __name__ == '__main__': - app.run(host='0.0.0.0', port=5001) +if __name__ == "__main__": + app.run(host="0.0.0.0", port=5001) diff --git a/api/commands.py b/api/commands.py index c7ffb47b512246..41f1a6444c4581 100644 --- a/api/commands.py +++ b/api/commands.py @@ -27,32 +27,29 @@ from services.account_service import RegisterService, TenantService -@click.command('reset-password', help='Reset the account password.') -@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset') -@click.option('--new-password', prompt=True, help='the new password.') -@click.option('--password-confirm', prompt=True, help='the new password confirm.') +@click.command("reset-password", help="Reset the account password.") +@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset") +@click.option("--new-password", prompt=True, help="the new password.") +@click.option("--password-confirm", prompt=True, help="the new password confirm.") def reset_password(email, new_password, password_confirm): """ Reset password of owner account Only available in SELF_HOSTED mode """ if str(new_password).strip() != str(password_confirm).strip(): - click.echo(click.style('sorry. The two passwords do not match.', fg='red')) + click.echo(click.style("sorry. The two passwords do not match.", fg="red")) return - account = db.session.query(Account). \ - filter(Account.email == email). \ - one_or_none() + account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return try: valid_password(new_password) except: - click.echo( - click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red')) + click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red")) return # generate password salt @@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm): account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - click.echo(click.style('Congratulations! Password has been reset.', fg='green')) + click.echo(click.style("Congratulations! Password has been reset.", fg="green")) -@click.command('reset-email', help='Reset the account email.') -@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset') -@click.option('--new-email', prompt=True, help='the new email.') -@click.option('--email-confirm', prompt=True, help='the new email confirm.') +@click.command("reset-email", help="Reset the account email.") +@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset") +@click.option("--new-email", prompt=True, help="the new email.") +@click.option("--email-confirm", prompt=True, help="the new email confirm.") def reset_email(email, new_email, email_confirm): """ Replace account email :return: """ if str(new_email).strip() != str(email_confirm).strip(): - click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red')) + click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red")) return - account = db.session.query(Account). \ - filter(Account.email == email). \ - one_or_none() + account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return try: email_validate(new_email) except: - click.echo( - click.style('sorry. {} is not a valid email. '.format(email), fg='red')) + click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red")) return account.email = new_email db.session.commit() - click.echo(click.style('Congratulations!, email has been reset.', fg='green')) - - -@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. ' - 'After the reset, all LLM credentials will become invalid, ' - 'requiring re-entry.' - 'Only support SELF_HOSTED mode.') -@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?' - ' this operation cannot be rolled back!', fg='red')) + click.echo(click.style("Congratulations!, email has been reset.", fg="green")) + + +@click.command( + "reset-encrypt-key-pair", + help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. " + "After the reset, all LLM credentials will become invalid, " + "requiring re-entry." + "Only support SELF_HOSTED mode.", +) +@click.confirmation_option( + prompt=click.style( + "Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red" + ) +) def reset_encrypt_key_pair(): """ Reset the encrypted key pair of workspace for encrypt LLM credentials. After the reset, all LLM credentials will become invalid, requiring re-entry. Only support SELF_HOSTED mode. """ - if dify_config.EDITION != 'SELF_HOSTED': - click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red')) + if dify_config.EDITION != "SELF_HOSTED": + click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red")) return tenants = db.session.query(Tenant).all() for tenant in tenants: if not tenant: - click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red')) + click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red")) return tenant.encrypt_public_key = generate_key_pair(tenant.id) - db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete() + db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() db.session.commit() - click.echo(click.style('Congratulations! ' - 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) + click.echo( + click.style( + "Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id), + fg="green", + ) + ) -@click.command('vdb-migrate', help='migrate vector db.') -@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.') +@click.command("vdb-migrate", help="migrate vector db.") +@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") def vdb_migrate(scope: str): - if scope in ['knowledge', 'all']: + if scope in ["knowledge", "all"]: migrate_knowledge_vector_database() - if scope in ['annotation', 'all']: + if scope in ["annotation", "all"]: migrate_annotation_vector_database() @@ -146,7 +150,7 @@ def migrate_annotation_vector_database(): """ Migrate annotation datas to target vector database . """ - click.echo(click.style('Start migrate annotation data.', fg='green')) + click.echo(click.style("Start migrate annotation data.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 @@ -154,98 +158,103 @@ def migrate_annotation_vector_database(): while True: try: # get apps info - apps = db.session.query(App).filter( - App.status == 'normal' - ).order_by(App.created_at.desc()).paginate(page=page, per_page=50) + apps = ( + db.session.query(App) + .filter(App.status == "normal") + .order_by(App.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break page += 1 for app in apps: total_count = total_count + 1 - click.echo(f'Processing the {total_count} app {app.id}. ' - + f'{create_count} created, {skipped_count} skipped.') + click.echo( + f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." + ) try: - click.echo('Create app annotation index: {}'.format(app.id)) - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app.id - ).first() + click.echo("Create app annotation index: {}".format(app.id)) + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() + ) if not app_annotation_setting: skipped_count = skipped_count + 1 - click.echo('App annotation setting is disabled: {}'.format(app.id)) + click.echo("App annotation setting is disabled: {}".format(app.id)) continue # get dataset_collection_binding info - dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter( - DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id - ).first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) + .first() + ) if not dataset_collection_binding: - click.echo('App annotation collection binding is not exist: {}'.format(app.id)) + click.echo("App annotation collection binding is not exist: {}".format(app.id)) continue annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) documents = [] if annotations: for annotation in annotations: document = Document( page_content=annotation.question, - metadata={ - "annotation_id": annotation.id, - "app_id": app.id, - "doc_id": annotation.id - } + metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, ) documents.append(document) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) click.echo(f"Start to migrate annotation, app_id: {app.id}.") try: vector.delete() - click.echo( - click.style(f'Successfully delete vector index for app: {app.id}.', - fg='green')) + click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green")) except Exception as e: - click.echo( - click.style(f'Failed to delete vector index for app {app.id}.', - fg='red')) + click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) raise e if documents: try: - click.echo(click.style( - f'Start to created vector index with {len(documents)} annotations for app {app.id}.', - fg='green')) - vector.create(documents) click.echo( - click.style(f'Successfully created vector index for app {app.id}.', fg='green')) + click.style( + f"Start to created vector index with {len(documents)} annotations for app {app.id}.", + fg="green", + ) + ) + vector.create(documents) + click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green")) except Exception as e: - click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red')) + click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red")) raise e - click.echo(f'Successfully migrated app annotation {app.id}.') + click.echo(f"Successfully migrated app annotation {app.id}.") create_count += 1 except Exception as e: click.echo( - click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style( + "Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red" + ) + ) continue click.echo( - click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.', - fg='green')) + click.style( + f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.", + fg="green", + ) + ) def migrate_knowledge_vector_database(): """ Migrate vector database datas to target vector database . """ - click.echo(click.style('Start migrate vector db.', fg='green')) + click.echo(click.style("Start migrate vector db.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 @@ -253,87 +262,77 @@ def migrate_knowledge_vector_database(): page = 1 while True: try: - datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ - .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) + datasets = ( + db.session.query(Dataset) + .filter(Dataset.indexing_technique == "high_quality") + .order_by(Dataset.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break page += 1 for dataset in datasets: total_count = total_count + 1 - click.echo(f'Processing the {total_count} dataset {dataset.id}. ' - + f'{create_count} created, {skipped_count} skipped.') + click.echo( + f"Processing the {total_count} dataset {dataset.id}. " + + f"{create_count} created, {skipped_count} skipped." + ) try: - click.echo('Create dataset vdb index: {}'.format(dataset.id)) + click.echo("Create dataset vdb index: {}".format(dataset.id)) if dataset.index_struct_dict: - if dataset.index_struct_dict['type'] == vector_type: + if dataset.index_struct_dict["type"] == vector_type: skipped_count = skipped_count + 1 continue - collection_name = '' + collection_name = "" if vector_type == VectorType.WEAVIATE: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.WEAVIATE, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.QDRANT: if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError('Dataset Collection Bindings is not exist!') + raise ValueError("Dataset Collection Bindings is not exist!") else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.QDRANT, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.MILVUS: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.MILVUS, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.RELYT: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": 'relyt', - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.TENCENT: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.TENCENT, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.PGVECTOR: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.PGVECTOR, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.OPENSEARCH: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { "type": VectorType.OPENSEARCH, - "vector_store": {"class_prefix": collection_name} + "vector_store": {"class_prefix": collection_name}, } dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.ANALYTICDB: @@ -341,9 +340,14 @@ def migrate_knowledge_vector_database(): collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { "type": VectorType.ANALYTICDB, - "vector_store": {"class_prefix": collection_name} + "vector_store": {"class_prefix": collection_name}, } dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == VectorType.ELASTICSEARCH: + dataset_id = dataset.id + index_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} + dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {vector_type} is not supported.") @@ -353,29 +357,41 @@ def migrate_knowledge_vector_database(): try: vector.delete() click.echo( - click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.', - fg='green')) + click.style( + f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green" + ) + ) except Exception as e: click.echo( - click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.', - fg='red')) + click.style( + f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red" + ) + ) raise e - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) documents = [] segments_count = 0 for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + .all() + ) for segment in segments: document = Document( @@ -385,7 +401,7 @@ def migrate_knowledge_vector_database(): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) @@ -393,37 +409,43 @@ def migrate_knowledge_vector_database(): if documents: try: - click.echo(click.style( - f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.', - fg='green')) + click.echo( + click.style( + f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.", + fg="green", + ) + ) vector.create(documents) click.echo( - click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green')) + click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green") + ) except Exception as e: - click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red')) + click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red")) raise e db.session.add(dataset) db.session.commit() - click.echo(f'Successfully migrated dataset {dataset.id}.') + click.echo(f"Successfully migrated dataset {dataset.id}.") create_count += 1 except Exception as e: db.session.rollback() click.echo( - click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) continue click.echo( - click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.', - fg='green')) + click.style( + f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green" + ) + ) -@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.') +@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.") def convert_to_agent_apps(): """ Convert Agent Assistant to Agent App. """ - click.echo(click.style('Start convert to agent apps.', fg='green')) + click.echo(click.style("Start convert to agent apps.", fg="green")) proceeded_app_ids = [] @@ -458,7 +480,7 @@ def convert_to_agent_apps(): break for app in apps: - click.echo('Converting app: {}'.format(app.id)) + click.echo("Converting app: {}".format(app.id)) try: app.mode = AppMode.AGENT_CHAT.value @@ -470,137 +492,139 @@ def convert_to_agent_apps(): ) db.session.commit() - click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) + click.echo(click.style("Converted app: {}".format(app.id), fg="green")) except Exception as e: - click.echo( - click.style('Convert app error: {} {}'.format(e.__class__.__name__, - str(e)), fg='red')) + click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red")) - click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) + click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green")) -@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.') -@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.') +@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.") +@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.") def add_qdrant_doc_id_index(field: str): - click.echo(click.style('Start add qdrant doc_id index.', fg='green')) + click.echo(click.style("Start add qdrant doc_id index.", fg="green")) vector_type = dify_config.VECTOR_STORE if vector_type != "qdrant": - click.echo(click.style('Sorry, only support qdrant vector store.', fg='red')) + click.echo(click.style("Sorry, only support qdrant vector store.", fg="red")) return create_count = 0 try: bindings = db.session.query(DatasetCollectionBinding).all() if not bindings: - click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red')) + click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red")) return import qdrant_client from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import PayloadSchemaType from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig + for binding in bindings: if dify_config.QDRANT_URL is None: - raise ValueError('Qdrant url is required.') + raise ValueError("Qdrant url is required.") qdrant_config = QdrantConfig( endpoint=dify_config.QDRANT_URL, api_key=dify_config.QDRANT_API_KEY, root_path=current_app.root_path, timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, ) try: client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) # create payload index - client.create_payload_index(binding.collection_name, field, - field_schema=PayloadSchemaType.KEYWORD) + client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) create_count += 1 except UnexpectedResponse as e: # Collection does not exist, so return if e.status_code == 404: - click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red')) + click.echo( + click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red") + ) continue # Some other error occurred, so re-raise the exception else: - click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red')) + click.echo( + click.style( + f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red" + ) + ) except Exception as e: - click.echo(click.style('Failed to create qdrant client.', fg='red')) + click.echo(click.style("Failed to create qdrant client.", fg="red")) - click.echo( - click.style(f'Congratulations! Create {create_count} collection indexes.', - fg='green')) + click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green")) -@click.command('create-tenant', help='Create account and tenant.') -@click.option('--email', prompt=True, help='The email address of the tenant account.') -@click.option('--language', prompt=True, help='Account language, default: en-US.') +@click.command("create-tenant", help="Create account and tenant.") +@click.option("--email", prompt=True, help="The email address of the tenant account.") +@click.option("--language", prompt=True, help="Account language, default: en-US.") def create_tenant(email: str, language: Optional[str] = None): """ Create tenant account """ if not email: - click.echo(click.style('Sorry, email is required.', fg='red')) + click.echo(click.style("Sorry, email is required.", fg="red")) return # Create account email = email.strip() - if '@' not in email: - click.echo(click.style('Sorry, invalid email address.', fg='red')) + if "@" not in email: + click.echo(click.style("Sorry, invalid email address.", fg="red")) return - account_name = email.split('@')[0] + account_name = email.split("@")[0] if language not in languages: - language = 'en-US' + language = "en-US" # generate random password new_password = secrets.token_urlsafe(16) # register account - account = RegisterService.register( - email=email, - name=account_name, - password=new_password, - language=language - ) + account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) TenantService.create_owner_tenant_if_not_exist(account) - click.echo(click.style('Congratulations! Account and tenant created.\n' - 'Account: {}\nPassword: {}'.format(email, new_password), fg='green')) + click.echo( + click.style( + "Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password), + fg="green", + ) + ) -@click.command('upgrade-db', help='upgrade the database') +@click.command("upgrade-db", help="upgrade the database") def upgrade_db(): - click.echo('Preparing database migration...') - lock = redis_client.lock(name='db_upgrade_lock', timeout=60) + click.echo("Preparing database migration...") + lock = redis_client.lock(name="db_upgrade_lock", timeout=60) if lock.acquire(blocking=False): try: - click.echo(click.style('Start database migration.', fg='green')) + click.echo(click.style("Start database migration.", fg="green")) # run db migration import flask_migrate + flask_migrate.upgrade() - click.echo(click.style('Database migration successful!', fg='green')) + click.echo(click.style("Database migration successful!", fg="green")) except Exception as e: - logging.exception(f'Database migration failed, error: {e}') + logging.exception(f"Database migration failed, error: {e}") finally: lock.release() else: - click.echo('Database migration skipped') + click.echo("Database migration skipped") -@click.command('fix-app-site-missing', help='Fix app related site missing issue.') +@click.command("fix-app-site-missing", help="Fix app related site missing issue.") def fix_app_site_missing(): """ Fix app related site missing issue. """ - click.echo(click.style('Start fix app related site missing issue.', fg='green')) + click.echo(click.style("Start fix app related site missing issue.", fg="green")) failed_app_ids = [] while True: @@ -631,15 +655,14 @@ def fix_app_site_missing(): app_was_created.send(app, account=account) except Exception as e: failed_app_ids.append(app_id) - click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red')) - logging.exception(f'Fix app related site missing issue failed, error: {e}') + click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red")) + logging.exception(f"Fix app related site missing issue failed, error: {e}") continue if not processed_count: break - - click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green')) + click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green")) def register_commands(app): diff --git a/api/configs/app_config.py b/api/configs/app_config.py index a5a4fc788d0d19..b277760edd7b2c 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -12,19 +12,14 @@ class DifyConfig( # Packaging info PackagingInfo, - # Deployment configs DeploymentConfig, - # Feature configs FeatureConfig, - # Middleware configs MiddlewareConfig, - # Extra service configs ExtraServiceConfig, - # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, @@ -36,7 +31,6 @@ class DifyConfig( env_file='.env', env_file_encoding='utf-8', frozen=True, - # ignore extra attributes extra='ignore', ) @@ -67,3 +61,5 @@ def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str: SSRF_PROXY_HTTPS_URL: str | None = None MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.') + + MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.') diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 369b25d788a440..ce59a281bcb4c1 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings): default=False, ) + class WorkspaceConfig(BaseSettings): """ Workspace configs @@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings): ) +class PositionConfig(BaseSettings): + + POSITION_PROVIDER_PINS: str = Field( + description='The heads of model providers', + default='', + ) + + POSITION_PROVIDER_INCLUDES: str = Field( + description='The included model providers', + default='', + ) + + POSITION_PROVIDER_EXCLUDES: str = Field( + description='The excluded model providers', + default='', + ) + + POSITION_TOOL_PINS: str = Field( + description='The heads of tools', + default='', + ) + + POSITION_TOOL_INCLUDES: str = Field( + description='The included tools', + default='', + ) + + POSITION_TOOL_EXCLUDES: str = Field( + description='The excluded tools', + default='', + ) + + @computed_field + def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != ''] + + @computed_field + def POSITION_PROVIDER_INCLUDES_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''] + + @computed_field + def POSITION_PROVIDER_EXCLUDES_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''] + + @computed_field + def POSITION_TOOL_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != ''] + + @computed_field + def POSITION_TOOL_INCLUDES_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''] + + @computed_field + def POSITION_TOOL_EXCLUDES_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''] + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -466,6 +524,7 @@ class FeatureConfig( UpdateConfig, WorkflowConfig, WorkspaceConfig, + PositionConfig, # hosted services config HostedServiceConfig, diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 1104e298b1e82c..247fcde655a180 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description='Dify version', - default='0.6.16', + default='0.7.0', ) COMMIT_SHA: str = Field( diff --git a/api/constants/__init__.py b/api/constants/__init__.py index e374c04316b274..e22c3268ef428b 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1 +1 @@ -HIDDEN_VALUE = '[__HIDDEN__]' +HIDDEN_VALUE = "[__HIDDEN__]" diff --git a/api/constants/languages.py b/api/constants/languages.py index 38e49e0d1e2caf..524dc61b5790a4 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -1,22 +1,22 @@ language_timezone_mapping = { - 'en-US': 'America/New_York', - 'zh-Hans': 'Asia/Shanghai', - 'zh-Hant': 'Asia/Taipei', - 'pt-BR': 'America/Sao_Paulo', - 'es-ES': 'Europe/Madrid', - 'fr-FR': 'Europe/Paris', - 'de-DE': 'Europe/Berlin', - 'ja-JP': 'Asia/Tokyo', - 'ko-KR': 'Asia/Seoul', - 'ru-RU': 'Europe/Moscow', - 'it-IT': 'Europe/Rome', - 'uk-UA': 'Europe/Kyiv', - 'vi-VN': 'Asia/Ho_Chi_Minh', - 'ro-RO': 'Europe/Bucharest', - 'pl-PL': 'Europe/Warsaw', - 'hi-IN': 'Asia/Kolkata', - 'tr-TR': 'Europe/Istanbul', - 'fa-IR': 'Asia/Tehran', + "en-US": "America/New_York", + "zh-Hans": "Asia/Shanghai", + "zh-Hant": "Asia/Taipei", + "pt-BR": "America/Sao_Paulo", + "es-ES": "Europe/Madrid", + "fr-FR": "Europe/Paris", + "de-DE": "Europe/Berlin", + "ja-JP": "Asia/Tokyo", + "ko-KR": "Asia/Seoul", + "ru-RU": "Europe/Moscow", + "it-IT": "Europe/Rome", + "uk-UA": "Europe/Kyiv", + "vi-VN": "Asia/Ho_Chi_Minh", + "ro-RO": "Europe/Bucharest", + "pl-PL": "Europe/Warsaw", + "hi-IN": "Asia/Kolkata", + "tr-TR": "Europe/Istanbul", + "fa-IR": "Asia/Tehran", } languages = list(language_timezone_mapping.keys()) @@ -26,6 +26,5 @@ def supported_language(lang): if lang in languages: return lang - error = ('{lang} is not a valid language.' - .format(lang=lang)) + error = "{lang} is not a valid language.".format(lang=lang) raise ValueError(error) diff --git a/api/constants/model_template.py b/api/constants/model_template.py index cc5a37025479fd..7e1a196356c4e2 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -5,82 +5,79 @@ default_app_templates = { # workflow default mode AppMode.WORKFLOW: { - 'app': { - 'mode': AppMode.WORKFLOW.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.WORKFLOW.value, + "enable_site": True, + "enable_api": True, } }, - # completion default mode AppMode.COMPLETION: { - 'app': { - 'mode': AppMode.COMPLETION.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.COMPLETION.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} + "completion_params": {}, }, - 'user_input_form': json.dumps([ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]), - 'pre_prompt': '{{query}}' + "user_input_form": json.dumps( + [ + { + "paragraph": { + "label": "Query", + "variable": "query", + "required": True, + "default": "", + }, + }, + ] + ), + "pre_prompt": "{{query}}", }, - }, - # chat default mode AppMode.CHAT: { - 'app': { - 'mode': AppMode.CHAT.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.CHAT.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} - } - } + "completion_params": {}, + }, + }, }, - # advanced-chat default mode AppMode.ADVANCED_CHAT: { - 'app': { - 'mode': AppMode.ADVANCED_CHAT.value, - 'enable_site': True, - 'enable_api': True - } + "app": { + "mode": AppMode.ADVANCED_CHAT.value, + "enable_site": True, + "enable_api": True, + }, }, - # agent-chat default mode AppMode.AGENT_CHAT: { - 'app': { - 'mode': AppMode.AGENT_CHAT.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.AGENT_CHAT.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} - } - } - } + "completion_params": {}, + }, + }, + }, } diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 306fac3a931298..623a1a28eb731e 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -1,3 +1,7 @@ from contextvars import ContextVar -tenant_id: ContextVar[str] = ContextVar('tenant_id') \ No newline at end of file +from core.workflow.entities.variable_pool import VariablePool + +tenant_id: ContextVar[str] = ContextVar("tenant_id") + +workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool") diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index bef40bea7eb32e..b2b9d8d4967927 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -17,6 +17,7 @@ audio, completion, conversation, + conversation_variables, generator, message, model_config, diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py new file mode 100644 index 00000000000000..aa0722ea355ca2 --- /dev/null +++ b/api/controllers/console/app/conversation_variables.py @@ -0,0 +1,61 @@ +from flask_restful import Resource, marshal_with, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from extensions.ext_database import db +from fields.conversation_variable_fields import paginated_conversation_variable_fields +from libs.login import login_required +from models import ConversationVariable +from models.model import AppMode + + +class ConversationVariablesApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=AppMode.ADVANCED_CHAT) + @marshal_with(paginated_conversation_variable_fields) + def get(self, app_model): + parser = reqparse.RequestParser() + parser.add_argument('conversation_id', type=str, location='args') + args = parser.parse_args() + + stmt = ( + select(ConversationVariable) + .where(ConversationVariable.app_id == app_model.id) + .order_by(ConversationVariable.created_at) + ) + if args['conversation_id']: + stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id']) + else: + raise ValueError('conversation_id is required') + + # NOTE: This is a temporary solution to avoid performance issues. + page = 1 + page_size = 100 + stmt = stmt.limit(page_size).offset((page - 1) * page_size) + + with Session(db.engine) as session: + rows = session.scalars(stmt).all() + + return { + 'page': page, + 'limit': page_size, + 'total': len(rows), + 'has_more': False, + 'data': [ + { + 'created_at': row.created_at, + 'updated_at': row.updated_at, + **row.to_variable().model_dump(), + } + for row in rows + ], + } + + +api.add_resource(ConversationVariablesApi, '/apps//conversation-variables') diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 8889d6c6002057..a2052b9764e9ff 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -74,6 +74,7 @@ def post(self, app_model: App): parser.add_argument('hash', type=str, required=False, location='json') # TODO: set this to required=True after frontend is updated parser.add_argument('environment_variables', type=list, required=False, location='json') + parser.add_argument('conversation_variables', type=list, required=False, location='json') args = parser.parse_args() elif 'text/plain' in content_type: try: @@ -88,7 +89,8 @@ def post(self, app_model: App): 'graph': data.get('graph'), 'features': data.get('features'), 'hash': data.get('hash'), - 'environment_variables': data.get('environment_variables') + 'environment_variables': data.get('environment_variables'), + 'conversation_variables': data.get('conversation_variables'), } except json.JSONDecodeError: return {'message': 'Invalid JSON data'}, 400 @@ -100,6 +102,8 @@ def post(self, app_model: App): try: environment_variables_list = args.get('environment_variables') or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + conversation_variables_list = args.get('conversation_variables') or [] + conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=args['graph'], @@ -107,6 +111,7 @@ def post(self, app_model: App): unique_hash=args.get('hash'), account=current_user, environment_variables=environment_variables, + conversation_variables=conversation_variables, ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 3e9884328029ce..a5bc2dd86a905d 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -555,7 +555,7 @@ def get(self): RetrievalMethod.SEMANTIC_SEARCH.value ] } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE: + case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: return { 'retrieval_method': [ RetrievalMethod.SEMANTIC_SEARCH.value, @@ -579,7 +579,7 @@ def get(self, vector_type): RetrievalMethod.SEMANTIC_SEARCH.value ] } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE: + case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: return { 'retrieval_method': [ RetrievalMethod.SEMANTIC_SEARCH.value, diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index afe0ca7c69b2b7..976b97660ae293 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -178,11 +178,20 @@ def get(self, dataset_id): .subquery() query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \ - .order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0))) + .order_by( + sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) elif sort == 'created_at': - query = query.order_by(sort_logic(Document.created_at)) + query = query.order_by( + sort_logic(Document.created_at), + sort_logic(Document.position), + ) else: - query = query.order_by(desc(Document.created_at)) + query = query.order_by( + desc(Document.created_at), + desc(Document.position), + ) paginated_documents = query.paginate( page=page, per_page=limit, max_per_page=100, error_out=False) diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index c8b44cfa38114c..875870e667c8d9 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -131,7 +131,7 @@ def get(self, app_model: App, end_user: EndUser, message_id): except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") except SuggestedQuestionsAfterAnswerDisabledError: - raise BadRequest("Message Not Exists.") + raise BadRequest("Suggested Questions Is Disabled.") except Exception: logging.exception("internal server error.") raise InternalServerError() diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index ec17db5f06a30c..f4e6675bd44435 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -93,6 +93,7 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: reranking_model=dataset_configs.get('reranking_model'), weights=dataset_configs.get('weights'), reranking_enabled=dataset_configs.get('reranking_enabled', True), + rerank_mode=dataset_configs["reranking_mode"], ) ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index a490ddd67089f4..05a42a898e4af7 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -3,8 +3,9 @@ from pydantic import BaseModel +from core.file.file_obj import FileExtraConfig from core.model_runtime.entities.message_entities import PromptMessageRole -from models.model import AppMode +from models import AppMode class ModelConfigEntity(BaseModel): @@ -200,11 +201,6 @@ class TracingConfigEntity(BaseModel): tracing_provider: str -class FileExtraConfig(BaseModel): - """ - File Upload Entity. - """ - image_config: Optional[dict[str, Any]] = None class AppAdditionalFeatures(BaseModel): diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 86799fb1abe133..3da3c2eddb83f3 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from typing import Any, Optional -from core.app.app_config.entities import FileExtraConfig +from core.file.file_obj import FileExtraConfig class FileUploadConfigManager: diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index e854ea18b099b8..351eb05d8ad41c 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -8,6 +8,8 @@ from flask import Flask, current_app from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.orm import Session import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -18,15 +20,20 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, +) from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow +from models.workflow import ConversationVariable, Workflow logger = logging.getLogger(__name__) @@ -113,7 +120,6 @@ def generate( contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) return self._generate( - app_model=app_model, workflow=workflow, user=user, invoke_from=invoke_from, @@ -121,7 +127,7 @@ def generate( conversation=conversation, stream=stream ) - + def single_iteration_generate(self, app_model: App, workflow: Workflow, node_id: str, @@ -141,10 +147,10 @@ def single_iteration_generate(self, app_model: App, """ if not node_id: raise ValueError('node_id is required') - + if args.get('inputs') is None: raise ValueError('inputs is required') - + extras = { "auto_generate_conversation_name": False } @@ -180,7 +186,6 @@ def single_iteration_generate(self, app_model: App, contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) return self._generate( - app_model=app_model, workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, @@ -189,12 +194,12 @@ def single_iteration_generate(self, app_model: App, stream=stream ) - def _generate(self, app_model: App, + def _generate(self, *, workflow: Workflow, user: Union[Account, EndUser], invoke_from: InvokeFrom, application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Conversation = None, + conversation: Conversation | None = None, stream: bool = True) \ -> Union[dict, Generator[dict, None, None]]: is_first_conversation = False @@ -211,7 +216,7 @@ def _generate(self, app_model: App, # update conversation features conversation.override_model_configs = workflow.features db.session.commit() - db.session.refresh(conversation) + # db.session.refresh(conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -223,15 +228,69 @@ def _generate(self, app_model: App, message_id=message.id ) + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id + ) + with Session(db.engine) as session: + conversation_variables = session.scalars(stmt).all() + if not conversation_variables: + # Create conversation variables if they don't exist. + conversation_variables = [ + ConversationVariable.from_variable( + app_id=conversation.app_id, conversation_id=conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(conversation_variables) + # Convert database entities to variables. + conversation_variables = [item.to_variable() for item in conversation_variables] + + session.commit() + + # Increment dialogue count. + conversation.dialogue_count += 1 + + conversation_id = conversation.id + conversation_dialogue_count = conversation.dialogue_count + db.session.commit() + db.session.refresh(conversation) + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + user_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = application_generate_entity.user_id + + # Create a variable pool. + system_inputs = { + SystemVariable.QUERY: query, + SystemVariable.FILES: files, + SystemVariable.CONVERSATION_ID: conversation_id, + SystemVariable.USER_ID: user_id, + SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count, + } + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + contexts.workflow_variable_pool.set(variable_pool) + # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ 'flask_app': current_app._get_current_object(), 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, - 'conversation_id': conversation.id, 'message_id': message.id, - 'user': user, - 'context': contextvars.copy_context() + 'context': contextvars.copy_context(), }) worker_thread.start() @@ -244,7 +303,7 @@ def _generate(self, app_model: App, conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) return AdvancedChatAppGenerateResponseConverter.convert( @@ -255,9 +314,7 @@ def _generate(self, app_model: App, def _generate_worker(self, flask_app: Flask, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, - conversation_id: str, message_id: str, - user: Account, context: contextvars.Context) -> None: """ Generate worker in a new thread. @@ -284,8 +341,7 @@ def _generate_worker(self, flask_app: Flask, user_id=application_generate_entity.user_id ) else: - # get conversation and message - conversation = self._get_conversation(conversation_id) + # get message message = self._get_message(message_id) # chatbot app @@ -293,7 +349,6 @@ def _generate_worker(self, flask_app: Flask, runner.run( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - conversation=conversation, message=message ) except GenerateTaskStoppedException: @@ -316,14 +371,17 @@ def _generate_worker(self, flask_app: Flask, finally: db.session.close() - def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False) \ - -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: + def _handle_advanced_chat_response( + self, + *, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Handle response. :param application_generate_entity: application generate entity @@ -343,7 +401,7 @@ def _handle_advanced_chat_response(self, application_generate_entity: AdvancedCh conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) try: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 18db0ab22d4ded..5dc03979cf3b4b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -16,12 +16,10 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow +from models import App, Message, Workflow logger = logging.getLogger(__name__) @@ -31,10 +29,12 @@ class AdvancedChatAppRunner(AppRunner): AdvancedChat Application Runner """ - def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message) -> None: + def run( + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + message: Message, + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -48,53 +48,43 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError("App not found") + raise ValueError('App not found') workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError("Workflow not initialized") + raise ValueError('Workflow not initialized') inputs = application_generate_entity.inputs query = application_generate_entity.query - files = application_generate_entity.files - - user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = application_generate_entity.user_id # moderation if self.handle_input_moderation( - queue_manager=queue_manager, - app_record=app_record, - app_generate_entity=application_generate_entity, - inputs=inputs, - query=query, - message_id=message.id + queue_manager=queue_manager, + app_record=app_record, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query, + message_id=message.id, ): return # annotation reply if self.handle_annotation_reply( - app_record=app_record, - message=message, - query=query, - queue_manager=queue_manager, - app_generate_entity=application_generate_entity + app_record=app_record, + message=message, + query=query, + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, ): return db.session.close() - workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + workflow_callbacks: list[WorkflowCallback] = [ + WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) + ] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) # RUN WORKFLOW @@ -106,43 +96,29 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, invoke_from=application_generate_entity.invoke_from, - user_inputs=inputs, - system_inputs={ - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id - }, callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth + call_depth=application_generate_entity.call_depth, ) - def single_iteration_run(self, app_id: str, workflow_id: str, - queue_manager: AppQueueManager, - inputs: dict, node_id: str, user_id: str) -> None: + def single_iteration_run( + self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str + ) -> None: """ Single iteration run """ - app_record: App = db.session.query(App).filter(App.id == app_id).first() + app_record = db.session.query(App).filter(App.id == app_id).first() if not app_record: - raise ValueError("App not found") - + raise ValueError('App not found') + workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) if not workflow: - raise ValueError("Workflow not initialized") - - workflow_callbacks = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + raise ValueError('Workflow not initialized') + + workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, - node_id=node_id, - user_id=user_id, - user_inputs=inputs, - callbacks=workflow_callbacks + workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks ) def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: @@ -150,22 +126,25 @@ def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: Get workflow """ # fetch workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) # return workflow return workflow def handle_input_moderation( - self, queue_manager: AppQueueManager, - app_record: App, - app_generate_entity: AdvancedChatAppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str + self, + queue_manager: AppQueueManager, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> bool: """ Handle input moderation @@ -192,17 +171,20 @@ def handle_input_moderation( queue_manager=queue_manager, text=str(e), stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION + stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION, ) return True return False - def handle_annotation_reply(self, app_record: App, - message: Message, - query: str, - queue_manager: AppQueueManager, - app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: + def handle_annotation_reply( + self, + app_record: App, + message: Message, + query: str, + queue_manager: AppQueueManager, + app_generate_entity: AdvancedChatAppGenerateEntity, + ) -> bool: """ Handle annotation reply :param app_record: app record @@ -217,29 +199,27 @@ def handle_annotation_reply(self, app_record: App, message=message, query=query, user_id=app_generate_entity.user_id, - invoke_from=app_generate_entity.invoke_from + invoke_from=app_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER ) self._stream_output( queue_manager=queue_manager, text=annotation_reply.content, stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY + stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY, ) return True return False - def _stream_output(self, queue_manager: AppQueueManager, - text: str, - stream: bool, - stopped_by: QueueStopEvent.StopBy) -> None: + def _stream_output( + self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy + ) -> None: """ Direct output :param queue_manager: application queue manager @@ -250,21 +230,10 @@ def _stream_output(self, queue_manager: AppQueueManager, if stream: index = 0 for token in text: - queue_manager.publish( - QueueTextChunkEvent( - text=token - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER) index += 1 time.sleep(0.01) else: - queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER) - queue_manager.publish( - QueueStopEvent(stopped_by=stopped_by), - PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 91a43ed4493027..f8efcb59606d08 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,6 +4,7 @@ from collections.abc import Generator from typing import Any, Optional, Union, cast +import contexts from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -47,7 +48,8 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariable from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created @@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow _user: Union[Account, EndUser] + # Deprecated _workflow_system_variables: dict[SystemVariable, Any] _iteration_nested_relations: dict[str, list[str]] @@ -81,7 +84,7 @@ def __init__( conversation: Conversation, message: Message, user: Union[Account, EndUser], - stream: bool + stream: bool, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -103,11 +106,12 @@ def __init__( self._workflow = workflow self._conversation = conversation self._message = message + # Deprecated self._workflow_system_variables = { SystemVariable.QUERY: message.query, SystemVariable.FILES: application_generate_entity.files, SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id + SystemVariable.USER_ID: user_id, } self._task_state = AdvancedChatTaskState( @@ -613,7 +617,9 @@ def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]: if route_chunk_node_id == 'sys': # system variable - value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1])) + value = contexts.workflow_variable_pool.get().get(value_selector) + if value: + value = value.text elif route_chunk_node_id in self._iteration_nested_relations: # it's a iteration variable if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 58c7d04b8348f8..6fb387c15ac566 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,6 @@ import time from collections.abc import Generator -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -14,7 +14,6 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -27,13 +26,16 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class AppRunner: def get_pre_calculate_rest_tokens(self, app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], query: Optional[str] = None) -> int: """ Get pre calculate rest tokens @@ -126,7 +128,7 @@ def organize_prompt_messages(self, app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None) \ @@ -366,7 +368,7 @@ def moderation_for_inputs( message_id=message_id, trace_manager=app_generate_entity.trace_manager ) - + def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, prompt_messages: list[PromptMessage]) -> bool: @@ -418,7 +420,7 @@ def fill_in_inputs_from_external_data_tools(self, tenant_id: str, inputs=inputs, query=query ) - + def query_app_annotations_to_reply(self, app_record: App, message: Message, query: str, diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index c5cd6864020b33..12f69f1528e241 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -258,7 +258,7 @@ def _get_conversation_introduction(self, application_generate_entity: AppGenerat return introduction - def _get_conversation(self, conversation_id: str) -> Conversation: + def _get_conversation(self, conversation_id: str): """ Get conversation by conversation id :param conversation_id: conversation id @@ -270,6 +270,9 @@ def _get_conversation(self, conversation_id: str) -> Conversation: .first() ) + if not conversation: + raise ConversationNotExistsError() + return conversation def _get_message(self, message_id: str) -> Message: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 24f4a83217a239..994919391e7ed5 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -11,7 +11,8 @@ WorkflowAppGenerateEntity, ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db @@ -26,8 +27,7 @@ class WorkflowAppRunner: Workflow Application Runner """ - def run(self, application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager) -> None: + def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: """ Run application :param application_generate_entity: application generate entity @@ -47,25 +47,36 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError("App not found") + raise ValueError('App not found') workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError("Workflow not initialized") + raise ValueError('Workflow not initialized') inputs = application_generate_entity.inputs files = application_generate_entity.files db.session.close() - workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + workflow_callbacks: list[WorkflowCallback] = [ + WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) + ] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) + # Create a variable pool. + system_inputs = { + SystemVariable.FILES: files, + SystemVariable.USER_ID: user_id, + } + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) + # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( @@ -75,44 +86,33 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, invoke_from=application_generate_entity.invoke_from, - user_inputs=inputs, - system_inputs={ - SystemVariable.FILES: files, - SystemVariable.USER_ID: user_id - }, callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth + call_depth=application_generate_entity.call_depth, + variable_pool=variable_pool, ) - def single_iteration_run(self, app_id: str, workflow_id: str, - queue_manager: AppQueueManager, - inputs: dict, node_id: str, user_id: str) -> None: + def single_iteration_run( + self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str + ) -> None: """ Single iteration run """ - app_record: App = db.session.query(App).filter(App.id == app_id).first() + app_record = db.session.query(App).filter(App.id == app_id).first() if not app_record: - raise ValueError("App not found") - + raise ValueError('App not found') + if not app_record.workflow_id: - raise ValueError("Workflow not initialized") + raise ValueError('Workflow not initialized') workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) if not workflow: - raise ValueError("Workflow not initialized") - - workflow_callbacks = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + raise ValueError('Workflow not initialized') + + workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, - node_id=node_id, - user_id=user_id, - user_inputs=inputs, - callbacks=workflow_callbacks + workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks ) def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: @@ -120,11 +120,13 @@ def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: Get workflow """ # fetch workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) # return workflow return workflow diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2b4362150fc7e7..5022eb0438d13b 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -42,7 +42,8 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariable from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account @@ -519,7 +520,7 @@ def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: """ nodes = graph.get('nodes') - iteration_ids = [node.get('id') for node in nodes + iteration_ids = [node.get('id') for node in nodes if node.get('data', {}).get('type') in [ NodeType.ITERATION.value, NodeType.LOOP.value, @@ -530,4 +531,3 @@ def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id ] for iteration_id in iteration_ids } - \ No newline at end of file diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 9a861c29e2634c..6a1ab230416d0c 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -166,4 +166,4 @@ class SingleIterationRunEntity(BaseModel): node_id: str inputs: dict - single_iteration_run: Optional[SingleIterationRunEntity] = None \ No newline at end of file + single_iteration_run: Optional[SingleIterationRunEntity] = None diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index d5cd0a589cc38a..7de06dfb9639fd 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -1,7 +1,7 @@ from .segment_group import SegmentGroup from .segments import ( ArrayAnySegment, - FileSegment, + ArraySegment, FloatSegment, IntegerSegment, NoneSegment, @@ -12,11 +12,9 @@ from .types import SegmentType from .variables import ( ArrayAnyVariable, - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileVariable, FloatVariable, IntegerVariable, NoneVariable, @@ -31,7 +29,6 @@ 'FloatVariable', 'ObjectVariable', 'SecretVariable', - 'FileVariable', 'StringVariable', 'ArrayAnyVariable', 'Variable', @@ -44,10 +41,9 @@ 'FloatSegment', 'ObjectSegment', 'ArrayAnySegment', - 'FileSegment', 'StringSegment', 'ArrayStringVariable', 'ArrayNumberVariable', 'ArrayObjectVariable', - 'ArrayFileVariable', + 'ArraySegment', ] diff --git a/api/core/app/segments/exc.py b/api/core/app/segments/exc.py new file mode 100644 index 00000000000000..d15d6d500ffa4a --- /dev/null +++ b/api/core/app/segments/exc.py @@ -0,0 +1,2 @@ +class VariableError(Exception): + pass diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index 1196284b183e68..e6e9ce97747ce1 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -1,11 +1,11 @@ from collections.abc import Mapping from typing import Any -from core.file.file_obj import FileVar +from configs import dify_config +from .exc import VariableError from .segments import ( ArrayAnySegment, - FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -15,11 +15,9 @@ ) from .types import SegmentType from .variables import ( - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileVariable, FloatVariable, IntegerVariable, ObjectVariable, @@ -29,39 +27,37 @@ ) -def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable: - if (value_type := m.get('value_type')) is None: - raise ValueError('missing value type') - if not m.get('name'): - raise ValueError('missing name') - if (value := m.get('value')) is None: - raise ValueError('missing value') +def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if (value_type := mapping.get('value_type')) is None: + raise VariableError('missing value type') + if not mapping.get('name'): + raise VariableError('missing name') + if (value := mapping.get('value')) is None: + raise VariableError('missing value') match value_type: case SegmentType.STRING: - return StringVariable.model_validate(m) + result = StringVariable.model_validate(mapping) case SegmentType.SECRET: - return SecretVariable.model_validate(m) + result = SecretVariable.model_validate(mapping) case SegmentType.NUMBER if isinstance(value, int): - return IntegerVariable.model_validate(m) + result = IntegerVariable.model_validate(mapping) case SegmentType.NUMBER if isinstance(value, float): - return FloatVariable.model_validate(m) + result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): - raise ValueError(f'invalid number value {value}') - case SegmentType.FILE: - return FileVariable.model_validate(m) + raise VariableError(f'invalid number value {value}') case SegmentType.OBJECT if isinstance(value, dict): - return ObjectVariable.model_validate( - {**m, 'value': {k: build_variable_from_mapping(v) for k, v in value.items()}} - ) + result = ObjectVariable.model_validate(mapping) case SegmentType.ARRAY_STRING if isinstance(value, list): - return ArrayStringVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) + result = ArrayStringVariable.model_validate(mapping) case SegmentType.ARRAY_NUMBER if isinstance(value, list): - return ArrayNumberVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) + result = ArrayNumberVariable.model_validate(mapping) case SegmentType.ARRAY_OBJECT if isinstance(value, list): - return ArrayObjectVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) - case SegmentType.ARRAY_FILE if isinstance(value, list): - return ArrayFileVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) - raise ValueError(f'not supported value type {value_type}') + result = ArrayObjectVariable.model_validate(mapping) + case _: + raise VariableError(f'not supported value type {value_type}') + if result.size > dify_config.MAX_VARIABLE_SIZE: + raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}') + return result def build_segment(value: Any, /) -> Segment: @@ -74,12 +70,7 @@ def build_segment(value: Any, /) -> Segment: if isinstance(value, float): return FloatSegment(value=value) if isinstance(value, dict): - # TODO: Limit the depth of the object return ObjectSegment(value=value) if isinstance(value, list): - # TODO: Limit the depth of the array - elements = [build_segment(v) for v in value] - return ArrayAnySegment(value=elements) - if isinstance(value, FileVar): - return FileSegment(value=value) + return ArrayAnySegment(value=value) raise ValueError(f'not supported value {value}') diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index 0001c5300fe90b..5c713cac6747f9 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -1,11 +1,10 @@ import json +import sys from collections.abc import Mapping, Sequence from typing import Any from pydantic import BaseModel, ConfigDict, field_validator -from core.file.file_obj import FileVar - from .types import SegmentType @@ -37,6 +36,10 @@ def log(self) -> str: def markdown(self) -> str: return str(self.value) + @property + def size(self) -> int: + return sys.getsizeof(self.value) + def to_object(self) -> Any: return self.value @@ -73,14 +76,7 @@ class IntegerSegment(Segment): value: int -class FileSegment(Segment): - value_type: SegmentType = SegmentType.FILE - # TODO: embed FileVar in this model. - value: FileVar - @property - def markdown(self) -> str: - return self.value.to_markdown() class ObjectSegment(Segment): @@ -103,32 +99,31 @@ def markdown(self) -> str: class ArraySegment(Segment): @property def markdown(self) -> str: - return '\n'.join(['- ' + item.markdown for item in self.value]) - - def to_object(self): - return [v.to_object() for v in self.value] + items = [] + for item in self.value: + if hasattr(item, 'to_markdown'): + items.append(item.to_markdown()) + else: + items.append(str(item)) + return '\n'.join(items) class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Segment] + value: Sequence[Any] class ArrayStringSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[StringSegment] + value: Sequence[str] class ArrayNumberSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[FloatSegment | IntegerSegment] + value: Sequence[float | int] class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[ObjectSegment] - + value: Sequence[Mapping[str, Any]] -class ArrayFileSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[FileSegment] diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py index a371058ef52bac..cdd2b0b4b09191 100644 --- a/api/core/app/segments/types.py +++ b/api/core/app/segments/types.py @@ -10,8 +10,6 @@ class SegmentType(str, Enum): ARRAY_STRING = 'array[string]' ARRAY_NUMBER = 'array[number]' ARRAY_OBJECT = 'array[object]' - ARRAY_FILE = 'array[file]' OBJECT = 'object' - FILE = 'file' GROUP = 'group' diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index ac26e165425c3a..8fef707fcf298b 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -4,11 +4,9 @@ from .segments import ( ArrayAnySegment, - ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, - FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable): pass -class FileVariable(FileSegment, Variable): - pass - - class ObjectVariable(ObjectSegment, Variable): pass @@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable): pass -class ArrayFileVariable(ArrayFileSegment, Variable): - pass - class SecretVariable(StringVariable): value_type: SegmentType = SegmentType.SECRET diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index 545f31fddfaedb..8baa8ba09e4b00 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -2,7 +2,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.entities.node_entities import SystemVariable +from core.workflow.enums import SystemVariable from models.account import Account from models.model import EndUser from models.workflow import Workflow @@ -13,4 +13,4 @@ class WorkflowCycleStateManager: _workflow: Workflow _user: Union[Account, EndUser] _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] \ No newline at end of file + _workflow_system_variables: dict[SystemVariable, Any] diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 268ef5df867988..3959f4b4a0bb61 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -1,14 +1,19 @@ import enum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel -from core.app.app_config.entities import FileExtraConfig from core.file.tool_file_parser import ToolFileParser from core.file.upload_file_parser import UploadFileParser from core.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db -from models.model import UploadFile + + +class FileExtraConfig(BaseModel): + """ + File Upload Entity. + """ + image_config: Optional[dict[str, Any]] = None class FileType(enum.Enum): @@ -114,6 +119,7 @@ def prompt_message_content(self) -> ImagePromptMessageContent: ) def _get_data(self, force_url: bool = False) -> Optional[str]: + from models.model import UploadFile if self.type == FileType.IMAGE: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.url diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index ec502b5e062ec6..085ff07cfde921 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -5,8 +5,7 @@ import requests -from core.app.app_config.entities import FileExtraConfig -from core.file.file_obj import FileBelongsTo, FileTransferMethod, FileType, FileVar +from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar from extensions.ext_database import db from models.account import Account from models.model import EndUser, MessageFile, UploadFile @@ -100,7 +99,7 @@ def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], f # return all file objs return new_files - def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]: + def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): """ transform message files @@ -145,7 +144,7 @@ def _to_file_objs(self, files: list[Union[dict, MessageFile]], return type_file_objs - def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar: + def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): """ transform file to file obj diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index bf87a842c00bf4..5e5deb86b47e54 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -2,7 +2,6 @@ from extensions.ext_database import db from libs import rsa -from models.account import Tenant def obfuscated_token(token: str): @@ -14,6 +13,7 @@ def obfuscated_token(token: str): def encrypt_token(tenant_id: str, token: str): + from models.account import Tenant if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): raise ValueError(f'Tenant with id {tenant_id} not found') encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index dd1534c791b313..93e3a87124a889 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -3,12 +3,13 @@ from collections.abc import Callable from typing import Any +from configs import dify_config from core.tools.utils.yaml_utils import load_yaml_file def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]: """ - Get the mapping from name to index from a YAML file + Get the mapping from name to index from a YAML file. :param folder_path: :param file_name: the YAML file name, default to '_position.yaml' :return: a dict with name as key and index as value @@ -19,6 +20,64 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> return {name: index for index, name in enumerate(positions)} +def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for tools from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + + return sort_and_filter_position_map( + position_map, + pin_list=dify_config.POSITION_TOOL_PINS_LIST, + include_list=dify_config.POSITION_TOOL_INCLUDES_LIST, + exclude_list=dify_config.POSITION_TOOL_EXCLUDES_LIST + ) + + +def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for providers from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + return sort_and_filter_position_map( + position_map, + pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, + include_list=dify_config.POSITION_PROVIDER_INCLUDES_LIST, + exclude_list=dify_config.POSITION_PROVIDER_EXCLUDES_LIST + ) + + +def sort_and_filter_position_map(original_position_map: dict[str, int], pin_list: list[str], include_list: list[str], exclude_list: list[str]) -> dict[str, int]: + """ + Sort and filter the positions + :param position_map: the position map to be sorted and filtered + :param pin_list: the list of pins to be put at the beginning + :param include_set: the set of names to be included + :param exclude_set: the set of names to be excluded + :return: the sorted and filtered position map + """ + positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x]) + include_set = set(include_list) if include_list else set(positions) + exclude_set = set(exclude_list) if exclude_list else set() + + # Add pins to position map + position_map = {name: idx for idx, name in enumerate(pin_list) if name in original_position_map} + + # Add remaining positions to position map, respecting include and exclude lists + start_idx = len(position_map) + for name in positions: + if name in include_set and name not in exclude_set and name not in position_map: + position_map[name] = start_idx + start_idx += 1 + return position_map + + def sort_by_position_map( position_map: dict[str, int], data: list[Any], @@ -35,7 +94,9 @@ def sort_by_position_map( if not position_map or not data: return data - return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) + filtered_data = [item for item in data if name_func(item) in position_map] + + return sorted(filtered_data, key=lambda x: position_map.get(name_func(x), float('inf'))) def sort_to_dict_by_position_map( diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index 87fe4f681ce5c7..d2076bf74a3cde 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -1,4 +1,3 @@ - from core.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { @@ -94,5 +93,16 @@ }, 'required': False, 'options': ['JSON', 'XML'], - } -} \ No newline at end of file + }, + DefaultParameterName.JSON_SCHEMA: { + 'label': { + 'en_US': 'JSON Schema', + }, + 'type': 'text', + 'help': { + 'en_US': 'Set a response json schema will ensure LLM to adhere it.', + 'zh_Hans': '设置返回的json schema,llm将按照它返回', + }, + 'required': False, + }, +} diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 3d471787bbef8e..c257ce63d27926 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -95,6 +95,7 @@ class DefaultParameterName(Enum): FREQUENCY_PENALTY = "frequency_penalty" MAX_TOKENS = "max_tokens" RESPONSE_FORMAT = "response_format" + JSON_SCHEMA = "json_schema" @classmethod def value_of(cls, value: Any) -> 'DefaultParameterName': @@ -118,6 +119,7 @@ class ParameterType(Enum): INT = "int" STRING = "string" BOOLEAN = "boolean" + TEXT = "text" class ModelPropertyKey(Enum): diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 0de216bf896fc2..716bb63566c372 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -151,9 +151,9 @@ def predefined_models(self) -> list[AIModelEntity]: os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') - and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + and not model_schema_yaml.startswith('_') + and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) + and model_schema_yaml.endswith('.yaml') ] # get _position.yaml file path diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index b1660afafb12e4..e2d17e32575920 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict from core.helper.module_import_helper import load_single_subclass_from_source -from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map +from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -234,7 +234,7 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: ] # get _position.yaml file path - position_map = get_position_map(model_providers_path) + position_map = get_provider_position_map(model_providers_path) # traverse all model_provider_dir_paths model_providers: list[ModelProviderExtension] = [] diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index 21661b9a2b8aef..ac7313aaa1bf0b 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -2,6 +2,7 @@ - gpt-4o - gpt-4o-2024-05-13 - gpt-4o-2024-08-06 +- chatgpt-4o-latest - gpt-4o-mini - gpt-4o-mini-2024-07-18 - gpt-4-turbo diff --git a/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml new file mode 100644 index 00000000000000..98e236650c9e73 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml @@ -0,0 +1,44 @@ +model: chatgpt-4o-latest +label: + zh_Hans: chatgpt-4o-latest + en_US: chatgpt-4o-latest +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '2.50' + output: '10.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml index cf2de0f73a0b84..7e430c51a710fc 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '2.50' output: '10.00' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml index b97fbf8aabcae4..23dcf85085e123 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '0.15' output: '0.60' diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index aae2729bdfb042..06135c958463e8 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,3 +1,4 @@ +import json import logging from collections.abc import Generator from typing import Optional, Union, cast @@ -544,13 +545,18 @@ def _chat_generate(self, model: str, credentials: dict, response_format = model_parameters.get("response_format") if response_format: - if response_format == "json_object": - response_format = {"type": "json_object"} + if response_format == "json_schema": + json_schema = model_parameters.get("json_schema") + if not json_schema: + raise ValueError("Must define JSON Schema when the response format is json_schema") + try: + schema = json.loads(json_schema) + except: + raise ValueError(f"not currect json_schema format: {json_schema}") + model_parameters.pop("json_schema") + model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} else: - response_format = {"type": "text"} - - model_parameters["response_format"] = response_format - + model_parameters["response_format"] = {"type": response_format} extra_model_kwargs = {} @@ -922,11 +928,14 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" if model.startswith('ft:'): model = model.split(':')[1] + # Currently, we can use gpt4o to calculate chatgpt-4o-latest's token. + if model == "chatgpt-4o-latest": + model = "gpt-4o" + try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -946,7 +955,7 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], raise NotImplementedError( f"get_num_tokens_from_messages() is not presently implemented " f"for model {model}." - "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "See https://platform.openai.com/docs/advanced-usage/managing-tokens for " "information on how messages are converted to tokens." ) num_tokens = 0 diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index 69bed9603902a6..88c76fe16ef733 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -7,6 +7,7 @@ description: supported_model_types: - llm - text-embedding + - speech2text configurate_methods: - customizable-model model_credential_schema: @@ -61,6 +62,22 @@ model_credential_schema: zh_Hans: 模型上下文长度 en_US: Model context size required: true + show_on: + - variable: __model_type + value: llm + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + show_on: + - variable: __model_type + value: text-embedding type: text-input default: '4096' placeholder: diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/__init__.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py new file mode 100644 index 00000000000000..00702ba9367cf4 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -0,0 +1,63 @@ +from typing import IO, Optional +from urllib.parse import urljoin + +import requests + +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat + + +class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): + """ + Model class for OpenAI Compatible Speech to text model. + """ + + def _invoke( + self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None + ) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + headers = {} + + api_key = credentials.get("api_key") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" + endpoint_url = urljoin(endpoint_url, "audio/transcriptions") + + payload = {"model": model} + files = [("file", file)] + response = requests.post(endpoint_url, headers=headers, data=payload, files=files) + + if response.status_code != 200: + raise InvokeBadRequestError(response.text) + response_data = response.json() + return response_data["text"] + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + audio_file_path = self._get_demo_file_path() + + with open(audio_file_path, "rb") as audio_file: + self._invoke(model, credentials, audio_file) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py index a53f16c929728e..dd0eea362a5f83 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class SiliconflowProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml index 3084d3edcd644f..1ebb1e6d8b149c 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml @@ -16,6 +16,7 @@ help: supported_model_types: - llm - text-embedding + - speech2text configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/__init__.py b/api/core/model_runtime/model_providers/siliconflow/speech2text/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/sense-voice-small.yaml b/api/core/model_runtime/model_providers/siliconflow/speech2text/sense-voice-small.yaml new file mode 100644 index 00000000000000..deceaf60f4f017 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/speech2text/sense-voice-small.yaml @@ -0,0 +1,5 @@ +model: iic/SenseVoiceSmall +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: mp3,wav diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py new file mode 100644 index 00000000000000..6ad3cab5873c69 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py @@ -0,0 +1,32 @@ +from typing import IO, Optional + +from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel + + +class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel): + """ + Model class for Siliconflow Speech to text model. + """ + + def _invoke( + self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None + ) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, file) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + return super().validate_credentials(model, credentials) + + @classmethod + def _add_custom_parameters(cls, credentials: dict) -> None: + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/bce-embedding-base-v1.yaml b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bce-embedding-base-v1.yaml new file mode 100644 index 00000000000000..710fbc04f6ad12 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bce-embedding-base-v1.yaml @@ -0,0 +1,5 @@ +model: netease-youdao/bce-embedding-base_v1 +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 1 diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-m3.yaml b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-m3.yaml new file mode 100644 index 00000000000000..f0b12dd420ab2b --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/bge-m3.yaml @@ -0,0 +1,5 @@ +model: BAAI/bge-m3 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 1 diff --git a/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml new file mode 100644 index 00000000000000..aad07f56736e52 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml @@ -0,0 +1,81 @@ +model: farui-plus +label: + en_US: farui-plus +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 12288 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: enable_search + type: boolean + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + use_template: response_format +pricing: + input: '0.02' + output: '0.02' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml index eed09f95dedea7..f4303c53d38b80 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml @@ -2,3 +2,8 @@ model: text-embedding-v1 model_type: text-embedding model_properties: context_size: 2048 + max_chunks: 25 +pricing: + input: "0.0007" + unit: "0.001" + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml index db2fa861e69f90..f6be3544ed8f65 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml @@ -2,3 +2,8 @@ model: text-embedding-v2 model_type: text-embedding model_properties: context_size: 2048 + max_chunks: 25 +pricing: + input: "0.0007" + unit: "0.001" + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index c207ffc1e34bbb..e7e1b5c764c093 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -2,6 +2,7 @@ from typing import Optional import dashscope +import numpy as np from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( @@ -21,11 +22,11 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): """ def _invoke( - self, - model: str, - credentials: dict, - texts: list[str], - user: Optional[str] = None, + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,16 +38,44 @@ def _invoke( :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - embeddings, embedding_used_tokens = self.embed_documents( - credentials_kwargs=credentials_kwargs, - model=model, - texts=texts - ) + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + inputs = [] + indices = [] + used_tokens = 0 + + for i, text in enumerate(texts): + + # Here token count is only an approximation based on the GPT2 tokenizer + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0:cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + + for i in _iter: + embeddings_batch, embedding_used_tokens = self.embed_documents( + credentials_kwargs=credentials_kwargs, + model=model, + texts=inputs[i : i + max_chunks], + ) + used_tokens += embedding_used_tokens + batched_embeddings += embeddings_batch + + # calc usage + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=used_tokens + ) return TextEmbeddingResult( - embeddings=embeddings, - usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), - model=model + embeddings=batched_embeddings, usage=usage, model=model ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -79,12 +108,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None: credentials_kwargs = self._to_credential_kwargs(credentials) # call embedding model - self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]) + self.embed_documents( + credentials_kwargs=credentials_kwargs, model=model, texts=["ping"] + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @staticmethod - def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]: + def embed_documents( + credentials_kwargs: dict, model: str, texts: list[str] + ) -> tuple[list[list[float]], int]: """Call out to Tongyi's embedding endpoint. Args: @@ -102,7 +135,7 @@ def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> t api_key=credentials_kwargs["dashscope_api_key"], model=model, input=text, - text_type="document" + text_type="document", ) data = response.output["embeddings"][0] embeddings.append(data["embedding"]) @@ -111,7 +144,7 @@ def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> t return [list(map(float, e)) for e in embeddings], embedding_used_tokens def _calc_response_usage( - self, model: str, credentials: dict, tokens: int + self, model: str, credentials: dict, tokens: int ) -> EmbeddingUsage: """ Calculate response usage @@ -125,7 +158,7 @@ def _calc_response_usage( model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -136,7 +169,7 @@ def _calc_response_usage( price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/upstage/llm/_position.yaml b/api/core/model_runtime/model_providers/upstage/llm/_position.yaml index d4f03e1988f8b8..7992843dcb1d1d 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/upstage/llm/_position.yaml @@ -1 +1 @@ -- soloar-1-mini-chat +- solar-1-mini-chat diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 452b270348b2ff..fd7ed0181be2f2 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,11 +1,10 @@ import enum import json import os -from typing import Optional +from typing import TYPE_CHECKING, Optional from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( PromptMessage, @@ -18,6 +17,9 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class ModelMode(enum.Enum): COMPLETION = 'completion' @@ -50,7 +52,7 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: list[FileVar], + files: list["FileVar"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> \ @@ -163,7 +165,7 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, inputs: dict, query: str, context: Optional[str], - files: list[FileVar], + files: list["FileVar"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -206,7 +208,7 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, inputs: dict, query: str, context: Optional[str], - files: list[FileVar], + files: list["FileVar"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -255,7 +257,7 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage: if files: prompt_message_contents = [TextPromptMessageContent(data=prompt)] for file in files: diff --git a/api/core/rag/datasource/vdb/elasticsearch/__init__.py b/api/core/rag/datasource/vdb/elasticsearch/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py new file mode 100644 index 00000000000000..01ba6fb3248786 --- /dev/null +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -0,0 +1,191 @@ +import json +from typing import Any + +import requests +from elasticsearch import Elasticsearch +from flask import current_app +from pydantic import BaseModel, model_validator + +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.models.document import Document +from models.dataset import Dataset + + +class ElasticSearchConfig(BaseModel): + host: str + port: str + username: str + password: str + + @model_validator(mode='before') + def validate_config(cls, values: dict) -> dict: + if not values['host']: + raise ValueError("config HOST is required") + if not values['port']: + raise ValueError("config PORT is required") + if not values['username']: + raise ValueError("config USERNAME is required") + if not values['password']: + raise ValueError("config PASSWORD is required") + return values + + +class ElasticSearchVector(BaseVector): + def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list): + super().__init__(index_name.lower()) + self._client = self._init_client(config) + self._attributes = attributes + + def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: + try: + client = Elasticsearch( + hosts=f'{config.host}:{config.port}', + basic_auth=(config.username, config.password), + request_timeout=100000, + retry_on_timeout=True, + max_retries=10000, + ) + except requests.exceptions.ConnectionError: + raise ConnectionError("Vector database connection error") + + return client + + def get_type(self) -> str: + return 'elasticsearch' + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + if not self._client.indices.exists(index=self._collection_name): + dim = len(embeddings[0]) + mapping = { + "properties": { + "text": { + "type": "text" + }, + "vector": { + "type": "dense_vector", + "index": True, + "dims": dim, + "similarity": "l2_norm" + }, + } + } + self._client.indices.create(index=self._collection_name, mappings=mapping) + + added_ids = [] + for i, text in enumerate(texts): + self._client.index(index=self._collection_name, + id=uuids[i], + document={ + "text": text, + "vector": embeddings[i] if embeddings[i] else None, + "metadata": metadatas[i] if metadatas[i] else {}, + }) + added_ids.append(uuids[i]) + + self._client.indices.refresh(index=self._collection_name) + return uuids + + def text_exists(self, id: str) -> bool: + return self._client.exists(index=self._collection_name, id=id).__bool__() + + def delete_by_ids(self, ids: list[str]) -> None: + for id in ids: + self._client.delete(index=self._collection_name, id=id) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + query_str = { + 'query': { + 'match': { + f'metadata.{key}': f'{value}' + } + } + } + results = self._client.search(index=self._collection_name, body=query_str) + ids = [hit['_id'] for hit in results['hits']['hits']] + if ids: + self.delete_by_ids(ids) + + def delete(self) -> None: + self._client.indices.delete(index=self._collection_name) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + query_str = { + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", + "params": { + "query_vector": query_vector + } + } + } + } + } + + results = self._client.search(index=self._collection_name, body=query_str) + + docs_and_scores = [] + for hit in results['hits']['hits']: + docs_and_scores.append( + (Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score'])) + + docs = [] + for doc, score in docs_and_scores: + score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + if score > score_threshold: + doc.metadata['score'] = score + docs.append(doc) + + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + + return docs + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + query_str = { + "match": { + "text": query + } + } + results = self._client.search(index=self._collection_name, query=query_str) + docs = [] + for hit in results['hits']['hits']: + docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata'])) + + return docs + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + return self.add_texts(texts, embeddings, **kwargs) + + +class ElasticSearchVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) + + config = current_app.config + return ElasticSearchVector( + index_name=collection_name, + config=ElasticSearchConfig( + host=config.get('ELASTICSEARCH_HOST'), + port=config.get('ELASTICSEARCH_PORT'), + username=config.get('ELASTICSEARCH_USERNAME'), + password=config.get('ELASTICSEARCH_PASSWORD'), + ), + attributes=[] + ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index cff9293baa2efe..4ae1a3395b0749 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -93,7 +93,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** @staticmethod def escape_str(value: Any) -> str: - return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value)) + return "".join(" " if c in ("\\", "'") else c for c in str(value)) def text_exists(self, id: str) -> bool: results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") @@ -118,7 +118,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs) + return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs) def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index fad60ecf45c151..3e9ca8e1fe7f4a 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -71,6 +71,9 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: case VectorType.RELYT: from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory return RelytVectorFactory + case VectorType.ELASTICSEARCH: + from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory + return ElasticSearchVectorFactory case VectorType.TIDB_VECTOR: from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory return TiDBVectorFactory diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 77495044df562c..317ca6abc8c89d 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -15,3 +15,4 @@ class VectorType(str, Enum): OPENSEARCH = 'opensearch' TENCENT = 'tencent' ORACLE = 'oracle' + ELASTICSEARCH = 'elasticsearch' diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 569a1d3238f87f..2e4433d9f6d2b9 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -46,7 +46,7 @@ def value_of(cls, value: str) -> 'ToolProviderType': if mode.value == value: return mode raise ValueError(f'invalid mode value {value}') - + class ApiProviderSchemaType(Enum): """ Enum class for api provider schema type. @@ -68,7 +68,7 @@ def value_of(cls, value: str) -> 'ApiProviderSchemaType': if mode.value == value: return mode raise ValueError(f'invalid mode value {value}') - + class ApiProviderAuthType(Enum): """ Enum class for api provider auth type. @@ -103,8 +103,8 @@ class MessageType(Enum): """ plain text, image url or link url """ - message: Union[str, bytes, dict] = None - meta: dict[str, Any] = None + message: str | bytes | dict | None = None + meta: dict[str, Any] | None = None save_as: str = '' class ToolInvokeMessageBinary(BaseModel): @@ -154,8 +154,8 @@ class ToolParameterForm(Enum): options: Optional[list[ToolParameterOption]] = None @classmethod - def get_simple_instance(cls, - name: str, llm_description: str, type: ToolParameterType, + def get_simple_instance(cls, + name: str, llm_description: str, type: ToolParameterType, required: bool, options: Optional[list[str]] = None) -> 'ToolParameter': """ get a simple tool parameter @@ -222,7 +222,7 @@ def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType": if mode.value == value: return mode raise ValueError(f'invalid mode value {value}') - + @staticmethod def default(value: str) -> str: return "" @@ -290,7 +290,7 @@ def dict(self) -> dict: 'tenant_id': self.tenant_id, 'pool': [variable.model_dump() for variable in self.pool], } - + def set_text(self, tool_name: str, name: str, value: str) -> None: """ set a text variable @@ -301,7 +301,7 @@ def set_text(self, tool_name: str, name: str, value: str) -> None: variable = cast(ToolRuntimeTextVariable, variable) variable.value = value return - + variable = ToolRuntimeTextVariable( type=ToolRuntimeVariableType.TEXT, name=name, @@ -334,7 +334,7 @@ def set_file(self, tool_name: str, value: str, name: str = None) -> None: variable = cast(ToolRuntimeImageVariable, variable) variable.value = value return - + variable = ToolRuntimeImageVariable( type=ToolRuntimeVariableType.IMAGE, name=name, @@ -388,21 +388,21 @@ def empty(cls) -> 'ToolInvokeMeta': Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) - + @classmethod def error_instance(cls, error: str) -> 'ToolInvokeMeta': """ Get an instance of ToolInvokeMeta with error """ return cls(time_cost=0.0, error=error, tool_config={}) - + def to_dict(self) -> dict: return { 'time_cost': self.time_cost, 'error': self.error, 'tool_config': self.tool_config, } - + class ToolLabel(BaseModel): """ Tool label @@ -416,4 +416,4 @@ class ToolInvokeFrom(Enum): Enum class for tool invoke """ WORKFLOW = "workflow" - AGENT = "agent" \ No newline at end of file + AGENT = "agent" diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index ae806eaff4a032..062668fc5bf8bf 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,6 +1,6 @@ import os.path -from core.helper.position_helper import get_position_map, sort_by_position_map +from core.helper.position_helper import get_tool_position_map, sort_by_position_map from core.tools.entities.api_entities import UserToolProvider @@ -10,11 +10,11 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) + cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) def name_func(provider: UserToolProvider) -> str: return provider.name sorted_providers = sort_by_position_map(cls._position, providers, name_func) - return sorted_providers \ No newline at end of file + return sorted_providers diff --git a/api/core/tools/provider/builtin/gitlab/_assets/gitlab.svg b/api/core/tools/provider/builtin/gitlab/_assets/gitlab.svg new file mode 100644 index 00000000000000..07734077d5d300 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/_assets/gitlab.svg @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.py b/api/core/tools/provider/builtin/gitlab/gitlab.py new file mode 100644 index 00000000000000..fca34ae15f9070 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/gitlab.py @@ -0,0 +1,34 @@ +from typing import Any + +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class GitlabProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.") + + if 'site_url' not in credentials or not credentials.get('site_url'): + site_url = 'https://gitlab.com' + else: + site_url = credentials.get('site_url') + + try: + headers = { + "Content-Type": "application/vnd.text+json", + "Authorization": f"Bearer {credentials.get('access_tokens')}", + } + + response = requests.get( + url= f"{site_url}/api/v4/user", + headers=headers) + if response.status_code != 200: + raise ToolProviderCredentialValidationError((response.json()).get('message')) + except Exception as e: + raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e)) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.yaml b/api/core/tools/provider/builtin/gitlab/gitlab.yaml new file mode 100644 index 00000000000000..b5feea23823449 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/gitlab.yaml @@ -0,0 +1,38 @@ +identity: + author: Leo.Wang + name: gitlab + label: + en_US: Gitlab + zh_Hans: Gitlab + description: + en_US: Gitlab plugin for commit + zh_Hans: 用于获取Gitlab commit的插件 + icon: gitlab.svg +credentials_for_provider: + access_tokens: + type: secret-input + required: true + label: + en_US: Gitlab access token + zh_Hans: Gitlab access token + placeholder: + en_US: Please input your Gitlab access token + zh_Hans: 请输入你的 Gitlab access token + help: + en_US: Get your Gitlab access token from Gitlab + zh_Hans: 从 Gitlab 获取您的 access token + url: https://docs.gitlab.com/16.9/ee/api/oauth2.html + site_url: + type: text-input + required: false + default: 'https://gitlab.com' + label: + en_US: Gitlab site url + zh_Hans: Gitlab site url + placeholder: + en_US: Please input your Gitlab site url + zh_Hans: 请输入你的 Gitlab site url + help: + en_US: Find your Gitlab url + zh_Hans: 找到你的Gitlab url + url: https://gitlab.com/help diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py new file mode 100644 index 00000000000000..212bdb03abaaad --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -0,0 +1,101 @@ +import json +from datetime import datetime, timedelta +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GitlabCommitsTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + project = tool_parameters.get('project', '') + employee = tool_parameters.get('employee', '') + start_time = tool_parameters.get('start_time', '') + end_time = tool_parameters.get('end_time', '') + + if not project: + return self.create_text_message('Project is required') + + if not start_time: + start_time = (datetime.utcnow() - timedelta(days=1)).isoformat() + if not end_time: + end_time = datetime.utcnow().isoformat() + + access_token = self.runtime.credentials.get('access_tokens') + site_url = self.runtime.credentials.get('site_url') + + if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + return self.create_text_message("Gitlab API Access Tokens is required.") + if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): + site_url = 'https://gitlab.com' + + # Get commit content + result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time) + + return self.create_text_message(json.dumps(result, ensure_ascii=False)) + + def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]: + domain = site_url + headers = {"PRIVATE-TOKEN": access_token} + results = [] + + try: + # Get all of projects + url = f"{domain}/api/v4/projects" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() + + filtered_projects = [p for p in projects if project == "*" or p['name'] == project] + + for project in filtered_projects: + project_id = project['id'] + project_name = project['name'] + print(f"Project: {project_name}") + + # Get all of proejct commits + commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" + params = { + 'since': start_time, + 'until': end_time + } + if employee: + params['author'] = employee + + commits_response = requests.get(commits_url, headers=headers, params=params) + commits_response.raise_for_status() + commits = commits_response.json() + + for commit in commits: + commit_sha = commit['id'] + print(f"\tCommit SHA: {commit_sha}") + + diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" + diff_response = requests.get(diff_url, headers=headers) + diff_response.raise_for_status() + diffs = diff_response.json() + + for diff in diffs: + # Caculate code lines of changed + added_lines = diff['diff'].count('\n+') + removed_lines = diff['diff'].count('\n-') + total_changes = added_lines + removed_lines + + if total_changes > 1: + final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) + results.append({ + "project": project_name, + "commit_sha": commit_sha, + "diff": final_code + }) + print(f"Commit code:{final_code}") + except requests.RequestException as e: + print(f"Error fetching data from GitLab: {e}") + + return results \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml new file mode 100644 index 00000000000000..fc4e7eb7bb3ed4 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml @@ -0,0 +1,56 @@ +identity: + name: gitlab_commits + author: Leo.Wang + label: + en_US: Gitlab Commits + zh_Hans: Gitlab代码提交内容 +description: + human: + en_US: A tool for query gitlab commits. Input should be a exists username. + zh_Hans: 一个用于查询gitlab代码提交记录的的工具,输入的内容应该是一个已存在的用户名或者项目名。 + llm: A tool for query gitlab commits. Input should be a exists username or project. +parameters: + - name: employee + type: string + required: false + label: + en_US: employee + zh_Hans: 员工用户名 + human_description: + en_US: employee + zh_Hans: 员工用户名 + llm_description: employee for gitlab + form: llm + - name: project + type: string + required: true + label: + en_US: project + zh_Hans: 项目名 + human_description: + en_US: project + zh_Hans: 项目名 + llm_description: project for gitlab + form: llm + - name: start_time + type: string + required: false + label: + en_US: start_time + zh_Hans: 开始时间 + human_description: + en_US: start_time + zh_Hans: 开始时间 + llm_description: start_time for gitlab + form: llm + - name: end_time + type: string + required: false + label: + en_US: end_time + zh_Hans: 结束时间 + human_description: + en_US: end_time + zh_Hans: 结束时间 + llm_description: end_time for gitlab + form: llm diff --git a/api/core/tools/provider/builtin/searxng/docker/settings.yml b/api/core/tools/provider/builtin/searxng/docker/settings.yml new file mode 100644 index 00000000000000..18e18688002cbc --- /dev/null +++ b/api/core/tools/provider/builtin/searxng/docker/settings.yml @@ -0,0 +1,2501 @@ +general: + # Debug mode, only for development. Is overwritten by ${SEARXNG_DEBUG} + debug: false + # displayed name + instance_name: "searxng" + # For example: https://example.com/privacy + privacypolicy_url: false + # use true to use your own donation page written in searx/info/en/donate.md + # use false to disable the donation link + donation_url: false + # mailto:contact@example.com + contact_url: false + # record stats + enable_metrics: true + +brand: + new_issue_url: https://github.com/searxng/searxng/issues/new + docs_url: https://docs.searxng.org/ + public_instances: https://searx.space + wiki_url: https://github.com/searxng/searxng/wiki + issue_url: https://github.com/searxng/searxng/issues + # custom: + # maintainer: "Jon Doe" + # # Custom entries in the footer: [title]: [link] + # links: + # Uptime: https://uptime.searxng.org/history/darmarit-org + # About: "https://searxng.org" + +search: + # Filter results. 0: None, 1: Moderate, 2: Strict + safe_search: 0 + # Existing autocomplete backends: "dbpedia", "duckduckgo", "google", "yandex", "mwmbl", + # "seznam", "startpage", "stract", "swisscows", "qwant", "wikipedia" - leave blank to turn it off + # by default. + autocomplete: "" + # minimun characters to type before autocompleter starts + autocomplete_min: 4 + # Default search language - leave blank to detect from browser information or + # use codes from 'languages.py' + default_lang: "auto" + # max_page: 0 # if engine supports paging, 0 means unlimited numbers of pages + # Available languages + # languages: + # - all + # - en + # - en-US + # - de + # - it-IT + # - fr + # - fr-BE + # ban time in seconds after engine errors + ban_time_on_fail: 5 + # max ban time in seconds after engine errors + max_ban_time_on_fail: 120 + suspended_times: + # Engine suspension time after error (in seconds; set to 0 to disable) + # For error "Access denied" and "HTTP error [402, 403]" + SearxEngineAccessDenied: 86400 + # For error "CAPTCHA" + SearxEngineCaptcha: 86400 + # For error "Too many request" and "HTTP error 429" + SearxEngineTooManyRequests: 3600 + # Cloudflare CAPTCHA + cf_SearxEngineCaptcha: 1296000 + cf_SearxEngineAccessDenied: 86400 + # ReCAPTCHA + recaptcha_SearxEngineCaptcha: 604800 + + # remove format to deny access, use lower case. + # formats: [html, csv, json, rss] + formats: + - html + - json + +server: + # Is overwritten by ${SEARXNG_PORT} and ${SEARXNG_BIND_ADDRESS} + port: 8888 + bind_address: "127.0.0.1" + # public URL of the instance, to ensure correct inbound links. Is overwritten + # by ${SEARXNG_URL}. + base_url: http://0.0.0.0:8081/ # "http://example.com/location" + # rate limit the number of request on the instance, block some bots. + # Is overwritten by ${SEARXNG_LIMITER} + limiter: false + # enable features designed only for public instances. + # Is overwritten by ${SEARXNG_PUBLIC_INSTANCE} + public_instance: false + + # If your instance owns a /etc/searxng/settings.yml file, then set the following + # values there. + + secret_key: "772ba36386fb56d0f8fe818941552dabbe69220d4c0eb4a385a5729cdbc20c2d" # Is overwritten by ${SEARXNG_SECRET} + # Proxy image results through SearXNG. Is overwritten by ${SEARXNG_IMAGE_PROXY} + image_proxy: false + # 1.0 and 1.1 are supported + http_protocol_version: "1.0" + # POST queries are more secure as they don't show up in history but may cause + # problems when using Firefox containers + method: "POST" + default_http_headers: + X-Content-Type-Options: nosniff + X-Download-Options: noopen + X-Robots-Tag: noindex, nofollow + Referrer-Policy: no-referrer + +redis: + # URL to connect redis database. Is overwritten by ${SEARXNG_REDIS_URL}. + # https://docs.searxng.org/admin/settings/settings_redis.html#settings-redis + url: false + +ui: + # Custom static path - leave it blank if you didn't change + static_path: "" + # Is overwritten by ${SEARXNG_STATIC_USE_HASH}. + static_use_hash: false + # Custom templates path - leave it blank if you didn't change + templates_path: "" + # query_in_title: When true, the result page's titles contains the query + # it decreases the privacy, since the browser can records the page titles. + query_in_title: false + # infinite_scroll: When true, automatically loads the next page when scrolling to bottom of the current page. + infinite_scroll: false + # ui theme + default_theme: simple + # center the results ? + center_alignment: false + # URL prefix of the internet archive, don't forget trailing slash (if needed). + # cache_url: "https://webcache.googleusercontent.com/search?q=cache:" + # Default interface locale - leave blank to detect from browser information or + # use codes from the 'locales' config section + default_locale: "" + # Open result links in a new tab by default + # results_on_new_tab: false + theme_args: + # style of simple theme: auto, light, dark + simple_style: auto + # Perform search immediately if a category selected. + # Disable to select multiple categories at once and start the search manually. + search_on_category_select: true + # Hotkeys: default or vim + hotkeys: default + +# Lock arbitrary settings on the preferences page. To find the ID of the user +# setting you want to lock, check the ID of the form on the page "preferences". +# +# preferences: +# lock: +# - language +# - autocomplete +# - method +# - query_in_title + +# searx supports result proxification using an external service: +# https://github.com/asciimoo/morty uncomment below section if you have running +# morty proxy the key is base64 encoded (keep the !!binary notation) +# Note: since commit af77ec3, morty accepts a base64 encoded key. +# +# result_proxy: +# url: http://127.0.0.1:3000/ +# # the key is a base64 encoded string, the YAML !!binary prefix is optional +# key: !!binary "your_morty_proxy_key" +# # [true|false] enable the "proxy" button next to each result +# proxify_results: true + +# communication with search engines +# +outgoing: + # default timeout in seconds, can be override by engine + request_timeout: 3.0 + # the maximum timeout in seconds + # max_request_timeout: 10.0 + # suffix of searx_useragent, could contain information like an email address + # to the administrator + useragent_suffix: "" + # The maximum number of concurrent connections that may be established. + pool_connections: 100 + # Allow the connection pool to maintain keep-alive connections below this + # point. + pool_maxsize: 20 + # See https://www.python-httpx.org/http2/ + enable_http2: true + # uncomment below section if you want to use a custom server certificate + # see https://www.python-httpx.org/advanced/#changing-the-verification-defaults + # and https://www.python-httpx.org/compatibility/#ssl-configuration + # verify: ~/.mitmproxy/mitmproxy-ca-cert.cer + # + # uncomment below section if you want to use a proxyq see: SOCKS proxies + # https://2.python-requests.org/en/latest/user/advanced/#proxies + # are also supported: see + # https://2.python-requests.org/en/latest/user/advanced/#socks + # + # proxies: + # all://: + # - http://host.docker.internal:1080 + # + # using_tor_proxy: true + # + # Extra seconds to add in order to account for the time taken by the proxy + # + # extra_proxy_timeout: 10 + # + # uncomment below section only if you have more than one network interface + # which can be the source of outgoing search requests + # + # source_ips: + # - 1.1.1.1 + # - 1.1.1.2 + # - fe80::/126 + +# External plugin configuration, for more details see +# https://docs.searxng.org/dev/plugins.html +# +# plugins: +# - plugin1 +# - plugin2 +# - ... + +# Comment or un-comment plugin to activate / deactivate by default. +# +# enabled_plugins: +# # these plugins are enabled if nothing is configured .. +# - 'Hash plugin' +# - 'Self Information' +# - 'Tracker URL remover' +# - 'Ahmia blacklist' # activation depends on outgoing.using_tor_proxy +# # these plugins are disabled if nothing is configured .. +# - 'Hostnames plugin' # see 'hostnames' configuration below +# - 'Basic Calculator' +# - 'Open Access DOI rewrite' +# - 'Tor check plugin' +# # Read the docs before activate: auto-detection of the language could be +# # detrimental to users expectations / users can activate the plugin in the +# # preferences if they want. +# - 'Autodetect search language' + +# Configuration of the "Hostnames plugin": +# +# hostnames: +# replace: +# '(.*\.)?youtube\.com$': 'invidious.example.com' +# '(.*\.)?youtu\.be$': 'invidious.example.com' +# '(.*\.)?reddit\.com$': 'teddit.example.com' +# '(.*\.)?redd\.it$': 'teddit.example.com' +# '(www\.)?twitter\.com$': 'nitter.example.com' +# remove: +# - '(.*\.)?facebook.com$' +# low_priority: +# - '(.*\.)?google(\..*)?$' +# high_priority: +# - '(.*\.)?wikipedia.org$' +# +# Alternatively you can use external files for configuring the "Hostnames plugin": +# +# hostnames: +# replace: 'rewrite-hosts.yml' +# +# Content of 'rewrite-hosts.yml' (place the file in the same directory as 'settings.yml'): +# '(.*\.)?youtube\.com$': 'invidious.example.com' +# '(.*\.)?youtu\.be$': 'invidious.example.com' +# + +checker: + # disable checker when in debug mode + off_when_debug: true + + # use "scheduling: false" to disable scheduling + # scheduling: interval or int + + # to activate the scheduler: + # * uncomment "scheduling" section + # * add "cache2 = name=searxngcache,items=2000,blocks=2000,blocksize=4096,bitmap=1" + # to your uwsgi.ini + + # scheduling: + # start_after: [300, 1800] # delay to start the first run of the checker + # every: [86400, 90000] # how often the checker runs + + # additional tests: only for the YAML anchors (see the engines section) + # + additional_tests: + rosebud: &test_rosebud + matrix: + query: rosebud + lang: en + result_container: + - not_empty + - ['one_title_contains', 'citizen kane'] + test: + - unique_results + + android: &test_android + matrix: + query: ['android'] + lang: ['en', 'de', 'fr', 'zh-CN'] + result_container: + - not_empty + - ['one_title_contains', 'google'] + test: + - unique_results + + # tests: only for the YAML anchors (see the engines section) + tests: + infobox: &tests_infobox + infobox: + matrix: + query: ["linux", "new york", "bbc"] + result_container: + - has_infobox + +categories_as_tabs: + general: + images: + videos: + news: + map: + music: + it: + science: + files: + social media: + +engines: + - name: 9gag + engine: 9gag + shortcut: 9g + disabled: true + + - name: alpine linux packages + engine: alpinelinux + disabled: true + shortcut: alp + + - name: annas archive + engine: annas_archive + disabled: true + shortcut: aa + + # - name: annas articles + # engine: annas_archive + # shortcut: aaa + # # https://docs.searxng.org/dev/engines/online/annas_archive.html + # aa_content: 'magazine' # book_fiction, book_unknown, book_nonfiction, book_comic + # aa_ext: 'pdf' # pdf, epub, .. + # aa_sort: oldest' # newest, oldest, largest, smallest + + - name: apk mirror + engine: apkmirror + timeout: 4.0 + shortcut: apkm + disabled: true + + - name: apple app store + engine: apple_app_store + shortcut: aps + disabled: true + + # Requires Tor + - name: ahmia + engine: ahmia + categories: onions + enable_http: true + shortcut: ah + + - name: anaconda + engine: xpath + paging: true + first_page_num: 0 + search_url: https://anaconda.org/search?q={query}&page={pageno} + results_xpath: //tbody/tr + url_xpath: ./td/h5/a[last()]/@href + title_xpath: ./td/h5 + content_xpath: ./td[h5]/text() + categories: it + timeout: 6.0 + shortcut: conda + disabled: true + + - name: arch linux wiki + engine: archlinux + shortcut: al + + - name: artic + engine: artic + shortcut: arc + timeout: 4.0 + + - name: arxiv + engine: arxiv + shortcut: arx + timeout: 4.0 + + - name: ask + engine: ask + shortcut: ask + disabled: true + + # tmp suspended: dh key too small + # - name: base + # engine: base + # shortcut: bs + + - name: bandcamp + engine: bandcamp + shortcut: bc + categories: music + + - name: wikipedia + engine: wikipedia + shortcut: wp + # add "list" to the array to get results in the results list + display_type: ["infobox"] + base_url: 'https://{language}.wikipedia.org/' + categories: [general] + + - name: bilibili + engine: bilibili + shortcut: bil + disabled: true + + - name: bing + engine: bing + shortcut: bi + disabled: false + + - name: bing images + engine: bing_images + shortcut: bii + + - name: bing news + engine: bing_news + shortcut: bin + + - name: bing videos + engine: bing_videos + shortcut: biv + + - name: bitbucket + engine: xpath + paging: true + search_url: https://bitbucket.org/repo/all/{pageno}?name={query} + url_xpath: //article[@class="repo-summary"]//a[@class="repo-link"]/@href + title_xpath: //article[@class="repo-summary"]//a[@class="repo-link"] + content_xpath: //article[@class="repo-summary"]/p + categories: [it, repos] + timeout: 4.0 + disabled: true + shortcut: bb + about: + website: https://bitbucket.org/ + wikidata_id: Q2493781 + official_api_documentation: https://developer.atlassian.com/bitbucket + use_official_api: false + require_api_key: false + results: HTML + + - name: bpb + engine: bpb + shortcut: bpb + disabled: true + + - name: btdigg + engine: btdigg + shortcut: bt + disabled: true + + - name: openverse + engine: openverse + categories: images + shortcut: opv + + - name: media.ccc.de + engine: ccc_media + shortcut: c3tv + # We don't set language: de here because media.ccc.de is not just + # for a German audience. It contains many English videos and many + # German videos have English subtitles. + disabled: true + + - name: chefkoch + engine: chefkoch + shortcut: chef + # to show premium or plus results too: + # skip_premium: false + + # - name: core.ac.uk + # engine: core + # categories: science + # shortcut: cor + # # get your API key from: https://core.ac.uk/api-keys/register/ + # api_key: 'unset' + + - name: cppreference + engine: cppreference + shortcut: cpp + paging: false + disabled: true + + - name: crossref + engine: crossref + shortcut: cr + timeout: 30 + disabled: true + + - name: crowdview + engine: json_engine + shortcut: cv + categories: general + paging: false + search_url: https://crowdview-next-js.onrender.com/api/search-v3?query={query} + results_query: results + url_query: link + title_query: title + content_query: snippet + disabled: true + about: + website: https://crowdview.ai/ + + - name: yep + engine: yep + shortcut: yep + categories: general + search_type: web + timeout: 5 + disabled: true + + - name: yep images + engine: yep + shortcut: yepi + categories: images + search_type: images + disabled: true + + - name: yep news + engine: yep + shortcut: yepn + categories: news + search_type: news + disabled: true + + - name: curlie + engine: xpath + shortcut: cl + categories: general + disabled: true + paging: true + lang_all: '' + search_url: https://curlie.org/search?q={query}&lang={lang}&start={pageno}&stime=92452189 + page_size: 20 + results_xpath: //div[@id="site-list-content"]/div[@class="site-item"] + url_xpath: ./div[@class="title-and-desc"]/a/@href + title_xpath: ./div[@class="title-and-desc"]/a/div + content_xpath: ./div[@class="title-and-desc"]/div[@class="site-descr"] + about: + website: https://curlie.org/ + wikidata_id: Q60715723 + use_official_api: false + require_api_key: false + results: HTML + + - name: currency + engine: currency_convert + categories: general + shortcut: cc + + - name: bahnhof + engine: json_engine + search_url: https://www.bahnhof.de/api/stations/search/{query} + url_prefix: https://www.bahnhof.de/ + url_query: slug + title_query: name + content_query: state + shortcut: bf + disabled: true + about: + website: https://www.bahn.de + wikidata_id: Q22811603 + use_official_api: false + require_api_key: false + results: JSON + language: de + tests: + bahnhof: + matrix: + query: berlin + lang: en + result_container: + - not_empty + - ['one_title_contains', 'Berlin Hauptbahnhof'] + test: + - unique_results + + - name: deezer + engine: deezer + shortcut: dz + disabled: true + + - name: destatis + engine: destatis + shortcut: destat + disabled: true + + - name: deviantart + engine: deviantart + shortcut: da + timeout: 3.0 + + - name: ddg definitions + engine: duckduckgo_definitions + shortcut: ddd + weight: 2 + disabled: true + tests: *tests_infobox + + # cloudflare protected + # - name: digbt + # engine: digbt + # shortcut: dbt + # timeout: 6.0 + # disabled: true + + - name: docker hub + engine: docker_hub + shortcut: dh + categories: [it, packages] + + - name: encyclosearch + engine: json_engine + shortcut: es + categories: general + paging: true + search_url: https://encyclosearch.org/encyclosphere/search?q={query}&page={pageno}&resultsPerPage=15 + results_query: Results + url_query: SourceURL + title_query: Title + content_query: Description + disabled: true + about: + website: https://encyclosearch.org + official_api_documentation: https://encyclosearch.org/docs/#/rest-api + use_official_api: true + require_api_key: false + results: JSON + + - name: erowid + engine: xpath + paging: true + first_page_num: 0 + page_size: 30 + search_url: https://www.erowid.org/search.php?q={query}&s={pageno} + url_xpath: //dl[@class="results-list"]/dt[@class="result-title"]/a/@href + title_xpath: //dl[@class="results-list"]/dt[@class="result-title"]/a/text() + content_xpath: //dl[@class="results-list"]/dd[@class="result-details"] + categories: [] + shortcut: ew + disabled: true + about: + website: https://www.erowid.org/ + wikidata_id: Q1430691 + official_api_documentation: + use_official_api: false + require_api_key: false + results: HTML + + # - name: elasticsearch + # shortcut: es + # engine: elasticsearch + # base_url: http://localhost:9200 + # username: elastic + # password: changeme + # index: my-index + # # available options: match, simple_query_string, term, terms, custom + # query_type: match + # # if query_type is set to custom, provide your query here + # #custom_query_json: {"query":{"match_all": {}}} + # #show_metadata: false + # disabled: true + + - name: wikidata + engine: wikidata + shortcut: wd + timeout: 3.0 + weight: 2 + # add "list" to the array to get results in the results list + display_type: ["infobox"] + tests: *tests_infobox + categories: [general] + + - name: duckduckgo + engine: duckduckgo + shortcut: ddg + + - name: duckduckgo images + engine: duckduckgo_extra + categories: [images, web] + ddg_category: images + shortcut: ddi + disabled: true + + - name: duckduckgo videos + engine: duckduckgo_extra + categories: [videos, web] + ddg_category: videos + shortcut: ddv + disabled: true + + - name: duckduckgo news + engine: duckduckgo_extra + categories: [news, web] + ddg_category: news + shortcut: ddn + disabled: true + + - name: duckduckgo weather + engine: duckduckgo_weather + shortcut: ddw + disabled: true + + - name: apple maps + engine: apple_maps + shortcut: apm + disabled: true + timeout: 5.0 + + - name: emojipedia + engine: emojipedia + timeout: 4.0 + shortcut: em + disabled: true + + - name: tineye + engine: tineye + shortcut: tin + timeout: 9.0 + disabled: true + + - name: etymonline + engine: xpath + paging: true + search_url: https://etymonline.com/search?page={pageno}&q={query} + url_xpath: //a[contains(@class, "word__name--")]/@href + title_xpath: //a[contains(@class, "word__name--")] + content_xpath: //section[contains(@class, "word__defination")] + first_page_num: 1 + shortcut: et + categories: [dictionaries] + about: + website: https://www.etymonline.com/ + wikidata_id: Q1188617 + official_api_documentation: + use_official_api: false + require_api_key: false + results: HTML + + # - name: ebay + # engine: ebay + # shortcut: eb + # base_url: 'https://www.ebay.com' + # disabled: true + # timeout: 5 + + - name: 1x + engine: www1x + shortcut: 1x + timeout: 3.0 + disabled: true + + - name: fdroid + engine: fdroid + shortcut: fd + disabled: true + + - name: findthatmeme + engine: findthatmeme + shortcut: ftm + disabled: true + + - name: flickr + categories: images + shortcut: fl + # You can use the engine using the official stable API, but you need an API + # key, see: https://www.flickr.com/services/apps/create/ + # engine: flickr + # api_key: 'apikey' # required! + # Or you can use the html non-stable engine, activated by default + engine: flickr_noapi + + - name: free software directory + engine: mediawiki + shortcut: fsd + categories: [it, software wikis] + base_url: https://directory.fsf.org/ + search_type: title + timeout: 5.0 + disabled: true + about: + website: https://directory.fsf.org/ + wikidata_id: Q2470288 + + # - name: freesound + # engine: freesound + # shortcut: fnd + # disabled: true + # timeout: 15.0 + # API key required, see: https://freesound.org/docs/api/overview.html + # api_key: MyAPIkey + + - name: frinkiac + engine: frinkiac + shortcut: frk + disabled: true + + - name: fyyd + engine: fyyd + shortcut: fy + timeout: 8.0 + disabled: true + + - name: geizhals + engine: geizhals + shortcut: geiz + disabled: true + + - name: genius + engine: genius + shortcut: gen + + - name: gentoo + engine: mediawiki + shortcut: ge + categories: ["it", "software wikis"] + base_url: "https://wiki.gentoo.org/" + api_path: "api.php" + search_type: text + timeout: 10 + + - name: gitlab + engine: json_engine + paging: true + search_url: https://gitlab.com/api/v4/projects?search={query}&page={pageno} + url_query: web_url + title_query: name_with_namespace + content_query: description + page_size: 20 + categories: [it, repos] + shortcut: gl + timeout: 10.0 + disabled: true + about: + website: https://about.gitlab.com/ + wikidata_id: Q16639197 + official_api_documentation: https://docs.gitlab.com/ee/api/ + use_official_api: false + require_api_key: false + results: JSON + + - name: github + engine: github + shortcut: gh + + - name: codeberg + # https://docs.searxng.org/dev/engines/online/gitea.html + engine: gitea + base_url: https://codeberg.org + shortcut: cb + disabled: true + + - name: gitea.com + engine: gitea + base_url: https://gitea.com + shortcut: gitea + disabled: true + + - name: goodreads + engine: goodreads + shortcut: good + timeout: 4.0 + disabled: true + + - name: google + engine: google + shortcut: go + # additional_tests: + # android: *test_android + + - name: google images + engine: google_images + shortcut: goi + # additional_tests: + # android: *test_android + # dali: + # matrix: + # query: ['Dali Christ'] + # lang: ['en', 'de', 'fr', 'zh-CN'] + # result_container: + # - ['one_title_contains', 'Salvador'] + + - name: google news + engine: google_news + shortcut: gon + # additional_tests: + # android: *test_android + + - name: google videos + engine: google_videos + shortcut: gov + # additional_tests: + # android: *test_android + + - name: google scholar + engine: google_scholar + shortcut: gos + + - name: google play apps + engine: google_play + categories: [files, apps] + shortcut: gpa + play_categ: apps + disabled: true + + - name: google play movies + engine: google_play + categories: videos + shortcut: gpm + play_categ: movies + disabled: true + + - name: material icons + engine: material_icons + categories: images + shortcut: mi + disabled: true + + - name: gpodder + engine: json_engine + shortcut: gpod + timeout: 4.0 + paging: false + search_url: https://gpodder.net/search.json?q={query} + url_query: url + title_query: title + content_query: description + page_size: 19 + categories: music + disabled: true + about: + website: https://gpodder.net + wikidata_id: Q3093354 + official_api_documentation: https://gpoddernet.readthedocs.io/en/latest/api/ + use_official_api: false + requires_api_key: false + results: JSON + + - name: habrahabr + engine: xpath + paging: true + search_url: https://habr.com/en/search/page{pageno}/?q={query} + results_xpath: //article[contains(@class, "tm-articles-list__item")] + url_xpath: .//a[@class="tm-title__link"]/@href + title_xpath: .//a[@class="tm-title__link"] + content_xpath: .//div[contains(@class, "article-formatted-body")] + categories: it + timeout: 4.0 + disabled: true + shortcut: habr + about: + website: https://habr.com/ + wikidata_id: Q4494434 + official_api_documentation: https://habr.com/en/docs/help/api/ + use_official_api: false + require_api_key: false + results: HTML + + - name: hackernews + engine: hackernews + shortcut: hn + disabled: true + + - name: hex + engine: hex + shortcut: hex + disabled: true + # Valid values: name inserted_at updated_at total_downloads recent_downloads + sort_criteria: "recent_downloads" + page_size: 10 + + - name: crates.io + engine: crates + shortcut: crates + disabled: true + timeout: 6.0 + + - name: hoogle + engine: xpath + search_url: https://hoogle.haskell.org/?hoogle={query} + results_xpath: '//div[@class="result"]' + title_xpath: './/div[@class="ans"]//a' + url_xpath: './/div[@class="ans"]//a/@href' + content_xpath: './/div[@class="from"]' + page_size: 20 + categories: [it, packages] + shortcut: ho + about: + website: https://hoogle.haskell.org/ + wikidata_id: Q34010 + official_api_documentation: https://hackage.haskell.org/api + use_official_api: false + require_api_key: false + results: JSON + + - name: imdb + engine: imdb + shortcut: imdb + timeout: 6.0 + disabled: true + + - name: imgur + engine: imgur + shortcut: img + disabled: true + + - name: ina + engine: ina + shortcut: in + timeout: 6.0 + disabled: true + + - name: invidious + engine: invidious + # Instanes will be selected randomly, see https://api.invidious.io/ for + # instances that are stable (good uptime) and close to you. + base_url: + - https://invidious.io.lol + - https://invidious.fdn.fr + - https://yt.artemislena.eu + - https://invidious.tiekoetter.com + - https://invidious.flokinet.to + - https://vid.puffyan.us + - https://invidious.privacydev.net + - https://inv.tux.pizza + shortcut: iv + timeout: 3.0 + disabled: true + + - name: jisho + engine: jisho + shortcut: js + timeout: 3.0 + disabled: true + + - name: kickass + engine: kickass + base_url: + - https://kickasstorrents.to + - https://kickasstorrents.cr + - https://kickasstorrent.cr + - https://kickass.sx + - https://kat.am + shortcut: kc + timeout: 4.0 + disabled: true + + - name: lemmy communities + engine: lemmy + lemmy_type: Communities + shortcut: leco + + - name: lemmy users + engine: lemmy + network: lemmy communities + lemmy_type: Users + shortcut: leus + + - name: lemmy posts + engine: lemmy + network: lemmy communities + lemmy_type: Posts + shortcut: lepo + + - name: lemmy comments + engine: lemmy + network: lemmy communities + lemmy_type: Comments + shortcut: lecom + + - name: library genesis + engine: xpath + # search_url: https://libgen.is/search.php?req={query} + search_url: https://libgen.rs/search.php?req={query} + url_xpath: //a[contains(@href,"book/index.php?md5")]/@href + title_xpath: //a[contains(@href,"book/")]/text()[1] + content_xpath: //td/a[1][contains(@href,"=author")]/text() + categories: files + timeout: 7.0 + disabled: true + shortcut: lg + about: + website: https://libgen.fun/ + wikidata_id: Q22017206 + official_api_documentation: + use_official_api: false + require_api_key: false + results: HTML + + - name: z-library + engine: zlibrary + shortcut: zlib + categories: files + timeout: 7.0 + disabled: true + + - name: library of congress + engine: loc + shortcut: loc + categories: images + + - name: libretranslate + engine: libretranslate + # https://github.com/LibreTranslate/LibreTranslate?tab=readme-ov-file#mirrors + base_url: + - https://translate.terraprint.co + - https://trans.zillyhuhn.com + # api_key: abc123 + shortcut: lt + disabled: true + + - name: lingva + engine: lingva + shortcut: lv + # set lingva instance in url, by default it will use the official instance + # url: https://lingva.thedaviddelta.com + + - name: lobste.rs + engine: xpath + search_url: https://lobste.rs/search?q={query}&what=stories&order=relevance + results_xpath: //li[contains(@class, "story")] + url_xpath: .//a[@class="u-url"]/@href + title_xpath: .//a[@class="u-url"] + content_xpath: .//a[@class="domain"] + categories: it + shortcut: lo + timeout: 5.0 + disabled: true + about: + website: https://lobste.rs/ + wikidata_id: Q60762874 + official_api_documentation: + use_official_api: false + require_api_key: false + results: HTML + + - name: mastodon users + engine: mastodon + mastodon_type: accounts + base_url: https://mastodon.social + shortcut: mau + + - name: mastodon hashtags + engine: mastodon + mastodon_type: hashtags + base_url: https://mastodon.social + shortcut: mah + + # - name: matrixrooms + # engine: mrs + # # https://docs.searxng.org/dev/engines/online/mrs.html + # # base_url: https://mrs-api-host + # shortcut: mtrx + # disabled: true + + - name: mdn + shortcut: mdn + engine: json_engine + categories: [it] + paging: true + search_url: https://developer.mozilla.org/api/v1/search?q={query}&page={pageno} + results_query: documents + url_query: mdn_url + url_prefix: https://developer.mozilla.org + title_query: title + content_query: summary + about: + website: https://developer.mozilla.org + wikidata_id: Q3273508 + official_api_documentation: null + use_official_api: false + require_api_key: false + results: JSON + + - name: metacpan + engine: metacpan + shortcut: cpan + disabled: true + number_of_results: 20 + + # - name: meilisearch + # engine: meilisearch + # shortcut: mes + # enable_http: true + # base_url: http://localhost:7700 + # index: my-index + + - name: mixcloud + engine: mixcloud + shortcut: mc + + # MongoDB engine + # Required dependency: pymongo + # - name: mymongo + # engine: mongodb + # shortcut: md + # exact_match_only: false + # host: '127.0.0.1' + # port: 27017 + # enable_http: true + # results_per_page: 20 + # database: 'business' + # collection: 'reviews' # name of the db collection + # key: 'name' # key in the collection to search for + + - name: mozhi + engine: mozhi + base_url: + - https://mozhi.aryak.me + - https://translate.bus-hit.me + - https://nyc1.mz.ggtyler.dev + # mozhi_engine: google - see https://mozhi.aryak.me for supported engines + timeout: 4.0 + shortcut: mz + disabled: true + + - name: mwmbl + engine: mwmbl + # api_url: https://api.mwmbl.org + shortcut: mwm + disabled: true + + - name: npm + engine: npm + shortcut: npm + timeout: 5.0 + disabled: true + + - name: nyaa + engine: nyaa + shortcut: nt + disabled: true + + - name: mankier + engine: json_engine + search_url: https://www.mankier.com/api/v2/mans/?q={query} + results_query: results + url_query: url + title_query: name + content_query: description + categories: it + shortcut: man + about: + website: https://www.mankier.com/ + official_api_documentation: https://www.mankier.com/api + use_official_api: true + require_api_key: false + results: JSON + + # read https://docs.searxng.org/dev/engines/online/mullvad_leta.html + # - name: mullvadleta + # engine: mullvad_leta + # leta_engine: google # choose one of the following: google, brave + # use_cache: true # Only 100 non-cache searches per day, suggested only for private instances + # search_url: https://leta.mullvad.net + # categories: [general, web] + # shortcut: ml + + - name: odysee + engine: odysee + shortcut: od + disabled: true + + - name: openairedatasets + engine: json_engine + paging: true + search_url: https://api.openaire.eu/search/datasets?format=json&page={pageno}&size=10&title={query} + results_query: response/results/result + url_query: metadata/oaf:entity/oaf:result/children/instance/webresource/url/$ + title_query: metadata/oaf:entity/oaf:result/title/$ + content_query: metadata/oaf:entity/oaf:result/description/$ + content_html_to_text: true + categories: "science" + shortcut: oad + timeout: 5.0 + about: + website: https://www.openaire.eu/ + wikidata_id: Q25106053 + official_api_documentation: https://api.openaire.eu/ + use_official_api: false + require_api_key: false + results: JSON + + - name: openairepublications + engine: json_engine + paging: true + search_url: https://api.openaire.eu/search/publications?format=json&page={pageno}&size=10&title={query} + results_query: response/results/result + url_query: metadata/oaf:entity/oaf:result/children/instance/webresource/url/$ + title_query: metadata/oaf:entity/oaf:result/title/$ + content_query: metadata/oaf:entity/oaf:result/description/$ + content_html_to_text: true + categories: science + shortcut: oap + timeout: 5.0 + about: + website: https://www.openaire.eu/ + wikidata_id: Q25106053 + official_api_documentation: https://api.openaire.eu/ + use_official_api: false + require_api_key: false + results: JSON + + - name: openmeteo + engine: open_meteo + shortcut: om + disabled: true + + # - name: opensemanticsearch + # engine: opensemantic + # shortcut: oss + # base_url: 'http://localhost:8983/solr/opensemanticsearch/' + + - name: openstreetmap + engine: openstreetmap + shortcut: osm + + - name: openrepos + engine: xpath + paging: true + search_url: https://openrepos.net/search/node/{query}?page={pageno} + url_xpath: //li[@class="search-result"]//h3[@class="title"]/a/@href + title_xpath: //li[@class="search-result"]//h3[@class="title"]/a + content_xpath: //li[@class="search-result"]//div[@class="search-snippet-info"]//p[@class="search-snippet"] + categories: files + timeout: 4.0 + disabled: true + shortcut: or + about: + website: https://openrepos.net/ + wikidata_id: + official_api_documentation: + use_official_api: false + require_api_key: false + results: HTML + + - name: packagist + engine: json_engine + paging: true + search_url: https://packagist.org/search.json?q={query}&page={pageno} + results_query: results + url_query: url + title_query: name + content_query: description + categories: [it, packages] + disabled: true + timeout: 5.0 + shortcut: pack + about: + website: https://packagist.org + wikidata_id: Q108311377 + official_api_documentation: https://packagist.org/apidoc + use_official_api: true + require_api_key: false + results: JSON + + - name: pdbe + engine: pdbe + shortcut: pdb + # Hide obsolete PDB entries. Default is not to hide obsolete structures + # hide_obsolete: false + + - name: photon + engine: photon + shortcut: ph + + - name: pinterest + engine: pinterest + shortcut: pin + + - name: piped + engine: piped + shortcut: ppd + categories: videos + piped_filter: videos + timeout: 3.0 + + # URL to use as link and for embeds + frontend_url: https://srv.piped.video + # Instance will be selected randomly, for more see https://piped-instances.kavin.rocks/ + backend_url: + - https://pipedapi.kavin.rocks + - https://pipedapi-libre.kavin.rocks + - https://pipedapi.adminforge.de + + - name: piped.music + engine: piped + network: piped + shortcut: ppdm + categories: music + piped_filter: music_songs + timeout: 3.0 + + - name: piratebay + engine: piratebay + shortcut: tpb + # You may need to change this URL to a proxy if piratebay is blocked in your + # country + url: https://thepiratebay.org/ + timeout: 3.0 + + - name: pixiv + shortcut: pv + engine: pixiv + disabled: true + inactive: true + pixiv_image_proxies: + - https://pximg.example.org + # A proxy is required to load the images. Hosting an image proxy server + # for Pixiv: + # --> https://pixivfe.pages.dev/hosting-image-proxy-server/ + # Proxies from public instances. Ask the public instances owners if they + # agree to receive traffic from SearXNG! + # --> https://codeberg.org/VnPower/PixivFE#instances + # --> https://github.com/searxng/searxng/pull/3192#issuecomment-1941095047 + # image proxy of https://pixiv.cat + # - https://i.pixiv.cat + # image proxy of https://www.pixiv.pics + # - https://pximg.cocomi.eu.org + # image proxy of https://pixivfe.exozy.me + # - https://pximg.exozy.me + # image proxy of https://pixivfe.ducks.party + # - https://pixiv.ducks.party + # image proxy of https://pixiv.perennialte.ch + # - https://pximg.perennialte.ch + + - name: podcastindex + engine: podcastindex + shortcut: podcast + + # Required dependency: psychopg2 + # - name: postgresql + # engine: postgresql + # database: postgres + # username: postgres + # password: postgres + # limit: 10 + # query_str: 'SELECT * from my_table WHERE my_column = %(query)s' + # shortcut : psql + + - name: presearch + engine: presearch + search_type: search + categories: [general, web] + shortcut: ps + timeout: 4.0 + disabled: true + + - name: presearch images + engine: presearch + network: presearch + search_type: images + categories: [images, web] + timeout: 4.0 + shortcut: psimg + disabled: true + + - name: presearch videos + engine: presearch + network: presearch + search_type: videos + categories: [general, web] + timeout: 4.0 + shortcut: psvid + disabled: true + + - name: presearch news + engine: presearch + network: presearch + search_type: news + categories: [news, web] + timeout: 4.0 + shortcut: psnews + disabled: true + + - name: pub.dev + engine: xpath + shortcut: pd + search_url: https://pub.dev/packages?q={query}&page={pageno} + paging: true + results_xpath: //div[contains(@class,"packages-item")] + url_xpath: ./div/h3/a/@href + title_xpath: ./div/h3/a + content_xpath: ./div/div/div[contains(@class,"packages-description")]/span + categories: [packages, it] + timeout: 3.0 + disabled: true + first_page_num: 1 + about: + website: https://pub.dev/ + official_api_documentation: https://pub.dev/help/api + use_official_api: false + require_api_key: false + results: HTML + + - name: pubmed + engine: pubmed + shortcut: pub + timeout: 3.0 + + - name: pypi + shortcut: pypi + engine: pypi + + - name: qwant + qwant_categ: web + engine: qwant + disabled: true + shortcut: qw + categories: [general, web] + additional_tests: + rosebud: *test_rosebud + + - name: qwant news + qwant_categ: news + engine: qwant + shortcut: qwn + categories: news + network: qwant + + - name: qwant images + qwant_categ: images + engine: qwant + shortcut: qwi + categories: [images, web] + network: qwant + + - name: qwant videos + qwant_categ: videos + engine: qwant + shortcut: qwv + categories: [videos, web] + network: qwant + + # - name: library + # engine: recoll + # shortcut: lib + # base_url: 'https://recoll.example.org/' + # search_dir: '' + # mount_prefix: /export + # dl_prefix: 'https://download.example.org' + # timeout: 30.0 + # categories: files + # disabled: true + + # - name: recoll library reference + # engine: recoll + # base_url: 'https://recoll.example.org/' + # search_dir: reference + # mount_prefix: /export + # dl_prefix: 'https://download.example.org' + # shortcut: libr + # timeout: 30.0 + # categories: files + # disabled: true + + - name: radio browser + engine: radio_browser + shortcut: rb + + - name: reddit + engine: reddit + shortcut: re + page_size: 25 + disabled: true + + - name: rottentomatoes + engine: rottentomatoes + shortcut: rt + disabled: true + + # Required dependency: redis + # - name: myredis + # shortcut : rds + # engine: redis_server + # exact_match_only: false + # host: '127.0.0.1' + # port: 6379 + # enable_http: true + # password: '' + # db: 0 + + # tmp suspended: bad certificate + # - name: scanr structures + # shortcut: scs + # engine: scanr_structures + # disabled: true + + - name: searchmysite + engine: xpath + shortcut: sms + categories: general + paging: true + search_url: https://searchmysite.net/search/?q={query}&page={pageno} + results_xpath: //div[contains(@class,'search-result')] + url_xpath: .//a[contains(@class,'result-link')]/@href + title_xpath: .//span[contains(@class,'result-title-txt')]/text() + content_xpath: ./p[@id='result-hightlight'] + disabled: true + about: + website: https://searchmysite.net + + - name: sepiasearch + engine: sepiasearch + shortcut: sep + + - name: soundcloud + engine: soundcloud + shortcut: sc + + - name: stackoverflow + engine: stackexchange + shortcut: st + api_site: 'stackoverflow' + categories: [it, q&a] + + - name: askubuntu + engine: stackexchange + shortcut: ubuntu + api_site: 'askubuntu' + categories: [it, q&a] + + - name: internetarchivescholar + engine: internet_archive_scholar + shortcut: ias + timeout: 15.0 + + - name: superuser + engine: stackexchange + shortcut: su + api_site: 'superuser' + categories: [it, q&a] + + - name: discuss.python + engine: discourse + shortcut: dpy + base_url: 'https://discuss.python.org' + categories: [it, q&a] + disabled: true + + - name: caddy.community + engine: discourse + shortcut: caddy + base_url: 'https://caddy.community' + categories: [it, q&a] + disabled: true + + - name: pi-hole.community + engine: discourse + shortcut: pi + categories: [it, q&a] + base_url: 'https://discourse.pi-hole.net' + disabled: true + + - name: searchcode code + engine: searchcode_code + shortcut: scc + disabled: true + + # - name: searx + # engine: searx_engine + # shortcut: se + # instance_urls : + # - http://127.0.0.1:8888/ + # - ... + # disabled: true + + - name: semantic scholar + engine: semantic_scholar + disabled: true + shortcut: se + + # Spotify needs API credentials + # - name: spotify + # engine: spotify + # shortcut: stf + # api_client_id: ******* + # api_client_secret: ******* + + # - name: solr + # engine: solr + # shortcut: slr + # base_url: http://localhost:8983 + # collection: collection_name + # sort: '' # sorting: asc or desc + # field_list: '' # comma separated list of field names to display on the UI + # default_fields: '' # default field to query + # query_fields: '' # query fields + # enable_http: true + + # - name: springer nature + # engine: springer + # # get your API key from: https://dev.springernature.com/signup + # # working API key, for test & debug: "a69685087d07eca9f13db62f65b8f601" + # api_key: 'unset' + # shortcut: springer + # timeout: 15.0 + + - name: startpage + engine: startpage + shortcut: sp + timeout: 6.0 + disabled: true + additional_tests: + rosebud: *test_rosebud + + - name: tokyotoshokan + engine: tokyotoshokan + shortcut: tt + timeout: 6.0 + disabled: true + + - name: solidtorrents + engine: solidtorrents + shortcut: solid + timeout: 4.0 + base_url: + - https://solidtorrents.to + - https://bitsearch.to + + # For this demo of the sqlite engine download: + # https://liste.mediathekview.de/filmliste-v2.db.bz2 + # and unpack into searx/data/filmliste-v2.db + # Query to test: "!demo concert" + # + # - name: demo + # engine: sqlite + # shortcut: demo + # categories: general + # result_template: default.html + # database: searx/data/filmliste-v2.db + # query_str: >- + # SELECT title || ' (' || time(duration, 'unixepoch') || ')' AS title, + # COALESCE( NULLIF(url_video_hd,''), NULLIF(url_video_sd,''), url_video) AS url, + # description AS content + # FROM film + # WHERE title LIKE :wildcard OR description LIKE :wildcard + # ORDER BY duration DESC + + - name: tagesschau + engine: tagesschau + # when set to false, display URLs from Tagesschau, and not the actual source + # (e.g. NDR, WDR, SWR, HR, ...) + use_source_url: true + shortcut: ts + disabled: true + + - name: tmdb + engine: xpath + paging: true + categories: movies + search_url: https://www.themoviedb.org/search?page={pageno}&query={query} + results_xpath: //div[contains(@class,"movie") or contains(@class,"tv")]//div[contains(@class,"card")] + url_xpath: .//div[contains(@class,"poster")]/a/@href + thumbnail_xpath: .//img/@src + title_xpath: .//div[contains(@class,"title")]//h2 + content_xpath: .//div[contains(@class,"overview")] + shortcut: tm + disabled: true + + # Requires Tor + - name: torch + engine: xpath + paging: true + search_url: + http://xmh57jrknzkhv6y3ls3ubitzfqnkrwxhopf5aygthi7d6rplyvk3noyd.onion/cgi-bin/omega/omega?P={query}&DEFAULTOP=and + results_xpath: //table//tr + url_xpath: ./td[2]/a + title_xpath: ./td[2]/b + content_xpath: ./td[2]/small + categories: onions + enable_http: true + shortcut: tch + + # torznab engine lets you query any torznab compatible indexer. Using this + # engine in combination with Jackett opens the possibility to query a lot of + # public and private indexers directly from SearXNG. More details at: + # https://docs.searxng.org/dev/engines/online/torznab.html + # + # - name: Torznab EZTV + # engine: torznab + # shortcut: eztv + # base_url: http://localhost:9117/api/v2.0/indexers/eztv/results/torznab + # enable_http: true # if using localhost + # api_key: xxxxxxxxxxxxxxx + # show_magnet_links: true + # show_torrent_files: false + # # https://github.com/Jackett/Jackett/wiki/Jackett-Categories + # torznab_categories: # optional + # - 2000 + # - 5000 + + # tmp suspended - too slow, too many errors + # - name: urbandictionary + # engine : xpath + # search_url : https://www.urbandictionary.com/define.php?term={query} + # url_xpath : //*[@class="word"]/@href + # title_xpath : //*[@class="def-header"] + # content_xpath: //*[@class="meaning"] + # shortcut: ud + + - name: unsplash + engine: unsplash + shortcut: us + + - name: yandex music + engine: yandex_music + shortcut: ydm + disabled: true + # https://yandex.com/support/music/access.html + inactive: true + + - name: yahoo + engine: yahoo + shortcut: yh + disabled: true + + - name: yahoo news + engine: yahoo_news + shortcut: yhn + + - name: youtube + shortcut: yt + # You can use the engine using the official stable API, but you need an API + # key See: https://console.developers.google.com/project + # + # engine: youtube_api + # api_key: 'apikey' # required! + # + # Or you can use the html non-stable engine, activated by default + engine: youtube_noapi + + - name: dailymotion + engine: dailymotion + shortcut: dm + + - name: vimeo + engine: vimeo + shortcut: vm + disabled: true + + - name: wiby + engine: json_engine + paging: true + search_url: https://wiby.me/json/?q={query}&p={pageno} + url_query: URL + title_query: Title + content_query: Snippet + categories: [general, web] + shortcut: wib + disabled: true + about: + website: https://wiby.me/ + + - name: alexandria + engine: json_engine + shortcut: alx + categories: general + paging: true + search_url: https://api.alexandria.org/?a=1&q={query}&p={pageno} + results_query: results + title_query: title + url_query: url + content_query: snippet + timeout: 1.5 + disabled: true + about: + website: https://alexandria.org/ + official_api_documentation: https://github.com/alexandria-org/alexandria-api/raw/master/README.md + use_official_api: true + require_api_key: false + results: JSON + + - name: wikibooks + engine: mediawiki + weight: 0.5 + shortcut: wb + categories: [general, wikimedia] + base_url: "https://{language}.wikibooks.org/" + search_type: text + disabled: true + about: + website: https://www.wikibooks.org/ + wikidata_id: Q367 + + - name: wikinews + engine: mediawiki + shortcut: wn + categories: [news, wikimedia] + base_url: "https://{language}.wikinews.org/" + search_type: text + srsort: create_timestamp_desc + about: + website: https://www.wikinews.org/ + wikidata_id: Q964 + + - name: wikiquote + engine: mediawiki + weight: 0.5 + shortcut: wq + categories: [general, wikimedia] + base_url: "https://{language}.wikiquote.org/" + search_type: text + disabled: true + additional_tests: + rosebud: *test_rosebud + about: + website: https://www.wikiquote.org/ + wikidata_id: Q369 + + - name: wikisource + engine: mediawiki + weight: 0.5 + shortcut: ws + categories: [general, wikimedia] + base_url: "https://{language}.wikisource.org/" + search_type: text + disabled: true + about: + website: https://www.wikisource.org/ + wikidata_id: Q263 + + - name: wikispecies + engine: mediawiki + shortcut: wsp + categories: [general, science, wikimedia] + base_url: "https://species.wikimedia.org/" + search_type: text + disabled: true + about: + website: https://species.wikimedia.org/ + wikidata_id: Q13679 + tests: + wikispecies: + matrix: + query: "Campbell, L.I. et al. 2011: MicroRNAs" + lang: en + result_container: + - not_empty + - ['one_title_contains', 'Tardigrada'] + test: + - unique_results + + - name: wiktionary + engine: mediawiki + shortcut: wt + categories: [dictionaries, wikimedia] + base_url: "https://{language}.wiktionary.org/" + search_type: text + about: + website: https://www.wiktionary.org/ + wikidata_id: Q151 + + - name: wikiversity + engine: mediawiki + weight: 0.5 + shortcut: wv + categories: [general, wikimedia] + base_url: "https://{language}.wikiversity.org/" + search_type: text + disabled: true + about: + website: https://www.wikiversity.org/ + wikidata_id: Q370 + + - name: wikivoyage + engine: mediawiki + weight: 0.5 + shortcut: wy + categories: [general, wikimedia] + base_url: "https://{language}.wikivoyage.org/" + search_type: text + disabled: true + about: + website: https://www.wikivoyage.org/ + wikidata_id: Q373 + + - name: wikicommons.images + engine: wikicommons + shortcut: wc + categories: images + search_type: images + number_of_results: 10 + + - name: wikicommons.videos + engine: wikicommons + shortcut: wcv + categories: videos + search_type: videos + number_of_results: 10 + + - name: wikicommons.audio + engine: wikicommons + shortcut: wca + categories: music + search_type: audio + number_of_results: 10 + + - name: wikicommons.files + engine: wikicommons + shortcut: wcf + categories: files + search_type: files + number_of_results: 10 + + - name: wolframalpha + shortcut: wa + # You can use the engine using the official stable API, but you need an API + # key. See: https://products.wolframalpha.com/api/ + # + # engine: wolframalpha_api + # api_key: '' + # + # Or you can use the html non-stable engine, activated by default + engine: wolframalpha_noapi + timeout: 6.0 + categories: general + disabled: true + + - name: dictzone + engine: dictzone + shortcut: dc + + - name: mymemory translated + engine: translated + shortcut: tl + timeout: 5.0 + # You can use without an API key, but you are limited to 1000 words/day + # See: https://mymemory.translated.net/doc/usagelimits.php + # api_key: '' + + # Required dependency: mysql-connector-python + # - name: mysql + # engine: mysql_server + # database: mydatabase + # username: user + # password: pass + # limit: 10 + # query_str: 'SELECT * from mytable WHERE fieldname=%(query)s' + # shortcut: mysql + + - name: 1337x + engine: 1337x + shortcut: 1337x + disabled: true + + - name: duden + engine: duden + shortcut: du + disabled: true + + - name: seznam + shortcut: szn + engine: seznam + disabled: true + + # - name: deepl + # engine: deepl + # shortcut: dpl + # # You can use the engine using the official stable API, but you need an API key + # # See: https://www.deepl.com/pro-api?cta=header-pro-api + # api_key: '' # required! + # timeout: 5.0 + # disabled: true + + - name: mojeek + shortcut: mjk + engine: mojeek + categories: [general, web] + disabled: true + + - name: mojeek images + shortcut: mjkimg + engine: mojeek + categories: [images, web] + search_type: images + paging: false + disabled: true + + - name: mojeek news + shortcut: mjknews + engine: mojeek + categories: [news, web] + search_type: news + paging: false + disabled: true + + - name: moviepilot + engine: moviepilot + shortcut: mp + disabled: true + + - name: naver + shortcut: nvr + categories: [general, web] + engine: xpath + paging: true + search_url: https://search.naver.com/search.naver?where=webkr&sm=osp_hty&ie=UTF-8&query={query}&start={pageno} + url_xpath: //a[@class="link_tit"]/@href + title_xpath: //a[@class="link_tit"] + content_xpath: //div[@class="total_dsc_wrap"]/a + first_page_num: 1 + page_size: 10 + disabled: true + about: + website: https://www.naver.com/ + wikidata_id: Q485639 + official_api_documentation: https://developers.naver.com/docs/nmt/examples/ + use_official_api: false + require_api_key: false + results: HTML + language: ko + + - name: rubygems + shortcut: rbg + engine: xpath + paging: true + search_url: https://rubygems.org/search?page={pageno}&query={query} + results_xpath: /html/body/main/div/a[@class="gems__gem"] + url_xpath: ./@href + title_xpath: ./span/h2 + content_xpath: ./span/p + suggestion_xpath: /html/body/main/div/div[@class="search__suggestions"]/p/a + first_page_num: 1 + categories: [it, packages] + disabled: true + about: + website: https://rubygems.org/ + wikidata_id: Q1853420 + official_api_documentation: https://guides.rubygems.org/rubygems-org-api/ + use_official_api: false + require_api_key: false + results: HTML + + - name: peertube + engine: peertube + shortcut: ptb + paging: true + # alternatives see: https://instances.joinpeertube.org/instances + # base_url: https://tube.4aem.com + categories: videos + disabled: true + timeout: 6.0 + + - name: mediathekviewweb + engine: mediathekviewweb + shortcut: mvw + disabled: true + + - name: yacy + # https://docs.searxng.org/dev/engines/online/yacy.html + engine: yacy + categories: general + search_type: text + base_url: + - https://yacy.searchlab.eu + # see https://github.com/searxng/searxng/pull/3631#issuecomment-2240903027 + # - https://search.kyun.li + # - https://yacy.securecomcorp.eu + # - https://yacy.myserv.ca + # - https://yacy.nsupdate.info + # - https://yacy.electroncash.de + shortcut: ya + disabled: true + # if you aren't using HTTPS for your local yacy instance disable https + # enable_http: false + search_mode: 'global' + # timeout can be reduced in 'local' search mode + timeout: 5.0 + + - name: yacy images + engine: yacy + network: yacy + categories: images + search_type: image + shortcut: yai + disabled: true + # timeout can be reduced in 'local' search mode + timeout: 5.0 + + - name: rumble + engine: rumble + shortcut: ru + base_url: https://rumble.com/ + paging: true + categories: videos + disabled: true + + - name: livespace + engine: livespace + shortcut: ls + categories: videos + disabled: true + timeout: 5.0 + + - name: wordnik + engine: wordnik + shortcut: def + base_url: https://www.wordnik.com/ + categories: [dictionaries] + timeout: 5.0 + + - name: woxikon.de synonyme + engine: xpath + shortcut: woxi + categories: [dictionaries] + timeout: 5.0 + disabled: true + search_url: https://synonyme.woxikon.de/synonyme/{query}.php + url_xpath: //div[@class="upper-synonyms"]/a/@href + content_xpath: //div[@class="synonyms-list-group"] + title_xpath: //div[@class="upper-synonyms"]/a + no_result_for_http_status: [404] + about: + website: https://www.woxikon.de/ + wikidata_id: # No Wikidata ID + use_official_api: false + require_api_key: false + results: HTML + language: de + + - name: seekr news + engine: seekr + shortcut: senews + categories: news + seekr_category: news + disabled: true + + - name: seekr images + engine: seekr + network: seekr news + shortcut: seimg + categories: images + seekr_category: images + disabled: true + + - name: seekr videos + engine: seekr + network: seekr news + shortcut: sevid + categories: videos + seekr_category: videos + disabled: true + + - name: sjp.pwn + engine: sjp + shortcut: sjp + base_url: https://sjp.pwn.pl/ + timeout: 5.0 + disabled: true + + - name: stract + engine: stract + shortcut: str + disabled: true + + - name: svgrepo + engine: svgrepo + shortcut: svg + timeout: 10.0 + disabled: true + + - name: tootfinder + engine: tootfinder + shortcut: toot + + - name: voidlinux + engine: voidlinux + shortcut: void + disabled: true + + - name: wallhaven + engine: wallhaven + # api_key: abcdefghijklmnopqrstuvwxyz + shortcut: wh + + # wikimini: online encyclopedia for children + # The fulltext and title parameter is necessary for Wikimini because + # sometimes it will not show the results and redirect instead + - name: wikimini + engine: xpath + shortcut: wkmn + search_url: https://fr.wikimini.org/w/index.php?search={query}&title=Sp%C3%A9cial%3ASearch&fulltext=Search + url_xpath: //li/div[@class="mw-search-result-heading"]/a/@href + title_xpath: //li//div[@class="mw-search-result-heading"]/a + content_xpath: //li/div[@class="searchresult"] + categories: general + disabled: true + about: + website: https://wikimini.org/ + wikidata_id: Q3568032 + use_official_api: false + require_api_key: false + results: HTML + language: fr + + - name: wttr.in + engine: wttr + shortcut: wttr + timeout: 9.0 + + - name: yummly + engine: yummly + shortcut: yum + disabled: true + + - name: brave + engine: brave + shortcut: br + time_range_support: true + paging: true + categories: [general, web] + brave_category: search + # brave_spellcheck: true + + - name: brave.images + engine: brave + network: brave + shortcut: brimg + categories: [images, web] + brave_category: images + + - name: brave.videos + engine: brave + network: brave + shortcut: brvid + categories: [videos, web] + brave_category: videos + + - name: brave.news + engine: brave + network: brave + shortcut: brnews + categories: news + brave_category: news + + # - name: brave.goggles + # engine: brave + # network: brave + # shortcut: brgog + # time_range_support: true + # paging: true + # categories: [general, web] + # brave_category: goggles + # Goggles: # required! This should be a URL ending in .goggle + + - name: lib.rs + shortcut: lrs + engine: lib_rs + disabled: true + + - name: sourcehut + shortcut: srht + engine: xpath + paging: true + search_url: https://sr.ht/projects?page={pageno}&search={query} + results_xpath: (//div[@class="event-list"])[1]/div[@class="event"] + url_xpath: ./h4/a[2]/@href + title_xpath: ./h4/a[2] + content_xpath: ./p + first_page_num: 1 + categories: [it, repos] + disabled: true + about: + website: https://sr.ht + wikidata_id: Q78514485 + official_api_documentation: https://man.sr.ht/ + use_official_api: false + require_api_key: false + results: HTML + + - name: goo + shortcut: goo + engine: xpath + paging: true + search_url: https://search.goo.ne.jp/web.jsp?MT={query}&FR={pageno}0 + url_xpath: //div[@class="result"]/p[@class='title fsL1']/a/@href + title_xpath: //div[@class="result"]/p[@class='title fsL1']/a + content_xpath: //p[contains(@class,'url fsM')]/following-sibling::p + first_page_num: 0 + categories: [general, web] + disabled: true + timeout: 4.0 + about: + website: https://search.goo.ne.jp + wikidata_id: Q249044 + use_official_api: false + require_api_key: false + results: HTML + language: ja + + - name: bt4g + engine: bt4g + shortcut: bt4g + + - name: pkg.go.dev + engine: pkg_go_dev + shortcut: pgo + disabled: true + +# Doku engine lets you access to any Doku wiki instance: +# A public one or a privete/corporate one. +# - name: ubuntuwiki +# engine: doku +# shortcut: uw +# base_url: 'https://doc.ubuntu-fr.org' + +# Be careful when enabling this engine if you are +# running a public instance. Do not expose any sensitive +# information. You can restrict access by configuring a list +# of access tokens under tokens. +# - name: git grep +# engine: command +# command: ['git', 'grep', '{{QUERY}}'] +# shortcut: gg +# tokens: [] +# disabled: true +# delimiter: +# chars: ':' +# keys: ['filepath', 'code'] + +# Be careful when enabling this engine if you are +# running a public instance. Do not expose any sensitive +# information. You can restrict access by configuring a list +# of access tokens under tokens. +# - name: locate +# engine: command +# command: ['locate', '{{QUERY}}'] +# shortcut: loc +# tokens: [] +# disabled: true +# delimiter: +# chars: ' ' +# keys: ['line'] + +# Be careful when enabling this engine if you are +# running a public instance. Do not expose any sensitive +# information. You can restrict access by configuring a list +# of access tokens under tokens. +# - name: find +# engine: command +# command: ['find', '.', '-name', '{{QUERY}}'] +# query_type: path +# shortcut: fnd +# tokens: [] +# disabled: true +# delimiter: +# chars: ' ' +# keys: ['line'] + +# Be careful when enabling this engine if you are +# running a public instance. Do not expose any sensitive +# information. You can restrict access by configuring a list +# of access tokens under tokens. +# - name: pattern search in files +# engine: command +# command: ['fgrep', '{{QUERY}}'] +# shortcut: fgr +# tokens: [] +# disabled: true +# delimiter: +# chars: ' ' +# keys: ['line'] + +# Be careful when enabling this engine if you are +# running a public instance. Do not expose any sensitive +# information. You can restrict access by configuring a list +# of access tokens under tokens. +# - name: regex search in files +# engine: command +# command: ['grep', '{{QUERY}}'] +# shortcut: gr +# tokens: [] +# disabled: true +# delimiter: +# chars: ' ' +# keys: ['line'] + +doi_resolvers: + oadoi.org: 'https://oadoi.org/' + doi.org: 'https://doi.org/' + doai.io: 'https://dissem.in/' + sci-hub.se: 'https://sci-hub.se/' + sci-hub.st: 'https://sci-hub.st/' + sci-hub.ru: 'https://sci-hub.ru/' + +default_doi_resolver: 'oadoi.org' diff --git a/api/core/tools/provider/builtin/searxng/docker/uwsgi.ini b/api/core/tools/provider/builtin/searxng/docker/uwsgi.ini new file mode 100644 index 00000000000000..9db3d762649fc5 --- /dev/null +++ b/api/core/tools/provider/builtin/searxng/docker/uwsgi.ini @@ -0,0 +1,54 @@ +[uwsgi] +# Who will run the code +uid = searxng +gid = searxng + +# Number of workers (usually CPU count) +# default value: %k (= number of CPU core, see Dockerfile) +workers = %k + +# Number of threads per worker +# default value: 4 (see Dockerfile) +threads = 4 + +# The right granted on the created socket +chmod-socket = 666 + +# Plugin to use and interpreter config +single-interpreter = true +master = true +plugin = python3 +lazy-apps = true +enable-threads = 4 + +# Module to import +module = searx.webapp + +# Virtualenv and python path +pythonpath = /usr/local/searxng/ +chdir = /usr/local/searxng/searx/ + +# automatically set processes name to something meaningful +auto-procname = true + +# Disable request logging for privacy +disable-logging = true +log-5xx = true + +# Set the max size of a request (request-body excluded) +buffer-size = 8192 + +# No keep alive +# See https://github.com/searx/searx-docker/issues/24 +add-header = Connection: close + +# Follow SIGTERM convention +# See https://github.com/searxng/searxng/issues/3427 +die-on-term + +# uwsgi serves the static files +static-map = /static=/usr/local/searxng/searx/static +# expires set to one day +static-expires = /* 86400 +static-gzip-all = True +offload-threads = 4 diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py index 24b94b5ca4a391..ab354003e6f567 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.py +++ b/api/core/tools/provider/builtin/searxng/searxng.py @@ -17,8 +17,7 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: tool_parameters={ "query": "SearXNG", "limit": 1, - "search_type": "page", - "result_type": "link" + "search_type": "general" }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/searxng/searxng.yaml b/api/core/tools/provider/builtin/searxng/searxng.yaml index 1e8050b035f838..9554c93d5a0c53 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.yaml +++ b/api/core/tools/provider/builtin/searxng/searxng.yaml @@ -6,7 +6,7 @@ identity: zh_Hans: SearXNG description: en_US: A free internet metasearch engine. - zh_Hans: 开源互联网元搜索引擎 + zh_Hans: 开源免费的互联网元搜索引擎 icon: icon.svg tags: - search @@ -18,9 +18,6 @@ credentials_for_provider: label: en_US: SearXNG base URL zh_Hans: SearXNG base URL - help: - en_US: Please input your SearXNG base URL - zh_Hans: 请输入您的 SearXNG base URL placeholder: en_US: Please input your SearXNG base URL zh_Hans: 请输入您的 SearXNG base URL diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py index 5d12553629abfe..dc835a8e8cbd5b 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py @@ -1,4 +1,3 @@ -import json from typing import Any import requests @@ -7,90 +6,11 @@ from core.tools.tool.builtin_tool import BuiltinTool -class SearXNGSearchResults(dict): - """Wrapper for search results.""" - - def __init__(self, data: str): - super().__init__(json.loads(data)) - self.__dict__ = self - - @property - def results(self) -> Any: - return self.get("results", []) - - class SearXNGSearchTool(BuiltinTool): """ Tool for performing a search using SearXNG engine. """ - SEARCH_TYPE: dict[str, str] = { - "page": "general", - "news": "news", - "image": "images", - # "video": "videos", - # "file": "files" - } - LINK_FILED: dict[str, str] = { - "page": "url", - "news": "url", - "image": "img_src", - # "video": "iframe_src", - # "file": "magnetlink" - } - TEXT_FILED: dict[str, str] = { - "page": "content", - "news": "content", - "image": "img_src", - # "video": "iframe_src", - # "file": "magnetlink" - } - - def _invoke_query(self, user_id: str, host: str, query: str, search_type: str, result_type: str, topK: int = 5) -> list[dict]: - """Run query and return the results.""" - - search_type = search_type.lower() - if search_type not in self.SEARCH_TYPE.keys(): - search_type= "page" - - response = requests.get(host, params={ - "q": query, - "format": "json", - "categories": self.SEARCH_TYPE[search_type] - }) - - if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') - - search_results = SearXNGSearchResults(response.text).results[:topK] - - if result_type == 'link': - results = [] - if search_type == "page" or search_type == "news": - for r in search_results: - results.append(self.create_text_message( - text=f'{r["title"]}: {r.get(self.LINK_FILED[search_type], "")}' - )) - elif search_type == "image": - for r in search_results: - results.append(self.create_image_message( - image=r.get(self.LINK_FILED[search_type], "") - )) - else: - for r in search_results: - results.append(self.create_link_message( - link=r.get(self.LINK_FILED[search_type], "") - )) - - return results - else: - text = '' - for i, r in enumerate(search_results): - text += f'{i+1}: {r["title"]} - {r.get(self.TEXT_FILED[search_type], "")}\n' - - return self.create_text_message(text=self.summary(user_id=user_id, content=text)) - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invoke the SearXNG search tool. @@ -103,23 +23,21 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - host = self.runtime.credentials.get('searxng_base_url', None) + host = self.runtime.credentials.get('searxng_base_url') if not host: raise Exception('SearXNG api is required') - - query = tool_parameters.get('query') - if not query: - return self.create_text_message('Please input query') - - num_results = min(tool_parameters.get('num_results', 5), 20) - search_type = tool_parameters.get('search_type', 'page') or 'page' - result_type = tool_parameters.get('result_type', 'text') or 'text' - return self._invoke_query( - user_id=user_id, - host=host, - query=query, - search_type=search_type, - result_type=result_type, - topK=num_results - ) + response = requests.get(host, params={ + "q": tool_parameters.get('query'), + "format": "json", + "categories": tool_parameters.get('search_type', 'general') + }) + + if response.status_code != 200: + raise Exception(f'Error {response.status_code}: {response.text}') + + res = response.json().get("results", []) + if not res: + return self.create_text_message(f"No results found, get response: {response.content}") + + return [self.create_json_message(item) for item in res] diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml b/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml index 0edf1744f4b2f4..a5e448a30375b4 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml @@ -1,13 +1,13 @@ identity: name: searxng_search - author: Tice + author: Junytang label: en_US: SearXNG Search zh_Hans: SearXNG 搜索 description: human: - en_US: Perform searches on SearXNG and get results. - zh_Hans: 在 SearXNG 上进行搜索并获取结果。 + en_US: SearXNG is a free internet metasearch engine which aggregates results from more than 70 search services. + zh_Hans: SearXNG 是一个免费的互联网元搜索引擎,它从70多个不同的搜索服务中聚合搜索结果。 llm: Perform searches on SearXNG and get results. parameters: - name: query @@ -16,9 +16,6 @@ parameters: label: en_US: Query string zh_Hans: 查询语句 - human_description: - en_US: The search query. - zh_Hans: 搜索查询语句。 llm_description: Key words for searching form: llm - name: search_type @@ -27,63 +24,46 @@ parameters: label: en_US: search type zh_Hans: 搜索类型 - pt_BR: search type - human_description: - en_US: search type for page, news or image. - zh_Hans: 选择搜索的类型:网页,新闻,图片。 - pt_BR: search type for page, news or image. - default: Page + default: general options: - - value: Page + - value: general label: - en_US: Page - zh_Hans: 网页 - pt_BR: Page - - value: News + en_US: General + zh_Hans: 综合 + - value: images + label: + en_US: Images + zh_Hans: 图片 + - value: videos + label: + en_US: Videos + zh_Hans: 视频 + - value: news label: en_US: News zh_Hans: 新闻 - pt_BR: News - - value: Image + - value: map label: - en_US: Image - zh_Hans: 图片 - pt_BR: Image - form: form - - name: num_results - type: number - required: true - label: - en_US: Number of query results - zh_Hans: 返回查询数量 - human_description: - en_US: The number of query results. - zh_Hans: 返回查询结果的数量。 - form: form - default: 5 - min: 1 - max: 20 - - name: result_type - type: select - required: true - label: - en_US: result type - zh_Hans: 结果类型 - pt_BR: result type - human_description: - en_US: return a list of links or texts. - zh_Hans: 返回一个连接列表还是纯文本内容。 - pt_BR: return a list of links or texts. - default: text - options: - - value: link + en_US: Map + zh_Hans: 地图 + - value: music + label: + en_US: Music + zh_Hans: 音乐 + - value: it + label: + en_US: It + zh_Hans: 信息技术 + - value: science + label: + en_US: Science + zh_Hans: 科学 + - value: files label: - en_US: Link - zh_Hans: 链接 - pt_BR: Link - - value: text + en_US: Files + zh_Hans: 文件 + - value: social_media label: - en_US: Text - zh_Hans: 文本 - pt_BR: Text + en_US: Social Media + zh_Hans: 社交媒体 form: form diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 5d561911d12564..d990131b5fbbfd 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -2,13 +2,12 @@ from collections.abc import Mapping from copy import deepcopy from enum import Enum -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator from pydantic_core.core_schema import ValidationInfo from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileVar from core.tools.entities.tool_entities import ( ToolDescription, ToolIdentity, @@ -23,6 +22,9 @@ from core.tools.tool_file_manager import ToolFileManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class Tool(BaseModel, ABC): identity: Optional[ToolIdentity] = None @@ -76,7 +78,7 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': description=self.description.model_copy() if self.description else None, runtime=Tool.Runtime(**runtime), ) - + @abstractmethod def tool_provider_type(self) -> ToolProviderType: """ @@ -84,7 +86,7 @@ def tool_provider_type(self) -> ToolProviderType: :return: the tool provider type """ - + def load_variables(self, variables: ToolRuntimeVariablePool): """ load variables from database @@ -99,7 +101,7 @@ def set_image_variable(self, variable_name: str, image_key: str) -> None: """ if not self.variables: return - + self.variables.set_file(self.identity.name, variable_name, image_key) def set_text_variable(self, variable_name: str, text: str) -> None: @@ -108,9 +110,9 @@ def set_text_variable(self, variable_name: str, text: str) -> None: """ if not self.variables: return - + self.variables.set_text(self.identity.name, variable_name, text) - + def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: """ get a variable @@ -120,14 +122,14 @@ def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: """ if not self.variables: return None - + if isinstance(name, Enum): name = name.value - + for variable in self.variables.pool: if variable.name == name: return variable - + return None def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: @@ -138,9 +140,9 @@ def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: """ if not self.variables: return None - + return self.get_variable(self.VARIABLE_KEY.IMAGE) - + def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ get a variable file @@ -151,7 +153,7 @@ def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: variable = self.get_variable(name) if not variable: return None - + if not isinstance(variable, ToolRuntimeImageVariable): return None @@ -160,9 +162,9 @@ def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id) if not file_binary: return None - + return file_binary[0] - + def list_variables(self) -> list[ToolRuntimeVariable]: """ list all variables @@ -171,9 +173,9 @@ def list_variables(self) -> list[ToolRuntimeVariable]: """ if not self.variables: return [] - + return self.variables.pool - + def list_default_image_variables(self) -> list[ToolRuntimeVariable]: """ list all image variables @@ -182,9 +184,9 @@ def list_default_image_variables(self) -> list[ToolRuntimeVariable]: """ if not self.variables: return [] - + result = [] - + for variable in self.variables.pool: if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): result.append(variable) @@ -225,7 +227,7 @@ def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> @abstractmethod def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass - + def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: """ validate the credentials @@ -244,7 +246,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]: :return: the runtime parameters """ return self.parameters or [] - + def get_all_runtime_parameters(self) -> list[ToolParameter]: """ get all runtime parameters @@ -278,7 +280,7 @@ def get_all_runtime_parameters(self) -> list[ToolParameter]: parameters.append(parameter) return parameters - + def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: """ create an image message @@ -286,18 +288,18 @@ def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessa :param image: the url of the image :return: the image message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=image, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, + message=image, save_as=save_as) - - def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage: + + def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, message='', meta={ 'file_var': file_var }, save_as='') - + def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: """ create a link message @@ -305,10 +307,10 @@ def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage :param link: the url of the link :return: the link message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=link, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, + message=link, save_as=save_as) - + def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: """ create a text message @@ -321,7 +323,7 @@ def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage message=text, save_as=save_as ) - + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: """ create a blob message diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index ef9e5b67ae2ab6..564b9d3e14c15e 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,7 +1,7 @@ import logging from mimetypes import guess_extension -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file.file_obj import FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager @@ -27,12 +27,12 @@ def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], # try to download image try: file = ToolFileManager.create_file_by_url( - user_id=user_id, + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message ) - + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' result.append(ToolInvokeMessage( @@ -55,14 +55,14 @@ def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], # if message is str, encode it to bytes if isinstance(message.message, str): message.message = message.message.encode('utf-8') - + file = ToolFileManager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_binary=message.message, mimetype=mimetype ) - + url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) # check if file is image @@ -81,7 +81,7 @@ def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], meta=message.meta.copy() if message.meta is not None else {}, )) elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - file_var: FileVar = message.meta.get('file_var') + file_var = message.meta.get('file_var') if file_var: if file_var.transfer_method == FileTransferMethod.TOOL_FILE: url = cls.get_tool_file_url(file_var.related_id, file_var.extension) @@ -103,7 +103,7 @@ def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], result.append(message) return result - + @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: - return f'/files/tools/{tool_file_id}{extension or ".bin"}' \ No newline at end of file + return f'/files/tools/{tool_file_id}{extension or ".bin"}' diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 996aae94c20e27..025453567bfc1b 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -4,13 +4,14 @@ from pydantic import BaseModel -from models.workflow import WorkflowNodeExecutionStatus +from models import WorkflowNodeExecutionStatus class NodeType(Enum): """ Node Types. """ + START = 'start' END = 'end' ANSWER = 'answer' @@ -23,10 +24,12 @@ class NodeType(Enum): HTTP_REQUEST = 'http-request' TOOL = 'tool' VARIABLE_AGGREGATOR = 'variable-aggregator' + # TODO: merge this into VARIABLE_AGGREGATOR VARIABLE_ASSIGNER = 'variable-assigner' LOOP = 'loop' ITERATION = 'iteration' PARAMETER_EXTRACTOR = 'parameter-extractor' + CONVERSATION_VARIABLE_ASSIGNER = 'assigner' @classmethod def value_of(cls, value: str) -> 'NodeType': @@ -42,33 +45,11 @@ def value_of(cls, value: str) -> 'NodeType': raise ValueError(f'invalid node type value {value}') -class SystemVariable(Enum): - """ - System Variables. - """ - QUERY = 'query' - FILES = 'files' - CONVERSATION_ID = 'conversation_id' - USER_ID = 'user_id' - - @classmethod - def value_of(cls, value: str) -> 'SystemVariable': - """ - Get value of given system variable. - - :param value: system variable value - :return: system variable - """ - for system_variable in cls: - if system_variable.value == value: - return system_variable - raise ValueError(f'invalid system variable value {value}') - - class NodeRunMetadataKey(Enum): """ Node Run Metadata Key. """ + TOTAL_TOKENS = 'total_tokens' TOTAL_PRICE = 'total_price' CURRENCY = 'currency' @@ -81,6 +62,7 @@ class NodeRunResult(BaseModel): """ Node Run Result. """ + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING inputs: Optional[Mapping[str, Any]] = None # node inputs diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index a27b4261e486bc..9fe3356faa2ef5 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -6,13 +6,14 @@ from core.app.segments import Segment, Variable, factory from core.file.file_obj import FileVar -from core.workflow.entities.node_entities import SystemVariable +from core.workflow.enums import SystemVariable VariableValue = Union[str, int, float, dict, list, FileVar] SYSTEM_VARIABLE_NODE_ID = 'sys' ENVIRONMENT_VARIABLE_NODE_ID = 'env' +CONVERSATION_VARIABLE_NODE_ID = 'conversation' class VariablePool: @@ -21,6 +22,7 @@ def __init__( system_variables: Mapping[SystemVariable, Any], user_inputs: Mapping[str, Any], environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable] | None = None, ) -> None: # system variables # for example: @@ -44,9 +46,13 @@ def __init__( self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) # Add environment variables to the variable pool - for var in environment_variables or []: + for var in environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) + # Add conversation variables to the variable pool + for var in conversation_variables or []: + self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + def add(self, selector: Sequence[str], value: Any, /) -> None: """ Adds a variable to the variable pool. diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py new file mode 100644 index 00000000000000..4757cf32f88988 --- /dev/null +++ b/api/core/workflow/enums.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class SystemVariable(str, Enum): + """ + System Variables. + """ + QUERY = 'query' + FILES = 'files' + CONVERSATION_ID = 'conversation_id' + USER_ID = 'user_id' + DIALOGUE_COUNT = 'dialogue_count' + + @classmethod + def value_of(cls, value: str): + """ + Get value of given system variable. + + :param value: system variable value + :return: system variable + """ + for system_variable in cls: + if system_variable.value == value: + return system_variable + raise ValueError(f'invalid system variable value {value}') diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index d8c812e7ef1244..3d9cf52771e146 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -8,6 +8,7 @@ from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from models import WorkflowNodeExecutionStatus class UserFrom(Enum): @@ -91,14 +92,19 @@ def run(self, variable_pool: VariablePool) -> NodeRunResult: :param variable_pool: variable pool :return: """ - result = self._run( - variable_pool=variable_pool - ) - - self.node_run_result = result - return result - - def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None: + try: + result = self._run( + variable_pool=variable_pool + ) + self.node_run_result = result + return result + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + + def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None: """ Publish text chunk :param text: chunk text diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index bbe5f9ad43f561..1facf8a4f4a4b5 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -133,9 +133,6 @@ def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVa """ files = [] mimetype, file_binary = response.extract_file() - # if not image, return directly - if 'image' not in mimetype: - return files if mimetype: # extract filename from url diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 4431259a57543b..c20e0d45062f51 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,14 +1,13 @@ import json from collections.abc import Generator from copy import deepcopy -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMUsage @@ -23,8 +22,9 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, @@ -38,6 +38,10 @@ from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus +if TYPE_CHECKING: + from core.file.file_obj import FileVar + + class LLMNode(BaseNode): _node_data_cls = LLMNodeData @@ -70,7 +74,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_inputs = {} # fetch files - files: list[FileVar] = self._fetch_files(node_data, variable_pool) + files = self._fetch_files(node_data, variable_pool) if files: node_inputs['#files#'] = [file.to_dict() for file in files] @@ -201,8 +205,8 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage usage = LLMUsage.empty_usage() return full_text, usage - - def _transform_chat_messages(self, + + def _transform_chat_messages(self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ @@ -249,13 +253,13 @@ def parse_dict(d: dict) -> str: # check if it's a context structure if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: return d['content'] - + # else, parse the dict try: return json.dumps(d, ensure_ascii=False) except Exception: return str(d) - + if isinstance(value, str): value = value elif isinstance(value, list): @@ -321,7 +325,7 @@ def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> return inputs - def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: + def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: """ Fetch files :param node_data: node data @@ -520,7 +524,7 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData, query: Optional[str], query_prompt_template: Optional[str], inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 87bfa5beae880b..554e3b6074ed58 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -2,19 +2,20 @@ from os import path from typing import Any, cast -from core.app.segments import parser +from core.app.segments import ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser -from models.workflow import WorkflowNodeExecutionStatus +from models import WorkflowNodeExecutionStatus class ToolNode(BaseNode): @@ -140,9 +141,9 @@ def _generate_parameters( return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - # FIXME: ensure this is a ArrayVariable contains FileVariable. variable = variable_pool.get(['sys', SystemVariable.FILES.value]) - return [file_var.value for file_var in variable.value] if variable else [] + assert isinstance(variable, ArrayAnyVariable) + return list(variable.value) if variable else [] def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): """ diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py new file mode 100644 index 00000000000000..552cc367f2674f --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -0,0 +1,109 @@ +from collections.abc import Sequence +from enum import Enum +from typing import Optional, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.segments import SegmentType, Variable, factory +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from extensions.ext_database import db +from models import ConversationVariable, WorkflowNodeExecutionStatus + + +class VariableAssignerNodeError(Exception): + pass + + +class WriteMode(str, Enum): + OVER_WRITE = 'over-write' + APPEND = 'append' + CLEAR = 'clear' + + +class VariableAssignerData(BaseNodeData): + title: str = 'Variable Assigner' + desc: Optional[str] = 'Assign a value to a variable' + assigned_variable_selector: Sequence[str] + write_mode: WriteMode + input_variable_selector: Sequence[str] + + +class VariableAssignerNode(BaseNode): + _node_data_cls: type[BaseNodeData] = VariableAssignerData + _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + data = cast(VariableAssignerData, self.node_data) + + # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject + original_variable = variable_pool.get(data.assigned_variable_selector) + if not isinstance(original_variable, Variable): + raise VariableAssignerNodeError('assigned variable not found') + + match data.write_mode: + case WriteMode.OVER_WRITE: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_variable = original_variable.model_copy(update={'value': income_value.value}) + + case WriteMode.APPEND: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_value = original_variable.value + [income_value.value] + updated_variable = original_variable.model_copy(update={'value': updated_value}) + + case WriteMode.CLEAR: + income_value = get_zero_value(original_variable.value_type) + updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) + + case _: + raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') + + # Over write the variable. + variable_pool.add(data.assigned_variable_selector, updated_variable) + + # Update conversation variable. + # TODO: Find a better way to use the database. + conversation_id = variable_pool.get(['sys', 'conversation_id']) + if not conversation_id: + raise VariableAssignerNodeError('conversation_id not found') + update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + 'value': income_value.to_object(), + }, + ) + + +def update_conversation_variable(conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(db.engine) as session: + row = session.scalar(stmt) + if not row: + raise VariableAssignerNodeError('conversation variable not found in the database') + row.data = variable.model_dump_json() + session.commit() + + +def get_zero_value(t: SegmentType): + match t: + case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: + return factory.build_segment([]) + case SegmentType.OBJECT: + return factory.build_segment({}) + case SegmentType.STRING: + return factory.build_segment('') + case SegmentType.NUMBER: + return factory.build_segment(0) + case _: + raise VariableAssignerNodeError(f'unsupported variable type: {t}') diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index bd2b3eafa7a8b1..3157eedfee5238 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -3,13 +3,13 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional, cast +import contexts from configs import dify_config -from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.errors import WorkflowNodeRunFailedError @@ -30,6 +30,7 @@ from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode +from core.workflow.nodes.variable_assigner import VariableAssignerNode from extensions.ext_database import db from models.workflow import ( Workflow, @@ -51,7 +52,8 @@ NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, NodeType.ITERATION: IterationNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, + NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, } logger = logging.getLogger(__name__) @@ -94,19 +96,18 @@ def run_workflow( user_id: str, user_from: UserFrom, invoke_from: InvokeFrom, - user_inputs: Mapping[str, Any], - system_inputs: Mapping[SystemVariable, Any], callbacks: Sequence[WorkflowCallback], - call_depth: int = 0 + call_depth: int = 0, + variable_pool: VariablePool | None = None, ) -> None: """ :param workflow: Workflow instance :param user_id: user id :param user_from: user from - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files + :param invoke_from: invoke from :param callbacks: workflow callbacks :param call_depth: call depth + :param variable_pool: variable pool """ # fetch workflow graph graph = workflow.graph_dict @@ -122,18 +123,14 @@ def run_workflow( if not isinstance(graph.get('edges'), list): raise ValueError('edges in workflow graph must be a list') - # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - ) workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) # init workflow run state + if not variable_pool: + variable_pool = contexts.workflow_variable_pool.get() workflow_run_state = WorkflowRunState( workflow=workflow, start_at=time.perf_counter(), @@ -403,6 +400,7 @@ def single_step_run_workflow_node(self, workflow: Workflow, system_variables={}, user_inputs={}, environment_variables=workflow.environment_variables, + conversation_variables=workflow.conversation_variables, ) if node_cls is None: @@ -468,6 +466,7 @@ def single_step_run_iteration_workflow_node(self, workflow: Workflow, system_variables={}, user_inputs={}, environment_variables=workflow.environment_variables, + conversation_variables=workflow.conversation_variables, ) # variable selector to variable mapping diff --git a/api/events/app_event.py b/api/events/app_event.py index 67a5982527f7b6..f2ce71bbbb3632 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -1,13 +1,13 @@ from blinker import signal # sender: app -app_was_created = signal('app-was-created') +app_was_created = signal("app-was-created") # sender: app, kwargs: app_model_config -app_model_config_was_updated = signal('app-model-config-was-updated') +app_model_config_was_updated = signal("app-model-config-was-updated") # sender: app, kwargs: published_workflow -app_published_workflow_was_updated = signal('app-published-workflow-was-updated') +app_published_workflow_was_updated = signal("app-published-workflow-was-updated") # sender: app, kwargs: synced_draft_workflow -app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced') +app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced") diff --git a/api/events/dataset_event.py b/api/events/dataset_event.py index d4a2b6f313c13a..750b7424e2347b 100644 --- a/api/events/dataset_event.py +++ b/api/events/dataset_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: dataset -dataset_was_deleted = signal('dataset-was-deleted') +dataset_was_deleted = signal("dataset-was-deleted") diff --git a/api/events/document_event.py b/api/events/document_event.py index f95326630b2b7a..2c5a416a5e0c91 100644 --- a/api/events/document_event.py +++ b/api/events/document_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: document -document_was_deleted = signal('document-was-deleted') +document_was_deleted = signal("document-was-deleted") diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 42f1c70614c49a..7caa2d1cc9f3f2 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -5,5 +5,11 @@ @dataset_was_deleted.connect def handle(sender, **kwargs): dataset = sender - clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, - dataset.index_struct, dataset.collection_binding_id, dataset.doc_form) + clean_dataset_task.delay( + dataset.id, + dataset.tenant_id, + dataset.indexing_technique, + dataset.index_struct, + dataset.collection_binding_id, + dataset.doc_form, + ) diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py index 24022da15f81ee..00a66f50ad9319 100644 --- a/api/events/event_handlers/clean_when_document_deleted.py +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -5,7 +5,7 @@ @document_was_deleted.connect def handle(sender, **kwargs): document_id = sender - dataset_id = kwargs.get('dataset_id') - doc_form = kwargs.get('doc_form') - file_id = kwargs.get('file_id') + dataset_id = kwargs.get("dataset_id") + doc_form = kwargs.get("doc_form") + file_id = kwargs.get("file_id") clean_document_task.delay(document_id, dataset_id, doc_form, file_id) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 68dae5a5537cd7..72a135e73d4ca5 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -14,21 +14,25 @@ @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get('document_ids', None) + document_ids = kwargs.get("document_ids", None) documents = [] start_at = time.perf_counter() for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document) + .filter( + Document.id == document_id, + Document.dataset_id == dataset_id, + ) + .first() + ) if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) @@ -38,8 +42,8 @@ def handle(sender, **kwargs): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py index 31084ce0fe8bdc..57412cc4ad0af2 100644 --- a/api/events/event_handlers/create_installed_app_when_app_created.py +++ b/api/events/event_handlers/create_installed_app_when_app_created.py @@ -10,7 +10,7 @@ def handle(sender, **kwargs): installed_app = InstalledApp( tenant_id=app.tenant_id, app_id=app.id, - app_owner_tenant_id=app.tenant_id + app_owner_tenant_id=app.tenant_id, ) db.session.add(installed_app) db.session.commit() diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index a3dcda61388826..95b155b3f5560f 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -7,7 +7,7 @@ def handle(sender, **kwargs): """Create site record when an app is created.""" app = sender - account = kwargs.get('account') + account = kwargs.get("account") site = Site( app_id=app.id, title=app.name, @@ -15,8 +15,8 @@ def handle(sender, **kwargs): icon = app.icon, icon_background = app.icon_background, default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) + customize_token_strategy="not_allow", + code=Site.generate_code(16), ) db.session.add(site) diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 8cf52bf8f5d0b8..843a2320968ced 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -8,7 +8,7 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get("application_generate_entity") if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): return @@ -39,7 +39,7 @@ def handle(sender, **kwargs): elif quota_unit == QuotaUnit.CREDITS: used_quota = 1 - if 'gpt-4' in model_config.model: + if "gpt-4" in model_config.model: used_quota = 20 else: used_quota = 1 @@ -50,6 +50,6 @@ def handle(sender, **kwargs): Provider.provider_name == model_config.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + used_quota}) + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) db.session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 1f6da34ee24d56..f96bb5ef74b62e 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -8,8 +8,8 @@ @app_draft_workflow_was_synced.connect def handle(sender, **kwargs): app = sender - for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []): - if node_data.get('data', {}).get('type') == NodeType.TOOL.value: + for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []): + if node_data.get("data", {}).get("type") == NodeType.TOOL.value: try: tool_entity = ToolEntity(**node_data["data"]) tool_runtime = ToolManager.get_tool_runtime( @@ -23,7 +23,7 @@ def handle(sender, **kwargs): tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, provider_type=tool_entity.provider_type, - identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}' + identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}', ) manager.delete_tool_parameters_cache() except: diff --git a/api/events/event_handlers/document_index_event.py b/api/events/event_handlers/document_index_event.py index 9c4e055debdd9e..3d463fe5b35acf 100644 --- a/api/events/event_handlers/document_index_event.py +++ b/api/events/event_handlers/document_index_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: document -document_index_created = signal('document-index-created') +document_index_created = signal("document-index-created") diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 2b202c53d0b883..59375b1a0b1a81 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -7,13 +7,11 @@ @app_model_config_was_updated.connect def handle(sender, **kwargs): app = sender - app_model_config = kwargs.get('app_model_config') + app_model_config = kwargs.get("app_model_config") dataset_ids = get_dataset_ids_from_model_config(app_model_config) - app_dataset_joins = db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id - ).all() + app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids = [] if not app_dataset_joins: @@ -29,16 +27,12 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id, - AppDatasetJoin.dataset_id == dataset_id + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() if added_dataset_ids: for dataset_id in added_dataset_ids: - app_dataset_join = AppDatasetJoin( - app_id=app.id, - dataset_id=dataset_id - ) + app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id) db.session.add(app_dataset_join) db.session.commit() @@ -51,7 +45,7 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: agent_mode = app_model_config.agent_mode_dict - tools = agent_mode.get('tools', []) or [] + tools = agent_mode.get("tools", []) or [] for tool in tools: if len(list(tool.keys())) != 1: continue @@ -63,11 +57,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: # get dataset from dataset_configs dataset_configs = app_model_config.dataset_configs_dict - datasets = dataset_configs.get('datasets', {}) or {} - for dataset in datasets.get('datasets', []) or []: + datasets = dataset_configs.get("datasets", {}) or {} + for dataset in datasets.get("datasets", []) or []: keys = list(dataset.keys()) - if len(keys) == 1 and keys[0] == 'dataset': - if dataset['dataset'].get('id'): - dataset_ids.add(dataset['dataset'].get('id')) + if len(keys) == 1 and keys[0] == "dataset": + if dataset["dataset"].get("id"): + dataset_ids.add(dataset["dataset"].get("id")) return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 996b1e96910b93..333b85ecb2907a 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -11,13 +11,11 @@ @app_published_workflow_was_updated.connect def handle(sender, **kwargs): app = sender - published_workflow = kwargs.get('published_workflow') + published_workflow = kwargs.get("published_workflow") published_workflow = cast(Workflow, published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow) - app_dataset_joins = db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id - ).all() + app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids = [] if not app_dataset_joins: @@ -33,16 +31,12 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id, - AppDatasetJoin.dataset_id == dataset_id + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() if added_dataset_ids: for dataset_id in added_dataset_ids: - app_dataset_join = AppDatasetJoin( - app_id=app.id, - dataset_id=dataset_id - ) + app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id) db.session.add(app_dataset_join) db.session.commit() @@ -54,18 +48,19 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: if not graph: return dataset_ids - nodes = graph.get('nodes', []) + nodes = graph.get("nodes", []) # fetch all knowledge retrieval nodes - knowledge_retrieval_nodes = [node for node in nodes - if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value] + knowledge_retrieval_nodes = [ + node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value + ] if not knowledge_retrieval_nodes: return dataset_ids for node in knowledge_retrieval_nodes: try: - node_data = KnowledgeRetrievalNodeData(**node.get('data', {})) + node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) dataset_ids.update(node_data.dataset_ids) except Exception as e: continue diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index 6188f1a0850ac4..a80572c0debb1a 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -9,13 +9,13 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get("application_generate_entity") if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): return db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, - Provider.provider_name == application_generate_entity.model_conf.provider - ).update({'last_used': datetime.now(timezone.utc).replace(tzinfo=None)}) + Provider.provider_name == application_generate_entity.model_conf.provider, + ).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)}) db.session.commit() diff --git a/api/events/message_event.py b/api/events/message_event.py index 21da83f2496af5..6576c35c453c95 100644 --- a/api/events/message_event.py +++ b/api/events/message_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: message, kwargs: conversation -message_was_created = signal('message-was-created') +message_was_created = signal("message-was-created") diff --git a/api/events/tenant_event.py b/api/events/tenant_event.py index 942f709917bf61..d99feaac40896d 100644 --- a/api/events/tenant_event.py +++ b/api/events/tenant_event.py @@ -1,7 +1,7 @@ from blinker import signal # sender: tenant -tenant_was_created = signal('tenant-was-created') +tenant_was_created = signal("tenant-was-created") # sender: tenant -tenant_was_updated = signal('tenant-was-updated') +tenant_was_updated = signal("tenant-was-updated") diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index ae9a07534084e7..f5ec7c1759cb9d 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -17,7 +17,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: backend=app.config["CELERY_BACKEND"], task_ignore_result=True, ) - + # Add SSL options to the Celery configuration ssl_options = { "ssl_cert_reqs": None, @@ -35,7 +35,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: celery_app.conf.update( broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration ) - + celery_app.set_default() app.extensions["celery"] = celery_app @@ -45,18 +45,15 @@ def __call__(self, *args: object, **kwargs: object) -> object: ] day = app.config["CELERY_BEAT_SCHEDULER_TIME"] beat_schedule = { - 'clean_embedding_cache_task': { - 'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task', - 'schedule': timedelta(days=day), + "clean_embedding_cache_task": { + "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", + "schedule": timedelta(days=day), + }, + "clean_unused_datasets_task": { + "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", + "schedule": timedelta(days=day), }, - 'clean_unused_datasets_task': { - 'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task', - 'schedule': timedelta(days=day), - } } - celery_app.conf.update( - beat_schedule=beat_schedule, - imports=imports - ) + celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 1dbaffcfb0dc27..38e67749fcc994 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -2,15 +2,14 @@ def init_app(app: Flask): - if app.config.get('API_COMPRESSION_ENABLED'): + if app.config.get("API_COMPRESSION_ENABLED"): from flask_compress import Compress - app.config['COMPRESS_MIMETYPES'] = [ - 'application/json', - 'image/svg+xml', - 'text/html', + app.config["COMPRESS_MIMETYPES"] = [ + "application/json", + "image/svg+xml", + "text/html", ] compress = Compress() compress.init_app(app) - diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index c248e173a252c7..f6ffa536343afc 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -2,11 +2,11 @@ from sqlalchemy import MetaData POSTGRES_INDEXES_NAMING_CONVENTION = { - 'ix': '%(column_0_label)s_idx', - 'uq': '%(table_name)s_%(column_0_name)s_key', - 'ck': '%(table_name)s_%(constraint_name)s_check', - 'fk': '%(table_name)s_%(column_0_name)s_fkey', - 'pk': '%(table_name)s_pkey', + "ix": "%(column_0_label)s_idx", + "uq": "%(table_name)s_%(column_0_name)s_key", + "ck": "%(table_name)s_%(constraint_name)s_check", + "fk": "%(table_name)s_%(column_0_name)s_fkey", + "pk": "%(table_name)s_pkey", } metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index ec3a5cc112b871..b435294abc23ba 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -14,67 +14,69 @@ def is_inited(self) -> bool: return self._client is not None def init_app(self, app: Flask): - if app.config.get('MAIL_TYPE'): - if app.config.get('MAIL_DEFAULT_SEND_FROM'): - self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM') - - if app.config.get('MAIL_TYPE') == 'resend': - api_key = app.config.get('RESEND_API_KEY') + if app.config.get("MAIL_TYPE"): + if app.config.get("MAIL_DEFAULT_SEND_FROM"): + self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM") + + if app.config.get("MAIL_TYPE") == "resend": + api_key = app.config.get("RESEND_API_KEY") if not api_key: - raise ValueError('RESEND_API_KEY is not set') + raise ValueError("RESEND_API_KEY is not set") - api_url = app.config.get('RESEND_API_URL') + api_url = app.config.get("RESEND_API_URL") if api_url: resend.api_url = api_url resend.api_key = api_key self._client = resend.Emails - elif app.config.get('MAIL_TYPE') == 'smtp': + elif app.config.get("MAIL_TYPE") == "smtp": from libs.smtp import SMTPClient - if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'): - raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type') - if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'): - raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS') + + if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"): + raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") + if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"): + raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") self._client = SMTPClient( - server=app.config.get('SMTP_SERVER'), - port=app.config.get('SMTP_PORT'), - username=app.config.get('SMTP_USERNAME'), - password=app.config.get('SMTP_PASSWORD'), - _from=app.config.get('MAIL_DEFAULT_SEND_FROM'), - use_tls=app.config.get('SMTP_USE_TLS'), - opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS') + server=app.config.get("SMTP_SERVER"), + port=app.config.get("SMTP_PORT"), + username=app.config.get("SMTP_USERNAME"), + password=app.config.get("SMTP_PASSWORD"), + _from=app.config.get("MAIL_DEFAULT_SEND_FROM"), + use_tls=app.config.get("SMTP_USE_TLS"), + opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"), ) else: - raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE'))) + raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE"))) else: - logging.warning('MAIL_TYPE is not set') - + logging.warning("MAIL_TYPE is not set") def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): if not self._client: - raise ValueError('Mail client is not initialized') + raise ValueError("Mail client is not initialized") if not from_ and self._default_send_from: from_ = self._default_send_from if not from_: - raise ValueError('mail from is not set') + raise ValueError("mail from is not set") if not to: - raise ValueError('mail to is not set') + raise ValueError("mail to is not set") if not subject: - raise ValueError('mail subject is not set') + raise ValueError("mail subject is not set") if not html: - raise ValueError('mail html is not set') - - self._client.send({ - "from": from_, - "to": to, - "subject": subject, - "html": html - }) + raise ValueError("mail html is not set") + + self._client.send( + { + "from": from_, + "to": to, + "subject": subject, + "html": html, + } + ) def init_app(app: Flask): diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 23d7768d4d0f5a..d5fb162fd8f2fb 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -6,18 +6,21 @@ def init_app(app): connection_class = Connection - if app.config.get('REDIS_USE_SSL'): + if app.config.get("REDIS_USE_SSL"): connection_class = SSLConnection - redis_client.connection_pool = redis.ConnectionPool(**{ - 'host': app.config.get('REDIS_HOST'), - 'port': app.config.get('REDIS_PORT'), - 'username': app.config.get('REDIS_USERNAME'), - 'password': app.config.get('REDIS_PASSWORD'), - 'db': app.config.get('REDIS_DB'), - 'encoding': 'utf-8', - 'encoding_errors': 'strict', - 'decode_responses': False - }, connection_class=connection_class) + redis_client.connection_pool = redis.ConnectionPool( + **{ + "host": app.config.get("REDIS_HOST"), + "port": app.config.get("REDIS_PORT"), + "username": app.config.get("REDIS_USERNAME"), + "password": app.config.get("REDIS_PASSWORD"), + "db": app.config.get("REDIS_DB"), + "encoding": "utf-8", + "encoding_errors": "strict", + "decode_responses": False, + }, + connection_class=connection_class, + ) - app.extensions['redis'] = redis_client + app.extensions["redis"] = redis_client diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index f05c10bc08926e..227c6635f0eb11 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,16 +5,13 @@ def init_app(app): - if app.config.get('SENTRY_DSN'): + if app.config.get("SENTRY_DSN"): sentry_sdk.init( - dsn=app.config.get('SENTRY_DSN'), - integrations=[ - FlaskIntegration(), - CeleryIntegration() - ], + dsn=app.config.get("SENTRY_DSN"), + integrations=[FlaskIntegration(), CeleryIntegration()], ignore_errors=[HTTPException, ValueError], - traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0), - profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0), - environment=app.config.get('DEPLOY_ENV'), - release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}" + traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), + profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), + environment=app.config.get("DEPLOY_ENV"), + release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", ) diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 38db1c6ce103af..e6c4352577fc3f 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -17,31 +17,19 @@ def __init__(self): self.storage_runner = None def init_app(self, app: Flask): - storage_type = app.config.get('STORAGE_TYPE') - if storage_type == 's3': - self.storage_runner = S3Storage( - app=app - ) - elif storage_type == 'azure-blob': - self.storage_runner = AzureStorage( - app=app - ) - elif storage_type == 'aliyun-oss': - self.storage_runner = AliyunStorage( - app=app - ) - elif storage_type == 'google-storage': - self.storage_runner = GoogleStorage( - app=app - ) - elif storage_type == 'tencent-cos': - self.storage_runner = TencentStorage( - app=app - ) - elif storage_type == 'oci-storage': - self.storage_runner = OCIStorage( - app=app - ) + storage_type = app.config.get("STORAGE_TYPE") + if storage_type == "s3": + self.storage_runner = S3Storage(app=app) + elif storage_type == "azure-blob": + self.storage_runner = AzureStorage(app=app) + elif storage_type == "aliyun-oss": + self.storage_runner = AliyunStorage(app=app) + elif storage_type == "google-storage": + self.storage_runner = GoogleStorage(app=app) + elif storage_type == "tencent-cos": + self.storage_runner = TencentStorage(app=app) + elif storage_type == "oci-storage": + self.storage_runner = OCIStorage(app=app) else: self.storage_runner = LocalStorage(app=app) diff --git a/api/extensions/storage/aliyun_storage.py b/api/extensions/storage/aliyun_storage.py index b81a8691f15457..b962cedc55178d 100644 --- a/api/extensions/storage/aliyun_storage.py +++ b/api/extensions/storage/aliyun_storage.py @@ -8,23 +8,22 @@ class AliyunStorage(BaseStorage): - """Implementation for aliyun storage. - """ + """Implementation for aliyun storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('ALIYUN_OSS_BUCKET_NAME') + self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME") oss_auth_method = aliyun_s3.Auth region = None - if app_config.get('ALIYUN_OSS_AUTH_VERSION') == 'v4': + if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4": oss_auth_method = aliyun_s3.AuthV4 - region = app_config.get('ALIYUN_OSS_REGION') - oss_auth = oss_auth_method(app_config.get('ALIYUN_OSS_ACCESS_KEY'), app_config.get('ALIYUN_OSS_SECRET_KEY')) + region = app_config.get("ALIYUN_OSS_REGION") + oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY")) self.client = aliyun_s3.Bucket( oss_auth, - app_config.get('ALIYUN_OSS_ENDPOINT'), + app_config.get("ALIYUN_OSS_ENDPOINT"), self.bucket_name, connect_timeout=30, region=region, diff --git a/api/extensions/storage/azure_storage.py b/api/extensions/storage/azure_storage.py index af3e7ef84911ff..ca8cbb9188b5c9 100644 --- a/api/extensions/storage/azure_storage.py +++ b/api/extensions/storage/azure_storage.py @@ -9,16 +9,15 @@ class AzureStorage(BaseStorage): - """Implementation for azure storage. - """ + """Implementation for azure storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('AZURE_BLOB_CONTAINER_NAME') - self.account_url = app_config.get('AZURE_BLOB_ACCOUNT_URL') - self.account_name = app_config.get('AZURE_BLOB_ACCOUNT_NAME') - self.account_key = app_config.get('AZURE_BLOB_ACCOUNT_KEY') + self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME") + self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL") + self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME") + self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY") def save(self, filename, data): client = self._sync_client() @@ -39,6 +38,7 @@ def generate(filename: str = filename) -> Generator: blob = client.get_blob_client(container=self.bucket_name, blob=filename) blob_data = blob.download_blob() yield from blob_data.chunks() + return generate(filename) def download(self, filename, target_filepath): @@ -62,17 +62,17 @@ def delete(self, filename): blob_container.delete_blob(filename) def _sync_client(self): - cache_key = 'azure_blob_sas_token_{}_{}'.format(self.account_name, self.account_key) + cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key) cache_result = redis_client.get(cache_key) if cache_result is not None: - sas_token = cache_result.decode('utf-8') + sas_token = cache_result.decode("utf-8") else: sas_token = generate_account_sas( account_name=self.account_name, account_key=self.account_key, resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1) + expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) return BlobServiceClient(account_url=self.account_url, credential=sas_token) diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py index 13d9c3429044c8..c3fe9ec82a5b41 100644 --- a/api/extensions/storage/base_storage.py +++ b/api/extensions/storage/base_storage.py @@ -1,4 +1,5 @@ """Abstract interface for file storage implementations.""" + from abc import ABC, abstractmethod from collections.abc import Generator @@ -6,8 +7,8 @@ class BaseStorage(ABC): - """Interface for file storage. - """ + """Interface for file storage.""" + app = None def __init__(self, app: Flask): diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py index ef6cd69039787c..9ed1fcf0b4e118 100644 --- a/api/extensions/storage/google_storage.py +++ b/api/extensions/storage/google_storage.py @@ -11,16 +11,16 @@ class GoogleStorage(BaseStorage): - """Implementation for google storage. - """ + """Implementation for google storage.""" + def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME') - service_account_json_str = app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64') + self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME") + service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64") # if service_account_json_str is empty, use Application Default Credentials if service_account_json_str: - service_account_json = base64.b64decode(service_account_json_str).decode('utf-8') + service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") # convert str to object service_account_obj = json.loads(service_account_json) self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) @@ -43,9 +43,10 @@ def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) - with closing(blob.open(mode='rb')) as blob_stream: + with closing(blob.open(mode="rb")) as blob_stream: while chunk := blob_stream.read(4096): yield chunk + return generate() def download(self, filename, target_filepath): @@ -60,4 +61,4 @@ def exists(self, filename): def delete(self, filename): bucket = self.client.get_bucket(self.bucket_name) - bucket.delete_blob(filename) \ No newline at end of file + bucket.delete_blob(filename) diff --git a/api/extensions/storage/local_storage.py b/api/extensions/storage/local_storage.py index 389ef12f82bcd8..46ee4bf80f8e6d 100644 --- a/api/extensions/storage/local_storage.py +++ b/api/extensions/storage/local_storage.py @@ -8,21 +8,20 @@ class LocalStorage(BaseStorage): - """Implementation for local storage. - """ + """Implementation for local storage.""" def __init__(self, app: Flask): super().__init__(app) - folder = self.app.config.get('STORAGE_LOCAL_PATH') + folder = self.app.config.get("STORAGE_LOCAL_PATH") if not os.path.isabs(folder): folder = os.path.join(app.root_path, folder) self.folder = folder def save(self, filename, data): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename folder = os.path.dirname(filename) os.makedirs(folder, exist_ok=True) @@ -31,10 +30,10 @@ def save(self, filename, data): f.write(data) def load_once(self, filename: str) -> bytes: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -46,10 +45,10 @@ def load_once(self, filename: str) -> bytes: def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -61,10 +60,10 @@ def generate(filename: str = filename) -> Generator: return generate() def download(self, filename, target_filepath): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -72,17 +71,17 @@ def download(self, filename, target_filepath): shutil.copyfile(filename, target_filepath) def exists(self, filename): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename return os.path.exists(filename) def delete(self, filename): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if os.path.exists(filename): os.remove(filename) diff --git a/api/extensions/storage/oci_storage.py b/api/extensions/storage/oci_storage.py index e78d870950339b..e32fa0a0ae78a9 100644 --- a/api/extensions/storage/oci_storage.py +++ b/api/extensions/storage/oci_storage.py @@ -12,14 +12,14 @@ class OCIStorage(BaseStorage): def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('OCI_BUCKET_NAME') + self.bucket_name = app_config.get("OCI_BUCKET_NAME") self.client = boto3.client( - 's3', - aws_secret_access_key=app_config.get('OCI_SECRET_KEY'), - aws_access_key_id=app_config.get('OCI_ACCESS_KEY'), - endpoint_url=app_config.get('OCI_ENDPOINT'), - region_name=app_config.get('OCI_REGION') - ) + "s3", + aws_secret_access_key=app_config.get("OCI_SECRET_KEY"), + aws_access_key_id=app_config.get("OCI_ACCESS_KEY"), + endpoint_url=app_config.get("OCI_ENDPOINT"), + region_name=app_config.get("OCI_REGION"), + ) def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) @@ -27,9 +27,9 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() + data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -40,12 +40,13 @@ def generate(filename: str = filename) -> Generator: try: with closing(self.client) as client: response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].iter_chunks() + yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise + return generate() def download(self, filename, target_filepath): @@ -61,4 +62,4 @@ def exists(self, filename): return False def delete(self, filename): - self.client.delete_object(Bucket=self.bucket_name, Key=filename) \ No newline at end of file + self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/s3_storage.py b/api/extensions/storage/s3_storage.py index 787596fa791d4a..022ce5b14a7b88 100644 --- a/api/extensions/storage/s3_storage.py +++ b/api/extensions/storage/s3_storage.py @@ -10,24 +10,24 @@ class S3Storage(BaseStorage): - """Implementation for s3 storage. - """ + """Implementation for s3 storage.""" + def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('S3_BUCKET_NAME') - if app_config.get('S3_USE_AWS_MANAGED_IAM'): + self.bucket_name = app_config.get("S3_BUCKET_NAME") + if app_config.get("S3_USE_AWS_MANAGED_IAM"): session = boto3.Session() - self.client = session.client('s3') + self.client = session.client("s3") else: self.client = boto3.client( - 's3', - aws_secret_access_key=app_config.get('S3_SECRET_KEY'), - aws_access_key_id=app_config.get('S3_ACCESS_KEY'), - endpoint_url=app_config.get('S3_ENDPOINT'), - region_name=app_config.get('S3_REGION'), - config=Config(s3={'addressing_style': app_config.get('S3_ADDRESS_STYLE')}) - ) + "s3", + aws_secret_access_key=app_config.get("S3_SECRET_KEY"), + aws_access_key_id=app_config.get("S3_ACCESS_KEY"), + endpoint_url=app_config.get("S3_ENDPOINT"), + region_name=app_config.get("S3_REGION"), + config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}), + ) def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) @@ -35,9 +35,9 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() + data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -48,12 +48,13 @@ def generate(filename: str = filename) -> Generator: try: with closing(self.client) as client: response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].iter_chunks() + yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise + return generate() def download(self, filename, target_filepath): diff --git a/api/extensions/storage/tencent_storage.py b/api/extensions/storage/tencent_storage.py index e2c1ca55e3c434..1d499cd3bcea1c 100644 --- a/api/extensions/storage/tencent_storage.py +++ b/api/extensions/storage/tencent_storage.py @@ -7,18 +7,17 @@ class TencentStorage(BaseStorage): - """Implementation for tencent cos storage. - """ + """Implementation for tencent cos storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('TENCENT_COS_BUCKET_NAME') + self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME") config = CosConfig( - Region=app_config.get('TENCENT_COS_REGION'), - SecretId=app_config.get('TENCENT_COS_SECRET_ID'), - SecretKey=app_config.get('TENCENT_COS_SECRET_KEY'), - Scheme=app_config.get('TENCENT_COS_SCHEME'), + Region=app_config.get("TENCENT_COS_REGION"), + SecretId=app_config.get("TENCENT_COS_SECRET_ID"), + SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"), + Scheme=app_config.get("TENCENT_COS_SCHEME"), ) self.client = CosS3Client(config) @@ -26,19 +25,19 @@ def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) def load_once(self, filename: str) -> bytes: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].get_raw_stream().read() + data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() return data def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].get_stream(chunk_size=4096) + yield from response["Body"].get_stream(chunk_size=4096) return generate() def download(self, filename, target_filepath): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - response['Body'].get_stream_to_file(target_filepath) + response["Body"].get_stream_to_file(target_filepath) def exists(self, filename): return self.client.object_exists(Bucket=self.bucket_name, Key=filename) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index c77808447519fc..379dcc6d16fe56 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -5,7 +5,7 @@ annotation_fields = { "id": fields.String, "question": fields.String, - "answer": fields.Raw(attribute='content'), + "answer": fields.Raw(attribute="content"), "hit_count": fields.Integer, "created_at": TimestampField, # 'account': fields.Nested(simple_account_fields, allow_null=True) @@ -21,8 +21,8 @@ "score": fields.Float, "question": fields.String, "created_at": TimestampField, - "match": fields.String(attribute='annotation_question'), - "response": fields.String(attribute='annotation_content') + "match": fields.String(attribute="annotation_question"), + "response": fields.String(attribute="annotation_content"), } annotation_hit_history_list_fields = { diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index 749e9900de181d..a85d4a34dbe7b1 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -8,16 +8,16 @@ def output(self, key, obj): api_key = obj.api_key # If the length of the api_key is less than 8 characters, show the first and last characters if len(api_key) <= 8: - return api_key[0] + '******' + api_key[-1] + return api_key[0] + "******" + api_key[-1] # If the api_key is greater than 8 characters, show the first three and the last three characters else: - return api_key[:3] + '******' + api_key[-3:] + return api_key[:3] + "******" + api_key[-3:] api_based_extension_fields = { - 'id': fields.String, - 'name': fields.String, - 'api_endpoint': fields.String, - 'api_key': HiddenAPIKey, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "api_endpoint": fields.String, + "api_key": HiddenAPIKey, + "created_at": TimestampField, } diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index a550a161373e73..3aa439520c74eb 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -3,165 +3,161 @@ from libs.helper import AppIconUrlField, TimestampField app_detail_kernel_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon_type': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'icon_url': AppIconUrlField, + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, } related_app_list = { - 'data': fields.List(fields.Nested(app_detail_kernel_fields)), - 'total': fields.Integer, + "data": fields.List(fields.Nested(app_detail_kernel_fields)), + "total": fields.Integer, } model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), - 'text_to_speech': fields.Raw(attribute='text_to_speech_dict'), - 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), - 'annotation_reply': fields.Raw(attribute='annotation_reply_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), - 'external_data_tools': fields.Raw(attribute='external_data_tools_list'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'dataset_query_variable': fields.String, - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw(attribute='agent_mode_dict'), - 'prompt_type': fields.String, - 'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'), - 'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'), - 'dataset_configs': fields.Raw(attribute='dataset_configs_dict'), - 'file_upload': fields.Raw(attribute='file_upload_dict'), - 'created_at': TimestampField + "opening_statement": fields.String, + "suggested_questions": fields.Raw(attribute="suggested_questions_list"), + "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"), + "speech_to_text": fields.Raw(attribute="speech_to_text_dict"), + "text_to_speech": fields.Raw(attribute="text_to_speech_dict"), + "retriever_resource": fields.Raw(attribute="retriever_resource_dict"), + "annotation_reply": fields.Raw(attribute="annotation_reply_dict"), + "more_like_this": fields.Raw(attribute="more_like_this_dict"), + "sensitive_word_avoidance": fields.Raw(attribute="sensitive_word_avoidance_dict"), + "external_data_tools": fields.Raw(attribute="external_data_tools_list"), + "model": fields.Raw(attribute="model_dict"), + "user_input_form": fields.Raw(attribute="user_input_form_list"), + "dataset_query_variable": fields.String, + "pre_prompt": fields.String, + "agent_mode": fields.Raw(attribute="agent_mode_dict"), + "prompt_type": fields.String, + "chat_prompt_config": fields.Raw(attribute="chat_prompt_config_dict"), + "completion_prompt_config": fields.Raw(attribute="completion_prompt_config_dict"), + "dataset_configs": fields.Raw(attribute="dataset_configs_dict"), + "file_upload": fields.Raw(attribute="file_upload_dict"), + "created_at": TimestampField, } app_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), - 'tracing': fields.Raw, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "tracing": fields.Raw, + "created_at": TimestampField, } prompt_config_fields = { - 'prompt_template': fields.String, + "prompt_template": fields.String, } model_config_partial_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, + "model": fields.Raw(attribute="model_dict"), + "pre_prompt": fields.String, } -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String -} +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} app_partial_fields = { - 'id': fields.String, - 'name': fields.String, - 'max_active_requests': fields.Raw(), - 'description': fields.String(attribute='desc_or_prompt'), - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon_type': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'icon_url': AppIconUrlField, - 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True), - 'created_at': TimestampField, - 'tags': fields.List(fields.Nested(tag_fields)) + "id": fields.String, + "name": fields.String, + "max_active_requests": fields.Raw(), + "description": fields.String(attribute="desc_or_prompt"), + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True), + "created_at": TimestampField, + "tags": fields.List(fields.Nested(tag_fields)) } app_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(app_partial_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(app_partial_fields), attribute="items"), } template_fields = { - 'name': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'mode': fields.String, - 'model_config': fields.Nested(model_config_fields), + "name": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "mode": fields.String, + "model_config": fields.Nested(model_config_fields), } template_list_fields = { - 'data': fields.List(fields.Nested(template_fields)), + "data": fields.List(fields.Nested(template_fields)), } site_fields = { - 'access_token': fields.String(attribute='code'), - 'code': fields.String, - 'title': fields.String, - 'icon_type': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'icon_url': AppIconUrlField, - 'description': fields.String, - 'default_language': fields.String, - 'chat_color_theme': fields.String, - 'chat_color_theme_inverted': fields.Boolean, - 'customize_domain': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'customize_token_strategy': fields.String, - 'prompt_public': fields.Boolean, - 'app_base_url': fields.String, - 'show_workflow_steps': fields.Boolean, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "default_language": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "app_base_url": fields.String, + "show_workflow_steps": fields.Boolean, } app_detail_fields_with_site = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon_type': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'icon_url': AppIconUrlField, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), - 'site': fields.Nested(site_fields), - 'api_base_url': fields.String, - 'created_at': TimestampField, - 'deleted_tools': fields.List(fields.String), + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "model_config": fields.Nested(model_config_fields, attribute="", allow_null=True), + "site": fields.Nested(site_fields), + "api_base_url": fields.String, + "created_at": TimestampField, + "deleted_tools": fields.List(fields.String), } app_site_fields = { - 'app_id': fields.String, - 'access_token': fields.String(attribute='code'), - 'code': fields.String, - 'title': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'default_language': fields.String, - 'customize_domain': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'customize_token_strategy': fields.String, - 'prompt_public': fields.Boolean, - 'show_workflow_steps': fields.Boolean, + "app_id": fields.String, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "default_language": fields.String, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, } diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 79ceb026852792..1b15fe38800b3e 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -6,205 +6,202 @@ class MessageTextField(fields.Raw): def format(self, value): - return value[0]['text'] if value else '' + return value[0]["text"] if value else "" feedback_fields = { - 'rating': fields.String, - 'content': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account': fields.Nested(simple_account_fields, allow_null=True), + "rating": fields.String, + "content": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account": fields.Nested(simple_account_fields, allow_null=True), } annotation_fields = { - 'id': fields.String, - 'question': fields.String, - 'content': fields.String, - 'account': fields.Nested(simple_account_fields, allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "question": fields.String, + "content": fields.String, + "account": fields.Nested(simple_account_fields, allow_null=True), + "created_at": TimestampField, } annotation_hit_history_fields = { - 'annotation_id': fields.String(attribute='id'), - 'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True), - 'created_at': TimestampField + "annotation_id": fields.String(attribute="id"), + "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True), + "created_at": TimestampField, } message_file_fields = { - 'id': fields.String, - 'type': fields.String, - 'url': fields.String, - 'belongs_to': fields.String(default='user'), + "id": fields.String, + "type": fields.String, + "url": fields.String, + "belongs_to": fields.String(default="user"), } agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'files': fields.List(fields.String), + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "files": fields.List(fields.String), } message_detail_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'message': fields.Raw, - 'message_tokens': fields.Integer, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'answer_tokens': fields.Integer, - 'provider_response_latency': fields.Float, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'feedbacks': fields.List(fields.Nested(feedback_fields)), - 'workflow_run_id': fields.String, - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'metadata': fields.Raw(attribute='message_metadata_dict'), - 'status': fields.String, - 'error': fields.String, -} - -feedback_stat_fields = { - 'like': fields.Integer, - 'dislike': fields.Integer -} + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "message": fields.Raw, + "message_tokens": fields.Integer, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "answer_tokens": fields.Integer, + "provider_response_latency": fields.Float, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "feedbacks": fields.List(fields.Nested(feedback_fields)), + "workflow_run_id": fields.String, + "annotation": fields.Nested(annotation_fields, allow_null=True), + "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "metadata": fields.Raw(attribute="message_metadata_dict"), + "status": fields.String, + "error": fields.String, +} + +feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'model': fields.Raw, - 'user_input_form': fields.Raw, - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw, + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "model": fields.Raw, + "user_input_form": fields.Raw, + "pre_prompt": fields.String, + "agent_mode": fields.Raw, } simple_configs_fields = { - 'prompt_template': fields.String, + "prompt_template": fields.String, } simple_model_config_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, + "model": fields.Raw(attribute="model_dict"), + "pre_prompt": fields.String, } simple_message_detail_fields = { - 'inputs': fields.Raw, - 'query': fields.String, - 'message': MessageTextField, - 'answer': fields.String, + "inputs": fields.Raw, + "query": fields.String, + "message": MessageTextField, + "answer": fields.String, } conversation_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_end_user_session_id': fields.String(), - 'from_account_id': fields.String, - 'read_at': TimestampField, - 'created_at': TimestampField, - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'model_config': fields.Nested(simple_model_config_fields), - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields), - 'message': fields.Nested(simple_message_detail_fields, attribute='first_message') + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_end_user_session_id": fields.String(), + "from_account_id": fields.String, + "read_at": TimestampField, + "created_at": TimestampField, + "annotation": fields.Nested(annotation_fields, allow_null=True), + "model_config": fields.Nested(simple_model_config_fields), + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), + "message": fields.Nested(simple_message_detail_fields, attribute="first_message"), } conversation_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(conversation_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(conversation_fields), attribute="items"), } conversation_message_detail_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'created_at': TimestampField, - 'model_config': fields.Nested(model_config_fields), - 'message': fields.Nested(message_detail_fields, attribute='first_message'), + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "created_at": TimestampField, + "model_config": fields.Nested(model_config_fields), + "message": fields.Nested(message_detail_fields, attribute="first_message"), } conversation_with_summary_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_end_user_session_id': fields.String, - 'from_account_id': fields.String, - 'name': fields.String, - 'summary': fields.String(attribute='summary_or_query'), - 'read_at': TimestampField, - 'created_at': TimestampField, - 'annotated': fields.Boolean, - 'model_config': fields.Nested(simple_model_config_fields), - 'message_count': fields.Integer, - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields) + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_end_user_session_id": fields.String, + "from_account_id": fields.String, + "name": fields.String, + "summary": fields.String(attribute="summary_or_query"), + "read_at": TimestampField, + "created_at": TimestampField, + "annotated": fields.Boolean, + "model_config": fields.Nested(simple_model_config_fields), + "message_count": fields.Integer, + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), } conversation_with_summary_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"), } conversation_detail_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'created_at': TimestampField, - 'annotated': fields.Boolean, - 'introduction': fields.String, - 'model_config': fields.Nested(model_config_fields), - 'message_count': fields.Integer, - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields) + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "created_at": TimestampField, + "annotated": fields.Boolean, + "introduction": fields.String, + "model_config": fields.Nested(model_config_fields), + "message_count": fields.Integer, + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), } simple_conversation_fields = { - 'id': fields.String, - 'name': fields.String, - 'inputs': fields.Raw, - 'status': fields.String, - 'introduction': fields.String, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "inputs": fields.Raw, + "status": fields.String, + "introduction": fields.String, + "created_at": TimestampField, } conversation_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(simple_conversation_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(simple_conversation_fields)), } conversation_with_model_config_fields = { **simple_conversation_fields, - 'model_config': fields.Raw, + "model_config": fields.Raw, } conversation_with_model_config_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(conversation_with_model_config_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(conversation_with_model_config_fields)), } diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py new file mode 100644 index 00000000000000..983e50e73ceb9f --- /dev/null +++ b/api/fields/conversation_variable_fields.py @@ -0,0 +1,21 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +conversation_variable_fields = { + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.String, + "description": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, +} + +paginated_conversation_variable_fields = { + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(conversation_variable_fields), attribute="data"), +} diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 6f3c920c85b60f..071071376fe6c8 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -2,64 +2,56 @@ from libs.helper import TimestampField -integrate_icon_fields = { - 'type': fields.String, - 'url': fields.String, - 'emoji': fields.String -} +integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} integrate_page_fields = { - 'page_name': fields.String, - 'page_id': fields.String, - 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), - 'is_bound': fields.Boolean, - 'parent_id': fields.String, - 'type': fields.String + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_fields, allow_null=True), + "is_bound": fields.Boolean, + "parent_id": fields.String, + "type": fields.String, } integrate_workspace_fields = { - 'workspace_name': fields.String, - 'workspace_id': fields.String, - 'workspace_icon': fields.String, - 'pages': fields.List(fields.Nested(integrate_page_fields)) + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(integrate_page_fields)), } integrate_notion_info_list_fields = { - 'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), + "notion_info": fields.List(fields.Nested(integrate_workspace_fields)), } -integrate_icon_fields = { - 'type': fields.String, - 'url': fields.String, - 'emoji': fields.String -} +integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} integrate_page_fields = { - 'page_name': fields.String, - 'page_id': fields.String, - 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), - 'parent_id': fields.String, - 'type': fields.String + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_fields, allow_null=True), + "parent_id": fields.String, + "type": fields.String, } integrate_workspace_fields = { - 'workspace_name': fields.String, - 'workspace_id': fields.String, - 'workspace_icon': fields.String, - 'pages': fields.List(fields.Nested(integrate_page_fields)), - 'total': fields.Integer + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(integrate_page_fields)), + "total": fields.Integer, } integrate_fields = { - 'id': fields.String, - 'provider': fields.String, - 'created_at': TimestampField, - 'is_bound': fields.Boolean, - 'disabled': fields.Boolean, - 'link': fields.String, - 'source_info': fields.Nested(integrate_workspace_fields) + "id": fields.String, + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "disabled": fields.Boolean, + "link": fields.String, + "source_info": fields.Nested(integrate_workspace_fields), } integrate_list_fields = { - 'data': fields.List(fields.Nested(integrate_fields)), -} \ No newline at end of file + "data": fields.List(fields.Nested(integrate_fields)), +} diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index a9f79b5c678e7c..9cf8da7acdc984 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -3,73 +3,64 @@ from libs.helper import TimestampField dataset_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'permission': fields.String, - 'data_source_type': fields.String, - 'indexing_technique': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, + "id": fields.String, + "name": fields.String, + "description": fields.String, + "permission": fields.String, + "data_source_type": fields.String, + "indexing_technique": fields.String, + "created_by": fields.String, + "created_at": TimestampField, } -reranking_model_fields = { - 'reranking_provider_name': fields.String, - 'reranking_model_name': fields.String -} +reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String} -keyword_setting_fields = { - 'keyword_weight': fields.Float -} +keyword_setting_fields = {"keyword_weight": fields.Float} vector_setting_fields = { - 'vector_weight': fields.Float, - 'embedding_model_name': fields.String, - 'embedding_provider_name': fields.String, + "vector_weight": fields.Float, + "embedding_model_name": fields.String, + "embedding_provider_name": fields.String, } weighted_score_fields = { - 'keyword_setting': fields.Nested(keyword_setting_fields), - 'vector_setting': fields.Nested(vector_setting_fields), + "keyword_setting": fields.Nested(keyword_setting_fields), + "vector_setting": fields.Nested(vector_setting_fields), } dataset_retrieval_model_fields = { - 'search_method': fields.String, - 'reranking_enable': fields.Boolean, - 'reranking_mode': fields.String, - 'reranking_model': fields.Nested(reranking_model_fields), - 'weights': fields.Nested(weighted_score_fields, allow_null=True), - 'top_k': fields.Integer, - 'score_threshold_enabled': fields.Boolean, - 'score_threshold': fields.Float + "search_method": fields.String, + "reranking_enable": fields.Boolean, + "reranking_mode": fields.String, + "reranking_model": fields.Nested(reranking_model_fields), + "weights": fields.Nested(weighted_score_fields, allow_null=True), + "top_k": fields.Integer, + "score_threshold_enabled": fields.Boolean, + "score_threshold": fields.Float, } -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String -} +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} dataset_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'provider': fields.String, - 'permission': fields.String, - 'data_source_type': fields.String, - 'indexing_technique': fields.String, - 'app_count': fields.Integer, - 'document_count': fields.Integer, - 'word_count': fields.Integer, - 'created_by': fields.String, - 'created_at': TimestampField, - 'updated_by': fields.String, - 'updated_at': TimestampField, - 'embedding_model': fields.String, - 'embedding_model_provider': fields.String, - 'embedding_available': fields.Boolean, - 'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields), - 'tags': fields.List(fields.Nested(tag_fields)) + "id": fields.String, + "name": fields.String, + "description": fields.String, + "provider": fields.String, + "permission": fields.String, + "data_source_type": fields.String, + "indexing_technique": fields.String, + "app_count": fields.Integer, + "document_count": fields.Integer, + "word_count": fields.Integer, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "embedding_model": fields.String, + "embedding_model_provider": fields.String, + "embedding_available": fields.Boolean, + "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), + "tags": fields.List(fields.Nested(tag_fields)), } dataset_query_detail_fields = { @@ -79,7 +70,5 @@ "source_app_id": fields.String, "created_by_role": fields.String, "created_by": fields.String, - "created_at": TimestampField + "created_at": TimestampField, } - - diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index e8215255b35d5b..a83ec7bc97adee 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -4,75 +4,73 @@ from libs.helper import TimestampField document_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'data_source_type': fields.String, - 'data_source_info': fields.Raw(attribute='data_source_info_dict'), - 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), - 'dataset_process_rule_id': fields.String, - 'name': fields.String, - 'created_from': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'tokens': fields.Integer, - 'indexing_status': fields.String, - 'error': fields.String, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'archived': fields.Boolean, - 'display_status': fields.String, - 'word_count': fields.Integer, - 'hit_count': fields.Integer, - 'doc_form': fields.String, + "id": fields.String, + "position": fields.Integer, + "data_source_type": fields.String, + "data_source_info": fields.Raw(attribute="data_source_info_dict"), + "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), + "dataset_process_rule_id": fields.String, + "name": fields.String, + "created_from": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "tokens": fields.Integer, + "indexing_status": fields.String, + "error": fields.String, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "archived": fields.Boolean, + "display_status": fields.String, + "word_count": fields.Integer, + "hit_count": fields.Integer, + "doc_form": fields.String, } document_with_segments_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'data_source_type': fields.String, - 'data_source_info': fields.Raw(attribute='data_source_info_dict'), - 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), - 'dataset_process_rule_id': fields.String, - 'name': fields.String, - 'created_from': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'tokens': fields.Integer, - 'indexing_status': fields.String, - 'error': fields.String, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'archived': fields.Boolean, - 'display_status': fields.String, - 'word_count': fields.Integer, - 'hit_count': fields.Integer, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer + "id": fields.String, + "position": fields.Integer, + "data_source_type": fields.String, + "data_source_info": fields.Raw(attribute="data_source_info_dict"), + "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), + "dataset_process_rule_id": fields.String, + "name": fields.String, + "created_from": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "tokens": fields.Integer, + "indexing_status": fields.String, + "error": fields.String, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "archived": fields.Boolean, + "display_status": fields.String, + "word_count": fields.Integer, + "hit_count": fields.Integer, + "completed_segments": fields.Integer, + "total_segments": fields.Integer, } dataset_and_document_fields = { - 'dataset': fields.Nested(dataset_fields), - 'documents': fields.List(fields.Nested(document_fields)), - 'batch': fields.String + "dataset": fields.Nested(dataset_fields), + "documents": fields.List(fields.Nested(document_fields)), + "batch": fields.String, } document_status_fields = { - 'id': fields.String, - 'indexing_status': fields.String, - 'processing_started_at': TimestampField, - 'parsing_completed_at': TimestampField, - 'cleaning_completed_at': TimestampField, - 'splitting_completed_at': TimestampField, - 'completed_at': TimestampField, - 'paused_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer, + "id": fields.String, + "indexing_status": fields.String, + "processing_started_at": TimestampField, + "parsing_completed_at": TimestampField, + "cleaning_completed_at": TimestampField, + "splitting_completed_at": TimestampField, + "completed_at": TimestampField, + "paused_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "completed_segments": fields.Integer, + "total_segments": fields.Integer, } -document_status_fields_list = { - 'data': fields.List(fields.Nested(document_status_fields)) -} \ No newline at end of file +document_status_fields_list = {"data": fields.List(fields.Nested(document_status_fields))} diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index ee630c12c2e9aa..99e529f9d1c076 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,8 +1,8 @@ from flask_restful import fields simple_end_user_fields = { - 'id': fields.String, - 'type': fields.String, - 'is_anonymous': fields.Boolean, - 'session_id': fields.String, + "id": fields.String, + "type": fields.String, + "is_anonymous": fields.Boolean, + "session_id": fields.String, } diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 2ef379dabc0d08..e5a03ce77ed5f0 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -3,17 +3,17 @@ from libs.helper import TimestampField upload_config_fields = { - 'file_size_limit': fields.Integer, - 'batch_count_limit': fields.Integer, - 'image_file_size_limit': fields.Integer, + "file_size_limit": fields.Integer, + "batch_count_limit": fields.Integer, + "image_file_size_limit": fields.Integer, } file_fields = { - 'id': fields.String, - 'name': fields.String, - 'size': fields.Integer, - 'extension': fields.String, - 'mime_type': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, -} \ No newline at end of file + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 541e56a378dae4..f36e80f8d493d5 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -3,39 +3,39 @@ from libs.helper import TimestampField document_fields = { - 'id': fields.String, - 'data_source_type': fields.String, - 'name': fields.String, - 'doc_type': fields.String, + "id": fields.String, + "data_source_type": fields.String, + "name": fields.String, + "doc_type": fields.String, } segment_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'document_id': fields.String, - 'content': fields.String, - 'answer': fields.String, - 'word_count': fields.Integer, - 'tokens': fields.Integer, - 'keywords': fields.List(fields.String), - 'index_node_id': fields.String, - 'index_node_hash': fields.String, - 'hit_count': fields.Integer, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'status': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'indexing_at': TimestampField, - 'completed_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'document': fields.Nested(document_fields), + "id": fields.String, + "position": fields.Integer, + "document_id": fields.String, + "content": fields.String, + "answer": fields.String, + "word_count": fields.Integer, + "tokens": fields.Integer, + "keywords": fields.List(fields.String), + "index_node_id": fields.String, + "index_node_hash": fields.String, + "hit_count": fields.Integer, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "status": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "indexing_at": TimestampField, + "completed_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "document": fields.Nested(document_fields), } hit_testing_record_fields = { - 'segment': fields.Nested(segment_fields), - 'score': fields.Float, - 'tsne_position': fields.Raw -} \ No newline at end of file + "segment": fields.Nested(segment_fields), + "score": fields.Float, + "tsne_position": fields.Raw, +} diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 6fc5f6ca4d01fa..9afc1b1a4ad79d 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -3,25 +3,23 @@ from libs.helper import AppIconUrlField, TimestampField app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon_type': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'icon_url': AppIconUrlField, + "id": fields.String, + "name": fields.String, + "mode": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, } installed_app_fields = { - 'id': fields.String, - 'app': fields.Nested(app_fields), - 'app_owner_tenant_id': fields.String, - 'is_pinned': fields.Boolean, - 'last_used_at': TimestampField, - 'editable': fields.Boolean, - 'uninstallable': fields.Boolean + "id": fields.String, + "app": fields.Nested(app_fields), + "app_owner_tenant_id": fields.String, + "is_pinned": fields.Boolean, + "last_used_at": TimestampField, + "editable": fields.Boolean, + "uninstallable": fields.Boolean, } -installed_app_list_fields = { - 'installed_apps': fields.List(fields.Nested(installed_app_fields)) -} \ No newline at end of file +installed_app_list_fields = {"installed_apps": fields.List(fields.Nested(installed_app_fields))} diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index d061b59c347022..1cf8e408d13d32 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -2,38 +2,32 @@ from libs.helper import TimestampField -simple_account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} +simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String} account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'is_password_set': fields.Boolean, - 'interface_language': fields.String, - 'interface_theme': fields.String, - 'timezone': fields.String, - 'last_login_at': TimestampField, - 'last_login_ip': fields.String, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "avatar": fields.String, + "email": fields.String, + "is_password_set": fields.Boolean, + "interface_language": fields.String, + "interface_theme": fields.String, + "timezone": fields.String, + "last_login_at": TimestampField, + "last_login_ip": fields.String, + "created_at": TimestampField, } account_with_role_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'last_login_at': TimestampField, - 'last_active_at': TimestampField, - 'created_at': TimestampField, - 'role': fields.String, - 'status': fields.String, + "id": fields.String, + "name": fields.String, + "avatar": fields.String, + "email": fields.String, + "last_login_at": TimestampField, + "last_active_at": TimestampField, + "created_at": TimestampField, + "role": fields.String, + "status": fields.String, } -account_with_role_list_fields = { - 'accounts': fields.List(fields.Nested(account_with_role_fields)) -} +account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))} diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 31168435892427..3d2df87afb9b19 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,83 +3,79 @@ from fields.conversation_fields import message_file_fields from libs.helper import TimestampField -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'files': fields.List(fields.String) + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "files": fields.List(fields.String), } retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index e41d1a53dd0234..2dd4cb45be409b 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -3,31 +3,31 @@ from libs.helper import TimestampField segment_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'document_id': fields.String, - 'content': fields.String, - 'answer': fields.String, - 'word_count': fields.Integer, - 'tokens': fields.Integer, - 'keywords': fields.List(fields.String), - 'index_node_id': fields.String, - 'index_node_hash': fields.String, - 'hit_count': fields.Integer, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'status': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'indexing_at': TimestampField, - 'completed_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField + "id": fields.String, + "position": fields.Integer, + "document_id": fields.String, + "content": fields.String, + "answer": fields.String, + "word_count": fields.Integer, + "tokens": fields.Integer, + "keywords": fields.List(fields.String), + "index_node_id": fields.String, + "index_node_hash": fields.String, + "hit_count": fields.Integer, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "status": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "indexing_at": TimestampField, + "completed_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, } segment_list_response = { - 'data': fields.List(fields.Nested(segment_fields)), - 'has_more': fields.Boolean, - 'limit': fields.Integer + "data": fields.List(fields.Nested(segment_fields)), + "has_more": fields.Boolean, + "limit": fields.Integer, } diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index f7e030b738e537..9af4fc57dd061c 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,8 +1,3 @@ from flask_restful import fields -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String, - 'binding_count': fields.String -} \ No newline at end of file +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index e230c159fba59a..a53b54624915c2 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -7,18 +7,18 @@ workflow_app_log_partial_fields = { "id": fields.String, - "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute='workflow_run', allow_null=True), + "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True), "created_from": fields.String, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), - "created_at": TimestampField + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "created_at": TimestampField, } workflow_app_log_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(workflow_app_log_partial_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(workflow_app_log_partial_fields), attribute="items"), } diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index ff33a97ff2a9ab..240b8f2eb03e79 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -13,41 +13,43 @@ def format(self, value): # Mask secret variables values in environment_variables if isinstance(value, SecretVariable): return { - 'id': value.id, - 'name': value.name, - 'value': encrypter.obfuscated_token(value.value), - 'value_type': value.value_type.value, + "id": value.id, + "name": value.name, + "value": encrypter.obfuscated_token(value.value), + "value_type": value.value_type.value, } if isinstance(value, Variable): return { - 'id': value.id, - 'name': value.name, - 'value': value.value, - 'value_type': value.value_type.value, + "id": value.id, + "name": value.name, + "value": value.value, + "value_type": value.value_type.value, } if isinstance(value, dict): - value_type = value.get('value_type') + value_type = value.get("value_type") if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: - raise ValueError(f'Unsupported environment variable value type: {value_type}') + raise ValueError(f"Unsupported environment variable value type: {value_type}") return value -environment_variable_fields = { - 'id': fields.String, - 'name': fields.String, - 'value': fields.Raw, - 'value_type': fields.String(attribute='value_type.value'), +conversation_variable_fields = { + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.Raw, + "description": fields.String, } workflow_fields = { - 'id': fields.String, - 'graph': fields.Raw(attribute='graph_dict'), - 'features': fields.Raw(attribute='features_dict'), - 'hash': fields.String(attribute='unique_hash'), - 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), - 'created_at': TimestampField, - 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), - 'updated_at': TimestampField, - 'tool_published': fields.Boolean, - 'environment_variables': fields.List(EnvironmentVariableField()), + "id": fields.String, + "graph": fields.Raw(attribute="graph_dict"), + "features": fields.Raw(attribute="features_dict"), + "hash": fields.String(attribute="unique_hash"), + "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), + "created_at": TimestampField, + "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), + "updated_at": TimestampField, + "tool_published": fields.Boolean, + "environment_variables": fields.List(EnvironmentVariableField()), + "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), } diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 3e798473cd0481..1413adf7196879 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -13,7 +13,7 @@ "total_tokens": fields.Integer, "total_steps": fields.Integer, "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } workflow_run_for_list_fields = { @@ -24,9 +24,9 @@ "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } advanced_chat_workflow_run_for_list_fields = { @@ -39,40 +39,40 @@ "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } advanced_chat_workflow_run_pagination_fields = { - 'limit': fields.Integer(attribute='limit'), - 'has_more': fields.Boolean(attribute='has_more'), - 'data': fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute='data') + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), + "data": fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute="data"), } workflow_run_pagination_fields = { - 'limit': fields.Integer(attribute='limit'), - 'has_more': fields.Boolean(attribute='has_more'), - 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='data') + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), + "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"), } workflow_run_detail_fields = { "id": fields.String, "sequence_number": fields.Integer, "version": fields.String, - "graph": fields.Raw(attribute='graph_dict'), - "inputs": fields.Raw(attribute='inputs_dict'), + "graph": fields.Raw(attribute="graph_dict"), + "inputs": fields.Raw(attribute="inputs_dict"), "status": fields.String, - "outputs": fields.Raw(attribute='outputs_dict'), + "outputs": fields.Raw(attribute="outputs_dict"), "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } workflow_run_node_execution_fields = { @@ -82,21 +82,21 @@ "node_id": fields.String, "node_type": fields.String, "title": fields.String, - "inputs": fields.Raw(attribute='inputs_dict'), - "process_data": fields.Raw(attribute='process_data_dict'), - "outputs": fields.Raw(attribute='outputs_dict'), + "inputs": fields.Raw(attribute="inputs_dict"), + "process_data": fields.Raw(attribute="process_data_dict"), + "outputs": fields.Raw(attribute="outputs_dict"), "status": fields.String, "error": fields.String, "elapsed_time": fields.Float, - "execution_metadata": fields.Raw(attribute='execution_metadata_dict'), + "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), "extras": fields.Raw, "created_at": TimestampField, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), - "finished_at": TimestampField + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "finished_at": TimestampField, } workflow_run_node_execution_list_fields = { - 'data': fields.List(fields.Nested(workflow_run_node_execution_fields)), + "data": fields.List(fields.Nested(workflow_run_node_execution_fields)), } diff --git a/api/libs/bearer_data_source.py b/api/libs/bearer_data_source.py index 04de1fb6daefbd..c1aee7b819e411 100644 --- a/api/libs/bearer_data_source.py +++ b/api/libs/bearer_data_source.py @@ -2,10 +2,10 @@ from abc import abstractmethod import requests -from api.models.source import DataSourceBearerBinding from flask_login import current_user from extensions.ext_database import db +from models.source import DataSourceBearerBinding class BearerDataSource: diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index a5c7814a543bdc..358858ceb1ec4d 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -154,11 +154,11 @@ def get_authorized_pages(self, access_token: str): for page_result in page_results: page_id = page_result['id'] page_name = 'Untitled' - for key in ['Name', 'title', 'Title', 'Page']: - if key in page_result['properties']: - if len(page_result['properties'][key].get('title', [])) > 0: - page_name = page_result['properties'][key]['title'][0]['plain_text'] - break + for key in page_result['properties']: + if 'title' in page_result['properties'][key] and page_result['properties'][key]['title']: + title_list = page_result['properties'][key]['title'] + if len(title_list) > 0 and 'plain_text' in title_list[0]: + page_name = title_list[0]['plain_text'] page_icon = page_result['icon'] if page_icon: icon_type = page_icon['type'] diff --git a/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py b/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py new file mode 100644 index 00000000000000..16e1efd4efd4ed --- /dev/null +++ b/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py @@ -0,0 +1,51 @@ +"""support conversation variables + +Revision ID: 63a83fcf12ba +Revises: 1787fbae959a +Create Date: 2024-08-13 06:33:07.950379 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '63a83fcf12ba' +down_revision = '1787fbae959a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow__conversation_variables', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('conversation_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('data', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', 'conversation_id', name=op.f('workflow__conversation_variables_pkey')) + ) + with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op: + batch_op.create_index(batch_op.f('workflow__conversation_variables_app_id_idx'), ['app_id'], unique=False) + batch_op.create_index(batch_op.f('workflow__conversation_variables_created_at_idx'), ['created_at'], unique=False) + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.add_column(sa.Column('conversation_variables', sa.Text(), server_default='{}', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_column('conversation_variables') + + with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow__conversation_variables_created_at_idx')) + batch_op.drop_index(batch_op.f('workflow__conversation_variables_app_id_idx')) + + op.drop_table('workflow__conversation_variables') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py new file mode 100644 index 00000000000000..eba78e2e77d5d8 --- /dev/null +++ b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py @@ -0,0 +1,33 @@ +"""add conversations.dialogue_count + +Revision ID: 8782057ff0dc +Revises: 63a83fcf12ba +Create Date: 2024-08-14 13:54:25.161324 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '8782057ff0dc' +down_revision = '63a83fcf12ba' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_column('dialogue_count') + + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 3b832cd22d8120..4012611471c337 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,15 +1,19 @@ from enum import Enum -from sqlalchemy import CHAR, TypeDecorator -from sqlalchemy.dialects.postgresql import UUID +from .model import App, AppMode, Message +from .types import StringUUID +from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus + +__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message'] class CreatedByRole(Enum): """ Enum class for createdByRole """ - ACCOUNT = "account" - END_USER = "end_user" + + ACCOUNT = 'account' + END_USER = 'end_user' @classmethod def value_of(cls, value: str) -> 'CreatedByRole': @@ -23,49 +27,3 @@ def value_of(cls, value: str) -> 'CreatedByRole': if role.value == value: return role raise ValueError(f'invalid createdByRole value {value}') - - -class CreatedFrom(Enum): - """ - Enum class for createdFrom - """ - SERVICE_API = "service-api" - WEB_APP = "web-app" - EXPLORE = "explore" - - @classmethod - def value_of(cls, value: str) -> 'CreatedFrom': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for role in cls: - if role.value == value: - return role - raise ValueError(f'invalid createdFrom value {value}') - - -class StringUUID(TypeDecorator): - impl = CHAR - cache_ok = True - - def process_bind_param(self, value, dialect): - if value is None: - return value - elif dialect.name == 'postgresql': - return str(value) - else: - return value.hex - - def load_dialect_impl(self, dialect): - if dialect.name == 'postgresql': - return dialect.type_descriptor(UUID()) - else: - return dialect.type_descriptor(CHAR(36)) - - def process_result_value(self, value, dialect): - if value is None: - return value - return str(value) diff --git a/api/models/account.py b/api/models/account.py index d36b2b9fda3278..67d940b7b7190e 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,7 +4,8 @@ from flask_login import UserMixin from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class AccountStatus(str, enum.Enum): diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index d1f9cd78a72e45..7f69323628a7cc 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,8 @@ import enum from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class APIBasedExtensionPoint(enum.Enum): diff --git a/api/models/dataset.py b/api/models/dataset.py index 40f9f4cf83ae96..0d48177eb60409 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -16,9 +16,10 @@ from core.rag.retrieval.retrival_methods import RetrievalMethod from extensions.ext_database import db from extensions.ext_storage import storage -from models import StringUUID -from models.account import Account -from models.model import App, Tag, TagBinding, UploadFile + +from .account import Account +from .model import App, Tag, TagBinding, UploadFile +from .types import StringUUID class Dataset(db.Model): diff --git a/api/models/model.py b/api/models/model.py index 7f837a313f8d57..94cfa527a7d792 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -7,6 +7,7 @@ from flask import request from flask_login import UserMixin from sqlalchemy import Float, func, text +from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config from core.file.tool_file_parser import ToolFileParser @@ -14,8 +15,8 @@ from extensions.ext_database import db from libs.helper import generate_string -from . import StringUUID from .account import Account, Tenant +from .types import StringUUID class DifySetup(db.Model): @@ -517,12 +518,12 @@ class Conversation(db.Model): from_account_id = db.Column(StringUUID) read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) + dialogue_count: Mapped[int] = mapped_column(default=0) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") - message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', - passive_deletes="all") + message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) diff --git a/api/models/provider.py b/api/models/provider.py index 4c14c33f095cee..5d92ee6eb60d18 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,7 +1,8 @@ from enum import Enum from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class ProviderType(Enum): diff --git a/api/models/source.py b/api/models/source.py index 265e68f014c6c2..adc00028bee43b 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -3,7 +3,8 @@ from sqlalchemy.dialects.postgresql import JSONB from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class DataSourceOauthBinding(db.Model): diff --git a/api/models/tool.py b/api/models/tool.py index f322944f5f0a8e..79a70c6b1f2d22 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -2,7 +2,8 @@ from enum import Enum from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class ToolProviderName(Enum): diff --git a/api/models/tools.py b/api/models/tools.py index 695ec26fbf515e..069dc5bad083c8 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -6,8 +6,9 @@ from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from models import StringUUID -from models.model import Account, App, Tenant + +from .model import Account, App, Tenant +from .types import StringUUID class BuiltinToolProvider(db.Model): diff --git a/api/models/types.py b/api/models/types.py new file mode 100644 index 00000000000000..1614ec20188541 --- /dev/null +++ b/api/models/types.py @@ -0,0 +1,26 @@ +from sqlalchemy import CHAR, TypeDecorator +from sqlalchemy.dialects.postgresql import UUID + + +class StringUUID(TypeDecorator): + impl = CHAR + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return value + elif dialect.name == 'postgresql': + return str(value) + else: + return value.hex + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + return dialect.type_descriptor(UUID()) + else: + return dialect.type_descriptor(CHAR(36)) + + def process_result_value(self, value, dialect): + if value is None: + return value + return str(value) \ No newline at end of file diff --git a/api/models/web.py b/api/models/web.py index 6fd27206a972db..0e901d5f842691 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,7 +1,8 @@ from extensions.ext_database import db -from models import StringUUID -from models.model import Message + +from .model import Message +from .types import StringUUID class SavedMessage(db.Model): diff --git a/api/models/workflow.py b/api/models/workflow.py index df2269cd0fb6cc..759e07c7154e0d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -3,18 +3,18 @@ from enum import Enum from typing import Any, Optional, Union +from sqlalchemy import func +from sqlalchemy.orm import Mapped + import contexts from constants import HIDDEN_VALUE -from core.app.segments import ( - SecretVariable, - Variable, - factory, -) +from core.app.segments import SecretVariable, Variable, factory from core.helper import encrypter from extensions.ext_database import db from libs import helper -from models import StringUUID -from models.account import Account + +from .account import Account +from .types import StringUUID class CreatedByRole(Enum): @@ -122,6 +122,7 @@ class Workflow(db.Model): updated_by = db.Column(StringUUID) updated_at = db.Column(db.DateTime) _environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') + _conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') @property def created_by_account(self): @@ -249,9 +250,27 @@ def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: 'graph': self.graph_dict, 'features': self.features_dict, 'environment_variables': [var.model_dump(mode='json') for var in environment_variables], + 'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables], } return result + @property + def conversation_variables(self) -> Sequence[Variable]: + # TODO: find some way to init `self._conversation_variables` when instance created. + if self._conversation_variables is None: + self._conversation_variables = '{}' + + variables_dict: dict[str, Any] = json.loads(self._conversation_variables) + results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()] + return results + + @conversation_variables.setter + def conversation_variables(self, value: Sequence[Variable]) -> None: + self._conversation_variables = json.dumps( + {var.name: var.model_dump() for var in value}, + ensure_ascii=False, + ) + class WorkflowRunTriggeredFrom(Enum): """ @@ -702,3 +721,34 @@ def created_by_end_user(self): created_by_role = CreatedByRole.value_of(self.created_by_role) return db.session.get(EndUser, self.created_by) \ if created_by_role == CreatedByRole.END_USER else None + + +class ConversationVariable(db.Model): + __tablename__ = 'workflow__conversation_variables' + + id: Mapped[str] = db.Column(StringUUID, primary_key=True) + conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) + app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) + data = db.Column(db.Text, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) + + def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: + self.id = id + self.app_id = app_id + self.conversation_id = conversation_id + self.data = data + + @classmethod + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable': + obj = cls( + id=variable.id, + app_id=app_id, + conversation_id=conversation_id, + data=variable.model_dump_json(), + ) + return obj + + def to_variable(self) -> Variable: + mapping = json.loads(self.data) + return factory.build_variable_from_mapping(mapping) diff --git a/api/poetry.lock b/api/poetry.lock index 89d017f656a2f6..358f9f8510c724 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -2100,6 +2100,44 @@ primp = ">=0.5.5" dev = ["mypy (>=1.11.0)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff (>=0.5.5)"] lxml = ["lxml (>=5.2.2)"] +[[package]] +name = "elastic-transport" +version = "8.15.0" +description = "Transport classes and utilities shared among Python Elastic client libraries" +optional = false +python-versions = ">=3.8" +files = [ + {file = "elastic_transport-8.15.0-py3-none-any.whl", hash = "sha256:d7080d1dada2b4eee69e7574f9c17a76b42f2895eff428e562f94b0360e158c0"}, + {file = "elastic_transport-8.15.0.tar.gz", hash = "sha256:85d62558f9baafb0868c801233a59b235e61d7b4804c28c2fadaa866b6766233"}, +] + +[package.dependencies] +certifi = "*" +urllib3 = ">=1.26.2,<3" + +[package.extras] +develop = ["aiohttp", "furo", "httpx", "opentelemetry-api", "opentelemetry-sdk", "orjson", "pytest", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "pytest-mock", "requests", "respx", "sphinx (>2)", "sphinx-autodoc-typehints", "trustme"] + +[[package]] +name = "elasticsearch" +version = "8.14.0" +description = "Python client for Elasticsearch" +optional = false +python-versions = ">=3.7" +files = [ + {file = "elasticsearch-8.14.0-py3-none-any.whl", hash = "sha256:cef8ef70a81af027f3da74a4f7d9296b390c636903088439087b8262a468c130"}, + {file = "elasticsearch-8.14.0.tar.gz", hash = "sha256:aa2490029dd96f4015b333c1827aa21fd6c0a4d223b00dfb0fe933b8d09a511b"}, +] + +[package.dependencies] +elastic-transport = ">=8.13,<9" + +[package.extras] +async = ["aiohttp (>=3,<4)"] +orjson = ["orjson (>=3)"] +requests = ["requests (>=2.4.0,!=2.32.2,<3.0.0)"] +vectorstore-mmr = ["numpy (>=1)", "simsimd (>=3)"] + [[package]] name = "emoji" version = "2.12.1" @@ -9546,4 +9584,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "2b822039247a445f72e04e967aef84f841781e2789b70071acad022f36ba26a5" +content-hash = "05dfa6b9bce9ed8ac21caf58eff1596f146080ab2ab6987924b189be673c22cf" diff --git a/api/pyproject.toml b/api/pyproject.toml index 058d67c42fe11c..3e107f5e9b0bcc 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -69,7 +69,18 @@ ignore = [ ] [tool.ruff.format] -quote-style = "single" +exclude = [ + "core/**/*.py", + "controllers/**/*.py", + "models/**/*.py", + "utils/**/*.py", + "migrations/**/*", + "services/**/*.py", + "tasks/**/*.py", + "tests/**/*.py", + "libs/**/*.py", + "configs/**/*.py", +] [tool.pytest_env] OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii" @@ -181,6 +192,7 @@ zhipuai = "1.0.7" rank-bm25 = "~0.2.2" openpyxl = "^3.1.5" kaleido = "0.2.1" +elasticsearch = "8.14.0" ############################################################ # Tool dependencies required by tool implementations diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index ccc1062266a02f..67d070682867bb 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -11,27 +11,32 @@ from models.dataset import Embedding -@app.celery.task(queue='dataset') +@app.celery.task(queue="dataset") def clean_embedding_cache_task(): - click.echo(click.style('Start clean embedding cache.', fg='green')) + click.echo(click.style("Start clean embedding cache.", fg="green")) clean_days = int(dify_config.CLEAN_DAY_SETTING) start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) while True: try: - embedding_ids = db.session.query(Embedding.id).filter(Embedding.created_at < thirty_days_ago) \ - .order_by(Embedding.created_at.desc()).limit(100).all() + embedding_ids = ( + db.session.query(Embedding.id) + .filter(Embedding.created_at < thirty_days_ago) + .order_by(Embedding.created_at.desc()) + .limit(100) + .all() + ) embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] except NotFound: break if embedding_ids: for embedding_id in embedding_ids: - db.session.execute(text( - "DELETE FROM embeddings WHERE id = :embedding_id" - ), {'embedding_id': embedding_id}) + db.session.execute( + text("DELETE FROM embeddings WHERE id = :embedding_id"), {"embedding_id": embedding_id} + ) db.session.commit() else: break end_at = time.perf_counter() - click.echo(click.style('Cleaned embedding cache from db success latency: {}'.format(end_at - start_at), fg='green')) + click.echo(click.style("Cleaned embedding cache from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index b2b2f82b786f5e..3d799bfd4ef732 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -12,9 +12,9 @@ from models.dataset import Dataset, DatasetQuery, Document -@app.celery.task(queue='dataset') +@app.celery.task(queue="dataset") def clean_unused_datasets_task(): - click.echo(click.style('Start clean unused datasets indexes.', fg='green')) + click.echo(click.style("Start clean unused datasets indexes.", fg="green")) clean_days = dify_config.CLEAN_DAY_SETTING start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) @@ -22,40 +22,44 @@ def clean_unused_datasets_task(): while True: try: # Subquery for counting new documents - document_subquery_new = db.session.query( - Document.dataset_id, - func.count(Document.id).label('document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.updated_at > thirty_days_ago - ).group_by(Document.dataset_id).subquery() + document_subquery_new = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at > thirty_days_ago, + ) + .group_by(Document.dataset_id) + .subquery() + ) # Subquery for counting old documents - document_subquery_old = db.session.query( - Document.dataset_id, - func.count(Document.id).label('document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.updated_at < thirty_days_ago - ).group_by(Document.dataset_id).subquery() + document_subquery_old = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at < thirty_days_ago, + ) + .group_by(Document.dataset_id) + .subquery() + ) # Main query with join and filter - datasets = (db.session.query(Dataset) - .outerjoin( - document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id - ).outerjoin( - document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id - ).filter( - Dataset.created_at < thirty_days_ago, - func.coalesce(document_subquery_new.c.document_count, 0) == 0, - func.coalesce(document_subquery_old.c.document_count, 0) > 0 - ).order_by( - Dataset.created_at.desc() - ).paginate(page=page, per_page=50)) + datasets = ( + db.session.query(Dataset) + .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) + .filter( + Dataset.created_at < thirty_days_ago, + func.coalesce(document_subquery_new.c.document_count, 0) == 0, + func.coalesce(document_subquery_old.c.document_count, 0) > 0, + ) + .order_by(Dataset.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break @@ -63,10 +67,11 @@ def clean_unused_datasets_task(): break page += 1 for dataset in datasets: - dataset_query = db.session.query(DatasetQuery).filter( - DatasetQuery.created_at > thirty_days_ago, - DatasetQuery.dataset_id == dataset.id - ).all() + dataset_query = ( + db.session.query(DatasetQuery) + .filter(DatasetQuery.created_at > thirty_days_ago, DatasetQuery.dataset_id == dataset.id) + .all() + ) if not dataset_query or len(dataset_query) == 0: try: # remove index @@ -74,17 +79,14 @@ def clean_unused_datasets_task(): index_processor.clean(dataset, None) # update document - update_params = { - Document.enabled: False - } + update_params = {Document.enabled: False} Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.commit() - click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id), - fg='green')) + click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) except Exception as e: click.echo( - click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) end_at = time.perf_counter() - click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green')) + click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index af0da2e87ec525..11ebbbdaf4c74a 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -13,9 +13,9 @@ logger = logging.getLogger(__name__) -current_dsl_version = "0.1.0" +current_dsl_version = "0.1.1" dsl_to_dify_version_mapping: dict[str, str] = { - "0.1.0": "0.6.0", # dsl version -> from dify version + "0.1.1": "0.6.0", # dsl version -> from dify version } @@ -244,6 +244,8 @@ def _import_and_create_new_workflow_based_app(cls, # init draft workflow environment_variables_list = workflow_data.get('environment_variables') or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + conversation_variables_list = workflow_data.get('conversation_variables') or [] + conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow_service = WorkflowService() draft_workflow = workflow_service.sync_draft_workflow( app_model=app, @@ -252,6 +254,7 @@ def _import_and_create_new_workflow_based_app(cls, unique_hash=None, account=account, environment_variables=environment_variables, + conversation_variables=conversation_variables, ) workflow_service.publish_workflow( app_model=app, diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 74c32f5097314d..610330eda55727 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -6,7 +6,6 @@ DatasetRetrieveConfigEntity, EasyUIBasedAppConfig, ExternalDataVariableEntity, - FileExtraConfig, ModelConfigEntity, PromptTemplateEntity, VariableEntity, @@ -14,6 +13,7 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.file.file_obj import FileExtraConfig from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6101ead1d5470b..c593b66f363dc7 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -72,6 +72,7 @@ def sync_draft_workflow( unique_hash: Optional[str], account: Account, environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], ) -> Workflow: """ Sync draft workflow @@ -99,7 +100,8 @@ def sync_draft_workflow( graph=json.dumps(graph), features=json.dumps(features), created_by=account.id, - environment_variables=environment_variables + environment_variables=environment_variables, + conversation_variables=conversation_variables, ) db.session.add(workflow) # update draft workflow if found @@ -109,6 +111,7 @@ def sync_draft_workflow( workflow.updated_by = account.id workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow.environment_variables = environment_variables + workflow.conversation_variables = conversation_variables # commit db session changes db.session.commit() @@ -145,7 +148,8 @@ def publish_workflow(self, app_model: App, graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, - environment_variables=draft_workflow.environment_variables + environment_variables=draft_workflow.environment_variables, + conversation_variables=draft_workflow.conversation_variables, ) # commit db session changes @@ -337,8 +341,8 @@ def get_elapsed_time(cls, workflow_run_id: str) -> float: ) if not workflow_nodes: return elapsed_time - + for node in workflow_nodes: elapsed_time += node.elapsed_time - return elapsed_time \ No newline at end of file + return elapsed_time diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 378756e68c8202..4efe7ee38c0b32 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -1,8 +1,10 @@ import logging import time +from collections.abc import Callable import click from celery import shared_task +from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError from extensions.ext_database import db @@ -28,7 +30,7 @@ ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun @shared_task(queue='app_deletion', bind=True, max_retries=3) @@ -54,6 +56,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_app_tag_bindings(tenant_id, app_id) _delete_end_users(tenant_id, app_id) _delete_trace_app_configs(tenant_id, app_id) + _delete_conversation_variables(app_id=app_id) end_at = time.perf_counter() logging.info(click.style(f'App and related data deleted: {app_id} latency: {end_at - start_at}', fg='green')) @@ -225,6 +228,13 @@ def del_conversation(conversation_id: str): "conversation" ) +def _delete_conversation_variables(*, app_id: str): + stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) + with db.engine.connect() as conn: + conn.execute(stmt) + conn.commit() + logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg='green')) + def _delete_app_messages(tenant_id: str, app_id: str): def del_message(message_id: str): @@ -299,7 +309,7 @@ def del_trace_app_config(trace_app_config_id: str): ) -def _delete_records(query_sql: str, params: dict, delete_func: callable, name: str) -> None: +def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: while True: with db.engine.begin() as conn: rs = conn.execute(db.text(query_sql), params) diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py index 2f66d707ca618b..c2fe95974b10f1 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -1,5 +1,4 @@ - -from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter +from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter class MockTEIClass: @@ -12,7 +11,7 @@ def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraPa model_type = 'embedding' return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) - + @staticmethod def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: # Use space as token separator, and split the text into tokens diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py index da65c7dfc7c92d..ed371fbc07aa8d 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py @@ -1,12 +1,12 @@ import os import pytest -from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import ( HuggingfaceTeiTextEmbeddingModel, + TeiHelper, ) from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py new file mode 100644 index 00000000000000..61079104dcad73 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py @@ -0,0 +1,59 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import ( + OAICompatSpeech2TextModel, +) + + +def test_validate_credentials(): + model = OAICompatSpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="whisper-1", + credentials={ + "api_key": "invalid_key", + "endpoint_url": "https://api.openai.com/v1/" + }, + ) + + model.validate_credentials( + model="whisper-1", + credentials={ + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/" + }, + ) + + +def test_invoke_model(): + model = OAICompatSpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="whisper-1", + credentials={ + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/" + }, + file=file, + user="abc-123", + ) + + assert isinstance(result, str) + assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py new file mode 100644 index 00000000000000..82b7921c8506f0 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py @@ -0,0 +1,53 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.speech2text.speech2text import SiliconflowSpeech2TextModel + + +def test_validate_credentials(): + model = SiliconflowSpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="iic/SenseVoiceSmall", + credentials={ + "api_key": "invalid_key" + }, + ) + + model.validate_credentials( + model="iic/SenseVoiceSmall", + credentials={ + "api_key": os.environ.get("API_KEY") + }, + ) + + +def test_invoke_model(): + model = SiliconflowSpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="iic/SenseVoiceSmall", + credentials={ + "api_key": os.environ.get("API_KEY") + }, + file=file + ) + + assert isinstance(result, str) + assert result == '1,2,3,4,5,6,7,8,9,10.' diff --git a/api/tests/integration_tests/vdb/elasticsearch/__init__.py b/api/tests/integration_tests/vdb/elasticsearch/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py new file mode 100644 index 00000000000000..b1c1cc10d9375d --- /dev/null +++ b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py @@ -0,0 +1,25 @@ +from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +class ElasticSearchVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.vector = ElasticSearchVector( + index_name=self.collection_name.lower(), + config=ElasticSearchConfig( + host='http://localhost', + port='9200', + username='elastic', + password='elastic' + ), + attributes=self.attributes + ) + + +def test_elasticsearch_vector(setup_mock_redis): + ElasticSearchVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index ac704e4eaf54df..4686ce06752ed5 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -10,8 +10,8 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db @@ -236,4 +236,4 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert 'sunny' in json.dumps(result.process_data) - assert 'what\'s the weather today?' in json.dumps(result.process_data) \ No newline at end of file + assert 'what\'s the weather today?' in json.dumps(result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 312ad47026beb5..adf5ffe3cadf77 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -12,8 +12,8 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db @@ -363,7 +363,7 @@ def test_extract_json_response(): { "location": "kawaii" } - hello world. + hello world. """) assert result['location'] == 'kawaii' @@ -445,4 +445,4 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): assert latest_role != prompt.get('role') if prompt.get('role') in ['user', 'assistant']: - latest_role = prompt.get('role') \ No newline at end of file + latest_role = prompt.get('role') diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index a88dd939bbc3ee..afd0fa50b590f8 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -3,19 +3,17 @@ import pytest from core.app.segments import ( - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileVariable, FloatVariable, IntegerVariable, - NoneSegment, ObjectSegment, SecretVariable, StringVariable, factory, ) +from core.app.segments.exc import VariableError def test_string_variable(): @@ -44,7 +42,7 @@ def test_secret_variable(): def test_invalid_value_type(): test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'} - with pytest.raises(ValueError): + with pytest.raises(VariableError): factory.build_variable_from_mapping(test_data) @@ -77,26 +75,14 @@ def test_object_variable(): 'name': 'test_object', 'description': 'Description of the variable.', 'value': { - 'key1': { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - 'key2': { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, + 'key1': 'text', + 'key2': 2, }, } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ObjectSegment) - assert isinstance(variable.value['key1'], StringVariable) - assert isinstance(variable.value['key2'], IntegerVariable) + assert isinstance(variable.value['key1'], str) + assert isinstance(variable.value['key2'], int) def test_array_string_variable(): @@ -106,26 +92,14 @@ def test_array_string_variable(): 'name': 'test_array', 'description': 'Description of the variable.', 'value': [ - { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, + 'text', + 'text', ], } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayStringVariable) - assert isinstance(variable.value[0], StringVariable) - assert isinstance(variable.value[1], StringVariable) + assert isinstance(variable.value[0], str) + assert isinstance(variable.value[1], str) def test_array_number_variable(): @@ -135,26 +109,14 @@ def test_array_number_variable(): 'name': 'test_array', 'description': 'Description of the variable.', 'value': [ - { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, - { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 2.0, - 'description': 'Description of the variable.', - }, + 1, + 2.0, ], } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayNumberVariable) - assert isinstance(variable.value[0], IntegerVariable) - assert isinstance(variable.value[1], FloatVariable) + assert isinstance(variable.value[0], int) + assert isinstance(variable.value[1], float) def test_array_object_variable(): @@ -165,143 +127,32 @@ def test_array_object_variable(): 'description': 'Description of the variable.', 'value': [ { - 'id': str(uuid4()), - 'value_type': 'object', - 'name': 'object', - 'description': 'Description of the variable.', - 'value': { - 'key1': { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - 'key2': { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, - }, + 'key1': 'text', + 'key2': 1, }, { - 'id': str(uuid4()), - 'value_type': 'object', - 'name': 'object', - 'description': 'Description of the variable.', - 'value': { - 'key1': { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - 'key2': { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, - }, + 'key1': 'text', + 'key2': 1, }, ], } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayObjectVariable) - assert isinstance(variable.value[0], ObjectSegment) - assert isinstance(variable.value[1], ObjectSegment) - assert isinstance(variable.value[0].value['key1'], StringVariable) - assert isinstance(variable.value[0].value['key2'], IntegerVariable) - assert isinstance(variable.value[1].value['key1'], StringVariable) - assert isinstance(variable.value[1].value['key2'], IntegerVariable) + assert isinstance(variable.value[0], dict) + assert isinstance(variable.value[1], dict) + assert isinstance(variable.value[0]['key1'], str) + assert isinstance(variable.value[0]['key2'], int) + assert isinstance(variable.value[1]['key1'], str) + assert isinstance(variable.value[1]['key2'], int) -def test_file_variable(): - mapping = { - 'id': str(uuid4()), - 'value_type': 'file', - 'name': 'test_file', - 'description': 'Description of the variable.', - 'value': { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - } - variable = factory.build_variable_from_mapping(mapping) - assert isinstance(variable, FileVariable) - - -def test_array_file_variable(): - mapping = { - 'id': str(uuid4()), - 'value_type': 'array[file]', - 'name': 'test_array_file', - 'description': 'Description of the variable.', - 'value': [ - { - 'id': str(uuid4()), - 'name': 'file', - 'value_type': 'file', - 'value': { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - }, +def test_variable_cannot_large_than_5_kb(): + with pytest.raises(VariableError): + factory.build_variable_from_mapping( { 'id': str(uuid4()), - 'name': 'file', - 'value_type': 'file', - 'value': { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - }, - ], - } - variable = factory.build_variable_from_mapping(mapping) - assert isinstance(variable, ArrayFileVariable) - assert isinstance(variable.value[0], FileVariable) - assert isinstance(variable.value[1], FileVariable) + 'value_type': 'string', + 'name': 'test_text', + 'value': 'a' * 1024 * 6, + } + ) diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 414404b7d0362a..7e3e69ffbfc45d 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -1,7 +1,7 @@ from core.app.segments import SecretVariable, StringSegment, parser from core.helper import encrypter -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable def test_segment_group_to_text(): diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index fd284488b548fe..d24cd4aae98ded 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,8 +2,8 @@ import pytest -from core.app.app_config.entities import FileExtraConfig, ModelConfigEntity -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.app.app_config.entities import ModelConfigEntity +from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage from core.prompt.advanced_prompt_transform import AdvancedPromptTransform diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 3a32829e373c28..4617b6a42f8ec2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -1,8 +1,8 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import UserFrom from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 4662c5ff2b26d8..d21b7785c4f4a4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,8 +1,8 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py new file mode 100644 index 00000000000000..0b37d06fc069bc --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -0,0 +1,150 @@ +from unittest import mock +from uuid import uuid4 + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.segments import ArrayStringVariable, StringVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode + +DEFAULT_NODE_ID = 'node_id' + + +def test_overwrite_string_variable(): + conversation_variable = StringVariable( + id=str(uuid4()), + name='test_conversation_variable', + value='the first value', + ) + + input_variable = StringVariable( + id=str(uuid4()), + name='test_string_variable', + value='the second value', + ) + + node = VariableAssignerNode( + tenant_id='tenant_id', + app_id='app_id', + workflow_id='workflow_id', + user_id='user_id', + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config={ + 'id': 'node_id', + 'data': { + 'assigned_variable_selector': ['conversation', conversation_variable.name], + 'write_mode': WriteMode.OVER_WRITE.value, + 'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + + variable_pool = VariablePool( + system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + variable_pool.add( + [DEFAULT_NODE_ID, input_variable.name], + input_variable, + ) + + with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + node.run(variable_pool) + mock_run.assert_called_once() + + got = variable_pool.get(['conversation', conversation_variable.name]) + assert got is not None + assert got.value == 'the second value' + assert got.to_object() == 'the second value' + + +def test_append_variable_to_array(): + conversation_variable = ArrayStringVariable( + id=str(uuid4()), + name='test_conversation_variable', + value=['the first value'], + ) + + input_variable = StringVariable( + id=str(uuid4()), + name='test_string_variable', + value='the second value', + ) + + node = VariableAssignerNode( + tenant_id='tenant_id', + app_id='app_id', + workflow_id='workflow_id', + user_id='user_id', + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config={ + 'id': 'node_id', + 'data': { + 'assigned_variable_selector': ['conversation', conversation_variable.name], + 'write_mode': WriteMode.APPEND.value, + 'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + + variable_pool = VariablePool( + system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + variable_pool.add( + [DEFAULT_NODE_ID, input_variable.name], + input_variable, + ) + + with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + node.run(variable_pool) + mock_run.assert_called_once() + + got = variable_pool.get(['conversation', conversation_variable.name]) + assert got is not None + assert got.to_object() == ['the first value', 'the second value'] + + +def test_clear_array(): + conversation_variable = ArrayStringVariable( + id=str(uuid4()), + name='test_conversation_variable', + value=['the first value'], + ) + + node = VariableAssignerNode( + tenant_id='tenant_id', + app_id='app_id', + workflow_id='workflow_id', + user_id='user_id', + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config={ + 'id': 'node_id', + 'data': { + 'assigned_variable_selector': ['conversation', conversation_variable.name], + 'write_mode': WriteMode.CLEAR.value, + 'input_variable_selector': [], + }, + }, + ) + + variable_pool = VariablePool( + system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + + node.run(variable_pool) + + got = variable_pool.get(['conversation', conversation_variable.name]) + assert got is not None + assert got.to_object() == [] diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py new file mode 100644 index 00000000000000..9e16010d7ef5a4 --- /dev/null +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -0,0 +1,25 @@ +from uuid import uuid4 + +from core.app.segments import SegmentType, factory +from models import ConversationVariable + + +def test_from_variable_and_to_variable(): + variable = factory.build_variable_from_mapping( + { + 'id': str(uuid4()), + 'name': 'name', + 'value_type': SegmentType.OBJECT, + 'value': { + 'key': { + 'key': 'value', + } + }, + } + ) + + conversation_variable = ConversationVariable.from_variable( + app_id='app_id', conversation_id='conversation_id', variable=variable + ) + + assert conversation_variable.to_variable() == variable diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index 22373199043d6c..eefe374df0762e 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -2,7 +2,7 @@ import pytest -from core.helper.position_helper import get_position_map +from core.helper.position_helper import get_position_map, sort_and_filter_position_map @pytest.fixture @@ -53,3 +53,47 @@ def test_position_helper_with_all_commented(prepare_empty_commented_positions_ya folder_path=prepare_empty_commented_positions_yaml, file_name='example_positions_all_commented.yaml') assert position_map == {} + + +def test_excluded_position_map(prepare_example_positions_yaml): + position_map = get_position_map( + folder_path=prepare_example_positions_yaml, + file_name='example_positions.yaml' + ) + pin_list = ['forth', 'first'] + include_list = [] + exclude_list = ['9999999999999'] + sorted_filtered_position_map = sort_and_filter_position_map( + original_position_map=position_map, + pin_list=pin_list, + include_list=include_list, + exclude_list=exclude_list + ) + assert sorted_filtered_position_map == { + 'forth': 0, + 'first': 1, + 'second': 2, + 'third': 3, + } + + +def test_included_position_map(prepare_example_positions_yaml): + position_map = get_position_map( + folder_path=prepare_example_positions_yaml, + file_name='example_positions.yaml' + ) + pin_list = ['second', 'first'] + include_list = ['first', 'second', 'third', 'forth'] + exclude_list = [] + sorted_filtered_position_map = sort_and_filter_position_map( + original_position_map=position_map, + pin_list=pin_list, + include_list=include_list, + exclude_list=exclude_list + ) + assert sorted_filtered_position_map == { + 'second': 0, + 'first': 1, + 'third': 2, + 'forth': 3, + } diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index c954c528fb2499..0b23200dc33b0e 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -7,4 +7,5 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/pgvector \ api/tests/integration_tests/vdb/qdrant \ api/tests/integration_tests/vdb/weaviate \ + api/tests/integration_tests/vdb/elasticsearch \ api/tests/integration_tests/vdb/test_vector_store.py \ No newline at end of file diff --git a/dev/reformat b/dev/reformat index f50ccb04c44ed1..ad83e897d978bd 100755 --- a/dev/reformat +++ b/dev/reformat @@ -11,5 +11,8 @@ fi # run ruff linter ruff check --fix ./api +# run ruff formatter +ruff format ./api + # run dotenv-linter linter dotenv-linter ./api/.env.example ./web/.env.example diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 807946f3fea820..aed2586053ceb0 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.6.16 + image: langgenius/dify-api:0.7.0 restart: always environment: # Startup mode, 'api' starts the API server. @@ -169,6 +169,11 @@ services: CHROMA_DATABASE: default_database CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider CHROMA_AUTH_CREDENTIALS: xxxxxx + # ElasticSearch Config + ELASTICSEARCH_HOST: 127.0.0.1 + ELASTICSEARCH_PORT: 9200 + ELASTICSEARCH_USERNAME: elastic + ELASTICSEARCH_PASSWORD: elastic # Mail configuration, support: resend, smtp MAIL_TYPE: '' # default send from email address, if not specified @@ -224,7 +229,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.6.16 + image: langgenius/dify-api:0.7.0 restart: always environment: CONSOLE_WEB_URL: '' @@ -371,6 +376,11 @@ services: CHROMA_DATABASE: default_database CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider CHROMA_AUTH_CREDENTIALS: xxxxxx + # ElasticSearch Config + ELASTICSEARCH_HOST: 127.0.0.1 + ELASTICSEARCH_PORT: 9200 + ELASTICSEARCH_USERNAME: elastic + ELASTICSEARCH_PASSWORD: elastic # Notion import configuration, support public and internal NOTION_INTEGRATION_TYPE: public NOTION_CLIENT_SECRET: you-client-secret @@ -390,7 +400,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.6.16 + image: langgenius/dify-web:0.7.0 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index 6fee8b4b3ca5b3..5898d3e62a8d16 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -695,3 +695,22 @@ COMPOSE_PROFILES=${VECTOR_STORE:-weaviate} # ------------------------------ EXPOSE_NGINX_PORT=80 EXPOSE_NGINX_SSL_PORT=443 + +# ---------------------------------------------------------------------------- +# ModelProvider & Tool Position Configuration +# Used to specify the model providers and tools that can be used in the app. +# ---------------------------------------------------------------------------- + +# Pin, include, and exclude tools +# Use comma-separated values with no spaces between items. +# Example: POSITION_TOOL_PINS=bing,google +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +# Pin, include, and exclude model providers +# Use comma-separated values with no spaces between items. +# Example: POSITION_PROVIDER_PINS=openai,openllm +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= \ No newline at end of file diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index d0d96018d41c62..f3151bbc2ad2af 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -125,6 +125,10 @@ x-shared-env: &shared-api-worker-env CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database} CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider} CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-} + ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-127.0.0.1} + ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} + ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} + ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} # AnalyticDB configuration ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-} @@ -187,7 +191,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.6.16 + image: langgenius/dify-api:0.7.0 restart: always environment: # Use the shared environment variables. @@ -207,7 +211,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.6.16 + image: langgenius/dify-api:0.7.0 restart: always environment: # Use the shared environment variables. @@ -226,12 +230,13 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.6.16 + image: langgenius/dify-web:0.7.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} APP_API_URL: ${APP_API_URL:-} SENTRY_DSN: ${WEB_SENTRY_DSN:-} + NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} # The postgres database. db: @@ -582,7 +587,7 @@ services: # MyScale vector database myscale: container_name: myscale - image: myscale/myscaledb:1.6 + image: myscale/myscaledb:1.6.4 profiles: - myscale restart: always @@ -594,6 +599,27 @@ services: ports: - "${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}" + elasticsearch: + image: docker.elastic.co/elasticsearch/elasticsearch:8.14.3 + container_name: elasticsearch + profiles: + - elasticsearch + restart: always + environment: + - "ELASTIC_PASSWORD=${ELASTICSEARCH_USERNAME:-elastic}" + - "cluster.name=dify-es-cluster" + - "node.name=dify-es0" + - "discovery.type=single-node" + - "xpack.security.http.ssl.enabled=false" + - "xpack.license.self_generated.type=trial" + ports: + - "${ELASTICSEARCH_PORT:-9200}:${ELASTICSEARCH_PORT:-9200}" + healthcheck: + test: ["CMD", "curl", "-s", "http://localhost:9200/_cluster/health?pretty"] + interval: 30s + timeout: 10s + retries: 50 + # unstructured . # (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.) unstructured: diff --git a/web/.env.example b/web/.env.example index 653913033d8d05..439092c20e0a0e 100644 --- a/web/.env.example +++ b/web/.env.example @@ -13,3 +13,6 @@ NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api # SENTRY NEXT_PUBLIC_SENTRY_DSN= + +# Disable Next.js Telemetry (https://nextjs.org/telemetry) +NEXT_TELEMETRY_DISABLED=1 \ No newline at end of file diff --git a/web/Dockerfile b/web/Dockerfile index 56957f0927010f..48bdb2301ad206 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -39,6 +39,7 @@ ENV DEPLOY_ENV=PRODUCTION ENV CONSOLE_API_URL=http://127.0.0.1:5001 ENV APP_API_URL=http://127.0.0.1:5001 ENV PORT=3000 +ENV NEXT_TELEMETRY_DISABLED=1 # set timezone ENV TZ=UTC diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 36395d391de1b3..44c5964d77736b 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -922,6 +922,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge ID + + Document ID + Document Segment ID @@ -965,6 +968,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge ID + + Document ID + Document Segment ID diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index a624c0594feffb..9f79b0f900287d 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -922,6 +922,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 知识库 ID + + 文档 ID + 文档分段ID @@ -965,6 +968,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 知识库 ID + + 文档 ID + 文档分段ID diff --git a/web/app/components/base/badge.tsx b/web/app/components/base/badge.tsx index 3e5414fa2cba81..c3300a1e67e590 100644 --- a/web/app/components/base/badge.tsx +++ b/web/app/components/base/badge.tsx @@ -4,16 +4,19 @@ import cn from '@/utils/classnames' type BadgeProps = { className?: string text: string + uppercase?: boolean } const Badge = ({ className, text, + uppercase = true, }: BadgeProps) => { return (
diff --git a/web/app/components/base/icons/assets/vender/line/others/bubble-x.svg b/web/app/components/base/icons/assets/vender/line/others/bubble-x.svg new file mode 100644 index 00000000000000..6e4df5b9b843bb --- /dev/null +++ b/web/app/components/base/icons/assets/vender/line/others/bubble-x.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/web/app/components/base/icons/assets/vender/line/others/long-arrow-left.svg b/web/app/components/base/icons/assets/vender/line/others/long-arrow-left.svg new file mode 100644 index 00000000000000..7320664db67618 --- /dev/null +++ b/web/app/components/base/icons/assets/vender/line/others/long-arrow-left.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/base/icons/assets/vender/line/others/long-arrow-right.svg b/web/app/components/base/icons/assets/vender/line/others/long-arrow-right.svg new file mode 100644 index 00000000000000..733785a276f88c --- /dev/null +++ b/web/app/components/base/icons/assets/vender/line/others/long-arrow-right.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/base/icons/assets/vender/workflow/assigner.svg b/web/app/components/base/icons/assets/vender/workflow/assigner.svg new file mode 100644 index 00000000000000..b37fbce52672ed --- /dev/null +++ b/web/app/components/base/icons/assets/vender/workflow/assigner.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/web/app/components/base/icons/src/vender/line/others/BubbleX.json b/web/app/components/base/icons/src/vender/line/others/BubbleX.json new file mode 100644 index 00000000000000..0cb5702c1f606b --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/BubbleX.json @@ -0,0 +1,57 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "Icon L" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "Vector" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "fill-rule": "evenodd", + "clip-rule": "evenodd", + "d": "M3.33463 3.33333C2.96643 3.33333 2.66796 3.63181 2.66796 4V10.6667C2.66796 11.0349 2.96643 11.3333 3.33463 11.3333H4.66796C5.03615 11.3333 5.33463 11.6318 5.33463 12V12.8225L7.65833 11.4283C7.76194 11.3662 7.8805 11.3333 8.00132 11.3333H12.0013C12.3695 11.3333 12.668 11.0349 12.668 10.6667C12.668 10.2985 12.9665 10 13.3347 10C13.7028 10 14.0013 10.2985 14.0013 10.6667C14.0013 11.7713 13.1058 12.6667 12.0013 12.6667H8.18598L5.01095 14.5717C4.805 14.6952 4.5485 14.6985 4.33949 14.5801C4.13049 14.4618 4.00129 14.2402 4.00129 14V12.6667H3.33463C2.23006 12.6667 1.33463 11.7713 1.33463 10.6667V4C1.33463 2.89543 2.23006 2 3.33463 2H6.66798C7.03617 2 7.33464 2.29848 7.33464 2.66667C7.33464 3.03486 7.03617 3.33333 6.66798 3.33333H3.33463Z", + "fill": "currentColor" + }, + "children": [] + }, + { + "type": "element", + "name": "path", + "attributes": { + "fill-rule": "evenodd", + "clip-rule": "evenodd", + "d": "M8.74113 2.66667C8.74113 2.29848 9.03961 2 9.4078 2H10.331C10.9721 2 11.5177 2.43571 11.6859 3.04075L11.933 3.93004L12.8986 2.77189C13.3045 2.28508 13.9018 2 14.536 2H14.5954C14.9636 2 15.2621 2.29848 15.2621 2.66667C15.2621 3.03486 14.9636 3.33333 14.5954 3.33333H14.536C14.3048 3.33333 14.08 3.43702 13.9227 3.6257L12.367 5.49165L12.8609 7.2689C12.8746 7.31803 12.9105 7.33333 12.9312 7.33333H13.8543C14.2225 7.33333 14.521 7.63181 14.521 8C14.521 8.36819 14.2225 8.66667 13.8543 8.66667H12.9312C12.29 8.66667 11.7444 8.23095 11.5763 7.62591L11.3291 6.73654L10.3634 7.89478C9.95758 8.38159 9.36022 8.66667 8.72604 8.66667H8.66666C8.29847 8.66667 7.99999 8.36819 7.99999 8C7.99999 7.63181 8.29847 7.33333 8.66666 7.33333H8.72604C8.95723 7.33333 9.18204 7.22965 9.33935 7.04096L10.8951 5.17493L10.4012 3.39777C10.3876 3.34863 10.3516 3.33333 10.331 3.33333H9.4078C9.03961 3.33333 8.74113 3.03486 8.74113 2.66667Z", + "fill": "currentColor" + }, + "children": [] + } + ] + } + ] + } + ] + }, + "name": "BubbleX" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/line/others/BubbleX.tsx b/web/app/components/base/icons/src/vender/line/others/BubbleX.tsx new file mode 100644 index 00000000000000..7d78bd33c7a92a --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/BubbleX.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './BubbleX.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'BubbleX' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.json b/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.json new file mode 100644 index 00000000000000..d2646b10909f3a --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.json @@ -0,0 +1,27 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "21", + "height": "8", + "viewBox": "0 0 21 8", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M0.646446 3.64645C0.451185 3.84171 0.451185 4.15829 0.646446 4.35355L3.82843 7.53553C4.02369 7.7308 4.34027 7.7308 4.53553 7.53553C4.7308 7.34027 4.7308 7.02369 4.53553 6.82843L1.70711 4L4.53553 1.17157C4.7308 0.976311 4.7308 0.659728 4.53553 0.464466C4.34027 0.269204 4.02369 0.269204 3.82843 0.464466L0.646446 3.64645ZM21 3.5L1 3.5V4.5L21 4.5V3.5Z", + "fill": "currentColor", + "fill-opacity": "0.3" + }, + "children": [] + } + ] + }, + "name": "LongArrowLeft" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.tsx b/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.tsx new file mode 100644 index 00000000000000..930ced5360d798 --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './LongArrowLeft.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'LongArrowLeft' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/line/others/LongArrowRight.json b/web/app/components/base/icons/src/vender/line/others/LongArrowRight.json new file mode 100644 index 00000000000000..7582b81568b3fe --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/LongArrowRight.json @@ -0,0 +1,27 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "26", + "height": "8", + "viewBox": "0 0 26 8", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M25.3536 4.35355C25.5488 4.15829 25.5488 3.84171 25.3536 3.64644L22.1716 0.464465C21.9763 0.269202 21.6597 0.269202 21.4645 0.464465C21.2692 0.659727 21.2692 0.976309 21.4645 1.17157L24.2929 4L21.4645 6.82843C21.2692 7.02369 21.2692 7.34027 21.4645 7.53553C21.6597 7.73079 21.9763 7.73079 22.1716 7.53553L25.3536 4.35355ZM3.59058e-08 4.5L25 4.5L25 3.5L-3.59058e-08 3.5L3.59058e-08 4.5Z", + "fill": "currentColor", + "fill-opacity": "0.3" + }, + "children": [] + } + ] + }, + "name": "LongArrowRight" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/line/others/LongArrowRight.tsx b/web/app/components/base/icons/src/vender/line/others/LongArrowRight.tsx new file mode 100644 index 00000000000000..3c9084cada9c4f --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/LongArrowRight.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './LongArrowRight.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'LongArrowRight' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/line/others/index.ts b/web/app/components/base/icons/src/vender/line/others/index.ts index 282a39499f3651..d54f31e4a9ce7a 100644 --- a/web/app/components/base/icons/src/vender/line/others/index.ts +++ b/web/app/components/base/icons/src/vender/line/others/index.ts @@ -1,8 +1,11 @@ export { default as Apps02 } from './Apps02' +export { default as BubbleX } from './BubbleX' export { default as Colors } from './Colors' export { default as DragHandle } from './DragHandle' export { default as Env } from './Env' export { default as Exchange02 } from './Exchange02' export { default as FileCode } from './FileCode' export { default as Icon3Dots } from './Icon3Dots' +export { default as LongArrowLeft } from './LongArrowLeft' +export { default as LongArrowRight } from './LongArrowRight' export { default as Tools } from './Tools' diff --git a/web/app/components/base/icons/src/vender/workflow/Assigner.json b/web/app/components/base/icons/src/vender/workflow/Assigner.json new file mode 100644 index 00000000000000..7106e5ad439179 --- /dev/null +++ b/web/app/components/base/icons/src/vender/workflow/Assigner.json @@ -0,0 +1,68 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "variable assigner" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "Vector" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "fill-rule": "evenodd", + "clip-rule": "evenodd", + "d": "M1.71438 4.42875C1.71438 3.22516 2.68954 2.25 3.89313 2.25C4.30734 2.25 4.64313 2.58579 4.64313 3C4.64313 3.41421 4.30734 3.75 3.89313 3.75C3.51796 3.75 3.21438 4.05359 3.21438 4.42875V7.28563C3.21438 7.48454 3.13536 7.6753 2.9947 7.81596L2.81066 8L2.9947 8.18404C3.13536 8.3247 3.21438 8.51546 3.21438 8.71437V11.5713C3.21438 11.9464 3.51796 12.25 3.89313 12.25C4.30734 12.25 4.64313 12.5858 4.64313 13C4.64313 13.4142 4.30734 13.75 3.89313 13.75C2.68954 13.75 1.71438 12.7748 1.71438 11.5713V9.02503L1.21967 8.53033C1.07902 8.38968 1 8.19891 1 8C1 7.80109 1.07902 7.61032 1.21967 7.46967L1.71438 6.97497V4.42875ZM11.3568 3C11.3568 2.58579 11.6925 2.25 12.1068 2.25C13.3103 2.25 14.2855 3.22516 14.2855 4.42875V6.97497L14.7802 7.46967C14.9209 7.61032 14.9999 7.80109 14.9999 8C14.9999 8.19891 14.9209 8.38968 14.7802 8.53033L14.2855 9.02503V11.5713C14.2855 12.7751 13.3095 13.75 12.1068 13.75C11.6925 13.75 11.3568 13.4142 11.3568 13C11.3568 12.5858 11.6925 12.25 12.1068 12.25C12.4815 12.25 12.7855 11.9462 12.7855 11.5713V8.71437C12.7855 8.51546 12.8645 8.3247 13.0052 8.18404L13.1892 8L13.0052 7.81596C12.8645 7.6753 12.7855 7.48454 12.7855 7.28563V4.42875C12.7855 4.05359 12.4819 3.75 12.1068 3.75C11.6925 3.75 11.3568 3.41421 11.3568 3Z", + "fill": "currentColor" + }, + "children": [] + }, + { + "type": "element", + "name": "path", + "attributes": { + "fill-rule": "evenodd", + "clip-rule": "evenodd", + "d": "M5.25 6C5.25 5.58579 5.58579 5.25 6 5.25H10C10.4142 5.25 10.75 5.58579 10.75 6C10.75 6.41421 10.4142 6.75 10 6.75H6C5.58579 6.75 5.25 6.41421 5.25 6Z", + "fill": "currentColor" + }, + "children": [] + }, + { + "type": "element", + "name": "path", + "attributes": { + "fill-rule": "evenodd", + "clip-rule": "evenodd", + "d": "M5.25 10C5.25 9.58579 5.58579 9.25 6 9.25H10C10.4142 9.25 10.75 9.58579 10.75 10C10.75 10.4142 10.4142 10.75 10 10.75H6C5.58579 10.75 5.25 10.4142 5.25 10Z", + "fill": "currentColor" + }, + "children": [] + } + ] + } + ] + } + ] + }, + "name": "Assigner" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/workflow/Assigner.tsx b/web/app/components/base/icons/src/vender/workflow/Assigner.tsx new file mode 100644 index 00000000000000..1cb7d692dd91a0 --- /dev/null +++ b/web/app/components/base/icons/src/vender/workflow/Assigner.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './Assigner.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'Assigner' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/workflow/index.ts b/web/app/components/base/icons/src/vender/workflow/index.ts index 94e20ae6a9e1f3..a2563a6a36584f 100644 --- a/web/app/components/base/icons/src/vender/workflow/index.ts +++ b/web/app/components/base/icons/src/vender/workflow/index.ts @@ -1,4 +1,5 @@ export { default as Answer } from './Answer' +export { default as Assigner } from './Assigner' export { default as Code } from './Code' export { default as End } from './End' export { default as Home } from './Home' diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index 0fb34de2e8b369..5ab82494463698 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -2,7 +2,7 @@ import type { SVGProps } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' -import s from './style.module.css' +import cn from 'classnames' type InputProps = { placeholder?: string @@ -27,10 +27,10 @@ const Input = ({ value, defaultValue, onChange, className = '', wrapperClassName const { t } = useTranslation() return (
- {showPrefix && {prefixIcon ?? }} + {showPrefix && {prefixIcon ?? }} { diff --git a/web/app/components/base/input/style.module.css b/web/app/components/base/input/style.module.css deleted file mode 100644 index 5f2782777d6d31..00000000000000 --- a/web/app/components/base/input/style.module.css +++ /dev/null @@ -1,7 +0,0 @@ -.input { - @apply inline-flex h-7 w-full py-1 px-2 rounded-lg text-xs leading-normal; - @apply bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-white placeholder:text-gray-400; -} -.prefix { - @apply whitespace-nowrap absolute left-2 self-center -} diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index 3adb4d75e1990c..af4b13ff70deaf 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -1,4 +1,5 @@ import ReactMarkdown from 'react-markdown' +import ReactEcharts from 'echarts-for-react' import 'katex/dist/katex.min.css' import RemarkMath from 'remark-math' import RemarkBreaks from 'remark-breaks' @@ -13,6 +14,7 @@ import cn from '@/utils/classnames' import CopyBtn from '@/app/components/base/copy-btn' import SVGBtn from '@/app/components/base/svg' import Flowchart from '@/app/components/base/mermaid' +import ImageGallery from '@/app/components/base/image-gallery' // Available language https://github.com/react-syntax-highlighter/react-syntax-highlighter/blob/master/AVAILABLE_LANGUAGES_HLJS.MD const capitalizationLanguageNameMap: Record = { @@ -30,6 +32,7 @@ const capitalizationLanguageNameMap: Record = { mermaid: 'Mermaid', markdown: 'MarkDown', makefile: 'MakeFile', + echarts: 'ECharts', } const getCorrectCapitalizationLanguageName = (language: string) => { if (!language) @@ -44,9 +47,9 @@ const getCorrectCapitalizationLanguageName = (language: string) => { const preprocessLaTeX = (content: string) => { if (typeof content !== 'string') return content - return content.replace(/\\\[(.*?)\\\]/gs, (_, equation) => `$$${equation}$$`) - .replace(/\\\((.*?)\\\)/gs, (_, equation) => `$$${equation}$$`) - .replace(/(^|[^\\])\$(.+?)\$/gs, (_, prefix, equation) => `${prefix}$${equation}$`) + return content.replace(/\\\[(.*?)\\\]/g, (_, equation) => `$$${equation}$$`) + .replace(/\\\((.*?)\\\)/g, (_, equation) => `$$${equation}$$`) + .replace(/(^|[^\\])\$(.+?)\$/g, (_, prefix, equation) => `${prefix}$${equation}$`) } export function PreCode(props: { children: any }) { @@ -56,12 +59,6 @@ export function PreCode(props: { children: any }) {
        {
-          if (ref.current) {
-            const code = ref.current.innerText
-            // copyToClipboard(code);
-          }
-        }}
       >
       {props.children}
     
@@ -107,6 +104,14 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } const match = /language-(\w+)/.exec(className || '') const language = match?.[1] const languageShowName = getCorrectCapitalizationLanguageName(language || '') + let chartData = JSON.parse(String('{"title":{"text":"Something went wrong."}}').replace(/\n$/, '')) + if (language === 'echarts') { + try { + chartData = JSON.parse(String(children).replace(/\n$/, '')) + } + catch (error) { + } + } // Use `useMemo` to ensure that `SyntaxHighlighter` only re-renders when necessary return useMemo(() => { @@ -136,19 +141,25 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props }
{(language === 'mermaid' && isSVG) ? () - : ( - {String(children).replace(/\n$/, '')} - )} + : ( + (language === 'echarts') + ? (
+
) + : ( + {String(children).replace(/\n$/, '')} + ))}
) : ( @@ -156,7 +167,7 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } {children} ) - }, [children, className, inline, isSVG, language, languageShowName, match, props]) + }, [chartData, children, className, inline, isSVG, language, languageShowName, match, props]) }) CodeBlock.displayName = 'CodeBlock' @@ -172,17 +183,9 @@ export function Markdown(props: { content: string; className?: string }) { ]} components={{ code: CodeBlock, - img({ src, alt, ...props }) { + img({ src }) { return ( - // eslint-disable-next-line @next/next/no-img-element - {alt} + ) }, p: (paragraph) => { @@ -192,14 +195,7 @@ export function Markdown(props: { content: string; className?: string }) { return ( <> - {/* eslint-disable-next-line @next/next/no-img-element */} - {image.properties.alt} +

{paragraph.children.slice(1)}

) diff --git a/web/app/components/base/prompt-editor/index.tsx b/web/app/components/base/prompt-editor/index.tsx index da70d04ac1b105..deae6833cd69f0 100644 --- a/web/app/components/base/prompt-editor/index.tsx +++ b/web/app/components/base/prompt-editor/index.tsx @@ -144,7 +144,7 @@ const PromptEditor: FC = ({ return ( -
+
} placeholder={} diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx index e149f5b75a198a..39193fc31d6ae8 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx @@ -21,10 +21,10 @@ import { } from './index' import cn from '@/utils/classnames' import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' -import { Env } from '@/app/components/base/icons/src/vender/line/others' +import { BubbleX, Env } from '@/app/components/base/icons/src/vender/line/others' import { VarBlockIcon } from '@/app/components/workflow/block-icon' import { Line3 } from '@/app/components/base/icons/src/public/common' -import { isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' +import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' import TooltipPlus from '@/app/components/base/tooltip-plus' type WorkflowVariableBlockComponentProps = { @@ -52,6 +52,7 @@ const WorkflowVariableBlockComponent = ({ const [localWorkflowNodesMap, setLocalWorkflowNodesMap] = useState(workflowNodesMap) const node = localWorkflowNodesMap![variables[0]] const isEnv = isENV(variables) + const isChatVar = isConversationVar(variables) useEffect(() => { if (!editor.hasNodes([WorkflowVariableBlockNode])) @@ -75,11 +76,11 @@ const WorkflowVariableBlockComponent = ({ className={cn( 'mx-0.5 relative group/wrap flex items-center h-[18px] pl-0.5 pr-[3px] rounded-[5px] border select-none', isSelected ? ' border-[#84ADFF] bg-[#F5F8FF]' : ' border-black/5 bg-white', - !node && !isEnv && '!border-[#F04438] !bg-[#FEF3F2]', + !node && !isEnv && !isChatVar && '!border-[#F04438] !bg-[#FEF3F2]', )} ref={ref} > - {!isEnv && ( + {!isEnv && !isChatVar && (
{ node?.type && ( @@ -97,11 +98,12 @@ const WorkflowVariableBlockComponent = ({
)}
- {!isEnv && } + {!isEnv && !isChatVar && } {isEnv && } -
{varName}
+ {isChatVar && } +
{varName}
{ - !node && !isEnv && ( + !node && !isEnv && !isChatVar && ( ) } @@ -109,7 +111,7 @@ const WorkflowVariableBlockComponent = ({
) - if (!node && !isEnv) { + if (!node && !isEnv && !isChatVar) { return ( {Item} diff --git a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx index 347572c755ae1a..a22ec16c252288 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx @@ -19,7 +19,7 @@ const ModelIcon: FC = ({ }) => { const language = useLanguage() - if (provider?.provider === 'openai' && modelName?.startsWith('gpt-4')) + if (provider?.provider === 'openai' && (modelName?.startsWith('gpt-4') || modelName?.includes('4o'))) return if (provider?.icon_small) { diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index 57ea4bdd118fed..eced2a8082bb86 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -100,7 +100,7 @@ const ParameterItem: FC = ({ handleInputChange(v === 1) } - const handleStringInputChange = (e: React.ChangeEvent) => { + const handleStringInputChange = (e: React.ChangeEvent) => { handleInputChange(e.target.value) } @@ -190,6 +190,16 @@ const ParameterItem: FC = ({ ) } + if (parameterRule.type === 'text') { + return ( +