diff --git a/.gitignore b/.gitignore index 29374d13dc..27cf8a4ba3 100644 --- a/.gitignore +++ b/.gitignore @@ -175,6 +175,8 @@ docker/volumes/pgvector/data/* docker/volumes/pgvecto_rs/data/* docker/nginx/conf.d/default.conf +docker/nginx/ssl/* +!docker/nginx/ssl/.gitkeep docker/middleware.env sdks/python-client/build diff --git a/LICENSE b/LICENSE index 06b0fa1d12..d7b8373839 100644 --- a/LICENSE +++ b/LICENSE @@ -6,8 +6,9 @@ Dify is licensed under the Apache License 2.0, with the following additional con a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment. - Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations. - -b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components. + +b. LOGO and copyright information: In the process of using Dify's frontend, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend. + - Frontend Definition: For the purposes of this license, the "frontend" of Dify includes all components located in the `web/` directory when running Dify from the raw source code, or the "web" image when running Dify with Docker. Please contact business@dify.ai by email to inquire about licensing matters. diff --git a/api/.env.example b/api/.env.example index b78f6c612e..f2b8a19dda 100644 --- a/api/.env.example +++ b/api/.env.example @@ -42,7 +42,7 @@ DB_DATABASE=dify # Storage configuration # use for store upload files, private keys... -# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs, supabase +# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase STORAGE_TYPE=local STORAGE_LOCAL_PATH=storage S3_USE_AWS_MANAGED_IAM=false @@ -233,6 +233,8 @@ VIKINGDB_SOCKET_TIMEOUT=30 UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 +UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 +UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 # Model Configuration MULTIMODAL_SEND_IMAGE_FORMAT=base64 @@ -310,6 +312,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000 WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 +MAX_VARIABLE_SIZE=204800 # App configuration APP_MAX_EXECUTION_TIME=1200 @@ -338,3 +341,6 @@ INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 # Marketplace configuration MARKETPLACE_ENABLED=true MARKETPLACE_API_URL=https://marketplace.dify.ai + +# Reset password token expiry minutes +RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 diff --git a/api/.vscode/launch.json.example b/api/.vscode/launch.json.example index e9f8e42dd5..b9e32e2511 100644 --- a/api/.vscode/launch.json.example +++ b/api/.vscode/launch.json.example @@ -1,8 +1,15 @@ { "version": "0.2.0", + "compounds": [ + { + "name": "Launch Flask and Celery", + "configurations": ["Python: Flask", "Python: Celery"] + } + ], "configurations": [ { "name": "Python: Flask", + "consoleName": "Flask", "type": "debugpy", "request": "launch", "python": "${workspaceFolder}/.venv/bin/python", @@ -17,12 +24,12 @@ }, "args": [ "run", - "--host=0.0.0.0", "--port=5001" ] }, { "name": "Python: Celery", + "consoleName": "Celery", "type": "debugpy", "request": "launch", "python": "${workspaceFolder}/.venv/bin/python", @@ -45,10 +52,10 @@ "-c", "1", "--loglevel", - "info", + "DEBUG", "-Q", "dataset,generation,mail,ops_trace,app_deletion" ] - }, + } ] -} \ No newline at end of file +} diff --git a/api/Dockerfile b/api/Dockerfile index d32f70321d..466e6df578 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,7 +55,7 @@ RUN apt-get update \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \ + && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.1-1 \ && apt-get autoremove -y \ && rm -rf /var/lib/apt/lists/* diff --git a/api/app.py b/api/app.py index 52dd492225..a3efabf06c 100644 --- a/api/app.py +++ b/api/app.py @@ -10,44 +10,20 @@ if os.environ.get("DEBUG", "false").lower() != "true": grpc.experimental.gevent.init_gevent() import json -import logging -import sys import threading import time import warnings -from logging.handlers import RotatingFileHandler -from flask import Flask, Response, request -from flask_cors import CORS -from werkzeug.exceptions import Unauthorized +from flask import Response -import contexts -from commands import register_commands -from configs import dify_config +from app_factory import create_app # DO NOT REMOVE BELOW from events import event_handlers # noqa: F401 -from extensions import ( - ext_celery, - ext_code_based_extension, - ext_compress, - ext_database, - ext_hosting_provider, - ext_login, - ext_mail, - ext_migrate, - ext_proxy_fix, - ext_redis, - ext_sentry, - ext_storage, -) from extensions.ext_database import db -from extensions.ext_login import login_manager -from libs.passport import PassportService # TODO: Find a way to avoid importing models here from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 -from services.account_service import AccountService # DO NOT REMOVE ABOVE @@ -60,188 +36,12 @@ if hasattr(time, "tzset"): time.tzset() -class DifyApp(Flask): - pass - - # ------------- # Configuration # ------------- - - config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first -# ---------------------------- -# Application Factory Function -# ---------------------------- - - -def create_flask_app_with_configs() -> Flask: - """ - create a raw flask app - with configs loaded from .env file - """ - dify_app = DifyApp(__name__) - dify_app.config.from_mapping(dify_config.model_dump()) - - # populate configs into system environment variables - for key, value in dify_app.config.items(): - if isinstance(value, str): - os.environ[key] = value - elif isinstance(value, int | float | bool): - os.environ[key] = str(value) - elif value is None: - os.environ[key] = "" - - return dify_app - - -def create_app() -> Flask: - app = create_flask_app_with_configs() - - app.secret_key = app.config["SECRET_KEY"] - - log_handlers = None - log_file = app.config.get("LOG_FILE") - if log_file: - log_dir = os.path.dirname(log_file) - os.makedirs(log_dir, exist_ok=True) - log_handlers = [ - RotatingFileHandler( - filename=log_file, - maxBytes=1024 * 1024 * 1024, - backupCount=5, - ), - logging.StreamHandler(sys.stdout), - ] - - logging.basicConfig( - level=app.config.get("LOG_LEVEL"), - format=app.config.get("LOG_FORMAT"), - datefmt=app.config.get("LOG_DATEFORMAT"), - handlers=log_handlers, - force=True, - ) - log_tz = app.config.get("LOG_TZ") - if log_tz: - from datetime import datetime - - import pytz - - timezone = pytz.timezone(log_tz) - - def time_converter(seconds): - return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() - - for handler in logging.root.handlers: - handler.formatter.converter = time_converter - initialize_extensions(app) - register_blueprints(app) - register_commands(app) - - return app - - -def initialize_extensions(app): - # Since the application instance is now created, pass it to each Flask - # extension instance to bind it to the Flask application instance (app) - ext_compress.init_app(app) - ext_code_based_extension.init() - ext_database.init_app(app) - ext_migrate.init(app, db) - ext_redis.init_app(app) - ext_storage.init_app(app) - ext_celery.init_app(app) - ext_login.init_app(app) - ext_mail.init_app(app) - ext_hosting_provider.init_app(app) - ext_sentry.init_app(app) - ext_proxy_fix.init_app(app) - - -# Flask-Login configuration -@login_manager.request_loader -def load_user_from_request(request_from_flask_login): - """Load user based on the request.""" - if request.blueprint not in {"console", "inner_api"}: - return None - # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get("Authorization", "") - if not auth_header: - auth_token = request.args.get("_token") - if not auth_token: - raise Unauthorized("Invalid Authorization token.") - else: - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - decoded = PassportService().verify(auth_token) - user_id = decoded.get("user_id") - - logged_in_account = AccountService.load_logged_in_account(account_id=user_id) - if logged_in_account: - contexts.tenant_id.set(logged_in_account.current_tenant_id) - return logged_in_account - - -@login_manager.unauthorized_handler -def unauthorized_handler(): - """Handle unauthorized requests.""" - return Response( - json.dumps({"code": "unauthorized", "message": "Unauthorized."}), - status=401, - content_type="application/json", - ) - - -# register blueprint routers -def register_blueprints(app): - from controllers.console import bp as console_app_bp - from controllers.files import bp as files_bp - from controllers.inner_api import bp as inner_api_bp - from controllers.service_api import bp as service_api_bp - from controllers.web import bp as web_bp - - CORS( - service_api_bp, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - ) - app.register_blueprint(service_api_bp) - - CORS( - web_bp, - resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(web_bp) - - CORS( - console_app_bp, - resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(console_app_bp) - - CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) - app.register_blueprint(files_bp) - - app.register_blueprint(inner_api_bp) - - # create app app = create_app() celery = app.extensions["celery"] diff --git a/api/app_factory.py b/api/app_factory.py new file mode 100644 index 0000000000..04654c2699 --- /dev/null +++ b/api/app_factory.py @@ -0,0 +1,213 @@ +import os + +if os.environ.get("DEBUG", "false").lower() != "true": + from gevent import monkey + + monkey.patch_all() + + import grpc.experimental.gevent + + grpc.experimental.gevent.init_gevent() + +import json +import logging +import sys +from logging.handlers import RotatingFileHandler + +from flask import Flask, Response, request +from flask_cors import CORS +from werkzeug.exceptions import Unauthorized + +import contexts +from commands import register_commands +from configs import dify_config +from extensions import ( + ext_celery, + ext_code_based_extension, + ext_compress, + ext_database, + ext_hosting_provider, + ext_login, + ext_mail, + ext_migrate, + ext_proxy_fix, + ext_redis, + ext_sentry, + ext_storage, +) +from extensions.ext_database import db +from extensions.ext_login import login_manager +from libs.passport import PassportService +from services.account_service import AccountService + + +class DifyApp(Flask): + pass + + +# ---------------------------- +# Application Factory Function +# ---------------------------- +def create_flask_app_with_configs() -> Flask: + """ + create a raw flask app + with configs loaded from .env file + """ + dify_app = DifyApp(__name__) + dify_app.config.from_mapping(dify_config.model_dump()) + + # populate configs into system environment variables + for key, value in dify_app.config.items(): + if isinstance(value, str): + os.environ[key] = value + elif isinstance(value, int | float | bool): + os.environ[key] = str(value) + elif value is None: + os.environ[key] = "" + + return dify_app + + +def create_app() -> Flask: + app = create_flask_app_with_configs() + + app.secret_key = app.config["SECRET_KEY"] + + log_handlers = None + log_file = app.config.get("LOG_FILE") + if log_file: + log_dir = os.path.dirname(log_file) + os.makedirs(log_dir, exist_ok=True) + log_handlers = [ + RotatingFileHandler( + filename=log_file, + maxBytes=1024 * 1024 * 1024, + backupCount=5, + ), + logging.StreamHandler(sys.stdout), + ] + + logging.basicConfig( + level=app.config.get("LOG_LEVEL"), + format=app.config.get("LOG_FORMAT"), + datefmt=app.config.get("LOG_DATEFORMAT"), + handlers=log_handlers, + force=True, + ) + log_tz = app.config.get("LOG_TZ") + if log_tz: + from datetime import datetime + + import pytz + + timezone = pytz.timezone(log_tz) + + def time_converter(seconds): + return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() + + for handler in logging.root.handlers: + handler.formatter.converter = time_converter + initialize_extensions(app) + register_blueprints(app) + register_commands(app) + + return app + + +def initialize_extensions(app): + # Since the application instance is now created, pass it to each Flask + # extension instance to bind it to the Flask application instance (app) + ext_compress.init_app(app) + ext_code_based_extension.init() + ext_database.init_app(app) + ext_migrate.init(app, db) + ext_redis.init_app(app) + ext_storage.init_app(app) + ext_celery.init_app(app) + ext_login.init_app(app) + ext_mail.init_app(app) + ext_hosting_provider.init_app(app) + ext_sentry.init_app(app) + ext_proxy_fix.init_app(app) + + +# Flask-Login configuration +@login_manager.request_loader +def load_user_from_request(request_from_flask_login): + """Load user based on the request.""" + if request.blueprint not in {"console", "inner_api"}: + return None + # Check if the user_id contains a dot, indicating the old format + auth_header = request.headers.get("Authorization", "") + if not auth_header: + auth_token = request.args.get("_token") + if not auth_token: + raise Unauthorized("Invalid Authorization token.") + else: + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + + decoded = PassportService().verify(auth_token) + user_id = decoded.get("user_id") + + logged_in_account = AccountService.load_logged_in_account(account_id=user_id) + if logged_in_account: + contexts.tenant_id.set(logged_in_account.current_tenant_id) + return logged_in_account + + +@login_manager.unauthorized_handler +def unauthorized_handler(): + """Handle unauthorized requests.""" + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) + + +# register blueprint routers +def register_blueprints(app): + from controllers.console import bp as console_app_bp + from controllers.files import bp as files_bp + from controllers.inner_api import bp as inner_api_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) + app.register_blueprint(service_api_bp) + + CORS( + web_bp, + resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(web_bp) + + CORS( + console_app_bp, + resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(console_app_bp) + + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) + app.register_blueprint(files_bp) + + app.register_blueprint(inner_api_bp) diff --git a/api/commands.py b/api/commands.py index dbcd8a744d..f2809be8e7 100644 --- a/api/commands.py +++ b/api/commands.py @@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair -from models.account import Tenant +from models import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation @@ -259,6 +259,25 @@ def migrate_knowledge_vector_database(): skipped_count = 0 total_count = 0 vector_type = dify_config.VECTOR_STORE + upper_colletion_vector_types = { + VectorType.MILVUS, + VectorType.PGVECTOR, + VectorType.RELYT, + VectorType.WEAVIATE, + VectorType.ORACLE, + VectorType.ELASTICSEARCH, + } + lower_colletion_vector_types = { + VectorType.ANALYTICDB, + VectorType.CHROMA, + VectorType.MYSCALE, + VectorType.PGVECTO_RS, + VectorType.TIDB_VECTOR, + VectorType.OPENSEARCH, + VectorType.TENCENT, + VectorType.BAIDU, + VectorType.VIKINGDB, + } page = 1 while True: try: @@ -284,11 +303,9 @@ def migrate_knowledge_vector_database(): skipped_count = skipped_count + 1 continue collection_name = "" - if vector_type == VectorType.WEAVIATE: - dataset_id = dataset.id + dataset_id = dataset.id + if vector_type in upper_colletion_vector_types: collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.QDRANT: if dataset.collection_binding_id: dataset_collection_binding = ( @@ -301,63 +318,15 @@ def migrate_knowledge_vector_database(): else: raise ValueError("Dataset Collection Binding not found") else: - dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.MILVUS: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.RELYT: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.TENCENT: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.PGVECTOR: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.OPENSEARCH: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.OPENSEARCH, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.ANALYTICDB: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.ANALYTICDB, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.ELASTICSEARCH: - dataset_id = dataset.id - index_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.BAIDU: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.BAIDU, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type in lower_colletion_vector_types: + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() else: raise ValueError(f"Vector store {vector_type} is not supported.") + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} + dataset.index_struct = json.dumps(index_struct_dict) vector = Vector(dataset) click.echo(f"Migrating dataset {dataset.id}.") @@ -457,14 +426,14 @@ def convert_to_agent_apps(): # fetch first 1000 apps sql_query = """SELECT a.id AS id FROM apps a INNER JOIN app_model_configs am ON a.app_model_config_id=am.id - WHERE a.mode = 'chat' - AND am.agent_mode is not null + WHERE a.mode = 'chat' + AND am.agent_mode is not null AND ( - am.agent_mode like '%"strategy": "function_call"%' + am.agent_mode like '%"strategy": "function_call"%' OR am.agent_mode like '%"strategy": "react"%' - ) + ) AND ( - am.agent_mode like '{"enabled": true%' + am.agent_mode like '{"enabled": true%' OR am.agent_mode like '{"max_iteration": %' ) ORDER BY a.created_at DESC LIMIT 1000 """ diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index b453779348..a941c8b673 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,6 +1,15 @@ -from typing import Annotated, Optional +from typing import Annotated, Literal, Optional -from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field +from pydantic import ( + AliasChoices, + Field, + HttpUrl, + NegativeInt, + NonNegativeInt, + PositiveFloat, + PositiveInt, + computed_field, +) from pydantic_settings import BaseSettings from configs.feature.hosted_service import HostedServiceConfig @@ -11,16 +20,16 @@ class SecurityConfig(BaseSettings): Security-related configurations for the application """ - SECRET_KEY: Optional[str] = Field( + SECRET_KEY: str = Field( description="Secret key for secure session cookie signing." "Make sure you are changing this key for your deployment with a strong key." "Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.", - default=None, + default="", ) - RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field( - description="Duration in hours for which a password reset token remains valid", - default=24, + RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a password reset token remains valid", + default=5, ) @@ -230,6 +239,16 @@ class FileUploadConfig(BaseSettings): default=10, ) + UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="video file size limit in Megabytes for uploading files", + default=100, + ) + + UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="audio file size limit in Megabytes for uploading files", + default=50, + ) + BATCH_UPLOAD_LIMIT: NonNegativeInt = Field( description="Maximum number of files allowed in a batch upload operation", default=20, @@ -408,8 +427,8 @@ class WorkflowConfig(BaseSettings): ) MAX_VARIABLE_SIZE: PositiveInt = Field( - description="Maximum size in bytes for a single variable in workflows. Default to 5KB.", - default=5 * 1024, + description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.", + default=200 * 1024, ) @@ -526,12 +545,18 @@ class MailConfig(BaseSettings): default=False, ) + EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field( + description="Maximum number of emails allowed to be sent from the same IP address in a minute", + default=50, + ) + class RagEtlConfig(BaseSettings): """ Configuration for RAG ETL processes """ + # TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config ETL_TYPE: str = Field( description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'", default="dify", @@ -598,7 +623,7 @@ class IndexingConfig(BaseSettings): class ImageFormatConfig(BaseSettings): - MULTIMODAL_SEND_IMAGE_FORMAT: str = Field( + MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field( description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", default="base64", ) @@ -667,6 +692,33 @@ class PositionConfig(BaseSettings): return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} +class LoginConfig(BaseSettings): + ENABLE_EMAIL_CODE_LOGIN: bool = Field( + description="whether to enable email code login", + default=False, + ) + ENABLE_EMAIL_PASSWORD_LOGIN: bool = Field( + description="whether to enable email password login", + default=True, + ) + ENABLE_SOCIAL_OAUTH_LOGIN: bool = Field( + description="whether to enable github/google oauth login", + default=False, + ) + EMAIL_CODE_LOGIN_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="expiry time in minutes for email code login token", + default=5, + ) + ALLOW_REGISTER: bool = Field( + description="whether to enable register", + default=False, + ) + ALLOW_CREATE_WORKSPACE: bool = Field( + description="whether to enable create workspace", + default=False, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -694,6 +746,7 @@ class FeatureConfig( UpdateConfig, WorkflowConfig, WorkspaceConfig, + LoginConfig, # hosted services config HostedServiceConfig, CeleryBeatConfig, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index fa7f41d630..84d03e2f45 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -35,7 +35,8 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig class StorageConfig(BaseSettings): STORAGE_TYPE: str = Field( description="Type of storage to use." - " Options: 'local', 's3', 'azure-blob', 'aliyun-oss', 'google-storage'. Default is 'local'.", + " Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', " + "'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.", default="local", ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index c832fb671e..635d12fc55 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.9.2", + default="0.10.0", ) COMMIT_SHA: str = Field( diff --git a/api/constants/__init__.py b/api/constants/__init__.py index 75eaf81638..66b9c0b632 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1,2 +1,22 @@ +from configs import dify_config + HIDDEN_VALUE = "[__HIDDEN__]" UUID_NIL = "00000000-0000-0000-0000-000000000000" + +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] +IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) + +VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"] +VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS]) + +AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"] +AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) + + +if dify_config.ETL_TYPE == "Unstructured": + DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"] + DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub")) + DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) +else: + DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] + DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 623a1a28eb..85380b7330 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -1,7 +1,9 @@ from contextvars import ContextVar +from typing import TYPE_CHECKING -from core.workflow.entities.variable_pool import VariablePool +if TYPE_CHECKING: + from core.workflow.entities.variable_pool import VariablePool tenant_id: ContextVar[str] = ContextVar("tenant_id") -workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool") +workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index c1e16b3b9b..b60a424d98 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -22,7 +22,8 @@ from fields.conversation_fields import ( ) from libs.helper import DatetimeString from libs.login import login_required -from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation +from models import Conversation, EndUser, Message, MessageAnnotation +from models.model import AppMode class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 26da1ef26d..115a832da9 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.login import login_required -from models.model import Site +from models import Site def parse_app_site_args(): diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 22a1fbb563..1ffdceb2c8 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -13,15 +13,15 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.segments import factory -from core.errors.error import AppInvokeQuotaExceededError +from factories import variable_factory from fields.workflow_fields import workflow_fields from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required +from models import App from models.account import Account -from models.model import App, AppMode +from models.model import AppMode from services.app_dsl_service import AppDslService from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError @@ -105,9 +105,13 @@ class DraftWorkflowApi(Resource): try: environment_variables_list = args.get("environment_variables") or [] - environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] conversation_variables_list = args.get("conversation_variables") or [] - conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=args["graph"], @@ -292,17 +296,15 @@ class DraftWorkflowRunApi(Resource): parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() - try: - response = AppGenerateService.generate( - app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True - ) + response = AppGenerateService.generate( + app_model=app_model, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) - return helper.compact_generate_response(response) - except (ValueError, AppInvokeQuotaExceededError) as e: - raise e - except Exception as e: - logging.exception("internal server error.") - raise InternalServerError() + return helper.compact_generate_response(response) class WorkflowTaskStopApi(Resource): diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index dc962409cc..629b7a8bf4 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -7,7 +7,8 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs.login import login_required -from models.model import App, AppMode +from models import App +from models.model import AppMode from services.workflow_app_service import WorkflowAppService diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index a055d03deb..5824ead9c3 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -13,7 +13,8 @@ from fields.workflow_run_fields import ( ) from libs.helper import uuid_value from libs.login import login_required -from models.model import App, AppMode +from models import App +from models.model import AppMode from services.workflow_run_service import WorkflowRunService diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index c7e54f2be0..f46af0f1ca 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -13,8 +13,8 @@ from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.helper import DatetimeString from libs.login import login_required +from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode -from models.workflow import WorkflowRunTriggeredFrom class WorkflowDailyRunsStatistic(Resource): diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 8a743d6be9..f84c592bba 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,7 +5,8 @@ from typing import Optional, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models.model import App, AppMode +from models import App +from models.model import AppMode def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index f3198dfc1d..be353cefac 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,17 +1,15 @@ -import base64 import datetime -import secrets +from flask import request from flask_restful import Resource, reqparse from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db -from libs.helper import StrLen, email, timezone -from libs.password import hash_password, valid_password -from models.account import AccountStatus -from services.account_service import RegisterService +from libs.helper import StrLen, email, extract_remote_ip, timezone +from models.account import AccountStatus, Tenant +from services.account_service import AccountService, RegisterService class ActivateCheckApi(Resource): @@ -27,8 +25,18 @@ class ActivateCheckApi(Resource): token = args["token"] invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) - - return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None} + if invitation: + data = invitation.get("data", {}) + tenant: Tenant = invitation.get("tenant", None) + workspace_name = tenant.name if tenant else None + workspace_id = tenant.id if tenant else None + invitee_email = data.get("email") if data else None + return { + "is_valid": invitation is not None, + "data": {"workspace_name": workspace_name, "workspace_id": workspace_id, "email": invitee_email}, + } + else: + return {"is_valid": False} class ActivateApi(Resource): @@ -38,7 +46,6 @@ class ActivateApi(Resource): parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") - parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument( "interface_language", type=supported_language, required=True, nullable=False, location="json" ) @@ -54,15 +61,6 @@ class ActivateApi(Resource): account = invitation["account"] account.name = args["name"] - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() - - # encrypt password with salt - password_hashed = hash_password(args["password"], salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt account.interface_language = args["interface_language"] account.timezone = args["timezone"] account.interface_theme = "light" @@ -70,7 +68,9 @@ class ActivateApi(Resource): account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {"result": "success"} + token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) + + return {"result": "success", "data": token_pair.model_dump()} api.add_resource(ActivateCheckApi, "/activate/check") diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index ea23e097d0..e6e30c3c0b 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -27,5 +27,29 @@ class InvalidTokenError(BaseHTTPException): class PasswordResetRateLimitExceededError(BaseHTTPException): error_code = "password_reset_rate_limit_exceeded" - description = "Password reset rate limit exceeded. Try again later." + description = "Too many password reset emails have been sent. Please try again in 1 minutes." + code = 429 + + +class EmailCodeError(BaseHTTPException): + error_code = "email_code_error" + description = "Email code is invalid or expired." + code = 400 + + +class EmailOrPasswordMismatchError(BaseHTTPException): + error_code = "email_or_password_mismatch" + description = "The email or password is mismatched." + code = 400 + + +class EmailPasswordLoginLimitError(BaseHTTPException): + error_code = "email_code_login_limit" + description = "Too many incorrect password attempts. Please try again later." + code = 429 + + +class EmailCodeLoginRateLimitExceededError(BaseHTTPException): + error_code = "email_code_login_rate_limit_exceeded" + description = "Too many login emails have been sent. Please try again in 5 minutes." code = 429 diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 0b01a4906a..7fea610610 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -1,65 +1,82 @@ import base64 -import logging import secrets +from flask import request from flask_restful import Resource, reqparse +from constants.languages import languages from controllers.console import api from controllers.console.auth.error import ( + EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError, - PasswordResetRateLimitExceededError, ) +from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister from controllers.console.setup import setup_required +from events.tenant_event import tenant_was_created from extensions.ext_database import db -from libs.helper import email as email_validate +from libs.helper import email, extract_remote_ip from libs.password import hash_password, valid_password from models.account import Account -from services.account_service import AccountService -from services.errors.account import RateLimitExceededError +from services.account_service import AccountService, TenantService +from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.feature_service import FeatureService class ForgotPasswordSendEmailApi(Resource): @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - email = args["email"] + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() - if not email_validate(email): - raise InvalidEmailError() - - account = Account.query.filter_by(email=email).first() - - if account: - try: - AccountService.send_reset_password_email(account=account) - except RateLimitExceededError: - logging.warning(f"Rate limit exceeded for email: {account.email}") - raise PasswordResetRateLimitExceededError() + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" else: - # Return success to avoid revealing email registration status - logging.warning(f"Attempt to reset password for unregistered email: {email}") + language = "en-US" - return {"result": "success"} + account = Account.query.filter_by(email=args["email"]).first() + token = None + if account is None: + if FeatureService.get_system_features().is_allow_register: + token = AccountService.send_reset_password_email(email=args["email"], language=language) + return {"result": "fail", "data": token, "code": "account_not_found"} + else: + raise NotAllowedRegister() + else: + token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) + + return {"result": "success", "data": token} class ForgotPasswordCheckApi(Resource): @setup_required def post(self): parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - token = args["token"] - reset_data = AccountService.get_reset_password_data(token) + user_email = args["email"] - if reset_data is None: - return {"is_valid": False, "email": None} - return {"is_valid": True, "email": reset_data.get("email")} + token_data = AccountService.get_reset_password_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + raise EmailCodeError() + + return {"is_valid": True, "email": token_data.get("email")} class ForgotPasswordResetApi(Resource): @@ -92,9 +109,26 @@ class ForgotPasswordResetApi(Resource): base64_password_hashed = base64.b64encode(password_hashed).decode() account = Account.query.filter_by(email=reset_data.get("email")).first() - account.password = base64_password_hashed - account.password_salt = base64_salt - db.session.commit() + if account: + account.password = base64_password_hashed + account.password_salt = base64_salt + db.session.commit() + tenant = TenantService.get_join_tenants(account) + if not tenant and not FeatureService.get_system_features().is_allow_create_workspace: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) + else: + try: + account = AccountService.create_account_and_tenant( + email=reset_data.get("email"), + name=reset_data.get("email"), + password=password_confirm, + interface_language=languages[0], + ) + except WorkSpaceNotAllowedCreateError: + pass return {"result": "success"} diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 18a7b23166..4821c543b7 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,16 +1,34 @@ from typing import cast import flask_login -from flask import request +from flask import redirect, request from flask_restful import Resource, reqparse import services +from configs import dify_config +from constants.languages import languages from controllers.console import api +from controllers.console.auth.error import ( + EmailCodeError, + EmailOrPasswordMismatchError, + EmailPasswordLoginLimitError, + InvalidEmailError, + InvalidTokenError, +) +from controllers.console.error import ( + AccountBannedError, + EmailSendIpLimitError, + NotAllowedCreateWorkspace, + NotAllowedRegister, +) from controllers.console.setup import setup_required +from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip from libs.password import valid_password from models.account import Account -from services.account_service import AccountService, TenantService +from services.account_service import AccountService, RegisterService, TenantService +from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.feature_service import FeatureService class LoginApi(Resource): @@ -23,15 +41,43 @@ class LoginApi(Resource): parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json") parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") + parser.add_argument("invite_token", type=str, required=False, default=None, location="json") + parser.add_argument("language", type=str, required=False, default="en-US", location="json") args = parser.parse_args() - # todo: Verify the recaptcha + is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) + if is_login_error_rate_limit: + raise EmailPasswordLoginLimitError() + + invitation = args["invite_token"] + if invitation: + invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" try: - account = AccountService.authenticate(args["email"], args["password"]) - except services.errors.account.AccountLoginError as e: - return {"code": "unauthorized", "message": str(e)}, 401 - + if invitation: + data = invitation.get("data", {}) + invitee_email = data.get("email") if data else None + if invitee_email != args["email"]: + raise InvalidEmailError() + account = AccountService.authenticate(args["email"], args["password"], args["invite_token"]) + else: + account = AccountService.authenticate(args["email"], args["password"]) + except services.errors.account.AccountLoginError: + raise AccountBannedError() + except services.errors.account.AccountPasswordError: + AccountService.add_login_error_rate_limit(args["email"]) + raise EmailOrPasswordMismatchError() + except services.errors.account.AccountNotFoundError: + if FeatureService.get_system_features().is_allow_register: + token = AccountService.send_reset_password_email(email=args["email"], language=language) + return {"result": "fail", "data": token, "code": "account_not_found"} + else: + raise NotAllowedRegister() # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: @@ -41,7 +87,7 @@ class LoginApi(Resource): } token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - + AccountService.reset_login_error_rate_limit(args["email"]) return {"result": "success", "data": token_pair.model_dump()} @@ -49,60 +95,114 @@ class LogoutApi(Resource): @setup_required def get(self): account = cast(Account, flask_login.current_user) + if isinstance(account, flask_login.AnonymousUserMixin): + return {"result": "success"} AccountService.logout(account=account) flask_login.logout_user() return {"result": "success"} -class ResetPasswordApi(Resource): +class ResetPasswordSendEmailApi(Resource): @setup_required - def get(self): - # parser = reqparse.RequestParser() - # parser.add_argument('email', type=email, required=True, location='json') - # args = parser.parse_args() + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() - # import mailchimp_transactional as MailchimpTransactional - # from mailchimp_transactional.api_client import ApiClientError + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" - # account = {'email': args['email']} - # account = AccountService.get_by_email(args['email']) - # if account is None: - # raise ValueError('Email not found') - # new_password = AccountService.generate_password() - # AccountService.update_password(account, new_password) + account = AccountService.get_user_through_email(args["email"]) + if account is None: + if FeatureService.get_system_features().is_allow_register: + token = AccountService.send_reset_password_email(email=args["email"], language=language) + else: + raise NotAllowedRegister() + else: + token = AccountService.send_reset_password_email(account=account, language=language) - # todo: Send email - # MAILCHIMP_API_KEY = dify_config.MAILCHIMP_TRANSACTIONAL_API_KEY - # mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY) + return {"result": "success", "data": token} - # message = { - # 'from_email': 'noreply@example.com', - # 'to': [{'email': account['email']}], - # 'subject': 'Reset your Dify password', - # 'html': """ - #

Dear User,

- #

The Dify team has generated a new password for you, details as follows:

- #

{new_password}

- #

Please change your password to log in as soon as possible.

- #

Regards,

- #

The Dify Team

- # """ - # } - # response = mailchimp.messages.send({ - # 'message': message, - # # required for transactional email - # ' settings': { - # 'sandbox_mode': dify_config.MAILCHIMP_SANDBOX_MODE, - # }, - # }) +class EmailCodeLoginSendEmailApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() - # Check if MSG was sent - # if response.status_code != 200: - # # handle error - # pass + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() - return {"result": "success"} + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + account = AccountService.get_user_through_email(args["email"]) + if account is None: + if FeatureService.get_system_features().is_allow_register: + token = AccountService.send_email_code_login_email(email=args["email"], language=language) + else: + raise NotAllowedRegister() + else: + token = AccountService.send_email_code_login_email(account=account, language=language) + + return {"result": "success", "data": token} + + +class EmailCodeLoginApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, location="json") + args = parser.parse_args() + + user_email = args["email"] + + token_data = AccountService.get_email_code_login_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if token_data["email"] != args["email"]: + raise InvalidEmailError() + + if token_data["code"] != args["code"]: + raise EmailCodeError() + + AccountService.revoke_email_code_login_token(args["token"]) + account = AccountService.get_user_through_email(user_email) + if account: + tenant = TenantService.get_join_tenants(account) + if not tenant: + if not FeatureService.get_system_features().is_allow_create_workspace: + raise NotAllowedCreateWorkspace() + else: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) + + if account is None: + try: + account = AccountService.create_account_and_tenant( + email=user_email, name=user_email, interface_language=languages[0] + ) + except WorkSpaceNotAllowedCreateError: + return redirect( + f"{dify_config.CONSOLE_WEB_URL}/signin" + "?message=Workspace not found, please contact system admin to invite you to join in a workspace." + ) + token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(args["email"]) + return {"result": "success", "data": token_pair.model_dump()} class RefreshTokenApi(Resource): @@ -120,4 +220,7 @@ class RefreshTokenApi(Resource): api.add_resource(LoginApi, "/login") api.add_resource(LogoutApi, "/logout") +api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") +api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") +api.add_resource(ResetPasswordSendEmailApi, "/reset-password") api.add_resource(RefreshTokenApi, "/refresh-token") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index c5909b8c10..282e69448e 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -5,14 +5,20 @@ from typing import Optional import requests from flask import current_app, redirect, request from flask_restful import Resource +from werkzeug.exceptions import Unauthorized from configs import dify_config from constants.languages import languages +from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo -from models.account import Account, AccountStatus +from models import Account +from models.account import AccountStatus from services.account_service import AccountService, RegisterService, TenantService +from services.errors.account import AccountNotFoundError +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError +from services.feature_service import FeatureService from .. import api @@ -42,6 +48,7 @@ def get_oauth_providers(): class OAuthLogin(Resource): def get(self, provider: str): + invite_token = request.args.get("invite_token") or None OAUTH_PROVIDERS = get_oauth_providers() with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) @@ -49,7 +56,7 @@ class OAuthLogin(Resource): if not oauth_provider: return {"error": "Invalid provider"}, 400 - auth_url = oauth_provider.get_authorization_url() + auth_url = oauth_provider.get_authorization_url(invite_token=invite_token) return redirect(auth_url) @@ -62,6 +69,11 @@ class OAuthCallback(Resource): return {"error": "Invalid provider"}, 400 code = request.args.get("code") + state = request.args.get("state") + invite_token = None + if state: + invite_token = state + try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) @@ -69,7 +81,27 @@ class OAuthCallback(Resource): logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") return {"error": "OAuth process failed"}, 400 - account = _generate_account(provider, user_info) + if invite_token and RegisterService.is_valid_invite_token(invite_token): + invitation = RegisterService._get_invitation_by_token(token=invite_token) + if invitation: + invitation_email = invitation.get("email", None) + if invitation_email != user_info.email: + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") + + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") + + try: + account = _generate_account(provider, user_info) + except AccountNotFoundError: + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") + except WorkSpaceNotFoundError: + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.") + except WorkSpaceNotAllowedCreateError: + return redirect( + f"{dify_config.CONSOLE_WEB_URL}/signin" + "?message=Workspace not found, please contact system admin to invite you to join in a workspace." + ) + # Check account status if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: return {"error": "Account is banned or closed."}, 403 @@ -79,7 +111,15 @@ class OAuthCallback(Resource): account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() - TenantService.create_owner_tenant_if_not_exist(account) + try: + TenantService.create_owner_tenant_if_not_exist(account) + except Unauthorized: + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.") + except WorkSpaceNotAllowedCreateError: + return redirect( + f"{dify_config.CONSOLE_WEB_URL}/signin" + "?message=Workspace not found, please contact system admin to invite you to join in a workspace." + ) token_pair = AccountService.login( account=account, @@ -104,8 +144,20 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) + if account: + tenant = TenantService.get_join_tenants(account) + if not tenant: + if not FeatureService.get_system_features().is_allow_create_workspace: + raise WorkSpaceNotAllowedCreateError() + else: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) + if not account: - # Create account + if not FeatureService.get_system_features().is_allow_register: + raise AccountNotFoundError() account_name = user_info.name or "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 0e1acab946..a2c9760782 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -15,8 +15,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from libs.login import login_required -from models.dataset import Document -from models.source import DataSourceOauthBinding +from models import DataSourceOauthBinding, Document from services.dataset_service import DatasetService, DocumentService from tasks.document_indexing_sync_task import document_indexing_sync_task diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 6583356d23..16a77ed880 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -24,8 +24,8 @@ from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.document_fields import document_status_fields from libs.login import login_required -from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment -from models.model import ApiToken, UploadFile +from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile +from models.dataset import DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ca6c571727..31b4f7b741 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -45,8 +45,7 @@ from fields.document_fields import ( document_with_segments_fields, ) from libs.login import login_required -from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment -from models.model import UploadFile +from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from services.dataset_service import DatasetService, DocumentService from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2405649387..08ea414288 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -24,7 +24,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import segment_fields from libs.login import login_required -from models.dataset import DocumentSegment +from models import DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index 846aa70e86..51be7e7a7d 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -1,9 +1,12 @@ +import urllib.parse + from flask import request from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with, reqparse import services from configs import dify_config +from constants import DOCUMENT_EXTENSIONS from controllers.console import api from controllers.console.datasets.error import ( FileTooLargeError, @@ -13,9 +16,10 @@ from controllers.console.datasets.error import ( ) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from fields.file_fields import file_fields, upload_config_fields +from core.helper import ssrf_proxy +from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields from libs.login import login_required -from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService +from services.file_service import FileService PREVIEW_WORDS_LIMIT = 3000 @@ -44,6 +48,10 @@ class FileApi(Resource): # get file from request file = request.files["file"] + parser = reqparse.RequestParser() + parser.add_argument("source", type=str, required=False, location="args") + source = parser.parse_args().get("source") + # check file if "file" not in request.files: raise NoFileUploadedError() @@ -51,7 +59,7 @@ class FileApi(Resource): if len(request.files) > 1: raise TooManyFilesError() try: - upload_file = FileService.upload_file(file, current_user) + upload_file = FileService.upload_file(file=file, user=current_user, source=source) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: @@ -75,11 +83,24 @@ class FileSupportTypeApi(Resource): @login_required @account_initialization_required def get(self): - etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS - return {"allowed_extensions": allowed_extensions} + return {"allowed_extensions": DOCUMENT_EXTENSIONS} + + +class RemoteFileInfoApi(Resource): + @marshal_with(remote_file_info_fields) + def get(self, url): + decoded_url = urllib.parse.unquote(url) + try: + response = ssrf_proxy.head(decoded_url) + return { + "file_type": response.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(response.headers.get("Content-Length", 0)), + } + except Exception as e: + return {"error": str(e)}, 400 api.add_resource(FileApi, "/files/upload") api.add_resource(FilePreviewApi, "/files//preview") api.add_resource(FileSupportTypeApi, "/files/support-type") +api.add_resource(RemoteFileInfoApi, "/remote-files/") diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 870e547728..a6d4c8e8ec 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -38,3 +38,27 @@ class AlreadyActivateError(BaseHTTPException): error_code = "already_activate" description = "Auth Token is invalid or account already activated, please check again." code = 403 + + +class NotAllowedCreateWorkspace(BaseHTTPException): + error_code = "unauthorized" + description = "Workspace not found, please contact system admin to invite you to join in a workspace." + code = 400 + + +class AccountBannedError(BaseHTTPException): + error_code = "account_banned" + description = "Account is banned." + code = 400 + + +class NotAllowedRegister(BaseHTTPException): + error_code = "unauthorized" + description = "Account not found." + code = 400 + + +class EmailSendIpLimitError(BaseHTTPException): + error_code = "email_send_ip_limit" + description = "Too many emails have been sent from this IP address recently. Please try again later." + code = 429 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 408afc33a0..d72715a38c 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields from libs.login import login_required -from models.model import App, InstalledApp, RecommendedApp +from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index a7ccf737a8..0fc9637479 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -18,7 +18,7 @@ message_fields = { "inputs": fields.Raw, "query": fields.String, "answer": fields.String, - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "created_at": TimestampField, } diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 3c9317847b..49ea81a8a0 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import login_required -from models.model import InstalledApp +from models import InstalledApp def installed_app_required(view=None): diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index dec426128f..97f5625726 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -20,7 +20,7 @@ from extensions.ext_database import db from fields.member_fields import account_fields from libs.helper import TimestampField, timezone from libs.login import login_required -from models.account import AccountIntegrate, InvitationCode +from models import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 14edc9ac13..b799c6380e 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -397,16 +397,15 @@ class ToolWorkflowProviderCreateApi(Resource): args = reqparser.parse_args() return WorkflowToolManageService.create_workflow_tool( - user_id, - tenant_id, - args["workflow_app_id"], - args["name"], - args["label"], - args["icon"], - args["description"], - args["parameters"], - args["privacy_policy"], - args.get("labels", []), + user_id=user_id, + tenant_id=tenant_id, + workflow_app_id=args["workflow_app_id"], + name=args["name"], + label=args["label"], + icon=args["icon"], + description=args["description"], + parameters=args["parameters"], + privacy_policy=args["privacy_policy"], ) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index af3ebc099b..96f866fca2 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource): raise UnsupportedFileTypeError() try: - upload_file = FileService.upload_file(file, current_user, True) + upload_file = FileService.upload_file(file=file, user=current_user) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index a56c1c332d..4b2d61e7c3 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -10,6 +10,10 @@ from services.file_service import FileService class ImagePreviewApi(Resource): + """ + Deprecated + """ + def get(self, file_id): file_id = str(file_id) @@ -21,7 +25,36 @@ class ImagePreviewApi(Resource): return {"content": "Invalid request."}, 400 try: - generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign) + generator, mimetype = FileService.get_image_preview( + file_id=file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return Response(generator, mimetype=mimetype) + + +class FilePreviewApi(Resource): + def get(self, file_id): + file_id = str(file_id) + + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") + sign = request.args.get("sign") + + if not timestamp or not nonce or not sign: + return {"content": "Invalid request."}, 400 + + try: + generator, mimetype = FileService.get_signed_file_preview( + file_id=file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() @@ -49,4 +82,5 @@ class WorkspaceWebappLogoApi(Resource): api.add_resource(ImagePreviewApi, "/files//image-preview") +api.add_resource(FilePreviewApi, "/files//file-preview") api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces//webapp-logo") diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 406cd42214..104b7cd9bb 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -16,6 +16,7 @@ class ToolFilePreviewApi(Resource): parser.add_argument("timestamp", type=str, required=True, location="args") parser.add_argument("nonce", type=str, required=True, location="args") parser.add_argument("sign", type=str, required=True, location="args") + parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") args = parser.parse_args() @@ -28,18 +29,27 @@ class ToolFilePreviewApi(Resource): raise Forbidden("Invalid request.") try: - result = ToolFileManager.get_file_generator_by_tool_file_id( + stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id( file_id, ) - if not result: + if not stream or not tool_file: raise NotFound("file is not found") - - generator, mimetype = result except Exception: raise UnsupportedFileTypeError() - return Response(generator, mimetype=mimetype) + response = Response( + stream, + mimetype=tool_file.mimetype, + direct_passthrough=True, + headers={ + "Content-Length": str(tool_file.size), + }, + ) + if args["as_attachment"]: + response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}" + + return response api.add_resource(ToolFilePreviewApi, "/files/tools/.") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index a70ee89b5e..d9a9fad13c 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -48,7 +48,7 @@ class MessageListApi(Resource): "tool_input": fields.String, "created_at": TimestampField, "observation": fields.String, - "message_files": fields.List(fields.String, attribute="files"), + "message_files": fields.List(fields.String), } message_fields = { @@ -58,7 +58,7 @@ class MessageListApi(Resource): "inputs": fields.Raw, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "created_at": TimestampField, diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py index 253b1d511c..c029a07707 100644 --- a/api/controllers/web/file.py +++ b/api/controllers/web/file.py @@ -1,11 +1,14 @@ +import urllib.parse + from flask import request -from flask_restful import marshal_with +from flask_restful import marshal_with, reqparse import services from controllers.web import api from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError from controllers.web.wraps import WebApiResource -from fields.file_fields import file_fields +from core.helper import ssrf_proxy +from fields.file_fields import file_fields, remote_file_info_fields from services.file_service import FileService @@ -15,6 +18,10 @@ class FileApi(WebApiResource): # get file from request file = request.files["file"] + parser = reqparse.RequestParser() + parser.add_argument("source", type=str, required=False, location="args") + source = parser.parse_args().get("source") + # check file if "file" not in request.files: raise NoFileUploadedError() @@ -22,7 +29,7 @@ class FileApi(WebApiResource): if len(request.files) > 1: raise TooManyFilesError() try: - upload_file = FileService.upload_file(file, end_user) + upload_file = FileService.upload_file(file=file, user=end_user, source=source) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: @@ -31,4 +38,19 @@ class FileApi(WebApiResource): return upload_file, 201 +class RemoteFileInfoApi(WebApiResource): + @marshal_with(remote_file_info_fields) + def get(self, url): + decoded_url = urllib.parse.unquote(url) + try: + response = ssrf_proxy.head(decoded_url) + return { + "file_type": response.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(response.headers.get("Content-Length", 0)), + } + except Exception as e: + return {"error": str(e)}, 400 + + api.add_resource(FileApi, "/files/upload") +api.add_resource(RemoteFileInfoApi, "/remote-files/") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 2d2a5866c8..98891f5d00 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields +from fields.raws import FilesContainedField from libs import helper from libs.helper import TimestampField, uuid_value from models.model import AppMode @@ -58,10 +59,10 @@ class MessageListApi(WebApiResource): "id": fields.String, "conversation_id": fields.String, "parent_message_id": fields.String, - "inputs": fields.Raw, + "inputs": FilesContainedField, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "created_at": TimestampField, diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 8253f5fc57..b0492e6b6f 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -17,7 +17,7 @@ message_fields = { "inputs": fields.Raw, "query": fields.String, "answer": fields.String, - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "created_at": TimestampField, } diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index eb33c62bce..b271048839 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -15,12 +15,12 @@ from core.app.entities.app_invoke_entities import ( ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.file.message_file_parser import MessageFileParser +from core.file import file_manager from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + LLMUsage, PromptMessage, PromptMessageContent, PromptMessageTool, @@ -38,9 +38,9 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool -from core.tools.utils.tool_parameter_converter import ToolParameterConverter from extensions.ext_database import db -from models.model import Conversation, Message, MessageAgentThought +from factories import file_factory +from models.model import Conversation, Message, MessageAgentThought, MessageFile logger = logging.getLogger(__name__) @@ -61,23 +61,6 @@ class BaseAgentRunner(AppRunner): memory: Optional[TokenBufferMemory] = None, prompt_messages: Optional[list[PromptMessage]] = None, ) -> None: - """ - Agent runner - :param tenant_id: tenant id - :param application_generate_entity: application generate entity - :param conversation: conversation - :param app_config: app generate entity - :param model_config: model config - :param config: dataset config - :param queue_manager: queue manager - :param message: message - :param user_id: user id - :param memory: memory - :param prompt_messages: prompt messages - :param variables_pool: variables pool - :param db_variables: db variables - :param model_instance: model instance - """ self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity self.conversation = conversation @@ -172,7 +155,7 @@ class BaseAgentRunner(AppRunner): if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) + parameter_type = parameter.type.as_normal_type() enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] @@ -259,7 +242,7 @@ class BaseAgentRunner(AppRunner): if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) + parameter_type = parameter.type.as_normal_type() enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] @@ -492,18 +475,15 @@ class BaseAgentRunner(AppRunner): return result def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: - message_file_parser = MessageFileParser( - tenant_id=self.tenant_id, - app_id=self.app_config.app_id, - ) - - files = message.message_files + files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: assert message.app_model_config file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.transform_message_files(files, file_extra_config) + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=self.tenant_id, config=file_extra_config + ) else: file_objs = [] @@ -512,7 +492,7 @@ class BaseAgentRunner(AppRunner): else: prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)] for file_obj in file_objs: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) return UserPromptMessage(content=prompt_message_contents) else: diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 095f8775ae..e5ac50d76b 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,7 +1,8 @@ import json from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import ( +from core.file import file_manager +from core.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, PromptMessageContent, @@ -38,7 +39,7 @@ class CotChatAgentRunner(CotAgentRunner): if self.files: prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)] for file_obj in self.files: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index cc4b1961ad..368424f593 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -7,9 +7,13 @@ from typing import Any, Optional, Union from core.agent.base_agent_runner import BaseAgentRunner from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.file import file_manager +from core.model_runtime.entities import ( AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, PromptMessage, PromptMessageContent, PromptMessageContentType, @@ -390,7 +394,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): if self.files: prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)] for file_obj in self.files: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index a1bfde3208..126eb0b41e 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -53,12 +53,11 @@ class BasicVariablesConfigManager: VariableEntity( type=variable_type, variable=variable.get("variable"), - description=variable.get("description"), + description=variable.get("description", ""), label=variable.get("label"), required=variable.get("required", False), max_length=variable.get("max_length"), - options=variable.get("options"), - default=variable.get("default"), + options=variable.get("options", []), ) ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 7e5899bafa..d8fa08c0a3 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,11 +1,12 @@ +from collections.abc import Sequence from enum import Enum from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field -from core.file.file_obj import FileExtraConfig +from core.file import FileExtraConfig, FileTransferMethod, FileType from core.model_runtime.entities.message_entities import PromptMessageRole -from models import AppMode +from models.model import AppMode class ModelConfigEntity(BaseModel): @@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel): ADVANCED = "advanced" @classmethod - def value_of(cls, value: str) -> "PromptType": + def value_of(cls, value: str): """ Get value of given mode. @@ -93,6 +94,8 @@ class VariableEntityType(str, Enum): PARAGRAPH = "paragraph" NUMBER = "number" EXTERNAL_DATA_TOOL = "external_data_tool" + FILE = "file" + FILE_LIST = "file-list" class VariableEntity(BaseModel): @@ -102,13 +105,14 @@ class VariableEntity(BaseModel): variable: str label: str - description: Optional[str] = None + description: str = "" type: VariableEntityType required: bool = False max_length: Optional[int] = None - options: Optional[list[str]] = None - default: Optional[str] = None - hint: Optional[str] = None + options: Sequence[str] = Field(default_factory=list) + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_file_extensions: Sequence[str] = Field(default_factory=list) + allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) class ExternalDataVariableEntity(BaseModel): @@ -136,7 +140,7 @@ class DatasetRetrieveConfigEntity(BaseModel): MULTIPLE = "multiple" @classmethod - def value_of(cls, value: str) -> "RetrieveStrategy": + def value_of(cls, value: str): """ Get value of given mode. diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 7a275cb532..6d301f6ea7 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,12 +1,13 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any -from core.file.file_obj import FileExtraConfig +from core.file.models import FileExtraConfig +from models import FileUploadConfig class FileUploadConfigManager: @classmethod - def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]: + def convert(cls, config: Mapping[str, Any], is_vision: bool = True): """ Convert model config to model config @@ -15,19 +16,18 @@ class FileUploadConfigManager: """ file_upload_dict = config.get("file_upload") if file_upload_dict: - if file_upload_dict.get("image"): - if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]: - image_config = { - "number_limits": file_upload_dict["image"]["number_limits"], - "transfer_methods": file_upload_dict["image"]["transfer_methods"], + if file_upload_dict.get("enabled"): + data = { + "image_config": { + "number_limits": file_upload_dict["number_limits"], + "transfer_methods": file_upload_dict["allowed_file_upload_methods"], } + } - if is_vision: - image_config["detail"] = file_upload_dict["image"]["detail"] + if is_vision: + data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low") - return FileExtraConfig(image_config=image_config) - - return None + return FileExtraConfig.model_validate(data) @classmethod def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]: @@ -39,29 +39,7 @@ class FileUploadConfigManager: """ if not config.get("file_upload"): config["file_upload"] = {} - - if not isinstance(config["file_upload"], dict): - raise ValueError("file_upload must be of dict type") - - # check image config - if not config["file_upload"].get("image"): - config["file_upload"]["image"] = {"enabled": False} - - if config["file_upload"]["image"]["enabled"]: - number_limits = config["file_upload"]["image"]["number_limits"] - if number_limits < 1 or number_limits > 6: - raise ValueError("number_limits must be in [1, 6]") - - if is_vision: - detail = config["file_upload"]["image"]["detail"] - if detail not in {"high", "low"}: - raise ValueError("detail must be in ['high', 'low']") - - transfer_methods = config["file_upload"]["image"]["transfer_methods"] - if not isinstance(transfer_methods, list): - raise ValueError("transfer_methods must be of list type") - for method in transfer_methods: - if method not in {"remote_url", "local_file"}: - raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") + else: + FileUploadConfig.model_validate(config["file_upload"]) return config, ["file_upload"] diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 4b117d87f8..2f1da38082 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager: # variables for variable in user_input_form: - variables.append(VariableEntity(**variable)) + variables.append(VariableEntity.model_validate(variable)) return variables diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index dc1aeeecff..18c5526c53 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -21,11 +21,12 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db +from factories import file_factory from models.account import Account +from models.enums import CreatedByRole from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow @@ -107,10 +108,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # parse files files = args["files"] if args.get("files") else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) else: file_objs = [] @@ -118,8 +125,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id - trace_manager = TraceQueueManager(app_model.id, user_id) + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) if invoke_from == InvokeFrom.DEBUGGER: # always enable retriever resource in debugger mode @@ -131,7 +139,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 1dcd051d15..65d744eddf 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,31 +1,27 @@ import logging -import os from collections.abc import Mapping from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session +from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, -) +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.queue_entities import ( QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent, ) from core.moderation.base import ModerationError -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import UserFrom +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from models.enums import UserFrom from models.model import App, Conversation, EndUser, Message from models.workflow import ConversationVariable, WorkflowType @@ -44,12 +40,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation: Conversation, message: Message, ) -> None: - """ - :param application_generate_entity: application generate entity - :param queue_manager: application queue manager - :param conversation: conversation - :param message: message - """ super().__init__(queue_manager) self.application_generate_entity = application_generate_entity @@ -57,10 +47,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self.message = message def run(self) -> None: - """ - Run application - :return: - """ app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) @@ -81,7 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id = self.application_generate_entity.user_id workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get("DEBUG", "False").lower() == "true"): + if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) if self.application_generate_entity.single_iteration_run: @@ -201,15 +187,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): query: str, message_id: str, ) -> bool: - """ - Handle input moderation - :param app_record: app record - :param app_generate_entity: application generate entity - :param inputs: inputs - :param query: query - :param message_id: message id - :return: - """ try: # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( @@ -229,14 +206,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): def handle_annotation_reply( self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity ) -> bool: - """ - Handle annotation reply - :param app_record: app record - :param message: message - :param query: query - :param app_generate_entity: application generate entity - """ - # annotation reply annotation_reply = self.query_app_annotations_to_reply( app_record=app_record, message=message, @@ -258,8 +227,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: """ Direct output - :param text: text - :return: """ self._publish_event(QueueTextChunkEvent(text=text)) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index fd63c7787f..e4cb3f8527 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,7 +1,7 @@ import json import logging import time -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Optional, Union from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -9,6 +9,7 @@ from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, + InvokeFrom, ) from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, @@ -50,10 +51,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes import NodeType from events.message_event import message_was_created from extensions.ext_database import db +from models import Conversation, EndUser, Message, MessageFile from models.account import Account -from models.model import Conversation, EndUser, Message +from models.enums import CreatedByRole from models.workflow import ( Workflow, WorkflowNodeExecution, @@ -120,6 +123,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._wip_workflow_node_executions = {} self._conversation_name_generate_thread = None + self._recorded_files: list[Mapping[str, Any]] = [] def process(self): """ @@ -298,6 +302,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) + # Record files if it's an answer node or end node + if event.node_type in [NodeType.ANSWER, NodeType.END]: + self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) + response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, @@ -364,7 +372,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - outputs=json.dumps(event.outputs) if event.outputs else None, + outputs=event.outputs, conversation_id=self._conversation.id, trace_manager=trace_manager, ) @@ -490,10 +498,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._conversation_name_generate_thread.join() def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: - """ - Save message. - :return: - """ self._refetch_message() self._message.answer = self._task_state.answer @@ -501,6 +505,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) + message_files = [ + MessageFile( + message_id=self._message.id, + type=file["type"], + transfer_method=file["transfer_method"], + url=file["remote_url"], + belongs_to="assistant", + upload_file_id=file["related_id"], + created_by_role=CreatedByRole.ACCOUNT + if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatedByRole.END_USER, + created_by=self._message.from_account_id or self._message.from_end_user_id or "", + ) + for file in self._recorded_files + ] + db.session.add_all(message_files) if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage @@ -540,7 +560,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras ) def _handle_output_moderation_chunk(self, text: str) -> bool: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 5379f14e73..b2b161cdca 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -18,12 +18,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser +from factories import file_factory +from models import Account, App, EndUser +from models.enums import CreatedByRole logger = logging.getLogger(__name__) @@ -108,12 +108,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode override_model_config_dict["retriever_resource"] = {"enabled": True} + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER + # parse files - files = args["files"] if args.get("files") else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + files = args.get("files") or [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) else: file_objs = [] @@ -126,8 +133,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id - trace_manager = TraceQueueManager(app_model.id, user_id) + trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) # init application generate entity application_generate_entity = AgentChatAppGenerateEntity( @@ -135,7 +141,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 06a18864b7..993f8e904d 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,36 +1,93 @@ import json from collections.abc import Generator, Mapping -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union -from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType +from core.app.app_config.entities import VariableEntityType +from core.file import File, FileExtraConfig +from factories import file_factory + +if TYPE_CHECKING: + from core.app.app_config.entities import AppConfig, VariableEntity + from models.enums import CreatedByRole class BaseAppGenerator: - def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]: + def _prepare_user_inputs( + self, + *, + user_inputs: Optional[Mapping[str, Any]], + app_config: "AppConfig", + user_id: str, + role: "CreatedByRole", + ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} - filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} - return filtered_inputs + user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} + user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} + # Convert files in inputs to File + entity_dictionary = {item.variable: item for item in app_config.variables} + # Convert single file to File + files_inputs = { + k: file_factory.build_from_mapping( + mapping=v, + tenant_id=app_config.tenant_id, + user_id=user_id, + role=role, + config=FileExtraConfig( + allowed_file_types=entity_dictionary[k].allowed_file_types, + allowed_extensions=entity_dictionary[k].allowed_file_extensions, + allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, + ), + ) + for k, v in user_inputs.items() + if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE + } + # Convert list of files to File + file_list_inputs = { + k: file_factory.build_from_mappings( + mappings=v, + tenant_id=app_config.tenant_id, + user_id=user_id, + role=role, + config=FileExtraConfig( + allowed_file_types=entity_dictionary[k].allowed_file_types, + allowed_extensions=entity_dictionary[k].allowed_file_extensions, + allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, + ), + ) + for k, v in user_inputs.items() + if isinstance(v, list) + # Ensure skip List + and all(isinstance(item, dict) for item in v) + and entity_dictionary[k].type == VariableEntityType.FILE_LIST + } + # Merge all inputs + user_inputs = {**user_inputs, **files_inputs, **file_list_inputs} - def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): - user_input_value = inputs.get(var.variable) - if var.required and not user_input_value: - raise ValueError(f"{var.variable} is required in input form") - if not var.required and not user_input_value: - # TODO: should we return None here if the default value is None? - return var.default or "" - if ( - var.type - in { - VariableEntityType.TEXT_INPUT, - VariableEntityType.SELECT, - VariableEntityType.PARAGRAPH, - } - and user_input_value - and not isinstance(user_input_value, str) + # Check if all files are converted to File + if any(filter(lambda v: isinstance(v, dict), user_inputs.values())): + raise ValueError("Invalid input type") + if any( + filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values())) ): + raise ValueError("Invalid input type") + + return user_inputs + + def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"): + user_input_value = inputs.get(var.variable) + if not user_input_value: + if var.required: + raise ValueError(f"{var.variable} is required in input form") + else: + return None + + if var.type in { + VariableEntityType.TEXT_INPUT, + VariableEntityType.SELECT, + VariableEntityType.PARAGRAPH, + } and not isinstance(user_input_value, str): raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): # may raise ValueError if user_input_value is not a valid number @@ -42,12 +99,24 @@ class BaseAppGenerator: except ValueError: raise ValueError(f"{var.variable} in input form must be a valid number") if var.type == VariableEntityType.SELECT: - options = var.options or [] + options = var.options if user_input_value not in options: raise ValueError(f"{var.variable} in input form must be one of the following: {options}") elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: - if var.max_length and user_input_value and len(user_input_value) > var.max_length: + if var.max_length and len(user_input_value) > var.max_length: raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") + elif var.type == VariableEntityType.FILE: + if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File): + raise ValueError(f"{var.variable} in input form must be a file") + elif var.type == VariableEntityType.FILE_LIST: + if not ( + isinstance(user_input_value, list) + and ( + all(isinstance(item, dict) for item in user_input_value) + or all(isinstance(item, File) for item in user_input_value) + ) + ): + raise ValueError(f"{var.variable} in input form must be a list of files") return user_input_value diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 203aca3384..609fd03f22 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File class AppRunner: @@ -37,7 +37,7 @@ class AppRunner: model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list["FileVar"], + files: list["File"], query: Optional[str] = None, ) -> int: """ @@ -137,7 +137,7 @@ class AppRunner: model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list["FileVar"], + files: list["File"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 46a4855508..12bcb5a777 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -18,11 +18,12 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db +from factories import file_factory from models.account import Account +from models.enums import CreatedByRole from models.model import App, EndUser logger = logging.getLogger(__name__) @@ -110,12 +111,19 @@ class ChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode override_model_config_dict["retriever_resource"] = {"enabled": True} + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER + # parse files files = args["files"] if args.get("files") else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) else: file_objs = [] @@ -128,7 +136,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - trace_manager = TraceQueueManager(app_model.id) + trace_manager = TraceQueueManager(app_id=app_model.id) # init application generate entity application_generate_entity = ChatAppGenerateEntity( @@ -136,15 +144,17 @@ class ChatAppGenerator(MessageBasedAppGenerator): app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, - stream=stream, invoke_from=invoke_from, extras=extras, trace_manager=trace_manager, + stream=stream, ) # init generate records diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 729ff1e1e0..7fb05192c7 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser, Message +from factories import file_factory +from models import Account, App, EndUser, Message +from models.enums import CreatedByRole from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError @@ -98,12 +98,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): tenant_id=app_model.tenant_id, config=args.get("model_config") ) + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER + # parse files files = args["files"] if args.get("files") else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) else: file_objs = [] @@ -113,6 +120,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # get tracing instance + user_id = user.id if isinstance(user, Account) else user.session_id trace_manager = TraceQueueManager(app_model.id) # init application generate entity @@ -120,7 +128,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), - inputs=self._get_cleaned_inputs(inputs, app_config), + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, user_id=user.id, @@ -261,10 +269,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator): override_model_config_dict["model"] = model_dict # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user) + file_objs = file_factory.build_from_mappings( + mappings=message.message_files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) else: file_objs = [] diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 65b759acf5..2b5597e055 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -26,7 +26,7 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db -from models.account import Account +from models import Account from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError @@ -235,13 +235,13 @@ class MessageBasedAppGenerator(BaseAppGenerator): for file in application_generate_entity.files: message_file = MessageFile( message_id=message.id, - type=file.type.value, - transfer_method=file.transfer_method.value, + type=file.type, + transfer_method=file.transfer_method, belongs_to="user", - url=file.url, + url=file.remote_url, upload_file_id=file.related_id, created_by_role=("account" if account_id else "end_user"), - created_by=account_id or end_user_id, + created_by=account_id or end_user_id or "", ) db.session.add(message_file) db.session.commit() diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index dd5f821869..9e7591545d 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -3,7 +3,7 @@ import logging import os import threading import uuid -from collections.abc import Generator +from collections.abc import Generator, Mapping, Sequence from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app @@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser -from models.workflow import Workflow +from factories import file_factory +from models import Account, App, EndUser, Workflow +from models.enums import CreatedByRole logger = logging.getLogger(__name__) @@ -75,49 +74,46 @@ class WorkflowAppGenerator(BaseAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: dict, + args: Mapping[str, Any], invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, ): - """ - Generate App response. + files: Sequence[Mapping[str, Any]] = args.get("files") or [] - :param app_model: App - :param workflow: Workflow - :param user: account or end user - :param args: request args - :param invoke_from: invoke from source - :param stream: is stream - :param call_depth: call depth - :param workflow_thread_pool_id: workflow thread pool id - """ - inputs = args["inputs"] + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER # parse files - files = args["files"] if args.get("files") else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) - else: - file_objs = [] + system_files = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) # convert to app config - app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow, + ) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id - trace_manager = TraceQueueManager(app_model.id, user_id) + trace_manager = TraceQueueManager( + app_id=app_model.id, + user_id=user.id if isinstance(user, Account) else user.session_id, + ) + inputs: Mapping[str, Any] = args["inputs"] workflow_run_id = str(uuid.uuid4()) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, - inputs=self._get_cleaned_inputs(inputs, app_config), - files=file_objs, + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + files=system_files, user_id=user.id, stream=stream, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 378a4bb8bc..faefcb0ed5 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -1,21 +1,20 @@ import logging -import os from typing import Optional, cast +from configs import dify_config from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, ) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import UserFrom +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from models.enums import UserFrom from models.model import App, EndUser from models.workflow import WorkflowType @@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): db.session.close() workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get("DEBUG", "False").lower() == "true"): + if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) # if only single iteration run is requested diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 7c53556e43..419a5da806 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,4 +1,3 @@ -import json import logging import time from collections.abc import Generator @@ -334,9 +333,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - outputs=json.dumps(event.outputs) - if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs - else None, + outputs=event.outputs, conversation_id=None, trace_manager=trace_manager, ) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index ce266116a7..ca23bbdd47 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -20,7 +20,6 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, @@ -41,9 +40,9 @@ from core.workflow.graph_engine.entities.event import ( ParallelBranchRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.iteration.entities import IterationNodeData -from core.workflow.nodes.node_mapping import node_classes +from core.workflow.nodes import NodeType +from core.workflow.nodes.iteration import IterationNodeData +from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App @@ -137,9 +136,8 @@ class WorkflowBasedAppRunner(AppRunner): raise ValueError("iteration node id not found in workflow graph") # Get node class - node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type")) - node_cls = node_classes.get(node_type) - node_cls = cast(type[BaseNode], node_cls) + node_type = NodeType(iteration_node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping[node_type] # init variable pool variable_pool = VariablePool( diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 98685513a3..f2eba29323 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from core.file.file_obj import FileVar +from core.file.models import File from core.model_runtime.entities.model_entities import AIModelEntity from core.ops.ops_trace_manager import TraceQueueManager @@ -23,7 +23,7 @@ class InvokeFrom(Enum): DEBUGGER = "debugger" @classmethod - def value_of(cls, value: str) -> "InvokeFrom": + def value_of(cls, value: str): """ Get value of given mode. @@ -82,7 +82,7 @@ class AppGenerateEntity(BaseModel): app_config: AppConfig inputs: Mapping[str, Any] - files: list[FileVar] = [] + files: Sequence[File] user_id: str # extras diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 4577e28535..bc43baf8a5 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -5,9 +5,10 @@ from typing import Any, Optional from pydantic import BaseModel, field_validator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes import NodeType +from core.workflow.nodes.base import BaseNodeData class QueueEvent(str, Enum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 49e5f55ebc..4b5f4716ed 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional @@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE_END id: str metadata: dict = {} + files: Optional[Sequence[Mapping[str, Any]]] = None class MessageFileStreamResponse(StreamResponse): @@ -211,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse): created_by: Optional[dict] = None created_at: int finished_at: int - files: Optional[list[dict]] = [] + files: Optional[Sequence[Mapping[str, Any]]] = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED workflow_run_id: str @@ -296,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse): execution_metadata: Optional[dict] = None created_at: int finished_at: int - files: Optional[list[dict]] = [] + files: Optional[Sequence[Mapping[str, Any]]] = [] parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None parent_parallel_id: Optional[str] = None diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py deleted file mode 100644 index 3c4d7046f4..0000000000 --- a/api/core/app/segments/parser.py +++ /dev/null @@ -1,18 +0,0 @@ -import re - -from core.workflow.entities.variable_pool import VariablePool - -from . import SegmentGroup, factory - -VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") - - -def convert_template(*, template: str, variable_pool: VariablePool): - parts = re.split(VARIABLE_PATTERN, template) - segments = [] - for part in filter(lambda x: x, parts): - if "." in part and (value := variable_pool.get(part.split("."))): - segments.append(value) - else: - segments.append(factory.build_segment(part)) - return SegmentGroup(value=segments) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index b8f5ac2603..138503d404 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,5 +1,6 @@ import json import time +from collections.abc import Mapping, Sequence from datetime import datetime, timezone from typing import Any, Optional, Union, cast @@ -27,27 +28,26 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, WorkflowTaskState, ) -from core.file.file_obj import FileVar +from core.file import FILE_MODEL_IDENTITY, File from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager -from core.workflow.entities.node_entities import NodeType from core.workflow.enums import SystemVariableKey +from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( - CreatedByRole, Workflow, WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, WorkflowRunStatus, - WorkflowRunTriggeredFrom, ) @@ -117,7 +117,7 @@ class WorkflowCycleManage: start_at: float, total_tokens: int, total_steps: int, - outputs: Optional[str] = None, + outputs: Mapping[str, Any] | None = None, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: @@ -133,8 +133,10 @@ class WorkflowCycleManage: """ workflow_run = self._refetch_workflow_run(workflow_run.id) + outputs = WorkflowEntry.handle_special_values(outputs) + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value - workflow_run.outputs = outputs + workflow_run.outputs = json.dumps(outputs or {}) workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps @@ -265,6 +267,7 @@ class WorkflowCycleManage: workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None @@ -276,7 +279,7 @@ class WorkflowCycleManage: { WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, + WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.execution_metadata: execution_metadata, WorkflowNodeExecution.finished_at: finished_at, @@ -286,10 +289,11 @@ class WorkflowCycleManage: db.session.commit() db.session.close() + process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.finished_at = finished_at @@ -308,6 +312,7 @@ class WorkflowCycleManage: workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) finished_at = datetime.now(timezone.utc).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() @@ -317,7 +322,7 @@ class WorkflowCycleManage: WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, + WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.finished_at: finished_at, WorkflowNodeExecution.elapsed_time: elapsed_time, @@ -326,11 +331,12 @@ class WorkflowCycleManage: db.session.commit() db.session.close() + process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = event.error workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time @@ -637,7 +643,7 @@ class WorkflowCycleManage: ), ) - def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: + def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -646,15 +652,15 @@ class WorkflowCycleManage: if not outputs_dict: return [] - files = [] - for output_var, output_value in outputs_dict.items(): - file_vars = self._fetch_files_from_variable_value(output_value) - if file_vars: - files.extend(file_vars) + files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] + # Remove None + files = [file for file in files if file] + # Flatten list + files = [file for sublist in files for file in sublist] return files - def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]: + def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: """ Fetch files from variable value :param value: variable value @@ -666,17 +672,17 @@ class WorkflowCycleManage: files = [] if isinstance(value, list): for item in value: - file_var = self._get_file_var_from_value(item) - if file_var: - files.append(file_var) + file = self._get_file_var_from_value(item) + if file: + files.append(file) elif isinstance(value, dict): - file_var = self._get_file_var_from_value(value) - if file_var: - files.append(file_var) + file = self._get_file_var_from_value(value) + if file: + files.append(file) return files - def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]: + def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None: """ Get file var from value :param value: variable value @@ -685,14 +691,11 @@ class WorkflowCycleManage: if not value: return None - if isinstance(value, dict): - if "__variant" in value and value["__variant"] == FileVar.__name__: - return value - elif isinstance(value, FileVar): + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + return value + elif isinstance(value, File): return value.to_dict() - return None - def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run diff --git a/api/core/embedding/embedding_constant.py b/api/core/entities/embedding_type.py similarity index 100% rename from api/core/embedding/embedding_constant.py rename to api/core/entities/embedding_type.py diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py deleted file mode 100644 index 10bc9f6ed7..0000000000 --- a/api/core/entities/message_entities.py +++ /dev/null @@ -1,29 +0,0 @@ -import enum -from typing import Any - -from pydantic import BaseModel - - -class PromptMessageFileType(enum.Enum): - IMAGE = "image" - - @staticmethod - def value_of(value): - for member in PromptMessageFileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class PromptMessageFile(BaseModel): - type: PromptMessageFileType - data: Any = None - - -class ImagePromptMessageFile(PromptMessageFile): - class DETAIL(enum.Enum): - LOW = "low" - HIGH = "high" - - type: PromptMessageFileType = PromptMessageFileType.IMAGE - detail: DETAIL = DETAIL.LOW diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 27ac0e3455..596a841e74 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -6,7 +6,24 @@ from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope from core.model_runtime.entities.model_entities import ModelType from core.tools.entities.common_entities import I18nObject -from models.provider import ProviderQuotaType + + +class ProviderQuotaType(Enum): + PAID = "paid" + """hosted paid quota""" + + FREE = "free" + """third-party free quota""" + + TRIAL = "trial" + """hosted trial quota""" + + @staticmethod + def value_of(value): + for member in ProviderQuotaType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") class QuotaUnit(Enum): diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py index e69de29bb2..bdaf8793fa 100644 --- a/api/core/file/__init__.py +++ b/api/core/file/__init__.py @@ -0,0 +1,19 @@ +from .constants import FILE_MODEL_IDENTITY +from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType +from .models import ( + File, + FileExtraConfig, + ImageConfig, +) + +__all__ = [ + "FileType", + "FileExtraConfig", + "FileTransferMethod", + "FileBelongsTo", + "File", + "ImageConfig", + "FileAttribute", + "ArrayFileAttribute", + "FILE_MODEL_IDENTITY", +] diff --git a/api/core/file/constants.py b/api/core/file/constants.py new file mode 100644 index 0000000000..ce1d238e93 --- /dev/null +++ b/api/core/file/constants.py @@ -0,0 +1 @@ +FILE_MODEL_IDENTITY = "__dify__file__" diff --git a/api/core/file/enums.py b/api/core/file/enums.py new file mode 100644 index 0000000000..f4153f1676 --- /dev/null +++ b/api/core/file/enums.py @@ -0,0 +1,55 @@ +from enum import Enum + + +class FileType(str, Enum): + IMAGE = "image" + DOCUMENT = "document" + AUDIO = "audio" + VIDEO = "video" + CUSTOM = "custom" + + @staticmethod + def value_of(value): + for member in FileType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileTransferMethod(str, Enum): + REMOTE_URL = "remote_url" + LOCAL_FILE = "local_file" + TOOL_FILE = "tool_file" + + @staticmethod + def value_of(value): + for member in FileTransferMethod: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileBelongsTo(str, Enum): + USER = "user" + ASSISTANT = "assistant" + + @staticmethod + def value_of(value): + for member in FileBelongsTo: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileAttribute(str, Enum): + TYPE = "type" + SIZE = "size" + NAME = "name" + MIME_TYPE = "mime_type" + TRANSFER_METHOD = "transfer_method" + URL = "url" + EXTENSION = "extension" + + +class ArrayFileAttribute(str, Enum): + LENGTH = "length" diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py new file mode 100644 index 0000000000..0c6ce8ce75 --- /dev/null +++ b/api/core/file/file_manager.py @@ -0,0 +1,156 @@ +import base64 + +from configs import dify_config +from core.file import file_repository +from core.helper import ssrf_proxy +from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent +from extensions.ext_database import db +from extensions.ext_storage import storage + +from . import helpers +from .enums import FileAttribute +from .models import File, FileTransferMethod, FileType +from .tool_file_parser import ToolFileParser + + +def get_attr(*, file: File, attr: FileAttribute): + match attr: + case FileAttribute.TYPE: + return file.type.value + case FileAttribute.SIZE: + return file.size + case FileAttribute.NAME: + return file.filename + case FileAttribute.MIME_TYPE: + return file.mime_type + case FileAttribute.TRANSFER_METHOD: + return file.transfer_method.value + case FileAttribute.URL: + return file.remote_url + case FileAttribute.EXTENSION: + return file.extension + case _: + raise ValueError(f"Invalid file attribute: {attr}") + + +def to_prompt_message_content(f: File, /): + """ + Convert a File object to an ImagePromptMessageContent object. + + This function takes a File object and converts it to an ImagePromptMessageContent + object, which can be used as a prompt for image-based AI models. + + Args: + file (File): The File object to convert. Must be of type FileType.IMAGE. + + Returns: + ImagePromptMessageContent: An object containing the image data and detail level. + + Raises: + ValueError: If the file is not an image or if the file data is missing. + + Note: + The detail level of the image prompt is determined by the file's extra_config. + If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW. + """ + match f.type: + case FileType.IMAGE: + if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": + data = _to_url(f) + else: + data = _to_base64_data_string(f) + + if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail: + detail = f._extra_config.image_config.detail + else: + detail = ImagePromptMessageContent.DETAIL.LOW + + return ImagePromptMessageContent(data=data, detail=detail) + case FileType.AUDIO: + encoded_string = _file_to_encoded_string(f) + if f.extension is None: + raise ValueError("Missing file extension") + return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) + case _: + raise ValueError(f"file type {f.type} is not supported") + + +def download(f: File, /): + upload_file = file_repository.get_upload_file(session=db.session(), file=f) + return _download_file_content(upload_file.key) + + +def _download_file_content(path: str, /): + """ + Download and return the contents of a file as bytes. + + This function loads the file from storage and ensures it's in bytes format. + + Args: + path (str): The path to the file in storage. + + Returns: + bytes: The contents of the file as a bytes object. + + Raises: + ValueError: If the loaded file is not a bytes object. + """ + data = storage.load(path, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {path} is not a bytes object") + return data + + +def _get_encoded_string(f: File, /): + match f.transfer_method: + case FileTransferMethod.REMOTE_URL: + response = ssrf_proxy.get(f.remote_url) + response.raise_for_status() + content = response.content + encoded_string = base64.b64encode(content).decode("utf-8") + return encoded_string + case FileTransferMethod.LOCAL_FILE: + upload_file = file_repository.get_upload_file(session=db.session(), file=f) + data = _download_file_content(upload_file.key) + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string + case FileTransferMethod.TOOL_FILE: + tool_file = file_repository.get_tool_file(session=db.session(), file=f) + data = _download_file_content(tool_file.file_key) + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string + case _: + raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + +def _to_base64_data_string(f: File, /): + encoded_string = _get_encoded_string(f) + return f"data:{f.mime_type};base64,{encoded_string}" + + +def _file_to_encoded_string(f: File, /): + match f.type: + case FileType.IMAGE: + return _to_base64_data_string(f) + case FileType.AUDIO: + return _get_encoded_string(f) + case _: + raise ValueError(f"file type {f.type} is not supported") + + +def _to_url(f: File, /): + if f.transfer_method == FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") + return f.remote_url + elif f.transfer_method == FileTransferMethod.LOCAL_FILE: + if f.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=f.related_id) + elif f.transfer_method == FileTransferMethod.TOOL_FILE: + # add sign url + if f.related_id is None or f.extension is None: + raise ValueError("Missing file related_id or extension") + return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension) + else: + raise ValueError(f"Unsupported transfer method: {f.transfer_method}") diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py deleted file mode 100644 index 5c4e694025..0000000000 --- a/api/core/file/file_obj.py +++ /dev/null @@ -1,145 +0,0 @@ -import enum -from typing import Any, Optional - -from pydantic import BaseModel - -from core.file.tool_file_parser import ToolFileParser -from core.file.upload_file_parser import UploadFileParser -from core.model_runtime.entities.message_entities import ImagePromptMessageContent -from extensions.ext_database import db - - -class FileExtraConfig(BaseModel): - """ - File Upload Entity. - """ - - image_config: Optional[dict[str, Any]] = None - - -class FileType(enum.Enum): - IMAGE = "image" - - @staticmethod - def value_of(value): - for member in FileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileTransferMethod(enum.Enum): - REMOTE_URL = "remote_url" - LOCAL_FILE = "local_file" - TOOL_FILE = "tool_file" - - @staticmethod - def value_of(value): - for member in FileTransferMethod: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileBelongsTo(enum.Enum): - USER = "user" - ASSISTANT = "assistant" - - @staticmethod - def value_of(value): - for member in FileBelongsTo: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileVar(BaseModel): - id: Optional[str] = None # message file id - tenant_id: str - type: FileType - transfer_method: FileTransferMethod - url: Optional[str] = None # remote url - related_id: Optional[str] = None - extra_config: Optional[FileExtraConfig] = None - filename: Optional[str] = None - extension: Optional[str] = None - mime_type: Optional[str] = None - - def to_dict(self) -> dict: - return { - "__variant": self.__class__.__name__, - "tenant_id": self.tenant_id, - "type": self.type.value, - "transfer_method": self.transfer_method.value, - "url": self.preview_url, - "remote_url": self.url, - "related_id": self.related_id, - "filename": self.filename, - "extension": self.extension, - "mime_type": self.mime_type, - } - - def to_markdown(self) -> str: - """ - Convert file to markdown - :return: - """ - preview_url = self.preview_url - if self.type == FileType.IMAGE: - text = f'![{self.filename or ""}]({preview_url})' - else: - text = f"[{self.filename or preview_url}]({preview_url})" - - return text - - @property - def data(self) -> Optional[str]: - """ - Get image data, file signed url or base64 data - depending on config MULTIMODAL_SEND_IMAGE_FORMAT - :return: - """ - return self._get_data() - - @property - def preview_url(self) -> Optional[str]: - """ - Get signed preview url - :return: - """ - return self._get_data(force_url=True) - - @property - def prompt_message_content(self) -> ImagePromptMessageContent: - if self.type == FileType.IMAGE: - image_config = self.extra_config.image_config - - return ImagePromptMessageContent( - data=self.data, - detail=ImagePromptMessageContent.DETAIL.HIGH - if image_config.get("detail") == "high" - else ImagePromptMessageContent.DETAIL.LOW, - ) - - def _get_data(self, force_url: bool = False) -> Optional[str]: - from models.model import UploadFile - - if self.type == FileType.IMAGE: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = ( - db.session.query(UploadFile) - .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id) - .first() - ) - - return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url) - elif self.transfer_method == FileTransferMethod.TOOL_FILE: - extension = self.extension - # add sign url - return ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=self.related_id, extension=extension - ) - - return None diff --git a/api/core/file/file_repository.py b/api/core/file/file_repository.py new file mode 100644 index 0000000000..975e1e72db --- /dev/null +++ b/api/core/file/file_repository.py @@ -0,0 +1,32 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models import ToolFile, UploadFile + +from .models import File + + +def get_upload_file(*, session: Session, file: File): + if file.related_id is None: + raise ValueError("Missing file related_id") + stmt = select(UploadFile).filter( + UploadFile.id == file.related_id, + UploadFile.tenant_id == file.tenant_id, + ) + record = session.scalar(stmt) + if not record: + raise ValueError(f"upload file {file.related_id} not found") + return record + + +def get_tool_file(*, session: Session, file: File): + if file.related_id is None: + raise ValueError("Missing file related_id") + stmt = select(ToolFile).filter( + ToolFile.id == file.related_id, + ToolFile.tenant_id == file.tenant_id, + ) + record = session.scalar(stmt) + if not record: + raise ValueError(f"tool file {file.related_id} not found") + return record diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py new file mode 100644 index 0000000000..12123cf3f7 --- /dev/null +++ b/api/core/file/helpers.py @@ -0,0 +1,48 @@ +import base64 +import hashlib +import hmac +import os +import time + +from configs import dify_config + + +def get_signed_file_url(upload_file_id: str) -> str: + url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + key = dify_config.SECRET_KEY.encode() + msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + +def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py deleted file mode 100644 index 641686bd7c..0000000000 --- a/api/core/file/message_file_parser.py +++ /dev/null @@ -1,243 +0,0 @@ -import re -from collections.abc import Mapping, Sequence -from typing import Any, Union -from urllib.parse import parse_qs, urlparse - -import requests - -from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar -from extensions.ext_database import db -from models.account import Account -from models.model import EndUser, MessageFile, UploadFile -from services.file_service import IMAGE_EXTENSIONS - - -class MessageFileParser: - def __init__(self, tenant_id: str, app_id: str) -> None: - self.tenant_id = tenant_id - self.app_id = app_id - - def validate_and_transform_files_arg( - self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser] - ) -> list[FileVar]: - """ - validate and transform files arg - - :param files: - :param file_extra_config: - :param user: - :return: - """ - for file in files: - if not isinstance(file, dict): - raise ValueError("Invalid file format, must be dict") - if not file.get("type"): - raise ValueError("Missing file type") - FileType.value_of(file.get("type")) - if not file.get("transfer_method"): - raise ValueError("Missing file transfer method") - FileTransferMethod.value_of(file.get("transfer_method")) - if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value: - if not file.get("url"): - raise ValueError("Missing file url") - if not file.get("url").startswith("http"): - raise ValueError("Invalid file url") - if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"): - raise ValueError("Missing file upload_file_id") - if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"): - raise ValueError("Missing file tool_file_id") - - # transform files to file objs - type_file_objs = self._to_file_objs(files, file_extra_config) - - # validate files - new_files = [] - for file_type, file_objs in type_file_objs.items(): - if file_type == FileType.IMAGE: - # parse and validate files - image_config = file_extra_config.image_config - - # check if image file feature is enabled - if not image_config: - continue - - # Validate number of files - if len(files) > image_config["number_limits"]: - raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") - - for file_obj in file_objs: - # Validate transfer method - if file_obj.transfer_method.value not in image_config["transfer_methods"]: - raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}") - - # Validate file type - if file_obj.type != FileType.IMAGE: - raise ValueError(f"Invalid file type: {file_obj.type}") - - if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: - # check remote url valid and is image - result, error = self._check_image_remote_url(file_obj.url) - if result is False: - raise ValueError(error) - elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: - # get upload file from upload_file_id - upload_file = ( - db.session.query(UploadFile) - .filter( - UploadFile.id == file_obj.related_id, - UploadFile.tenant_id == self.tenant_id, - UploadFile.created_by == user.id, - UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"), - UploadFile.extension.in_(IMAGE_EXTENSIONS), - ) - .first() - ) - - # check upload file is belong to tenant and user - if not upload_file: - raise ValueError("Invalid upload file") - - new_files.append(file_obj) - - # return all file objs - return new_files - - def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): - """ - transform message files - - :param files: - :param file_extra_config: - :return: - """ - # transform files to file objs - type_file_objs = self._to_file_objs(files, file_extra_config) - - # return all file objs - return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - - def _to_file_objs( - self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig - ) -> dict[FileType, list[FileVar]]: - """ - transform files to file objs - - :param files: - :param file_extra_config: - :return: - """ - type_file_objs: dict[FileType, list[FileVar]] = { - # Currently only support image - FileType.IMAGE: [] - } - - if not files: - return type_file_objs - - # group by file type and convert file args or message files to FileObj - for file in files: - if isinstance(file, MessageFile): - if file.belongs_to == FileBelongsTo.ASSISTANT.value: - continue - - file_obj = self._to_file_obj(file, file_extra_config) - if file_obj.type not in type_file_objs: - continue - - type_file_objs[file_obj.type].append(file_obj) - - return type_file_objs - - def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): - """ - transform file to file obj - - :param file: - :return: - """ - if isinstance(file, dict): - transfer_method = FileTransferMethod.value_of(file.get("transfer_method")) - if transfer_method != FileTransferMethod.TOOL_FILE: - return FileVar( - tenant_id=self.tenant_id, - type=FileType.value_of(file.get("type")), - transfer_method=transfer_method, - url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=file_extra_config, - ) - return FileVar( - tenant_id=self.tenant_id, - type=FileType.value_of(file.get("type")), - transfer_method=transfer_method, - url=None, - related_id=file.get("tool_file_id"), - extra_config=file_extra_config, - ) - else: - return FileVar( - id=file.id, - tenant_id=self.tenant_id, - type=FileType.value_of(file.type), - transfer_method=FileTransferMethod.value_of(file.transfer_method), - url=file.url, - related_id=file.upload_file_id or None, - extra_config=file_extra_config, - ) - - def _check_image_remote_url(self, url): - try: - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" - " Chrome/91.0.4472.124 Safari/537.36" - } - - def is_s3_presigned_url(url): - try: - parsed_url = urlparse(url) - if "amazonaws.com" not in parsed_url.netloc: - return False - query_params = parse_qs(parsed_url.query) - - def check_presign_v2(query_params): - required_params = ["Signature", "Expires"] - for param in required_params: - if param not in query_params: - return False - if not query_params["Expires"][0].isdigit(): - return False - signature = query_params["Signature"][0] - if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): - return False - - return True - - def check_presign_v4(query_params): - required_params = ["X-Amz-Signature", "X-Amz-Expires"] - for param in required_params: - if param not in query_params: - return False - if not query_params["X-Amz-Expires"][0].isdigit(): - return False - signature = query_params["X-Amz-Signature"][0] - if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): - return False - - return True - - return check_presign_v4(query_params) or check_presign_v2(query_params) - except Exception: - return False - - if is_s3_presigned_url(url): - response = requests.get(url, headers=headers, allow_redirects=True) - if response.status_code in {200, 304}: - return True, "" - - response = requests.head(url, headers=headers, allow_redirects=True) - if response.status_code in {200, 304}: - return True, "" - else: - return False, "URL does not exist." - except requests.RequestException as e: - return False, f"Error checking URL: {e}" diff --git a/api/core/file/models.py b/api/core/file/models.py new file mode 100644 index 0000000000..866ff3155b --- /dev/null +++ b/api/core/file/models.py @@ -0,0 +1,140 @@ +from collections.abc import Mapping, Sequence +from typing import Optional + +from pydantic import BaseModel, Field, model_validator + +from core.model_runtime.entities.message_entities import ImagePromptMessageContent + +from . import helpers +from .constants import FILE_MODEL_IDENTITY +from .enums import FileTransferMethod, FileType +from .tool_file_parser import ToolFileParser + + +class ImageConfig(BaseModel): + """ + NOTE: This part of validation is deprecated, but still used in app features "Image Upload". + """ + + number_limits: int = 0 + transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + detail: ImagePromptMessageContent.DETAIL | None = None + + +class FileExtraConfig(BaseModel): + """ + File Upload Entity. + """ + + image_config: Optional[ImageConfig] = None + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_extensions: Sequence[str] = Field(default_factory=list) + allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + number_limits: int = 0 + + +class File(BaseModel): + dify_model_identity: str = FILE_MODEL_IDENTITY + + id: Optional[str] = None # message file id + tenant_id: str + type: FileType + transfer_method: FileTransferMethod + remote_url: Optional[str] = None # remote url + related_id: Optional[str] = None + filename: Optional[str] = None + extension: Optional[str] = Field(default=None, description="File extension, should contains dot") + mime_type: Optional[str] = None + size: int = -1 + _extra_config: FileExtraConfig | None = None + + def to_dict(self) -> Mapping[str, str | int | None]: + data = self.model_dump(mode="json") + return { + **data, + "url": self.generate_url(), + } + + @property + def markdown(self) -> str: + url = self.generate_url() + if self.type == FileType.IMAGE: + text = f'![{self.filename or ""}]({url})' + else: + text = f"[{self.filename or url}]({url})" + + return text + + def generate_url(self) -> Optional[str]: + if self.type == FileType.IMAGE: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) + else: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) + + @model_validator(mode="after") + def validate_after(self): + match self.transfer_method: + case FileTransferMethod.REMOTE_URL: + if not self.remote_url: + raise ValueError("Missing file url") + if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): + raise ValueError("Invalid file url") + case FileTransferMethod.LOCAL_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") + case FileTransferMethod.TOOL_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") + + # Validate the extra config. + if not self._extra_config: + return self + + if self._extra_config.allowed_file_types: + if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM: + raise ValueError(f"Invalid file type: {self.type}") + + if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions: + raise ValueError(f"Invalid file extension: {self.extension}") + + if ( + self._extra_config.allowed_upload_methods + and self.transfer_method not in self._extra_config.allowed_upload_methods + ): + raise ValueError(f"Invalid transfer method: {self.transfer_method}") + + match self.type: + case FileType.IMAGE: + # NOTE: This part of validation is deprecated, but still used in app features "Image Upload". + if not self._extra_config.image_config: + return self + # TODO: skip check if transfer_methods is empty, because many test cases are not setting this field + if ( + self._extra_config.image_config.transfer_methods + and self.transfer_method not in self._extra_config.image_config.transfer_methods + ): + raise ValueError(f"Invalid transfer method: {self.transfer_method}") + + return self diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 4e6d58904e..6793e41978 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -13,8 +13,11 @@ SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "") SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3")) -proxies = ( - {"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL} +proxy_mounts = ( + { + "http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL), + "https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL), + } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None ) @@ -33,11 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): while retries <= max_retries: try: if SSRF_PROXY_ALL_URL: - response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) - elif proxies: - response = httpx.request(method=method, url=url, proxies=proxies, **kwargs) + with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client: + response = client.request(method=method, url=url, **kwargs) + elif proxy_mounts: + with httpx.Client(mounts=proxy_mounts) as client: + response = client.request(method=method, url=url, **kwargs) else: - response = httpx.request(method=method, url=url, **kwargs) + with httpx.Client() as client: + response = client.request(method=method, url=url, **kwargs) if response.status_code not in STATUS_FORCELIST: return response diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index eeeccc2349..9b4080aef1 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -3,9 +3,8 @@ from typing import Optional from flask import Config, Flask from pydantic import BaseModel -from core.entities.provider_entities import QuotaUnit, RestrictModel +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel from core.model_runtime.entities.model_entities import ModelType -from models.provider import ProviderQuotaType class HostingQuota(BaseModel): diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index bc94912c1e..189d94e290 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,18 +1,20 @@ from typing import Optional from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.file.message_file_parser import MessageFileParser +from core.file import file_manager from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, + PromptMessageContent, PromptMessageRole, TextPromptMessageContent, UserPromptMessage, ) from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db +from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import WorkflowRun @@ -65,13 +67,12 @@ class TokenBufferMemory: messages = list(reversed(thread_messages)) - message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id) prompt_messages = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: @@ -84,17 +85,21 @@ class TokenBufferMemory: workflow_run.workflow.features_dict, is_vision=False ) - if file_extra_config: - file_objs = message_file_parser.transform_message_files(files, file_extra_config) + if file_extra_config and app_record: + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config + ) else: file_objs = [] if not file_objs: prompt_messages.append(UserPromptMessage(content=message.query)) else: - prompt_message_contents = [TextPromptMessageContent(data=message.query)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file_obj in file_objs: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message = file_manager.to_prompt_message_content(file_obj) + prompt_message_contents.append(prompt_message) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 1a4a03e277..dc95e4b509 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,9 +1,9 @@ import logging import os -from collections.abc import Callable, Generator, Sequence -from typing import IO, Literal, Optional, Union, cast, overload +from collections.abc import Callable, Generator, Iterable, Sequence +from typing import IO, Any, Literal, Optional, Union, cast, overload -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError @@ -310,9 +310,7 @@ class ModelInstance: user=user, ) - def invoke_tts( - self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None - ) -> Generator[bytes, None, None]: + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: """ Invoke large language tts model @@ -336,7 +334,7 @@ class ModelInstance: voice=voice, ) - def _round_robin_invoke(self, function: Callable, *args, **kwargs): + def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): """ Round-robin invoke :param function: function to invoke diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index e69de29bb2..b3eb4d4dfe 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -0,0 +1,38 @@ +from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from .message_entities import ( + AssistantPromptMessage, + AudioPromptMessageContent, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from .model_entities import ModelPropertyKey + +__all__ = [ + "ImagePromptMessageContent", + "PromptMessage", + "PromptMessageRole", + "LLMUsage", + "ModelPropertyKey", + "AssistantPromptMessage", + "PromptMessage", + "PromptMessageContent", + "PromptMessageRole", + "SystemPromptMessage", + "TextPromptMessageContent", + "UserPromptMessage", + "PromptMessageTool", + "ToolPromptMessage", + "PromptMessageContentType", + "LLMResult", + "LLMResultChunk", + "LLMResultChunkDelta", + "AudioPromptMessageContent", +] diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index e51bb18deb..cda1639661 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -2,7 +2,7 @@ from abc import ABC from enum import Enum from typing import Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator class PromptMessageRole(Enum): @@ -55,6 +55,7 @@ class PromptMessageContentType(Enum): TEXT = "text" IMAGE = "image" + AUDIO = "audio" class PromptMessageContent(BaseModel): @@ -74,12 +75,18 @@ class TextPromptMessageContent(PromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.TEXT +class AudioPromptMessageContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.AUDIO + data: str = Field(..., description="Base64 encoded audio data") + format: str = Field(..., description="Audio format") + + class ImagePromptMessageContent(PromptMessageContent): """ Model class for image prompt message content. """ - class DETAIL(Enum): + class DETAIL(str, Enum): LOW = "low" HIGH = "high" diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 33dbce37c4..18a5e8be34 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,11 +1,11 @@ import logging -import os import time from collections.abc import Generator from typing import Optional, Union from pydantic import ConfigDict +from configs import dify_config from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.logging_callback import LoggingCallback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage @@ -68,7 +68,7 @@ class LargeLanguageModel(AIModel): callbacks = callbacks or [] - if bool(os.environ.get("DEBUG", "False").lower() == "true"): + if dify_config.DEBUG: callbacks.append(LoggingCallback()) # trigger before invoke callbacks diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index beade74362..6da5db3883 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -2,7 +2,7 @@ from typing import Optional from pydantic import ConfigDict -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 8cefa63ebf..4feaa6f042 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Iterable from typing import Optional from pydantic import ConfigDict @@ -21,8 +22,14 @@ class TTSModel(AIModel): model_config = ConfigDict(protected_namespaces=()) def invoke( - self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None - ): + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + user: Optional[str] = None, + ) -> Iterable[bytes]: """ Invoke large language model @@ -52,12 +59,12 @@ class TTSModel(AIModel): def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list[dict]: """ - Get voice for given tts model voices + Retrieves the list of voices supported by a given text-to-speech (TTS) model. - :param language: tts language - :param model: model name - :param credentials: model credentials - :return: voices lists + :param language: The language for which the voices are requested. + :param model: The name of the TTS model. + :param credentials: The credentials required to access the TTS model. + :return: A list of voices supported by the TTS model. """ plugin_model_manager = PluginModelManager() return plugin_model_manager.get_tts_model_voices( diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 0200f4a32d..764944f799 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -358,8 +358,8 @@ class TraceTask: workflow_run_id = workflow_run.id workflow_run_elapsed_time = workflow_run.elapsed_time workflow_run_status = workflow_run.status - workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {} - workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {} + workflow_run_inputs = workflow_run.inputs_dict + workflow_run_outputs = workflow_run.outputs_dict workflow_run_version = workflow_run.version error = workflow_run.error or "" diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 377512886a..8894f8eef5 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -21,7 +21,7 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.llm.node import LLMNode from models.account import Tenant diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 1bd5d84e4c..f402da030f 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,5 +1,5 @@ from core.plugin.backwards_invocation.base import BaseBackwardsInvocation -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.parameter_extractor.entities import ( ModelConfig as ParameterExtractorModelConfig, ) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index ce8038d14e..bbd9531b19 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,12 +1,15 @@ -from typing import Optional, Union +from collections.abc import Sequence +from typing import Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file.file_obj import FileVar +from core.file import file_manager +from core.file.models import File from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContent, PromptMessageRole, SystemPromptMessage, TextPromptMessageContent, @@ -14,8 +17,8 @@ from core.model_runtime.entities.message_entities import ( ) from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform -from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.variable_pool import VariablePool class AdvancedPromptTransform(PromptTransform): @@ -28,22 +31,19 @@ class AdvancedPromptTransform(PromptTransform): def get_prompt( self, - prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], - inputs: dict, + *, + prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate, + inputs: dict[str, str], query: str, - files: list[FileVar], + files: Sequence[File], context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None, ) -> list[PromptMessage]: - inputs = {key: str(value) for key, value in inputs.items()} - prompt_messages = [] - model_mode = ModelMode.value_of(model_config.mode) - if model_mode == ModelMode.COMPLETION: + if isinstance(prompt_template, CompletionModelPromptTemplate): prompt_messages = self._get_completion_model_prompt_messages( prompt_template=prompt_template, inputs=inputs, @@ -54,12 +54,11 @@ class AdvancedPromptTransform(PromptTransform): memory=memory, model_config=model_config, ) - elif model_mode == ModelMode.CHAT: + elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template): prompt_messages = self._get_chat_model_prompt_messages( prompt_template=prompt_template, inputs=inputs, query=query, - query_prompt_template=query_prompt_template, files=files, context=context, memory_config=memory_config, @@ -74,7 +73,7 @@ class AdvancedPromptTransform(PromptTransform): prompt_template: CompletionModelPromptTemplate, inputs: dict, query: Optional[str], - files: list[FileVar], + files: Sequence[File], context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], @@ -88,10 +87,10 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = [] if prompt_template.edition_type == "basic" or not prompt_template.edition_type: - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) if memory and memory_config: role_prefix = memory_config.role_prefix @@ -100,15 +99,15 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, raw_prompt=raw_prompt, role_prefix=role_prefix, - prompt_template=prompt_template, + parser=parser, prompt_inputs=prompt_inputs, model_config=model_config, ) if query: - prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) + prompt_inputs = self._set_query_variable(query, parser, prompt_inputs) - prompt = prompt_template.format(prompt_inputs) + prompt = parser.format(prompt_inputs) else: prompt = raw_prompt prompt_inputs = inputs @@ -116,9 +115,10 @@ class AdvancedPromptTransform(PromptTransform): prompt = Jinja2Formatter.format(prompt, prompt_inputs) if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: @@ -131,35 +131,38 @@ class AdvancedPromptTransform(PromptTransform): prompt_template: list[ChatModelMessage], inputs: dict, query: Optional[str], - files: list[FileVar], + files: Sequence[File], context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None, ) -> list[PromptMessage]: """ Get chat model prompt messages. """ - raw_prompt_list = prompt_template - prompt_messages = [] - - for prompt_item in raw_prompt_list: + for prompt_item in prompt_template: raw_prompt = prompt_item.text if prompt_item.edition_type == "basic" or not prompt_item.edition_type: - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = prompt_template.format(prompt_inputs) + if self.with_variable_tmpl: + vp = VariablePool() + for k, v in inputs.items(): + if k.startswith("#"): + vp.add(k[1:-1].split("."), v) + raw_prompt = raw_prompt.replace("{{#context#}}", context or "") + prompt = vp.convert_template(raw_prompt).text + else: + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs = self._set_context_variable( + context=context, parser=parser, prompt_inputs=prompt_inputs + ) + prompt = parser.format(prompt_inputs) elif prompt_item.edition_type == "jinja2": prompt = raw_prompt prompt_inputs = inputs - - prompt = Jinja2Formatter.format(prompt, prompt_inputs) + prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs) else: raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") @@ -170,25 +173,25 @@ class AdvancedPromptTransform(PromptTransform): elif prompt_item.role == PromptMessageRole.ASSISTANT: prompt_messages.append(AssistantPromptMessage(content=prompt)) - if query and query_prompt_template: - prompt_template = PromptTemplateParser( - template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl + if query and memory_config and memory_config.query_prompt_template: + parser = PromptTemplateParser( + template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl ) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs["#sys.query#"] = query - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) - query = prompt_template.format(prompt_inputs) + query = parser.format(prompt_inputs) if memory and memory_config: prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] + if files and query is not None: + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=query)) for file in files: - prompt_message_contents.append(file.prompt_message_content) - + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_messages.append(UserPromptMessage(content=query)) @@ -200,19 +203,19 @@ class AdvancedPromptTransform(PromptTransform): # get last user message content and add files prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) last_message.content = prompt_message_contents else: prompt_message_contents = [TextPromptMessageContent(data="")] # not for query for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_message_contents = [TextPromptMessageContent(data=query)] for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) elif query: @@ -220,8 +223,8 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if "#context#" in prompt_template.variable_keys: + def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + if "#context#" in parser.variable_keys: if context: prompt_inputs["#context#"] = context else: @@ -229,8 +232,8 @@ class AdvancedPromptTransform(PromptTransform): return prompt_inputs - def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if "#query#" in prompt_template.variable_keys: + def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + if "#query#" in parser.variable_keys: if query: prompt_inputs["#query#"] = query else: @@ -244,16 +247,16 @@ class AdvancedPromptTransform(PromptTransform): memory_config: MemoryConfig, raw_prompt: str, role_prefix: MemoryConfig.RolePrefix, - prompt_template: PromptTemplateParser, + parser: PromptTemplateParser, prompt_inputs: dict, model_config: ModelConfigWithCredentialsEntity, ) -> dict: - if "#histories#" in prompt_template.variable_keys: + if "#histories#" in parser.variable_keys: if memory: inputs = {"#histories#": "", **prompt_inputs} - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs)) + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs)) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 7479560520..5a3481b963 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Optional from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.file import file_manager from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( PromptMessage, + PromptMessageContent, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, @@ -18,10 +20,10 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File -class ModelMode(enum.Enum): +class ModelMode(str, enum.Enum): COMPLETION = "completion" CHAT = "chat" @@ -53,7 +55,7 @@ class SimplePromptTransform(PromptTransform): prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: list["FileVar"], + files: list["File"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, @@ -169,7 +171,7 @@ class SimplePromptTransform(PromptTransform): inputs: dict, query: str, context: Optional[str], - files: list["FileVar"], + files: list["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -214,7 +216,7 @@ class SimplePromptTransform(PromptTransform): inputs: dict, query: str, context: Optional[str], - files: list["FileVar"], + files: list["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -261,11 +263,12 @@ class SimplePromptTransform(PromptTransform): return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage: if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_message = UserPromptMessage(content=prompt_message_contents) else: diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py index e8b626499f..f7aef76c87 100644 --- a/api/core/prompt/utils/extract_thread_messages.py +++ b/api/core/prompt/utils/extract_thread_messages.py @@ -1,7 +1,9 @@ +from typing import Any + from constants import UUID_NIL -def extract_thread_messages(messages: list[dict]) -> list[dict]: +def extract_thread_messages(messages: list[Any]): thread_messages = [] next_message = None diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 29494db221..5eec5e3c99 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,7 +1,8 @@ from typing import cast -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -21,7 +22,7 @@ class PromptMessageUtil: :return: """ prompts = [] - if model_mode == ModelMode.CHAT.value: + if model_mode == ModelMode.CHAT: tool_calls = [] for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: @@ -51,11 +52,9 @@ class PromptMessageUtil: files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) + if isinstance(content, TextPromptMessageContent): text += content.data - else: - content = cast(ImagePromptMessageContent, content) + elif isinstance(content, ImagePromptMessageContent): files.append( { "type": "image", @@ -63,6 +62,14 @@ class PromptMessageUtil: "detail": content.detail.value, } ) + elif isinstance(content, AudioPromptMessageContent): + files.append( + { + "type": "audio", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "format": content.format, + } + ) else: text = prompt_message.content diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index ae3934327e..2224aaab80 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -14,6 +14,7 @@ from core.entities.provider_entities import ( CustomProviderConfiguration, ModelLoadBalancingConfiguration, ModelSettings, + ProviderQuotaType, QuotaConfiguration, SystemConfiguration, ) @@ -31,7 +32,6 @@ from models.provider import ( Provider, ProviderModel, ProviderModelSetting, - ProviderQuotaType, ProviderType, TenantDefaultModel, TenantPreferredModelProvider, diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index b1d6f93cff..992415657e 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,14 +1,14 @@ from typing import Optional -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.models.document import Document -from core.rag.rerank.constants.rerank_mode import RerankMode from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights -from core.rag.rerank.rerank_model import RerankModelRunner -from core.rag.rerank.weight_rerank import WeightRerankRunner +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_factory import RerankRunnerFactory +from core.rag.rerank.rerank_type import RerankMode class DataPostProcessor: @@ -47,11 +47,12 @@ class DataPostProcessor: tenant_id: str, reranking_model: Optional[dict] = None, weights: Optional[dict] = None, - ) -> Optional[RerankModelRunner | WeightRerankRunner]: + ) -> Optional[BaseRerankRunner]: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: - return WeightRerankRunner( - tenant_id, - Weights( + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, + tenant_id=tenant_id, + weights=Weights( vector_setting=VectorSetting( vector_weight=weights["vector_setting"]["vector_weight"], embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], @@ -62,23 +63,33 @@ class DataPostProcessor: ), ), ) + return runner elif reranking_mode == RerankMode.RERANKING_MODEL.value: - if reranking_model: - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - provider=reranking_model["reranking_provider_name"], - model_type=ModelType.RERANK, - model=reranking_model["reranking_model_name"], - ) - except InvokeAuthorizationError: - return None - return RerankModelRunner(rerank_model_instance) - return None + rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model) + if rerank_model_instance is None: + return None + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, rerank_model_instance=rerank_model_instance + ) + return runner return None def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: if reorder_enabled: return ReorderRunner() return None + + def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None: + if reranking_model: + try: + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model["reranking_provider_name"], + model_type=ModelType.RERANK, + model=reranking_model["reranking_model_name"], + ) + return rerank_model_instance + except InvokeAuthorizationError: + return None + return None diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index d3fd0c672a..3affbd2d0a 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,7 +6,7 @@ from flask import Flask, current_app from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.rerank.constants.rerank_mode import RerankMode +from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 6dcd98dcfd..c77cb87376 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -9,10 +9,10 @@ _import_err_msg = ( ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 543cfa67b3..1d4bfef76d 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -12,10 +12,10 @@ from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 610aa498ab..a9e1486edd 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -6,10 +6,10 @@ from chromadb import QueryResult, Settings from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f420373d5b..052a187225 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -9,11 +9,11 @@ from elasticsearch import Elasticsearch from flask import current_app from pydantic import BaseModel, model_validator -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index bdca59f869..080a1ef567 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -7,11 +7,11 @@ from pymilvus import MilvusClient, MilvusException from pymilvus.milvus_client import IndexParams from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b30aa7ca22..1fca926a2d 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -8,10 +8,10 @@ from clickhouse_connect import get_client from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 8d2e0a86ab..0e0f107268 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -9,11 +9,11 @@ from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 84a4381cd1..4ced5d61e5 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -13,10 +13,10 @@ from nltk.corpus import stopwords from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index a82a9b96dd..7cbbdcc81f 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -12,11 +12,11 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -216,7 +216,7 @@ class PGVectoRSFactory(AbstractVectorFactory): else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTO_RS, collection_name)) dim = len(embeddings.embed_query("pgvecto_rs")) return PGVectoRS( diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 6f336d27e7..40a9cdd136 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -8,10 +8,10 @@ import psycopg2.pool from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f418e3ca05..69d2aa4f76 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -20,11 +20,11 @@ from qdrant_client.http.models import ( from qdrant_client.local.qdrant_local import QdrantLocal from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 13a63784be..f373dcfeab 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -8,9 +8,9 @@ from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from models.dataset import Dataset try: diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 39e3a7f6cf..f971a9c5eb 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -8,10 +8,10 @@ from tcvectordb.model import index as vdb_index from tcvectordb.model.document import Filter from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 7837c5a4aa..1147e35ce8 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -9,10 +9,10 @@ from sqlalchemy import text as sql_text from sqlalchemy.orm import Session, declarative_base from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 873b289027..fb956a16ed 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -2,12 +2,12 @@ from abc import ABC, abstractmethod from typing import Any, Optional from configs import dify_config -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 5f60f10acb..4f927f2899 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -14,11 +14,11 @@ from volcengine.viking_db import ( ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field as vdb_Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 4009efe7a7..649cfbfea8 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -7,11 +7,11 @@ import weaviate from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/embedding/__init__.py b/api/core/rag/embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py similarity index 97% rename from api/core/embedding/cached_embedding.py rename to api/core/rag/embedding/cached_embedding.py index 31d2171e72..b3e93ce760 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -6,11 +6,11 @@ import numpy as np from sqlalchemy.exc import IntegrityError from configs import dify_config -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.rag.datasource.entity.embedding import Embeddings +from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper diff --git a/api/core/rag/datasource/entity/embedding.py b/api/core/rag/embedding/embedding_base.py similarity index 90% rename from api/core/rag/datasource/entity/embedding.py rename to api/core/rag/embedding/embedding_base.py index 126c1a3723..9f232ab910 100644 --- a/api/core/rag/datasource/entity/embedding.py +++ b/api/core/rag/embedding/embedding_base.py @@ -7,10 +7,12 @@ class Embeddings(ABC): @abstractmethod def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs.""" + raise NotImplementedError @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" + raise NotImplementedError async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs.""" diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 7352ef378b..2b6e048652 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -121,7 +121,7 @@ class WordExtractor(BaseExtractor): db.session.add(upload_file) db.session.commit() image_map[rel.target_part] = ( - f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)" ) return image_map diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py new file mode 100644 index 0000000000..818b04b2ff --- /dev/null +++ b/api/core/rag/rerank/rerank_base.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from core.rag.models.document import Document + + +class BaseRerankRunner(ABC): + @abstractmethod + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: + """ + Run rerank model + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ + raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_factory.py b/api/core/rag/rerank/rerank_factory.py new file mode 100644 index 0000000000..1a3cf85736 --- /dev/null +++ b/api/core/rag/rerank/rerank_factory.py @@ -0,0 +1,16 @@ +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_model import RerankModelRunner +from core.rag.rerank.rerank_type import RerankMode +from core.rag.rerank.weight_rerank import WeightRerankRunner + + +class RerankRunnerFactory: + @staticmethod + def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner: + match runner_type: + case RerankMode.RERANKING_MODEL.value: + return RerankModelRunner(*args, **kwargs) + case RerankMode.WEIGHTED_SCORE.value: + return WeightRerankRunner(*args, **kwargs) + case _: + raise ValueError(f"Unknown runner type: {runner_type}") diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 27f86aed34..40ebf0befd 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -2,9 +2,10 @@ from typing import Optional from core.model_manager import ModelInstance from core.rag.models.document import Document +from core.rag.rerank.rerank_base import BaseRerankRunner -class RerankModelRunner: +class RerankModelRunner(BaseRerankRunner): def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/rerank_type.py similarity index 100% rename from api/core/rag/rerank/constants/rerank_mode.py rename to api/core/rag/rerank/rerank_type.py diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 16d6b879a4..2e3fbe04e2 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -4,15 +4,16 @@ from typing import Optional import numpy as np -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights +from core.rag.rerank.rerank_base import BaseRerankRunner -class WeightRerankRunner: +class WeightRerankRunner(BaseRerankRunner): def __init__(self, tenant_id: str, weights: Weights) -> None: self.tenant_id = tenant_id self.weights = weights diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index a0494adc60..68fab0c127 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.llm import LLMNode PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 8a96cd45a2..e08f4f64cf 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -3,6 +3,9 @@ from collections.abc import Generator from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional +if TYPE_CHECKING: + from models.model import File + from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ToolEntity, @@ -10,10 +13,6 @@ from core.tools.entities.tool_entities import ( ToolParameter, ToolProviderType, ) -from core.tools.utils.tool_parameter_converter import ToolParameterConverter - -if TYPE_CHECKING: - from core.file.file_obj import FileVar class Tool(ABC): @@ -91,11 +90,9 @@ class Tool(ABC): """ # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials result = deepcopy(tool_parameters) - for parameter in self.entity.parameters: + for parameter in self.entity.parameters or []: if parameter.name in tool_parameters: - result[parameter.name] = ToolParameterConverter.cast_parameter_by_type( - tool_parameters[parameter.name], parameter.type - ) + result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) return result @@ -171,9 +168,12 @@ class Tool(ABC): type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image), save_as=save_as ) - def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: + def create_file_message(self, file: "File") -> ToolInvokeMessage: return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.FILE_VAR, message=None, meta={"file_var": file_var}, save_as="" + type=ToolInvokeMessage.MessageType.FILE, + message=ToolInvokeMessage.FileMessage(), + meta={"file": file}, + save_as="", ) def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 504f7e012d..b58a0a2eb9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -44,7 +44,7 @@ class ToolProviderApiEntity(BaseModel): for tool in tools: if tool.get("parameters"): for parameter in tool.get("parameters"): - if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value: + if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: parameter["type"] = "files" # ------------- diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 0c15b2a371..ffeeabbc1c 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -18,7 +18,7 @@ class ApiToolBundle(BaseModel): # summary summary: Optional[str] = None # operation_id - operation_id: str = None + operation_id: Optional[str] = None # parameters parameters: Optional[list[ToolParameter]] = None # author diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index b98ee28fb4..5f6c593cc0 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -111,6 +111,9 @@ class ToolInvokeMessage(BaseModel): class BlobMessage(BaseModel): blob: bytes + class FileMessage(BaseModel): + pass + class VariableMessage(BaseModel): variable_name: str = Field(..., description="The name of the variable") variable_value: str = Field(..., description="The value of the variable") @@ -149,14 +152,14 @@ class ToolInvokeMessage(BaseModel): BLOB = "blob" JSON = "json" IMAGE_LINK = "image_link" - FILE_VAR = "file_var" VARIABLE = "variable" + FILE = "file" type: MessageType = MessageType.TEXT """ plain text, image url or link url """ - message: JsonMessage | TextMessage | BlobMessage | VariableMessage | None + message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | None meta: dict[str, Any] | None = None save_as: str = "" @@ -205,6 +208,67 @@ class ToolParameter(BaseModel): SELECT = CommonParameterType.SELECT.value SECRET_INPUT = CommonParameterType.SECRET_INPUT.value FILE = CommonParameterType.FILE.value + FILES = "files" + + # deprecated, should not use. + SYSTEM_FILES = "systme-files" + + def as_normal_type(self): + if self in { + ToolParameter.ToolParameterType.SECRET_INPUT, + ToolParameter.ToolParameterType.SELECT, + }: + return "string" + return self.value + + def cast_value(self, value: Any, /): + try: + match self: + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): + if value is None: + return "" + else: + return value if isinstance(value, str) else str(value) + + case ToolParameter.ToolParameterType.BOOLEAN: + if value is None: + return False + elif isinstance(value, str): + # Allowed YAML boolean value strings: https://yaml.org/type/bool.html + # and also '0' for False and '1' for True + match value.lower(): + case "true" | "yes" | "y" | "1": + return True + case "false" | "no" | "n" | "0": + return False + case _: + return bool(value) + else: + return value if isinstance(value, bool) else bool(value) + + case ToolParameter.ToolParameterType.NUMBER: + if isinstance(value, int | float): + return value + elif isinstance(value, str) and value: + if "." in value: + return float(value) + else: + return int(value) + case ( + ToolParameter.ToolParameterType.SYSTEM_FILES + | ToolParameter.ToolParameterType.FILE + | ToolParameter.ToolParameterType.FILES + ): + return value + case _: + return str(value) + + except Exception: + raise ValueError(f"The tool parameter value {value} is not in correct type of {self.as_normal_type()}.") class ToolParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index ab15970500..1b7c63193f 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -10,7 +10,8 @@ from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file.file_obj import FileTransferMethod +from core.file import FileType +from core.file.models import FileTransferMethod from core.ops.ops_trace_manager import TraceQueueManager from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter @@ -26,6 +27,7 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from models.enums import CreatedByRole from models.model import Message, MessageFile @@ -295,7 +297,10 @@ class ToolEngine: @staticmethod def _create_message_files( - tool_messages: Iterable[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str + tool_messages: Iterable[ToolInvokeMessageBinary], + agent_message: Message, + invoke_from: InvokeFrom, + user_id: str, ) -> list[tuple[MessageFile, str]]: """ Create message file @@ -306,29 +311,31 @@ class ToolEngine: result = [] for message in tool_messages: - file_type = "bin" if "image" in message.mimetype: - file_type = "image" + file_type = FileType.IMAGE elif "video" in message.mimetype: - file_type = "video" + file_type = FileType.VIDEO elif "audio" in message.mimetype: - file_type = "audio" - elif "text" in message.mimetype: - file_type = "text" - elif "pdf" in message.mimetype: - file_type = "pdf" - elif "zip" in message.mimetype: - file_type = "archive" - # ... + file_type = FileType.AUDIO + elif "text" in message.mimetype or "pdf" in message.mimetype: + file_type = FileType.DOCUMENT + else: + file_type = FileType.CUSTOM + # extract tool file id from url + tool_file_id = message.url.split("/")[-1].split(".")[0] message_file = MessageFile( message_id=agent_message.id, type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE.value, + transfer_method=FileTransferMethod.TOOL_FILE, belongs_to="assistant", url=message.url, - upload_file_id=None, - created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"), + upload_file_id=tool_file_id, + created_by_role=( + CreatedByRole.ACCOUNT + if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatedByRole.END_USER + ), created_by=user_id, ) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index f123e69c19..b1249a0ff5 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -4,7 +4,6 @@ import hmac import logging import os import time -from collections.abc import Generator from mimetypes import guess_extension, guess_type from typing import Optional, Union from uuid import uuid4 @@ -57,22 +56,32 @@ class ToolFileManager: @staticmethod def create_file_by_raw( - user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str], + file_binary: bytes, + mimetype: str, ) -> ToolFile: - """ - create file - """ extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f"tools/{tenant_id}/{unique_name}{extension}" - storage.save(filename, file_binary) + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, file_binary) tool_file = ToolFile( - user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + name=filename, + size=len(file_binary), ) db.session.add(tool_file) db.session.commit() + db.session.refresh(tool_file) return tool_file @@ -83,26 +92,31 @@ class ToolFileManager: file_url: str, conversation_id: Optional[str] = None, ) -> ToolFile: - """ - create file - """ # try to download image - response = get(file_url) - response.raise_for_status() - blob = response.content + try: + response = get(file_url) + response.raise_for_status() + blob = response.content + except Exception as e: + logger.error(f"Failed to download file from {file_url}: {e}") + raise + mimetype = guess_type(file_url)[0] or "octet/stream" extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f"tools/{tenant_id}/{unique_name}{extension}" - storage.save(filename, blob) + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, blob) tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, - file_key=filename, + file_key=filepath, mimetype=mimetype, original_url=file_url, + name=filename, + size=len(blob), ) db.session.add(tool_file) @@ -110,18 +124,6 @@ class ToolFileManager: return tool_file - @staticmethod - def create_file_by_key( - user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str - ) -> ToolFile: - """ - create file - """ - tool_file = ToolFile( - user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype - ) - return tool_file - @staticmethod def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: """ @@ -166,9 +168,12 @@ class ToolFileManager: # Check if message_file is not None if message_file is not None: # get tool file id - tool_file_id = message_file.url.split("/")[-1] - # trim extension - tool_file_id = tool_file_id.split(".")[0] + if message_file.url is not None: + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + else: + tool_file_id = None else: tool_file_id = None @@ -188,7 +193,7 @@ class ToolFileManager: return blob, tool_file.mimetype @staticmethod - def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]: + def get_file_generator_by_tool_file_id(tool_file_id: str): """ get file binary @@ -205,11 +210,11 @@ class ToolFileManager: ) if not tool_file: - return None + return None, None - generator = storage.load_stream(tool_file.file_key) + stream = storage.load_stream(tool_file.file_key) - return generator, tool_file.mimetype + return stream, tool_file # init tool_file_parser diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 1d83934130..089c970b13 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -38,8 +38,10 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager -from core.tools.utils.tool_parameter_converter import ToolParameterConverter +from core.tools.utils.configuration import ( + ProviderConfigEncrypter, + ToolParameterConfigurationManager, +) from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -251,7 +253,7 @@ class ToolManager: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") @classmethod - def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: + def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict): """ init runtime parameter """ @@ -270,7 +272,7 @@ class ToolManager: f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" ) - return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type) + return parameter_rule.type.cast_value(parameter_value) @classmethod def get_agent_tool_runtime( @@ -295,7 +297,11 @@ class ToolManager: parameters = tool_entity.get_merged_runtime_parameters() for parameter in parameters: # check file types - if parameter.type == ToolParameter.ToolParameterType.FILE: + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: raise ValueError(f"file type parameter {parameter.name} not supported in agent") if parameter.form == ToolParameter.ToolParameterForm.FORM: diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 73c7ef44b3..5bcd2ec61b 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -3,7 +3,7 @@ from collections.abc import Generator from mimetypes import guess_extension from typing import Optional -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager @@ -25,7 +25,7 @@ class ToolFileMessageTransformer: for message in messages: if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: yield message - elif message.type == ToolInvokeMessage.MessageType.IMAGE: + elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(message.message, str): # try to download image try: if not conversation_id: @@ -68,6 +68,8 @@ class ToolFileMessageTransformer: if not isinstance(message.message, ToolInvokeMessage.BlobMessage): raise ValueError("unexpected message type") + # FIXME: should do a type check here. + assert isinstance(message.message, bytes) file = ToolFileManager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, @@ -76,8 +78,7 @@ class ToolFileMessageTransformer: mimetype=mimetype, ) - extension = guess_extension(file.mimetype) or ".bin" - url = cls.get_tool_file_url(file.id, extension) + url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) # check if file is image if "image" in mimetype: @@ -94,17 +95,14 @@ class ToolFileMessageTransformer: save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) - elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - assert message.meta - - file_var: FileVar | None = message.meta.get("file_var") - if file_var: - if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - assert file_var.related_id - assert file_var.extension - - url = cls.get_tool_file_url(file_var.related_id, file_var.extension) - if file_var.type == FileType.IMAGE: + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + file = message.meta.get("file") + if isinstance(file, File): + if file.transfer_method == FileTransferMethod.TOOL_FILE: + assert file.related_id is not None + url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + if file.type == FileType.IMAGE: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), @@ -118,9 +116,11 @@ class ToolFileMessageTransformer: save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) + else: + yield message else: yield message @classmethod - def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: + def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: return f'/files/tools/{tool_file_id}{extension or ".bin"}' diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py deleted file mode 100644 index 6f7610651c..0000000000 --- a/api/core/tools/utils/tool_parameter_converter.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolParameter - - -class ToolParameterConverter: - @staticmethod - def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: - match parameter_type: - case ( - ToolParameter.ToolParameterType.STRING - | ToolParameter.ToolParameterType.SECRET_INPUT - | ToolParameter.ToolParameterType.SELECT - ): - return "string" - - case ToolParameter.ToolParameterType.BOOLEAN: - return "boolean" - - case ToolParameter.ToolParameterType.NUMBER: - return "number" - - case _: - raise ValueError(f"Unsupported parameter type {parameter_type}") - - @staticmethod - def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: - # convert tool parameter config to correct type - try: - match parameter_type: - case ( - ToolParameter.ToolParameterType.STRING - | ToolParameter.ToolParameterType.SECRET_INPUT - | ToolParameter.ToolParameterType.SELECT - ): - if value is None: - return "" - else: - return value if isinstance(value, str) else str(value) - - case ToolParameter.ToolParameterType.BOOLEAN: - if value is None: - return False - elif isinstance(value, str): - # Allowed YAML boolean value strings: https://yaml.org/type/bool.html - # and also '0' for False and '1' for True - match value.lower(): - case "true" | "yes" | "y" | "1": - return True - case "false" | "no" | "n" | "0": - return False - case _: - return bool(value) - else: - return value if isinstance(value, bool) else bool(value) - - case ToolParameter.ToolParameterType.NUMBER: - if isinstance(value, int) | isinstance(value, float): - return value - elif isinstance(value, str) and value != "": - if "." in value: - return float(value) - else: - return int(value) - case ToolParameter.ToolParameterType.FILE: - return value - case _: - return str(value) - - except Exception: - raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index df09609402..bc9bc89161 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,4 +1,5 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence +from typing import Any from core.app.app_config.entities import VariableEntity from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration @@ -6,16 +7,12 @@ from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration class WorkflowToolConfigurationUtils: @classmethod - def check_parameter_configurations(cls, configurations: list[dict]): - """ - check parameter configurations - """ + def check_parameter_configurations(cls, configurations: Mapping[str, Any]): for configuration in configurations: - if not WorkflowToolParameterConfiguration(**configuration): - raise ValueError("invalid parameter configuration") + WorkflowToolParameterConfiguration.model_validate(configuration) @classmethod - def get_workflow_graph_variables(cls, graph: Mapping) -> list[VariableEntity]: + def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: """ get workflow graph variables """ diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 99b9f80499..42c7f85bc6 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from typing import Any import yaml @@ -17,15 +18,18 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any :param default_value: the value returned when errors ignored :return: an object of the YAML content """ - try: - with open(file_path, encoding="utf-8") as yaml_file: - try: - yaml_content = yaml.safe_load(yaml_file) - return yaml_content or default_value - except Exception as e: - raise YAMLError(f"Failed to load YAML file {file_path}: {e}") - except Exception as e: + if not file_path or not Path(file_path).exists(): if ignore_error: return default_value else: - raise e + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, encoding="utf-8") as yaml_file: + try: + yaml_content = yaml.safe_load(yaml_file) + return yaml_content or default_value + except Exception as e: + if ignore_error: + return default_value + else: + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index dec353ec93..aa8af1b16e 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -14,6 +14,8 @@ from core.tools.entities.tool_entities import ( ToolIdentity, ToolParameter, ToolParameterOption, + ToolProviderEntity, + ToolProviderIdentity, ToolProviderType, ) from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils @@ -28,6 +30,8 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = { VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, + VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE, + VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES, } @@ -35,6 +39,10 @@ class WorkflowToolProviderController(ToolProviderController): provider_id: str tools: list[WorkflowTool] = Field(default_factory=list) + def __init__(self, entity: ToolProviderEntity, provider_id: str): + super().__init__(entity=entity) + self.provider_id = provider_id + @classmethod def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": app = db_provider.app @@ -43,17 +51,17 @@ class WorkflowToolProviderController(ToolProviderController): raise ValueError("app not found") controller = WorkflowToolProviderController( - **{ - "identity": { - "author": db_provider.user.name if db_provider.user_id and db_provider.user else "", - "name": db_provider.label, - "label": {"en_US": db_provider.label, "zh_Hans": db_provider.label}, - "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, - "icon": db_provider.icon, - }, - "credentials_schema": {}, - "provider_id": db_provider.id or "", - } + entity=ToolProviderEntity( + identity=ToolProviderIdentity( + author=db_provider.user.name if db_provider.user_id and db_provider.user else "", + name=db_provider.label, + label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), + description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), + icon=db_provider.icon, + ), + credentials_schema=[], + ), + provider_id=db_provider.id, ) # init tools @@ -121,7 +129,6 @@ class WorkflowToolProviderController(ToolProviderController): llm_description=parameter.description, required=variable.required, options=options, - default=variable.default, ) ) elif features.file_upload: @@ -130,7 +137,7 @@ class WorkflowToolProviderController(ToolProviderController): name=parameter.name, label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name), human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), - type=ToolParameter.ToolParameterType.FILE, + type=ToolParameter.ToolParameterType.SYSTEM_FILES, llm_description=parameter.description, required=False, form=parameter.form, diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 677e52b5ba..d7838eae57 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -3,7 +3,7 @@ import logging from collections.abc import Generator from typing import Any, Optional, Union -from core.file.file_obj import FileTransferMethod, FileVar +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType @@ -71,14 +71,13 @@ class WorkflowTool(Tool): workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) # transform the tool parameters - tool_parameters, files = self._transform_args(tool_parameters) + tool_parameters, files = self._transform_args(tool_parameters=tool_parameters) from core.app.apps.workflow.app_generator import WorkflowAppGenerator generator = WorkflowAppGenerator() - - assert self.runtime - assert self.runtime.invoke_from + assert self.runtime is not None + assert self.runtime.invoke_from is not None result = generator.generate( app_model=app, @@ -105,7 +104,7 @@ class WorkflowTool(Tool): else: outputs, files = self._extract_files(outputs) for file in files: - yield self.create_file_var_message(file) + yield self.create_file_message(file) yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_json_message(outputs) @@ -181,22 +180,22 @@ class WorkflowTool(Tool): parameters_result = {} files = [] for parameter in parameter_rules: - if parameter.type == ToolParameter.ToolParameterType.FILE: + if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES: file = tool_parameters.get(parameter.name) if file: try: - file_var_list = [FileVar(**f) for f in file] - for file_var in file_var_list: - file_dict: dict[str, Any] = { - "transfer_method": file_var.transfer_method.value, - "type": file_var.type.value, + file_var_list = [File.model_validate(f) for f in file] + for file in file_var_list: + file_dict: dict[str, str | None] = { + "transfer_method": file.transfer_method.value, + "type": file.type.value, } - if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file_var.related_id - elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file_var.related_id - elif file_var.transfer_method == FileTransferMethod.REMOTE_URL: - file_dict["url"] = file_var.preview_url + if file.transfer_method == FileTransferMethod.TOOL_FILE: + file_dict["tool_file_id"] = file.related_id + elif file.transfer_method == FileTransferMethod.LOCAL_FILE: + file_dict["upload_file_id"] = file.related_id + elif file.transfer_method == FileTransferMethod.REMOTE_URL: + file_dict["url"] = file.generate_url() files.append(file_dict) except Exception as e: @@ -206,7 +205,7 @@ class WorkflowTool(Tool): return parameters_result, files - def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: + def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]: """ extract files from the result @@ -217,17 +216,13 @@ class WorkflowTool(Tool): result = {} for key, value in outputs.items(): if isinstance(value, list): - has_file = False for item in value: - if isinstance(item, dict) and item.get("__variant") == "FileVar": - try: - files.append(FileVar(**item)) - has_file = True - except Exception as e: - pass - if has_file: - continue + if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY: + file = File.model_validate(item) + files.append(file) + elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + file = File.model_validate(value) + files.append(file) result[key] = value - return result, files diff --git a/api/core/app/segments/__init__.py b/api/core/variables/__init__.py similarity index 78% rename from api/core/app/segments/__init__.py rename to api/core/variables/__init__.py index 652ef243b4..87f9e3ed45 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/variables/__init__.py @@ -1,7 +1,12 @@ from .segment_group import SegmentGroup from .segments import ( ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, ArraySegment, + ArrayStringSegment, + FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -15,6 +20,7 @@ from .variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, + FileVariable, FloatVariable, IntegerVariable, NoneVariable, @@ -46,4 +52,10 @@ __all__ = [ "ArrayNumberVariable", "ArrayObjectVariable", "ArraySegment", + "ArrayFileSegment", + "ArrayNumberSegment", + "ArrayObjectSegment", + "ArrayStringSegment", + "FileSegment", + "FileVariable", ] diff --git a/api/core/app/segments/exc.py b/api/core/variables/exc.py similarity index 100% rename from api/core/app/segments/exc.py rename to api/core/variables/exc.py diff --git a/api/core/app/segments/segment_group.py b/api/core/variables/segment_group.py similarity index 100% rename from api/core/app/segments/segment_group.py rename to api/core/variables/segment_group.py diff --git a/api/core/app/segments/segments.py b/api/core/variables/segments.py similarity index 79% rename from api/core/app/segments/segments.py rename to api/core/variables/segments.py index b26b3c8291..782798411e 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/variables/segments.py @@ -5,6 +5,8 @@ from typing import Any from pydantic import BaseModel, ConfigDict, field_validator +from core.file import File + from .types import SegmentType @@ -39,6 +41,9 @@ class Segment(BaseModel): @property def size(self) -> int: + """ + Return the size of the value in bytes. + """ return sys.getsizeof(self.value) def to_object(self) -> Any: @@ -99,13 +104,27 @@ class ArraySegment(Segment): def markdown(self) -> str: items = [] for item in self.value: - if hasattr(item, "to_markdown"): - items.append(item.to_markdown()) - else: - items.append(str(item)) + items.append(str(item)) return "\n".join(items) +class FileSegment(Segment): + value_type: SegmentType = SegmentType.FILE + value: File + + @property + def markdown(self) -> str: + return self.value.markdown + + @property + def log(self) -> str: + return str(self.value) + + @property + def text(self) -> str: + return str(self.value) + + class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY value: Sequence[Any] @@ -124,3 +143,15 @@ class ArrayNumberSegment(ArraySegment): class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT value: Sequence[Mapping[str, Any]] + + +class ArrayFileSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_FILE + value: Sequence[File] + + @property + def markdown(self) -> str: + items = [] + for item in self.value: + items.append(item.markdown) + return "\n".join(items) diff --git a/api/core/app/segments/types.py b/api/core/variables/types.py similarity index 86% rename from api/core/app/segments/types.py rename to api/core/variables/types.py index 9cf0856df5..53c2e8a3aa 100644 --- a/api/core/app/segments/types.py +++ b/api/core/variables/types.py @@ -11,5 +11,7 @@ class SegmentType(str, Enum): ARRAY_NUMBER = "array[number]" ARRAY_OBJECT = "array[object]" OBJECT = "object" + FILE = "file" + ARRAY_FILE = "array[file]" GROUP = "group" diff --git a/api/core/app/segments/variables.py b/api/core/variables/variables.py similarity index 95% rename from api/core/app/segments/variables.py rename to api/core/variables/variables.py index f0e403ab8d..ddc6914192 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/variables/variables.py @@ -7,6 +7,7 @@ from .segments import ( ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -73,3 +74,7 @@ class SecretVariable(StringVariable): class NoneVariable(NoneSegment, Variable): value_type: SegmentType = SegmentType.NONE value: None = None + + +class FileVariable(FileSegment, Variable): + pass diff --git a/api/core/workflow/callbacks/__init__.py b/api/core/workflow/callbacks/__init__.py index e69de29bb2..403fbbaa2f 100644 --- a/api/core/workflow/callbacks/__init__.py +++ b/api/core/workflow/callbacks/__init__.py @@ -0,0 +1,7 @@ +from .base_workflow_callback import WorkflowCallback +from .workflow_logging_callback import WorkflowLoggingCallback + +__all__ = [ + "WorkflowLoggingCallback", + "WorkflowCallback", +] diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py similarity index 99% rename from api/core/app/apps/workflow_logging_callback.py rename to api/core/workflow/callbacks/workflow_logging_callback.py index 60683b0f21..17913de7b0 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -1,7 +1,6 @@ from typing import Optional from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, GraphRunFailedEvent, @@ -20,6 +19,8 @@ from core.workflow.graph_engine.entities.event import ( ParallelBranchRunSucceededEvent, ) +from .base_workflow_callback import WorkflowCallback + _TEXT_COLOR_MAPPING = { "blue": "36;1", "yellow": "33;1", diff --git a/api/core/workflow/constants.py b/api/core/workflow/constants.py new file mode 100644 index 0000000000..e3fe17c284 --- /dev/null +++ b/api/core/workflow/constants.py @@ -0,0 +1,3 @@ +SYSTEM_VARIABLE_NODE_ID = "sys" +ENVIRONMENT_VARIABLE_NODE_ID = "env" +CONVERSATION_VARIABLE_NODE_ID = "conversation" diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 5353b99ed3..0131bb342b 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,52 +1,14 @@ +from collections.abc import Mapping from enum import Enum from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMUsage -from models import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus -class NodeType(Enum): - """ - Node Types. - """ - - START = "start" - END = "end" - ANSWER = "answer" - LLM = "llm" - KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" - IF_ELSE = "if-else" - CODE = "code" - TEMPLATE_TRANSFORM = "template-transform" - QUESTION_CLASSIFIER = "question-classifier" - HTTP_REQUEST = "http-request" - TOOL = "tool" - VARIABLE_AGGREGATOR = "variable-aggregator" - # TODO: merge this into VARIABLE_AGGREGATOR - VARIABLE_ASSIGNER = "variable-assigner" - LOOP = "loop" - ITERATION = "iteration" - ITERATION_START = "iteration-start" # fake start node for iteration - PARAMETER_EXTRACTOR = "parameter-extractor" - CONVERSATION_VARIABLE_ASSIGNER = "assigner" - - @classmethod - def value_of(cls, value: str) -> "NodeType": - """ - Get value of given node type. - - :param value: node type value - :return: node type - """ - for node_type in cls: - if node_type.value == value: - return node_type - raise ValueError(f"invalid node type value {value}") - - -class NodeRunMetadataKey(Enum): +class NodeRunMetadataKey(str, Enum): """ Node Run Metadata Key. """ @@ -70,7 +32,7 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING - inputs: Optional[dict[str, Any]] = None # node inputs + inputs: Optional[Mapping[str, Any]] = None # node inputs process_data: Optional[dict[str, Any]] = None # process data outputs: Optional[dict[str, Any]] = None # node outputs metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata @@ -79,24 +41,3 @@ class NodeRunResult(BaseModel): edge_source_handle: Optional[str] = None # source handle id of node with multiple branches error: Optional[str] = None # error message if status is failed - - -class UserFrom(Enum): - """ - User from - """ - - ACCOUNT = "account" - END_USER = "end-user" - - @classmethod - def value_of(cls, value: str) -> "UserFrom": - """ - Value of - :param value: value - :return: - """ - for item in cls: - if item.value == value: - return item - raise ValueError(f"Invalid value: {value}") diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py index 1dfb1852f8..8f4c2d7975 100644 --- a/api/core/workflow/entities/variable_entities.py +++ b/api/core/workflow/entities/variable_entities.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + from pydantic import BaseModel @@ -7,4 +9,4 @@ class VariableSelector(BaseModel): """ variable: str - value_selector: list[str] + value_selector: Sequence[str] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index b94b7f7198..5f932c0a8e 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,20 +1,23 @@ +import re from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Union -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field from typing_extensions import deprecated -from core.app.segments import Segment, Variable, factory -from core.file.file_obj import FileVar -from core.workflow.enums import SystemVariableKey +from core.file import File, FileAttribute, file_manager +from core.variables import Segment, SegmentGroup, Variable +from core.variables.segments import FileSegment +from factories import variable_factory -VariableValue = Union[str, int, float, dict, list, FileVar] +from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from ..enums import SystemVariableKey + +VariableValue = Union[str, int, float, dict, list, File] -SYSTEM_VARIABLE_NODE_ID = "sys" -ENVIRONMENT_VARIABLE_NODE_ID = "env" -CONVERSATION_VARIABLE_NODE_ID = "conversation" +VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") class VariablePool(BaseModel): @@ -23,46 +26,63 @@ class VariablePool(BaseModel): # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. variable_dictionary: dict[str, dict[int, Segment]] = Field( - description="Variables mapping", default=defaultdict(dict) + description="Variables mapping", + default=defaultdict(dict), ) - # TODO: This user inputs is not used for pool. user_inputs: Mapping[str, Any] = Field( description="User inputs", ) - system_variables: Mapping[SystemVariableKey, Any] = Field( description="System variables", ) + environment_variables: Sequence[Variable] = Field( + description="Environment variables.", + default_factory=list, + ) + conversation_variables: Sequence[Variable] = Field( + description="Conversation variables.", + default_factory=list, + ) - environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list) + def __init__( + self, + *, + system_variables: Mapping[SystemVariableKey, Any] | None = None, + user_inputs: Mapping[str, Any] | None = None, + environment_variables: Sequence[Variable] | None = None, + conversation_variables: Sequence[Variable] | None = None, + **kwargs, + ): + environment_variables = environment_variables or [] + conversation_variables = conversation_variables or [] + user_inputs = user_inputs or {} + system_variables = system_variables or {} - conversation_variables: Sequence[Variable] | None = None + super().__init__( + system_variables=system_variables, + user_inputs=user_inputs, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + **kwargs, + ) - @model_validator(mode="after") - def val_model_after(self): - """ - Append system variables - :return: - """ - # Add system variables to the variable pool for key, value in self.system_variables.items(): self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) - # Add environment variables to the variable pool - for var in self.environment_variables or []: + for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool - for var in self.conversation_variables or []: + for var in self.conversation_variables: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) - return self - def add(self, selector: Sequence[str], value: Any, /) -> None: """ Adds a variable to the variable pool. + NOTE: You should not add a non-Segment value to the variable pool + even if it is allowed now. + Args: selector (Sequence[str]): The selector for the variable. value (VariableValue): The value of the variable. @@ -82,7 +102,7 @@ class VariablePool(BaseModel): if isinstance(value, Segment): v = value else: - v = factory.build_segment(value) + v = variable_factory.build_segment(value) hash_key = hash(tuple(selector[1:])) self.variable_dictionary[selector[0]][hash_key] = v @@ -101,10 +121,19 @@ class VariablePool(BaseModel): ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError("Invalid selector") + return None + hash_key = hash(tuple(selector[1:])) value = self.variable_dictionary[selector[0]].get(hash_key) + if value is None: + selector, attr = selector[:-1], selector[-1] + value = self.get(selector) + if isinstance(value, FileSegment): + attr = FileAttribute(attr) + attr_value = file_manager.get_attr(file=value.value, attr=attr) + return variable_factory.build_segment(attr_value) + return value @deprecated("This method is deprecated, use `get` instead.") @@ -145,14 +174,18 @@ class VariablePool(BaseModel): hash_key = hash(tuple(selector[1:])) self.variable_dictionary[selector[0]].pop(hash_key, None) - def remove_node(self, node_id: str, /): - """ - Remove all variables associated with a given node id. + def convert_template(self, template: str, /): + parts = VARIABLE_PATTERN.split(template) + segments = [] + for part in filter(lambda x: x, parts): + if "." in part and (variable := self.get(part.split("."))): + segments.append(variable) + else: + segments.append(variable_factory.build_segment(part)) + return SegmentGroup(value=segments) - Args: - node_id (str): The node id to remove. - - Returns: - None - """ - self.variable_dictionary.pop(node_id, None) + def get_file(self, selector: Sequence[str], /) -> FileSegment | None: + segment = self.get(selector) + if isinstance(segment, FileSegment): + return segment + return None diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 0a1eb57de4..da56af1407 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -3,12 +3,13 @@ from typing import Optional from pydantic import BaseModel from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.base_node_data_entities import BaseIterationState -from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode, UserFrom +from core.workflow.nodes.base import BaseIterationState, BaseNode +from models.enums import UserFrom from models.workflow import Workflow, WorkflowType +from .node_entities import NodeRunResult +from .variable_pool import VariablePool + class WorkflowNodeAndResult: node: BaseNode diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index 07cbcd981e..bd4ccc1072 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base import BaseNode class WorkflowNodeRunFailedError(Exception): diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py index e69de29bb2..2fee3d7fad 100644 --- a/api/core/workflow/graph_engine/__init__.py +++ b/api/core/workflow/graph_engine/__init__.py @@ -0,0 +1,3 @@ +from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState + +__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index eda5fe079c..bc3a15bd00 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -18,11 +18,10 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): # process condition condition_processor = ConditionProcessor() - input_conditions, group_result = condition_processor.process_conditions( - variable_pool=graph_runtime_state.variable_pool, conditions=self.condition.conditions + _, _, final_result = condition_processor.process_conditions( + variable_pool=graph_runtime_state.variable_pool, + conditions=self.condition.conditions, + operator="and", ) - # Apply the logical operator for the current case - compare_result = all(group_result) - - return compare_result + return final_result diff --git a/api/core/workflow/graph_engine/entities/__init__.py b/api/core/workflow/graph_engine/entities/__init__.py index e69de29bb2..6331a0b723 100644 --- a/api/core/workflow/graph_engine/entities/__init__.py +++ b/api/core/workflow/graph_engine/entities/__init__.py @@ -0,0 +1,6 @@ +from .graph import Graph +from .graph_init_params import GraphInitParams +from .graph_runtime_state import GraphRuntimeState +from .runtime_route_state import RuntimeRouteState + +__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 06dc4cb8f4..86d89e0a32 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -3,9 +3,9 @@ from typing import Any, Optional from pydantic import BaseModel, Field -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes import NodeType +from core.workflow.nodes.base import BaseNodeData class GraphEngineEvent(BaseModel): diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 1175f4af2a..d87c039409 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -4,8 +4,8 @@ from typing import Any, Optional, cast from pydantic import BaseModel, Field -from core.workflow.entities.node_entities import NodeType from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter diff --git a/api/core/workflow/graph_engine/entities/graph_init_params.py b/api/core/workflow/graph_engine/entities/graph_init_params.py index 1a403f3e49..a0ecd824f4 100644 --- a/api/core/workflow/graph_engine/entities/graph_init_params.py +++ b/api/core/workflow/graph_engine/entities/graph_init_params.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import UserFrom +from models.enums import UserFrom from models.workflow import WorkflowType diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 8342dbd13d..ada0b14ce4 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -10,11 +10,7 @@ from flask import Flask, current_app from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import ( - NodeRunMetadataKey, - NodeType, - UserFrom, -) +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( @@ -36,12 +32,14 @@ from core.workflow.graph_engine.entities.graph import Graph, GraphEdge from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.nodes.node_mapping import node_classes +from core.workflow.nodes.node_mapping import node_type_classes_mapping from extensions.ext_database import db +from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType logger = logging.getLogger(__name__) @@ -229,10 +227,8 @@ class GraphEngine: raise GraphRunFailedError(f"Node {node_id} config not found.") # convert to specific node - node_type = NodeType.value_of(node_config.get("data", {}).get("type")) - node_cls = node_classes.get(node_type) - if not node_cls: - raise GraphRunFailedError(f"Node {node_id} type {node_type} not found.") + node_type = NodeType(node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping[node_type] previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py index e69de29bb2..6101fcf9af 100644 --- a/api/core/workflow/nodes/__init__.py +++ b/api/core/workflow/nodes/__init__.py @@ -0,0 +1,3 @@ +from .enums import NodeType + +__all__ = ["NodeType"] diff --git a/api/core/workflow/nodes/answer/__init__.py b/api/core/workflow/nodes/answer/__init__.py index e69de29bb2..7a10f47eed 100644 --- a/api/core/workflow/nodes/answer/__init__.py +++ b/api/core/workflow/nodes/answer/__init__.py @@ -0,0 +1,4 @@ +from .answer_node import AnswerNode +from .entities import AnswerStreamGenerateRoute + +__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"] diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index deacbbbbb0..520cbdbb60 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,7 +1,8 @@ from collections.abc import Mapping, Sequence from typing import Any, cast -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.variables import ArrayFileSegment, FileSegment +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import ( AnswerNodeData, @@ -9,12 +10,13 @@ from core.workflow.nodes.answer.entities import ( TextGenerateRouteChunk, VarGenerateRouteChunk, ) -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser from models.workflow import WorkflowNodeExecutionStatus -class AnswerNode(BaseNode): +class AnswerNode(BaseNode[AnswerNodeData]): _node_data_cls = AnswerNodeData _node_type: NodeType = NodeType.ANSWER @@ -23,30 +25,35 @@ class AnswerNode(BaseNode): Run node :return: """ - node_data = self.node_data - node_data = cast(AnswerNodeData, node_data) - # generate routes - generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data) answer = "" + files = [] for part in generate_routes: if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = self.graph_runtime_state.variable_pool.get(value_selector) - - if value: - answer += value.markdown + variable = self.graph_runtime_state.variable_pool.get(value_selector) + if variable: + if isinstance(variable, FileSegment): + files.append(variable.value) + elif isinstance(variable, ArrayFileSegment): + files.extend(variable.value) + answer += variable.markdown else: part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer}) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: AnswerNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -55,9 +62,6 @@ class AnswerNode(BaseNode): :param node_data: node data :return: """ - node_data = node_data - node_data = cast(AnswerNodeData, node_data) - variable_template_parser = VariableTemplateParser(template=node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index bbd1f88867..bce28c5fcb 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -1,5 +1,4 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.answer.entities import ( AnswerNodeData, AnswerStreamGenerateRoute, @@ -7,6 +6,7 @@ from core.workflow.nodes.answer.entities import ( TextGenerateRouteChunk, VarGenerateRouteChunk, ) +from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 32dbf436ec..e3889941ca 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -1,8 +1,8 @@ import logging from collections.abc import Generator -from typing import Optional, cast +from typing import cast -from core.file.file_obj import FileVar +from core.file import FILE_MODEL_IDENTITY, File from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, @@ -203,7 +203,7 @@ class AnswerStreamProcessor(StreamProcessor): return files @classmethod - def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]: + def _get_file_var_from_value(cls, value: dict | list): """ Get file var from value :param value: variable value @@ -213,9 +213,9 @@ class AnswerStreamProcessor(StreamProcessor): return None if isinstance(value, dict): - if "__variant" in value and value["__variant"] == FileVar.__name__: + if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY: return value - elif isinstance(value, FileVar): + elif isinstance(value, File): return value.to_dict() return None diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index e356e7fd70..e543d02dd7 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -2,7 +2,7 @@ from enum import Enum from pydantic import BaseModel, Field -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class AnswerNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/base/__init__.py b/api/core/workflow/nodes/base/__init__.py new file mode 100644 index 0000000000..61f727740c --- /dev/null +++ b/api/core/workflow/nodes/base/__init__.py @@ -0,0 +1,4 @@ +from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData +from .node import BaseNode + +__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"] diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/nodes/base/entities.py similarity index 100% rename from api/core/workflow/entities/base_node_data_entities.py rename to api/core/workflow/nodes/base/entities.py diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base/node.py similarity index 60% rename from api/core/workflow/nodes/base_node.py rename to api/core/workflow/nodes/base/node.py index 7bfe45a13c..053a339ba7 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,17 +1,27 @@ -from abc import ABC, abstractmethod +import logging +from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.graph_engine.entities.event import InNodeEvent -from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams -from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.nodes.event import RunCompletedEvent, RunEvent +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import NodeEvent, RunCompletedEvent +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import BaseNodeData + +if TYPE_CHECKING: + from core.workflow.graph_engine.entities.event import InNodeEvent + from core.workflow.graph_engine.entities.graph import Graph + from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams + from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState + +logger = logging.getLogger(__name__) + +GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) -class BaseNode(ABC): +class BaseNode(Generic[GenericNodeData]): _node_data_cls: type[BaseNodeData] _node_type: NodeType @@ -19,9 +29,9 @@ class BaseNode(ABC): self, id: str, config: Mapping[str, Any], - graph_init_params: GraphInitParams, - graph: Graph, - graph_runtime_state: GraphRuntimeState, + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", previous_node_id: Optional[str] = None, thread_pool_id: Optional[str] = None, ) -> None: @@ -45,22 +55,25 @@ class BaseNode(ABC): raise ValueError("Node ID is required.") self.node_id = node_id - self.node_data = self._node_data_cls(**config.get("data", {})) + self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {}))) @abstractmethod - def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: + def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: """ Run node :return: """ raise NotImplementedError - def run(self) -> Generator[RunEvent | InNodeEvent, None, None]: - """ - Run node entry - :return: - """ - result = self._run() + def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: + try: + result = self._run() + except Exception as e: + logger.error(f"Node {self.node_id} failed to run: {e}") + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) if isinstance(result, NodeRunResult): yield RunCompletedEvent(run_result=result) @@ -69,7 +82,10 @@ class BaseNode(ABC): @classmethod def extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], config: dict + cls, + *, + graph_config: Mapping[str, Any], + config: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -83,12 +99,16 @@ class BaseNode(ABC): node_data = cls._node_data_cls(**config.get("data", {})) return cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=node_data + graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: GenericNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/code/__init__.py b/api/core/workflow/nodes/code/__init__.py index e69de29bb2..8c6dcc7fcc 100644 --- a/api/core/workflow/nodes/code/__init__.py +++ b/api/core/workflow/nodes/code/__init__.py @@ -0,0 +1,3 @@ +from .code_node import CodeNode + +__all__ = ["CodeNode"] diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 9da7ad99f3..dd533ffc4c 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,18 +1,19 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus -class CodeNode(BaseNode): +class CodeNode(BaseNode[CodeNodeData]): _node_data_cls = CodeNodeData _node_type = NodeType.CODE @@ -33,20 +34,13 @@ class CodeNode(BaseNode): return code_provider.get_default_config() def _run(self) -> NodeRunResult: - """ - Run code - :return: - """ - node_data = self.node_data - node_data = cast(CodeNodeData, node_data) - # Get code language - code_language = node_data.code_language - code = node_data.code + code_language = self.node_data.code_language + code = self.node_data.code # Get variables variables = {} - for variable_selector in node_data.variables: + for variable_selector in self.node_data.variables: variable = variable_selector.variable value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) @@ -60,7 +54,7 @@ class CodeNode(BaseNode): ) # Transform result - result = self._transform_result(result, node_data.outputs) + result = self._transform_result(result, self.node_data.outputs) except (CodeExecutionError, ValueError) as e: return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) @@ -316,7 +310,11 @@ class CodeNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: CodeNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 5eb0e0f63f..e78183baf1 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -3,8 +3,8 @@ from typing import Literal, Optional from pydantic import BaseModel from core.helper.code_executor.code_executor import CodeLanguage -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class CodeNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/document_extractor/__init__.py b/api/core/workflow/nodes/document_extractor/__init__.py new file mode 100644 index 0000000000..3cc5fae187 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/__init__.py @@ -0,0 +1,4 @@ +from .entities import DocumentExtractorNodeData +from .node import DocumentExtractorNode + +__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"] diff --git a/api/core/workflow/nodes/document_extractor/entities.py b/api/core/workflow/nodes/document_extractor/entities.py new file mode 100644 index 0000000000..7e9ffaa889 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/entities.py @@ -0,0 +1,7 @@ +from collections.abc import Sequence + +from core.workflow.nodes.base import BaseNodeData + + +class DocumentExtractorNodeData(BaseNodeData): + variable_selector: Sequence[str] diff --git a/api/core/workflow/nodes/document_extractor/exc.py b/api/core/workflow/nodes/document_extractor/exc.py new file mode 100644 index 0000000000..c9d4bb8ef6 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/exc.py @@ -0,0 +1,14 @@ +class DocumentExtractorError(Exception): + """Base exception for errors related to the DocumentExtractorNode.""" + + +class FileDownloadError(DocumentExtractorError): + """Exception raised when there's an error downloading a file.""" + + +class UnsupportedFileTypeError(DocumentExtractorError): + """Exception raised when trying to extract text from an unsupported file type.""" + + +class TextExtractionError(DocumentExtractorError): + """Exception raised when there's an error during text extraction from a file.""" diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py new file mode 100644 index 0000000000..3efcc373b1 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -0,0 +1,244 @@ +import csv +import io + +import docx +import pandas as pd +import pypdfium2 +from unstructured.partition.email import partition_email +from unstructured.partition.epub import partition_epub +from unstructured.partition.msg import partition_msg +from unstructured.partition.ppt import partition_ppt +from unstructured.partition.pptx import partition_pptx + +from core.file import File, FileTransferMethod, file_manager +from core.helper import ssrf_proxy +from core.variables import ArrayFileSegment +from core.variables.segments import FileSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import DocumentExtractorNodeData +from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError + + +class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): + """ + Extracts text content from various file types. + Supports plain text, PDF, and DOC/DOCX files. + """ + + _node_data_cls = DocumentExtractorNodeData + _node_type = NodeType.DOCUMENT_EXTRACTOR + + def _run(self): + variable_selector = self.node_data.variable_selector + variable = self.graph_runtime_state.variable_pool.get(variable_selector) + + if variable is None: + error_message = f"File variable not found for selector: {variable_selector}" + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) + if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment): + error_message = f"Variable {variable_selector} is not an ArrayFileSegment" + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) + + value = variable.value + inputs = {"variable_selector": variable_selector} + process_data = {"documents": value if isinstance(value, list) else [value]} + + try: + if isinstance(value, list): + extracted_text_list = list(map(_extract_text_from_file, value)) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={"text": extracted_text_list}, + ) + elif isinstance(value, File): + extracted_text = _extract_text_from_file(value) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={"text": extracted_text}, + ) + else: + raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") + except DocumentExtractorError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=inputs, + process_data=process_data, + ) + + +def _extract_text(*, file_content: bytes, mime_type: str) -> str: + """Extract text from a file based on its MIME type.""" + if mime_type.startswith("text/plain") or mime_type in {"text/html", "text/htm", "text/markdown", "text/xml"}: + return _extract_text_from_plain_text(file_content) + elif mime_type == "application/pdf": + return _extract_text_from_pdf(file_content) + elif mime_type in { + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/msword", + }: + return _extract_text_from_doc(file_content) + elif mime_type == "text/csv": + return _extract_text_from_csv(file_content) + elif mime_type in { + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.ms-excel", + }: + return _extract_text_from_excel(file_content) + elif mime_type == "application/vnd.ms-powerpoint": + return _extract_text_from_ppt(file_content) + elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation": + return _extract_text_from_pptx(file_content) + elif mime_type == "application/epub+zip": + return _extract_text_from_epub(file_content) + elif mime_type == "message/rfc822": + return _extract_text_from_eml(file_content) + elif mime_type == "application/vnd.ms-outlook": + return _extract_text_from_msg(file_content) + else: + raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") + + +def _extract_text_from_plain_text(file_content: bytes) -> str: + try: + return file_content.decode("utf-8") + except UnicodeDecodeError as e: + raise TextExtractionError("Failed to decode plain text file") from e + + +def _extract_text_from_pdf(file_content: bytes) -> str: + try: + pdf_file = io.BytesIO(file_content) + pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True) + text = "" + for page in pdf_document: + text_page = page.get_textpage() + text += text_page.get_text_range() + text_page.close() + page.close() + return text + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e + + +def _extract_text_from_doc(file_content: bytes) -> str: + try: + doc_file = io.BytesIO(file_content) + doc = docx.Document(doc_file) + return "\n".join([paragraph.text for paragraph in doc.paragraphs]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e + + +def _download_file_content(file: File) -> bytes: + """Download the content of a file based on its transfer method.""" + try: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + if file.remote_url is None: + raise FileDownloadError("Missing URL for remote file") + response = ssrf_proxy.get(file.remote_url) + response.raise_for_status() + return response.content + elif file.transfer_method == FileTransferMethod.LOCAL_FILE: + return file_manager.download(file) + else: + raise ValueError(f"Unsupported transfer method: {file.transfer_method}") + except Exception as e: + raise FileDownloadError(f"Error downloading file: {str(e)}") from e + + +def _extract_text_from_file(file: File): + if file.mime_type is None: + raise UnsupportedFileTypeError("Unable to determine file type: MIME type is missing") + file_content = _download_file_content(file) + extracted_text = _extract_text(file_content=file_content, mime_type=file.mime_type) + return extracted_text + + +def _extract_text_from_csv(file_content: bytes) -> str: + try: + csv_file = io.StringIO(file_content.decode("utf-8")) + csv_reader = csv.reader(csv_file) + rows = list(csv_reader) + + if not rows: + return "" + + # Create markdown table + markdown_table = "| " + " | ".join(rows[0]) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n" + for row in rows[1:]: + markdown_table += "| " + " | ".join(row) + " |\n" + + return markdown_table.strip() + except Exception as e: + raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e + + +def _extract_text_from_excel(file_content: bytes) -> str: + """Extract text from an Excel file using pandas.""" + + try: + df = pd.read_excel(io.BytesIO(file_content)) + + # Drop rows where all elements are NaN + df.dropna(how="all", inplace=True) + + # Convert DataFrame to markdown table + markdown_table = df.to_markdown(index=False) + return markdown_table + except Exception as e: + raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e + + +def _extract_text_from_ppt(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_ppt(file=file) + return "\n".join([getattr(element, "text", "") for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PPT: {str(e)}") from e + + +def _extract_text_from_pptx(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_pptx(file=file) + return "\n".join([getattr(element, "text", "") for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e + + +def _extract_text_from_epub(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_epub(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e + + +def _extract_text_from_eml(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_email(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e + + +def _extract_text_from_msg(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_msg(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e diff --git a/api/core/workflow/nodes/end/__init__.py b/api/core/workflow/nodes/end/__init__.py index e69de29bb2..adb381701c 100644 --- a/api/core/workflow/nodes/end/__init__.py +++ b/api/core/workflow/nodes/end/__init__.py @@ -0,0 +1,4 @@ +from .end_node import EndNode +from .entities import EndStreamParam + +__all__ = ["EndStreamParam", "EndNode"] diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 7b78d67be8..2398e4e89d 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,13 +1,14 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus -class EndNode(BaseNode): +class EndNode(BaseNode[EndNodeData]): _node_data_cls = EndNodeData _node_type = NodeType.END @@ -16,20 +17,27 @@ class EndNode(BaseNode): Run node :return: """ - node_data = self.node_data - node_data = cast(EndNodeData, node_data) - output_variables = node_data.outputs + output_variables = self.node_data.outputs outputs = {} for variable_selector in output_variables: - value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + value = variable.to_object() if variable is not None else None outputs[variable_selector.variable] = value - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=outputs, + outputs=outputs, + ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: EndNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 9a7d2ecde3..ea8b6b5042 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -1,5 +1,5 @@ -from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam +from core.workflow.nodes.enums import NodeType class EndStreamGeneratorRouter: diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index c3270ac22a..c16e85b0eb 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class EndNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py new file mode 100644 index 0000000000..208144655b --- /dev/null +++ b/api/core/workflow/nodes/enums.py @@ -0,0 +1,24 @@ +from enum import Enum + + +class NodeType(str, Enum): + START = "start" + END = "end" + ANSWER = "answer" + LLM = "llm" + KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + IF_ELSE = "if-else" + CODE = "code" + TEMPLATE_TRANSFORM = "template-transform" + QUESTION_CLASSIFIER = "question-classifier" + HTTP_REQUEST = "http-request" + TOOL = "tool" + VARIABLE_AGGREGATOR = "variable-aggregator" + VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. + LOOP = "loop" + ITERATION = "iteration" + ITERATION_START = "iteration-start" # Fake start node for iteration. + PARAMETER_EXTRACTOR = "parameter-extractor" + CONVERSATION_VARIABLE_ASSIGNER = "assigner" + DOCUMENT_EXTRACTOR = "document-extractor" + LIST_OPERATOR = "list-operator" diff --git a/api/core/workflow/nodes/event/__init__.py b/api/core/workflow/nodes/event/__init__.py new file mode 100644 index 0000000000..581def9553 --- /dev/null +++ b/api/core/workflow/nodes/event/__init__.py @@ -0,0 +1,10 @@ +from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from .types import NodeEvent + +__all__ = [ + "RunCompletedEvent", + "RunRetrieverResourceEvent", + "RunStreamChunkEvent", + "NodeEvent", + "ModelInvokeCompletedEvent", +] diff --git a/api/core/workflow/nodes/event.py b/api/core/workflow/nodes/event/event.py similarity index 72% rename from api/core/workflow/nodes/event.py rename to api/core/workflow/nodes/event/event.py index 276c13a6d4..b7034561bf 100644 --- a/api/core/workflow/nodes/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, Field +from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.node_entities import NodeRunResult @@ -17,4 +18,11 @@ class RunRetrieverResourceEvent(BaseModel): context: str = Field(..., description="context") -RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent +class ModelInvokeCompletedEvent(BaseModel): + """ + Model invoke completed + """ + + text: str + usage: LLMUsage + finish_reason: str | None = None diff --git a/api/core/workflow/nodes/event/types.py b/api/core/workflow/nodes/event/types.py new file mode 100644 index 0000000000..b19a91022d --- /dev/null +++ b/api/core/workflow/nodes/event/types.py @@ -0,0 +1,3 @@ +from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent + +NodeEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent | ModelInvokeCompletedEvent diff --git a/api/core/workflow/nodes/http_request/__init__.py b/api/core/workflow/nodes/http_request/__init__.py index e69de29bb2..9408c2dde0 100644 --- a/api/core/workflow/nodes/http_request/__init__.py +++ b/api/core/workflow/nodes/http_request/__init__.py @@ -0,0 +1,4 @@ +from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData +from .node import HttpRequestNode + +__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"] diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 66dd1f2dc6..816ece9577 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,15 +1,25 @@ -from typing import Literal, Optional, Union +from collections.abc import Sequence +from typing import Literal, Optional -from pydantic import BaseModel, ValidationInfo, field_validator +import httpx +from pydantic import BaseModel, Field, ValidationInfo, field_validator from configs import dify_config -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData + +NON_FILE_CONTENT_TYPES = ( + "application/json", + "application/xml", + "text/html", + "text/plain", + "application/x-www-form-urlencoded", +) class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal[None, "basic", "bearer", "custom"] - api_key: Union[None, str] = None - header: Union[None, str] = None + type: Literal["basic", "bearer", "custom"] + api_key: str + header: str = "" class HttpRequestNodeAuthorization(BaseModel): @@ -31,9 +41,16 @@ class HttpRequestNodeAuthorization(BaseModel): return v +class BodyData(BaseModel): + key: str = "" + type: Literal["file", "text"] + value: str = "" + file: Sequence[str] = Field(default_factory=list) + + class HttpRequestNodeBody(BaseModel): - type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"] - data: Union[None, str] = None + type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"] + data: Sequence[BodyData] = Field(default_factory=list) class HttpRequestNodeTimeout(BaseModel): @@ -54,3 +71,51 @@ class HttpRequestNodeData(BaseNodeData): params: str body: Optional[HttpRequestNodeBody] = None timeout: Optional[HttpRequestNodeTimeout] = None + + +class Response: + headers: dict[str, str] + response: httpx.Response + + def __init__(self, response: httpx.Response): + self.response = response + self.headers = dict(response.headers) + + @property + def is_file(self): + content_type = self.content_type + content_disposition = self.response.headers.get("Content-Disposition", "") + + return "attachment" in content_disposition or ( + not any(non_file in content_type for non_file in NON_FILE_CONTENT_TYPES) + and any(file_type in content_type for file_type in ("application/", "image/", "audio/", "video/")) + ) + + @property + def content_type(self) -> str: + return self.headers.get("Content-Type", "") + + @property + def text(self) -> str: + return self.response.text + + @property + def content(self) -> bytes: + return self.response.content + + @property + def status_code(self) -> int: + return self.response.status_code + + @property + def size(self) -> int: + return len(self.content) + + @property + def readable_size(self) -> str: + if self.size < 1024: + return f"{self.size} bytes" + elif self.size < 1024 * 1024: + return f"{(self.size / 1024):.2f} KB" + else: + return f"{(self.size / 1024 / 1024):.2f} MB" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py new file mode 100644 index 0000000000..71bb0ac86a --- /dev/null +++ b/api/core/workflow/nodes/http_request/executor.py @@ -0,0 +1,327 @@ +import json +from collections.abc import Mapping, Sequence +from copy import deepcopy +from random import randint +from typing import Any, Literal +from urllib.parse import urlencode, urlparse + +import httpx + +from configs import dify_config +from core.file import file_manager +from core.helper import ssrf_proxy +from core.workflow.entities.variable_pool import VariablePool + +from .entities import ( + HttpRequestNodeAuthorization, + HttpRequestNodeData, + HttpRequestNodeTimeout, + Response, +) + +BODY_TYPE_TO_CONTENT_TYPE = { + "json": "application/json", + "x-www-form-urlencoded": "application/x-www-form-urlencoded", + "form-data": "multipart/form-data", + "raw-text": "text/plain", +} + + +class Executor: + method: Literal["get", "head", "post", "put", "delete", "patch"] + url: str + params: Mapping[str, str] | None + content: str | bytes | None + data: Mapping[str, Any] | None + files: Mapping[str, bytes] | None + json: Any + headers: dict[str, str] + auth: HttpRequestNodeAuthorization + timeout: HttpRequestNodeTimeout + + boundary: str + + def __init__( + self, + *, + node_data: HttpRequestNodeData, + timeout: HttpRequestNodeTimeout, + variable_pool: VariablePool, + ): + # If authorization API key is present, convert the API key using the variable pool + if node_data.authorization.type == "api-key": + if node_data.authorization.config is None: + raise ValueError("authorization config is required") + node_data.authorization.config.api_key = variable_pool.convert_template( + node_data.authorization.config.api_key + ).text + + self.url: str = node_data.url + self.method = node_data.method + self.auth = node_data.authorization + self.timeout = timeout + self.params = None + self.headers = {} + self.content = None + self.files = None + self.data = None + self.json = None + + # init template + self.variable_pool = variable_pool + self.node_data = node_data + self._initialize() + + def _initialize(self): + self._init_url() + self._init_params() + self._init_headers() + self._init_body() + + def _init_url(self): + self.url = self.variable_pool.convert_template(self.node_data.url).text + + def _init_params(self): + params = self.variable_pool.convert_template(self.node_data.params).text + self.params = _plain_text_to_dict(params) + + def _init_headers(self): + headers = self.variable_pool.convert_template(self.node_data.headers).text + self.headers = _plain_text_to_dict(headers) + + body = self.node_data.body + if body is None: + return + if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: + self.headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] + if body.type == "form-data": + self.boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" + self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" + + def _init_body(self): + body = self.node_data.body + if body is not None: + data = body.data + match body.type: + case "none": + self.content = "" + case "raw-text": + self.content = self.variable_pool.convert_template(data[0].value).text + case "json": + json_object = json.loads(data[0].value) + self.json = self._parse_object_contains_variables(json_object) + case "binary": + file_selector = data[0].file + file_variable = self.variable_pool.get_file(file_selector) + if file_variable is None: + raise ValueError(f"cannot fetch file with selector {file_selector}") + file = file_variable.value + self.content = file_manager.download(file) + case "x-www-form-urlencoded": + form_data = { + self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( + item.value + ).text + for item in data + } + self.data = form_data + case "form-data": + form_data = { + self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( + item.value + ).text + for item in filter(lambda item: item.type == "text", data) + } + file_selectors = { + self.variable_pool.convert_template(item.key).text: item.file + for item in filter(lambda item: item.type == "file", data) + } + files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} + files = {k: v for k, v in files.items() if v is not None} + files = {k: variable.value for k, variable in files.items()} + files = {k: file_manager.download(v) for k, v in files.items() if v.related_id is not None} + + self.data = form_data + self.files = files + + def _assembling_headers(self) -> dict[str, Any]: + authorization = deepcopy(self.auth) + headers = deepcopy(self.headers) or {} + if self.auth.type == "api-key": + if self.auth.config is None: + raise ValueError("self.authorization config is required") + if authorization.config is None: + raise ValueError("authorization config is required") + + if self.auth.config.api_key is None: + raise ValueError("api_key is required") + + if not authorization.config.header: + authorization.config.header = "Authorization" + + if self.auth.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif self.auth.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif self.auth.config.type == "custom": + headers[authorization.config.header] = authorization.config.api_key or "" + + return headers + + def _validate_and_parse_response(self, response: httpx.Response) -> Response: + executor_response = Response(response) + + threshold_size = ( + dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE + if executor_response.is_file + else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + ) + if executor_response.size > threshold_size: + raise ValueError( + f'{"File" if executor_response.is_file else "Text"} size is too large,' + f' max size is {threshold_size / 1024 / 1024:.2f} MB,' + f' but current size is {executor_response.readable_size}.' + ) + + return executor_response + + def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: + """ + do http request depending on api bundle + """ + if self.method not in {"get", "head", "post", "put", "delete", "patch"}: + raise ValueError(f"Invalid http method {self.method}") + + request_args = { + "url": self.url, + "data": self.data, + "files": self.files, + "json": self.json, + "content": self.content, + "headers": headers, + "params": self.params, + "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), + "follow_redirects": True, + } + + response = getattr(ssrf_proxy, self.method)(**request_args) + return response + + def invoke(self) -> Response: + # assemble headers + headers = self._assembling_headers() + # do http request + response = self._do_http_request(headers) + # validate response + return self._validate_and_parse_response(response) + + def to_log(self): + url_parts = urlparse(self.url) + path = url_parts.path or "/" + + # Add query parameters + if self.params: + query_string = urlencode(self.params) + path += f"?{query_string}" + elif url_parts.query: + path += f"?{url_parts.query}" + + raw = f"{self.method.upper()} {path} HTTP/1.1\r\n" + raw += f"Host: {url_parts.netloc}\r\n" + + headers = self._assembling_headers() + for k, v in headers.items(): + if self.auth.type == "api-key": + authorization_header = "Authorization" + if self.auth.config and self.auth.config.header: + authorization_header = self.auth.config.header + if k.lower() == authorization_header.lower(): + raw += f'{k}: {"*" * len(v)}\r\n' + continue + raw += f"{k}: {v}\r\n" + + body = "" + if self.files: + boundary = self.boundary + for k, v in self.files.items(): + body += f"--{boundary}\r\n" + body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' + body += f"{v[1]}\r\n" + body += f"--{boundary}--\r\n" + elif self.node_data.body: + if self.content: + if isinstance(self.content, str): + body = self.content + elif isinstance(self.content, bytes): + body = self.content.decode("utf-8", errors="replace") + elif self.data and self.node_data.body.type == "x-www-form-urlencoded": + body = urlencode(self.data) + elif self.data and self.node_data.body.type == "form-data": + boundary = self.boundary + for key, value in self.data.items(): + body += f"--{boundary}\r\n" + body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' + body += f"{value}\r\n" + body += f"--{boundary}--\r\n" + elif self.json: + body = json.dumps(self.json) + elif self.node_data.body.type == "raw-text": + body = self.node_data.body.data[0].value + if body: + raw += f"Content-Length: {len(body)}\r\n" + raw += "\r\n" # Empty line between headers and body + raw += body + + return raw + + def _parse_object_contains_variables(self, obj: str | dict | list, /) -> Mapping[str, Any] | Sequence[Any] | str: + if isinstance(obj, dict): + return {k: self._parse_object_contains_variables(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._parse_object_contains_variables(v) for v in obj] + elif isinstance(obj, str): + return self.variable_pool.convert_template(obj).text + + +def _plain_text_to_dict(text: str, /) -> dict[str, str]: + """ + Convert a string of key-value pairs to a dictionary. + + Each line in the input string represents a key-value pair. + Keys and values are separated by ':'. + Empty values are allowed. + + Examples: + 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} + 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} + 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} + + Args: + convert_text (str): The input string to convert. + + Returns: + dict[str, str]: A dictionary of key-value pairs. + """ + return { + key.strip(): (value[0].strip() if value else "") + for line in text.splitlines() + if line.strip() + for key, *value in [line.split(":", 1)] + } + + +def _generate_random_string(n: int) -> str: + """ + Generate a random string of lowercase ASCII letters. + + Args: + n (int): The length of the random string to generate. + + Returns: + str: A random string of lowercase ASCII letters with length n. + + Example: + >>> _generate_random_string(5) + 'abcde' + """ + return "".join([chr(randint(97, 122)) for _ in range(n)]) diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py deleted file mode 100644 index f8ab4e3132..0000000000 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ /dev/null @@ -1,343 +0,0 @@ -import json -from copy import deepcopy -from random import randint -from typing import Any, Optional, Union -from urllib.parse import urlencode - -import httpx - -from configs import dify_config -from core.helper import ssrf_proxy -from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.http_request.entities import ( - HttpRequestNodeAuthorization, - HttpRequestNodeBody, - HttpRequestNodeData, - HttpRequestNodeTimeout, -) -from core.workflow.utils.variable_template_parser import VariableTemplateParser - - -class HttpExecutorResponse: - headers: dict[str, str] - response: httpx.Response - - def __init__(self, response: httpx.Response): - self.response = response - self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {} - - @property - def is_file(self) -> bool: - """ - check if response is file - """ - content_type = self.get_content_type() - file_content_types = ["image", "audio", "video"] - - return any(v in content_type for v in file_content_types) - - def get_content_type(self) -> str: - return self.headers.get("content-type", "") - - def extract_file(self) -> tuple[str, bytes]: - """ - extract file from response if content type is file related - """ - if self.is_file: - return self.get_content_type(), self.body - - return "", b"" - - @property - def content(self) -> str: - if isinstance(self.response, httpx.Response): - return self.response.text - else: - raise ValueError(f"Invalid response type {type(self.response)}") - - @property - def body(self) -> bytes: - if isinstance(self.response, httpx.Response): - return self.response.content - else: - raise ValueError(f"Invalid response type {type(self.response)}") - - @property - def status_code(self) -> int: - if isinstance(self.response, httpx.Response): - return self.response.status_code - else: - raise ValueError(f"Invalid response type {type(self.response)}") - - @property - def size(self) -> int: - return len(self.body) - - @property - def readable_size(self) -> str: - if self.size < 1024: - return f"{self.size} bytes" - elif self.size < 1024 * 1024: - return f"{(self.size / 1024):.2f} KB" - else: - return f"{(self.size / 1024 / 1024):.2f} MB" - - -class HttpExecutor: - server_url: str - method: str - authorization: HttpRequestNodeAuthorization - params: dict[str, Any] - headers: dict[str, Any] - body: Union[None, str] - files: Union[None, dict[str, Any]] - boundary: str - variable_selectors: list[VariableSelector] - timeout: HttpRequestNodeTimeout - - def __init__( - self, - node_data: HttpRequestNodeData, - timeout: HttpRequestNodeTimeout, - variable_pool: Optional[VariablePool] = None, - ): - self.server_url = node_data.url - self.method = node_data.method - self.authorization = node_data.authorization - self.timeout = timeout - self.params = {} - self.headers = {} - self.body = None - self.files = None - - # init template - self.variable_selectors = [] - self._init_template(node_data, variable_pool) - - @staticmethod - def _is_json_body(body: HttpRequestNodeBody): - """ - check if body is json - """ - if body and body.type == "json" and body.data: - try: - json.loads(body.data) - return True - except: - return False - - return False - - @staticmethod - def _to_dict(convert_text: str): - """ - Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` - """ - kv_paris = convert_text.split("\n") - result = {} - for kv in kv_paris: - if not kv.strip(): - continue - - kv = kv.split(":", maxsplit=1) - if len(kv) == 1: - k, v = kv[0], "" - else: - k, v = kv - result[k.strip()] = v - return result - - def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None): - # extract all template in url - self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool) - - # extract all template in params - params, params_variable_selectors = self._format_template(node_data.params, variable_pool) - self.params = self._to_dict(params) - - # extract all template in headers - headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool) - self.headers = self._to_dict(headers) - - # extract all template in body - body_data_variable_selectors = [] - if node_data.body: - # check if it's a valid JSON - is_valid_json = self._is_json_body(node_data.body) - - body_data = node_data.body.data or "" - if body_data: - body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json) - - content_type_is_set = any(key.lower() == "content-type" for key in self.headers) - if node_data.body.type == "json" and not content_type_is_set: - self.headers["Content-Type"] = "application/json" - elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: - self.headers["Content-Type"] = "application/x-www-form-urlencoded" - - if node_data.body.type in {"form-data", "x-www-form-urlencoded"}: - body = self._to_dict(body_data) - - if node_data.body.type == "form-data": - self.files = {k: ("", v) for k, v in body.items()} - random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)]) - self.boundary = f"----WebKitFormBoundary{random_str(16)}" - - self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" - else: - self.body = urlencode(body) - elif node_data.body.type in {"json", "raw-text"}: - self.body = body_data - elif node_data.body.type == "none": - self.body = "" - - self.variable_selectors = ( - server_url_variable_selectors - + params_variable_selectors - + headers_variable_selectors - + body_data_variable_selectors - ) - - def _assembling_headers(self) -> dict[str, Any]: - authorization = deepcopy(self.authorization) - headers = deepcopy(self.headers) or {} - if self.authorization.type == "api-key": - if self.authorization.config is None: - raise ValueError("self.authorization config is required") - if authorization.config is None: - raise ValueError("authorization config is required") - - if self.authorization.config.api_key is None: - raise ValueError("api_key is required") - - if not authorization.config.header: - authorization.config.header = "Authorization" - - if self.authorization.config.type == "bearer": - headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" - elif self.authorization.config.type == "basic": - headers[authorization.config.header] = f"Basic {authorization.config.api_key}" - elif self.authorization.config.type == "custom": - headers[authorization.config.header] = authorization.config.api_key - - return headers - - def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse: - """ - validate the response - """ - if isinstance(response, httpx.Response): - executor_response = HttpExecutorResponse(response) - else: - raise ValueError(f"Invalid response type {type(response)}") - - threshold_size = ( - dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE - if executor_response.is_file - else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE - ) - if executor_response.size > threshold_size: - raise ValueError( - f'{"File" if executor_response.is_file else "Text"} size is too large,' - f' max size is {threshold_size / 1024 / 1024:.2f} MB,' - f' but current size is {executor_response.readable_size}.' - ) - - return executor_response - - def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: - """ - do http request depending on api bundle - """ - kwargs = { - "url": self.server_url, - "headers": headers, - "params": self.params, - "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), - "follow_redirects": True, - } - - if self.method in {"get", "head", "post", "put", "delete", "patch"}: - response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) - else: - raise ValueError(f"Invalid http method {self.method}") - return response - - def invoke(self) -> HttpExecutorResponse: - """ - invoke http request - """ - # assemble headers - headers = self._assembling_headers() - - # do http request - response = self._do_http_request(headers) - - # validate response - return self._validate_and_parse_response(response) - - def to_raw_request(self) -> str: - """ - convert to raw request - """ - server_url = self.server_url - if self.params: - server_url += f"?{urlencode(self.params)}" - - raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n" - - headers = self._assembling_headers() - for k, v in headers.items(): - # get authorization header - if self.authorization.type == "api-key": - authorization_header = "Authorization" - if self.authorization.config and self.authorization.config.header: - authorization_header = self.authorization.config.header - - if k.lower() == authorization_header.lower(): - raw_request += f'{k}: {"*" * len(v)}\n' - continue - - raw_request += f"{k}: {v}\n" - - raw_request += "\n" - - # if files, use multipart/form-data with boundary - if self.files: - boundary = self.boundary - raw_request += f"--{boundary}" - for k, v in self.files.items(): - raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n' - raw_request += f"{v[1]}\n" - raw_request += f"--{boundary}" - raw_request += "--" - else: - raw_request += self.body or "" - - return raw_request - - def _format_template( - self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False - ) -> tuple[str, list[VariableSelector]]: - """ - format template - """ - variable_template_parser = VariableTemplateParser(template=template) - variable_selectors = variable_template_parser.extract_variable_selectors() - - if variable_pool: - variable_value_mapping = {} - for variable_selector in variable_selectors: - variable = variable_pool.get_any(variable_selector.value_selector) - if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") - if escape_quotes and isinstance(variable, str): - value = variable.replace('"', '\\"').replace("\n", "\\n") - else: - value = variable - variable_value_mapping[variable_selector.variable] = value - - return variable_template_parser.format(variable_value_mapping), variable_selectors - else: - return template, variable_selectors diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py deleted file mode 100644 index cd40819126..0000000000 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ /dev/null @@ -1,165 +0,0 @@ -import logging -from collections.abc import Mapping, Sequence -from mimetypes import guess_extension -from os import path -from typing import Any, cast - -from configs import dify_config -from core.app.segments import parser -from core.file.file_obj import FileTransferMethod, FileType, FileVar -from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.http_request.entities import ( - HttpRequestNodeData, - HttpRequestNodeTimeout, -) -from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse -from models.workflow import WorkflowNodeExecutionStatus - -HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( - connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, - read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, - write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, -) - - -class HttpRequestNode(BaseNode): - _node_data_cls = HttpRequestNodeData - _node_type = NodeType.HTTP_REQUEST - - @classmethod - def get_default_config(cls, filters: dict | None = None) -> dict: - return { - "type": "http-request", - "config": { - "method": "get", - "authorization": { - "type": "no-auth", - }, - "body": {"type": "none"}, - "timeout": { - **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), - "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, - "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, - "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, - }, - }, - } - - def _run(self) -> NodeRunResult: - node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) - # TODO: Switch to use segment directly - if node_data.authorization.config and node_data.authorization.config.api_key: - node_data.authorization.config.api_key = parser.convert_template( - template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool - ).text - - # init http executor - http_executor = None - try: - http_executor = HttpExecutor( - node_data=node_data, - timeout=self._get_request_timeout(node_data), - variable_pool=self.graph_runtime_state.variable_pool, - ) - - # invoke http executor - response = http_executor.invoke() - except Exception as e: - process_data = {} - if http_executor: - process_data = { - "request": http_executor.to_raw_request(), - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - process_data=process_data, - ) - - files = self.extract_files(http_executor.server_url, response) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "status_code": response.status_code, - "body": response.content if not files else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_raw_request(), - }, - ) - - @staticmethod - def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: - timeout = node_data.timeout - if timeout is None: - return HTTP_REQUEST_DEFAULT_TIMEOUT - - timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect - timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read - timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write - return timeout - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - try: - http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) - - variable_selectors = http_executor.variable_selectors - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - - return variable_mapping - except Exception as e: - logging.exception(f"Failed to extract variable selector to variable mapping: {e}") - return {} - - def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: - """ - Extract files from response - """ - files = [] - mimetype, file_binary = response.extract_file() - - if mimetype: - # extract filename from url - filename = path.basename(url) - # extract extension if possible - extension = guess_extension(mimetype) or ".bin" - - tool_file = ToolFileManager.create_file_by_raw( - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=None, - file_binary=file_binary, - mimetype=mimetype, - ) - - files.append( - FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file.id, - filename=filename, - extension=extension, - mime_type=mimetype, - ) - ) - - return files diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py new file mode 100644 index 0000000000..483d0e2b7e --- /dev/null +++ b/api/core/workflow/nodes/http_request/node.py @@ -0,0 +1,174 @@ +import logging +from collections.abc import Mapping, Sequence +from mimetypes import guess_extension +from os import path +from typing import Any + +from configs import dify_config +from core.file import File, FileTransferMethod, FileType +from core.tools.tool_file_manager import ToolFileManager +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.http_request.executor import Executor +from core.workflow.utils import variable_template_parser +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ( + HttpRequestNodeData, + HttpRequestNodeTimeout, + Response, +) + +HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( + connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, +) + +logger = logging.getLogger(__name__) + + +class HttpRequestNode(BaseNode[HttpRequestNodeData]): + _node_data_cls = HttpRequestNodeData + _node_type = NodeType.HTTP_REQUEST + + @classmethod + def get_default_config(cls, filters: dict | None = None) -> dict: + return { + "type": "http-request", + "config": { + "method": "get", + "authorization": { + "type": "no-auth", + }, + "body": {"type": "none"}, + "timeout": { + **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), + "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, + }, + }, + } + + def _run(self) -> NodeRunResult: + process_data = {} + try: + http_executor = Executor( + node_data=self.node_data, + timeout=self._get_request_timeout(self.node_data), + variable_pool=self.graph_runtime_state.variable_pool, + ) + process_data["request"] = http_executor.to_log() + + response = http_executor.invoke() + files = self.extract_files(url=http_executor.url, response=response) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "status_code": response.status_code, + "body": response.text if not files else "", + "headers": response.headers, + "files": files, + }, + process_data={ + "request": http_executor.to_log(), + }, + ) + except Exception as e: + logger.warning(f"http request node {self.node_id} failed to run: {e}") + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + process_data=process_data, + ) + + @staticmethod + def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: + timeout = node_data.timeout + if timeout is None: + return HTTP_REQUEST_DEFAULT_TIMEOUT + + timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect + timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read + timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write + return timeout + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: HttpRequestNodeData, + ) -> Mapping[str, Sequence[str]]: + selectors: list[VariableSelector] = [] + selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(node_data.params) + if node_data.body: + body_type = node_data.body.type + data = node_data.body.data + match body_type: + case "binary": + selector = data[0].file + selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) + case "json" | "raw-text": + selectors += variable_template_parser.extract_selectors_from_template(data[0].key) + selectors += variable_template_parser.extract_selectors_from_template(data[0].value) + case "x-www-form-urlencoded": + for item in data: + selectors += variable_template_parser.extract_selectors_from_template(item.key) + selectors += variable_template_parser.extract_selectors_from_template(item.value) + case "form-data": + for item in data: + selectors += variable_template_parser.extract_selectors_from_template(item.key) + if item.type == "text": + selectors += variable_template_parser.extract_selectors_from_template(item.value) + elif item.type == "file": + selectors.append( + VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file) + ) + + mapping = {} + for selector in selectors: + mapping[node_id + "." + selector.variable] = selector.value_selector + + return mapping + + def extract_files(self, url: str, response: Response) -> list[File]: + """ + Extract files from response + """ + files = [] + content_type = response.content_type + content = response.content + + if content_type: + # extract filename from url + filename = path.basename(url) + # extract extension if possible + extension = guess_extension(content_type) or ".bin" + + tool_file = ToolFileManager.create_file_by_raw( + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + file_binary=content, + mimetype=content_type, + ) + + files.append( + File( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=tool_file.id, + filename=filename, + extension=extension, + mime_type=content_type, + ) + ) + + return files diff --git a/api/core/workflow/nodes/if_else/__init__.py b/api/core/workflow/nodes/if_else/__init__.py index e69de29bb2..afa0e8112c 100644 --- a/api/core/workflow/nodes/if_else/__init__.py +++ b/api/core/workflow/nodes/if_else/__init__.py @@ -0,0 +1,3 @@ +from .if_else_node import IfElseNode + +__all__ = ["IfElseNode"] diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 54c1081fd3..23f5d2cc31 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -1,8 +1,8 @@ from typing import Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData from core.workflow.utils.condition.entities import Condition @@ -21,6 +21,6 @@ class IfElseNodeData(BaseNodeData): conditions: list[Condition] logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = None + conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) cases: Optional[list[Case]] = None diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 37384202d8..6960fc045a 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,14 +1,19 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Literal -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from typing_extensions import deprecated + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.processor import ConditionProcessor from models.workflow import WorkflowNodeExecutionStatus -class IfElseNode(BaseNode): +class IfElseNode(BaseNode[IfElseNodeData]): _node_data_cls = IfElseNodeData _node_type = NodeType.IF_ELSE @@ -17,9 +22,6 @@ class IfElseNode(BaseNode): Run node :return: """ - node_data = self.node_data - node_data = cast(IfElseNodeData, node_data) - node_inputs: dict[str, list] = {"conditions": []} process_datas: dict[str, list] = {"condition_results": []} @@ -30,15 +32,14 @@ class IfElseNode(BaseNode): condition_processor = ConditionProcessor() try: # Check if the new cases structure is used - if node_data.cases: - for case in node_data.cases: - input_conditions, group_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions + if self.node_data.cases: + for case in self.node_data.cases: + input_conditions, group_result, final_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=case.conditions, + operator=case.logical_operator, ) - # Apply the logical operator for the current case - final_result = all(group_result) if case.logical_operator == "and" else any(group_result) - process_datas["condition_results"].append( { "group": case.model_dump(), @@ -53,13 +54,15 @@ class IfElseNode(BaseNode): break else: + # TODO: Update database then remove this # Fallback to old structure if cases are not defined - input_conditions, group_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions + input_conditions, group_result, final_result = _should_not_use_old_function( + condition_processor=condition_processor, + variable_pool=self.graph_runtime_state.variable_pool, + conditions=self.node_data.conditions or [], + operator=self.node_data.logical_operator or "and", ) - final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) - selected_case_id = "true" if final_result else "false" process_datas["condition_results"].append( @@ -87,7 +90,11 @@ class IfElseNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IfElseNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -97,3 +104,18 @@ class IfElseNode(BaseNode): :return: """ return {} + + +@deprecated("This function is deprecated. You should use the new cases structure.") +def _should_not_use_old_function( + *, + condition_processor: ConditionProcessor, + variable_pool: VariablePool, + conditions: list[Condition], + operator: Literal["and", "or"], +): + return condition_processor.process_conditions( + variable_pool=variable_pool, + conditions=conditions, + operator=operator, + ) diff --git a/api/core/workflow/nodes/iteration/__init__.py b/api/core/workflow/nodes/iteration/__init__.py index e69de29bb2..5bb87aaffa 100644 --- a/api/core/workflow/nodes/iteration/__init__.py +++ b/api/core/workflow/nodes/iteration/__init__.py @@ -0,0 +1,5 @@ +from .entities import IterationNodeData +from .iteration_node import IterationNode +from .iteration_start_node import IterationStartNode + +__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"] diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 3c2c189159..4afc870e50 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,6 +1,8 @@ from typing import Any, Optional -from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData +from pydantic import Field + +from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData class IterationNodeData(BaseIterationNodeData): @@ -26,7 +28,7 @@ class IterationState(BaseIterationState): Iteration State. """ - outputs: list[Any] = None + outputs: list[Any] = Field(default_factory=list) current_output: Optional[Any] = None class MetaData(BaseIterationState.MetaData): diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 01bb4e9076..b28ae0a85c 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -5,7 +5,7 @@ from typing import Any, cast from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, BaseNodeEvent, @@ -20,15 +20,16 @@ from core.workflow.graph_engine.entities.event import ( NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.event import RunCompletedEvent, RunEvent +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import IterationNodeData from models.workflow import WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) -class IterationNode(BaseNode): +class IterationNode(BaseNode[IterationNodeData]): """ Iteration Node. """ @@ -36,11 +37,10 @@ class IterationNode(BaseNode): _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION - def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. """ - self.node_data = cast(IterationNodeData, self.node_data) iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) if not iterator_list_segment: @@ -177,7 +177,7 @@ class IterationNode(BaseNode): # remove all nodes outputs from variable pool for node_id in iteration_graph.node_ids: - variable_pool.remove_node(node_id) + variable_pool.remove([node_id]) # move to next iteration current_index = variable_pool.get([self.node_id, "index"]) @@ -247,7 +247,11 @@ class IterationNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IterationNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -273,15 +277,13 @@ class IterationNode(BaseNode): # variable selector to variable mapping try: # Get node class - from core.workflow.nodes.node_mapping import node_classes + from core.workflow.nodes.node_mapping import node_type_classes_mapping - node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type")) - node_cls = node_classes.get(node_type) + node_type = NodeType(sub_node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping.get(node_type) if not node_cls: continue - node_cls = cast(BaseNode, node_cls) - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=graph_config, config=sub_node_config ) diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 88b9665ac6..6ab7c30106 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,8 +1,9 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData from models.workflow import WorkflowNodeExecutionStatus diff --git a/api/core/workflow/nodes/knowledge_retrieval/__init__.py b/api/core/workflow/nodes/knowledge_retrieval/__init__.py index e69de29bb2..4d4a4cbd9f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/__init__.py +++ b/api/core/workflow/nodes/knowledge_retrieval/__init__.py @@ -0,0 +1,3 @@ +from .knowledge_retrieval_node import KnowledgeRetrievalNode + +__all__ = ["KnowledgeRetrievalNode"] diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 1cd88039b1..e8972d1381 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -2,7 +2,7 @@ from typing import Any, Literal, Optional from pydantic import BaseModel -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 8cd208d7fc..b286f34d7f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -14,8 +14,9 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -32,15 +33,13 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(BaseNode): +class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): _node_data_cls = KnowledgeRetrievalNodeData - node_type = NodeType.KNOWLEDGE_RETRIEVAL + _node_type = NodeType.KNOWLEDGE_RETRIEVAL def _run(self) -> NodeRunResult: - node_data = cast(KnowledgeRetrievalNodeData, self.node_data) - # extract variables - variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector) + variable = self.graph_runtime_state.variable_pool.get_any(self.node_data.query_variable_selector) query = variable variables = {"query": query} if not query: @@ -49,7 +48,7 @@ class KnowledgeRetrievalNode(BaseNode): ) # retrieve knowledge try: - results = self._fetch_dataset_retriever(node_data=node_data, query=query) + results = self._fetch_dataset_retriever(node_data=self.node_data, query=query) outputs = {"result": results} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs @@ -244,7 +243,11 @@ class KnowledgeRetrievalNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: KnowledgeRetrievalNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/list_operator/__init__.py b/api/core/workflow/nodes/list_operator/__init__.py new file mode 100644 index 0000000000..1877586ef4 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/__init__.py @@ -0,0 +1,3 @@ +from .node import ListOperatorNode + +__all__ = ["ListOperatorNode"] diff --git a/api/core/workflow/nodes/list_operator/entities.py b/api/core/workflow/nodes/list_operator/entities.py new file mode 100644 index 0000000000..79cef1c27a --- /dev/null +++ b/api/core/workflow/nodes/list_operator/entities.py @@ -0,0 +1,56 @@ +from collections.abc import Sequence +from typing import Literal + +from pydantic import BaseModel, Field + +from core.workflow.nodes.base import BaseNodeData + +_Condition = Literal[ + # string conditions + "contains", + "start with", + "end with", + "is", + "in", + "empty", + "not contains", + "is not", + "not in", + "not empty", + # number conditions + "=", + "≠", + "<", + ">", + "≥", + "≤", +] + + +class FilterCondition(BaseModel): + key: str = "" + comparison_operator: _Condition = "contains" + value: str | Sequence[str] = "" + + +class FilterBy(BaseModel): + enabled: bool = False + conditions: Sequence[FilterCondition] = Field(default_factory=list) + + +class OrderBy(BaseModel): + enabled: bool = False + key: str = "" + value: Literal["asc", "desc"] = "asc" + + +class Limit(BaseModel): + enabled: bool = False + size: int = -1 + + +class ListOperatorNodeData(BaseNodeData): + variable: Sequence[str] = Field(default_factory=list) + filter_by: FilterBy + order_by: OrderBy + limit: Limit diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py new file mode 100644 index 0000000000..d7e4c64313 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/node.py @@ -0,0 +1,259 @@ +from collections.abc import Callable, Sequence +from typing import Literal + +from core.file import File +from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ListOperatorNodeData + + +class ListOperatorNode(BaseNode[ListOperatorNodeData]): + _node_data_cls = ListOperatorNodeData + _node_type = NodeType.LIST_OPERATOR + + def _run(self): + inputs = {} + process_data = {} + outputs = {} + + variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) + if variable is None: + error_message = f"Variable not found for selector: {self.node_data.variable}" + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs + ) + if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): + error_message = ( + f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " + "or ArrayStringSegment" + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs + ) + + if isinstance(variable, ArrayFileSegment): + process_data["variable"] = [item.to_dict() for item in variable.value] + else: + process_data["variable"] = variable.value + + # Filter + if self.node_data.filter_by.enabled: + for condition in self.node_data.filter_by.conditions: + if isinstance(variable, ArrayStringSegment): + if not isinstance(condition.value, str): + raise ValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + if not isinstance(condition.value, str): + raise ValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + if isinstance(condition.value, str): + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + else: + value = condition.value + filter_func = _get_file_filter_func( + key=condition.key, + condition=condition.comparison_operator, + value=value, + ) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + + # Order + if self.node_data.order_by.enabled: + if isinstance(variable, ArrayStringSegment): + result = _order_string(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + result = _order_number(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + result = _order_file( + order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + ) + variable = variable.model_copy(update={"value": result}) + + # Slice + if self.node_data.limit.enabled: + result = variable.value[: self.node_data.limit.size] + variable = variable.model_copy(update={"value": result}) + + outputs = { + "result": variable.value, + "first_record": variable.value[0] if variable.value else None, + "last_record": variable.value[-1] if variable.value else None, + } + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + + +def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: + match key: + case "size": + return lambda x: x.size + case _: + raise ValueError(f"Invalid key: {key}") + + +def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: + match key: + case "name": + return lambda x: x.filename or "" + case "type": + return lambda x: x.type + case "extension": + return lambda x: x.extension or "" + case "mimetype": + return lambda x: x.mime_type or "" + case "transfer_method": + return lambda x: x.transfer_method + case "url": + return lambda x: x.remote_url or "" + case _: + raise ValueError(f"Invalid key: {key}") + + +def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: + match condition: + case "contains": + return _contains(value) + case "start with": + return _startswith(value) + case "end with": + return _endswith(value) + case "is": + return _is(value) + case "in": + return _in(value) + case "empty": + return lambda x: x == "" + case "not contains": + return lambda x: not _contains(value)(x) + case "is not": + return lambda x: not _is(value)(x) + case "not in": + return lambda x: not _in(value)(x) + case "not empty": + return lambda x: x != "" + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: + match condition: + case "in": + return _in(value) + case "not in": + return lambda x: not _in(value)(x) + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: + match condition: + case "=": + return _eq(value) + case "≠": + return _ne(value) + case "<": + return _lt(value) + case "≤": + return _le(value) + case ">": + return _gt(value) + case "≥": + return _ge(value) + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: + if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): + extract_func = _get_file_extract_string_func(key=key) + return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) + if key in {"type", "transfer_method"} and isinstance(value, Sequence): + extract_func = _get_file_extract_string_func(key=key) + return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) + elif key == "size" and isinstance(value, str): + extract_func = _get_file_extract_number_func(key=key) + return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) + else: + raise ValueError(f"Invalid key: {key}") + + +def _contains(value: str): + return lambda x: value in x + + +def _startswith(value: str): + return lambda x: x.startswith(value) + + +def _endswith(value: str): + return lambda x: x.endswith(value) + + +def _is(value: str): + return lambda x: x is value + + +def _in(value: str | Sequence[str]): + return lambda x: x in value + + +def _eq(value: int | float): + return lambda x: x == value + + +def _ne(value: int | float): + return lambda x: x != value + + +def _lt(value: int | float): + return lambda x: x < value + + +def _le(value: int | float): + return lambda x: x <= value + + +def _gt(value: int | float): + return lambda x: x > value + + +def _ge(value: int | float): + return lambda x: x >= value + + +def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]): + return sorted(array, key=lambda x: x, reverse=order == "desc") + + +def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): + return sorted(array, key=lambda x: x, reverse=order == "desc") + + +def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): + if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}: + extract_func = _get_file_extract_string_func(key=order_by) + return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + elif order_by == "size": + extract_func = _get_file_extract_number_func(key=order_by) + return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + else: + raise ValueError(f"Invalid order key: {order_by}") diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/core/workflow/nodes/llm/__init__.py index e69de29bb2..f7bc713f63 100644 --- a/api/core/workflow/nodes/llm/__init__.py +++ b/api/core/workflow/nodes/llm/__init__.py @@ -0,0 +1,17 @@ +from .entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from .node import LLMNode + +__all__ = [ + "LLMNode", + "LLMNodeChatModelMessage", + "LLMNodeCompletionModelPromptTemplate", + "LLMNodeData", + "ModelConfig", + "VisionConfig", +] diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 93ee0ac250..b4de312461 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,17 +1,15 @@ -from typing import Any, Literal, Optional, Union +from collections.abc import Sequence +from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field +from core.model_runtime.entities import ImagePromptMessageContent from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class ModelConfig(BaseModel): - """ - Model Config. - """ - provider: str name: str mode: str @@ -19,62 +17,36 @@ class ModelConfig(BaseModel): class ContextConfig(BaseModel): - """ - Context Config. - """ - enabled: bool variable_selector: Optional[list[str]] = None +class VisionConfigOptions(BaseModel): + variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"]) + detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH + + class VisionConfig(BaseModel): - """ - Vision Config. - """ - - class Configs(BaseModel): - """ - Configs. - """ - - detail: Literal["low", "high"] - - enabled: bool - configs: Optional[Configs] = None + enabled: bool = False + configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions) class PromptConfig(BaseModel): - """ - Prompt Config. - """ - jinja2_variables: Optional[list[VariableSelector]] = None class LLMNodeChatModelMessage(ChatModelMessage): - """ - LLM Node Chat Model Message. - """ - jinja2_text: Optional[str] = None class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - """ - LLM Node Chat Model Prompt Template. - """ - jinja2_text: Optional[str] = None class LLMNodeData(BaseNodeData): - """ - LLM Node Data. - """ - model: ModelConfig - prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate] + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: Optional[PromptConfig] = None memory: Optional[MemoryConfig] = None context: ContextConfig - vision: VisionConfig + vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/node.py similarity index 76% rename from api/core/workflow/nodes/llm/llm_node.py rename to api/core/workflow/nodes/llm/node.py index 3d336b0b0b..24e479153e 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,39 +1,40 @@ import json from collections.abc import Generator, Mapping, Sequence -from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional, cast -from pydantic import BaseModel - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( + AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, + TextPromptMessageContent, ) +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.nodes.llm.entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, - ModelConfig, +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import ( + ModelInvokeCompletedEvent, + NodeEvent, + RunCompletedEvent, + RunRetrieverResourceEvent, + RunStreamChunkEvent, ) from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db @@ -41,44 +42,34 @@ from models.model import Conversation from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus +from .entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + LLMNodeData, + ModelConfig, +) + if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File -class ModelInvokeCompleted(BaseModel): - """ - Model invoke completed - """ - - text: str - usage: LLMUsage - finish_reason: Optional[str] = None - - -class LLMNode(BaseNode): +class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: - """ - Run node - :return: - """ - node_data = cast(LLMNodeData, deepcopy(self.node_data)) - variable_pool = self.graph_runtime_state.variable_pool - + def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]: node_inputs = None process_data = None try: # init messages template - node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template) + self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data, variable_pool) + inputs = self._fetch_inputs(node_data=self.node_data) # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool) + jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) # merge inputs inputs.update(jinja_inputs) @@ -86,13 +77,17 @@ class LLMNode(BaseNode): node_inputs = {} # fetch files - files = self._fetch_files(node_data, variable_pool) + files = ( + self._fetch_files(selector=self.node_data.vision.configs.variable_selector) + if self.node_data.vision.enabled + else [] + ) if files: node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value - generator = self._fetch_context(node_data, variable_pool) + generator = self._fetch_context(node_data=self.node_data) context = None for event in generator: if isinstance(event, RunRetrieverResourceEvent): @@ -103,21 +98,30 @@ class LLMNode(BaseNode): node_inputs["#context#"] = context # type: ignore # fetch model config - model_instance, model_config = self._fetch_model_config(node_data.model) + model_instance, model_config = self._fetch_model_config(self.node_data.model) # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) # fetch prompt messages + if self.node_data.memory: + query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) + if not query: + raise ValueError("Query not found") + query = query.text + else: + query = None + prompt_messages, stop = self._fetch_prompt_messages( - node_data=node_data, - query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None, - query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, + system_query=query, inputs=inputs, files=files, context=context, memory=memory, model_config=model_config, + vision_detail=self.node_data.vision.configs.detail, + prompt_template=self.node_data.prompt_template, + memory_config=self.node_data.memory, ) process_data = { @@ -131,7 +135,7 @@ class LLMNode(BaseNode): # handle invoke result generator = self._invoke_llm( - node_data_model=node_data.model, + node_data_model=self.node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, @@ -143,7 +147,7 @@ class LLMNode(BaseNode): for event in generator: if isinstance(event, RunStreamChunkEvent): yield event - elif isinstance(event, ModelInvokeCompleted): + elif isinstance(event, ModelInvokeCompletedEvent): result_text = event.text usage = event.usage finish_reason = event.finish_reason @@ -182,15 +186,7 @@ class LLMNode(BaseNode): model_instance: ModelInstance, prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None, - ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: - """ - Invoke large language model - :param node_data_model: node data model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: - """ + ) -> Generator[NodeEvent, None, None]: db.session.close() invoke_result = model_instance.invoke_llm( @@ -207,20 +203,13 @@ class LLMNode(BaseNode): usage = LLMUsage.empty_usage() for event in generator: yield event - if isinstance(event, ModelInvokeCompleted): + if isinstance(event, ModelInvokeCompletedEvent): usage = event.usage # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - def _handle_invoke_result( - self, invoke_result: LLMResult | Generator - ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: - """ - Handle invoke result - :param invoke_result: invoke result - :return: - """ + def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: if isinstance(invoke_result, LLMResult): return @@ -250,18 +239,11 @@ class LLMNode(BaseNode): if not usage: usage = LLMUsage.empty_usage() - yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason) + yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason) def _transform_chat_messages( - self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - """ - Transform chat messages - - :param messages: chat messages - :return: - """ - + self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / + ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: if isinstance(messages, LLMNodeCompletionModelPromptTemplate): if messages.edition_type == "jinja2" and messages.jinja2_text: messages.text = messages.jinja2_text @@ -274,13 +256,7 @@ class LLMNode(BaseNode): return messages - def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: - """ - Fetch jinja inputs - :param node_data: node data - :param variable_pool: variable pool - :return: - """ + def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: variables = {} if not node_data.prompt_config: @@ -288,7 +264,7 @@ class LLMNode(BaseNode): for variable_selector in node_data.prompt_config.jinja2_variables or []: variable = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) def parse_dict(d: dict) -> str: """ @@ -330,13 +306,7 @@ class LLMNode(BaseNode): return variables - def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: - """ - Fetch inputs - :param node_data: node data - :param variable_pool: variable pool - :return: - """ + def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]: inputs = {} prompt_template = node_data.prompt_template @@ -350,7 +320,7 @@ class LLMNode(BaseNode): variable_selectors = variable_template_parser.extract_variable_selectors() for variable_selector in variable_selectors: - variable_value = variable_pool.get_any(variable_selector.value_selector) + variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) if variable_value is None: raise ValueError(f"Variable {variable_selector.variable} not found") @@ -362,7 +332,7 @@ class LLMNode(BaseNode): template=memory.query_prompt_template ).extract_variable_selectors() for variable_selector in query_variable_selectors: - variable_value = variable_pool.get_any(variable_selector.value_selector) + variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) if variable_value is None: raise ValueError(f"Variable {variable_selector.variable} not found") @@ -370,36 +340,28 @@ class LLMNode(BaseNode): return inputs - def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: - """ - Fetch files - :param node_data: node data - :param variable_pool: variable pool - :return: - """ - if not node_data.vision.enabled: + def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]: + variable = self.graph_runtime_state.variable_pool.get(selector) + if variable is None: return [] - - files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value]) - if not files: + if isinstance(variable, FileSegment): + return [variable.value] + if isinstance(variable, ArrayFileSegment): + return variable.value + # FIXME: Temporary fix for empty array, + # all variables added to variable pool should be a Segment instance. + if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0: return [] + raise ValueError(f"Invalid variable type: {type(variable)}") - return files - - def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]: - """ - Fetch context - :param node_data: node data - :param variable_pool: variable pool - :return: - """ + def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: return if not node_data.context.variable_selector: return - context_value = variable_pool.get_any(node_data.context.variable_selector) + context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector) if context_value: if isinstance(context_value, str): yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value) @@ -424,11 +386,6 @@ class LLMNode(BaseNode): ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: - """ - Convert to original retriever resource, temp. - :param context_dict: context dict - :return: - """ if ( "metadata" in context_dict and "_source" in context_dict["metadata"] @@ -451,6 +408,7 @@ class LLMNode(BaseNode): "segment_position": metadata.get("segment_position"), "index_node_hash": metadata.get("segment_index_node_hash"), "content": context_dict.get("content"), + "page": metadata.get("page"), } return source @@ -460,11 +418,6 @@ class LLMNode(BaseNode): def _fetch_model_config( self, node_data_model: ModelConfig ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - """ - Fetch model config - :param node_data_model: node data model - :return: - """ model_name = node_data_model.name provider_name = node_data_model.provider @@ -523,19 +476,15 @@ class LLMNode(BaseNode): ) def _fetch_memory( - self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance + self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance ) -> Optional[TokenBufferMemory]: - """ - Fetch memory - :param node_data_memory: node data memory - :param variable_pool: variable pool - :return: - """ if not node_data_memory: return None # get conversation id - conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value]) + conversation_id = self.graph_runtime_state.variable_pool.get_any( + ["sys", SystemVariableKey.CONVERSATION_ID.value] + ) if conversation_id is None: return None @@ -555,43 +504,31 @@ class LLMNode(BaseNode): def _fetch_prompt_messages( self, - node_data: LLMNodeData, - query: Optional[str], - query_prompt_template: Optional[str], - inputs: dict[str, str], - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], + *, + system_query: str | None = None, + inputs: dict[str, str] | None = None, + files: Sequence["File"], + context: str | None = None, + memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + memory_config: MemoryConfig | None = None, + vision_detail: ImagePromptMessageContent.DETAIL, ) -> tuple[list[PromptMessage], Optional[list[str]]]: - """ - Fetch prompt messages - :param node_data: node data - :param query: query - :param query_prompt_template: query prompt template - :param inputs: inputs - :param files: files - :param context: context - :param memory: memory - :param model_config: model config - :return: - """ + inputs = inputs or {} + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_messages = prompt_transform.get_prompt( - prompt_template=node_data.prompt_template, + prompt_template=prompt_template, inputs=inputs, - query=query or "", + query=system_query or "", files=files, context=context, - memory_config=node_data.memory, + memory_config=memory_config, memory=memory, model_config=model_config, - query_prompt_template=query_prompt_template, ) stop = model_config.stop - - vision_enabled = node_data.vision.enabled - vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None filtered_prompt_messages = [] for prompt_message in prompt_messages: if prompt_message.is_empty(): @@ -599,17 +536,13 @@ class LLMNode(BaseNode): if not isinstance(prompt_message.content, str): prompt_message_content = [] - for content_item in prompt_message.content: - if ( - vision_enabled - and content_item.type == PromptMessageContentType.IMAGE - and isinstance(content_item, ImagePromptMessageContent) - ): - # Override vision config if LLM node has vision config - if vision_detail: - content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) + for content_item in prompt_message.content or []: + if isinstance(content_item, ImagePromptMessageContent): + # Override vision config if LLM node has vision config, + # cuz vision detail is related to the configuration from FileUpload feature. + content_item.detail = vision_detail prompt_message_content.append(content_item) - elif content_item.type == PromptMessageContentType.TEXT: + elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent): prompt_message_content.append(content_item) if len(prompt_message_content) > 1: @@ -631,13 +564,6 @@ class LLMNode(BaseNode): @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: - """ - Deduct LLM quota - :param tenant_id: tenant id - :param model_instance: model instance - :param usage: usage - :return: - """ provider_model_bundle = model_instance.provider_model_bundle provider_configuration = provider_model_bundle.configuration @@ -668,7 +594,7 @@ class LLMNode(BaseNode): else: used_quota = 1 - if used_quota is not None: + if used_quota is not None and system_configuration.current_quota_type is not None: db.session.query(Provider).filter( Provider.tenant_id == tenant_id, Provider.provider_name == model_instance.provider, @@ -680,27 +606,28 @@ class LLMNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: LLMNodeData, ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ prompt_template = node_data.prompt_template variable_selectors = [] - if isinstance(prompt_template, list): + if isinstance(prompt_template, list) and all( + isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template + ): for prompt in prompt_template: if prompt.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - else: + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() + else: + raise ValueError(f"Invalid prompt template type: {type(prompt_template)}") variable_mapping = {} for variable_selector in variable_selectors: @@ -745,11 +672,6 @@ class LLMNode(BaseNode): @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ return { "type": "llm", "config": { diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index a8a0debe64..b7cd7a948e 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,4 +1,4 @@ -from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState +from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState class LoopNodeData(BaseIterationNodeData): diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index fbc68b79cb..6fdff96602 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,12 +1,12 @@ from typing import Any -from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.loop.entities import LoopNodeData, LoopState from core.workflow.utils.condition.entities import Condition -class LoopNode(BaseNode): +class LoopNode(BaseNode[LoopNodeData]): """ Loop Node. """ diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index b98525e86e..c13b5ff76f 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -1,22 +1,24 @@ -from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request.http_request_node import HttpRequestNode -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.iteration.iteration_node import IterationNode -from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode +from core.workflow.nodes.answer import AnswerNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.code import CodeNode +from core.workflow.nodes.document_extractor import DocumentExtractorNode +from core.workflow.nodes.end import EndNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.http_request import HttpRequestNode +from core.workflow.nodes.if_else import IfElseNode +from core.workflow.nodes.iteration import IterationNode, IterationStartNode +from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode +from core.workflow.nodes.list_operator import ListOperatorNode +from core.workflow.nodes.llm import LLMNode +from core.workflow.nodes.parameter_extractor import ParameterExtractorNode +from core.workflow.nodes.question_classifier import QuestionClassifierNode +from core.workflow.nodes.start import StartNode +from core.workflow.nodes.template_transform import TemplateTransformNode +from core.workflow.nodes.tool import ToolNode +from core.workflow.nodes.variable_aggregator import VariableAggregatorNode from core.workflow.nodes.variable_assigner import VariableAssignerNode -node_classes = { +node_type_classes_mapping: dict[NodeType, type[BaseNode]] = { NodeType.START: StartNode, NodeType.END: EndNode, NodeType.ANSWER: AnswerNode, @@ -34,4 +36,6 @@ node_classes = { NodeType.ITERATION_START: IterationStartNode, NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, + NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode, + NodeType.LIST_OPERATOR: ListOperatorNode, } diff --git a/api/core/workflow/nodes/parameter_extractor/__init__.py b/api/core/workflow/nodes/parameter_extractor/__init__.py index e69de29bb2..bdbf19a7d3 100644 --- a/api/core/workflow/nodes/parameter_extractor/__init__.py +++ b/api/core/workflow/nodes/parameter_extractor/__init__.py @@ -0,0 +1,3 @@ +from .parameter_extractor_node import ParameterExtractorNode + +__all__ = ["ParameterExtractorNode"] diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 5697d7c049..a001b44dc7 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -1,20 +1,10 @@ from typing import Any, Literal, Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class ModelConfig(BaseModel): - """ - Model Config. - """ - - provider: str - name: str - mode: str - completion_params: dict[str, Any] = {} +from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.llm import ModelConfig, VisionConfig class ParameterConfig(BaseModel): @@ -49,6 +39,7 @@ class ParameterExtractorNodeData(BaseNodeData): instruction: Optional[str] = None memory: Optional[MemoryConfig] = None reasoning_mode: Literal["function_call", "prompt"] + vision: VisionConfig = Field(default_factory=VisionConfig) @field_validator("reasoning_mode", mode="before") @classmethod @@ -64,7 +55,7 @@ class ParameterExtractorNodeData(BaseNodeData): parameters = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: - parameter_schema = {"description": parameter.description} + parameter_schema: dict[str, Any] = {"description": parameter.description} if parameter.type in {"string", "select"}: parameter_schema["type"] = "string" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index a6454bd1cd..49546e9356 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.file import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -22,12 +23,16 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.llm.entities import ModelConfig -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from core.workflow.nodes.parameter_extractor.prompts import ( +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.llm import LLMNode, ModelConfig +from core.workflow.utils import variable_template_parser +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ParameterExtractorNodeData +from .prompts import ( CHAT_EXAMPLE, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, COMPLETION_GENERATE_JSON_PROMPT, @@ -36,9 +41,6 @@ from core.workflow.nodes.parameter_extractor.prompts import ( FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT, FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, ) -from core.workflow.utils.variable_template_parser import VariableTemplateParser -from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus class ParameterExtractorNode(LLMNode): @@ -65,33 +67,39 @@ class ParameterExtractorNode(LLMNode): } } - def _run(self) -> NodeRunResult: + def _run(self): """ Run the node. """ node_data = cast(ParameterExtractorNodeData, self.node_data) - variable = self.graph_runtime_state.variable_pool.get_any(node_data.query) - if not variable: - raise ValueError("Input variable content not found or is empty") - query = variable + variable = self.graph_runtime_state.variable_pool.get(node_data.query) + query = variable.text if variable else "" - inputs = { - "query": query, - "parameters": jsonable_encoder(node_data.parameters), - "instruction": jsonable_encoder(node_data.instruction), - } + files = ( + self._fetch_files( + selector=node_data.vision.configs.variable_selector, + ) + if node_data.vision.enabled + else [] + ) model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): raise ValueError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance - model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) + model_schema = llm_model.get_model_schema( + model=model_config.model, + credentials=model_config.credentials, + ) if not model_schema: raise ValueError("Model schema not found") # fetch memory - memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance) + memory = self._fetch_memory( + node_data_memory=node_data.memory, + model_instance=model_instance, + ) if ( set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} @@ -99,15 +107,33 @@ class ParameterExtractorNode(LLMNode): ): # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data, query, self.graph_runtime_state.variable_pool, model_config, memory + node_data=node_data, + query=query, + variable_pool=self.graph_runtime_state.variable_pool, + model_config=model_config, + memory=memory, + files=files, ) else: # use prompt engineering prompt_messages = self._generate_prompt_engineering_prompt( - node_data, query, self.graph_runtime_state.variable_pool, model_config, memory + data=node_data, + query=query, + variable_pool=self.graph_runtime_state.variable_pool, + model_config=model_config, + memory=memory, + files=files, ) + prompt_message_tools = [] + inputs = { + "query": query, + "files": [f.to_dict() for f in files], + "parameters": jsonable_encoder(node_data.parameters), + "instruction": jsonable_encoder(node_data.instruction), + } + process_data = { "model_mode": model_config.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( @@ -119,7 +145,7 @@ class ParameterExtractorNode(LLMNode): } try: - text, usage, tool_call = self._invoke_llm( + text, usage, tool_call = self._invoke( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, @@ -150,12 +176,12 @@ class ParameterExtractorNode(LLMNode): error = "Failed to extract result from function call or text response, using empty result." try: - result = self._validate_result(node_data, result) + result = self._validate_result(data=node_data, result=result or {}) except Exception as e: error = str(e) # transform result into standard format - result = self._transform_result(node_data, result) + result = self._transform_result(data=node_data, result=result or {}) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -170,7 +196,7 @@ class ParameterExtractorNode(LLMNode): llm_usage=usage, ) - def _invoke_llm( + def _invoke( self, node_data_model: ModelConfig, model_instance: ModelInstance, @@ -178,14 +204,6 @@ class ParameterExtractorNode(LLMNode): tools: list[PromptMessageTool], stop: list[str], ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: - """ - Invoke large language model - :param node_data_model: node data model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: - """ db.session.close() invoke_result = model_instance.invoke_llm( @@ -202,6 +220,9 @@ class ParameterExtractorNode(LLMNode): raise ValueError(f"Invalid invoke result: {invoke_result}") text = invoke_result.message.content + if not isinstance(text, str): + raise ValueError(f"Invalid text content type: {type(text)}. Expected str.") + usage = invoke_result.usage tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None @@ -217,6 +238,7 @@ class ParameterExtractorNode(LLMNode): variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], + files: Sequence[File], ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. @@ -234,7 +256,7 @@ class ParameterExtractorNode(LLMNode): prompt_template=prompt_template, inputs={}, query="", - files=[], + files=files, context="", memory_config=node_data.memory, memory=None, @@ -296,6 +318,7 @@ class ParameterExtractorNode(LLMNode): variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], + files: Sequence[File], ) -> list[PromptMessage]: """ Generate prompt engineering prompt. @@ -303,9 +326,23 @@ class ParameterExtractorNode(LLMNode): model_mode = ModelMode.value_of(data.model.mode) if model_mode == ModelMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory) + return self._generate_prompt_engineering_completion_prompt( + node_data=data, + query=query, + variable_pool=variable_pool, + model_config=model_config, + memory=memory, + files=files, + ) elif model_mode == ModelMode.CHAT: - return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory) + return self._generate_prompt_engineering_chat_prompt( + node_data=data, + query=query, + variable_pool=variable_pool, + model_config=model_config, + memory=memory, + files=files, + ) else: raise ValueError(f"Invalid model mode: {model_mode}") @@ -316,20 +353,23 @@ class ParameterExtractorNode(LLMNode): variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], + files: Sequence[File], ) -> list[PromptMessage]: """ Generate completion prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + rest_token = self._calculate_rest_token( + node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" + ) prompt_template = self._get_prompt_engineering_prompt_template( - node_data, query, variable_pool, memory, rest_token + node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, query="", - files=[], + files=files, context="", memory_config=node_data.memory, memory=memory, @@ -345,27 +385,30 @@ class ParameterExtractorNode(LLMNode): variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], + files: Sequence[File], ) -> list[PromptMessage]: """ Generate chat prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + rest_token = self._calculate_rest_token( + node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" + ) prompt_template = self._get_prompt_engineering_prompt_template( - node_data, - CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + node_data=node_data, + query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( structure=json.dumps(node_data.get_parameter_json_schema()), text=query ), - variable_pool, - memory, - rest_token, + variable_pool=variable_pool, + memory=memory, + max_token_limit=rest_token, ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, query="", - files=[], + files=files, context="", memory_config=node_data.memory, memory=None, @@ -425,10 +468,11 @@ class ParameterExtractorNode(LLMNode): raise ValueError(f"Invalid `string` value for parameter {parameter.name}") if parameter.type.startswith("array"): - if not isinstance(result.get(parameter.name), list): + parameters = result.get(parameter.name) + if not isinstance(parameters, list): raise ValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] - for item in result.get(parameter.name): + for item in parameters: if nested_type == "number" and not isinstance(item, int | float): raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") if nested_type == "string" and not isinstance(item, str): @@ -565,18 +609,6 @@ class ParameterExtractorNode(LLMNode): return result - def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str: - """ - Render instruction. - """ - variable_template_parser = VariableTemplateParser(instruction) - inputs = {} - for selector in variable_template_parser.extract_variable_selectors(): - variable = variable_pool.get_any(selector.value_selector) - inputs[selector.variable] = variable - - return variable_template_parser.format(inputs) - def _get_function_calling_prompt_template( self, node_data: ParameterExtractorNodeData, @@ -588,9 +620,9 @@ class ParameterExtractorNode(LLMNode): model_mode = ModelMode.value_of(node_data.model.mode) input_text = query memory_str = "" - instruction = self._render_instruction(node_data.instruction or "", variable_pool) + instruction = variable_pool.convert_template(node_data.instruction or "").text - if memory: + if memory and node_data.memory and node_data.memory.window: memory_str = memory.get_history_prompt_text( max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) @@ -611,13 +643,13 @@ class ParameterExtractorNode(LLMNode): variable_pool: VariablePool, memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, - ) -> list[ChatModelMessage]: + ): model_mode = ModelMode.value_of(node_data.model.mode) input_text = query memory_str = "" - instruction = self._render_instruction(node_data.instruction or "", variable_pool) + instruction = variable_pool.convert_template(node_data.instruction or "").text - if memory: + if memory and node_data.memory and node_data.memory.window: memory_str = memory.get_history_prompt_text( max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) @@ -691,7 +723,7 @@ class ParameterExtractorNode(LLMNode): ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -712,7 +744,11 @@ class ParameterExtractorNode(LLMNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ParameterExtractorNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -721,11 +757,11 @@ class ParameterExtractorNode(LLMNode): :param node_data: node data :return: """ - variable_mapping = {"query": node_data.query} + variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) - for selector in variable_template_parser.extract_variable_selectors(): + selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) + for selector in selectors: variable_mapping[selector.variable] = selector.value_selector variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} diff --git a/api/core/workflow/nodes/question_classifier/__init__.py b/api/core/workflow/nodes/question_classifier/__init__.py index e69de29bb2..70414c4199 100644 --- a/api/core/workflow/nodes/question_classifier/__init__.py +++ b/api/core/workflow/nodes/question_classifier/__init__.py @@ -0,0 +1,4 @@ +from .entities import QuestionClassifierNodeData +from .question_classifier_node import QuestionClassifierNode + +__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"] diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index 40f7ce7582..5219f11d26 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -1,39 +1,21 @@ -from typing import Any, Optional +from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class ModelConfig(BaseModel): - """ - Model Config. - """ - - provider: str - name: str - mode: str - completion_params: dict[str, Any] = {} +from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.llm import ModelConfig, VisionConfig class ClassConfig(BaseModel): - """ - Class Config. - """ - id: str name: str class QuestionClassifierNodeData(BaseNodeData): - """ - Knowledge retrieval Node Data. - """ - query_variable_selector: list[str] - type: str = "question-classifier" model: ModelConfig classes: list[ClassConfig] instruction: Optional[str] = None memory: Optional[MemoryConfig] = None + vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 2ae58bc5f7..e6af453dcf 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,25 +1,30 @@ import json import logging from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole -from core.model_runtime.entities.model_entities import ModelPropertyKey +from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted -from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData -from core.workflow.nodes.question_classifier.template_prompts import ( +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import ModelInvokeCompletedEvent +from core.workflow.nodes.llm import ( + LLMNode, + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser +from libs.json_in_md_parser import parse_and_check_json_markdown +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import QuestionClassifierNodeData +from .template_prompts import ( QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, QUESTION_CLASSIFIER_COMPLETION_PROMPT, @@ -28,46 +33,77 @@ from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_2, QUESTION_CLASSIFIER_USER_PROMPT_3, ) -from core.workflow.utils.variable_template_parser import VariableTemplateParser -from libs.json_in_md_parser import parse_and_check_json_markdown -from models.workflow import WorkflowNodeExecutionStatus + +if TYPE_CHECKING: + from core.file import File class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData - node_type = NodeType.QUESTION_CLASSIFIER + _node_type = NodeType.QUESTION_CLASSIFIER - def _run(self) -> NodeRunResult: - node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) - node_data = cast(QuestionClassifierNodeData, node_data) + def _run(self): + node_data = cast(QuestionClassifierNodeData, self.node_data) variable_pool = self.graph_runtime_state.variable_pool # extract variables - variable = variable_pool.get(node_data.query_variable_selector) + variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None query = variable.value if variable else None variables = {"query": query} # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory( + node_data_memory=node_data.memory, + model_instance=model_instance, + ) # fetch instruction - instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else "" - node_data.instruction = instruction + node_data.instruction = node_data.instruction or "" + node_data.instruction = variable_pool.convert_template(node_data.instruction).text + + files: Sequence[File] = ( + self._fetch_files( + selector=node_data.vision.configs.variable_selector, + ) + if node_data.vision.enabled + else [] + ) + # fetch prompt messages - prompt_messages, stop = self._fetch_prompt( - node_data=node_data, context="", query=query, memory=memory, model_config=model_config + rest_token = self._calculate_rest_token( + node_data=node_data, + query=query or "", + model_config=model_config, + context="", + ) + prompt_template = self._get_prompt_template( + node_data=node_data, + query=query or "", + memory=memory, + max_token_limit=rest_token, + ) + prompt_messages, stop = self._fetch_prompt_messages( + prompt_template=prompt_template, + system_query=query, + memory=memory, + model_config=model_config, + files=files, + vision_detail=node_data.vision.configs.detail, ) # handle invoke result generator = self._invoke_llm( - node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop + node_data_model=node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, ) result_text = "" usage = LLMUsage.empty_usage() finish_reason = None for event in generator: - if isinstance(event, ModelInvokeCompleted): + if isinstance(event, ModelInvokeCompletedEvent): result_text = event.text usage = event.usage finish_reason = event.finish_reason @@ -129,7 +165,11 @@ class QuestionClassifierNode(LLMNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: QuestionClassifierNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -159,40 +199,6 @@ class QuestionClassifierNode(LLMNode): """ return {"type": "question-classifier", "config": {"instructions": ""}} - def _fetch_prompt( - self, - node_data: QuestionClassifierNodeData, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: - """ - Fetch prompt - :param node_data: node data - :param query: inputs - :param context: context - :param memory: memory - :param model_config: model config - :return: - """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, model_config, context) - prompt_template = self._get_prompt_template(node_data, query, memory, rest_token) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=[], - context=context, - memory_config=node_data.memory, - memory=None, - model_config=model_config, - ) - stop = model_config.stop - - return prompt_messages, stop - def _calculate_rest_token( self, node_data: QuestionClassifierNodeData, @@ -229,7 +235,7 @@ class QuestionClassifierNode(LLMNode): ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -243,7 +249,7 @@ class QuestionClassifierNode(LLMNode): query: str, memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, - ) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: + ): model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes categories = [] @@ -255,31 +261,32 @@ class QuestionClassifierNode(LLMNode): memory_str = "" if memory: memory_str = memory.get_history_prompt_text( - max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + max_token_limit=max_token_limit, + message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, ) - prompt_messages = [] + prompt_messages: list[LLMNodeChatModelMessage] = [] if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( + system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) prompt_messages.append(system_prompt_messages) - user_prompt_message_1 = ChatModelMessage( + user_prompt_message_1 = LLMNodeChatModelMessage( role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 ) prompt_messages.append(user_prompt_message_1) - assistant_prompt_message_1 = ChatModelMessage( + assistant_prompt_message_1 = LLMNodeChatModelMessage( role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 ) prompt_messages.append(assistant_prompt_message_1) - user_prompt_message_2 = ChatModelMessage( + user_prompt_message_2 = LLMNodeChatModelMessage( role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 ) prompt_messages.append(user_prompt_message_2) - assistant_prompt_message_2 = ChatModelMessage( + assistant_prompt_message_2 = LLMNodeChatModelMessage( role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 ) prompt_messages.append(assistant_prompt_message_2) - user_prompt_message_3 = ChatModelMessage( + user_prompt_message_3 = LLMNodeChatModelMessage( role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( input_text=input_text, @@ -290,7 +297,7 @@ class QuestionClassifierNode(LLMNode): prompt_messages.append(user_prompt_message_3) return prompt_messages elif model_mode == ModelMode.COMPLETION: - return CompletionModelPromptTemplate( + return LLMNodeCompletionModelPromptTemplate( text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( histories=memory_str, input_text=input_text, @@ -302,23 +309,3 @@ class QuestionClassifierNode(LLMNode): else: raise ValueError(f"Model mode {model_mode} not support.") - - def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str: - inputs = {} - - variable_selectors = [] - variable_template_parser = VariableTemplateParser(template=instruction) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - for variable_selector in variable_selectors: - variable = variable_pool.get(variable_selector.value_selector) - variable_value = variable.value if variable else None - if variable_value is None: - raise ValueError(f"Variable {variable_selector.variable} not found") - - inputs[variable_selector.variable] = variable_value - - prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - instruction = prompt_template.format(prompt_inputs) - return instruction diff --git a/api/core/workflow/nodes/start/__init__.py b/api/core/workflow/nodes/start/__init__.py index e69de29bb2..5411780423 100644 --- a/api/core/workflow/nodes/start/__init__.py +++ b/api/core/workflow/nodes/start/__init__.py @@ -0,0 +1,3 @@ +from .start_node import StartNode + +__all__ = ["StartNode"] diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 11d2ebe5dd..594d1b7bab 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from pydantic import Field from core.app.app_config.entities import VariableEntity -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class StartNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 96c887c58d..a7b91e82bb 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,25 +1,24 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID -from core.workflow.nodes.base_node import BaseNode +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus -class StartNode(BaseNode): +class StartNode(BaseNode[StartNodeData]): _node_data_cls = StartNodeData _node_type = NodeType.START def _run(self) -> NodeRunResult: - """ - Run node - :return: - """ node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) system_inputs = self.graph_runtime_state.variable_pool.system_variables + # TODO: System variables should be directly accessible, no need for special handling + # Set system variables as node outputs. for var in system_inputs: node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] @@ -27,13 +26,10 @@ class StartNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: StartNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: StartNodeData, ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ return {} diff --git a/api/core/workflow/nodes/template_transform/__init__.py b/api/core/workflow/nodes/template_transform/__init__.py index e69de29bb2..43863b9d59 100644 --- a/api/core/workflow/nodes/template_transform/__init__.py +++ b/api/core/workflow/nodes/template_transform/__init__.py @@ -0,0 +1,3 @@ +from .template_transform_node import TemplateTransformNode + +__all__ = ["TemplateTransformNode"] diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py index e934d69fa3..96adff6ffa 100644 --- a/api/core/workflow/nodes/template_transform/entities.py +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -1,5 +1,5 @@ -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class TemplateTransformNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 32c99e0d1c..857a693c5b 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,17 +1,18 @@ import os from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, Optional from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) -class TemplateTransformNode(BaseNode): +class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): _node_data_cls = TemplateTransformNodeData _node_type = NodeType.TEMPLATE_TRANSFORM @@ -28,22 +29,16 @@ class TemplateTransformNode(BaseNode): } def _run(self) -> NodeRunResult: - """ - Run node - """ - node_data = self.node_data - node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data) - # Get variables variables = {} - for variable_selector in node_data.variables: + for variable_selector in self.node_data.variables: variable_name = variable_selector.variable value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) variables[variable_name] = value # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables + language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables ) except CodeExecutionError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) diff --git a/api/core/workflow/nodes/tool/__init__.py b/api/core/workflow/nodes/tool/__init__.py index e69de29bb2..f4982e655d 100644 --- a/api/core/workflow/nodes/tool/__init__.py +++ b/api/core/workflow/nodes/tool/__init__.py @@ -0,0 +1,3 @@ +from .tool_node import ToolNode + +__all__ = ["ToolNode"] diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 1a408d96cb..a3eed8fa5b 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base.entities import BaseNodeData class ToolEntity(BaseModel): @@ -52,7 +52,4 @@ class ToolNodeData(BaseNodeData, ToolEntity): raise ValueError("value must be a string, int, float, or bool") return typ - """ - Tool Node Schema - """ tool_parameters: dict[str, ToolInput] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index fea04f1fe9..becf11c3d4 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -2,24 +2,31 @@ from collections.abc import Generator, Mapping, Sequence from os import path from typing import Any, cast -from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file.models import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.variables.segments import ArrayAnySegment +from core.variables.variables import ArrayAnyVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunStreamChunkEvent +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser -from models import WorkflowNodeExecutionStatus +from extensions.ext_database import db +from models.tools import ToolFile +from models.workflow import WorkflowNodeExecutionStatus -class ToolNode(BaseNode): +class ToolNode(BaseNode[ToolNodeData]): """ Tool Node """ @@ -27,7 +34,7 @@ class ToolNode(BaseNode): _node_data_cls = ToolNodeData _node_type = NodeType.TOOL - def _run(self) -> Generator[RunEvent]: + def _run(self) -> Generator: """ Run the tool node """ @@ -40,7 +47,7 @@ class ToolNode(BaseNode): # get tool runtime try: tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from + self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) except Exception as e: yield RunCompletedEvent( @@ -56,12 +63,14 @@ class ToolNode(BaseNode): # get parameters tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] parameters = self._generate_parameters( - tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, ) parameters_for_log = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=node_data, + node_data=self.node_data, for_log=True, ) @@ -107,7 +116,7 @@ class ToolNode(BaseNode): node_data (ToolNodeData): The data associated with the tool node. Returns: - dict[str, Any]: A dictionary containing the generated parameters. + Mapping[str, Any]: A dictionary containing the generated parameters. """ tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} @@ -118,26 +127,22 @@ class ToolNode(BaseNode): if not parameter: result[parameter_name] = None continue - if parameter.type == ToolParameter.ToolParameterType.FILE: - result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)] + tool_input = node_data.tool_parameters[parameter_name] + if tool_input.type == "variable": + variable = variable_pool.get(tool_input.value) + if variable is None: + raise ValueError(f"variable {tool_input.value} not exists") + parameter_value = variable.value + elif tool_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(tool_input.value)) + parameter_value = segment_group.log if for_log else segment_group.text else: - tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == "variable": - parameter_value_segment = variable_pool.get(tool_input.value) - if not parameter_value_segment: - raise Exception("input variable dose not exists") - parameter_value = parameter_value_segment.value - else: - segment_group = parser.convert_template( - template=str(tool_input.value), - variable_pool=variable_pool, - ) - parameter_value = segment_group.log if for_log else segment_group.text - result[parameter_name] = parameter_value + raise ValueError(f"unknown tool input type '{tool_input.type}'") + result[parameter_name] = parameter_value return result - def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: + def _fetch_files(self, variable_pool: VariablePool) -> list[File]: variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] @@ -147,7 +152,7 @@ class ToolNode(BaseNode): messages: Generator[ToolInvokeMessage, None, None], tool_info: Mapping[str, Any], parameters_for_log: dict[str, Any], - ) -> Generator[RunEvent, None, None]: + ) -> Generator: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -159,7 +164,7 @@ class ToolNode(BaseNode): conversation_id=None, ) - files: list[FileVar] = [] + files: list[File] = [] text = "" json: list[dict] = [] @@ -172,22 +177,28 @@ class ToolNode(BaseNode): url = message.message.text ext = path.splitext(url)[1] + tool_file_id = str(url).split("/")[-1].split(".")[0] mimetype = message.meta.get("mime_type", "image/jpeg") filename = message.save_as or url.split("/")[-1] transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - # get tool file id - tool_file_id = url.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ValueError(f"tool file {tool_file_id} not exists") + files.append( - FileVar( + File( tenant_id=self.tenant_id, type=FileType.IMAGE, transfer_method=transfer_method, - url=url, + remote_url=url, related_id=tool_file_id, filename=filename, extension=ext, mime_type=mimetype, + size=tool_file.size, ) ) elif message.type == ToolInvokeMessage.MessageType.BLOB: @@ -196,8 +207,14 @@ class ToolNode(BaseNode): assert message.meta tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ValueError(f"tool file {tool_file_id} not exists") + files.append( - FileVar( + File( tenant_id=self.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, @@ -237,6 +254,9 @@ class ToolNode(BaseNode): ) else: variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) yield RunCompletedEvent( run_result=NodeRunResult( @@ -249,7 +269,11 @@ class ToolNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ToolNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/variable_aggregator/__init__.py b/api/core/workflow/nodes/variable_aggregator/__init__.py index e69de29bb2..0b6bf2a5b6 100644 --- a/api/core/workflow/nodes/variable_aggregator/__init__.py +++ b/api/core/workflow/nodes/variable_aggregator/__init__.py @@ -0,0 +1,3 @@ +from .variable_aggregator_node import VariableAggregatorNode + +__all__ = ["VariableAggregatorNode"] diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index eb893a04e3..71a930e6b0 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -2,7 +2,7 @@ from typing import Literal, Optional from pydantic import BaseModel -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class AdvancedSettings(BaseModel): diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index f03eae257a..05477e2a90 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,24 +1,24 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData from models.workflow import WorkflowNodeExecutionStatus -class VariableAggregatorNode(BaseNode): +class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_AGGREGATOR def _run(self) -> NodeRunResult: - node_data = cast(VariableAssignerNodeData, self.node_data) # Get variables outputs = {} inputs = {} - if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled: - for selector in node_data.variables: + if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: + for selector in self.node_data.variables: variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: outputs = {"output": variable} @@ -26,7 +26,7 @@ class VariableAggregatorNode(BaseNode): inputs = {".".join(selector[1:]): variable} break else: - for group in node_data.advanced_settings.groups: + for group in self.node_data.advanced_settings.groups: for selector in group.variables: variable = self.graph_runtime_state.variable_pool.get_any(selector) diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py index 3969299795..4e66f640df 100644 --- a/api/core/workflow/nodes/variable_assigner/node.py +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -1,40 +1,38 @@ -from typing import cast - from sqlalchemy import select from sqlalchemy.orm import Session -from core.app.segments import SegmentType, Variable, factory -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.variables import SegmentType, Variable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode, BaseNodeData +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models import ConversationVariable, WorkflowNodeExecutionStatus +from factories import variable_factory +from models import ConversationVariable +from models.workflow import WorkflowNodeExecutionStatus from .exc import VariableAssignerNodeError from .node_data import VariableAssignerData, WriteMode -class VariableAssignerNode(BaseNode): +class VariableAssignerNode(BaseNode[VariableAssignerData]): _node_data_cls: type[BaseNodeData] = VariableAssignerData _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER def _run(self) -> NodeRunResult: - data = cast(VariableAssignerData, self.node_data) - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector) + original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableAssignerNodeError("assigned variable not found") - match data.write_mode: + match self.node_data.write_mode: case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) if not income_value: raise VariableAssignerNodeError("input value not found") updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) if not income_value: raise VariableAssignerNodeError("input value not found") updated_value = original_variable.value + [income_value.value] @@ -45,10 +43,10 @@ class VariableAssignerNode(BaseNode): updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}") + raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}") # Over write the variable. - self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable) + self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. @@ -80,12 +78,12 @@ def update_conversation_variable(conversation_id: str, variable: Variable): def get_zero_value(t: SegmentType): match t: case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: - return factory.build_segment([]) + return variable_factory.build_segment([]) case SegmentType.OBJECT: - return factory.build_segment({}) + return variable_factory.build_segment({}) case SegmentType.STRING: - return factory.build_segment("") + return variable_factory.build_segment("") case SegmentType.NUMBER: - return factory.build_segment(0) + return variable_factory.build_segment(0) case _: raise VariableAssignerNodeError(f"unsupported variable type: {t}") diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py index 8ac8eadf7c..70ae29d45f 100644 --- a/api/core/workflow/nodes/variable_assigner/node_data.py +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from enum import Enum from typing import Optional -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class WriteMode(str, Enum): diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index b8e8b881a5..1d96743879 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -1,32 +1,46 @@ -from typing import Literal, Optional +from collections.abc import Sequence +from typing import Literal -from pydantic import BaseModel +from pydantic import BaseModel, Field + +SupportedComparisonOperator = Literal[ + # for string or array + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", + "in", + "not in", + "all of", + # for number + "=", + "≠", + ">", + "<", + "≥", + "≤", + "null", + "not null", +] + + +class SubCondition(BaseModel): + key: str + comparison_operator: SupportedComparisonOperator + value: str | Sequence[str] | None = None + + +class SubVariableCondition(BaseModel): + logical_operator: Literal["and", "or"] + conditions: list[SubCondition] = Field(default=list) class Condition(BaseModel): - """ - Condition entity - """ - variable_selector: list[str] - comparison_operator: Literal[ - # for string or array - "contains", - "not contains", - "start with", - "end with", - "is", - "is not", - "empty", - "not empty", - # for number - "=", - "≠", - ">", - "<", - "≥", - "≤", - "null", - "not null", - ] - value: Optional[str] = None + comparison_operator: SupportedComparisonOperator + value: str | Sequence[str] | None = None + sub_variable_condition: SubVariableCondition | None = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index 395ee82478..f4a80fa5e1 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -1,381 +1,362 @@ from collections.abc import Sequence -from typing import Any, Optional +from typing import Any, Literal -from core.file.file_obj import FileVar +from core.file import FileAttribute, file_manager +from core.variables.segments import ArrayFileSegment from core.workflow.entities.variable_pool import VariablePool -from core.workflow.utils.condition.entities import Condition -from core.workflow.utils.variable_template_parser import VariableTemplateParser + +from .entities import Condition, SubCondition, SupportedComparisonOperator class ConditionProcessor: - def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): - input_conditions = [] - group_result = [] - - index = 0 - for condition in conditions: - index += 1 - actual_value = variable_pool.get_any(condition.variable_selector) - - expected_value = None - if condition.value is not None: - variable_template_parser = VariableTemplateParser(template=condition.value) - variable_selectors = variable_template_parser.extract_variable_selectors() - if variable_selectors: - for variable_selector in variable_selectors: - value = variable_pool.get_any(variable_selector.value_selector) - expected_value = variable_template_parser.format({variable_selector.variable: value}) - - if expected_value is None: - expected_value = condition.value - else: - expected_value = condition.value - - comparison_operator = condition.comparison_operator - input_conditions.append( - { - "actual_value": actual_value, - "expected_value": expected_value, - "comparison_operator": comparison_operator, - } - ) - - result = self.evaluate_condition(actual_value, comparison_operator, expected_value) - group_result.append(result) - - return input_conditions, group_result - - def evaluate_condition( + def process_conditions( self, - actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], - comparison_operator: str, - expected_value: Optional[str] = None, - ) -> bool: - """ - Evaluate condition - :param actual_value: actual value - :param expected_value: expected value - :param comparison_operator: comparison operator + *, + variable_pool: VariablePool, + conditions: Sequence[Condition], + operator: Literal["and", "or"], + ): + input_conditions = [] + group_results = [] - :return: bool - """ - if comparison_operator == "contains": - return self._assert_contains(actual_value, expected_value) - elif comparison_operator == "not contains": - return self._assert_not_contains(actual_value, expected_value) - elif comparison_operator == "start with": - return self._assert_start_with(actual_value, expected_value) - elif comparison_operator == "end with": - return self._assert_end_with(actual_value, expected_value) - elif comparison_operator == "is": - return self._assert_is(actual_value, expected_value) - elif comparison_operator == "is not": - return self._assert_is_not(actual_value, expected_value) - elif comparison_operator == "empty": - return self._assert_empty(actual_value) - elif comparison_operator == "not empty": - return self._assert_not_empty(actual_value) - elif comparison_operator == "=": - return self._assert_equal(actual_value, expected_value) - elif comparison_operator == "≠": - return self._assert_not_equal(actual_value, expected_value) - elif comparison_operator == ">": - return self._assert_greater_than(actual_value, expected_value) - elif comparison_operator == "<": - return self._assert_less_than(actual_value, expected_value) - elif comparison_operator == "≥": - return self._assert_greater_than_or_equal(actual_value, expected_value) - elif comparison_operator == "≤": - return self._assert_less_than_or_equal(actual_value, expected_value) - elif comparison_operator == "null": - return self._assert_null(actual_value) - elif comparison_operator == "not null": - return self._assert_not_null(actual_value) - else: - raise ValueError(f"Invalid comparison operator: {comparison_operator}") + for condition in conditions: + variable = variable_pool.get(condition.variable_selector) - def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False + if isinstance(variable, ArrayFileSegment) and condition.comparison_operator in { + "contains", + "not contains", + "all of", + }: + # check sub conditions + if not condition.sub_variable_condition: + raise ValueError("Sub variable is required") + result = _process_sub_conditions( + variable=variable, + sub_conditions=condition.sub_variable_condition.conditions, + operator=condition.sub_variable_condition.logical_operator, + ) + else: + actual_value = variable.value if variable else None + expected_value = condition.value + if isinstance(expected_value, str): + expected_value = variable_pool.convert_template(expected_value).text + input_conditions.append( + { + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": condition.comparison_operator, + } + ) + result = _evaluate_condition( + value=actual_value, + operator=condition.comparison_operator, + expected=expected_value, + ) + group_results.append(result) - if not isinstance(actual_value, str | list): - raise ValueError("Invalid actual value type: string or array") + final_result = all(group_results) if operator == "and" else any(group_results) + return input_conditions, group_results, final_result - if expected_value not in actual_value: - return False - return True - def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert not contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return True +def _evaluate_condition( + *, + operator: SupportedComparisonOperator, + value: Any, + expected: str | Sequence[str] | None, +) -> bool: + match operator: + case "contains": + return _assert_contains(value=value, expected=expected) + case "not contains": + return _assert_not_contains(value=value, expected=expected) + case "start with": + return _assert_start_with(value=value, expected=expected) + case "end with": + return _assert_end_with(value=value, expected=expected) + case "is": + return _assert_is(value=value, expected=expected) + case "is not": + return _assert_is_not(value=value, expected=expected) + case "empty": + return _assert_empty(value=value) + case "not empty": + return _assert_not_empty(value=value) + case "=": + return _assert_equal(value=value, expected=expected) + case "≠": + return _assert_not_equal(value=value, expected=expected) + case ">": + return _assert_greater_than(value=value, expected=expected) + case "<": + return _assert_less_than(value=value, expected=expected) + case "≥": + return _assert_greater_than_or_equal(value=value, expected=expected) + case "≤": + return _assert_less_than_or_equal(value=value, expected=expected) + case "null": + return _assert_null(value=value) + case "not null": + return _assert_not_null(value=value) + case "in": + return _assert_in(value=value, expected=expected) + case "not in": + return _assert_not_in(value=value, expected=expected) + case "all of" if isinstance(expected, list): + return _assert_all_of(value=value, expected=expected) + case _: + raise ValueError(f"Unsupported operator: {operator}") - if not isinstance(actual_value, str | list): - raise ValueError("Invalid actual value type: string or array") - if expected_value in actual_value: - return False - return True - - def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert start with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError("Invalid actual value type: string") - - if not actual_value.startswith(expected_value): - return False - return True - - def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert end with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError("Invalid actual value type: string") - - if not actual_value.endswith(expected_value): - return False - return True - - def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError("Invalid actual value type: string") - - if actual_value != expected_value: - return False - return True - - def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is not - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError("Invalid actual value type: string") - - if actual_value == expected_value: - return False - return True - - def _assert_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert empty - :param actual_value: actual value - :return: - """ - if not actual_value: - return True +def _assert_contains(*, value: Any, expected: Any) -> bool: + if not value: return False - def _assert_not_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert not empty - :param actual_value: actual value - :return: - """ - if actual_value: - return True + if not isinstance(value, str | list): + raise ValueError("Invalid actual value type: string or array") + + if expected not in value: + return False + return True + + +def _assert_not_contains(*, value: Any, expected: Any) -> bool: + if not value: + return True + + if not isinstance(value, str | list): + raise ValueError("Invalid actual value type: string or array") + + if expected in value: + return False + return True + + +def _assert_start_with(*, value: Any, expected: Any) -> bool: + if not value: return False - def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: - """ - Assert equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") + if not value.startswith(expected): + return False + return True - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - if actual_value != expected_value: - return False - return True - - def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: - """ - Assert not equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value == expected_value: - return False - return True - - def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: - """ - Assert greater than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value <= expected_value: - return False - return True - - def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: - """ - Assert less than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value >= expected_value: - return False - return True - - def _assert_greater_than_or_equal( - self, actual_value: Optional[int | float], expected_value: str | int | float - ) -> bool: - """ - Assert greater than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value < expected_value: - return False - return True - - def _assert_less_than_or_equal( - self, actual_value: Optional[int | float], expected_value: str | int | float - ) -> bool: - """ - Assert less than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value > expected_value: - return False - return True - - def _assert_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert null - :param actual_value: actual value - :return: - """ - if actual_value is None: - return True +def _assert_end_with(*, value: Any, expected: Any) -> bool: + if not value: return False - def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert not null - :param actual_value: actual value - :return: - """ - if actual_value is not None: - return True + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") + + if not value.endswith(expected): + return False + return True + + +def _assert_is(*, value: Any, expected: Any) -> bool: + if value is None: return False + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") -class ConditionAssertionError(Exception): - def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None: - self.message = message - self.conditions = conditions - self.sub_condition_compare_results = sub_condition_compare_results - super().__init__(self.message) + if value != expected: + return False + return True + + +def _assert_is_not(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") + + if value == expected: + return False + return True + + +def _assert_empty(*, value: Any) -> bool: + if not value: + return True + return False + + +def _assert_not_empty(*, value: Any) -> bool: + if value: + return True + return False + + +def _assert_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value != expected: + return False + return True + + +def _assert_not_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value == expected: + return False + return True + + +def _assert_greater_than(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value <= expected: + return False + return True + + +def _assert_less_than(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value >= expected: + return False + return True + + +def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value < expected: + return False + return True + + +def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value > expected: + return False + return True + + +def _assert_null(*, value: Any) -> bool: + if value is None: + return True + return False + + +def _assert_not_null(*, value: Any) -> bool: + if value is not None: + return True + return False + + +def _assert_in(*, value: Any, expected: Any) -> bool: + if not value: + return False + + if not isinstance(expected, list): + raise ValueError("Invalid expected value type: array") + + if value not in expected: + return False + return True + + +def _assert_not_in(*, value: Any, expected: Any) -> bool: + if not value: + return True + + if not isinstance(expected, list): + raise ValueError("Invalid expected value type: array") + + if value in expected: + return False + return True + + +def _assert_all_of(*, value: Any, expected: Sequence[str]) -> bool: + if not value: + return False + + if not all(item in value for item in expected): + return False + return True + + +def _process_sub_conditions( + variable: ArrayFileSegment, + sub_conditions: Sequence[SubCondition], + operator: Literal["and", "or"], +) -> bool: + files = variable.value + group_results = [] + for condition in sub_conditions: + key = FileAttribute(condition.key) + values = [file_manager.get_attr(file=file, attr=key) for file in files] + sub_group_results = [ + _evaluate_condition( + value=value, + operator=condition.comparison_operator, + expected=condition.value, + ) + for value in values + ] + # Determine the result based on the presence of "not" in the comparison operator + result = all(sub_group_results) if "not" in condition.comparison_operator else any(sub_group_results) + group_results.append(result) + return all(group_results) if operator == "and" else any(group_results) diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py index fd0e48b862..1d8fb38ebf 100644 --- a/api/core/workflow/utils/variable_template_parser.py +++ b/api/core/workflow/utils/variable_template_parser.py @@ -1,42 +1,21 @@ import re -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") +SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") -def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: - """ - This is an alternative to the VariableTemplateParser class, - offering the same functionality but with better readability and ease of use. - """ - variable_keys = [match[0] for match in re.findall(REGEX, template)] - variable_keys = list(set(variable_keys)) - # This key_selector is a tuple of (key, selector) where selector is a list of keys - # e.g. ('#node_id.query.name#', ['node_id', 'query', 'name']) - key_selectors = filter( - lambda t: len(t[1]) >= 2, - ((key, selector.replace("#", "").split(".")) for key, selector in zip(variable_keys, variable_keys)), - ) - inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors} - - def replacer(match): - key = match.group(1) - # return original matched string if key not found - value = inputs.get(key, match.group(0)) - if value is None: - value = "" - value = str(value) - # remove template variables if required - return re.sub(REGEX, r"{\1}", value) - - result = re.sub(REGEX, replacer, template) - result = re.sub(r"<\|.*?\|>", "", result) - return result +def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]: + parts = SELECTOR_PATTERN.split(template) + selectors = [] + for part in filter(lambda x: x, parts): + if "." in part and part[0] == "#" and part[-1] == "#": + selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split("."))) + return selectors class VariableTemplateParser: diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 0420d62ef7..fdc17cb4da 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -8,10 +8,8 @@ from configs import dify_config from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileTransferMethod, FileType, FileVar -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType, UserFrom +from core.file.models import File, FileTransferMethod, FileType, ImageConfig +from core.workflow.callbacks import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent @@ -19,10 +17,12 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.event import RunEvent -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.node_mapping import node_classes +from core.workflow.nodes import NodeType +from core.workflow.nodes.base import BaseNode, BaseNodeData +from core.workflow.nodes.event import NodeEvent +from core.workflow.nodes.llm import LLMNodeData +from core.workflow.nodes.node_mapping import node_type_classes_mapping +from models.enums import UserFrom from models.workflow import ( Workflow, WorkflowType, @@ -115,7 +115,7 @@ class WorkflowEntry: @classmethod def single_step_run( cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict - ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: + ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node :param workflow: Workflow instance @@ -144,8 +144,8 @@ class WorkflowEntry: raise ValueError("node id not found in workflow graph") # Get node class - node_type = NodeType.value_of(node_config.get("data", {}).get("type")) - node_cls = node_classes.get(node_type) + node_type = NodeType(node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping.get(node_type) node_cls = cast(type[BaseNode], node_cls) if not node_cls: @@ -162,7 +162,7 @@ class WorkflowEntry: graph = Graph.init(graph_config=workflow.graph_dict) # init workflow run state - node_instance: BaseNode = node_cls( + node_instance = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -208,7 +208,7 @@ class WorkflowEntry: @classmethod def run_free_node( cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any] - ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: + ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Run free node @@ -244,11 +244,11 @@ class WorkflowEntry: ], } - node_type = NodeType.value_of(node_data.get("type", "")) + node_type = NodeType(node_data.get("type", "")) if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}: raise ValueError(f"Node type {node_type} not supported") - node_cls = node_classes.get(node_type) + node_cls = node_type_classes_mapping.get(node_type) if not node_cls: raise ValueError(f"Node class not found for node type {node_type}") @@ -306,32 +306,27 @@ class WorkflowEntry: except Exception as e: raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) - @classmethod - def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: - """ - Handle special values - :param value: value - :return: - """ - if not value: - return None + @staticmethod + def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + return WorkflowEntry._handle_special_values(value) - new_value = dict(value) if value else {} - if isinstance(new_value, dict): - for key, val in new_value.items(): - if isinstance(val, FileVar): - new_value[key] = val.to_dict() - elif isinstance(val, list): - new_val = [] - for v in val: - if isinstance(v, FileVar): - new_val.append(v.to_dict()) - else: - new_val.append(v) - - new_value[key] = new_val - - return new_value + @staticmethod + def _handle_special_values(value: Any) -> Any: + if value is None: + return value + if isinstance(value, dict): + res = {} + for k, v in value.items(): + res[k] = WorkflowEntry._handle_special_values(v) + return res + if isinstance(value, list): + res = [] + for item in value: + res.append(WorkflowEntry._handle_special_values(item)) + return res + if isinstance(value, File): + return value.to_dict() + return value @classmethod def mapping_user_inputs_to_variable_pool( @@ -377,20 +372,24 @@ class WorkflowEntry: for item in input_value: if isinstance(item, dict) and "type" in item and item["type"] == "image": transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) - file = FileVar( + file = File( tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=transfer_method, - url=item.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + remote_url=item.get("url") + if transfer_method == FileTransferMethod.REMOTE_URL + else None, related_id=item.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={"detail": detail} if detail else None), + _extra_config=FileExtraConfig( + image_config=ImageConfig(detail=detail) if detail else None + ), ) new_value.append(file) if new_value: - value = new_value + input_value = new_value # append variable and value to variable pool variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index f96bb5ef74..9c5955c8c5 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,6 +1,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index c5e98e263f..453395e8d7 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,6 +1,6 @@ from typing import cast -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes import NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index f90629262d..5fc4f88832 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -72,7 +72,7 @@ class Storage: logging.exception("Failed to save file: %s", e) raise e - def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]: + def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: try: if stream: return self.load_stream(filename) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py new file mode 100644 index 0000000000..eac5090c2b --- /dev/null +++ b/api/factories/file_factory.py @@ -0,0 +1,254 @@ +import mimetypes +from collections.abc import Mapping, Sequence +from typing import Any + +from sqlalchemy import select + +from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS +from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType +from core.helper import ssrf_proxy +from extensions.ext_database import db +from models import MessageFile, ToolFile, UploadFile +from models.enums import CreatedByRole + + +def build_from_message_files( + *, + message_files: Sequence["MessageFile"], + tenant_id: str, + config: FileExtraConfig, +) -> Sequence[File]: + results = [ + build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) + for file in message_files + if file.belongs_to != FileBelongsTo.ASSISTANT + ] + return results + + +def build_from_message_file( + *, + message_file: "MessageFile", + tenant_id: str, + config: FileExtraConfig, +): + mapping = { + "transfer_method": message_file.transfer_method, + "url": message_file.url, + "id": message_file.id, + "type": message_file.type, + "upload_file_id": message_file.upload_file_id, + } + return build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + user_id=message_file.created_by, + role=CreatedByRole(message_file.created_by_role), + config=config, + ) + + +def build_from_mapping( + *, + mapping: Mapping[str, Any], + tenant_id: str, + user_id: str, + role: "CreatedByRole", + config: FileExtraConfig, +): + transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) + match transfer_method: + case FileTransferMethod.REMOTE_URL: + file = _build_from_remote_url( + mapping=mapping, + tenant_id=tenant_id, + config=config, + transfer_method=transfer_method, + ) + case FileTransferMethod.LOCAL_FILE: + file = _build_from_local_file( + mapping=mapping, + tenant_id=tenant_id, + user_id=user_id, + role=role, + config=config, + transfer_method=transfer_method, + ) + case FileTransferMethod.TOOL_FILE: + file = _build_from_tool_file( + mapping=mapping, + tenant_id=tenant_id, + user_id=user_id, + config=config, + transfer_method=transfer_method, + ) + case _: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + + return file + + +def build_from_mappings( + *, + mappings: Sequence[Mapping[str, Any]], + config: FileExtraConfig | None, + tenant_id: str, + user_id: str, + role: "CreatedByRole", +) -> Sequence[File]: + if not config: + return [] + + files = [ + build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + user_id=user_id, + role=role, + config=config, + ) + for mapping in mappings + ] + + if ( + # If image config is set. + config.image_config + # And the number of image files exceeds the maximum limit + and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits + ): + raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") + if config.number_limits and len(files) > config.number_limits: + raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") + + return files + + +def _build_from_local_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + user_id: str, + role: "CreatedByRole", + config: FileExtraConfig, + transfer_method: FileTransferMethod, +): + # check if the upload file exists. + file_type = FileType.value_of(mapping.get("type")) + stmt = select(UploadFile).where( + UploadFile.id == mapping.get("upload_file_id"), + UploadFile.tenant_id == tenant_id, + UploadFile.created_by == user_id, + UploadFile.created_by_role == role, + ) + if file_type == FileType.IMAGE: + stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS)) + elif file_type == FileType.VIDEO: + stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS)) + elif file_type == FileType.AUDIO: + stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS)) + elif file_type == FileType.DOCUMENT: + stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS)) + row = db.session.scalar(stmt) + if row is None: + raise ValueError("Invalid upload file") + file = File( + id=mapping.get("id"), + filename=row.name, + extension=row.extension, + mime_type=row.mime_type, + tenant_id=tenant_id, + type=file_type, + transfer_method=transfer_method, + remote_url=None, + related_id=mapping.get("upload_file_id"), + _extra_config=config, + size=row.size, + ) + return file + + +def _build_from_remote_url( + *, + mapping: Mapping[str, Any], + tenant_id: str, + config: FileExtraConfig, + transfer_method: FileTransferMethod, +): + url = mapping.get("url") + if not url: + raise ValueError("Invalid file url") + resp = ssrf_proxy.head(url, follow_redirects=True) + resp.raise_for_status() + + # Try to extract filename from response headers or URL + content_disposition = resp.headers.get("Content-Disposition") + if content_disposition: + filename = content_disposition.split("filename=")[-1].strip('"') + else: + filename = url.split("/")[-1].split("?")[0] + # If filename is empty, set a default one + if not filename: + filename = "unknown_file" + + # Determine file extension + extension = "." + filename.split(".")[-1] if "." in filename else ".bin" + + # Create the File object + file_size = int(resp.headers.get("Content-Length", -1)) + mime_type = str(resp.headers.get("Content-Type", "")) + if not mime_type: + mime_type, _ = mimetypes.guess_type(url) + file = File( + id=mapping.get("id"), + filename=filename, + tenant_id=tenant_id, + type=FileType.value_of(mapping.get("type")), + transfer_method=transfer_method, + remote_url=url, + _extra_config=config, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + return file + + +def _build_from_tool_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + user_id: str, + config: FileExtraConfig, + transfer_method: FileTransferMethod, +): + tool_file = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == mapping.get("tool_file_id"), + ToolFile.tenant_id == tenant_id, + ToolFile.user_id == user_id, + ) + .first() + ) + if tool_file is None: + raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") + + path = tool_file.file_key + if "." in path: + extension = "." + path.split("/")[-1].split(".")[-1] + else: + extension = ".bin" + file = File( + id=mapping.get("id"), + tenant_id=tenant_id, + filename=tool_file.name, + type=FileType.value_of(mapping.get("type")), + transfer_method=transfer_method, + remote_url=tool_file.original_url, + related_id=tool_file.id, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + _extra_config=config, + ) + return file diff --git a/api/core/app/segments/factory.py b/api/factories/variable_factory.py similarity index 73% rename from api/core/app/segments/factory.py rename to api/factories/variable_factory.py index 40a69ed4eb..a758f9981f 100644 --- a/api/core/app/segments/factory.py +++ b/api/factories/variable_factory.py @@ -2,29 +2,32 @@ from collections.abc import Mapping from typing import Any from configs import dify_config - -from .exc import VariableError -from .segments import ( +from core.file import File +from core.variables import ( ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayNumberVariable, + ArrayObjectSegment, + ArrayObjectVariable, + ArrayStringSegment, + ArrayStringVariable, + FileSegment, FloatSegment, + FloatVariable, IntegerSegment, + IntegerVariable, NoneSegment, ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - FloatVariable, - IntegerVariable, ObjectVariable, SecretVariable, + Segment, + SegmentType, + StringSegment, StringVariable, Variable, ) +from core.variables.exc import VariableError def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: @@ -71,6 +74,22 @@ def build_segment(value: Any, /) -> Segment: return FloatSegment(value=value) if isinstance(value, dict): return ObjectSegment(value=value) + if isinstance(value, File): + return FileSegment(value=value) if isinstance(value, list): - return ArrayAnySegment(value=value) + items = [build_segment(item) for item in value] + types = {item.value_type for item in items} + if len(types) != 1: + return ArrayAnySegment(value=value) + match types.pop(): + case SegmentType.STRING: + return ArrayStringSegment(value=value) + case SegmentType.NUMBER: + return ArrayNumberSegment(value=value) + case SegmentType.OBJECT: + return ArrayObjectSegment(value=value) + case SegmentType.FILE: + return ArrayFileSegment(value=value) + case _: + raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}") diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 3dcd88d1de..bf1c491a05 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,6 +3,8 @@ from flask_restful import fields from fields.member_fields import simple_account_fields from libs.helper import TimestampField +from .raws import FilesContainedField + class MessageTextField(fields.Raw): def format(self, value): @@ -33,8 +35,12 @@ annotation_hit_history_fields = { message_file_fields = { "id": fields.String, + "filename": fields.String, "type": fields.String, "url": fields.String, + "mime_type": fields.String, + "size": fields.Integer, + "transfer_method": fields.String, "belongs_to": fields.String(default="user"), } @@ -55,7 +61,7 @@ agent_thought_fields = { message_detail_fields = { "id": fields.String, "conversation_id": fields.String, - "inputs": fields.Raw, + "inputs": FilesContainedField, "query": fields.String, "message": fields.Raw, "message_tokens": fields.Integer, @@ -71,7 +77,7 @@ message_detail_fields = { "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True), "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, @@ -99,7 +105,7 @@ simple_model_config_fields = { } simple_message_detail_fields = { - "inputs": fields.Raw, + "inputs": FilesContainedField, "query": fields.String, "message": MessageTextField, "answer": fields.String, @@ -187,7 +193,7 @@ conversation_detail_fields = { simple_conversation_fields = { "id": fields.String, "name": fields.String, - "inputs": fields.Raw, + "inputs": FilesContainedField, "status": fields.String, "introduction": fields.String, "created_at": TimestampField, diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index e5a03ce77e..4ce7644e9d 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -17,3 +17,8 @@ file_fields = { "created_by": fields.String, "created_at": TimestampField, } + +remote_file_info_fields = { + "file_type": fields.String(attribute="file_type"), + "file_length": fields.Integer(attribute="file_length"), +} diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index c938097131..5f6e7884a6 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,6 +3,8 @@ from flask_restful import fields from fields.conversation_fields import message_file_fields from libs.helper import TimestampField +from .raws import FilesContainedField + feedback_fields = {"rating": fields.String} retriever_resource_fields = { @@ -63,14 +65,14 @@ message_fields = { "id": fields.String, "conversation_id": fields.String, "parent_message_id": fields.String, - "inputs": fields.Raw, + "inputs": FilesContainedField, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "status": fields.String, "error": fields.String, } diff --git a/api/fields/raws.py b/api/fields/raws.py new file mode 100644 index 0000000000..15ec16ab13 --- /dev/null +++ b/api/fields/raws.py @@ -0,0 +1,17 @@ +from flask_restful import fields + +from core.file import File + + +class FilesContainedField(fields.Raw): + def format(self, value): + return self._format_file_object(value) + + def _format_file_object(self, v): + if isinstance(v, File): + return v.model_dump() + if isinstance(v, dict): + return {k: self._format_file_object(vv) for k, vv in v.items()} + if isinstance(v, list): + return [self._format_file_object(vv) for vv in v] + return v diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 2adef63ada..0d860d6f40 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restful import fields -from core.app.segments import SecretVariable, SegmentType, Variable from core.helper import encrypter +from core.variables import SecretVariable, SegmentType, Variable from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/libs/helper.py b/api/libs/helper.py index e22469e964..48ef1bdc48 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,7 +16,7 @@ from flask import Response, current_app, stream_with_context from flask_restful import fields from core.app.features.rate_limiting.rate_limit import RateLimitGenerator -from core.file.upload_file_parser import UploadFileParser +from core.file import helpers as file_helpers from extensions.ext_redis import redis_client from models.account import Account @@ -33,7 +33,7 @@ class AppIconUrlField(fields.Raw): from models.model import IconType if obj.icon_type == IconType.IMAGE.value: - return UploadFileParser.get_signed_temp_image_url(obj.icon) + return file_helpers.get_signed_file_url(obj.icon) return None @@ -189,23 +189,39 @@ def compact_generate_response(response: Union[dict, Generator, RateLimitGenerato class TokenManager: @classmethod - def generate_token(cls, account: Account, token_type: str, additional_data: Optional[dict] = None) -> str: - old_token = cls._get_current_token_for_account(account.id, token_type) - if old_token: - if isinstance(old_token, bytes): - old_token = old_token.decode("utf-8") - cls.revoke_token(old_token, token_type) + def generate_token( + cls, + token_type: str, + account: Optional[Account] = None, + email: Optional[str] = None, + additional_data: Optional[dict] = None, + ) -> str: + if account is None and email is None: + raise ValueError("Account or email must be provided") + + account_id = account.id if account else None + account_email = account.email if account else email + + if account_id: + old_token = cls._get_current_token_for_account(account_id, token_type) + if old_token: + if isinstance(old_token, bytes): + old_token = old_token.decode("utf-8") + cls.revoke_token(old_token, token_type) token = str(uuid.uuid4()) - token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} + token_data = {"account_id": account_id, "email": account_email, "token_type": token_type} if additional_data: token_data.update(additional_data) - expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] + expiry_minutes = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES"] token_key = cls._get_token_key(token, token_type) - redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) + expiry_time = int(expiry_minutes * 60) + redis_client.setex(token_key, expiry_time, json.dumps(token_data)) + + if account_id: + cls._set_current_token_for_account(account.id, token, token_type, expiry_minutes) - cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) return token @classmethod @@ -234,9 +250,12 @@ class TokenManager: return current_token @classmethod - def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int): + def _set_current_token_for_account( + cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float] + ): key = cls._get_account_token_key(account_id, token_type) - redis_client.setex(key, expiry_hours * 60 * 60, token) + expiry_time = int(expiry_hours * 60 * 60) + redis_client.setex(key, expiry_time, token) @classmethod def _get_account_token_key(cls, account_id: str, token_type: str) -> str: diff --git a/api/libs/oauth.py b/api/libs/oauth.py index d8ce1a1e66..6b6919de24 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,5 +1,6 @@ import urllib.parse from dataclasses import dataclass +from typing import Optional import requests @@ -40,12 +41,14 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self): + def get_authorization_url(self, invite_token: Optional[str] = None): params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, "scope": "user:email", # Request only basic user information } + if invite_token: + params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): @@ -90,13 +93,15 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self): + def get_authorization_url(self, invite_token: Optional[str] = None): params = { "client_id": self.client_id, "response_type": "code", "redirect_uri": self.redirect_uri, "scope": "openid email", } + if invite_token: + params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): diff --git a/api/libs/password.py b/api/libs/password.py index cfcc0db22d..cdf55c57e5 100644 --- a/api/libs/password.py +++ b/api/libs/password.py @@ -13,7 +13,7 @@ def valid_password(password): if re.match(pattern, password) is not None: return password - raise ValueError("Not a valid password.") + raise ValueError("Password must contain letters and numbers, and the length must be greater than 8.") def hash_password(password_str, salt_byte): diff --git a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py index be2c615525..6a7402b16a 100644 --- a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py +++ b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py @@ -8,7 +8,7 @@ Create Date: 2024-06-12 07:49:07.666510 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '04c602f5dc9b' @@ -20,8 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tracing_app_configs', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('app_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), sa.Column('tracing_provider', sa.String(length=255), nullable=True), sa.Column('tracing_config', sa.JSON(), nullable=True), sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), diff --git a/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py new file mode 100644 index 0000000000..c17d1db77a --- /dev/null +++ b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py @@ -0,0 +1,49 @@ +"""add name and size to tool_files + +Revision ID: bbadea11becb +Revises: 33f5fac87f29 +Create Date: 2024-10-10 05:16:14.764268 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'bbadea11becb' +down_revision = 'd8e744d88ed6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Get the database connection + conn = op.get_bind() + + # Use SQLAlchemy inspector to get the columns of the 'tool_files' table + inspector = sa.inspect(conn) + columns = [col['name'] for col in inspector.get_columns('tool_files')] + + # If 'name' or 'size' columns already exist, exit the upgrade function + if 'name' in columns or 'size' in columns: + return + + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(), nullable=True)) + batch_op.add_column(sa.Column('size', sa.Integer(), nullable=True)) + op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL") + op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL") + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('name', existing_type=sa.String(), nullable=False) + batch_op.alter_column('size', existing_type=sa.Integer(), nullable=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.drop_column('size') + batch_op.drop_column('name') + # ### end Alembic commands ### diff --git a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py index db3119badf..bf54c247ea 100644 --- a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py +++ b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py @@ -8,7 +8,7 @@ Create Date: 2024-05-14 09:27:18.857890 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '3b18fea55204' @@ -20,7 +20,7 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tool_label_bindings', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('tool_id', sa.String(length=64), nullable=False), sa.Column('tool_type', sa.String(length=40), nullable=False), sa.Column('label_name', sa.String(length=40), nullable=False), diff --git a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py index 67d7b9fbf5..3be4ba4f2a 100644 --- a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py +++ b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py @@ -8,7 +8,7 @@ Create Date: 2024-05-10 12:08:09.812736 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '4e99a8df00ff' @@ -20,8 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('load_balancing_model_configs', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('provider_name', sa.String(length=255), nullable=False), sa.Column('model_name', sa.String(length=255), nullable=False), sa.Column('model_type', sa.String(length=40), nullable=False), @@ -36,8 +36,8 @@ def upgrade(): batch_op.create_index('load_balancing_model_config_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) op.create_table('provider_model_settings', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('provider_name', sa.String(length=255), nullable=False), sa.Column('model_name', sa.String(length=255), nullable=False), sa.Column('model_type', sa.String(length=40), nullable=False), diff --git a/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py index f63bad9345..2ba0e13caa 100644 --- a/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py +++ b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py @@ -8,7 +8,7 @@ Create Date: 2024-05-14 07:31:29.702766 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '7b45942e39bb' @@ -20,8 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('data_source_api_key_auth_bindings', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('category', sa.String(length=255), nullable=False), sa.Column('provider', sa.String(length=255), nullable=False), sa.Column('credentials', sa.Text(), nullable=True), diff --git a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py index 67b61e5c76..f09a682f28 100644 --- a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py +++ b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py @@ -1,6 +1,6 @@ """add workflow tool -Revision ID: 7bdef072e63a +Revision ID: 7bdef072e63a Revises: 5fda94355fce Create Date: 2024-05-04 09:47:19.366961 @@ -8,7 +8,7 @@ Create Date: 2024-05-04 09:47:19.366961 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '7bdef072e63a' @@ -20,12 +20,12 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tool_workflow_providers', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('name', sa.String(length=40), nullable=False), sa.Column('icon', sa.String(length=255), nullable=False), - sa.Column('app_id', models.StringUUID(), nullable=False), - sa.Column('user_id', models.StringUUID(), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('user_id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('description', sa.Text(), nullable=False), sa.Column('parameter_configuration', sa.Text(), server_default='[]', nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), diff --git a/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py b/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py index ff53eb65a6..865572f3a7 100644 --- a/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py +++ b/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py @@ -8,7 +8,7 @@ Create Date: 2024-06-25 03:20:46.012193 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '7e6a8693e07a' @@ -20,9 +20,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('dataset_permissions', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('dataset_id', models.StringUUID(), nullable=False), - sa.Column('account_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('account_id', models.types.StringUUID(), nullable=False), sa.Column('has_permission', sa.Boolean(), server_default=sa.text('true'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.PrimaryKeyConstraint('id', name='dataset_permission_pkey') diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py index 1ac44d083a..469c04338a 100644 --- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py +++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql -import models as models +import models.types # revision identifiers, used by Alembic. revision = 'c031d46af369' @@ -21,8 +21,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('trace_app_config', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('app_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), sa.Column('tracing_provider', sa.String(length=255), nullable=True), sa.Column('tracing_config', sa.JSON(), nullable=True), sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), diff --git a/api/models/__init__.py b/api/models/__init__.py index 30ceef057e..1d8bae6cfa 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,29 +1,55 @@ -from enum import Enum +from .account import Account, AccountIntegrate, InvitationCode, Tenant +from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment +from .model import ( + ApiToken, + App, + AppMode, + Conversation, + EndUser, + FileUploadConfig, + InstalledApp, + Message, + MessageAnnotation, + MessageFile, + RecommendedApp, + Site, + UploadFile, +) +from .source import DataSourceOauthBinding +from .tools import ToolFile +from .workflow import ( + ConversationVariable, + Workflow, + WorkflowAppLog, + WorkflowRun, +) -from .model import App, AppMode, Message -from .types import StringUUID -from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus - -__all__ = ["ConversationVariable", "StringUUID", "AppMode", "WorkflowNodeExecutionStatus", "Workflow", "App", "Message"] - - -class CreatedByRole(Enum): - """ - Enum class for createdByRole - """ - - ACCOUNT = "account" - END_USER = "end_user" - - @classmethod - def value_of(cls, value: str) -> "CreatedByRole": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for role in cls: - if role.value == value: - return role - raise ValueError(f"invalid createdByRole value {value}") +__all__ = [ + "ConversationVariable", + "Document", + "Dataset", + "DatasetProcessRule", + "DocumentSegment", + "DataSourceOauthBinding", + "AppMode", + "Workflow", + "App", + "Message", + "EndUser", + "MessageFile", + "UploadFile", + "Account", + "WorkflowAppLog", + "WorkflowRun", + "Site", + "InstalledApp", + "RecommendedApp", + "ApiToken", + "AccountIntegrate", + "InvitationCode", + "Tenant", + "Conversation", + "MessageAnnotation", + "FileUploadConfig", + "ToolFile", +] diff --git a/api/models/dataset.py b/api/models/dataset.py index 4224ee5e9c..4e2ccab7e8 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -560,7 +560,7 @@ class DocumentSegment(db.Model): ) def get_sign_content(self): - pattern = r"/files/([a-f0-9\-]+)/image-preview" + pattern = r"/files/([a-f0-9\-]+)/file-preview" text = self.content matches = re.finditer(pattern, text) signed_urls = [] @@ -568,7 +568,7 @@ class DocumentSegment(db.Model): upload_file_id = match.group(1) nonce = os.urandom(16).hex() timestamp = str(int(time.time())) - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() diff --git a/api/models/enums.py b/api/models/enums.py new file mode 100644 index 0000000000..a83d35e042 --- /dev/null +++ b/api/models/enums.py @@ -0,0 +1,16 @@ +from enum import Enum + + +class CreatedByRole(str, Enum): + ACCOUNT = "account" + END_USER = "end_user" + + +class UserFrom(str, Enum): + ACCOUNT = "account" + END_USER = "end-user" + + +class WorkflowRunTriggeredFrom(str, Enum): + DEBUGGING = "debugging" + APP_RUN = "app-run" diff --git a/api/models/model.py b/api/models/model.py index cefbb96b8d..12c57ab372 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,29 +1,44 @@ import json import re import uuid +from collections.abc import Mapping, Sequence +from datetime import datetime from enum import Enum from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from models.workflow import Workflow +from typing import Any, Literal + from flask import request from flask_login import UserMixin +from pydantic import BaseModel, Field from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config +from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType +from core.file import helpers as file_helpers from core.file.tool_file_parser import ToolFileParser -from core.file.upload_file_parser import UploadFileParser from extensions.ext_database import db from libs.helper import generate_string from models.base import Base +from models.enums import CreatedByRole from .account import Account, Tenant from .types import StringUUID -class DifySetup(Base): +class FileUploadConfig(BaseModel): + enabled: bool = Field(default=False) + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_extensions: Sequence[str] = Field(default_factory=list) + allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + number_limits: int = Field(default=0, gt=0, le=10) + + +class DifySetup(db.Model): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -31,7 +46,7 @@ class DifySetup(Base): setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class AppMode(Enum): +class AppMode(str, Enum): COMPLETION = "completion" WORKFLOW = "workflow" CHAT = "chat" @@ -63,7 +78,7 @@ class App(Base): __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) @@ -538,7 +553,7 @@ class Conversation(Base): mode = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) - inputs = db.Column(db.JSON) + _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) introduction = db.Column(db.Text) system_instruction = db.Column(db.Text) system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) @@ -560,6 +575,28 @@ class Conversation(Base): is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + @property + def inputs(self): + inputs = self._inputs.copy() + for key, value in inputs.items(): + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + inputs[key] = File.model_validate(value) + elif isinstance(value, list) and all( + isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value + ): + inputs[key] = [File.model_validate(item) for item in value] + return inputs + + @inputs.setter + def inputs(self, value: Mapping[str, Any]): + inputs = dict(value) + for k, v in inputs.items(): + if isinstance(v, File): + inputs[k] = v.model_dump() + elif isinstance(v, list) and all(isinstance(item, File) for item in v): + inputs[k] = [item.model_dump() for item in v] + self._inputs = inputs + @property def model_config(self): model_config = {} @@ -711,13 +748,13 @@ class Message(Base): model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) - inputs = db.Column(db.JSON) - query = db.Column(db.Text, nullable=False) + _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) + query: Mapped[str] = db.Column(db.Text, nullable=False) message = db.Column(db.JSON, nullable=False) message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - answer = db.Column(db.Text, nullable=False) + answer: Mapped[str] = db.Column(db.Text, nullable=False) answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) @@ -728,15 +765,37 @@ class Message(Base): status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) error = db.Column(db.Text) message_metadata = db.Column(db.Text) - invoke_from = db.Column(db.String(255), nullable=True) + invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) - from_end_user_id = db.Column(StringUUID) - from_account_id = db.Column(StringUUID) + from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) + from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) + @property + def inputs(self): + inputs = self._inputs.copy() + for key, value in inputs.items(): + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + inputs[key] = File.model_validate(value) + elif isinstance(value, list) and all( + isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value + ): + inputs[key] = [File.model_validate(item) for item in value] + return inputs + + @inputs.setter + def inputs(self, value: Mapping[str, Any]): + inputs = dict(value) + for k, v in inputs.items(): + if isinstance(v, File): + inputs[k] = v.model_dump() + elif isinstance(v, list) and all(isinstance(item, File) for item in v): + inputs[k] = [item.model_dump() for item in v] + self._inputs = inputs + @property def re_sign_file_url_answer(self) -> str: if not self.answer: @@ -783,19 +842,29 @@ class Message(Base): sign_url = ToolFileParser.get_tool_file_manager().sign_file( tool_file_id=tool_file_id, extension=extension ) - else: + elif "file-preview" in url: # get upload file id - upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp=" + upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp=" result = re.search(upload_file_id_pattern, url) if not result: continue upload_file_id = result.group(1) - if not upload_file_id: continue - - sign_url = UploadFileParser.get_signed_temp_image_url(upload_file_id) + sign_url = file_helpers.get_signed_file_url(upload_file_id) + elif "image-preview" in url: + # image-preview is deprecated, use file-preview instead + upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp=" + result = re.search(upload_file_id_pattern, url) + if not result: + continue + upload_file_id = result.group(1) + if not upload_file_id: + continue + sign_url = file_helpers.get_signed_file_url(upload_file_id) + else: + continue re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url) @@ -881,50 +950,71 @@ class Message(Base): @property def message_files(self): - return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() + from factories import file_factory - @property - def files(self): - message_files = self.message_files + message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() + current_app = db.session.query(App).filter(App.id == self.app_id).first() + if not current_app: + raise ValueError(f"App {self.app_id} not found") - files = [] + files: list[File] = [] for message_file in message_files: - url = message_file.url - if message_file.type == "image": - if message_file.transfer_method == "local_file": - upload_file = ( - db.session.query(UploadFile).filter(UploadFile.id == message_file.upload_file_id).first() - ) - - url = UploadFileParser.get_image_data(upload_file=upload_file, force_url=True) - if message_file.transfer_method == "tool_file": - # get tool file id - tool_file_id = message_file.url.split("/")[-1] - # trim extension - tool_file_id = tool_file_id.split(".")[0] - - # get extension - if "." in message_file.url: - extension = f'.{message_file.url.split(".")[-1]}' - if len(extension) > 10: - extension = ".bin" - else: - extension = ".bin" - # add sign url - url = ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=tool_file_id, extension=extension - ) - - files.append( - { + if message_file.transfer_method == "local_file": + if message_file.upload_file_id is None: + raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "upload_file_id": message_file.upload_file_id, + "transfer_method": message_file.transfer_method, + "type": message_file.type, + }, + tenant_id=current_app.tenant_id, + user_id=self.from_account_id or self.from_end_user_id or "", + role=CreatedByRole(message_file.created_by_role), + config=FileExtraConfig(), + ) + elif message_file.transfer_method == "remote_url": + if message_file.url is None: + raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "type": message_file.type, + "transfer_method": message_file.transfer_method, + "url": message_file.url, + }, + tenant_id=current_app.tenant_id, + user_id=self.from_account_id or self.from_end_user_id or "", + role=CreatedByRole(message_file.created_by_role), + config=FileExtraConfig(), + ) + elif message_file.transfer_method == "tool_file": + mapping = { "id": message_file.id, "type": message_file.type, - "url": url, - "belongs_to": message_file.belongs_to or "user", + "transfer_method": message_file.transfer_method, + "tool_file_id": message_file.upload_file_id, } - ) + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=current_app.tenant_id, + user_id=self.from_account_id or self.from_end_user_id or "", + role=CreatedByRole(message_file.created_by_role), + config=FileExtraConfig(), + ) + else: + raise ValueError( + f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}" + ) + files.append(file) - return files + result = [ + {"belongs_to": message_file.belongs_to, **file.to_dict()} + for (file, message_file) in zip(files, message_files) + ] + + return result @property def workflow_run(self): @@ -1014,16 +1104,39 @@ class MessageFile(Base): db.Index("message_file_created_by_idx", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - message_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - transfer_method = db.Column(db.String(255), nullable=False) - url = db.Column(db.Text, nullable=True) - belongs_to = db.Column(db.String(255), nullable=True) - upload_file_id = db.Column(StringUUID, nullable=True) - created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + def __init__( + self, + *, + message_id: str, + type: FileType, + transfer_method: FileTransferMethod, + url: str | None = None, + belongs_to: Literal["user", "assistant"] | None = None, + upload_file_id: str | None = None, + created_by_role: CreatedByRole, + created_by: str, + ): + self.message_id = message_id + self.type = type + self.transfer_method = transfer_method + self.url = url + self.belongs_to = belongs_to + self.upload_file_id = upload_file_id + self.created_by_role = created_by_role + self.created_by = created_by + + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + message_id: Mapped[str] = db.Column(StringUUID, nullable=False) + type: Mapped[str] = db.Column(db.String(255), nullable=False) + transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False) + url: Mapped[Optional[str]] = db.Column(db.Text, nullable=True) + belongs_to: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) + upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) + created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) class MessageAnnotation(Base): @@ -1261,21 +1374,58 @@ class UploadFile(Base): db.Index("upload_file_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - storage_type = db.Column(db.String(255), nullable=False) - key = db.Column(db.String(255), nullable=False) - name = db.Column(db.String(255), nullable=False) - size = db.Column(db.Integer, nullable=False) - extension = db.Column(db.String(255), nullable=False) - mime_type = db.Column(db.String(255), nullable=True) - created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - used = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - used_by = db.Column(StringUUID, nullable=True) - used_at = db.Column(db.DateTime, nullable=True) - hash = db.Column(db.String(255), nullable=True) + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + storage_type: Mapped[str] = db.Column(db.String(255), nullable=False) + key: Mapped[str] = db.Column(db.String(255), nullable=False) + name: Mapped[str] = db.Column(db.String(255), nullable=False) + size: Mapped[int] = db.Column(db.Integer, nullable=False) + extension: Mapped[str] = db.Column(db.String(255), nullable=False) + mime_type: Mapped[str] = db.Column(db.String(255), nullable=True) + created_by_role: Mapped[str] = db.Column( + db.String(255), nullable=False, server_default=db.text("'account'::character varying") + ) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) + used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) + hash: Mapped[str | None] = db.Column(db.String(255), nullable=True) + + def __init__( + self, + *, + tenant_id: str, + storage_type: str, + key: str, + name: str, + size: int, + extension: str, + mime_type: str, + created_by_role: str, + created_by: str, + created_at: datetime, + used: bool, + used_by: str | None = None, + used_at: datetime | None = None, + hash: str | None = None, + ) -> None: + self.tenant_id = tenant_id + self.storage_type = storage_type + self.key = key + self.name = name + self.size = size + self.extension = extension + self.mime_type = mime_type + self.created_by_role = created_by_role + self.created_by = created_by + self.created_at = created_at + self.used = used + self.used_by = used_by + self.used_at = used_at + self.hash = hash class ApiRequest(Base): diff --git a/api/models/provider.py b/api/models/provider.py index d3c6db9bab..58c1978573 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -18,24 +18,6 @@ class ProviderType(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class ProviderQuotaType(Enum): - PAID = "paid" - """hosted paid quota""" - - FREE = "free" - """third-party free quota""" - - TRIAL = "trial" - """hosted trial quota""" - - @staticmethod - def value_of(value): - for member in ProviderQuotaType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - class Provider(Base): """ Provider model representing the API providers and their configurations. diff --git a/api/models/tools.py b/api/models/tools.py index b0d4ea3399..8712f37946 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,5 +1,6 @@ import json from datetime import datetime +from typing import Optional from deprecated import deprecated from sqlalchemy import ForeignKey @@ -67,7 +68,7 @@ class ApiToolProvider(Base): icon = db.Column(db.String(255), nullable=False) # original schema schema = db.Column(db.Text, nullable=False) - schema_type_str = db.Column(db.String(40), nullable=False) + schema_type_str: Mapped[str] = db.Column(db.String(40), nullable=False) # who created this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -168,6 +169,10 @@ class WorkflowToolProvider(Base): db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) + @property + def schema_type(self) -> ApiProviderSchemaType: + return ApiProviderSchemaType.value_of(self.schema_type_str) + @property def user(self) -> Account | None: return db.session.query(Account).filter(Account.id == self.user_id).first() @@ -262,7 +267,6 @@ class ToolFile(Base): __tablename__ = "tool_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_file_pkey"), - # add index for conversation_id db.Index("tool_file_conversation_id_idx", "conversation_id"), ) @@ -322,3 +326,34 @@ class DeprecatedPublishedAppTool(Base): @property def app(self) -> App: return db.session.query(App).filter(App.id == self.app_id).first() + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + user_id: Mapped[str] = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) + file_key: Mapped[str] = db.Column(db.String(255), nullable=False) + mimetype: Mapped[str] = db.Column(db.String(255), nullable=False) + original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True) + name: Mapped[str] = mapped_column(default="") + size: Mapped[int] = mapped_column(default=-1) + + def __init__( + self, + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str] = None, + file_key: str, + mimetype: str, + original_url: Optional[str] = None, + name: str, + size: int, + ): + self.user_id = user_id + self.tenant_id = tenant_id + self.conversation_id = conversation_id + self.file_key = file_key + self.mimetype = mimetype + self.original_url = original_url + self.name = name + self.size = size diff --git a/api/models/workflow.py b/api/models/workflow.py index 0b7d255954..da3152ec75 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -12,38 +12,18 @@ from sqlalchemy.orm import Mapped, mapped_column import contexts from constants import HIDDEN_VALUE -from core.app.segments import SecretVariable, Variable, factory from core.helper import encrypter +from core.variables import SecretVariable, Variable from extensions.ext_database import db +from factories import variable_factory from libs import helper from models.base import Base +from models.enums import CreatedByRole from .account import Account from .types import StringUUID -class CreatedByRole(Enum): - """ - Created By Role Enum - """ - - ACCOUNT = "account" - END_USER = "end_user" - - @classmethod - def value_of(cls, value: str) -> "CreatedByRole": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid created by role value {value}") - - class WorkflowType(Enum): """ Workflow Type Enum @@ -118,23 +98,23 @@ class Workflow(Base): db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) - app_id: Mapped[str] = db.Column(StringUUID, nullable=False) - type: Mapped[str] = db.Column(db.String(255), nullable=False) - version: Mapped[str] = db.Column(db.String(255), nullable=False) - graph: Mapped[str] = db.Column(db.Text) - features: Mapped[str] = db.Column(db.Text) - created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + type: Mapped[str] = mapped_column(db.String(255), nullable=False) + version: Mapped[str] = mapped_column(db.String(255), nullable=False) + graph: Mapped[str] = mapped_column(db.Text) + _features: Mapped[str] = mapped_column("features") + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) - updated_by: Mapped[str] = db.Column(StringUUID) - updated_at: Mapped[datetime] = db.Column(db.DateTime) - _environment_variables: Mapped[str] = db.Column( + updated_by: Mapped[str] = mapped_column(StringUUID) + updated_at: Mapped[datetime] = mapped_column(db.DateTime) + _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" ) - _conversation_variables: Mapped[str] = db.Column( + _conversation_variables: Mapped[str] = mapped_column( "conversation_variables", db.Text, nullable=False, server_default="{}" ) @@ -173,6 +153,34 @@ class Workflow(Base): def graph_dict(self) -> Mapping[str, Any]: return json.loads(self.graph) if self.graph else {} + @property + def features(self) -> str: + """ + Convert old features structure to new features structure. + """ + if not self._features: + return self._features + + features = json.loads(self._features) + if features.get("file_upload", {}).get("image", {}).get("enabled", False): + image_enabled = True + image_number_limits = int(features["file_upload"]["image"].get("number_limits", 1)) + image_transfer_methods = features["file_upload"]["image"].get( + "transfer_methods", ["remote_url", "local_file"] + ) + features["file_upload"]["enabled"] = image_enabled + features["file_upload"]["number_limits"] = image_number_limits + features["file_upload"]["allowed_upload_methods"] = image_transfer_methods + features["file_upload"]["allowed_file_types"] = ["image"] + features["file_upload"]["allowed_extensions"] = [] + del features["file_upload"]["image"] + self._features = json.dumps(features) + return self._features + + @features.setter + def features(self, value: str) -> None: + self._features = value + @property def features_dict(self) -> Mapping[str, Any]: return json.loads(self.features) if self.features else {} @@ -231,7 +239,7 @@ class Workflow(Base): tenant_id = contexts.tenant_id.get() environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) - results = [factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()] + results = [variable_factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()] # decrypt secret variables value decrypt_func = ( @@ -244,6 +252,10 @@ class Workflow(Base): @environment_variables.setter def environment_variables(self, value: Sequence[Variable]): + if not value: + self._environment_variables = "{}" + return + tenant_id = contexts.tenant_id.get() value = list(value) @@ -292,7 +304,7 @@ class Workflow(Base): self._conversation_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._conversation_variables) - results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()] + results = [variable_factory.build_variable_from_mapping(v) for v in variables_dict.values()] return results @conversation_variables.setter @@ -303,28 +315,6 @@ class Workflow(Base): ) -class WorkflowRunTriggeredFrom(Enum): - """ - Workflow Run Triggered From Enum - """ - - DEBUGGING = "debugging" - APP_RUN = "app-run" - - @classmethod - def value_of(cls, value: str) -> "WorkflowRunTriggeredFrom": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid workflow run triggered from value {value}") - - class WorkflowRunStatus(Enum): """ Workflow Run Status Enum @@ -405,7 +395,7 @@ class WorkflowRun(Base): graph = db.Column(db.Text) inputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) - outputs = db.Column(db.Text) + outputs: Mapped[str] = db.Column(db.Text) error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) @@ -417,27 +407,27 @@ class WorkflowRun(Base): @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property def graph_dict(self): - return json.loads(self.graph) if self.graph else None + return json.loads(self.graph) if self.graph else {} @property - def inputs_dict(self): - return json.loads(self.inputs) if self.inputs else None + def inputs_dict(self) -> Mapping[str, Any]: + return json.loads(self.inputs) if self.inputs else {} @property - def outputs_dict(self): - return json.loads(self.outputs) if self.outputs else None + def outputs_dict(self) -> Mapping[str, Any]: + return json.loads(self.outputs) if self.outputs else {} @property def message(self): @@ -644,14 +634,14 @@ class WorkflowNodeExecution(Base): @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property @@ -676,7 +666,7 @@ class WorkflowNodeExecution(Base): extras = {} if self.execution_metadata_dict: - from core.workflow.entities.node_entities import NodeType + from core.workflow.nodes import NodeType if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: tool_info = self.execution_metadata_dict["tool_info"] @@ -763,14 +753,14 @@ class WorkflowAppLog(Base): @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @@ -809,4 +799,4 @@ class ConversationVariable(Base): def to_variable(self) -> Variable: mapping = json.loads(self.data) - return factory.build_variable_from_mapping(mapping) + return variable_factory.build_variable_from_mapping(mapping) diff --git a/api/poetry.lock b/api/poetry.lock index 8967345e89..18e26d2be3 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -5762,13 +5762,13 @@ sympy = "*" [[package]] name = "openai" -version = "1.51.2" +version = "1.52.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.51.2-py3-none-any.whl", hash = "sha256:5c5954711cba931423e471c37ff22ae0fd3892be9b083eee36459865fbbb83fa"}, - {file = "openai-1.51.2.tar.gz", hash = "sha256:c6a51fac62a1ca9df85a522e462918f6bb6bc51a8897032217e453a0730123a6"}, + {file = "openai-1.52.0-py3-none-any.whl", hash = "sha256:0c249f20920183b0a2ca4f7dba7b0452df3ecd0fa7985eb1d91ad884bc3ced9c"}, + {file = "openai-1.52.0.tar.gz", hash = "sha256:95c65a5f77559641ab8f3e4c3a050804f7b51d278870e2ec1f7444080bfe565a"}, ] [package.dependencies] @@ -7098,6 +7098,17 @@ typing-extensions = ">3.10,<4.6.0 || >4.6.0" [package.extras] dev = ["build", "coverage", "furo", "invoke", "mypy", "pytest", "pytest-cov", "pytest-mypy-testing", "ruff", "sphinx", "sphinx-autodoc-typehints", "tox", "twine", "wheel"] +[[package]] +name = "pydub" +version = "0.25.1" +description = "Manipulate audio with an simple and easy high level interface" +optional = false +python-versions = "*" +files = [ + {file = "pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6"}, + {file = "pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f"}, +] + [[package]] name = "pygments" version = "2.18.0" @@ -10784,4 +10795,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "51f048197baebf9ffdc393e5990b9a90185bc5ff515b8b5d2d9b72de900cf6e2" +content-hash = "642b2dae9e18ee6671d3d2c7129cb9a77327b69dacba996d00de2a9475d5bad3" diff --git a/api/pyproject.toml b/api/pyproject.toml index a44968cbea..b705161e58 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -158,7 +158,7 @@ nomic = "~3.1.2" novita-client = "~0.5.7" numpy = "~1.26.4" oci = "~2.135.1" -openai = "~1.51.2" +openai = "~1.52.0" openpyxl = "~3.1.5" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } psycopg2-binary = "~2.9.6" @@ -216,6 +216,7 @@ matplotlib = "~3.8.2" newspaper3k = "0.2.8" nltk = "3.8.1" numexpr = "~2.9.0" +pydub = "~0.25.1" qrcode = "~7.4.2" twilio = "~9.0.4" vanna = { version = "0.7.3", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } diff --git a/api/services/account_service.py b/api/services/account_service.py index eda6011aef..529b716773 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1,6 +1,7 @@ import base64 import json import logging +import random import secrets import uuid from datetime import datetime, timedelta, timezone @@ -34,7 +35,9 @@ from models.model import DifySetup from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, + AccountNotFoundError, AccountNotLinkTenantError, + AccountPasswordError, AccountRegisterError, CannotOperateSelfError, CurrentPasswordIncorrectError, @@ -42,10 +45,12 @@ from services.errors.account import ( LinkAccountIntegrateError, MemberNotInTenantError, NoPermissionError, - RateLimitExceededError, RoleAlreadyAssignedError, TenantNotFoundError, ) +from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.feature_service import FeatureService +from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task @@ -61,7 +66,11 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=30) class AccountService: - reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) + email_code_login_rate_limiter = RateLimiter( + prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 + ) + LOGIN_MAX_ERROR_LIMITS = 5 @staticmethod def _get_refresh_token_key(refresh_token: str) -> str: @@ -127,23 +136,34 @@ class AccountService: return token @staticmethod - def authenticate(email: str, password: str) -> Account: + def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account: """authenticate account with email and password""" account = Account.query.filter_by(email=email).first() if not account: - raise AccountLoginError("Invalid email or password.") + raise AccountNotFoundError() if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise AccountLoginError("Account is banned or closed.") + if password and invite_token and account.password is None: + # if invite_token is valid, set password and password_salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() + password_hashed = hash_password(password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account.password = base64_password_hashed + account.password_salt = base64_salt + + if account.password is None or not compare_password(password, account.password, account.password_salt): + raise AccountPasswordError("Invalid email or password.") + if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) - db.session.commit() - if account.password is None or not compare_password(password, account.password, account.password_salt): - raise AccountLoginError("Invalid email or password.") + db.session.commit() + return account @staticmethod @@ -169,9 +189,18 @@ class AccountService: @staticmethod def create_account( - email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light" + email: str, + name: str, + interface_language: str, + password: Optional[str] = None, + interface_theme: str = "light", + is_setup: Optional[bool] = False, ) -> Account: """create account""" + if not FeatureService.get_system_features().is_allow_register and not is_setup: + from controllers.console.error import NotAllowedRegister + + raise NotAllowedRegister() account = Account() account.email = email account.name = name @@ -198,6 +227,19 @@ class AccountService: db.session.commit() return account + @staticmethod + def create_account_and_tenant( + email: str, name: str, interface_language: str, password: Optional[str] = None + ) -> Account: + """create account""" + account = AccountService.create_account( + email=email, name=name, interface_language=interface_language, password=password + ) + + TenantService.create_owner_tenant_if_not_exist(account=account) + + return account + @staticmethod def link_account_integrate(provider: str, open_id: str, account: Account) -> None: """Link account integrate""" @@ -256,6 +298,10 @@ class AccountService: if ip_address: AccountService.update_login_info(account=account, ip_address=ip_address) + if account.status == AccountStatus.PENDING.value: + account.status = AccountStatus.ACTIVE.value + db.session.commit() + access_token = AccountService.get_account_jwt_token(account=account) refresh_token = _generate_refresh_token() @@ -294,13 +340,29 @@ class AccountService: return AccountService.load_user(account_id) @classmethod - def send_reset_password_email(cls, account): - if cls.reset_password_rate_limiter.is_rate_limited(account.email): - raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.") + def send_reset_password_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + ): + account_email = account.email if account else email - token = TokenManager.generate_token(account, "reset_password") - send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token) - cls.reset_password_rate_limiter.increment_rate_limit(account.email) + if cls.reset_password_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import PasswordResetRateLimitExceededError + + raise PasswordResetRateLimitExceededError() + + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="reset_password", additional_data={"code": code} + ) + send_reset_password_mail_task.delay( + language=language, + to=account_email, + code=code, + ) + cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token @classmethod @@ -311,11 +373,125 @@ class AccountService: def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: return TokenManager.get_token_data(token, "reset_password") + @classmethod + def send_email_code_login_email( + cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + ): + if cls.email_code_login_rate_limiter.is_rate_limited(email): + from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError + + raise EmailCodeLoginRateLimitExceededError() + + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="email_code_login", additional_data={"code": code} + ) + send_email_code_login_mail_task.delay( + language=language, + to=account.email if account else email, + code=code, + ) + cls.email_code_login_rate_limiter.increment_rate_limit(email) + return token + + @classmethod + def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "email_code_login") + + @classmethod + def revoke_email_code_login_token(cls, token: str): + TokenManager.revoke_token(token, "email_code_login") + + @classmethod + def get_user_through_email(cls, email: str): + account = db.session.query(Account).filter(Account.email == email).first() + if not account: + return None + + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: + raise Unauthorized("Account is banned or closed.") + + return account + + @staticmethod + def add_login_error_rate_limit(email: str) -> None: + key = f"login_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, 60 * 60 * 24, count) + + @staticmethod + def is_login_error_rate_limit(email: str) -> bool: + key = f"login_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + + count = int(count) + if count > AccountService.LOGIN_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + def reset_login_error_rate_limit(email: str): + key = f"login_error_rate_limit:{email}" + redis_client.delete(key) + + @staticmethod + def is_email_send_ip_limit(ip_address: str): + minute_key = f"email_send_ip_limit_minute:{ip_address}" + freeze_key = f"email_send_ip_limit_freeze:{ip_address}" + hour_limit_key = f"email_send_ip_limit_hour:{ip_address}" + + # check ip is frozen + if redis_client.get(freeze_key): + return True + + # check current minute count + current_minute_count = redis_client.get(minute_key) + if current_minute_count is None: + current_minute_count = 0 + current_minute_count = int(current_minute_count) + + # check current hour count + if current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE: + hour_limit_count = redis_client.get(hour_limit_key) + if hour_limit_count is None: + hour_limit_count = 0 + hour_limit_count = int(hour_limit_count) + + if hour_limit_count >= 1: + redis_client.setex(freeze_key, 60 * 60, 1) + return True + else: + redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes + + # add hour limit count + redis_client.incr(hour_limit_key) + redis_client.expire(hour_limit_key, 60 * 60) + + return True + + redis_client.setex(minute_key, 60, current_minute_count + 1) + redis_client.expire(minute_key, 60) + + return False + + +def _get_login_cache_key(*, account_id: str, token: str): + return f"account_login:{account_id}:{token}" + class TenantService: @staticmethod - def create_tenant(name: str) -> Tenant: + def create_tenant(name: str, is_setup: Optional[bool] = False) -> Tenant: """Create tenant""" + if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: + from controllers.console.error import NotAllowedCreateWorkspace + + raise NotAllowedCreateWorkspace() tenant = Tenant(name=name) db.session.add(tenant) @@ -326,8 +502,12 @@ class TenantService: return tenant @staticmethod - def create_owner_tenant_if_not_exist(account: Account, name: Optional[str] = None): + def create_owner_tenant_if_not_exist( + account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False + ): """Create owner tenant if not exist""" + if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: + raise WorkSpaceNotAllowedCreateError() available_ta = ( TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() ) @@ -336,9 +516,9 @@ class TenantService: return if name: - tenant = TenantService.create_tenant(name) + tenant = TenantService.create_tenant(name=name, is_setup=is_setup) else: - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup) TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant db.session.commit() @@ -352,8 +532,13 @@ class TenantService: logging.error(f"Tenant {tenant.id} has already an owner.") raise Exception("Tenant already has an owner.") - ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) - db.session.add(ta) + ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + if ta: + ta.role = role + else: + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) + db.session.add(ta) + db.session.commit() return ta @@ -570,12 +755,13 @@ class RegisterService: name=name, interface_language=languages[0], password=password, + is_setup=True, ) account.last_login_ip = ip_address account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) - TenantService.create_owner_tenant_if_not_exist(account) + TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) db.session.add(dify_setup) @@ -600,27 +786,33 @@ class RegisterService: provider: Optional[str] = None, language: Optional[str] = None, status: Optional[AccountStatus] = None, + is_setup: Optional[bool] = False, ) -> Account: db.session.begin_nested() """Register account""" try: account = AccountService.create_account( - email=email, name=name, interface_language=language or languages[0], password=password + email=email, + name=name, + interface_language=language or languages[0], + password=password, + is_setup=is_setup, ) account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if dify_config.EDITION != "SELF_HOSTED": - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + if FeatureService.get_system_features().is_allow_create_workspace: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant - tenant_was_created.send(tenant) db.session.commit() + except WorkSpaceNotAllowedCreateError: + db.session.rollback() except Exception as e: db.session.rollback() logging.error(f"Register failed: {e}") @@ -639,7 +831,9 @@ class RegisterService: TenantService.check_member_permission(tenant, inviter, None, "add") name = email.split("@")[0] - account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING) + account = cls.register( + email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True + ) # Create new tenant member for invited tenant TenantService.create_tenant_member(tenant, account, role) TenantService.switch_tenant(account, tenant.id) @@ -679,6 +873,11 @@ class RegisterService: redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data)) return token + @classmethod + def is_valid_invite_token(cls, token: str) -> bool: + data = redis_client.get(cls._get_invitation_token_key(token)) + return data is not None + @classmethod def revoke_token(cls, workspace_id: str, email: str, token: str): if workspace_id and email: @@ -727,7 +926,9 @@ class RegisterService: } @classmethod - def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]: + def _get_invitation_by_token( + cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None + ) -> Optional[dict[str, str]]: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 887fb878b9..c8819535f1 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -68,7 +68,7 @@ class AgentService: "iterations": len(agent_thoughts), }, "iterations": [], - "files": message.files, + "files": message.message_files, } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 54594e1175..750d0a8cd2 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -3,9 +3,9 @@ import logging import httpx import yaml # type: ignore -from core.app.segments import factory from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_database import db +from factories import variable_factory from models.account import Account from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow @@ -254,14 +254,18 @@ class AppDslService: # init draft workflow environment_variables_list = workflow_data.get("environment_variables") or [] - environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] conversation_variables_list = workflow_data.get("conversation_variables") or [] - conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] workflow_service = WorkflowService() draft_workflow = workflow_service.sync_draft_workflow( app_model=app, graph=workflow_data.get("graph", {}), - features=workflow_data.get("../core/app/features", {}), + features=workflow_data.get("features", {}), unique_hash=None, account=account, environment_variables=environment_variables, @@ -295,9 +299,13 @@ class AppDslService: # sync draft workflow environment_variables_list = workflow_data.get("environment_variables") or [] - environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] conversation_variables_list = workflow_data.get("conversation_variables") or [] - conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] draft_workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=workflow_data.get("graph", {}), diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index b13ed9718d..b356ec521a 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Union from openai._exceptions import RateLimitError @@ -23,7 +23,7 @@ class AppGenerateService: cls, app_model: App, user: Union[Account, EndUser], - args: Any, + args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, ): diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index c519f0b0e5..4eed26efdf 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict from configs import dify_config from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity -from core.entities.provider_entities import QuotaConfiguration +from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ( @@ -15,7 +15,7 @@ from core.model_runtime.entities.provider_entities import ( ProviderHelpEntity, SimpleProviderEntity, ) -from models.provider import ProviderQuotaType, ProviderType +from models.provider import ProviderType class CustomConfigurationStatus(Enum): diff --git a/api/services/errors/account.py b/api/services/errors/account.py index 82dd9f944a..5aca12ffeb 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -13,6 +13,10 @@ class AccountLoginError(BaseServiceError): pass +class AccountPasswordError(BaseServiceError): + pass + + class AccountNotLinkTenantError(BaseServiceError): pass diff --git a/api/services/errors/workspace.py b/api/services/errors/workspace.py new file mode 100644 index 0000000000..714064ffdf --- /dev/null +++ b/api/services/errors/workspace.py @@ -0,0 +1,9 @@ +from services.errors.base import BaseServiceError + + +class WorkSpaceNotAllowedCreateError(BaseServiceError): + pass + + +class WorkSpaceNotFoundError(BaseServiceError): + pass diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 30d819bd30..4d0a5f67ce 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -44,6 +44,11 @@ class SystemFeatureModel(BaseModel): enable_web_sso_switch_component: bool = False enable_marketplace: bool = True max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE + enable_email_code_login: bool = False + enable_email_password_login: bool = True + enable_social_oauth_login: bool = False + is_allow_register: bool = False + is_allow_create_workspace: bool = False class FeatureService: @@ -62,8 +67,11 @@ class FeatureService: def get_system_features(cls) -> SystemFeatureModel: system_features = SystemFeatureModel() + cls._fulfill_system_params_from_env(system_features) + if dify_config.ENTERPRISE_ENABLED: system_features.enable_web_sso_switch_component = True + cls._fulfill_params_from_enterprise(system_features) if dify_config.MARKETPLACE_ENABLED: @@ -71,6 +79,14 @@ class FeatureService: return system_features + @classmethod + def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel): + system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN + system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN + system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN + system_features.is_allow_register = dify_config.ALLOW_REGISTER + system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE + @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): features.can_replace_logo = dify_config.CAN_REPLACE_LOGO @@ -118,7 +134,19 @@ class FeatureService: def _fulfill_params_from_enterprise(cls, features): enterprise_info = EnterpriseService.get_info() - features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] - features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] - features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] - features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] + if "sso_enforced_for_signin" in enterprise_info: + features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + if "sso_enforced_for_signin_protocol" in enterprise_info: + features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + if "sso_enforced_for_web" in enterprise_info: + features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + if "sso_enforced_for_web_protocol" in enterprise_info: + features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] + if "enable_email_code_login" in enterprise_info: + features.enable_email_code_login = enterprise_info["enable_email_code_login"] + if "enable_email_password_login" in enterprise_info: + features.enable_email_password_login = enterprise_info["enable_email_password_login"] + if "is_allow_register" in enterprise_info: + features.is_allow_register = enterprise_info["is_allow_register"] + if "is_allow_create_workspace" in enterprise_info: + features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"] diff --git a/api/services/file_service.py b/api/services/file_service.py index bedec76334..84ccc4e882 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -2,79 +2,67 @@ import datetime import hashlib import uuid from collections.abc import Generator -from typing import Union +from typing import Literal, Union from flask_login import current_user from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound from configs import dify_config -from core.file.upload_file_parser import UploadFileParser +from constants import ( + AUDIO_EXTENSIONS, + DOCUMENT_EXTENSIONS, + IMAGE_EXTENSIONS, + VIDEO_EXTENSIONS, +) +from core.file import helpers as file_helpers from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Account from models.model import EndUser, UploadFile -from services.errors.file import FileTooLargeError, UnsupportedFileTypeError - -IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] -IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) - -ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] -UNSTRUCTURED_ALLOWED_EXTENSIONS = [ - "txt", - "markdown", - "md", - "pdf", - "html", - "htm", - "xlsx", - "xls", - "docx", - "csv", - "eml", - "msg", - "pptx", - "ppt", - "xml", - "epub", -] +from services.errors.file import FileNotExistsError, FileTooLargeError, UnsupportedFileTypeError PREVIEW_WORDS_LIMIT = 3000 class FileService: @staticmethod - def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: + def upload_file( + file: FileStorage, user: Union[Account, EndUser], source: Literal["datasets"] | None = None + ) -> UploadFile: + # get file name filename = file.filename - extension = file.filename.split(".")[-1] + if not filename: + raise FileNotExistsError + extension = filename.split(".")[-1] if len(filename) > 200: filename = filename.split(".")[0][:200] + "." + extension - etl_type = dify_config.ETL_TYPE - allowed_extensions = ( - UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS - if etl_type == "Unstructured" - else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS - ) - if extension.lower() not in allowed_extensions or only_image and extension.lower() not in IMAGE_EXTENSIONS: + + if source == "datasets" and extension not in DOCUMENT_EXTENSIONS: raise UnsupportedFileTypeError() - # read file content - file_content = file.read() - - # get file size - file_size = len(file_content) - - if extension.lower() in IMAGE_EXTENSIONS: + # select file size limit + if extension in IMAGE_EXTENSIONS: file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in VIDEO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in AUDIO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 else: file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 + # read file content + file_content = file.read() + # get file size + file_size = len(file_content) + + # check if the file size is exceeded if file_size > file_size_limit: message = f"File size exceeded. {file_size} > {file_size_limit}" raise FileTooLargeError(message) - # user uuid as file name + # generate file key file_uuid = str(uuid.uuid4()) if isinstance(user, Account): @@ -150,9 +138,7 @@ class FileService: # extract text from file extension = upload_file.extension - etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS - if extension.lower() not in allowed_extensions: + if extension.lower() not in DOCUMENT_EXTENSIONS: raise UnsupportedFileTypeError() text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) @@ -161,8 +147,10 @@ class FileService: return text @staticmethod - def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> tuple[Generator, str]: - result = UploadFileParser.verify_image_file_signature(file_id, timestamp, nonce, sign) + def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str): + result = file_helpers.verify_image_signature( + upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign + ) if not result: raise NotFound("File not found or signature is invalid") @@ -180,6 +168,21 @@ class FileService: return generator, upload_file.mime_type + @staticmethod + def get_signed_file_preview(file_id: str, timestamp: str, nonce: str, sign: str): + result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign) + if not result: + raise NotFound("File not found or signature is invalid") + + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + + if not upload_file: + raise NotFound("File not found or signature is invalid") + + generator = storage.load(upload_file.key, stream=True) + + return generator, upload_file.mime_type + @staticmethod def get_public_image_preview(file_id: str) -> tuple[Generator, str]: upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 87ddaf3e67..09aa2caf55 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,5 +1,7 @@ import json +from collections.abc import Mapping, Sequence from datetime import datetime +from typing import Any from sqlalchemy import or_ @@ -20,9 +22,9 @@ class WorkflowToolManageService: Service class for managing workflow tools. """ - @classmethod + @staticmethod def create_workflow_tool( - cls, + *, user_id: str, tenant_id: str, workflow_app_id: str, @@ -30,22 +32,10 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[dict], + parameters: Mapping[str, Any], privacy_policy: str = "", labels: list[str] | None = None, ) -> dict: - """ - Create a workflow tool. - :param user_id: the user id - :param tenant_id: the tenant id - :param name: the name - :param icon: the icon - :param description: the description - :param parameters: the parameters - :param privacy_policy: the privacy policy - :param labels: labels - :return: the created tool - """ WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique @@ -192,7 +182,7 @@ class WorkflowToolManageService: """ db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() - tools = [] + tools: Sequence[WorkflowToolProviderController] = [] for provider in db_tools: try: tools.append(ToolTransformService.workflow_provider_to_controller(provider)) @@ -211,7 +201,7 @@ class WorkflowToolManageService: ToolTransformService.repack_provider(user_tool_provider) user_tool_provider.tools = [ ToolTransformService.convert_tool_entity_to_api_entity( - tool=tool.get_tools(user_id, tenant_id)[0], + tool=tool.get_tools(tenant_id)[0], labels=labels.get(tool.provider_id, []), tenant_id=tenant_id, ) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index db1a036e68..75c11afa94 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -13,12 +13,12 @@ from core.app.app_config.entities import ( from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from core.file.file_obj import FileExtraConfig +from core.file.models import FileExtraConfig from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes import NodeType from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account @@ -522,7 +522,7 @@ class WorkflowConverter: "vision": { "enabled": file_upload is not None, "variable_selector": ["sys", "files"] if file_upload is not None else None, - "configs": {"detail": file_upload.image_config["detail"]} + "configs": {"detail": file_upload.image_config.detail} if file_upload is not None and file_upload.image_config is not None else None, }, diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index b4f0882a3a..f89487415d 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -4,9 +4,9 @@ from flask_sqlalchemy.pagination import Pagination from sqlalchemy import and_, or_ from extensions.ext_database import db -from models import CreatedByRole -from models.model import App, EndUser -from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus +from models import App, EndUser, WorkflowAppLog, WorkflowRun +from models.enums import CreatedByRole +from models.workflow import WorkflowRunStatus class WorkflowAppService: @@ -21,7 +21,7 @@ class WorkflowAppService: WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id ) - status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None + status = WorkflowRunStatus.value_of(args.get("status", "")) if args.get("status") else None keyword = args["keyword"] if keyword or status: query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) @@ -42,7 +42,7 @@ class WorkflowAppService: query = query.outerjoin( EndUser, - and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value), + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER), ).filter(or_(*keyword_conditions)) if status: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index b7b3abeaa2..d8ee323908 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,11 +1,11 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.enums import WorkflowRunTriggeredFrom from models.model import App from models.workflow import ( WorkflowNodeExecution, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, - WorkflowRunTriggeredFrom, ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index eec3d26a7b..8cb3f9fe6e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -6,21 +6,23 @@ from typing import Any, Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.segments import Variable from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.variables import Variable +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.event import RunCompletedEvent, RunEvent -from core.workflow.nodes.node_mapping import node_classes +from core.workflow.nodes import NodeType +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.event.types import NodeEvent +from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole from models.model import App, AppMode from models.workflow import ( - CreatedByRole, Workflow, WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -177,7 +179,7 @@ class WorkflowService: """ # return default block config default_block_configs = [] - for node_type, node_class in node_classes.items(): + for node_type, node_class in node_type_classes_mapping.items(): default_config = node_class.get_default_config() if default_config: default_block_configs.append(default_config) @@ -191,10 +193,10 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type_enum: NodeType = NodeType.value_of(node_type) + node_type_enum: NodeType = NodeType(node_type) # return default block config - node_class = node_classes.get(node_type_enum) + node_class = node_type_classes_mapping.get(node_type_enum) if not node_class: return None @@ -265,7 +267,7 @@ class WorkflowService: def _handle_node_run_result( self, - getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]], + getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]], start_at: float, tenant_id: str, node_id: str, @@ -306,7 +308,7 @@ class WorkflowService: workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value workflow_node_execution.index = 1 workflow_node_execution.node_id = node_id - workflow_node_execution.node_type = node_instance.node_type.value + workflow_node_execution.node_type = node_instance.node_type workflow_node_execution.title = node_instance.node_data.title workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py new file mode 100644 index 0000000000..d78fc2b891 --- /dev/null +++ b/api/tasks/mail_email_code_login.py @@ -0,0 +1,41 @@ +import logging +import time + +import click +from celery import shared_task +from flask import render_template + +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_email_code_login_mail_task(language: str, to: str, code: str): + """ + Async Send email code login mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param code: Email code to be included in the email + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start email code login mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + # send email code login mail using different languages + try: + if language == "zh-Hans": + html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code) + mail.send(to=to, subject="邮箱验证码", html=html_content) + else: + html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code) + mail.send(to=to, subject="Email Code", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send email code login mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("Send email code login mail to {} failed".format(to)) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index cbb78976ca..8596ca07cf 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -5,17 +5,16 @@ import click from celery import shared_task from flask import render_template -from configs import dify_config from extensions.ext_mail import mail @shared_task(queue="mail") -def send_reset_password_mail_task(language: str, to: str, token: str): +def send_reset_password_mail_task(language: str, to: str, code: str): """ Async Send reset password mail :param language: Language in which the email should be sent (e.g., 'en', 'zh') :param to: Recipient email address - :param token: Reset password token to be included in the email + :param code: Reset password code """ if not mail.is_inited(): return @@ -25,13 +24,12 @@ def send_reset_password_mail_task(language: str, to: str, token: str): # send reset password mail using different languages try: - url = f"{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}" if language == "zh-Hans": - html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, url=url) - mail.send(to=to, subject="重置您的 Dify 密码", html=html_content) + html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, code=code) + mail.send(to=to, subject="设置您的 Dify 密码", html=html_content) else: - html_content = render_template("reset_password_mail_template_en-US.html", to=to, url=url) - mail.send(to=to, subject="Reset Your Dify Password", html=html_content) + html_content = render_template("reset_password_mail_template_en-US.html", to=to, code=code) + mail.send(to=to, subject="Set Your Dify Password", html=html_content) end_at = time.perf_counter() logging.info( diff --git a/api/templates/email_code_login_mail_template_en-US.html b/api/templates/email_code_login_mail_template_en-US.html new file mode 100644 index 0000000000..066818d10c --- /dev/null +++ b/api/templates/email_code_login_mail_template_en-US.html @@ -0,0 +1,74 @@ + + + + + + +
+
+ + Dify Logo +
+

Your login code for Dify

+

Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request a login, don't worry. You can safely ignore this email.

+
+ + diff --git a/api/templates/email_code_login_mail_template_zh-CN.html b/api/templates/email_code_login_mail_template_zh-CN.html new file mode 100644 index 0000000000..0c2b63a1f1 --- /dev/null +++ b/api/templates/email_code_login_mail_template_zh-CN.html @@ -0,0 +1,74 @@ + + + + + + +
+
+ + Dify Logo +
+

Dify 的登录验证码

+

复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求登录,请不要担心。您可以安全地忽略此电子邮件。

+
+ + diff --git a/api/templates/invite_member_mail_template_en-US.html b/api/templates/invite_member_mail_template_en-US.html index 80f7d42c20..e8bf7f5a52 100644 --- a/api/templates/invite_member_mail_template_en-US.html +++ b/api/templates/invite_member_mail_template_en-US.html @@ -59,7 +59,7 @@

Dear {{ to }},

{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.

-

You can now log in to Dify using the GitHub or Google account associated with this email.

+

Click the button below to log in to Dify and join the workspace.

Login Here

diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index 132096c6b4..9d6345aa6c 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -21,7 +21,7 @@ import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { CheckModal } from '@/hooks/use-pay' import TabSliderNew from '@/app/components/base/tab-slider-new' import { useTabSearchParams } from '@/hooks/use-tab-searchparams' -import SearchInput from '@/app/components/base/search-input' +import Input from '@/app/components/base/input' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' @@ -87,15 +87,15 @@ const Apps = () => { localStorage.removeItem(NEED_REFRESH_APP_LIST_KEY) mutate() } - }, []) + }, [mutate, t]) useEffect(() => { if (isCurrentWorkspaceDatasetOperator) return router.replace('/datasets') - }, [isCurrentWorkspaceDatasetOperator]) + }, [router, isCurrentWorkspaceDatasetOperator]) - const hasMore = data?.at(-1)?.has_more ?? true useEffect(() => { + const hasMore = data?.at(-1)?.has_more ?? true let observer: IntersectionObserver | undefined if (anchorRef.current) { observer = new IntersectionObserver((entries) => { @@ -105,7 +105,7 @@ const Apps = () => { observer.observe(anchorRef.current) } return () => observer?.disconnect() - }, [isLoading, setSize, anchorRef, mutate, hasMore]) + }, [isLoading, setSize, anchorRef, mutate, data]) const { run: handleSearch } = useDebounceFn(() => { setSearchKeywords(keywords) @@ -133,7 +133,14 @@ const Apps = () => { />
- + handleKeywordsChange(e.target.value)} + onClear={() => handleKeywordsChange('')} + />