diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 106c26bbed..36fa39b5d7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,6 +9,9 @@ # CODEOWNERS file /.github/CODEOWNERS @laipz8200 @crazywoola +# Agents +/.agents/skills/ @hyoban + # Docs /docs/ @crazywoola diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index ac7f3a6b48..704d896192 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -8,7 +8,6 @@ on: - "build/**" - "release/e-*" - "hotfix/**" - - "feat/hitl-backend" tags: - "*" diff --git a/api/.env.example b/api/.env.example index fcadfa1c3b..8bd2c706c1 100644 --- a/api/.env.example +++ b/api/.env.example @@ -717,28 +717,3 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 - - -# Redis URL used for PubSub between API and -# celery worker -# defaults to url constructed from `REDIS_*` -# configurations -PUBSUB_REDIS_URL= -# Pub/sub channel type for streaming events. -# valid options are: -# -# - pubsub: for normal Pub/Sub -# - sharded: for sharded Pub/Sub -# -# It's highly recommended to use sharded Pub/Sub AND redis cluster -# for large deployments. -PUBSUB_REDIS_CHANNEL_TYPE=pubsub -# Whether to use Redis cluster mode while running -# PubSub. -# It's highly recommended to enable this for large deployments. -PUBSUB_REDIS_USE_CLUSTERS=false - -# Whether to Enable human input timeout check task -ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true -# Human input timeout check interval in minutes -HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1 diff --git a/api/.importlinter b/api/.importlinter index 98f87710ed..9dad254560 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -36,8 +36,6 @@ ignore_imports = core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine core.workflow.nodes.loop.loop_node -> core.workflow.graph core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels - # TODO(QuantumGhost): fix the import violation later - core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities [importlinter:contract:workflow-infrastructure-dependencies] name = Workflow Infrastructure Dependencies @@ -60,8 +58,6 @@ ignore_imports = core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.graph_engine.manager -> extensions.ext_redis core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis - # TODO(QuantumGhost): use DI to avoid depending on global DB. - core.workflow.nodes.human_input.human_input_node -> extensions.ext_database [importlinter:contract:workflow-external-imports] name = Workflow External Imports @@ -149,7 +145,6 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> core.agent.entities core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities - core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities @@ -253,7 +248,6 @@ ignore_imports = core.workflow.nodes.document_extractor.node -> core.variables.segments core.workflow.nodes.http_request.executor -> core.variables.segments core.workflow.nodes.http_request.node -> core.variables.segments - core.workflow.nodes.human_input.entities -> core.variables.consts core.workflow.nodes.iteration.iteration_node -> core.variables core.workflow.nodes.iteration.iteration_node -> core.variables.segments core.workflow.nodes.iteration.iteration_node -> core.variables.variables @@ -300,8 +294,6 @@ ignore_imports = core.workflow.nodes.llm.llm_utils -> extensions.ext_database core.workflow.nodes.llm.node -> extensions.ext_database core.workflow.nodes.tool.tool_node -> extensions.ext_database - core.workflow.nodes.human_input.human_input_node -> extensions.ext_database - core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository core.workflow.workflow_entry -> extensions.otel.runtime core.workflow.nodes.agent.agent_node -> models core.workflow.nodes.base.node -> models.enums diff --git a/api/.ruff.toml b/api/.ruff.toml index 8db0cbcb21..3301452ad9 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -53,6 +53,7 @@ select = [ "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S302", # suspicious-marshal-usage, disallow use of `marshal` module "S311", # suspicious-non-cryptographic-random-usage, + "TID", # flake8-tidy-imports ] @@ -88,6 +89,7 @@ ignore = [ "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false + "TID252", # allow relative imports from parent modules ] [lint.per-file-ignores] @@ -109,10 +111,20 @@ ignore = [ "S110", # allow ignoring exceptions in tests code (currently) ] +"controllers/console/explore/trial.py" = ["TID251"] +"controllers/console/human_input_form.py" = ["TID251"] +"controllers/web/human_input_form.py" = ["TID251"] [lint.pyflakes] allowed-unused-imports = [ - "_pytest.monkeypatch", "tests.integration_tests", "tests.unit_tests", ] + +[lint.flake8-tidy-imports] + +[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"] +msg = "Use Pydantic payload/query models instead of reqparse." + +[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"] +msg = "Use Pydantic payload/query models instead of reqparse." diff --git a/api/commands.py b/api/commands.py index 4b811fb1e6..c4f2c9edbb 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1450,54 +1450,58 @@ def clear_orphaned_file_records(force: bool): all_ids_in_tables = [] for ids_table in ids_tables: query = "" - if ids_table["type"] == "uuid": - click.echo( - click.style( - f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white" + match ids_table["type"]: + case "uuid": + click.echo( + click.style( + f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", + fg="white", + ) ) - ) - query = ( - f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) - elif ids_table["type"] == "text": - click.echo( - click.style( - f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}", - fg="white", + c = ids_table["column"] + query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) + case "text": + t = ids_table["table"] + click.echo( + click.style( + f"- Listing file-id-like strings in column {ids_table['column']} in table {t}", + fg="white", + ) ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) - elif ids_table["type"] == "json": - click.echo( - click.style( - ( - f"- Listing file-id-like JSON string in column {ids_table['column']} " - f"in table {ids_table['table']}" - ), - fg="white", + query = ( + f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case "json": + click.echo( + click.style( + ( + f"- Listing file-id-like JSON string in column {ids_table['column']} " + f"in table {ids_table['table']}" + ), + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case _: + pass click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) except Exception as e: @@ -1737,59 +1741,18 @@ def file_usage( if src_filter != src: continue - if ids_table["type"] == "uuid": - # Direct UUID match - query = ( - f"SELECT {ids_table['pk_column']}, {ids_table['column']} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - ref_file_id = str(row[1]) - if ref_file_id not in file_key_map: - continue - storage_key = file_key_map[ref_file_id] - - # Apply filters - if file_id and ref_file_id != file_id: - continue - if key and not storage_key.endswith(key): - continue - - # Only collect items within the requested page range - if offset <= total_count < offset + limit: - paginated_usages.append( - { - "src": f"{ids_table['table']}.{ids_table['column']}", - "record_id": record_id, - "file_id": ref_file_id, - "key": storage_key, - } - ) - total_count += 1 - - elif ids_table["type"] in ("text", "json"): - # Extract UUIDs from text/json content - column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] - query = ( - f"SELECT {ids_table['pk_column']}, {column_cast} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - content = str(row[1]) - - # Find all UUIDs in the content - import re - - uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) - matches = uuid_pattern.findall(content) - - for ref_file_id in matches: + match ids_table["type"]: + case "uuid": + # Direct UUID match + query = ( + f"SELECT {ids_table['pk_column']}, {ids_table['column']} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + ref_file_id = str(row[1]) if ref_file_id not in file_key_map: continue storage_key = file_key_map[ref_file_id] @@ -1812,6 +1775,50 @@ def file_usage( ) total_count += 1 + case "text" | "json": + # Extract UUIDs from text/json content + column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] + query = ( + f"SELECT {ids_table['pk_column']}, {column_cast} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + content = str(row[1]) + + # Find all UUIDs in the content + import re + + uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) + matches = uuid_pattern.findall(content) + + for ref_file_id in matches: + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + case _: + pass + # Output results if output_json: result = { diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 8295e1739c..d97e9a0440 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,4 +1,3 @@ -from datetime import timedelta from enum import StrEnum from typing import Literal @@ -49,16 +48,6 @@ class SecurityConfig(BaseSettings): default=5, ) - WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS: PositiveInt = Field( - description="Maximum number of web form submissions allowed per IP within the rate limit window", - default=30, - ) - - WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS: PositiveInt = Field( - description="Time window in seconds for web form submission rate limiting", - default=60, - ) - LOGIN_DISABLED: bool = Field( description="Whether to disable login checks", default=False, @@ -93,12 +82,6 @@ class AppExecutionConfig(BaseSettings): default=0, ) - HITL_GLOBAL_TIMEOUT_SECONDS: PositiveInt = Field( - description="Maximum seconds a workflow run can stay paused waiting for human input before global timeout.", - default=int(timedelta(days=3).total_seconds()), - ge=1, - ) - class CodeExecutionSandboxConfig(BaseSettings): """ @@ -1151,14 +1134,6 @@ class CeleryScheduleTasksConfig(BaseSettings): description="Enable queue monitor task", default=False, ) - ENABLE_HUMAN_INPUT_TIMEOUT_TASK: bool = Field( - description="Enable human input timeout check task", - default=True, - ) - HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: PositiveInt = Field( - description="Human input timeout check interval in minutes", - default=1, - ) ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field( description="Enable check upgradable plugin task", default=True, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index a15e42babf..63f75924bf 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -6,7 +6,6 @@ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, Pos from pydantic_settings import BaseSettings from .cache.redis_config import RedisConfig -from .cache.redis_pubsub_config import RedisPubSubConfig from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig from .storage.amazon_s3_storage_config import S3StorageConfig from .storage.azure_blob_storage_config import AzureBlobStorageConfig @@ -318,7 +317,6 @@ class MiddlewareConfig( CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig KeywordStoreConfig, RedisConfig, - RedisPubSubConfig, # configs of storage and storage providers StorageConfig, AliyunOSSStorageConfig, diff --git a/api/configs/middleware/cache/redis_pubsub_config.py b/api/configs/middleware/cache/redis_pubsub_config.py deleted file mode 100644 index a72e1dd28f..0000000000 --- a/api/configs/middleware/cache/redis_pubsub_config.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Literal, Protocol -from urllib.parse import quote_plus, urlunparse - -from pydantic import Field -from pydantic_settings import BaseSettings - - -class RedisConfigDefaults(Protocol): - REDIS_HOST: str - REDIS_PORT: int - REDIS_USERNAME: str | None - REDIS_PASSWORD: str | None - REDIS_DB: int - REDIS_USE_SSL: bool - REDIS_USE_SENTINEL: bool | None - REDIS_USE_CLUSTERS: bool - - -class RedisConfigDefaultsMixin: - def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults: - return self - - -class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): - """ - Configuration settings for Redis pub/sub streaming. - """ - - PUBSUB_REDIS_URL: str | None = Field( - alias="PUBSUB_REDIS_URL", - description=( - "Redis connection URL for pub/sub streaming events between API " - "and celery worker, defaults to url constructed from " - "`REDIS_*` configurations" - ), - default=None, - ) - - PUBSUB_REDIS_USE_CLUSTERS: bool = Field( - description=( - "Enable Redis Cluster mode for pub/sub streaming. It's highly " - "recommended to enable this for large deployments." - ), - default=False, - ) - - PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field( - description=( - "Pub/sub channel type for streaming events. " - "Valid options are:\n" - "\n" - " - pubsub: for normal Pub/Sub\n" - " - sharded: for sharded Pub/Sub\n" - "\n" - "It's highly recommended to use sharded Pub/Sub AND redis cluster " - "for large deployments." - ), - default="pubsub", - ) - - def _build_default_pubsub_url(self) -> str: - defaults = self._redis_defaults() - if not defaults.REDIS_HOST or not defaults.REDIS_PORT: - raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed") - - scheme = "rediss" if defaults.REDIS_USE_SSL else "redis" - username = defaults.REDIS_USERNAME or None - password = defaults.REDIS_PASSWORD or None - - userinfo = "" - if username: - userinfo = quote_plus(username) - if password: - password_part = quote_plus(password) - userinfo = f"{userinfo}:{password_part}" if userinfo else f":{password_part}" - if userinfo: - userinfo = f"{userinfo}@" - - host = defaults.REDIS_HOST - port = defaults.REDIS_PORT - db = defaults.REDIS_DB - - netloc = f"{userinfo}{host}:{port}" - return urlunparse((scheme, netloc, f"/{db}", "", "", "")) - - @property - def normalized_pubsub_redis_url(self) -> str: - pubsub_redis_url = self.PUBSUB_REDIS_URL - if pubsub_redis_url: - cleaned = pubsub_redis_url.strip() - pubsub_redis_url = cleaned or None - - if pubsub_redis_url: - return pubsub_redis_url - - return self._build_default_pubsub_url() diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 902d67174b..fdc9aabc83 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -37,7 +37,6 @@ from . import ( apikey, extension, feature, - human_input_form, init_validate, ping, setup, @@ -172,7 +171,6 @@ __all__ = [ "forgot_password", "generator", "hit_testing", - "human_input_form", "init_validate", "installed_app", "load_balancing_config", diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 6a4c1528b0..9931bb5dd7 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,10 +1,11 @@ from typing import Any, Literal from flask import abort, make_response, request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter, field_validator from controllers.common.errors import NoFileUploadedError, TooManyFilesError +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -16,9 +17,11 @@ from controllers.console.wraps import ( ) from extensions.ext_redis import redis_client from fields.annotation_fields import ( - annotation_fields, - annotation_hit_history_fields, - build_annotation_model, + Annotation, + AnnotationExportList, + AnnotationHitHistory, + AnnotationHitHistoryList, + AnnotationList, ) from libs.helper import uuid_value from libs.login import login_required @@ -89,6 +92,14 @@ reg(CreateAnnotationPayload) reg(UpdateAnnotationPayload) reg(AnnotationReplyStatusQuery) reg(AnnotationFilePayload) +register_schema_models( + console_ns, + Annotation, + AnnotationList, + AnnotationExportList, + AnnotationHitHistory, + AnnotationHitHistoryList, +) @console_ns.route("/apps//annotation-reply/") @@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource): def post(self, app_id, action: Literal["enable", "disable"]): app_id = str(app_id) args = AnnotationReplyPayload.model_validate(console_ns.payload) - if action == "enable": - result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) - elif action == "disable": - result = AppAnnotationService.disable_app_annotation(app_id) + match action: + case "enable": + result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) + case "disable": + result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 @@ -201,33 +213,33 @@ class AnnotationApi(Resource): app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) - response = { - "data": marshal(annotation_list, annotation_fields), - "has_more": len(annotation_list) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response, 200 + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response = AnnotationList( + data=annotation_models, + has_more=len(annotation_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json"), 200 @console_ns.doc("create_annotation") @console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__]) - @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) + @console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") - @marshal_with(annotation_fields) @edit_permission_required def post(self, app_id): app_id = str(app_id) args = CreateAnnotationPayload.model_validate(console_ns.payload) data = args.model_dump(exclude_none=True) annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) - return annotation + return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @@ -264,7 +276,7 @@ class AnnotationExportApi(Resource): @console_ns.response( 200, "Annotations exported successfully", - console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}), + console_ns.models[AnnotationExportList.__name__], ) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -274,7 +286,8 @@ class AnnotationExportApi(Resource): def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response_data = {"data": marshal(annotation_list, annotation_fields)} + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json") # Create response with secure headers for CSV export response = make_response(response_data, 200) @@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.doc("update_delete_annotation") @console_ns.doc(description="Update or delete an annotation") @console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) - @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) + @console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__]) @console_ns.response(204, "Annotation deleted successfully") @console_ns.response(403, "Insufficient permissions") @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__]) @@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - @marshal_with(annotation_fields) def post(self, app_id, annotation_id): app_id = str(app_id) annotation_id = str(annotation_id) @@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource): annotation = AppAnnotationService.update_app_annotation_directly( args.model_dump(exclude_none=True), app_id, annotation_id ) - return annotation + return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource): @console_ns.response( 200, "Hit histories retrieved successfully", - console_ns.model( - "AnnotationHitHistoryList", - { - "data": fields.List( - fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields)) - ) - }, - ), + console_ns.models[AnnotationHitHistoryList.__name__], ) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource): annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( app_id, annotation_id, page, limit ) - response = { - "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), - "has_more": len(annotation_hit_history_list) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response + history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python( + annotation_hit_history_list, from_attributes=True + ) + response = AnnotationHitHistoryList( + data=history_models, + has_more=len(annotation_hit_history_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index d344ede466..941db325bf 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -33,7 +34,6 @@ from services.errors.audio import ( ) logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class TextToSpeechPayload(BaseModel): @@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel): language: str = Field(..., description="Language code") -console_ns.schema_model( - TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - TextToSpeechVoiceQuery.__name__, - TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class AudioTranscriptResponse(BaseModel): + text: str = Field(description="Transcribed text from audio") + + +register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery) @console_ns.route("/apps//audio-to-text") @@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource): @console_ns.response( 200, "Audio transcription successful", - console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), + console_ns.models[AudioTranscriptResponse.__name__], ) @console_ns.response(400, "Bad request - No audio uploaded or unsupported type") @console_ns.response(413, "Audio file too large") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 14910c5895..82cc957d04 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -89,7 +89,6 @@ status_count_model = console_ns.model( "success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer, - "paused": fields.Integer, }, ) @@ -509,16 +508,19 @@ class ChatConversationApi(Resource): case "created_at" | "-created_at" | _: query = query.where(Conversation.created_at <= end_datetime_utc) - if args.annotation_status == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ) - elif args.annotation_status == "not_annotated": - query = ( - query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) - .group_by(Conversation.id) - .having(func.count(MessageAnnotation.id) == 0) - ) + match args.annotation_status: + case "annotated": + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + case "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) + case "all": + pass if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index b4fc44767a..1ac55b5e8d 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Any from flask_restx import Resource from pydantic import BaseModel, Field @@ -12,10 +11,12 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, ) from controllers.console.wraps import account_initialization_required, setup_required +from core.app.app_config.entities import ModelConfig from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -26,28 +27,13 @@ from services.workflow_service import WorkflowService DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -class RuleGeneratePayload(BaseModel): - instruction: str = Field(..., description="Rule generation instruction") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") - no_variable: bool = Field(default=False, description="Whether to exclude variables") - - -class RuleCodeGeneratePayload(RuleGeneratePayload): - code_language: str = Field(default="javascript", description="Programming language for code generation") - - -class RuleStructuredOutputPayload(BaseModel): - instruction: str = Field(..., description="Structured output generation instruction") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") - - class InstructionGeneratePayload(BaseModel): flow_id: str = Field(..., description="Workflow/Flow ID") node_id: str = Field(default="", description="Node ID for workflow context") current: str = Field(default="", description="Current instruction text") language: str = Field(default="javascript", description="Programming language (javascript/python)") instruction: str = Field(..., description="Instruction for generation") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") ideal_output: str = Field(default="", description="Expected ideal output") @@ -64,6 +50,7 @@ reg(RuleCodeGeneratePayload) reg(RuleStructuredOutputPayload) reg(InstructionGeneratePayload) reg(InstructionTemplatePayload) +reg(ModelConfig) @console_ns.route("/rule-generate") @@ -82,12 +69,7 @@ class RuleGenerateApi(Resource): _, current_tenant_id = current_account_with_tenant() try: - rules = LLMGenerator.generate_rule_config( - tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=args.no_variable, - ) + rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -118,9 +100,7 @@ class RuleCodeGenerateApi(Resource): try: code_result = LLMGenerator.generate_code( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.code_language, + args=args, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -152,8 +132,7 @@ class RuleStructuredOutputGenerateApi(Resource): try: structured_output = LLMGenerator.generate_structured_output( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, + args=args, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -204,23 +183,29 @@ class InstructionGenerateApi(Resource): case "llm": return LLMGenerator.generate_rule_config( current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, + args=RuleGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + ), ) case "agent": return LLMGenerator.generate_rule_config( current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, + args=RuleGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + ), ) case "code": return LLMGenerator.generate_code( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.language, + args=RuleCodeGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + code_language=args.language, + ), ) case _: return {"error": f"invalid node type: {node_type}"} diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index ab1628d5d4..0be3e0ec49 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, select from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( CompletionRequestError, @@ -32,10 +33,9 @@ from libs.login import current_account_with_tenant, login_required from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError -from services.message_service import MessageService, attach_message_extra_contents +from services.message_service import MessageService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class ChatMessagesQuery(BaseModel): @@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel): raise ValueError("has_comment must be a boolean value") -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class AnnotationCountResponse(BaseModel): + count: int = Field(description="Number of annotations") -reg(ChatMessagesQuery) -reg(MessageFeedbackPayload) -reg(FeedbackExportQuery) +class SuggestedQuestionsResponse(BaseModel): + data: list[str] = Field(description="Suggested question") + + +register_schema_models( + console_ns, + ChatMessagesQuery, + MessageFeedbackPayload, + FeedbackExportQuery, + AnnotationCountResponse, + SuggestedQuestionsResponse, +) # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -198,7 +207,6 @@ message_detail_model = console_ns.model( "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), "message_files": fields.List(fields.Nested(message_file_model)), - "extra_contents": fields.List(fields.Raw), "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, @@ -232,7 +240,7 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): - args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ChatMessagesQuery.model_validate(request.args.to_dict()) conversation = ( db.session.query(Conversation) @@ -291,7 +299,6 @@ class ChatMessageListApi(Resource): has_more = False history_messages = list(reversed(history_messages)) - attach_message_extra_contents(history_messages) return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more) @@ -358,7 +365,7 @@ class MessageAnnotationCountApi(Resource): @console_ns.response( 200, "Annotation count retrieved successfully", - console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), + console_ns.models[AnnotationCountResponse.__name__], ) @get_app_model @setup_required @@ -378,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource): @console_ns.response( 200, "Suggested questions retrieved successfully", - console_ns.model( - "SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))} - ), + console_ns.models[SuggestedQuestionsResponse.__name__], ) @console_ns.response(404, "Message or conversation not found") @setup_required @@ -430,7 +435,7 @@ class MessageFeedbackExportApi(Resource): @login_required @account_initialization_required def get(self, app_model): - args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FeedbackExportQuery.model_validate(request.args.to_dict()) # Import the service function from services.feedback_service import FeedbackService @@ -476,5 +481,4 @@ class MessageApi(Resource): if not message: raise NotFound("Message Not Exists.") - attach_message_extra_contents([message]) return message diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 27e1d01af6..755463cb70 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -507,179 +507,6 @@ class WorkflowDraftRunLoopNodeApi(Resource): raise InternalServerError() -class HumanInputFormPreviewPayload(BaseModel): - inputs: dict[str, Any] = Field( - default_factory=dict, - description="Values used to fill missing upstream variables referenced in form_content", - ) - - -class HumanInputFormSubmitPayload(BaseModel): - form_inputs: dict[str, Any] = Field(..., description="Values the user provides for the form's own fields") - inputs: dict[str, Any] = Field( - ..., - description="Values used to fill missing upstream variables referenced in form_content", - ) - action: str = Field(..., description="Selected action ID") - - -class HumanInputDeliveryTestPayload(BaseModel): - delivery_method_id: str = Field(..., description="Delivery method ID") - inputs: dict[str, Any] = Field( - default_factory=dict, - description="Values used to fill missing upstream variables referenced in form_content", - ) - - -reg(HumanInputFormPreviewPayload) -reg(HumanInputFormSubmitPayload) -reg(HumanInputDeliveryTestPayload) - - -@console_ns.route("/apps//advanced-chat/workflows/draft/human-input/nodes//form/preview") -class AdvancedChatDraftHumanInputFormPreviewApi(Resource): - @console_ns.doc("get_advanced_chat_draft_human_input_form") - @console_ns.doc(description="Get human input form preview for advanced chat workflow") - @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT]) - @edit_permission_required - def post(self, app_model: App, node_id: str): - """ - Preview human input form content and placeholders - """ - current_user, _ = current_account_with_tenant() - args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {}) - inputs = args.inputs - - workflow_service = WorkflowService() - preview = workflow_service.get_human_input_form_preview( - app_model=app_model, - account=current_user, - node_id=node_id, - inputs=inputs, - ) - return jsonable_encoder(preview) - - -@console_ns.route("/apps//advanced-chat/workflows/draft/human-input/nodes//form/run") -class AdvancedChatDraftHumanInputFormRunApi(Resource): - @console_ns.doc("submit_advanced_chat_draft_human_input_form") - @console_ns.doc(description="Submit human input form preview for advanced chat workflow") - @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT]) - @edit_permission_required - def post(self, app_model: App, node_id: str): - """ - Submit human input form preview - """ - current_user, _ = current_account_with_tenant() - args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {}) - workflow_service = WorkflowService() - result = workflow_service.submit_human_input_form_preview( - app_model=app_model, - account=current_user, - node_id=node_id, - form_inputs=args.form_inputs, - inputs=args.inputs, - action=args.action, - ) - return jsonable_encoder(result) - - -@console_ns.route("/apps//workflows/draft/human-input/nodes//form/preview") -class WorkflowDraftHumanInputFormPreviewApi(Resource): - @console_ns.doc("get_workflow_draft_human_input_form") - @console_ns.doc(description="Get human input form preview for workflow") - @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.WORKFLOW]) - @edit_permission_required - def post(self, app_model: App, node_id: str): - """ - Preview human input form content and placeholders - """ - current_user, _ = current_account_with_tenant() - args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {}) - inputs = args.inputs - - workflow_service = WorkflowService() - preview = workflow_service.get_human_input_form_preview( - app_model=app_model, - account=current_user, - node_id=node_id, - inputs=inputs, - ) - return jsonable_encoder(preview) - - -@console_ns.route("/apps//workflows/draft/human-input/nodes//form/run") -class WorkflowDraftHumanInputFormRunApi(Resource): - @console_ns.doc("submit_workflow_draft_human_input_form") - @console_ns.doc(description="Submit human input form preview for workflow") - @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.WORKFLOW]) - @edit_permission_required - def post(self, app_model: App, node_id: str): - """ - Submit human input form preview - """ - current_user, _ = current_account_with_tenant() - workflow_service = WorkflowService() - args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {}) - result = workflow_service.submit_human_input_form_preview( - app_model=app_model, - account=current_user, - node_id=node_id, - form_inputs=args.form_inputs, - inputs=args.inputs, - action=args.action, - ) - return jsonable_encoder(result) - - -@console_ns.route("/apps//workflows/draft/human-input/nodes//delivery-test") -class WorkflowDraftHumanInputDeliveryTestApi(Resource): - @console_ns.doc("test_workflow_draft_human_input_delivery") - @console_ns.doc(description="Test human input delivery for workflow") - @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect(console_ns.models[HumanInputDeliveryTestPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - @edit_permission_required - def post(self, app_model: App, node_id: str): - """ - Test human input delivery - """ - current_user, _ = current_account_with_tenant() - workflow_service = WorkflowService() - args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {}) - workflow_service.test_human_input_delivery( - app_model=app_model, - account=current_user, - node_id=node_id, - delivery_method_id=args.delivery_method_id, - inputs=args.inputs, - ) - return jsonable_encoder({}) - - @console_ns.route("/apps//workflows/draft/run") class DraftWorkflowRunApi(Resource): @console_ns.doc("run_draft_workflow") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index d9a5dde55a..fa74f8aea1 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -5,15 +5,10 @@ from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker -from configs import dify_config from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required -from controllers.web.error import NotFoundError -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -32,21 +27,9 @@ from libs.custom_inputs import time_duration from libs.helper import uuid_value from libs.login import current_user, login_required from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom -from models.workflow import WorkflowRun -from repositories.factory import DifyAPIRepositoryFactory from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME from services.workflow_run_service import WorkflowRunService - -def _build_backstage_input_url(form_token: str | None) -> str | None: - if not form_token: - return None - base_url = dify_config.APP_WEB_URL - if not base_url: - return None - return f"{base_url.rstrip('/')}/form/{form_token}" - - # Workflow run status choices for filtering WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600 @@ -457,63 +440,3 @@ class WorkflowRunNodeExecutionListApi(Resource): ) return {"data": node_executions} - - -@console_ns.route("/workflow//pause-details") -class ConsoleWorkflowPauseDetailsApi(Resource): - """Console API for getting workflow pause details.""" - - @account_initialization_required - @login_required - def get(self, workflow_run_id: str): - """ - Get workflow pause details. - - GET /console/api/workflow//pause-details - - Returns information about why and where the workflow is paused. - """ - - # Query WorkflowRun to determine if workflow is suspended - session_maker = sessionmaker(bind=db.engine) - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker) - workflow_run = db.session.get(WorkflowRun, workflow_run_id) - if not workflow_run: - raise NotFoundError("Workflow run not found") - - # Check if workflow is suspended - is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED - if not is_paused: - return { - "paused_at": None, - "paused_nodes": [], - }, 200 - - pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) - pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] - - # Build response - paused_at = pause_entity.paused_at if pause_entity else None - paused_nodes = [] - response = { - "paused_at": paused_at.isoformat() + "Z" if paused_at else None, - "paused_nodes": paused_nodes, - } - - for reason in pause_reasons: - if isinstance(reason, HumanInputRequired): - paused_nodes.append( - { - "node_id": reason.node_id, - "node_title": reason.node_title, - "pause_type": { - "type": "human_input", - "form_id": reason.form_id, - "backstage_input_url": _build_backstage_input_url(reason.form_token), - }, - } - ) - else: - raise AssertionError("unimplemented.") - - return response, 200 diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 0dd7d33ae9..3a3278ec9d 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,9 +2,11 @@ import logging import httpx from flask import current_app, redirect, request -from flask_restx import Resource, fields +from flask_restx import Resource +from pydantic import BaseModel, Field from configs import dify_config +from controllers.common.schema import register_schema_models from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required, logger = logging.getLogger(__name__) +class OAuthDataSourceResponse(BaseModel): + data: str = Field(description="Authorization URL or 'internal' for internal setup") + + +class OAuthDataSourceBindingResponse(BaseModel): + result: str = Field(description="Operation result") + + +class OAuthDataSourceSyncResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + OAuthDataSourceResponse, + OAuthDataSourceBindingResponse, + OAuthDataSourceSyncResponse, +) + + def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( @@ -34,10 +56,7 @@ class OAuthDataSource(Resource): @console_ns.response( 200, "Authorization URL or internal setup success", - console_ns.model( - "OAuthDataSourceResponse", - {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, - ), + console_ns.models[OAuthDataSourceResponse.__name__], ) @console_ns.response(400, "Invalid provider") @console_ns.response(403, "Admin privileges required") @@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource): @console_ns.response( 200, "Data source binding success", - console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceBindingResponse.__name__], ) @console_ns.response(400, "Invalid provider or code") def get(self, provider: str): @@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource): @console_ns.response( 200, "Data source sync success", - console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceSyncResponse.__name__], ) @console_ns.response(400, "Invalid provider or sync failed") @setup_required diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 394f205d93..1ed931b0d7 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,10 +2,11 @@ import base64 import secrets from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailCodeError, @@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel): return valid_password(value) -for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class ForgotPasswordEmailResponse(BaseModel): + result: str = Field(description="Operation result") + data: str | None = Field(default=None, description="Reset token") + code: str | None = Field(default=None, description="Error code if account not found") + + +class ForgotPasswordCheckResponse(BaseModel): + is_valid: bool = Field(description="Whether code is valid") + email: EmailStr = Field(description="Email address") + token: str = Field(description="New reset token") + + +class ForgotPasswordResetResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + ForgotPasswordSendPayload, + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordEmailResponse, + ForgotPasswordCheckResponse, + ForgotPasswordResetResponse, +) @console_ns.route("/forgot-password") @@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): @console_ns.response( 200, "Email sent successfully", - console_ns.model( - "ForgotPasswordEmailResponse", - { - "result": fields.String(description="Operation result"), - "data": fields.String(description="Reset token"), - "code": fields.String(description="Error code if account not found"), - }, - ), + console_ns.models[ForgotPasswordEmailResponse.__name__], ) @console_ns.response(400, "Invalid email or rate limit exceeded") @setup_required @@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource): @console_ns.response( 200, "Code verified successfully", - console_ns.model( - "ForgotPasswordCheckResponse", - { - "is_valid": fields.Boolean(description="Whether code is valid"), - "email": fields.String(description="Email address"), - "token": fields.String(description="New reset token"), - }, - ), + console_ns.models[ForgotPasswordCheckResponse.__name__], ) @console_ns.response(400, "Invalid code or token") @setup_required @@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource): @console_ns.response( 200, "Password reset successfully", - console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[ForgotPasswordResetResponse.__name__], ) @console_ns.response(400, "Invalid token or password mismatch") @setup_required diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 6162d88a0b..38ea5d2dae 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource): grant_type = OAuthGrantType(payload.grant_type) except ValueError: raise BadRequest("invalid grant_type") + match grant_type: + case OAuthGrantType.AUTHORIZATION_CODE: + if not payload.code: + raise BadRequest("code is required") - if grant_type == OAuthGrantType.AUTHORIZATION_CODE: - if not payload.code: - raise BadRequest("code is required") + if payload.client_secret != oauth_provider_app.client_secret: + raise BadRequest("client_secret is invalid") - if payload.client_secret != oauth_provider_app.client_secret: - raise BadRequest("client_secret is invalid") + if payload.redirect_uri not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") - if payload.redirect_uri not in oauth_provider_app.redirect_uris: - raise BadRequest("redirect_uri is invalid") + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, code=payload.code, client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + case OAuthGrantType.REFRESH_TOKEN: + if not payload.refresh_token: + raise BadRequest("refresh_token is required") - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, code=payload.code, client_id=oauth_provider_app.client_id - ) - return jsonable_encoder( - { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, - "refresh_token": refresh_token, - } - ) - elif grant_type == OAuthGrantType.REFRESH_TOKEN: - if not payload.refresh_token: - raise BadRequest("refresh_token is required") - - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id - ) - return jsonable_encoder( - { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, - "refresh_token": refresh_token, - } - ) + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) @console_ns.route("/oauth/provider/account") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 01e9bf77c0..daef4e005a 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator -from typing import Any, cast +from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal_with @@ -157,9 +157,8 @@ class DataSourceApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, binding_id, action): + def patch(self, binding_id, action: Literal["enable", "disable"]): binding_id = str(binding_id) - action = str(action) with Session(db.engine) as session: data_source_binding = session.execute( select(DataSourceOauthBinding).filter_by(id=binding_id) @@ -167,23 +166,24 @@ class DataSourceApi(Resource): if data_source_binding is None: raise NotFound("Data source binding not found.") # enable binding - if action == "enable": - if data_source_binding.disabled: - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is not disabled.") - # disable binding - if action == "disable": - if not data_source_binding.disabled: - data_source_binding.disabled = True - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is disabled.") + match action: + case "enable": + if data_source_binding.disabled: + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is not disabled.") + # disable binding + case "disable": + if not data_source_binding.disabled: + data_source_binding.disabled = True + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is disabled.") return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6e3c0db8a3..bf097d374a 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -576,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict + match document.data_source_type: + case "upload_file": + if not data_source_info: + continue + file_id = data_source_info["upload_file_id"] + file_detail = ( + db.session.query(UploadFile) + .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) + .first() + ) - if document.data_source_type == "upload_file": - if not data_source_info: - continue - file_id = data_source_info["upload_file_id"] - file_detail = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) - .first() - ) + if file_detail is None: + raise NotFound("File not found.") - if file_detail is None: - raise NotFound("File not found.") + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form + ) + extract_settings.append(extract_setting) + case "notion_import": + if not data_source_info: + continue + extract_setting = ExtractSetting( + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info.get("credential_id"), + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_tenant_id, + } + ), + document_model=document.doc_form, + ) + extract_settings.append(extract_setting) + case "website_crawl": + if not data_source_info: + continue + extract_setting = ExtractSetting( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], + "tenant_id": current_tenant_id, + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), + document_model=document.doc_form, + ) + extract_settings.append(extract_setting) - extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form - ) - extract_settings.append(extract_setting) - - elif document.data_source_type == "notion_import": - if not data_source_info: - continue - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( - { - "credential_id": data_source_info.get("credential_id"), - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "tenant_id": current_tenant_id, - } - ), - document_model=document.doc_form, - ) - extract_settings.append(extract_setting) - elif document.data_source_type == "website_crawl": - if not data_source_info: - continue - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "url": data_source_info["url"], - "tenant_id": current_tenant_id, - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - } - ), - document_model=document.doc_form, - ) - extract_settings.append(extract_setting) - - else: - raise ValueError("Data source type not support") + case _: + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: response = indexing_runner.indexing_estimate( @@ -954,23 +953,24 @@ class DocumentProcessingApi(DocumentResource): if not current_user.is_dataset_editor: raise Forbidden() - if action == "pause": - if document.indexing_status != "indexing": - raise InvalidActionError("Document not in indexing state.") + match action: + case "pause": + if document.indexing_status != "indexing": + raise InvalidActionError("Document not in indexing state.") - document.paused_by = current_user.id - document.paused_at = naive_utc_now() - document.is_paused = True - db.session.commit() + document.paused_by = current_user.id + document.paused_at = naive_utc_now() + document.is_paused = True + db.session.commit() - elif action == "resume": - if document.indexing_status not in {"paused", "error"}: - raise InvalidActionError("Document not in paused or error state.") + case "resume": + if document.indexing_status not in {"paused", "error"}: + raise InvalidActionError("Document not in paused or error state.") - document.paused_by = None - document.paused_at = None - document.is_paused = False - db.session.commit() + document.paused_by = None + document.paused_at = None + document.is_paused = False + db.session.commit() return {"result": "success"}, 200 @@ -1339,6 +1339,18 @@ class DocumentGenerateSummaryApi(Resource): missing_ids = set(document_list) - found_ids raise NotFound(f"Some documents not found: {list(missing_ids)}") + # Update need_summary to True for documents that don't have it set + # This handles the case where documents were created when summary_index_setting was disabled + documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"] + + if documents_to_update: + document_ids_to_update = [str(doc.id) for doc in documents_to_update] + DocumentService.update_documents_need_summary( + dataset_id=dataset_id, + document_ids=document_ids_to_update, + need_summary=True, + ) + # Dispatch async tasks for each document for document in documents: # Skip qa_model documents as they don't generate summaries diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 05fc4cd714..2e69ddc5ab 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index d34fd5088d..29b6b64b94 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -1,10 +1,9 @@ import json import logging from typing import Any, Literal, cast -from uuid import UUID from flask import abort, request -from flask_restx import Resource, marshal_with, reqparse # type: ignore +from flask_restx import Resource, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory from libs import helper -from libs.helper import TimestampField +from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline @@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel): class WorkflowRunQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) @@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel): start_node_title: str +class RagPipelineRecommendedPluginQuery(BaseModel): + type: str = "all" + + register_schema_models( console_ns, DraftWorkflowSyncPayload, @@ -135,6 +138,7 @@ register_schema_models( NodeIdQuery, WorkflowRunQuery, DatasourceVariablesPayload, + RagPipelineRecommendedPluginQuery, ) @@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, location="args", required=False, default="all") - args = parser.parse_args() - type = args["type"] + query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict()) rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) + recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type) return recommended_plugins diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 1eb0cdb019..cd523b481c 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -9,7 +9,7 @@ import services from controllers.common.fields import Parameters as ParametersResponse from controllers.common.fields import Site as SiteResponse from controllers.common.schema import get_or_create_model -from controllers.console import api, console_ns +from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -51,7 +51,7 @@ from fields.app_fields import ( tag_fields, ) from fields.dataset_fields import dataset_fields -from fields.member_fields import build_simple_account_model +from fields.member_fields import simple_account_fields from fields.workflow_fields import ( conversation_variable_fields, pipeline_variable_fields, @@ -103,7 +103,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model)) app_detail_fields_with_site_copy["site"] = fields.Nested(site_model) app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy) -simple_account_model = build_simple_account_model(console_ns) +simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields) conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields) pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields) diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py deleted file mode 100644 index 7207f7fd1d..0000000000 --- a/api/controllers/console/human_input_form.py +++ /dev/null @@ -1,217 +0,0 @@ -""" -Console/Studio Human Input Form APIs. -""" - -import json -import logging -from collections.abc import Generator - -from flask import Response, jsonify, request -from flask_restx import Resource, reqparse -from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker - -from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, setup_required -from controllers.web.error import InvalidArgumentError, NotFoundError -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.apps.message_generator import MessageGenerator -from core.app.apps.workflow.app_generator import WorkflowAppGenerator -from extensions.ext_database import db -from libs.login import current_account_with_tenant, login_required -from models import App -from models.enums import CreatorUserRole -from models.human_input import RecipientType -from models.model import AppMode -from models.workflow import WorkflowRun -from repositories.factory import DifyAPIRepositoryFactory -from services.human_input_service import Form, HumanInputService -from services.workflow_event_snapshot_service import build_workflow_event_stream - -logger = logging.getLogger(__name__) - - -def _jsonify_form_definition(form: Form) -> Response: - payload = form.get_definition().model_dump() - payload["expiration_time"] = int(form.expiration_time.timestamp()) - return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json") - - -@console_ns.route("/form/human_input/") -class ConsoleHumanInputFormApi(Resource): - """Console API for getting human input form definition.""" - - @staticmethod - def _ensure_console_access(form: Form): - _, current_tenant_id = current_account_with_tenant() - - if form.tenant_id != current_tenant_id: - raise NotFoundError("App not found") - - @setup_required - @login_required - @account_initialization_required - def get(self, form_token: str): - """ - Get human input form definition by form token. - - GET /console/api/form/human_input/ - """ - service = HumanInputService(db.engine) - form = service.get_form_definition_by_token_for_console(form_token) - if form is None: - raise NotFoundError(f"form not found, token={form_token}") - - self._ensure_console_access(form) - - return _jsonify_form_definition(form) - - @account_initialization_required - @login_required - def post(self, form_token: str): - """ - Submit human input form by form token. - - POST /console/api/form/human_input/ - - Request body: - { - "inputs": { - "content": "User input content" - }, - "action": "Approve" - } - """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() - current_user, _ = current_account_with_tenant() - - service = HumanInputService(db.engine) - form = service.get_form_by_token(form_token) - if form is None: - raise NotFoundError(f"form not found, token={form_token}") - - self._ensure_console_access(form) - - recipient_type = form.recipient_type - if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}: - raise NotFoundError(f"form not found, token={form_token}") - # The type checker is not smart enought to validate the following invariant. - # So we need to assert it manually. - assert recipient_type is not None, "recipient_type cannot be None here." - - service.submit_form_by_token( - recipient_type=recipient_type, - form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], - submission_user_id=current_user.id, - ) - - return jsonify({}) - - -@console_ns.route("/workflow//events") -class ConsoleWorkflowEventsApi(Resource): - """Console API for getting workflow execution events after resume.""" - - @account_initialization_required - @login_required - def get(self, workflow_run_id: str): - """ - Get workflow execution events stream after resume. - - GET /console/api/workflow//events - - Returns Server-Sent Events stream. - """ - - user, tenant_id = current_account_with_tenant() - session_maker = sessionmaker(db.engine) - repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - workflow_run = repo.get_workflow_run_by_id_and_tenant_id( - tenant_id=tenant_id, - run_id=workflow_run_id, - ) - if workflow_run is None: - raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}") - - if workflow_run.created_by_role != CreatorUserRole.ACCOUNT: - raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}") - - if workflow_run.created_by != user.id: - raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}") - - with Session(expire_on_commit=False, bind=db.engine) as session: - app = _retrieve_app_for_workflow_run(session, workflow_run) - - if workflow_run.finished_at is not None: - # TODO(QuantumGhost): should we modify the handling for finished workflow run here? - response = WorkflowResponseConverter.workflow_run_result_to_finish_response( - task_id=workflow_run.id, - workflow_run=workflow_run, - creator_user=user, - ) - - payload = response.model_dump(mode="json") - payload["event"] = response.event.value - - def _generate_finished_events() -> Generator[str, None, None]: - yield f"data: {json.dumps(payload)}\n\n" - - event_generator = _generate_finished_events - - else: - msg_generator = MessageGenerator() - if app.mode == AppMode.ADVANCED_CHAT: - generator = AdvancedChatAppGenerator() - elif app.mode == AppMode.WORKFLOW: - generator = WorkflowAppGenerator() - else: - raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") - - include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" - - def _generate_stream_events(): - if include_state_snapshot: - return generator.convert_to_event_stream( - build_workflow_event_stream( - app_mode=AppMode(app.mode), - workflow_run=workflow_run, - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - session_maker=session_maker, - ) - ) - return generator.convert_to_event_stream( - msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id), - ) - - event_generator = _generate_stream_events - - return Response( - event_generator(), - mimetype="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) - - -def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun): - query = select(App).where( - App.id == workflow_run.app_id, - App.tenant_id == workflow_run.tenant_id, - ) - app = session.scalars(query).first() - if app is None: - raise AssertionError( - f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, " - f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}" - ) - - return app diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 2bebe79eac..f086bf1862 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,87 +1,74 @@ import os +from typing import Literal from flask import session -from flask_restx import Resource, fields from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config +from controllers.fastopenapi import console_router from extensions.ext_database import db from models.model import DifySetup from services.account_service import TenantService -from . import console_ns from .error import AlreadySetupError, InitValidateFailedError from .wraps import only_edition_self_hosted -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class InitValidatePayload(BaseModel): - password: str = Field(..., max_length=30) + password: str = Field(..., max_length=30, description="Initialization password") -console_ns.schema_model( - InitValidatePayload.__name__, - InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +class InitStatusResponse(BaseModel): + status: Literal["finished", "not_started"] = Field(..., description="Initialization status") + + +class InitValidateResponse(BaseModel): + result: str = Field(description="Operation result", examples=["success"]) + + +@console_router.get( + "/init", + response_model=InitStatusResponse, + tags=["console"], ) +def get_init_status() -> InitStatusResponse: + """Get initialization validation status.""" + init_status = get_init_validate_status() + if init_status: + return InitStatusResponse(status="finished") + return InitStatusResponse(status="not_started") -@console_ns.route("/init") -class InitValidateAPI(Resource): - @console_ns.doc("get_init_status") - @console_ns.doc(description="Get initialization validation status") - @console_ns.response( - 200, - "Success", - model=console_ns.model( - "InitStatusResponse", - {"status": fields.String(description="Initialization status", enum=["finished", "not_started"])}, - ), - ) - def get(self): - """Get initialization validation status""" - init_status = get_init_validate_status() - if init_status: - return {"status": "finished"} - return {"status": "not_started"} +@console_router.post( + "/init", + response_model=InitValidateResponse, + tags=["console"], + status_code=201, +) +@only_edition_self_hosted +def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse: + """Validate initialization password.""" + tenant_count = TenantService.get_tenant_count() + if tenant_count > 0: + raise AlreadySetupError() - @console_ns.doc("validate_init_password") - @console_ns.doc(description="Validate initialization password for self-hosted edition") - @console_ns.expect(console_ns.models[InitValidatePayload.__name__]) - @console_ns.response( - 201, - "Success", - model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}), - ) - @console_ns.response(400, "Already setup or validation failed") - @only_edition_self_hosted - def post(self): - """Validate initialization password""" - # is tenant created - tenant_count = TenantService.get_tenant_count() - if tenant_count > 0: - raise AlreadySetupError() + if payload.password != os.environ.get("INIT_PASSWORD"): + session["is_init_validated"] = False + raise InitValidateFailedError() - payload = InitValidatePayload.model_validate(console_ns.payload) - input_password = payload.password - - if input_password != os.environ.get("INIT_PASSWORD"): - session["is_init_validated"] = False - raise InitValidateFailedError() - - session["is_init_validated"] = True - return {"result": "success"}, 201 + session["is_init_validated"] = True + return InitValidateResponse(result="success") -def get_init_validate_status(): +def get_init_validate_status() -> bool: if dify_config.EDITION == "SELF_HOSTED": if os.environ.get("INIT_PASSWORD"): if session.get("is_init_validated"): return True with Session(db.engine) as db_session: - return db_session.execute(select(DifySetup)).scalar_one_or_none() + return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None return True diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 70c7b80ffa..88a9ce3a79 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,7 +1,6 @@ import urllib.parse import httpx -from flask_restx import Resource from pydantic import BaseModel, Field import services @@ -11,7 +10,7 @@ from controllers.common.errors import ( RemoteFileUploadError, UnsupportedFileTypeError, ) -from controllers.common.schema import register_schema_models +from controllers.fastopenapi import console_router from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db @@ -19,84 +18,74 @@ from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from libs.login import current_account_with_tenant from services.file_service import FileService -from . import console_ns - -register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl) - - -@console_ns.route("/remote-files/") -class RemoteFileInfoApi(Resource): - @console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__]) - def get(self, url): - decoded_url = urllib.parse.unquote(url) - resp = ssrf_proxy.head(decoded_url) - if resp.status_code != httpx.codes.OK: - # failed back to get method - resp = ssrf_proxy.get(decoded_url, timeout=3) - resp.raise_for_status() - info = RemoteFileInfo( - file_type=resp.headers.get("Content-Type", "application/octet-stream"), - file_length=int(resp.headers.get("Content-Length", 0)), - ) - return info.model_dump(mode="json") - class RemoteFileUploadPayload(BaseModel): url: str = Field(..., description="URL to fetch") -console_ns.schema_model( - RemoteFileUploadPayload.__name__, - RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"), +@console_router.get( + "/remote-files/", + response_model=RemoteFileInfo, + tags=["console"], ) +def get_remote_file_info(url: str) -> RemoteFileInfo: + decoded_url = urllib.parse.unquote(url) + resp = ssrf_proxy.head(decoded_url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(decoded_url, timeout=3) + resp.raise_for_status() + return RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", 0)), + ) -@console_ns.route("/remote-files/upload") -class RemoteFileUploadApi(Resource): - @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) - @console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__]) - def post(self): - args = RemoteFileUploadPayload.model_validate(console_ns.payload) - url = args.url +@console_router.post( + "/remote-files/upload", + response_model=FileWithSignedUrl, + tags=["console"], + status_code=201, +) +def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl: + url = payload.url - try: - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) - if resp.status_code != httpx.codes.OK: - raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") - except httpx.RequestError as e: - raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") - file_info = helpers.guess_file_info_from_response(resp) + file_info = helpers.guess_file_info_from_response(resp) - if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): - raise FileTooLargeError + if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): + raise FileTooLargeError - content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content + content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content - try: - user, _ = current_account_with_tenant() - upload_file = FileService(db.engine).upload_file( - filename=file_info.filename, - content=content, - mimetype=file_info.mimetype, - user=user, - source_url=url, - ) - except services.errors.file.FileTooLargeError as file_too_large_error: - raise FileTooLargeError(file_too_large_error.description) - except services.errors.file.UnsupportedFileTypeError: - raise UnsupportedFileTypeError() - - payload = FileWithSignedUrl( - id=upload_file.id, - name=upload_file.name, - size=upload_file.size, - extension=upload_file.extension, - url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - mime_type=upload_file.mime_type, - created_by=upload_file.created_by, - created_at=int(upload_file.created_at.timestamp()), + try: + user, _ = current_account_with_tenant() + upload_file = FileService(db.engine).upload_file( + filename=file_info.filename, + content=content, + mimetype=file_info.mimetype, + user=user, + source_url=url, ) - return payload.model_dump(mode="json"), 201 + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 9988524a80..e828d54ff4 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,14 +1,11 @@ from typing import Literal +from uuid import UUID -from flask import request -from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden -from controllers.common.schema import register_schema_models -from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required -from fields.tag_fields import dataset_tag_fields +from controllers.fastopenapi import console_router from libs.login import current_account_with_tenant, login_required from services.tag_service import TagService @@ -35,115 +32,129 @@ class TagListQueryParam(BaseModel): keyword: str | None = Field(None, description="Search keyword") -register_schema_models( - console_ns, - TagBasePayload, - TagBindingPayload, - TagBindingRemovePayload, - TagListQueryParam, +class TagResponse(BaseModel): + id: str = Field(description="Tag ID") + name: str = Field(description="Tag name") + type: str = Field(description="Tag type") + binding_count: int = Field(description="Number of bindings") + + +class TagBindingResult(BaseModel): + result: Literal["success"] = Field(description="Operation result", examples=["success"]) + + +@console_router.get( + "/tags", + response_model=list[TagResponse], + tags=["console"], ) +@setup_required +@login_required +@account_initialization_required +def list_tags(query: TagListQueryParam) -> list[TagResponse]: + _, current_tenant_id = current_account_with_tenant() + tags = TagService.get_tags(query.type, current_tenant_id, query.keyword) + + return [ + TagResponse( + id=tag.id, + name=tag.name, + type=tag.type, + binding_count=int(tag.binding_count), + ) + for tag in tags + ] -@console_ns.route("/tags") -class TagListApi(Resource): - @setup_required - @login_required - @account_initialization_required - @console_ns.doc( - params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."} - ) - @marshal_with(dataset_tag_fields) - def get(self): - _, current_tenant_id = current_account_with_tenant() - raw_args = request.args.to_dict() - param = TagListQueryParam.model_validate(raw_args) - tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) +@console_router.post( + "/tags", + response_model=TagResponse, + tags=["console"], +) +@setup_required +@login_required +@account_initialization_required +def create_tag(payload: TagBasePayload) -> TagResponse: + current_user, _ = current_account_with_tenant() + # The role of the current user in the tag table must be admin, owner, or editor + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() - return tags, 200 + tag = TagService.save_tags(payload.model_dump()) - @console_ns.expect(console_ns.models[TagBasePayload.__name__]) - @setup_required - @login_required - @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, or editor - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() - - payload = TagBasePayload.model_validate(console_ns.payload or {}) - tag = TagService.save_tags(payload.model_dump()) - - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} - - return response, 200 + return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=0) -@console_ns.route("/tags/") -class TagUpdateDeleteApi(Resource): - @console_ns.expect(console_ns.models[TagBasePayload.__name__]) - @setup_required - @login_required - @account_initialization_required - def patch(self, tag_id): - current_user, _ = current_account_with_tenant() - tag_id = str(tag_id) - # The role of the current user in the ta table must be admin, owner, or editor - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() +@console_router.patch( + "/tags/", + response_model=TagResponse, + tags=["console"], +) +@setup_required +@login_required +@account_initialization_required +def update_tag(tag_id: UUID, payload: TagBasePayload) -> TagResponse: + current_user, _ = current_account_with_tenant() + tag_id_str = str(tag_id) + # The role of the current user in the ta table must be admin, owner, or editor + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() - payload = TagBasePayload.model_validate(console_ns.payload or {}) - tag = TagService.update_tags(payload.model_dump(), tag_id) + tag = TagService.update_tags(payload.model_dump(), tag_id_str) - binding_count = TagService.get_tag_binding_count(tag_id) + binding_count = TagService.get_tag_binding_count(tag_id_str) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} - - return response, 200 - - @setup_required - @login_required - @account_initialization_required - @edit_permission_required - def delete(self, tag_id): - tag_id = str(tag_id) - - TagService.delete_tag(tag_id) - - return 204 + return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=binding_count) -@console_ns.route("/tag-bindings/create") -class TagBindingCreateApi(Resource): - @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() +@console_router.delete( + "/tags/", + tags=["console"], + status_code=204, +) +@setup_required +@login_required +@account_initialization_required +@edit_permission_required +def delete_tag(tag_id: UUID) -> None: + tag_id_str = str(tag_id) - payload = TagBindingPayload.model_validate(console_ns.payload or {}) - TagService.save_tag_binding(payload.model_dump()) - - return {"result": "success"}, 200 + TagService.delete_tag(tag_id_str) -@console_ns.route("/tag-bindings/remove") -class TagBindingDeleteApi(Resource): - @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) - @setup_required - @login_required - @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() +@console_router.post( + "/tag-bindings/create", + response_model=TagBindingResult, + tags=["console"], +) +@setup_required +@login_required +@account_initialization_required +def create_tag_binding(payload: TagBindingPayload) -> TagBindingResult: + current_user, _ = current_account_with_tenant() + # The role of the current user in the tag table must be admin, owner, editor, or dataset_operator + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() - payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) - TagService.delete_tag_binding(payload.model_dump()) + TagService.save_tag_binding(payload.model_dump()) - return {"result": "success"}, 200 + return TagBindingResult(result="success") + + +@console_router.post( + "/tag-bindings/remove", + response_model=TagBindingResult, + tags=["console"], +) +@setup_required +@login_required +@account_initialization_required +def delete_tag_binding(payload: TagBindingRemovePayload) -> TagBindingResult: + current_user, _ = current_account_with_tenant() + # The role of the current user in the tag table must be admin, owner, editor, or dataset_operator + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() + + TagService.delete_tag_binding(payload.model_dump()) + + return TagBindingResult(result="success") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 38c66525b3..708df62642 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, @@ -37,7 +38,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_fields +from fields.member_fields import Account as AccountResponse from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required @@ -170,6 +171,12 @@ reg(ChangeEmailSendPayload) reg(ChangeEmailValidityPayload) reg(ChangeEmailResetPayload) reg(CheckEmailUniquePayload) +register_schema_models(console_ns, AccountResponse) + + +def _serialize_account(account) -> dict: + return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") + integrate_fields = { "provider": fields.String, @@ -236,11 +243,11 @@ class AccountProfileApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) @enterprise_license_required def get(self): current_user, _ = current_account_with_tenant() - return current_user + return _serialize_account(current_user) @console_ns.route("/account/name") @@ -249,14 +256,14 @@ class AccountNameApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} args = AccountNamePayload.model_validate(payload) updated_account = AccountService.update_account(current_user, name=args.name) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/avatar") @@ -265,7 +272,7 @@ class AccountAvatarApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -273,7 +280,7 @@ class AccountAvatarApi(Resource): updated_account = AccountService.update_account(current_user, avatar=args.avatar) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/interface-language") @@ -282,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -290,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource): updated_account = AccountService.update_account(current_user, interface_language=args.interface_language) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/interface-theme") @@ -299,7 +306,7 @@ class AccountInterfaceThemeApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -307,7 +314,7 @@ class AccountInterfaceThemeApi(Resource): updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/timezone") @@ -316,7 +323,7 @@ class AccountTimezoneApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -324,7 +331,7 @@ class AccountTimezoneApi(Resource): updated_account = AccountService.update_account(current_user, timezone=args.timezone) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/password") @@ -333,7 +340,7 @@ class AccountPasswordApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -344,7 +351,7 @@ class AccountPasswordApi(Resource): except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() - return {"result": "success"} + return _serialize_account(current_user) @console_ns.route("/account/integrates") @@ -620,7 +627,7 @@ class ChangeEmailResetApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): payload = console_ns.payload or {} args = ChangeEmailResetPayload.model_validate(payload) @@ -649,7 +656,7 @@ class ChangeEmailResetApi(Resource): email=normalized_new_email, ) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/change-email/check-email-unique") diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index bfd9fc6c29..1897cbdca7 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,9 +1,10 @@ from typing import Any from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder @@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery): plugin_id: str +class EndpointCreateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class PluginEndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class EndpointDeleteResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointUpdateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointEnableResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointDisableResponse(BaseModel): + success: bool = Field(description="Operation success") + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -reg(EndpointCreatePayload) -reg(EndpointIdPayload) -reg(EndpointUpdatePayload) -reg(EndpointListQuery) -reg(EndpointListForPluginQuery) +register_schema_models( + console_ns, + EndpointCreatePayload, + EndpointIdPayload, + EndpointUpdatePayload, + EndpointListQuery, + EndpointListForPluginQuery, + EndpointCreateResponse, + EndpointListResponse, + PluginEndpointListResponse, + EndpointDeleteResponse, + EndpointUpdateResponse, + EndpointEnableResponse, + EndpointDisableResponse, +) @console_ns.route("/workspaces/current/endpoints/create") @@ -57,7 +96,7 @@ class EndpointCreateApi(Resource): @console_ns.response( 200, "Endpoint created successfully", - console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointCreateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -91,9 +130,7 @@ class EndpointListApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[EndpointListResponse.__name__], ) @setup_required @login_required @@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[PluginEndpointListResponse.__name__], ) @setup_required @login_required @@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource): @console_ns.response( 200, "Endpoint deleted successfully", - console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDeleteResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource): @console_ns.response( 200, "Endpoint updated successfully", - console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointUpdateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -221,7 +256,7 @@ class EndpointEnableApi(Resource): @console_ns.response( 200, "Endpoint enabled successfully", - console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointEnableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -248,7 +283,7 @@ class EndpointDisableApi(Resource): @console_ns.response( 200, "Endpoint disabled successfully", - console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDisableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 271cdce3c3..dd302b90d6 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,12 +1,12 @@ from urllib import parse from flask import abort, request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter import services from configs import dify_config -from controllers.common.schema import get_or_create_model, register_enum_models +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( CannotTransferOwnerToSelfError, @@ -25,7 +25,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_with_role_fields, account_with_role_list_fields +from fields.member_fields import AccountWithRole, AccountWithRoleList from libs.helper import extract_remote_ip from libs.login import current_account_with_tenant, login_required from models.account import Account, TenantAccountRole @@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload) reg(OwnerTransferCheckPayload) reg(OwnerTransferPayload) register_enum_models(console_ns, TenantAccountRole) - -account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields) - -account_with_role_list_fields_copy = account_with_role_list_fields.copy() -account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model)) -account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy) +register_schema_models(console_ns, AccountWithRole, AccountWithRoleList) @console_ns.route("/workspaces/current/members") @@ -84,13 +79,15 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) - return {"result": "success", "accounts": members}, 200 + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = AccountWithRoleList(accounts=member_models) + return response.model_dump(mode="json"), 200 @console_ns.route("/workspaces/current/members/invite-email") @@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) - return {"result": "success", "accounts": members}, 200 + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = AccountWithRoleList(accounts=member_models) + return response.model_dump(mode="json"), 200 @console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index e9e7b72718..5bfa895849 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,16 +1,16 @@ import io import logging +from typing import Any, Literal from urllib.parse import urlparse from flask import make_response, redirect, request, send_file -from flask_restx import ( - Resource, - reqparse, -) +from flask_restx import Resource +from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -26,8 +26,9 @@ from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from libs.helper import StrLen, alphanumeric, uuid_value +from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID @@ -52,24 +53,209 @@ def is_valid_url(url: str) -> bool: parsed = urlparse(url) return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] except (ValueError, TypeError): - # ValueError: Invalid URL format - # TypeError: url is not a string return False -parser_tool = reqparse.RequestParser().add_argument( - "type", - type=str, - choices=["builtin", "model", "api", "workflow", "mcp"], - required=False, - nullable=True, - location="args", +class ToolProviderListQuery(BaseModel): + type: Literal["builtin", "model", "api", "workflow", "mcp"] | None = None + + +class BuiltinToolCredentialDeletePayload(BaseModel): + credential_id: str + + +class BuiltinToolAddPayload(BaseModel): + credentials: dict[str, Any] + name: str | None = Field(default=None, max_length=30) + type: CredentialType + + +class BuiltinToolUpdatePayload(BaseModel): + credential_id: str + credentials: dict[str, Any] | None = None + name: str | None = Field(default=None, max_length=30) + + +class ApiToolProviderBasePayload(BaseModel): + credentials: dict[str, Any] + schema_type: ApiProviderSchemaType + schema_: str = Field(alias="schema") + provider: str + icon: dict[str, Any] + privacy_policy: str | None = None + labels: list[str] | None = None + custom_disclaimer: str = "" + + +class ApiToolProviderAddPayload(ApiToolProviderBasePayload): + pass + + +class ApiToolProviderUpdatePayload(ApiToolProviderBasePayload): + original_provider: str + + +class UrlQuery(BaseModel): + url: HttpUrl + + +class ProviderQuery(BaseModel): + provider: str + + +class ApiToolProviderDeletePayload(BaseModel): + provider: str + + +class ApiToolSchemaPayload(BaseModel): + schema_: str = Field(alias="schema") + + +class ApiToolTestPayload(BaseModel): + tool_name: str + provider_name: str | None = None + credentials: dict[str, Any] + parameters: dict[str, Any] + schema_type: ApiProviderSchemaType + schema_: str = Field(alias="schema") + + +class WorkflowToolBasePayload(BaseModel): + name: str + label: str + description: str + icon: dict[str, Any] + parameters: list[WorkflowToolParameterConfiguration] = Field(default_factory=list) + privacy_policy: str | None = "" + labels: list[str] | None = None + + @field_validator("name") + @classmethod + def validate_name(cls, value: str) -> str: + return alphanumeric(value) + + +class WorkflowToolCreatePayload(WorkflowToolBasePayload): + workflow_app_id: str + + @field_validator("workflow_app_id") + @classmethod + def validate_workflow_app_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolUpdatePayload(WorkflowToolBasePayload): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolDeletePayload(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolGetQuery(BaseModel): + workflow_tool_id: str | None = None + workflow_app_id: str | None = None + + @field_validator("workflow_tool_id", "workflow_app_id") + @classmethod + def validate_ids(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + @model_validator(mode="after") + def ensure_one(self) -> "WorkflowToolGetQuery": + if not self.workflow_tool_id and not self.workflow_app_id: + raise ValueError("workflow_tool_id or workflow_app_id is required") + return self + + +class WorkflowToolListQuery(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class BuiltinProviderDefaultCredentialPayload(BaseModel): + id: str + + +class ToolOAuthCustomClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enable_oauth_custom_client: bool | None = True + + +class MCPProviderBasePayload(BaseModel): + server_url: str + name: str + icon: str + icon_type: str + icon_background: str = "" + server_identifier: str + configuration: dict[str, Any] | None = Field(default_factory=dict) + headers: dict[str, Any] | None = Field(default_factory=dict) + authentication: dict[str, Any] | None = Field(default_factory=dict) + + +class MCPProviderCreatePayload(MCPProviderBasePayload): + pass + + +class MCPProviderUpdatePayload(MCPProviderBasePayload): + provider_id: str + + +class MCPProviderDeletePayload(BaseModel): + provider_id: str + + +class MCPAuthPayload(BaseModel): + provider_id: str + authorization_code: str | None = None + + +class MCPCallbackQuery(BaseModel): + code: str + state: str + + +register_schema_models( + console_ns, + BuiltinToolCredentialDeletePayload, + BuiltinToolAddPayload, + BuiltinToolUpdatePayload, + ApiToolProviderAddPayload, + ApiToolProviderUpdatePayload, + ApiToolProviderDeletePayload, + ApiToolSchemaPayload, + ApiToolTestPayload, + WorkflowToolCreatePayload, + WorkflowToolUpdatePayload, + WorkflowToolDeletePayload, + BuiltinProviderDefaultCredentialPayload, + ToolOAuthCustomClientPayload, + MCPProviderCreatePayload, + MCPProviderUpdatePayload, + MCPProviderDeletePayload, + MCPAuthPayload, ) @console_ns.route("/workspaces/current/tool-providers") class ToolProviderListApi(Resource): - @console_ns.expect(parser_tool) @setup_required @login_required @account_initialization_required @@ -78,9 +264,10 @@ class ToolProviderListApi(Resource): user_id = user.id - args = parser_tool.parse_args() + raw_args = request.args.to_dict() + query = ToolProviderListQuery.model_validate(raw_args) - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore @console_ns.route("/workspaces/current/tool-provider/builtin//tools") @@ -110,14 +297,9 @@ class ToolBuiltinProviderInfoApi(Resource): return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) -parser_delete = reqparse.RequestParser().add_argument( - "credential_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//delete") class ToolBuiltinProviderDeleteApi(Resource): - @console_ns.expect(parser_delete) + @console_ns.expect(console_ns.models[BuiltinToolCredentialDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -125,26 +307,18 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): _, tenant_id = current_account_with_tenant() - args = parser_delete.parse_args() + payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.delete_builtin_tool_provider( tenant_id, provider, - args["credential_id"], + payload.credential_id, ) -parser_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") - .add_argument("type", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//add") class ToolBuiltinProviderAddApi(Resource): - @console_ns.expect(parser_add) + @console_ns.expect(console_ns.models[BuiltinToolAddPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -153,32 +327,21 @@ class ToolBuiltinProviderAddApi(Resource): user_id = user.id - args = parser_add.parse_args() - - if args["type"] not in CredentialType.values(): - raise ValueError(f"Invalid credential type: {args['type']}") + payload = BuiltinToolAddPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.add_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credentials=args["credentials"], - name=args["name"], - api_type=CredentialType.of(args["type"]), + credentials=payload.credentials, + name=payload.name, + api_type=CredentialType.of(payload.type), ) -parser_update = ( - reqparse.RequestParser() - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//update") class ToolBuiltinProviderUpdateApi(Resource): - @console_ns.expect(parser_update) + @console_ns.expect(console_ns.models[BuiltinToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -187,15 +350,15 @@ class ToolBuiltinProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_update.parse_args() + payload = BuiltinToolUpdatePayload.model_validate(console_ns.payload or {}) result = BuiltinToolManageService.update_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credential_id=args["credential_id"], - credentials=args.get("credentials", None), - name=args.get("name", ""), + credential_id=payload.credential_id, + credentials=payload.credentials, + name=payload.name or "", ) return result @@ -225,22 +388,9 @@ class ToolBuiltinProviderIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) -parser_api_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) - .add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/add") class ToolApiProviderAddApi(Resource): - @console_ns.expect(parser_api_add) + @console_ns.expect(console_ns.models[ApiToolProviderAddPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -250,28 +400,24 @@ class ToolApiProviderAddApi(Resource): user_id = user.id - args = parser_api_add.parse_args() + payload = ApiToolProviderAddPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args["provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args.get("privacy_policy", ""), - args.get("custom_disclaimer", ""), - args.get("labels", []), + payload.provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema_, + payload.privacy_policy or "", + payload.custom_disclaimer or "", + payload.labels or [], ) -parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/remote") class ToolApiProviderGetRemoteSchemaApi(Resource): - @console_ns.expect(parser_remote) @setup_required @login_required @account_initialization_required @@ -280,23 +426,18 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): user_id = user.id - args = parser_remote.parse_args() + raw_args = request.args.to_dict() + query = UrlQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider_remote_schema( user_id, tenant_id, - args["url"], + str(query.url), ) -parser_tools = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/api/tools") class ToolApiProviderListToolsApi(Resource): - @console_ns.expect(parser_tools) @setup_required @login_required @account_initialization_required @@ -305,34 +446,21 @@ class ToolApiProviderListToolsApi(Resource): user_id = user.id - args = parser_tools.parse_args() + raw_args = request.args.to_dict() + query = ProviderQuery.model_validate(raw_args) return jsonable_encoder( ApiToolManageService.list_api_tool_provider_tools( user_id, tenant_id, - args["provider"], + query.provider, ) ) -parser_api_update = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("original_provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") - .add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/update") class ToolApiProviderUpdateApi(Resource): - @console_ns.expect(parser_api_update) + @console_ns.expect(console_ns.models[ApiToolProviderUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -342,31 +470,26 @@ class ToolApiProviderUpdateApi(Resource): user_id = user.id - args = parser_api_update.parse_args() + payload = ApiToolProviderUpdatePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args["provider"], - args["original_provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args["privacy_policy"], - args["custom_disclaimer"], - args.get("labels", []), + payload.provider, + payload.original_provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema_, + payload.privacy_policy, + payload.custom_disclaimer, + payload.labels or [], ) -parser_api_delete = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/delete") class ToolApiProviderDeleteApi(Resource): - @console_ns.expect(parser_api_delete) + @console_ns.expect(console_ns.models[ApiToolProviderDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -376,21 +499,17 @@ class ToolApiProviderDeleteApi(Resource): user_id = user.id - args = parser_api_delete.parse_args() + payload = ApiToolProviderDeletePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args["provider"], + payload.provider, ) -parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/get") class ToolApiProviderGetApi(Resource): - @console_ns.expect(parser_get) @setup_required @login_required @account_initialization_required @@ -399,12 +518,13 @@ class ToolApiProviderGetApi(Resource): user_id = user.id - args = parser_get.parse_args() + raw_args = request.args.to_dict() + query = ProviderQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args["provider"], + query.provider, ) @@ -423,72 +543,43 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): ) -parser_schema = reqparse.RequestParser().add_argument( - "schema", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/schema") class ToolApiProviderSchemaApi(Resource): - @console_ns.expect(parser_schema) + @console_ns.expect(console_ns.models[ApiToolSchemaPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_schema.parse_args() + payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.parser_api_schema( - schema=args["schema"], + schema=payload.schema_, ) -parser_pre = ( - reqparse.RequestParser() - .add_argument("tool_name", type=str, required=True, nullable=False, location="json") - .add_argument("provider_name", type=str, required=False, nullable=False, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/test/pre") class ToolApiProviderPreviousTestApi(Resource): - @console_ns.expect(parser_pre) + @console_ns.expect(console_ns.models[ApiToolTestPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_pre.parse_args() + payload = ApiToolTestPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() return ApiToolManageService.test_api_tool_preview( current_tenant_id, - args["provider_name"] or "", - args["tool_name"], - args["credentials"], - args["parameters"], - args["schema_type"], - args["schema"], + payload.provider_name or "", + payload.tool_name, + payload.credentials, + payload.parameters, + payload.schema_type, + payload.schema_, ) -parser_create = ( - reqparse.RequestParser() - .add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/create") class ToolWorkflowProviderCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[WorkflowToolCreatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -498,38 +589,25 @@ class ToolWorkflowProviderCreateApi(Resource): user_id = user.id - args = parser_create.parse_args() + payload = WorkflowToolCreatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.create_workflow_tool( user_id=user_id, tenant_id=tenant_id, - workflow_app_id=args["workflow_app_id"], - name=args["name"], - label=args["label"], - icon=args["icon"], - description=args["description"], - parameters=args["parameters"], - privacy_policy=args["privacy_policy"], - labels=args["labels"], + workflow_app_id=payload.workflow_app_id, + name=payload.name, + label=payload.label, + icon=payload.icon, + description=payload.description, + parameters=payload.parameters, + privacy_policy=payload.privacy_policy or "", + labels=payload.labels or [], ) -parser_workflow_update = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/update") class ToolWorkflowProviderUpdateApi(Resource): - @console_ns.expect(parser_workflow_update) + @console_ns.expect(console_ns.models[WorkflowToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -538,33 +616,25 @@ class ToolWorkflowProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_workflow_update.parse_args() - - if not args["workflow_tool_id"]: - raise ValueError("incorrect workflow_tool_id") + payload = WorkflowToolUpdatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], - args["name"], - args["label"], - args["icon"], - args["description"], - args["parameters"], - args["privacy_policy"], - args.get("labels", []), + payload.workflow_tool_id, + payload.name, + payload.label, + payload.icon, + payload.description, + payload.parameters, + payload.privacy_policy or "", + payload.labels or [], ) -parser_workflow_delete = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/delete") class ToolWorkflowProviderDeleteApi(Resource): - @console_ns.expect(parser_workflow_delete) + @console_ns.expect(console_ns.models[WorkflowToolDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -574,25 +644,17 @@ class ToolWorkflowProviderDeleteApi(Resource): user_id = user.id - args = parser_workflow_delete.parse_args() + payload = WorkflowToolDeletePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], + payload.workflow_tool_id, ) -parser_wf_get = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") - .add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/get") class ToolWorkflowProviderGetApi(Resource): - @console_ns.expect(parser_wf_get) @setup_required @login_required @account_initialization_required @@ -601,19 +663,20 @@ class ToolWorkflowProviderGetApi(Resource): user_id = user.id - args = parser_wf_get.parse_args() + raw_args = request.args.to_dict() + query = WorkflowToolGetQuery.model_validate(raw_args) - if args.get("workflow_tool_id"): + if query.workflow_tool_id: tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) - elif args.get("workflow_app_id"): + elif query.workflow_app_id: tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args["workflow_app_id"], + query.workflow_app_id, ) else: raise ValueError("incorrect workflow_tool_id or workflow_app_id") @@ -621,14 +684,8 @@ class ToolWorkflowProviderGetApi(Resource): return jsonable_encoder(tool) -parser_wf_tools = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/tools") class ToolWorkflowProviderListToolApi(Resource): - @console_ns.expect(parser_wf_tools) @setup_required @login_required @account_initialization_required @@ -637,13 +694,14 @@ class ToolWorkflowProviderListToolApi(Resource): user_id = user.id - args = parser_wf_tools.parse_args() + raw_args = request.args.to_dict() + query = WorkflowToolListQuery.model_validate(raw_args) return jsonable_encoder( WorkflowToolManageService.list_single_workflow_tools( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) ) @@ -810,49 +868,39 @@ class ToolOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_default_cred = reqparse.RequestParser().add_argument( - "id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//default-credential") class ToolBuiltinProviderSetDefaultApi(Resource): - @console_ns.expect(parser_default_cred) + @console_ns.expect(console_ns.models[BuiltinProviderDefaultCredentialPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self, provider): current_user, current_tenant_id = current_account_with_tenant() - args = parser_default_cred.parse_args() + payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.set_default_provider( - tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id ) -parser_custom = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//oauth/custom-client") class ToolOAuthCustomClient(Resource): - @console_ns.expect(parser_custom) + @console_ns.expect(console_ns.models[ToolOAuthCustomClientPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - args = parser_custom.parse_args() + payload = ToolOAuthCustomClientPayload.model_validate(console_ns.payload or {}) _, tenant_id = current_account_with_tenant() return BuiltinToolManageService.save_custom_oauth_client_params( tenant_id=tenant_id, provider=provider, - client_params=args.get("client_params", {}), - enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), + client_params=payload.client_params or {}, + enable_oauth_custom_client=payload.enable_oauth_custom_client + if payload.enable_oauth_custom_client is not None + else True, ) @setup_required @@ -904,49 +952,19 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): ) -parser_mcp = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_put = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json") - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_delete = reqparse.RequestParser().add_argument( - "provider_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/mcp") class ToolProviderMCPApi(Resource): - @console_ns.expect(parser_mcp) + @console_ns.expect(console_ns.models[MCPProviderCreatePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_mcp.parse_args() + payload = MCPProviderCreatePayload.model_validate(console_ns.payload or {}) user, tenant_id = current_account_with_tenant() # Parse and validate models - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None # 1) Create provider in a short transaction (no network I/O inside) with session_factory.create_session() as session, session.begin(): @@ -954,13 +972,13 @@ class ToolProviderMCPApi(Resource): result = service.create_provider( tenant_id=tenant_id, user_id=user.id, - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background, + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, ) @@ -969,8 +987,8 @@ class ToolProviderMCPApi(Resource): # Perform network I/O outside any DB session to avoid holding locks. try: reconnect = MCPToolManageService.reconnect_with_url( - server_url=args["server_url"], - headers=args.get("headers") or {}, + server_url=payload.server_url, + headers=payload.headers or {}, timeout=configuration.timeout, sse_read_timeout=configuration.sse_read_timeout, ) @@ -988,14 +1006,14 @@ class ToolProviderMCPApi(Resource): return jsonable_encoder(result) - @console_ns.expect(parser_mcp_put) + @console_ns.expect(console_ns.models[MCPProviderUpdatePayload.__name__]) @setup_required @login_required @account_initialization_required def put(self): - args = parser_mcp_put.parse_args() - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {}) + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None _, current_tenant_id = current_account_with_tenant() # Step 1: Get provider data for URL validation (short-lived session, no network I/O) @@ -1003,14 +1021,14 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session: service = MCPToolManageService(session=session) validation_data = service.get_provider_for_url_validation( - tenant_id=current_tenant_id, provider_id=args["provider_id"] + tenant_id=current_tenant_id, provider_id=payload.provider_id ) # Step 2: Perform URL validation with network I/O OUTSIDE of any database session # This prevents holding database locks during potentially slow network operations validation_result = MCPToolManageService.validate_server_url_standalone( tenant_id=current_tenant_id, - new_server_url=args["server_url"], + new_server_url=payload.server_url, validation_data=validation_data, ) @@ -1019,14 +1037,14 @@ class ToolProviderMCPApi(Resource): service = MCPToolManageService(session=session) service.update_provider( tenant_id=current_tenant_id, - provider_id=args["provider_id"], - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + provider_id=payload.provider_id, + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background, + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, validation_result=validation_result, @@ -1034,37 +1052,30 @@ class ToolProviderMCPApi(Resource): return {"result": "success"} - @console_ns.expect(parser_mcp_delete) + @console_ns.expect(console_ns.models[MCPProviderDeletePayload.__name__]) @setup_required @login_required @account_initialization_required def delete(self): - args = parser_mcp_delete.parse_args() + payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) - service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) + service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id) return {"result": "success"} -parser_auth = ( - reqparse.RequestParser() - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("authorization_code", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/mcp/auth") class ToolMCPAuthApi(Resource): - @console_ns.expect(parser_auth) + @console_ns.expect(console_ns.models[MCPAuthPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_auth.parse_args() - provider_id = args["provider_id"] + payload = MCPAuthPayload.model_validate(console_ns.payload or {}) + provider_id = payload.provider_id _, tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): @@ -1102,7 +1113,7 @@ class ToolMCPAuthApi(Resource): # Pass the extracted OAuth metadata hints to auth() auth_result = auth( provider_entity, - args.get("authorization_code"), + payload.authorization_code, resource_metadata_url=e.resource_metadata_url, scope_hint=e.scope_hint, ) @@ -1167,20 +1178,13 @@ class ToolMCPUpdateApi(Resource): return jsonable_encoder(tools) -parser_cb = ( - reqparse.RequestParser() - .add_argument("code", type=str, required=True, nullable=False, location="args") - .add_argument("state", type=str, required=True, nullable=False, location="args") -) - - @console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): - @console_ns.expect(parser_cb) def get(self): - args = parser_cb.parse_args() - state_key = args["state"] - authorization_code = args["code"] + raw_args = request.args.to_dict() + query = MCPCallbackQuery.model_validate(raw_args) + state_key = query.state + authorization_code = query.code # Create service instance for handle_callback with Session(db.engine) as session, session.begin(): diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 85ac9336d6..ef254ca357 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,16 +1,16 @@ from typing import Literal from flask import request -from flask_restx import Namespace, Resource, fields +from flask_restx import Resource from flask_restx.api import HTTPStatus -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from controllers.common.schema import register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client -from fields.annotation_fields import annotation_fields, build_annotation_model +from fields.annotation_fields import Annotation, AnnotationList from models.model import App from services.annotation_service import AppAnnotationService @@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel): embedding_model_name: str = Field(description="Embedding model name") -register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload) +register_schema_models( + service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList +) @service_api_ns.route("/apps/annotation-reply/") @@ -45,10 +47,11 @@ class AnnotationReplyActionApi(Resource): def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() - if action == "enable": - result = AppAnnotationService.enable_app_annotation(args, app_model.id) - elif action == "disable": - result = AppAnnotationService.disable_app_annotation(app_model.id) + match action: + case "enable": + result = AppAnnotationService.enable_app_annotation(args, app_model.id) + case "disable": + result = AppAnnotationService.disable_app_annotation(app_model.id) return result, 200 @@ -82,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 -# Define annotation list response model -annotation_list_fields = { - "data": fields.List(fields.Nested(annotation_fields)), - "has_more": fields.Boolean, - "limit": fields.Integer, - "total": fields.Integer, - "page": fields.Integer, -} - - -def build_annotation_list_model(api_or_ns: Namespace): - """Build the annotation list model for the API or Namespace.""" - copied_annotation_list_fields = annotation_list_fields.copy() - copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) - return api_or_ns.model("AnnotationList", copied_annotation_list_fields) - - @service_api_ns.route("/apps/annotations") class AnnotationListApi(Resource): @service_api_ns.doc("list_annotations") @@ -109,8 +95,12 @@ class AnnotationListApi(Resource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + 200, + "Annotations retrieved successfully", + service_api_ns.models[AnnotationList.__name__], + ) @validate_app_token - @service_api_ns.marshal_with(build_annotation_list_model(service_api_ns)) def get(self, app_model: App): """List annotations for the application.""" page = request.args.get("page", default=1, type=int) @@ -118,13 +108,15 @@ class AnnotationListApi(Resource): keyword = request.args.get("keyword", default="", type=str) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) - return { - "data": annotation_list, - "has_more": len(annotation_list) == limit, - "limit": limit, - "total": total, - "page": page, - } + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response = AnnotationList( + data=annotation_models, + has_more=len(annotation_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json") @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__]) @service_api_ns.doc("create_annotation") @@ -135,13 +127,18 @@ class AnnotationListApi(Resource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + HTTPStatus.CREATED, + "Annotation created successfully", + service_api_ns.models[Annotation.__name__], + ) @validate_app_token - @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App): """Create a new annotation.""" args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) - return annotation, 201 + response = Annotation.model_validate(annotation, from_attributes=True) + return response.model_dump(mode="json"), HTTPStatus.CREATED @service_api_ns.route("/apps/annotations/") @@ -158,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource): 404: "Annotation not found", } ) + @service_api_ns.response( + 200, + "Annotation updated successfully", + service_api_ns.models[Annotation.__name__], + ) @validate_app_token @edit_permission_required - @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) - return annotation + response = Annotation.model_validate(annotation, from_attributes=True) + return response.model_dump(mode="json") @service_api_ns.doc("delete_annotation") @service_api_ns.doc(description="Delete an annotation") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index b3836f3a47..9d8431f066 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -30,6 +30,7 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService @@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel): query: str files: list[dict[str, Any]] | None = None response_mode: Literal["blocking", "streaming"] | None = None - conversation_id: str | None = Field(default=None, description="Conversation UUID") + conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID") retriever_from: str = Field(default="dev") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 62e8258e25..8e29c9ff0f 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,4 @@ from typing import Any, Literal -from uuid import UUID from flask import request from flask_restx import Resource @@ -23,12 +22,13 @@ from fields.conversation_variable_fields import ( build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_model, ) +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService class ConversationListQuery(BaseModel): - last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination") + last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return") sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( default="-updated_at", description="Sort order for conversations" @@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel): class ConversationVariablesQuery(BaseModel): - last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") + last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") variable_name: str | None = Field( default=None, description="Filter variables by name", min_length=1, max_length=255 diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 8981bbd7d5..2aaf920efb 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,6 +1,5 @@ import logging from typing import Literal -from uuid import UUID from flask import request from flask_restx import Resource @@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -27,8 +27,8 @@ logger = logging.getLogger(__name__) class MessageListQuery(BaseModel): - conversation_id: UUID - first_id: UUID | None = None + conversation_id: UUIDStrOrEmpty + first_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 6088b142c2..6a549fc926 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -33,9 +33,8 @@ from core.workflow.graph_engine.manager import GraphEngineManager from extensions.ext_database import db from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model from libs import helper -from libs.helper import OptionalTimestampField, TimestampField +from libs.helper import TimestampField from models.model import App, AppMode, EndUser -from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError @@ -64,32 +63,17 @@ class WorkflowLogQuery(BaseModel): register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) - -class WorkflowRunStatusField(fields.Raw): - def output(self, key, obj: WorkflowRun, **kwargs): - return obj.status.value - - -class WorkflowRunOutputsField(fields.Raw): - def output(self, key, obj: WorkflowRun, **kwargs): - if obj.status == WorkflowExecutionStatus.PAUSED: - return {} - - outputs = obj.outputs_dict - return outputs or {} - - workflow_run_fields = { "id": fields.String, "workflow_id": fields.String, - "status": WorkflowRunStatusField, + "status": fields.String, "inputs": fields.Raw, - "outputs": WorkflowRunOutputsField, + "outputs": fields.Raw, "error": fields.String, "total_steps": fields.Integer, "total_tokens": fields.Integer, "created_at": TimestampField, - "finished_at": OptionalTimestampField, + "finished_at": TimestampField, "elapsed_time": fields.Float, } diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index c11f64585a..db5cabe8aa 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -17,7 +17,7 @@ from controllers.service_api.wraps import ( from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields -from fields.tag_fields import build_dataset_tag_fields +from fields.tag_fields import DataSetTag from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum @@ -114,6 +114,7 @@ register_schema_models( TagBindingPayload, TagUnbindingPayload, DatasetListQuery, + DataSetTag, ) @@ -480,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _): """Get all knowledge type tags.""" assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None tags = TagService.get_tags("knowledge", cid) - - return tags, 200 + tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True) + return [tag.model_dump(mode="json") for tag in tag_models], 200 @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__]) @service_api_ns.doc("create_dataset_tag") @@ -500,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def post(self, _): """Add a knowledge type tag.""" assert isinstance(current_user, Account) @@ -510,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource): payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + response = DataSetTag.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + ).model_dump(mode="json") return response, 200 @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__]) @@ -523,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def patch(self, _): assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -536,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource): binding_count = TagService.get_tag_binding_count(tag_id) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} - + response = DataSetTag.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + ).model_dump(mode="json") return response, 200 @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__]) diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 8dbb690901..97a70f5d0e 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,7 +1,10 @@ -from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase +from controllers.common.schema import register_schema_model +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +register_schema_model(service_api_ns, HitTestingPayload) + @service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve") class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): @@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): 404: "Dataset not found", } ) + @service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__]) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Perform hit testing on a dataset. diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index b8d9508004..692342a38a 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 24acced0d1..e597a72fc0 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -73,14 +73,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe # If caller needs end-user context, attach EndUser to current_user if fetch_user_arg: - if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: - user_id = request.args.get("user") - elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: - user_id = request.get_json().get("user") - elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: - user_id = request.form.get("user") - else: - user_id = None + user_id = None + match fetch_user_arg.fetch_from: + case WhereisUserArg.QUERY: + user_id = request.args.get("user") + case WhereisUserArg.JSON: + user_id = request.get_json().get("user") + case WhereisUserArg.FORM: + user_id = request.form.get("user") if not user_id and fetch_user_arg.required: raise ValueError("Arg user must be provided.") diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index cfa39e0dfd..1d22954308 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -23,7 +23,6 @@ from . import ( feature, files, forgot_password, - human_input_form, login, message, passport, @@ -31,7 +30,6 @@ from . import ( saved_message, site, workflow, - workflow_events, ) api.add_namespace(web_ns) @@ -46,7 +44,6 @@ __all__ = [ "feature", "files", "forgot_password", - "human_input_form", "login", "message", "passport", @@ -55,5 +52,4 @@ __all__ = [ "site", "web_ns", "workflow", - "workflow_events", ] diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index d1f936768e..196a27e348 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -117,12 +117,6 @@ class InvokeRateLimitError(BaseHTTPException): code = 429 -class WebFormRateLimitExceededError(BaseHTTPException): - error_code = "web_form_rate_limit_exceeded" - description = "Too many form requests. Please try again later." - code = 429 - - class NotFoundError(BaseHTTPException): error_code = "not_found" code = 404 diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py deleted file mode 100644 index c3989b1965..0000000000 --- a/api/controllers/web/human_input_form.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Web App Human Input Form APIs. -""" - -import json -import logging -from datetime import datetime - -from flask import Response, request -from flask_restx import Resource, reqparse -from werkzeug.exceptions import Forbidden - -from configs import dify_config -from controllers.web import web_ns -from controllers.web.error import NotFoundError, WebFormRateLimitExceededError -from controllers.web.site import serialize_app_site_payload -from extensions.ext_database import db -from libs.helper import RateLimiter, extract_remote_ip -from models.account import TenantStatus -from models.model import App, Site -from services.human_input_service import Form, FormNotFoundError, HumanInputService - -logger = logging.getLogger(__name__) - -_FORM_SUBMIT_RATE_LIMITER = RateLimiter( - prefix="web_form_submit_rate_limit", - max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, - time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS, -) -_FORM_ACCESS_RATE_LIMITER = RateLimiter( - prefix="web_form_access_rate_limit", - max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, - time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS, -) - - -def _stringify_default_values(values: dict[str, object]) -> dict[str, str]: - result: dict[str, str] = {} - for key, value in values.items(): - if value is None: - result[key] = "" - elif isinstance(value, (dict, list)): - result[key] = json.dumps(value, ensure_ascii=False) - else: - result[key] = str(value) - return result - - -def _to_timestamp(value: datetime) -> int: - return int(value.timestamp()) - - -def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response: - """Return the form payload (optionally with site) as a JSON response.""" - definition_payload = form.get_definition().model_dump() - payload = { - "form_content": definition_payload["rendered_content"], - "inputs": definition_payload["inputs"], - "resolved_default_values": _stringify_default_values(definition_payload["default_values"]), - "user_actions": definition_payload["user_actions"], - "expiration_time": _to_timestamp(form.expiration_time), - } - if site_payload is not None: - payload["site"] = site_payload - return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json") - - -# TODO(QuantumGhost): disable authorization for web app -# form api temporarily - - -@web_ns.route("/form/human_input/") -# class HumanInputFormApi(WebApiResource): -class HumanInputFormApi(Resource): - """API for getting and submitting human input forms via the web app.""" - - # def get(self, _app_model: App, _end_user: EndUser, form_token: str): - def get(self, form_token: str): - """ - Get human input form definition by token. - - GET /api/form/human_input/ - """ - ip_address = extract_remote_ip(request) - if _FORM_ACCESS_RATE_LIMITER.is_rate_limited(ip_address): - raise WebFormRateLimitExceededError() - _FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address) - - service = HumanInputService(db.engine) - # TODO(QuantumGhost): forbid submision for form tokens - # that are only for console. - form = service.get_form_by_token(form_token) - - if form is None: - raise NotFoundError("Form not found") - - service.ensure_form_active(form) - app_model, site = _get_app_site_from_form(form) - - return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None)) - - # def post(self, _app_model: App, _end_user: EndUser, form_token: str): - def post(self, form_token: str): - """ - Submit human input form by token. - - POST /api/form/human_input/ - - Request body: - { - "inputs": { - "content": "User input content" - }, - "action": "Approve" - } - """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() - - ip_address = extract_remote_ip(request) - if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address): - raise WebFormRateLimitExceededError() - _FORM_SUBMIT_RATE_LIMITER.increment_rate_limit(ip_address) - - service = HumanInputService(db.engine) - form = service.get_form_by_token(form_token) - if form is None: - raise NotFoundError("Form not found") - - if (recipient_type := form.recipient_type) is None: - logger.warning("Recipient type is None for form, form_id=%", form.id) - raise AssertionError("Recipient type is None") - - try: - service.submit_form_by_token( - recipient_type=recipient_type, - form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], - submission_end_user_id=None, - # submission_end_user_id=_end_user.id, - ) - except FormNotFoundError: - raise NotFoundError("Form not found") - - return {}, 200 - - -def _get_app_site_from_form(form: Form) -> tuple[App, Site]: - """Resolve App/Site for the form's app and validate tenant status.""" - app_model = db.session.query(App).where(App.id == form.app_id).first() - if app_model is None or app_model.tenant_id != form.tenant_id: - raise NotFoundError("Form not found") - - site = db.session.query(Site).where(Site.app_id == app_model.id).first() - if site is None: - raise Forbidden() - - if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE: - raise Forbidden() - - return app_model, site diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index f957229ece..b01aaba357 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,6 +1,4 @@ -from typing import cast - -from flask_restx import fields, marshal, marshal_with +from flask_restx import fields, marshal_with from werkzeug.exceptions import Forbidden from configs import dify_config @@ -9,7 +7,7 @@ from controllers.web.wraps import WebApiResource from extensions.ext_database import db from libs.helper import AppIconUrlField from models.account import TenantStatus -from models.model import App, Site +from models.model import Site from services.feature_service import FeatureService @@ -110,14 +108,3 @@ class AppSiteInfo: "remove_webapp_brand": remove_webapp_brand, "replace_webapp_logo": replace_webapp_logo, } - - -def serialize_site(site: Site) -> dict: - """Serialize Site model using the same schema as AppSiteApi.""" - return cast(dict, marshal(site, AppSiteApi.site_fields)) - - -def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict: - can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo - app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo) - return cast(dict, marshal(app_site_info, AppSiteApi.app_fields)) diff --git a/api/controllers/web/workflow_events.py b/api/controllers/web/workflow_events.py deleted file mode 100644 index 61568e70e6..0000000000 --- a/api/controllers/web/workflow_events.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -Web App Workflow Resume APIs. -""" - -import json -from collections.abc import Generator - -from flask import Response, request -from sqlalchemy.orm import sessionmaker - -from controllers.web import api -from controllers.web.error import InvalidArgumentError, NotFoundError -from controllers.web.wraps import WebApiResource -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.apps.message_generator import MessageGenerator -from core.app.apps.workflow.app_generator import WorkflowAppGenerator -from extensions.ext_database import db -from models.enums import CreatorUserRole -from models.model import App, AppMode, EndUser -from repositories.factory import DifyAPIRepositoryFactory -from services.workflow_event_snapshot_service import build_workflow_event_stream - - -class WorkflowEventsApi(WebApiResource): - """API for getting workflow execution events after resume.""" - - def get(self, app_model: App, end_user: EndUser, task_id: str): - """ - Get workflow execution events stream after resume. - - GET /api/workflow//events - - Returns Server-Sent Events stream. - """ - workflow_run_id = task_id - session_maker = sessionmaker(db.engine) - repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - workflow_run = repo.get_workflow_run_by_id_and_tenant_id( - tenant_id=app_model.tenant_id, - run_id=workflow_run_id, - ) - - if workflow_run is None: - raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}") - - if workflow_run.app_id != app_model.id: - raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}") - - if workflow_run.created_by_role != CreatorUserRole.END_USER: - raise NotFoundError(f"WorkflowRun not created by end user, id={workflow_run_id}") - - if workflow_run.created_by != end_user.id: - raise NotFoundError(f"WorkflowRun not created by the current end user, id={workflow_run_id}") - - if workflow_run.finished_at is not None: - response = WorkflowResponseConverter.workflow_run_result_to_finish_response( - task_id=workflow_run.id, - workflow_run=workflow_run, - creator_user=end_user, - ) - - payload = response.model_dump(mode="json") - payload["event"] = response.event.value - - def _generate_finished_events() -> Generator[str, None, None]: - yield f"data: {json.dumps(payload)}\n\n" - - event_generator = _generate_finished_events - else: - app_mode = AppMode.value_of(app_model.mode) - msg_generator = MessageGenerator() - generator: BaseAppGenerator - if app_mode == AppMode.ADVANCED_CHAT: - generator = AdvancedChatAppGenerator() - elif app_mode == AppMode.WORKFLOW: - generator = WorkflowAppGenerator() - else: - raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") - - include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" - - def _generate_stream_events(): - if include_state_snapshot: - return generator.convert_to_event_stream( - build_workflow_event_stream( - app_mode=app_mode, - workflow_run=workflow_run, - tenant_id=app_model.tenant_id, - app_id=app_model.id, - session_maker=session_maker, - ) - ) - return generator.convert_to_event_stream( - msg_generator.retrieve_events(app_mode, workflow_run.id), - ) - - event_generator = _generate_stream_events - - return Response( - event_generator(), - mimetype="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) - - -# Register the APIs -api.add_resource(WorkflowEventsApi, "/workflow//events") diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index c1f336fdde..9b981dfc09 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -14,16 +14,17 @@ class AgentConfigManager: agent_dict = config.get("agent_mode", {}) agent_strategy = agent_dict.get("strategy", "cot") - if agent_strategy == "function_call": - strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy in {"cot", "react"}: - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - else: - # old configs, try to detect default strategy - if config["model"]["provider"] == "openai": + match agent_strategy: + case "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - else: + case "cot" | "react": strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + case _: + # old configs, try to detect default strategy + if config["model"]["provider"] == "openai": + strategy = AgentEntity.Strategy.FUNCTION_CALLING + else: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT agent_tools = [] for tool in agent_dict.get("tools", []): diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 2891d3ceeb..528c45f6c8 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -4,8 +4,8 @@ import contextvars import logging import threading import uuid -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload +from collections.abc import Generator, Mapping +from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -29,25 +29,21 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory -from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom -from models.base import Base from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.workflow_draft_variable_service import ( @@ -69,9 +65,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user: Union[Account, EndUser], args: Mapping[str, Any], invoke_from: InvokeFrom, - workflow_run_id: str, streaming: Literal[False], - pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any]: ... @overload @@ -80,11 +74,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: Mapping[str, Any], + args: Mapping, invoke_from: InvokeFrom, - workflow_run_id: str, streaming: Literal[True], - pause_state_config: PauseStateLayerConfig | None = None, ) -> Generator[Mapping | str, None, None]: ... @overload @@ -93,11 +85,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: Mapping[str, Any], + args: Mapping, invoke_from: InvokeFrom, - workflow_run_id: str, streaming: bool, - pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ... def generate( @@ -105,11 +95,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: Mapping[str, Any], + args: Mapping, invoke_from: InvokeFrom, - workflow_run_id: str, streaming: bool = True, - pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: """ Generate App response. @@ -173,6 +161,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode app_config.additional_features.show_retrieve_source = True # type: ignore + workflow_run_id = str(uuid.uuid4()) # init application generate entity application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), @@ -190,7 +179,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, extras=extras, trace_manager=trace_manager, - workflow_run_id=str(workflow_run_id), + workflow_run_id=workflow_run_id, ) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -227,38 +216,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=conversation, stream=streaming, - pause_state_config=pause_state_config, - ) - - def resume( - self, - *, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - conversation: Conversation, - message: Message, - application_generate_entity: AdvancedChatAppGenerateEntity, - workflow_execution_repository: WorkflowExecutionRepository, - workflow_node_execution_repository: WorkflowNodeExecutionRepository, - graph_runtime_state: GraphRuntimeState, - pause_state_config: PauseStateLayerConfig | None = None, - ): - """ - Resume a paused advanced chat execution. - """ - return self._generate( - workflow=workflow, - user=user, - invoke_from=application_generate_entity.invoke_from, - application_generate_entity=application_generate_entity, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - conversation=conversation, - message=message, - stream=application_generate_entity.stream, - pause_state_config=pause_state_config, - graph_runtime_state=graph_runtime_state, ) def single_iteration_generate( @@ -439,12 +396,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Conversation | None = None, - message: Message | None = None, stream: bool = True, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, - pause_state_config: PauseStateLayerConfig | None = None, - graph_runtime_state: GraphRuntimeState | None = None, - graph_engine_layers: Sequence[GraphEngineLayer] = (), ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: """ Generate App response. @@ -458,12 +411,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param conversation: conversation :param stream: is stream """ - is_first_conversation = conversation is None + is_first_conversation = False + if not conversation: + is_first_conversation = True - if conversation is not None and message is not None: - pass - else: - conversation, message = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) if is_first_conversation: # update conversation features @@ -486,16 +439,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) - ) - # new thread with request context and contextvars context = contextvars.copy_context() @@ -511,25 +454,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): "variable_loader": variable_loader, "workflow_execution_repository": workflow_execution_repository, "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, }, ) worker_thread.start() # release database connection, because the following new thread operations may take a long time - with Session(bind=db.engine, expire_on_commit=False) as session: - workflow = _refresh_model(session, workflow) - message = _refresh_model(session, message) - # workflow_ = session.get(Workflow, workflow.id) - # assert workflow_ is not None - # workflow = workflow_ - # message_ = session.get(Message, message.id) - # assert message_ is not None - # message = message_ - # db.session.refresh(workflow) - # db.session.refresh(message) + db.session.refresh(workflow) + db.session.refresh(message) # db.session.refresh(user) db.session.close() @@ -558,8 +490,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): variable_loader: VariableLoader, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, - graph_engine_layers: Sequence[GraphEngineLayer] = (), - graph_runtime_state: GraphRuntimeState | None = None, ): """ Generate worker in a new thread. @@ -617,8 +547,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app=app, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, - graph_engine_layers=graph_engine_layers, - graph_runtime_state=graph_runtime_state, ) try: @@ -686,13 +614,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id) raise e - - -_T = TypeVar("_T", bound=Base) - - -def _refresh_model(session, model: _T) -> _T: - with Session(bind=db.engine, expire_on_commit=False) as session: - detach_model = session.get(type(model), model.id) - assert detach_model is not None - return detach_model diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 8b20442eab..d702db0908 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -66,7 +66,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, graph_engine_layers: Sequence[GraphEngineLayer] = (), - graph_runtime_state: GraphRuntimeState | None = None, ): super().__init__( queue_manager=queue_manager, @@ -83,7 +82,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._app = app self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository - self._resume_graph_runtime_state = graph_runtime_state @trace_span(WorkflowAppRunnerHandler) def run(self): @@ -112,21 +110,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): invoke_from = InvokeFrom.DEBUGGER user_from = self._resolve_user_from(invoke_from) - resume_state = self._resume_graph_runtime_state - - if resume_state is not None: - graph_runtime_state = resume_state - variable_pool = graph_runtime_state.variable_pool - graph = self._init_graph( - graph_config=self._workflow.graph_dict, - graph_runtime_state=graph_runtime_state, - workflow_id=self._workflow.id, - tenant_id=self._workflow.tenant_id, - user_id=self.application_generate_entity.user_id, - invoke_from=invoke_from, - user_from=user_from, - ) - elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: # Handle single iteration or single loop run graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 00a6a3d9af..da1e9f19b6 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -24,8 +24,6 @@ from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueAnnotationReplyEvent, QueueErrorEvent, - QueueHumanInputFormFilledEvent, - QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -44,7 +42,6 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowPartialSuccessEvent, - QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, WorkflowQueueMessage, @@ -66,8 +63,6 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import WorkflowExecutionStatus from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory @@ -76,8 +71,7 @@ from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile -from models.enums import CreatorUserRole, MessageStatus -from models.execution_extra_content import HumanInputContent +from models.enums import CreatorUserRole from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -134,7 +128,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) self._task_state = WorkflowTaskState() - self._seed_task_state_from_message(message) self._message_cycle_manager = MessageCycleManager( application_generate_entity=application_generate_entity, task_state=self._task_state ) @@ -142,7 +135,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._application_generate_entity = application_generate_entity self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict - self._workflow_tenant_id = workflow.tenant_id self._conversation_id = conversation.id self._conversation_mode = conversation.mode self._message_id = message.id @@ -152,13 +144,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._workflow_run_id: str = "" self._draft_var_saver_factory = draft_var_saver_factory self._graph_runtime_state: GraphRuntimeState | None = None - self._message_saved_on_pause = False self._seed_graph_runtime_state_from_queue_manager() - def _seed_task_state_from_message(self, message: Message) -> None: - if message.status == MessageStatus.PAUSED and message.answer: - self._task_state.answer = message.answer - def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. @@ -321,7 +308,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): task_id=self._application_generate_entity.task_id, workflow_run_id=run_id, workflow_id=self._workflow_id, - reason=event.reason, ) yield workflow_start_resp @@ -539,35 +525,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) yield workflow_finish_resp - - def _handle_workflow_paused_event( - self, - event: QueueWorkflowPausedEvent, - **kwargs, - ) -> Generator[StreamResponse, None, None]: - """Handle workflow paused events.""" - validated_state = self._ensure_graph_runtime_initialized() - responses = self._workflow_response_converter.workflow_pause_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - graph_runtime_state=validated_state, - ) - for reason in event.reasons: - if isinstance(reason, HumanInputRequired): - self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id) - yield from responses - resolved_state: GraphRuntimeState | None = None - try: - resolved_state = self._ensure_graph_runtime_initialized() - except ValueError: - resolved_state = None - - with self._database_session() as session: - self._save_message(session=session, graph_runtime_state=resolved_state) - message = self._get_message(session=session) - if message is not None: - message.status = MessageStatus.PAUSED - self._message_saved_on_pause = True self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) def _handle_workflow_failed_event( @@ -657,10 +614,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, ) - # Save message unless it has already been persisted on pause. - if not self._message_saved_on_pause: - with self._database_session() as session: - self._save_message(session=session, graph_runtime_state=resolved_state) + # Save message + with self._database_session() as session: + self._save_message(session=session, graph_runtime_state=resolved_state) yield self._message_end_to_stream_response() @@ -686,65 +642,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """Handle message replace events.""" yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason) - def _handle_human_input_form_filled_event( - self, event: QueueHumanInputFormFilledEvent, **kwargs - ) -> Generator[StreamResponse, None, None]: - """Handle human input form filled events.""" - self._persist_human_input_extra_content(node_id=event.node_id) - yield self._workflow_response_converter.human_input_form_filled_to_stream_response( - event=event, task_id=self._application_generate_entity.task_id - ) - - def _handle_human_input_form_timeout_event( - self, event: QueueHumanInputFormTimeoutEvent, **kwargs - ) -> Generator[StreamResponse, None, None]: - """Handle human input form timeout events.""" - yield self._workflow_response_converter.human_input_form_timeout_to_stream_response( - event=event, task_id=self._application_generate_entity.task_id - ) - - def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None: - if not self._workflow_run_id or not self._message_id: - return - - if form_id is None: - if node_id is None: - return - form_id = self._load_human_input_form_id(node_id=node_id) - if form_id is None: - logger.warning( - "HumanInput form not found for workflow run %s node %s", - self._workflow_run_id, - node_id, - ) - return - - with self._database_session() as session: - exists_stmt = select(HumanInputContent).where( - HumanInputContent.workflow_run_id == self._workflow_run_id, - HumanInputContent.message_id == self._message_id, - HumanInputContent.form_id == form_id, - ) - if session.scalar(exists_stmt) is not None: - return - - content = HumanInputContent( - workflow_run_id=self._workflow_run_id, - message_id=self._message_id, - form_id=form_id, - ) - session.add(content) - - def _load_human_input_form_id(self, *, node_id: str) -> str | None: - form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, - tenant_id=self._workflow_tenant_id, - ) - form = form_repository.get_form(self._workflow_run_id, node_id) - if form is None: - return None - return form.id - def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: """Handle agent log events.""" yield self._workflow_response_converter.handle_agent_log( @@ -762,7 +659,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueWorkflowStartedEvent: self._handle_workflow_started_event, QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event, - QueueWorkflowPausedEvent: self._handle_workflow_paused_event, QueueWorkflowFailedEvent: self._handle_workflow_failed_event, # Node events QueueNodeRetryEvent: self._handle_node_retry_event, @@ -784,8 +680,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueMessageReplaceEvent: self._handle_message_replace_event, QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event, QueueAgentLogEvent: self._handle_agent_log_event, - QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event, - QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event, } def _dispatch_event( @@ -853,9 +747,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): case QueueWorkflowFailedEvent(): yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager) break - case QueueWorkflowPausedEvent(): - yield from self._handle_workflow_paused_event(event) - break case QueueStopEvent(): yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager) @@ -881,11 +772,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None): message = self._get_message(session=session) - if message is None: - return - - if message.status == MessageStatus.PAUSED: - message.status = MessageStatus.NORMAL # If there are assistant files, remove markdown image links from answer answer_text = self._task_state.answer diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 6d329063f8..cefff7be92 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -5,14 +5,9 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, Union -from sqlalchemy import select -from sqlalchemy.orm import Session - from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( QueueAgentLogEvent, - QueueHumanInputFormFilledEvent, - QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -24,13 +19,9 @@ from core.app.entities.queue_entities import ( QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, - QueueWorkflowPausedEvent, ) from core.app.entities.task_entities import ( AgentLogStreamResponse, - HumanInputFormFilledResponse, - HumanInputFormTimeoutResponse, - HumanInputRequiredResponse, IterationNodeCompletedStreamResponse, IterationNodeNextStreamResponse, IterationNodeStartStreamResponse, @@ -40,9 +31,7 @@ from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, - StreamResponse, WorkflowFinishStreamResponse, - WorkflowPauseStreamResponse, WorkflowStartStreamResponse, ) from core.file import FILE_MODEL_IDENTITY, File @@ -51,8 +40,6 @@ from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.trigger_manager import TriggerManager from core.variables.segments import ArrayFileSegment, FileSegment, Segment -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import ( NodeType, SystemVariableKey, @@ -64,11 +51,8 @@ from core.workflow.runtime import GraphRuntimeState from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, EndUser -from models.human_input import HumanInputForm -from models.workflow import WorkflowRun from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator NodeExecutionId = NewType("NodeExecutionId", str) @@ -207,7 +191,6 @@ class WorkflowResponseConverter: task_id: str, workflow_run_id: str, workflow_id: str, - reason: WorkflowStartReason, ) -> WorkflowStartStreamResponse: run_id = self._ensure_workflow_run_id(workflow_run_id) started_at = naive_utc_now() @@ -221,7 +204,6 @@ class WorkflowResponseConverter: workflow_id=workflow_id, inputs=self._workflow_inputs, created_at=int(started_at.timestamp()), - reason=reason, ), ) @@ -268,7 +250,7 @@ class WorkflowResponseConverter: data=WorkflowFinishStreamResponse.Data( id=run_id, workflow_id=workflow_id, - status=status.value, + status=status, outputs=encoded_outputs, error=error, elapsed_time=elapsed_time, @@ -282,160 +264,6 @@ class WorkflowResponseConverter: ), ) - def workflow_pause_to_stream_response( - self, - *, - event: QueueWorkflowPausedEvent, - task_id: str, - graph_runtime_state: GraphRuntimeState, - ) -> list[StreamResponse]: - run_id = self._ensure_workflow_run_id() - started_at = self._workflow_started_at - if started_at is None: - raise ValueError( - "workflow_pause_to_stream_response called before workflow_start_to_stream_response", - ) - paused_at = naive_utc_now() - elapsed_time = (paused_at - started_at).total_seconds() - encoded_outputs = self._encode_outputs(event.outputs) or {} - if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API: - encoded_outputs = {} - pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons] - human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)] - expiration_times_by_form_id: dict[str, datetime] = {} - if human_input_form_ids: - stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( - HumanInputForm.id.in_(human_input_form_ids) - ) - with Session(bind=db.engine) as session: - for form_id, expiration_time in session.execute(stmt): - expiration_times_by_form_id[str(form_id)] = expiration_time - - responses: list[StreamResponse] = [] - - for reason in event.reasons: - if isinstance(reason, HumanInputRequired): - expiration_time = expiration_times_by_form_id.get(reason.form_id) - if expiration_time is None: - raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}") - responses.append( - HumanInputRequiredResponse( - task_id=task_id, - workflow_run_id=run_id, - data=HumanInputRequiredResponse.Data( - form_id=reason.form_id, - node_id=reason.node_id, - node_title=reason.node_title, - form_content=reason.form_content, - inputs=reason.inputs, - actions=reason.actions, - display_in_ui=reason.display_in_ui, - form_token=reason.form_token, - resolved_default_values=reason.resolved_default_values, - expiration_time=int(expiration_time.timestamp()), - ), - ) - ) - - responses.append( - WorkflowPauseStreamResponse( - task_id=task_id, - workflow_run_id=run_id, - data=WorkflowPauseStreamResponse.Data( - workflow_run_id=run_id, - paused_nodes=list(event.paused_nodes), - outputs=encoded_outputs, - reasons=pause_reasons, - status=WorkflowExecutionStatus.PAUSED.value, - created_at=int(started_at.timestamp()), - elapsed_time=elapsed_time, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - ), - ) - ) - - return responses - - def human_input_form_filled_to_stream_response( - self, *, event: QueueHumanInputFormFilledEvent, task_id: str - ) -> HumanInputFormFilledResponse: - run_id = self._ensure_workflow_run_id() - return HumanInputFormFilledResponse( - task_id=task_id, - workflow_run_id=run_id, - data=HumanInputFormFilledResponse.Data( - node_id=event.node_id, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ), - ) - - def human_input_form_timeout_to_stream_response( - self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str - ) -> HumanInputFormTimeoutResponse: - run_id = self._ensure_workflow_run_id() - return HumanInputFormTimeoutResponse( - task_id=task_id, - workflow_run_id=run_id, - data=HumanInputFormTimeoutResponse.Data( - node_id=event.node_id, - node_title=event.node_title, - expiration_time=int(event.expiration_time.timestamp()), - ), - ) - - @classmethod - def workflow_run_result_to_finish_response( - cls, - *, - task_id: str, - workflow_run: WorkflowRun, - creator_user: Account | EndUser, - ) -> WorkflowFinishStreamResponse: - run_id = workflow_run.id - elapsed_time = workflow_run.elapsed_time - - encoded_outputs = workflow_run.outputs_dict - finished_at = workflow_run.finished_at - assert finished_at is not None - - created_by: Mapping[str, object] - user = creator_user - if isinstance(user, Account): - created_by = { - "id": user.id, - "name": user.name, - "email": user.email, - } - else: - created_by = { - "id": user.id, - "user": user.session_id, - } - - return WorkflowFinishStreamResponse( - task_id=task_id, - workflow_run_id=run_id, - data=WorkflowFinishStreamResponse.Data( - id=run_id, - workflow_id=workflow_run.workflow_id, - status=workflow_run.status.value, - outputs=encoded_outputs, - error=workflow_run.error, - elapsed_time=elapsed_time, - total_tokens=workflow_run.total_tokens, - total_steps=workflow_run.total_steps, - created_by=created_by, - created_at=int(workflow_run.created_at.timestamp()), - finished_at=int(finished_at.timestamp()), - files=cls.fetch_files_from_node_outputs(encoded_outputs), - exceptions_count=workflow_run.exceptions_count, - ), - ) - def workflow_node_start_to_stream_response( self, *, @@ -512,13 +340,13 @@ class WorkflowResponseConverter: metadata = self._merge_metadata(event.execution_metadata, snapshot) if isinstance(event, QueueNodeSucceededEvent): - status = WorkflowNodeExecutionStatus.SUCCEEDED.value + status = WorkflowNodeExecutionStatus.SUCCEEDED error_message = event.error elif isinstance(event, QueueNodeFailedEvent): - status = WorkflowNodeExecutionStatus.FAILED.value + status = WorkflowNodeExecutionStatus.FAILED error_message = event.error else: - status = WorkflowNodeExecutionStatus.EXCEPTION.value + status = WorkflowNodeExecutionStatus.EXCEPTION error_message = event.error return NodeFinishStreamResponse( @@ -585,7 +413,7 @@ class WorkflowResponseConverter: process_data_truncated=process_data_truncated, outputs=outputs, outputs_truncated=outputs_truncated, - status=WorkflowNodeExecutionStatus.RETRY.value, + status=WorkflowNodeExecutionStatus.RETRY, error=event.error, elapsed_time=elapsed_time, execution_metadata=metadata, @@ -764,8 +592,7 @@ class WorkflowResponseConverter: ), ) - @classmethod - def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: + def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -774,7 +601,7 @@ class WorkflowResponseConverter: if not outputs_dict: return [] - files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] + files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] # Remove None files = [file for file in files if file] # Flatten list diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 4e9a191dae..57617d8863 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,6 +1,6 @@ import json import logging -from collections.abc import Callable, Generator, Mapping +from collections.abc import Generator from typing import Union, cast from sqlalchemy import select @@ -10,14 +10,12 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.exc import GenerateTaskStoppedError -from core.app.apps.streaming_utils import stream_topic_events from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, AppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, - ConversationAppGenerateEntity, InvokeFrom, ) from core.app.entities.task_entities import ( @@ -29,8 +27,6 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db -from extensions.ext_redis import get_pubsub_broadcast_channel -from libs.broadcast_channel.channel import Topic from libs.datetime_utils import naive_utc_now from models import Account from models.enums import CreatorUserRole @@ -160,7 +156,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): query = application_generate_entity.query or "New conversation" conversation_name = (query[:20] + "…") if len(query) > 20 else query - created_new_conversation = conversation is None try: if not conversation: conversation = Conversation( @@ -237,10 +232,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add_all(message_files) db.session.commit() - - if isinstance(application_generate_entity, ConversationAppGenerateEntity): - application_generate_entity.conversation_id = conversation.id - application_generate_entity.is_new_conversation = created_new_conversation return conversation, message except Exception: db.session.rollback() @@ -293,29 +284,3 @@ class MessageBasedAppGenerator(BaseAppGenerator): raise MessageNotExistsError("Message not exists") return message - - @staticmethod - def _make_channel_key(app_mode: AppMode, workflow_run_id: str): - return f"channel:{app_mode}:{workflow_run_id}" - - @classmethod - def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic: - key = cls._make_channel_key(app_mode, workflow_run_id) - channel = get_pubsub_broadcast_channel() - topic = channel.topic(key) - return topic - - @classmethod - def retrieve_events( - cls, - app_mode: AppMode, - workflow_run_id: str, - idle_timeout=300, - on_subscribe: Callable[[], None] | None = None, - ) -> Generator[Mapping | str, None, None]: - topic = cls.get_response_topic(app_mode, workflow_run_id) - return stream_topic_events( - topic=topic, - idle_timeout=idle_timeout, - on_subscribe=on_subscribe, - ) diff --git a/api/core/app/apps/message_generator.py b/api/core/app/apps/message_generator.py deleted file mode 100644 index 68631bb230..0000000000 --- a/api/core/app/apps/message_generator.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Callable, Generator, Mapping - -from core.app.apps.streaming_utils import stream_topic_events -from extensions.ext_redis import get_pubsub_broadcast_channel -from libs.broadcast_channel.channel import Topic -from models.model import AppMode - - -class MessageGenerator: - @staticmethod - def _make_channel_key(app_mode: AppMode, workflow_run_id: str): - return f"channel:{app_mode}:{str(workflow_run_id)}" - - @classmethod - def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic: - key = cls._make_channel_key(app_mode, workflow_run_id) - channel = get_pubsub_broadcast_channel() - topic = channel.topic(key) - return topic - - @classmethod - def retrieve_events( - cls, - app_mode: AppMode, - workflow_run_id: str, - idle_timeout=300, - ping_interval: float = 10.0, - on_subscribe: Callable[[], None] | None = None, - ) -> Generator[Mapping | str, None, None]: - topic = cls.get_response_topic(app_mode, workflow_run_id) - return stream_topic_events( - topic=topic, - idle_timeout=idle_timeout, - ping_interval=ping_interval, - on_subscribe=on_subscribe, - ) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index ea4441b5d8..eca96cb074 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("Pipeline dataset is required") inputs: Mapping[str, Any] = args["inputs"] start_node_id: str = args["start_node_id"] - datasource_type: str = args["datasource_type"] + datasource_type = DatasourceProviderType(args["datasource_type"]) datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list( datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user ) @@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator): tenant_id: str, dataset_id: str, built_in_field_enabled: bool, - datasource_type: str, + datasource_type: DatasourceProviderType, datasource_info: Mapping[str, Any], created_from: str, position: int, @@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator): batch: str, document_form: str, ): - if datasource_type == "local_file": - name = datasource_info.get("name", "untitled") - elif datasource_type == "online_document": - name = datasource_info.get("page", {}).get("page_name", "untitled") - elif datasource_type == "website_crawl": - name = datasource_info.get("title", "untitled") - elif datasource_type == "online_drive": - name = datasource_info.get("name", "untitled") - else: - raise ValueError(f"Unsupported datasource type: {datasource_type}") - + match datasource_type: + case DatasourceProviderType.LOCAL_FILE: + name = datasource_info.get("name", "untitled") + case DatasourceProviderType.ONLINE_DOCUMENT: + name = datasource_info.get("page", {}).get("page_name", "untitled") + case DatasourceProviderType.WEBSITE_CRAWL: + name = datasource_info.get("title", "untitled") + case DatasourceProviderType.ONLINE_DRIVE: + name = datasource_info.get("name", "untitled") + case _: + raise ValueError(f"Unsupported datasource type: {datasource_type}") document = Document( tenant_id=tenant_id, dataset_id=dataset_id, @@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator): def _format_datasource_info_list( self, - datasource_type: str, + datasource_type: DatasourceProviderType, datasource_info_list: list[Mapping[str, Any]], pipeline: Pipeline, workflow: Workflow, @@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator): """ Format datasource info list. """ - if datasource_type == "online_drive": + if datasource_type == DatasourceProviderType.ONLINE_DRIVE: all_files: list[Mapping[str, Any]] = [] datasource_node_data = None datasource_nodes = workflow.graph_dict.get("nodes", []) diff --git a/api/core/app/apps/streaming_utils.py b/api/core/app/apps/streaming_utils.py deleted file mode 100644 index 57d4b537a4..0000000000 --- a/api/core/app/apps/streaming_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -import json -import time -from collections.abc import Callable, Generator, Iterable, Mapping -from typing import Any - -from core.app.entities.task_entities import StreamEvent -from libs.broadcast_channel.channel import Topic -from libs.broadcast_channel.exc import SubscriptionClosedError - - -def stream_topic_events( - *, - topic: Topic, - idle_timeout: float, - ping_interval: float | None = None, - on_subscribe: Callable[[], None] | None = None, - terminal_events: Iterable[str | StreamEvent] | None = None, -) -> Generator[Mapping[str, Any] | str, None, None]: - # send a PING event immediately to prevent the connection staying in pending state for a long time. - # - # This simplify the debugging process as the DevTools in Chrome does not - # provide complete curl command for pending connections. - yield StreamEvent.PING.value - - terminal_values = _normalize_terminal_events(terminal_events) - last_msg_time = time.time() - last_ping_time = last_msg_time - with topic.subscribe() as sub: - # on_subscribe fires only after the Redis subscription is active. - # This is used to gate task start and reduce pub/sub race for the first event. - if on_subscribe is not None: - on_subscribe() - while True: - try: - msg = sub.receive(timeout=0.1) - except SubscriptionClosedError: - return - if msg is None: - current_time = time.time() - if current_time - last_msg_time > idle_timeout: - return - if ping_interval is not None and current_time - last_ping_time >= ping_interval: - yield StreamEvent.PING.value - last_ping_time = current_time - continue - - last_msg_time = time.time() - last_ping_time = last_msg_time - event = json.loads(msg) - yield event - if not isinstance(event, dict): - continue - - event_type = event.get("event") - if event_type in terminal_values: - return - - -def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]: - if not terminal_events: - return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value} - values: set[str] = set() - for item in terminal_events: - if isinstance(item, StreamEvent): - values.add(item.value) - else: - values.add(str(item)) - return values diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index dc5852d552..ee205ed153 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -25,7 +25,6 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError @@ -35,15 +34,12 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts -from models.account import Account +from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom -from models.model import App, EndUser -from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService if TYPE_CHECKING: @@ -70,11 +66,9 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[True], call_depth: int, - workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - pause_state_config: PauseStateLayerConfig | None = None, ) -> Generator[Mapping[str, Any] | str, None, None]: ... @overload @@ -88,11 +82,9 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[False], call_depth: int, - workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any]: ... @overload @@ -106,11 +98,9 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool, call_depth: int, - workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ... def generate( @@ -123,11 +113,9 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool = True, call_depth: int = 0, - workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: files: Sequence[Mapping[str, Any]] = args.get("files") or [] @@ -162,7 +150,7 @@ class WorkflowAppGenerator(BaseAppGenerator): extras = { **extract_external_trace_id_from_args(args), } - workflow_run_id = str(workflow_run_id or uuid.uuid4()) + workflow_run_id = str(uuid.uuid4()) # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args # trigger shouldn't prepare user inputs if self._should_prepare_user_inputs(args): @@ -228,40 +216,13 @@ class WorkflowAppGenerator(BaseAppGenerator): streaming=streaming, root_node_id=root_node_id, graph_engine_layers=graph_engine_layers, - pause_state_config=pause_state_config, ) - def resume( - self, - *, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - application_generate_entity: WorkflowAppGenerateEntity, - graph_runtime_state: GraphRuntimeState, - workflow_execution_repository: WorkflowExecutionRepository, - workflow_node_execution_repository: WorkflowNodeExecutionRepository, - graph_engine_layers: Sequence[GraphEngineLayer] = (), - pause_state_config: PauseStateLayerConfig | None = None, - variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, - ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: + def resume(self, *, workflow_run_id: str) -> None: """ - Resume a paused workflow execution using the persisted runtime state. + @TBD """ - return self._generate( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=application_generate_entity.invoke_from, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=application_generate_entity.stream, - variable_loader=variable_loader, - graph_engine_layers=graph_engine_layers, - graph_runtime_state=graph_runtime_state, - pause_state_config=pause_state_config, - ) + pass def _generate( self, @@ -277,8 +238,6 @@ class WorkflowAppGenerator(BaseAppGenerator): variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - graph_runtime_state: GraphRuntimeState | None = None, - pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -292,8 +251,6 @@ class WorkflowAppGenerator(BaseAppGenerator): :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream """ - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - # init queue manager queue_manager = WorkflowAppQueueManager( task_id=application_generate_entity.task_id, @@ -302,15 +259,6 @@ class WorkflowAppGenerator(BaseAppGenerator): app_mode=app_model.mode, ) - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) - ) - # new thread with request context and contextvars context = contextvars.copy_context() @@ -328,8 +276,7 @@ class WorkflowAppGenerator(BaseAppGenerator): "root_node_id": root_node_id, "workflow_execution_repository": workflow_execution_repository, "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, + "graph_engine_layers": graph_engine_layers, }, ) @@ -431,7 +378,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, - pause_state_config=None, ) def single_loop_generate( @@ -513,7 +459,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, - pause_state_config=None, ) def _generate_worker( @@ -527,7 +472,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - graph_runtime_state: GraphRuntimeState | None = None, ) -> None: """ Generate worker in a new thread. @@ -573,7 +517,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, root_node_id=root_node_id, graph_engine_layers=graph_engine_layers, - graph_runtime_state=graph_runtime_state, ) try: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index a43f7879d6..0ee3c177f2 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -42,7 +42,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, graph_engine_layers: Sequence[GraphEngineLayer] = (), - graph_runtime_state: GraphRuntimeState | None = None, ): super().__init__( queue_manager=queue_manager, @@ -56,7 +55,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self._root_node_id = root_node_id self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository - self._resume_graph_runtime_state = graph_runtime_state @trace_span(WorkflowAppRunnerHandler) def run(self): @@ -65,28 +63,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): """ app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) + + system_inputs = SystemVariable( + files=self.application_generate_entity.files, + user_id=self._sys_user_id, + app_id=app_config.app_id, + timestamp=int(naive_utc_now().timestamp()), + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_execution_id, + ) + invoke_from = self.application_generate_entity.invoke_from # if only single iteration or single loop run is requested if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: invoke_from = InvokeFrom.DEBUGGER user_from = self._resolve_user_from(invoke_from) - resume_state = self._resume_graph_runtime_state - - if resume_state is not None: - graph_runtime_state = resume_state - variable_pool = graph_runtime_state.variable_pool - graph = self._init_graph( - graph_config=self._workflow.graph_dict, - graph_runtime_state=graph_runtime_state, - workflow_id=self._workflow.id, - tenant_id=self._workflow.tenant_id, - user_id=self.application_generate_entity.user_id, - user_from=user_from, - invoke_from=invoke_from, - root_node_id=self._root_node_id, - ) - elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, @@ -96,14 +89,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): inputs = self.application_generate_entity.inputs # Create a variable pool. - system_inputs = SystemVariable( - files=self.application_generate_entity.files, - user_id=self._sys_user_id, - app_id=app_config.app_id, - timestamp=int(naive_utc_now().timestamp()), - workflow_id=app_config.workflow_id, - workflow_execution_id=self.application_generate_entity.workflow_execution_id, - ) + variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, @@ -112,6 +98,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + # init graph graph = self._init_graph( graph_config=self._workflow.graph_dict, graph_runtime_state=graph_runtime_state, diff --git a/api/core/app/apps/workflow/errors.py b/api/core/app/apps/workflow/errors.py deleted file mode 100644 index 16cd864209..0000000000 --- a/api/core/app/apps/workflow/errors.py +++ /dev/null @@ -1,7 +0,0 @@ -from libs.exception import BaseHTTPException - - -class WorkflowPausedInBlockingModeError(BaseHTTPException): - error_code = "workflow_paused_in_blocking_mode" - description = "Workflow execution paused for human input; blocking response mode is not supported." - code = 400 diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 0a567a4315..842ad545ad 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -16,8 +16,6 @@ from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentLogEvent, QueueErrorEvent, - QueueHumanInputFormFilledEvent, - QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -34,7 +32,6 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowPartialSuccessEvent, - QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, WorkflowQueueMessage, @@ -49,13 +46,11 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, - WorkflowPauseStreamResponse, WorkflowStartStreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import WorkflowExecutionStatus from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.runtime import GraphRuntimeState @@ -137,25 +132,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): for stream_response in generator: if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err - elif isinstance(stream_response, WorkflowPauseStreamResponse): - response = WorkflowAppBlockingResponse( - task_id=self._application_generate_entity.task_id, - workflow_run_id=stream_response.data.workflow_run_id, - data=WorkflowAppBlockingResponse.Data( - id=stream_response.data.workflow_run_id, - workflow_id=self._workflow.id, - status=stream_response.data.status, - outputs=stream_response.data.outputs or {}, - error=None, - elapsed_time=stream_response.data.elapsed_time, - total_tokens=stream_response.data.total_tokens, - total_steps=stream_response.data.total_steps, - created_at=stream_response.data.created_at, - finished_at=None, - ), - ) - - return response elif isinstance(stream_response, WorkflowFinishStreamResponse): response = WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, @@ -170,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): total_tokens=stream_response.data.total_tokens, total_steps=stream_response.data.total_steps, created_at=int(stream_response.data.created_at), - finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None, + finished_at=int(stream_response.data.finished_at), ), ) @@ -283,15 +259,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): run_id = self._extract_workflow_run_id(runtime_state) self._workflow_execution_id = run_id - if event.reason == WorkflowStartReason.INITIAL: - with self._database_session() as session: - self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id) + with self._database_session() as session: + self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id) start_resp = self._workflow_response_converter.workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run_id=run_id, workflow_id=self._workflow.id, - reason=event.reason, ) yield start_resp @@ -466,21 +440,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) yield workflow_finish_resp - def _handle_workflow_paused_event( - self, - event: QueueWorkflowPausedEvent, - **kwargs, - ) -> Generator[StreamResponse, None, None]: - """Handle workflow paused events.""" - self._ensure_workflow_initialized() - validated_state = self._ensure_graph_runtime_initialized() - responses = self._workflow_response_converter.workflow_pause_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - graph_runtime_state=validated_state, - ) - yield from responses - def _handle_workflow_failed_and_stop_events( self, event: Union[QueueWorkflowFailedEvent, QueueStopEvent], @@ -536,22 +495,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): task_id=self._application_generate_entity.task_id, event=event ) - def _handle_human_input_form_filled_event( - self, event: QueueHumanInputFormFilledEvent, **kwargs - ) -> Generator[StreamResponse, None, None]: - """Handle human input form filled events.""" - yield self._workflow_response_converter.human_input_form_filled_to_stream_response( - event=event, task_id=self._application_generate_entity.task_id - ) - - def _handle_human_input_form_timeout_event( - self, event: QueueHumanInputFormTimeoutEvent, **kwargs - ) -> Generator[StreamResponse, None, None]: - """Handle human input form timeout events.""" - yield self._workflow_response_converter.human_input_form_timeout_to_stream_response( - event=event, task_id=self._application_generate_entity.task_id - ) - def _get_event_handlers(self) -> dict[type, Callable]: """Get mapping of event types to their handlers using fluent pattern.""" return { @@ -563,7 +506,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueWorkflowStartedEvent: self._handle_workflow_started_event, QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event, - QueueWorkflowPausedEvent: self._handle_workflow_paused_event, # Node events QueueNodeRetryEvent: self._handle_node_retry_event, QueueNodeStartedEvent: self._handle_node_started_event, @@ -578,8 +520,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueLoopCompletedEvent: self._handle_loop_completed_event, # Agent events QueueAgentLogEvent: self._handle_agent_log_event, - QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event, - QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event, } def _dispatch_event( @@ -662,9 +602,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): case QueueWorkflowFailedEvent(): yield from self._handle_workflow_failed_and_stop_events(event) break - case QueueWorkflowPausedEvent(): - yield from self._handle_workflow_paused_event(event) - break case QueueStopEvent(): yield from self._handle_workflow_failed_and_stop_events(event) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index c9d7464c17..13b7865f55 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,4 +1,3 @@ -import logging import time from collections.abc import Mapping, Sequence from typing import Any, cast @@ -8,8 +7,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, - QueueHumanInputFormFilledEvent, - QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -25,27 +22,22 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowPartialSuccessEvent, - QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams -from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.graph import Graph from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events import ( GraphEngineEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, - GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunAgentLogEvent, NodeRunExceptionEvent, NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, NodeRunIterationFailedEvent, NodeRunIterationNextEvent, NodeRunIterationStartedEvent, @@ -69,9 +61,6 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, from core.workflow.workflow_entry import WorkflowEntry from models.enums import UserFrom from models.workflow import Workflow -from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task - -logger = logging.getLogger(__name__) class WorkflowBasedAppRunner: @@ -338,7 +327,7 @@ class WorkflowBasedAppRunner: :param event: event """ if isinstance(event, GraphRunStartedEvent): - self._publish_event(QueueWorkflowStartedEvent(reason=event.reason)) + self._publish_event(QueueWorkflowStartedEvent()) elif isinstance(event, GraphRunSucceededEvent): self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) elif isinstance(event, GraphRunPartialSucceededEvent): @@ -349,38 +338,6 @@ class WorkflowBasedAppRunner: self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) elif isinstance(event, GraphRunAbortedEvent): self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0)) - elif isinstance(event, GraphRunPausedEvent): - runtime_state = workflow_entry.graph_engine.graph_runtime_state - paused_nodes = runtime_state.get_paused_nodes() - self._enqueue_human_input_notifications(event.reasons) - self._publish_event( - QueueWorkflowPausedEvent( - reasons=event.reasons, - outputs=event.outputs, - paused_nodes=paused_nodes, - ) - ) - elif isinstance(event, NodeRunHumanInputFormFilledEvent): - self._publish_event( - QueueHumanInputFormFilledEvent( - node_execution_id=event.id, - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ) - ) - elif isinstance(event, NodeRunHumanInputFormTimeoutEvent): - self._publish_event( - QueueHumanInputFormTimeoutEvent( - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_title, - expiration_time=event.expiration_time, - ) - ) elif isinstance(event, NodeRunRetryEvent): node_run_result = event.node_run_result inputs = node_run_result.inputs @@ -587,19 +544,5 @@ class WorkflowBasedAppRunner: ) ) - def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None: - for reason in reasons: - if not isinstance(reason, HumanInputRequired): - continue - if not reason.form_id: - continue - try: - dispatch_human_input_email_task.apply_async( - kwargs={"form_id": reason.form_id, "node_title": reason.node_title}, - queue="mail", - ) - except Exception: # pragma: no cover - defensive logging - logger.exception("Failed to enqueue human input email task for form %s", reason.form_id) - def _publish_event(self, event: AppQueueEvent): self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 0e68e554c8..5bc453420d 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel): extras: dict[str, Any] = Field(default_factory=dict) # tracing instance - trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False) + trace_manager: Optional["TraceQueueManager"] = None class EasyUIBasedAppGenerateEntity(AppGenerateEntity): @@ -156,7 +156,6 @@ class ConversationAppGenerateEntity(AppGenerateEntity): """ conversation_id: str | None = None - is_new_conversation: bool = False parent_message_id: str | None = Field( default=None, description=( diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 5b2fa29b56..77d6bf03b4 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -8,8 +8,6 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import WorkflowNodeExecutionMetadataKey from core.workflow.nodes import NodeType @@ -48,9 +46,6 @@ class QueueEvent(StrEnum): PING = "ping" STOP = "stop" RETRY = "retry" - PAUSE = "pause" - HUMAN_INPUT_FORM_FILLED = "human_input_form_filled" - HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout" class AppQueueEvent(BaseModel): @@ -266,8 +261,6 @@ class QueueWorkflowStartedEvent(AppQueueEvent): """QueueWorkflowStartedEvent entity.""" event: QueueEvent = QueueEvent.WORKFLOW_STARTED - # Always present; mirrors GraphRunStartedEvent.reason for downstream consumers. - reason: WorkflowStartReason = WorkflowStartReason.INITIAL class QueueWorkflowSucceededEvent(AppQueueEvent): @@ -491,35 +484,6 @@ class QueueStopEvent(AppQueueEvent): return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.") -class QueueHumanInputFormFilledEvent(AppQueueEvent): - """ - QueueHumanInputFormFilledEvent entity - """ - - event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED - - node_execution_id: str - node_id: str - node_type: NodeType - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class QueueHumanInputFormTimeoutEvent(AppQueueEvent): - """ - QueueHumanInputFormTimeoutEvent entity - """ - - event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT - - node_id: str - node_type: NodeType - node_title: str - expiration_time: datetime - - class QueueMessage(BaseModel): """ QueueMessage abstract entity @@ -545,14 +509,3 @@ class WorkflowQueueMessage(QueueMessage): """ pass - - -class QueueWorkflowPausedEvent(AppQueueEvent): - """ - QueueWorkflowPausedEvent entity - """ - - event: QueueEvent = QueueEvent.PAUSE - reasons: Sequence[PauseReason] = Field(default_factory=list) - outputs: Mapping[str, object] = Field(default_factory=dict) - paused_nodes: Sequence[str] = Field(default_factory=list) diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 3f38904d2f..26fb17ccef 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -7,9 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.nodes.human_input.entities import FormInput, UserAction +from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class AnnotationReplyAccount(BaseModel): @@ -71,7 +69,6 @@ class StreamEvent(StrEnum): AGENT_THOUGHT = "agent_thought" AGENT_MESSAGE = "agent_message" WORKFLOW_STARTED = "workflow_started" - WORKFLOW_PAUSED = "workflow_paused" WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" @@ -85,9 +82,6 @@ class StreamEvent(StrEnum): TEXT_CHUNK = "text_chunk" TEXT_REPLACE = "text_replace" AGENT_LOG = "agent_log" - HUMAN_INPUT_REQUIRED = "human_input_required" - HUMAN_INPUT_FORM_FILLED = "human_input_form_filled" - HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout" class StreamResponse(BaseModel): @@ -211,8 +205,6 @@ class WorkflowStartStreamResponse(StreamResponse): workflow_id: str inputs: Mapping[str, Any] created_at: int - # Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients. - reason: WorkflowStartReason = WorkflowStartReason.INITIAL event: StreamEvent = StreamEvent.WORKFLOW_STARTED workflow_run_id: str @@ -231,7 +223,7 @@ class WorkflowFinishStreamResponse(StreamResponse): id: str workflow_id: str - status: str + status: WorkflowExecutionStatus outputs: Mapping[str, Any] | None = None error: str | None = None elapsed_time: float @@ -239,7 +231,7 @@ class WorkflowFinishStreamResponse(StreamResponse): total_steps: int created_by: Mapping[str, object] = Field(default_factory=dict) created_at: int - finished_at: int | None + finished_at: int exceptions_count: int | None = 0 files: Sequence[Mapping[str, Any]] | None = [] @@ -248,85 +240,6 @@ class WorkflowFinishStreamResponse(StreamResponse): data: Data -class WorkflowPauseStreamResponse(StreamResponse): - """ - WorkflowPauseStreamResponse entity - """ - - class Data(BaseModel): - """ - Data entity - """ - - workflow_run_id: str - paused_nodes: Sequence[str] = Field(default_factory=list) - outputs: Mapping[str, Any] = Field(default_factory=dict) - reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list) - status: str - created_at: int - elapsed_time: float - total_tokens: int - total_steps: int - - event: StreamEvent = StreamEvent.WORKFLOW_PAUSED - workflow_run_id: str - data: Data - - -class HumanInputRequiredResponse(StreamResponse): - class Data(BaseModel): - """ - Data entity - """ - - form_id: str - node_id: str - node_title: str - form_content: str - inputs: Sequence[FormInput] = Field(default_factory=list) - actions: Sequence[UserAction] = Field(default_factory=list) - display_in_ui: bool = False - form_token: str | None = None - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - expiration_time: int = Field(..., description="Unix timestamp in seconds") - - event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED - workflow_run_id: str - data: Data - - -class HumanInputFormFilledResponse(StreamResponse): - class Data(BaseModel): - """ - Data entity - """ - - node_id: str - node_title: str - rendered_content: str - action_id: str - action_text: str - - event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED - workflow_run_id: str - data: Data - - -class HumanInputFormTimeoutResponse(StreamResponse): - class Data(BaseModel): - """ - Data entity - """ - - node_id: str - node_title: str - expiration_time: int - - event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT - workflow_run_id: str - data: Data - - class NodeStartStreamResponse(StreamResponse): """ NodeStartStreamResponse entity @@ -398,7 +311,7 @@ class NodeFinishStreamResponse(StreamResponse): process_data_truncated: bool = False outputs: Mapping[str, Any] | None = None outputs_truncated: bool = True - status: str + status: WorkflowNodeExecutionStatus error: str | None = None elapsed_time: float execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None @@ -462,7 +375,7 @@ class NodeRetryStreamResponse(StreamResponse): process_data_truncated: bool = False outputs: Mapping[str, Any] | None = None outputs_truncated: bool = False - status: str + status: WorkflowNodeExecutionStatus error: str | None = None elapsed_time: float execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None @@ -806,14 +719,14 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str - status: str + status: WorkflowExecutionStatus outputs: Mapping[str, Any] | None = None error: str | None = None elapsed_time: float total_tokens: int total_steps: int created_at: int - finished_at: int | None + finished_at: int workflow_run_id: str data: Data diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 2ca1275a8a..565905be0d 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -1,4 +1,3 @@ -import contextlib import logging import time import uuid @@ -104,14 +103,6 @@ class RateLimit: ) -@contextlib.contextmanager -def rate_limit_context(rate_limit: RateLimit, request_id: str | None): - request_id = rate_limit.enter(request_id) - yield - if request_id is not None: - rate_limit.exit(request_id) - - class RateLimitGenerator: def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str): self.rate_limit = rate_limit diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 1c267091a4..bf76ae8178 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Annotated, Literal, Self, TypeAlias from pydantic import BaseModel, Field @@ -53,14 +52,6 @@ class WorkflowResumptionContext(BaseModel): return self.generate_entity.entity -@dataclass(frozen=True) -class PauseStateLayerConfig: - """Configuration container for instantiating pause persistence layers.""" - - session_factory: Engine | sessionmaker[Session] - state_owner_user_id: str - - class PauseStatePersistenceLayer(GraphEngineLayer): def __init__( self, diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index d682083f34..2d4ee08daf 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -82,11 +82,10 @@ class MessageCycleManager: if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): return None - is_first_message = self._application_generate_entity.is_new_conversation + is_first_message = self._application_generate_entity.conversation_id is None extras = self._application_generate_entity.extras auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True) - thread: Thread | None = None if auto_generate_conversation_name and is_first_message: # start generate thread # time.sleep not block other logic @@ -102,10 +101,9 @@ class MessageCycleManager: thread.daemon = True thread.start() - if is_first_message: - self._application_generate_entity.is_new_conversation = False + return thread - return thread + return None def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index e0a0059a38..a5773bbef8 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -4,13 +4,14 @@ from typing import TYPE_CHECKING, final from typing_extensions import override from configs import dify_config -from core.file import file_manager -from core.helper import ssrf_proxy +from core.file.file_manager import file_manager from core.helper.code_executor.code_executor import CodeExecutor from core.helper.code_executor.code_node_provider import CodeNodeProvider +from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager +from core.workflow.entities.graph_config import NodeConfigDict from core.workflow.enums import NodeType -from core.workflow.graph import NodeFactory +from core.workflow.graph.graph import NodeFactory from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits @@ -22,7 +23,6 @@ from core.workflow.nodes.template_transform.template_renderer import ( Jinja2TemplateRenderer, ) from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from libs.typing import is_str, is_str_dict if TYPE_CHECKING: from core.workflow.entities import GraphInitParams @@ -47,9 +47,9 @@ class DifyNodeFactory(NodeFactory): code_providers: Sequence[type[CodeNodeProvider]] | None = None, code_limits: CodeNodeLimits | None = None, template_renderer: Jinja2TemplateRenderer | None = None, - http_request_http_client: HttpClientProtocol = ssrf_proxy, + http_request_http_client: HttpClientProtocol | None = None, http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, - http_request_file_manager: FileManagerProtocol = file_manager, + http_request_file_manager: FileManagerProtocol | None = None, ) -> None: self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state @@ -68,12 +68,12 @@ class DifyNodeFactory(NodeFactory): max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() - self._http_request_http_client = http_request_http_client + self._http_request_http_client = http_request_http_client or ssrf_proxy self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory - self._http_request_file_manager = http_request_file_manager + self._http_request_file_manager = http_request_file_manager or file_manager @override - def create_node(self, node_config: dict[str, object]) -> Node: + def create_node(self, node_config: NodeConfigDict) -> Node: """ Create a Node instance from node configuration data using the traditional mapping. @@ -82,23 +82,14 @@ class DifyNodeFactory(NodeFactory): :raises ValueError: if node type is unknown or configuration is invalid """ # Get node_id from config - node_id = node_config.get("id") - if not is_str(node_id): - raise ValueError("Node config missing id") + node_id = node_config["id"] # Get node type from config - node_data = node_config.get("data", {}) - if not is_str_dict(node_data): - raise ValueError(f"Node {node_id} missing data information") - - node_type_str = node_data.get("type") - if not is_str(node_type_str): - raise ValueError(f"Node {node_id} missing or invalid type information") - + node_data = node_config["data"] try: - node_type = NodeType(node_type_str) + node_type = NodeType(node_data["type"]) except ValueError: - raise ValueError(f"Unknown node type: {node_type_str}") + raise ValueError(f"Unknown node type: {node_data['type']}") # Get node class node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type) diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py deleted file mode 100644 index 46006f4381..0000000000 --- a/api/core/entities/execution_extra_content.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from typing import Any, TypeAlias - -from pydantic import BaseModel, ConfigDict, Field - -from core.workflow.nodes.human_input.entities import FormInput, UserAction -from models.execution_extra_content import ExecutionContentType - - -class HumanInputFormDefinition(BaseModel): - model_config = ConfigDict(frozen=True) - - form_id: str - node_id: str - node_title: str - form_content: str - inputs: Sequence[FormInput] = Field(default_factory=list) - actions: Sequence[UserAction] = Field(default_factory=list) - display_in_ui: bool = False - form_token: str | None = None - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - expiration_time: int - - -class HumanInputFormSubmissionData(BaseModel): - model_config = ConfigDict(frozen=True) - - node_id: str - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class HumanInputContent(BaseModel): - model_config = ConfigDict(frozen=True) - - workflow_run_id: str - submitted: bool - form_definition: HumanInputFormDefinition | None = None - form_submission_data: HumanInputFormSubmissionData | None = None - type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT) - - -ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent - -__all__ = [ - "ExecutionExtraContentDomainModel", - "HumanInputContent", - "HumanInputFormDefinition", - "HumanInputFormSubmissionData", -] diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 8a26b2e91b..e8d41b9387 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import ( ) from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models.engine import db from models.provider import ( LoadBalancingModelConfig, Provider, diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index c0fefef3d0..9945d7c1ab 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -168,3 +168,18 @@ def _to_url(f: File, /): return sign_tool_file(tool_file_id=f.related_id, extension=f.extension) else: raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + +class FileManager: + """ + Adapter exposing file manager helpers behind FileManagerProtocol. + + This is intentionally a thin wrapper over the existing module-level functions so callers can inject it + where a protocol-typed file manager is expected. + """ + + def download(self, f: File, /) -> bytes: + return download(f) + + +file_manager = FileManager() diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index e93e1e4414..f4cce0b332 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -47,15 +47,16 @@ class CodeNodeProvider(BaseModel, ABC): @classmethod def get_default_config(cls) -> DefaultConfig: - return { - "type": "code", - "config": { - "variables": [ - {"variable": "arg1", "value_selector": []}, - {"variable": "arg2", "value_selector": []}, - ], - "code_language": cls.get_language(), - "code": cls.get_default_code(), - "outputs": {"result": {"type": "string", "children": None}}, - }, + variables: list[VariableConfig] = [ + {"variable": "arg1", "value_selector": []}, + {"variable": "arg2", "value_selector": []}, + ] + outputs: dict[str, OutputConfig] = {"result": {"type": "string", "children": None}} + + config: CodeConfig = { + "variables": variables, + "code_language": cls.get_language(), + "code": cls.get_default_code(), + "outputs": outputs, } + return {"type": "code", "config": config} diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index ddccfbaf45..54068fc28d 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -230,3 +230,41 @@ def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("HEAD", url, max_retries=max_retries, **kwargs) + + +class SSRFProxy: + """ + Adapter exposing SSRF-protected HTTP helpers behind HttpClientProtocol. + + This is intentionally a thin wrapper over the existing module-level functions so callers can inject it + where a protocol-typed HTTP client is expected. + """ + + @property + def max_retries_exceeded_error(self) -> type[Exception]: + return max_retries_exceeded_error + + @property + def request_error(self) -> type[Exception]: + return request_error + + def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return get(url=url, max_retries=max_retries, **kwargs) + + def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return head(url=url, max_retries=max_retries, **kwargs) + + def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return post(url=url, max_retries=max_retries, **kwargs) + + def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return put(url=url, max_retries=max_retries, **kwargs) + + def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return delete(url=url, max_retries=max_retries, **kwargs) + + def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return patch(url=url, max_retries=max_retries, **kwargs) + + +ssrf_proxy = SSRFProxy() diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index e172e88298..4e3ad7bb75 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -369,77 +369,78 @@ class IndexingRunner: # Generate summary preview summary_index_setting = tmp_processing_rule.get("summary_index_setting") if summary_index_setting and summary_index_setting.get("enable") and preview_texts: - preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting) + preview_texts = index_processor.generate_summary_preview( + tenant_id, preview_texts, summary_index_setting, doc_language + ) return IndexingEstimate(total_segments=total_segments, preview=preview_texts) def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict ) -> list[Document]: - # load file - if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: - return [] - data_source_info = dataset_document.data_source_info_dict text_docs = [] - if dataset_document.data_source_type == "upload_file": - if not data_source_info or "upload_file_id" not in data_source_info: - raise ValueError("no upload file found") - stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) - file_detail = db.session.scalars(stmt).one_or_none() + match dataset_document.data_source_type: + case "upload_file": + if not data_source_info or "upload_file_id" not in data_source_info: + raise ValueError("no upload file found") + stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) + file_detail = db.session.scalars(stmt).one_or_none() - if file_detail: + if file_detail: + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, + upload_file=file_detail, + document_model=dataset_document.doc_form, + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case "notion_import": + if ( + not data_source_info + or "notion_workspace_id" not in data_source_info + or "notion_page_id" not in data_source_info + ): + raise ValueError("no notion import info found") extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, - upload_file=file_detail, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info.get("credential_id"), + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "document": dataset_document, + "tenant_id": dataset_document.tenant_id, + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - elif dataset_document.data_source_type == "notion_import": - if ( - not data_source_info - or "notion_workspace_id" not in data_source_info - or "notion_page_id" not in data_source_info - ): - raise ValueError("no notion import info found") - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( - { - "credential_id": data_source_info.get("credential_id"), - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "document": dataset_document, - "tenant_id": dataset_document.tenant_id, - } - ), - document_model=dataset_document.doc_form, - ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - elif dataset_document.data_source_type == "website_crawl": - if ( - not data_source_info - or "provider" not in data_source_info - or "url" not in data_source_info - or "job_id" not in data_source_info - ): - raise ValueError("no website import info found") - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "tenant_id": dataset_document.tenant_id, - "url": data_source_info["url"], - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - } - ), - document_model=dataset_document.doc_form, - ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case "website_crawl": + if ( + not data_source_info + or "provider" not in data_source_info + or "url" not in data_source_info + or "job_id" not in data_source_info + ): + raise ValueError("no website import info found") + extract_setting = ExtractSetting( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "tenant_id": dataset_document.tenant_id, + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), + document_model=dataset_document.doc_form, + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case _: + return [] # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, diff --git a/api/core/llm_generator/entities.py b/api/core/llm_generator/entities.py new file mode 100644 index 0000000000..3bb8d2c899 --- /dev/null +++ b/api/core/llm_generator/entities.py @@ -0,0 +1,20 @@ +"""Shared payload models for LLM generator helpers and controllers.""" + +from pydantic import BaseModel, Field + +from core.app.app_config.entities import ModelConfig + + +class RuleGeneratePayload(BaseModel): + instruction: str = Field(..., description="Rule generation instruction") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") + no_variable: bool = Field(default=False, description="Whether to exclude variables") + + +class RuleCodeGeneratePayload(RuleGeneratePayload): + code_language: str = Field(default="javascript", description="Programming language for code generation") + + +class RuleStructuredOutputPayload(BaseModel): + instruction: str = Field(..., description="Structured output generation instruction") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index be1e306d47..5b2c640265 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -6,6 +6,8 @@ from typing import Protocol, cast import json_repair +from core.app.app_config.entities import ModelConfig +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.prompts import ( @@ -151,19 +153,19 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool): + def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload): output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} - model_parameters = model_config.get("completion_params", {}) - if no_variable: + model_parameters = args.model_config_data.completion_params + if args.no_variable: prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, }, remove_template_variables=False, ) @@ -175,8 +177,8 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) try: @@ -190,7 +192,7 @@ class LLMGenerator: error = str(e) error_step = "generate rule config" except Exception as e: - logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) + logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) rule_config["error"] = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -209,7 +211,7 @@ class LLMGenerator: # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, }, remove_template_variables=False, ) @@ -220,8 +222,8 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) try: @@ -250,7 +252,7 @@ class LLMGenerator: # the second step to generate the task_parameter and task_statement statement_generate_prompt = statement_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, "INPUT_TEXT": prompt_content.message.get_text_content(), }, remove_template_variables=False, @@ -276,7 +278,7 @@ class LLMGenerator: error_step = "generate conversation opener" except Exception as e: - logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) + logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) rule_config["error"] = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -284,16 +286,20 @@ class LLMGenerator: return rule_config @classmethod - def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): - if code_language == "python": + def generate_code( + cls, + tenant_id: str, + args: RuleCodeGeneratePayload, + ): + if args.code_language == "python": prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) else: prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE) prompt = prompt_template.format( inputs={ - "INSTRUCTION": instruction, - "CODE_LANGUAGE": code_language, + "INSTRUCTION": args.instruction, + "CODE_LANGUAGE": args.code_language, }, remove_template_variables=False, ) @@ -302,28 +308,28 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) prompt_messages = [UserPromptMessage(content=prompt)] - model_parameters = model_config.get("completion_params", {}) + model_parameters = args.model_config_data.completion_params try: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ) generated_code = response.message.get_text_content() - return {"code": generated_code, "language": code_language, "error": ""} + return {"code": generated_code, "language": args.code_language, "error": ""} except InvokeError as e: error = str(e) - return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} + return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"} except Exception as e: logger.exception( - "Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language + "Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language ) - return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} + return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"} @classmethod def generate_qa_document(cls, tenant_id: str, query, document_language: str): @@ -353,20 +359,20 @@ class LLMGenerator: return answer.strip() @classmethod - def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict): + def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) prompt_messages = [ SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), - UserPromptMessage(content=instruction), + UserPromptMessage(content=args.instruction), ] - model_parameters = model_config.get("model_parameters", {}) + model_parameters = args.model_config_data.completion_params try: response: LLMResult = model_instance.invoke_llm( @@ -390,12 +396,17 @@ class LLMGenerator: error = str(e) return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} except Exception as e: - logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) + logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name) return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} @staticmethod def instruction_modify_legacy( - tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None + tenant_id: str, + flow_id: str, + current: str, + instruction: str, + model_config: ModelConfig, + ideal_output: str | None, ): last_run: Message | None = ( db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() @@ -434,7 +445,7 @@ class LLMGenerator: node_id: str, current: str, instruction: str, - model_config: dict, + model_config: ModelConfig, ideal_output: str | None, workflow_service: WorkflowServiceInterface, ): @@ -505,7 +516,7 @@ class LLMGenerator: @staticmethod def __instruction_modify_common( tenant_id: str, - model_config: dict, + model_config: ModelConfig, last_run: dict | None, current: str | None, error_message: str | None, @@ -526,8 +537,8 @@ class LLMGenerator: model_instance = ModelManager().get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=model_config.provider, + model=model_config.name, ) match node_type: case "llm" | "agent": @@ -570,7 +581,5 @@ class LLMGenerator: error = str(e) return {"error": f"Failed to generate code. Error: {error}"} except Exception as e: - logger.exception( - "Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True - ) + logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True) return {"error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index d46cf049dd..ee9a016c95 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -441,11 +441,13 @@ DEFAULT_GENERATOR_SUMMARY_PROMPT = ( Requirements: 1. Write a concise summary in plain text -2. Use the same language as the input content +2. You must write in {language}. No language other than {language} should be used. 3. Focus on important facts, concepts, and details 4. If images are included, describe their key information 5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions" 6. Write directly without extra words +7. If there is not enough content to generate a meaningful summary, + return an empty string without any explanation or prompt Output only the summary text. Start summarizing now: diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index 76969fea70..51c9c51257 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -88,7 +88,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.MAX_TOKENS: { "label": { "en_US": "Max Tokens", - "zh_Hans": "最大标记", + "zh_Hans": "最大 Token 数", }, "type": "int", "help": { diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 8638ee7d64..bbbdec61d1 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -92,6 +92,10 @@ def _build_llm_result_from_first_chunk( Build a single `LLMResult` from the first returned chunk. This is used for `stream=False` because the plugin side may still implement the response via a chunked stream. + + Note: + This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying + streaming resources are released (e.g., HTTP connections owned by the plugin runtime). """ content = "" content_list: list[PromptMessageContentUnionTypes] = [] @@ -99,18 +103,25 @@ def _build_llm_result_from_first_chunk( system_fingerprint: str | None = None tools_calls: list[AssistantPromptMessage.ToolCall] = [] - first_chunk = next(chunks, None) - if first_chunk is not None: - if isinstance(first_chunk.delta.message.content, str): - content += first_chunk.delta.message.content - elif isinstance(first_chunk.delta.message.content, list): - content_list.extend(first_chunk.delta.message.content) + try: + first_chunk = next(chunks, None) + if first_chunk is not None: + if isinstance(first_chunk.delta.message.content, str): + content += first_chunk.delta.message.content + elif isinstance(first_chunk.delta.message.content, list): + content_list.extend(first_chunk.delta.message.content) - if first_chunk.delta.message.tool_calls: - _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) + if first_chunk.delta.message.tool_calls: + _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) - usage = first_chunk.delta.usage or LLMUsage.empty_usage() - system_fingerprint = first_chunk.system_fingerprint + usage = first_chunk.delta.usage or LLMUsage.empty_usage() + system_fingerprint = first_chunk.system_fingerprint + finally: + try: + for _ in chunks: + pass + except Exception: + logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True) return LLMResult( model=model, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 549e428f88..84f5bf5512 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -15,7 +15,10 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token -from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum +from core.ops.entities.config_entity import ( + OPS_FILE_PATH, + TracingProviderEnum, +) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -28,8 +31,8 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import get_message_data +from extensions.ext_database import db from extensions.ext_storage import storage -from models.engine import db from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks @@ -466,8 +469,6 @@ class TraceTask: @classmethod def _get_workflow_run_repo(cls): - from repositories.factory import DifyAPIRepositoryFactory - if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index a5196d66c0..631e3b77b2 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse from sqlalchemy import select -from models.engine import db +from extensions.ext_database import db from models.model import Message diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index ca7b6506f3..32e8ef385c 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,4 +1,3 @@ -import uuid from collections.abc import Generator, Mapping from typing import Union @@ -113,7 +112,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): "conversation_id": conversation_id, }, invoke_from=InvokeFrom.SERVICE_API, - workflow_run_id=str(uuid.uuid4()), streaming=stream, ) elif app.mode == AppMode.AGENT_CHAT: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 151a3de7d9..6e76321ea0 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -48,12 +48,22 @@ class BaseIndexProcessor(ABC): @abstractmethod def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each segment in preview_texts, generate a summary using LLM and attach it to the segment. The summary can be stored in a new attribute, e.g., summary. This method should be implemented by subclasses. + + Args: + tenant_id: Tenant ID + preview_texts: List of preview details to generate summaries for + summary_index_setting: Summary index configuration + doc_language: Optional document language to ensure summary is generated in the correct language """ raise NotImplementedError diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index ab91e29145..41d7656f8a 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -275,7 +275,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("Chunks is not a list") def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each segment, concurrently call generate_summary to generate a summary @@ -298,11 +302,15 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if flask_app: # Ensure Flask app context in worker thread with flask_app.app_context(): - summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + summary, _ = self.generate_summary( + tenant_id, preview.content, summary_index_setting, document_language=doc_language + ) preview.summary = summary else: # Fallback: try without app context (may fail) - summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + summary, _ = self.generate_summary( + tenant_id, preview.content, summary_index_setting, document_language=doc_language + ) preview.summary = summary # Generate summaries concurrently using ThreadPoolExecutor @@ -356,6 +364,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): text: str, summary_index_setting: dict | None = None, segment_id: str | None = None, + document_language: str | None = None, ) -> tuple[str, LLMUsage]: """ Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt, @@ -366,6 +375,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor): text: Text content to summarize summary_index_setting: Summary index configuration segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table + document_language: Optional document language (e.g., "Chinese", "English") + to ensure summary is generated in the correct language Returns: Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object @@ -381,8 +392,22 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("model_name and model_provider_name are required in summary_index_setting") # Import default summary prompt + is_default_prompt = False if not summary_prompt: summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT + is_default_prompt = True + + # Format prompt with document language only for default prompt + # Custom prompts are used as-is to avoid interfering with user-defined templates + # If document_language is provided, use it; otherwise, use "the same language as the input content" + # This is especially important for image-only chunks where text is empty or minimal + if is_default_prompt: + language_for_prompt = document_language or "the same language as the input content" + try: + summary_prompt = summary_prompt.format(language=language_for_prompt) + except KeyError: + # If default prompt doesn't have {language} placeholder, use it as-is + pass provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 961df2e50c..0ea77405ed 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -358,7 +358,11 @@ class ParentChildIndexProcessor(BaseIndexProcessor): } def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary @@ -389,6 +393,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): tenant_id=tenant_id, text=preview.content, summary_index_setting=summary_index_setting, + document_language=doc_language, ) preview.summary = summary else: @@ -397,6 +402,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): tenant_id=tenant_id, text=preview.content, summary_index_setting=summary_index_setting, + document_language=doc_language, ) preview.summary = summary diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 272d2ed351..40d9caaa69 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -241,7 +241,11 @@ class QAIndexProcessor(BaseIndexProcessor): } def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ QA model doesn't generate summaries, so this method returns preview_texts unchanged. diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6f2826f634..d83823d7b9 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -1,18 +1,19 @@ -"""Repository implementations for data access.""" +""" +Repository implementations for data access. -from __future__ import annotations +This package contains concrete implementations of the repository interfaces +defined in the core.workflow.repository package. +""" -from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from .factory import DifyCoreRepositoryFactory, RepositoryImportError -from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ "CeleryWorkflowExecutionRepository", "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", "RepositoryImportError", - "SQLAlchemyWorkflowExecutionRepository", "SQLAlchemyWorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py deleted file mode 100644 index 0e04c56e0e..0000000000 --- a/api/core/repositories/human_input_repository.py +++ /dev/null @@ -1,553 +0,0 @@ -import dataclasses -import json -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any - -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, selectinload, sessionmaker - -from core.workflow.nodes.human_input.entities import ( - DeliveryChannelConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - FormDefinition, - HumanInputNodeData, - MemberRecipient, - WebAppDeliveryMethod, -) -from core.workflow.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - FormNotFoundError, - HumanInputFormEntity, - HumanInputFormRecipientEntity, -) -from libs.datetime_utils import naive_utc_now -from libs.uuid_utils import uuidv7 -from models.account import Account, TenantAccountJoin -from models.human_input import ( - BackstageRecipientPayload, - ConsoleDeliveryPayload, - ConsoleRecipientPayload, - EmailExternalRecipientPayload, - EmailMemberRecipientPayload, - HumanInputDelivery, - HumanInputForm, - HumanInputFormRecipient, - RecipientType, - StandaloneWebAppRecipientPayload, -) - - -@dataclasses.dataclass(frozen=True) -class _DeliveryAndRecipients: - delivery: HumanInputDelivery - recipients: Sequence[HumanInputFormRecipient] - - -@dataclasses.dataclass(frozen=True) -class _WorkspaceMemberInfo: - user_id: str - email: str - - -class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): - def __init__(self, recipient_model: HumanInputFormRecipient): - self._recipient_model = recipient_model - - @property - def id(self) -> str: - return self._recipient_model.id - - @property - def token(self) -> str: - if self._recipient_model.access_token is None: - raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}") - return self._recipient_model.access_token - - -class _HumanInputFormEntityImpl(HumanInputFormEntity): - def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]): - self._form_model = form_model - self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models] - self._web_app_recipient = next( - ( - recipient - for recipient in recipient_models - if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP - ), - None, - ) - self._console_recipient = next( - (recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.CONSOLE), - None, - ) - self._submitted_data: Mapping[str, Any] | None = ( - json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None - ) - - @property - def id(self) -> str: - return self._form_model.id - - @property - def web_app_token(self): - if self._console_recipient is not None: - return self._console_recipient.access_token - if self._web_app_recipient is None: - return None - return self._web_app_recipient.access_token - - @property - def recipients(self) -> list[HumanInputFormRecipientEntity]: - return list(self._recipients) - - @property - def rendered_content(self) -> str: - return self._form_model.rendered_content - - @property - def selected_action_id(self) -> str | None: - return self._form_model.selected_action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self._submitted_data - - @property - def submitted(self) -> bool: - return self._form_model.submitted_at is not None - - @property - def status(self) -> HumanInputFormStatus: - return self._form_model.status - - @property - def expiration_time(self) -> datetime: - return self._form_model.expiration_time - - -@dataclasses.dataclass(frozen=True) -class HumanInputFormRecord: - form_id: str - workflow_run_id: str | None - node_id: str - tenant_id: str - app_id: str - form_kind: HumanInputFormKind - definition: FormDefinition - rendered_content: str - created_at: datetime - expiration_time: datetime - status: HumanInputFormStatus - selected_action_id: str | None - submitted_data: Mapping[str, Any] | None - submitted_at: datetime | None - submission_user_id: str | None - submission_end_user_id: str | None - completed_by_recipient_id: str | None - recipient_id: str | None - recipient_type: RecipientType | None - access_token: str | None - - @property - def submitted(self) -> bool: - return self.submitted_at is not None - - @classmethod - def from_models( - cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None - ) -> "HumanInputFormRecord": - definition_payload = json.loads(form_model.form_definition) - if "expiration_time" not in definition_payload: - definition_payload["expiration_time"] = form_model.expiration_time - return cls( - form_id=form_model.id, - workflow_run_id=form_model.workflow_run_id, - node_id=form_model.node_id, - tenant_id=form_model.tenant_id, - app_id=form_model.app_id, - form_kind=form_model.form_kind, - definition=FormDefinition.model_validate(definition_payload), - rendered_content=form_model.rendered_content, - created_at=form_model.created_at, - expiration_time=form_model.expiration_time, - status=form_model.status, - selected_action_id=form_model.selected_action_id, - submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None, - submitted_at=form_model.submitted_at, - submission_user_id=form_model.submission_user_id, - submission_end_user_id=form_model.submission_end_user_id, - completed_by_recipient_id=form_model.completed_by_recipient_id, - recipient_id=recipient_model.id if recipient_model else None, - recipient_type=recipient_model.recipient_type if recipient_model else None, - access_token=recipient_model.access_token if recipient_model else None, - ) - - -class _InvalidTimeoutStatusError(ValueError): - pass - - -class HumanInputFormRepositoryImpl: - def __init__( - self, - session_factory: sessionmaker | Engine, - tenant_id: str, - ): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - self._tenant_id = tenant_id - - def _delivery_method_to_model( - self, - session: Session, - form_id: str, - delivery_method: DeliveryChannelConfig, - ) -> _DeliveryAndRecipients: - delivery_id = str(uuidv7()) - delivery_model = HumanInputDelivery( - id=delivery_id, - form_id=form_id, - delivery_method_type=delivery_method.type, - delivery_config_id=delivery_method.id, - channel_payload=delivery_method.model_dump_json(), - ) - recipients: list[HumanInputFormRecipient] = [] - if isinstance(delivery_method, WebAppDeliveryMethod): - recipient_model = HumanInputFormRecipient( - form_id=form_id, - delivery_id=delivery_id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(), - ) - recipients.append(recipient_model) - elif isinstance(delivery_method, EmailDeliveryMethod): - email_recipients_config = delivery_method.config.recipients - recipients.extend( - self._build_email_recipients( - session=session, - form_id=form_id, - delivery_id=delivery_id, - recipients_config=email_recipients_config, - ) - ) - - return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients) - - def _build_email_recipients( - self, - session: Session, - form_id: str, - delivery_id: str, - recipients_config: EmailRecipients, - ) -> list[HumanInputFormRecipient]: - member_user_ids = [ - recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient) - ] - external_emails = [ - recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient) - ] - if recipients_config.whole_workspace: - members = self._query_all_workspace_members(session=session) - else: - members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids) - - return self._create_email_recipients_from_resolved( - form_id=form_id, - delivery_id=delivery_id, - members=members, - external_emails=external_emails, - ) - - @staticmethod - def _create_email_recipients_from_resolved( - *, - form_id: str, - delivery_id: str, - members: Sequence[_WorkspaceMemberInfo], - external_emails: Sequence[str], - ) -> list[HumanInputFormRecipient]: - recipient_models: list[HumanInputFormRecipient] = [] - seen_emails: set[str] = set() - - for member in members: - if not member.email: - continue - if member.email in seen_emails: - continue - seen_emails.add(member.email) - payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email) - recipient_models.append( - HumanInputFormRecipient.new( - form_id=form_id, - delivery_id=delivery_id, - payload=payload, - ) - ) - - for email in external_emails: - if not email: - continue - if email in seen_emails: - continue - seen_emails.add(email) - recipient_models.append( - HumanInputFormRecipient.new( - form_id=form_id, - delivery_id=delivery_id, - payload=EmailExternalRecipientPayload(email=email), - ) - ) - - return recipient_models - - def _query_all_workspace_members( - self, - session: Session, - ) -> list[_WorkspaceMemberInfo]: - stmt = ( - select(Account.id, Account.email) - .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) - .where(TenantAccountJoin.tenant_id == self._tenant_id) - ) - rows = session.execute(stmt).all() - return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] - - def _query_workspace_members_by_ids( - self, - session: Session, - restrict_to_user_ids: Sequence[str], - ) -> list[_WorkspaceMemberInfo]: - unique_ids = {user_id for user_id in restrict_to_user_ids if user_id} - if not unique_ids: - return [] - - stmt = ( - select(Account.id, Account.email) - .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) - .where(TenantAccountJoin.tenant_id == self._tenant_id) - ) - stmt = stmt.where(Account.id.in_(unique_ids)) - - rows = session.execute(stmt).all() - return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - form_config: HumanInputNodeData = params.form_config - - with self._session_factory(expire_on_commit=False) as session, session.begin(): - # Generate unique form ID - form_id = str(uuidv7()) - start_time = naive_utc_now() - node_expiration = form_config.expiration_time(start_time) - form_definition = FormDefinition( - form_content=form_config.form_content, - inputs=form_config.inputs, - user_actions=form_config.user_actions, - rendered_content=params.rendered_content, - expiration_time=node_expiration, - default_values=dict(params.resolved_default_values), - display_in_ui=params.display_in_ui, - node_title=form_config.title, - ) - form_model = HumanInputForm( - id=form_id, - tenant_id=self._tenant_id, - app_id=params.app_id, - workflow_run_id=params.workflow_execution_id, - form_kind=params.form_kind, - node_id=params.node_id, - form_definition=form_definition.model_dump_json(), - rendered_content=params.rendered_content, - expiration_time=node_expiration, - created_at=start_time, - ) - session.add(form_model) - recipient_models: list[HumanInputFormRecipient] = [] - for delivery in params.delivery_methods: - delivery_and_recipients = self._delivery_method_to_model( - session=session, - form_id=form_id, - delivery_method=delivery, - ) - session.add(delivery_and_recipients.delivery) - session.add_all(delivery_and_recipients.recipients) - recipient_models.extend(delivery_and_recipients.recipients) - if params.console_recipient_required and not any( - recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models - ): - console_delivery_id = str(uuidv7()) - console_delivery = HumanInputDelivery( - id=console_delivery_id, - form_id=form_id, - delivery_method_type=DeliveryMethodType.WEBAPP, - delivery_config_id=None, - channel_payload=ConsoleDeliveryPayload().model_dump_json(), - ) - console_recipient = HumanInputFormRecipient( - form_id=form_id, - delivery_id=console_delivery_id, - recipient_type=RecipientType.CONSOLE, - recipient_payload=ConsoleRecipientPayload( - account_id=params.console_creator_account_id, - ).model_dump_json(), - ) - session.add(console_delivery) - session.add(console_recipient) - recipient_models.append(console_recipient) - if params.backstage_recipient_required and not any( - recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models - ): - backstage_delivery_id = str(uuidv7()) - backstage_delivery = HumanInputDelivery( - id=backstage_delivery_id, - form_id=form_id, - delivery_method_type=DeliveryMethodType.WEBAPP, - delivery_config_id=None, - channel_payload=ConsoleDeliveryPayload().model_dump_json(), - ) - backstage_recipient = HumanInputFormRecipient( - form_id=form_id, - delivery_id=backstage_delivery_id, - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload( - account_id=params.console_creator_account_id, - ).model_dump_json(), - ) - session.add(backstage_delivery) - session.add(backstage_recipient) - recipient_models.append(backstage_recipient) - session.flush() - - return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - form_query = select(HumanInputForm).where( - HumanInputForm.workflow_run_id == workflow_execution_id, - HumanInputForm.node_id == node_id, - HumanInputForm.tenant_id == self._tenant_id, - ) - with self._session_factory(expire_on_commit=False) as session: - form_model: HumanInputForm | None = session.scalars(form_query).first() - if form_model is None: - return None - - recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id) - recipient_models = session.scalars(recipient_query).all() - return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - - -class HumanInputFormSubmissionRepository: - """Repository for fetching and submitting human input forms.""" - - def __init__(self, session_factory: sessionmaker | Engine): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - - def get_by_token(self, form_token: str) -> HumanInputFormRecord | None: - query = ( - select(HumanInputFormRecipient) - .options(selectinload(HumanInputFormRecipient.form)) - .where(HumanInputFormRecipient.access_token == form_token) - ) - with self._session_factory(expire_on_commit=False) as session: - recipient_model = session.scalars(query).first() - if recipient_model is None or recipient_model.form is None: - return None - return HumanInputFormRecord.from_models(recipient_model.form, recipient_model) - - def get_by_form_id_and_recipient_type( - self, - form_id: str, - recipient_type: RecipientType, - ) -> HumanInputFormRecord | None: - query = ( - select(HumanInputFormRecipient) - .options(selectinload(HumanInputFormRecipient.form)) - .where( - HumanInputFormRecipient.form_id == form_id, - HumanInputFormRecipient.recipient_type == recipient_type, - ) - ) - with self._session_factory(expire_on_commit=False) as session: - recipient_model = session.scalars(query).first() - if recipient_model is None or recipient_model.form is None: - return None - return HumanInputFormRecord.from_models(recipient_model.form, recipient_model) - - def mark_submitted( - self, - *, - form_id: str, - recipient_id: str | None, - selected_action_id: str, - form_data: Mapping[str, Any], - submission_user_id: str | None, - submission_end_user_id: str | None, - ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): - form_model = session.get(HumanInputForm, form_id) - if form_model is None: - raise FormNotFoundError(f"form not found, id={form_id}") - - recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None - - form_model.selected_action_id = selected_action_id - form_model.submitted_data = json.dumps(form_data) - form_model.submitted_at = naive_utc_now() - form_model.status = HumanInputFormStatus.SUBMITTED - form_model.submission_user_id = submission_user_id - form_model.submission_end_user_id = submission_end_user_id - form_model.completed_by_recipient_id = recipient_id - - session.add(form_model) - session.flush() - session.refresh(form_model) - if recipient_model is not None: - session.refresh(recipient_model) - - return HumanInputFormRecord.from_models(form_model, recipient_model) - - def mark_timeout( - self, - *, - form_id: str, - timeout_status: HumanInputFormStatus, - reason: str | None = None, - ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): - form_model = session.get(HumanInputForm, form_id) - if form_model is None: - raise FormNotFoundError(f"form not found, id={form_id}") - - if timeout_status not in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}: - raise _InvalidTimeoutStatusError(f"invalid timeout status: {timeout_status}") - - # already handled or submitted - if form_model.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}: - return HumanInputFormRecord.from_models(form_model, None) - - if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED: - raise FormNotFoundError(f"form already submitted, id={form_id}") - - form_model.status = timeout_status - form_model.selected_action_id = None - form_model.submitted_data = None - form_model.submission_user_id = None - form_model.submission_end_user_id = None - form_model.completed_by_recipient_id = None - # Reason is recorded in status/error downstream; not stored on form. - session.add(form_model) - session.flush() - session.refresh(form_model) - - return HumanInputFormRecord.from_models(form_model, None) diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 324dd059d1..4436773d25 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -488,7 +488,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, WorkflowNodeExecutionModel.tenant_id == self._tenant_id, WorkflowNodeExecutionModel.triggered_from == triggered_from, - WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED, ) if self._app_id: diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py index b4ecfe47ff..b87fba4eaa 100644 --- a/api/core/schemas/registry.py +++ b/api/core/schemas/registry.py @@ -35,6 +35,7 @@ class SchemaRegistry: registry.load_all_versions() cls._default_instance = registry + return cls._default_instance return cls._default_instance diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 4c3efd6ff9..e4afe24426 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -1,5 +1,4 @@ from core.tools.entities.tool_entities import ToolInvokeMeta -from libs.exception import BaseHTTPException class ToolProviderNotFoundError(ValueError): @@ -38,12 +37,6 @@ class ToolCredentialPolicyViolationError(ValueError): pass -class WorkflowToolHumanInputNotSupportedError(BaseHTTPException): - error_code = "workflow_tool_human_input_not_supported" - description = "Workflow with Human Input nodes cannot be published as a workflow tool." - code = 400 - - class ToolEngineInvokeError(Exception): meta: ToolInvokeMeta diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f8213d9fd7..d561d39923 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -189,16 +189,13 @@ class ToolManager: raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") if not provider_controller.need_credentials: - return cast( - BuiltinTool, - builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), + return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ) builtin_provider = None if isinstance(provider_controller, PluginToolProviderController): @@ -300,18 +297,15 @@ class ToolManager: decrypted_credentials = refreshed_credentials.credentials cache.delete() - return cast( - BuiltinTool, - builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials=dict(decrypted_credentials), - credential_type=CredentialType.of(builtin_provider.credential_type), - runtime_parameters={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), + return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials=dict(decrypted_credentials), + credential_type=CredentialType.of(builtin_provider.credential_type), + runtime_parameters={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ) elif provider_type == ToolProviderType.API: diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 8588ccc718..6d75df3603 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -3,17 +3,10 @@ from typing import Any from core.app.app_config.entities import VariableEntity from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from core.workflow.enums import NodeType from core.workflow.nodes.base.entities import OutputVariableEntity class WorkflowToolConfigurationUtils: - @classmethod - def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): - for configuration in configurations: - WorkflowToolParameterConfiguration.model_validate(configuration) - @classmethod def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: """ @@ -52,13 +45,6 @@ class WorkflowToolConfigurationUtils: return [outputs_by_variable[variable] for variable in variable_order] - @classmethod - def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None: - nodes = graph.get("nodes", []) - for node in nodes: - if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT: - raise WorkflowToolHumanInputNotSupportedError() - @classmethod def check_is_synced( cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] diff --git a/api/core/trigger/debug/event_bus.py b/api/core/trigger/debug/event_bus.py index 9d10e1a0e0..e3fb6a13d9 100644 --- a/api/core/trigger/debug/event_bus.py +++ b/api/core/trigger/debug/event_bus.py @@ -23,8 +23,8 @@ class TriggerDebugEventBus: """ # LUA_SELECT: Atomic poll or register for event - # KEYS[1] = trigger_debug_inbox:{tenant_id}:{address_id} - # KEYS[2] = trigger_debug_waiting_pool:{tenant_id}:... + # KEYS[1] = trigger_debug_inbox:{}: + # KEYS[2] = trigger_debug_waiting_pool:{}:... # ARGV[1] = address_id LUA_SELECT = ( "local v=redis.call('GET',KEYS[1]);" @@ -35,7 +35,7 @@ class TriggerDebugEventBus: ) # LUA_DISPATCH: Dispatch event to all waiting addresses - # KEYS[1] = trigger_debug_waiting_pool:{tenant_id}:... + # KEYS[1] = trigger_debug_waiting_pool:{}:... # ARGV[1] = tenant_id # ARGV[2] = event_json LUA_DISPATCH = ( @@ -43,7 +43,7 @@ class TriggerDebugEventBus: "if #a==0 then return 0 end;" "redis.call('DEL',KEYS[1]);" "for i=1,#a do " - f"redis.call('SET','trigger_debug_inbox:'..ARGV[1]..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});" + f"redis.call('SET','trigger_debug_inbox:{{'..ARGV[1]..'}}'..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});" "end;" "return #a" ) @@ -108,7 +108,7 @@ class TriggerDebugEventBus: Event object if available, None otherwise """ address_id: str = hashlib.sha256(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest() - address: str = f"trigger_debug_inbox:{tenant_id}:{address_id}" + address: str = f"trigger_debug_inbox:{{{tenant_id}}}:{address_id}" try: event_data = redis_client.eval( diff --git a/api/core/trigger/debug/events.py b/api/core/trigger/debug/events.py index 9f7bab5e49..9aec342ed1 100644 --- a/api/core/trigger/debug/events.py +++ b/api/core/trigger/debug/events.py @@ -42,7 +42,7 @@ def build_webhook_pool_key(tenant_id: str, app_id: str, node_id: str) -> str: app_id: App ID node_id: Node ID """ - return f"{TriggerDebugPoolKey.WEBHOOK}:{tenant_id}:{app_id}:{node_id}" + return f"{TriggerDebugPoolKey.WEBHOOK}:{{{tenant_id}}}:{app_id}:{node_id}" class PluginTriggerDebugEvent(BaseDebugEvent): @@ -64,4 +64,4 @@ def build_plugin_pool_key(tenant_id: str, provider_id: str, subscription_id: str provider_id: Provider ID subscription_id: Subscription ID """ - return f"{TriggerDebugPoolKey.PLUGIN}:{tenant_id}:{str(provider_id)}:{subscription_id}:{name}" + return f"{TriggerDebugPoolKey.PLUGIN}:{{{tenant_id}}}:{str(provider_id)}:{subscription_id}:{name}" diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index e73c38c1d3..be70e467a0 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -2,12 +2,10 @@ from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution -from .workflow_start_reason import WorkflowStartReason __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", "WorkflowExecution", "WorkflowNodeExecution", - "WorkflowStartReason", ] diff --git a/api/core/workflow/entities/graph_config.py b/api/core/workflow/entities/graph_config.py new file mode 100644 index 0000000000..209dcfe6bc --- /dev/null +++ b/api/core/workflow/entities/graph_config.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import sys + +from pydantic import TypeAdapter, with_config + +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +@with_config(extra="allow") +class NodeConfigData(TypedDict): + type: str + + +@with_config(extra="allow") +class NodeConfigDict(TypedDict): + id: str + data: NodeConfigData + + +NodeConfigDictAdapter = TypeAdapter(NodeConfigDict) diff --git a/api/core/workflow/entities/graph_init_params.py b/api/core/workflow/entities/graph_init_params.py index ff224a28d1..7bf25b9f43 100644 --- a/api/core/workflow/entities/graph_init_params.py +++ b/api/core/workflow/entities/graph_init_params.py @@ -5,16 +5,6 @@ from pydantic import BaseModel, Field class GraphInitParams(BaseModel): - """GraphInitParams encapsulates the configurations and contextual information - that remain constant throughout a single execution of the graph engine. - - A single execution is defined as follows: as long as the execution has not reached - its conclusion, it is considered one execution. For instance, if a workflow is suspended - and later resumed, it is still regarded as a single execution, not two. - - For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`. - """ - # init params tenant_id: str = Field(..., description="tenant / workspace id") app_id: str = Field(..., description="app id") diff --git a/api/core/workflow/entities/pause_reason.py b/api/core/workflow/entities/pause_reason.py index 147f56e8be..c6655b7eab 100644 --- a/api/core/workflow/entities/pause_reason.py +++ b/api/core/workflow/entities/pause_reason.py @@ -1,11 +1,8 @@ -from collections.abc import Mapping from enum import StrEnum, auto -from typing import Annotated, Any, Literal, TypeAlias +from typing import Annotated, Literal, TypeAlias from pydantic import BaseModel, Field -from core.workflow.nodes.human_input.entities import FormInput, UserAction - class PauseReasonType(StrEnum): HUMAN_INPUT_REQUIRED = auto() @@ -14,31 +11,10 @@ class PauseReasonType(StrEnum): class HumanInputRequired(BaseModel): TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED + form_id: str - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - actions: list[UserAction] = Field(default_factory=list) - display_in_ui: bool = False + # The identifier of the human input node causing the pause. node_id: str - node_title: str - - # The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from - # `output_variable_name` to their resolved values. - # - # For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its - # selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable - # `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The - # `resolved_default_values` is `{"name": "John"}`. - # - # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - - # The `form_token` is the token used to submit the form via UI surfaces. It corresponds to - # `HumanInputFormRecipient.access_token`. - # - # This field is `None` if webapp delivery is not set and not - # in orchestrating mode. - form_token: str | None = None class SchedulingPause(BaseModel): diff --git a/api/core/workflow/entities/workflow_start_reason.py b/api/core/workflow/entities/workflow_start_reason.py deleted file mode 100644 index df0f75383b..0000000000 --- a/api/core/workflow/entities/workflow_start_reason.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import StrEnum - - -class WorkflowStartReason(StrEnum): - """Reason for workflow start events across graph/queue/SSE layers.""" - - INITIAL = "initial" # First start of a workflow run. - RESUMPTION = "resumption" # Start triggered after resuming a paused run. diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 31bf6f3b27..52bbbb20cc 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -5,15 +5,20 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Protocol, cast, final +from pydantic import TypeAdapter + +from core.workflow.entities.graph_config import NodeConfigDict from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType from core.workflow.nodes.base.node import Node -from libs.typing import is_str, is_str_dict +from libs.typing import is_str from .edge import Edge from .validation import get_graph_validator logger = logging.getLogger(__name__) +_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict]) + class NodeFactory(Protocol): """ @@ -23,7 +28,7 @@ class NodeFactory(Protocol): allowing for different node creation strategies while maintaining type safety. """ - def create_node(self, node_config: dict[str, object]) -> Node: + def create_node(self, node_config: NodeConfigDict) -> Node: """ Create a Node instance from node configuration data. @@ -63,28 +68,24 @@ class Graph: self.root_node = root_node @classmethod - def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]: + def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]: """ Parse node configurations and build a mapping of node IDs to configs. :param node_configs: list of node configuration dictionaries :return: mapping of node ID to node config """ - node_configs_map: dict[str, dict[str, object]] = {} + node_configs_map: dict[str, NodeConfigDict] = {} for node_config in node_configs: - node_id = node_config.get("id") - if not node_id or not isinstance(node_id, str): - continue - - node_configs_map[node_id] = node_config + node_configs_map[node_config["id"]] = node_config return node_configs_map @classmethod def _find_root_node_id( cls, - node_configs_map: Mapping[str, Mapping[str, object]], + node_configs_map: Mapping[str, NodeConfigDict], edge_configs: Sequence[Mapping[str, object]], root_node_id: str | None = None, ) -> str: @@ -113,10 +114,8 @@ class Graph: # Prefer START node if available start_node_id = None for nid in root_candidates: - node_data = node_configs_map[nid].get("data") - if not is_str_dict(node_data): - continue - node_type = node_data.get("type") + node_data = node_configs_map[nid]["data"] + node_type = node_data["type"] if not isinstance(node_type, str): continue if NodeType(node_type).is_start_node: @@ -176,7 +175,7 @@ class Graph: @classmethod def _create_node_instances( cls, - node_configs_map: dict[str, dict[str, object]], + node_configs_map: dict[str, NodeConfigDict], node_factory: NodeFactory, ) -> dict[str, Node]: """ @@ -303,7 +302,7 @@ class Graph: node_configs = graph_config.get("nodes", []) edge_configs = cast(list[dict[str, object]], edge_configs) - node_configs = cast(list[dict[str, object]], node_configs) + node_configs = _ListNodeConfigDict.validate_python(node_configs) if not node_configs: raise ValueError("Graph must have at least one node") diff --git a/api/core/workflow/graph_engine/_engine_utils.py b/api/core/workflow/graph_engine/_engine_utils.py deleted file mode 100644 index 28898268fe..0000000000 --- a/api/core/workflow/graph_engine/_engine_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import time - - -def get_timestamp() -> float: - """Retrieve a timestamp as a float point numer representing the number of seconds - since the Unix epoch. - - This function is primarily used to measure the execution time of the workflow engine. - Since workflow execution may be paused and resumed on a different machine, - `time.perf_counter` cannot be used as it is inconsistent across machines. - - To address this, the function uses the wall clock as the time source. - However, it assumes that the clocks of all servers are properly synchronized. - """ - return round(time.time()) diff --git a/api/core/workflow/graph_engine/config.py b/api/core/workflow/graph_engine/config.py index d56a69cee0..10dbbd7535 100644 --- a/api/core/workflow/graph_engine/config.py +++ b/api/core/workflow/graph_engine/config.py @@ -2,14 +2,12 @@ GraphEngine configuration models. """ -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel class GraphEngineConfig(BaseModel): """Configuration for GraphEngine worker pool scaling.""" - model_config = ConfigDict(frozen=True) - min_workers: int = 1 max_workers: int = 5 scale_up_threshold: int = 3 diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 98a0702e1c..5b0f56e59d 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -192,13 +192,9 @@ class EventHandler: self._event_collector.collect(edge_event) # Enqueue ready nodes - if self._graph_execution.is_paused: - for node_id in ready_nodes: - self._graph_runtime_state.register_deferred_node(node_id) - else: - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) + for node_id in ready_nodes: + self._state_manager.enqueue_node(node_id) + self._state_manager.start_execution(node_id) # Update execution tracking self._state_manager.finish_execution(event.node_id) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ac9e00e29e..2b76b563ff 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -14,7 +14,6 @@ from collections.abc import Generator from typing import TYPE_CHECKING, cast, final from core.workflow.context import capture_current_context -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import NodeExecutionType from core.workflow.graph import Graph from core.workflow.graph_events import ( @@ -47,7 +46,6 @@ from .graph_traversal import EdgeProcessor, SkipPropagator from .layers.base import GraphEngineLayer from .orchestration import Dispatcher, ExecutionCoordinator from .protocols.command_channel import CommandChannel -from .ready_queue import ReadyQueue from .worker_management import WorkerPool if TYPE_CHECKING: @@ -57,9 +55,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -_DEFAULT_CONFIG = GraphEngineConfig() - - @final class GraphEngine: """ @@ -75,7 +70,7 @@ class GraphEngine: graph: Graph, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel, - config: GraphEngineConfig = _DEFAULT_CONFIG, + config: GraphEngineConfig, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" # stop event @@ -94,7 +89,7 @@ class GraphEngine: self._graph_execution.workflow_id = workflow_id # === Execution Queues === - self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue) + self._ready_queue = self._graph_runtime_state.ready_queue # Queue for events generated during execution self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() @@ -239,9 +234,7 @@ class GraphEngine: self._graph_execution.paused = False self._graph_execution.pause_reasons = [] - start_event = GraphRunStartedEvent( - reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL, - ) + start_event = GraphRunStartedEvent() self._event_manager.notify_layers(start_event) yield start_event @@ -310,17 +303,15 @@ class GraphEngine: for layer in self._layers: try: layer.on_graph_start() - except Exception: - logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__) + except Exception as e: + logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e) def _start_execution(self, *, resume: bool = False) -> None: """Start execution subsystems.""" self._stop_event.clear() paused_nodes: list[str] = [] - deferred_nodes: list[str] = [] if resume: paused_nodes = self._graph_runtime_state.consume_paused_nodes() - deferred_nodes = self._graph_runtime_state.consume_deferred_nodes() # Start worker pool (it calculates initial workers internally) self._worker_pool.start() @@ -336,11 +327,7 @@ class GraphEngine: self._state_manager.enqueue_node(root_node.id) self._state_manager.start_execution(root_node.id) else: - seen_nodes: set[str] = set() - for node_id in paused_nodes + deferred_nodes: - if node_id in seen_nodes: - continue - seen_nodes.add(node_id) + for node_id in paused_nodes: self._state_manager.enqueue_node(node_id) self._state_manager.start_execution(node_id) @@ -358,8 +345,8 @@ class GraphEngine: for layer in self._layers: try: layer.on_graph_end(self._graph_execution.error) - except Exception: - logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__) + except Exception as e: + logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e) # Public property accessors for attributes that need external access @property diff --git a/api/core/workflow/graph_engine/graph_state_manager.py b/api/core/workflow/graph_engine/graph_state_manager.py index d9773645c3..22a3a826fc 100644 --- a/api/core/workflow/graph_engine/graph_state_manager.py +++ b/api/core/workflow/graph_engine/graph_state_manager.py @@ -224,8 +224,6 @@ class GraphStateManager: Returns: Number of executing nodes """ - # This count is a best-effort snapshot and can change concurrently. - # Only use it for pause-drain checks where scheduling is already frozen. with self._lock: return len(self._executing_nodes) diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index d40d15c545..27439a2412 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -83,12 +83,12 @@ class Dispatcher: """Main dispatcher loop.""" try: self._process_commands() - paused = False while not self._stop_event.is_set(): - if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete: - break - if self._execution_coordinator.paused: - paused = True + if ( + self._execution_coordinator.aborted + or self._execution_coordinator.paused + or self._execution_coordinator.execution_complete + ): break self._execution_coordinator.check_scaling() @@ -101,10 +101,13 @@ class Dispatcher: time.sleep(0.1) self._process_commands() - if paused: - self._drain_events_until_idle() - else: - self._drain_event_queue() + while True: + try: + event = self._event_queue.get(block=False) + self._event_handler.dispatch(event) + self._event_queue.task_done() + except queue.Empty: + break except Exception as e: logger.exception("Dispatcher error") @@ -119,24 +122,3 @@ class Dispatcher: def _process_commands(self, event: GraphNodeEventBase | None = None): if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS): self._execution_coordinator.process_commands() - - def _drain_event_queue(self) -> None: - while True: - try: - event = self._event_queue.get(block=False) - self._event_handler.dispatch(event) - self._event_queue.task_done() - except queue.Empty: - break - - def _drain_events_until_idle(self) -> None: - while not self._stop_event.is_set(): - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - if not self._execution_coordinator.has_executing_nodes(): - break - self._drain_event_queue() diff --git a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py index 0f8550eb12..e8e8f9f16c 100644 --- a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py +++ b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py @@ -94,11 +94,3 @@ class ExecutionCoordinator: self._worker_pool.stop() self._state_manager.clear_executing() - - def has_executing_nodes(self) -> bool: - """Return True if any nodes are currently marked as executing.""" - # This check is only safe once execution has already paused. - # Before pause, executing state can change concurrently, which makes the result unreliable. - if not self._graph_execution.is_paused: - raise AssertionError("has_executing_nodes should only be called after execution is paused") - return self._state_manager.get_executing_count() > 0 diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 98e0ea91ef..e82ba29438 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -15,10 +15,10 @@ from uuid import uuid4 from pydantic import BaseModel, Field from core.workflow.enums import NodeExecutionType, NodeState -from core.workflow.graph import Graph from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent from core.workflow.nodes.base.template import TextSegment, VariableSegment from core.workflow.runtime import VariablePool +from core.workflow.runtime.graph_runtime_state import GraphProtocol from .path import Path from .session import ResponseSession @@ -75,7 +75,7 @@ class ResponseStreamCoordinator: Ensures ordered streaming of responses based on upstream node outputs and constants. """ - def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None: + def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None: """ Initialize coordinator with variable pool. diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/core/workflow/graph_engine/response_coordinator/session.py index 8ceaa428c3..5e4fada7d9 100644 --- a/api/core/workflow/graph_engine/response_coordinator/session.py +++ b/api/core/workflow/graph_engine/response_coordinator/session.py @@ -10,10 +10,10 @@ from __future__ import annotations from dataclasses import dataclass from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.knowledge_index import KnowledgeIndexNode +from core.workflow.runtime.graph_runtime_state import NodeProtocol @dataclass @@ -29,21 +29,26 @@ class ResponseSession: index: int = 0 # Current position in the template segments @classmethod - def from_node(cls, node: Node) -> ResponseSession: + def from_node(cls, node: NodeProtocol) -> ResponseSession: """ - Create a ResponseSession from an AnswerNode or EndNode. + Create a ResponseSession from a response-capable node. + + The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer, + but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides: + - `id: str` + - `get_streaming_template() -> Template` Args: - node: Must be either an AnswerNode or EndNode instance + node: Node from the materialized workflow graph. Returns: ResponseSession configured with the node's streaming template Raises: - TypeError: If node is not an AnswerNode or EndNode + TypeError: If node is not a supported response node type. """ if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode): - raise TypeError + raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode") return cls( node_id=node.id, template=node.get_streaming_template(), diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py index 56ea642092..2b6ee4ec1c 100644 --- a/api/core/workflow/graph_events/__init__.py +++ b/api/core/workflow/graph_events/__init__.py @@ -38,8 +38,6 @@ from .loop import ( from .node import ( NodeRunExceptionEvent, NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, NodeRunPauseRequestedEvent, NodeRunRetrieverResourceEvent, NodeRunRetryEvent, @@ -62,8 +60,6 @@ __all__ = [ "NodeRunAgentLogEvent", "NodeRunExceptionEvent", "NodeRunFailedEvent", - "NodeRunHumanInputFormFilledEvent", - "NodeRunHumanInputFormTimeoutEvent", "NodeRunIterationFailedEvent", "NodeRunIterationNextEvent", "NodeRunIterationStartedEvent", diff --git a/api/core/workflow/graph_events/graph.py b/api/core/workflow/graph_events/graph.py index f46526bcab..5d10a76c15 100644 --- a/api/core/workflow/graph_events/graph.py +++ b/api/core/workflow/graph_events/graph.py @@ -1,16 +1,11 @@ from pydantic import Field from core.workflow.entities.pause_reason import PauseReason -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.graph_events import BaseGraphEvent class GraphRunStartedEvent(BaseGraphEvent): - # Reason is emitted for workflow start events and is always set. - reason: WorkflowStartReason = Field( - default=WorkflowStartReason.INITIAL, - description="reason for workflow start", - ) + pass class GraphRunSucceededEvent(BaseGraphEvent): diff --git a/api/core/workflow/graph_events/human_input.py b/api/core/workflow/graph_events/human_input.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index 975d72ad1f..4d0108e77b 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -54,22 +54,6 @@ class NodeRunRetryEvent(NodeRunStartedEvent): retry_index: int = Field(..., description="which retry attempt is about to be performed") -class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase): - """Emitted when a HumanInput form is submitted and before the node finishes.""" - - node_title: str = Field(..., description="HumanInput node title") - rendered_content: str = Field(..., description="Markdown content rendered with user inputs.") - action_id: str = Field(..., description="User action identifier chosen in the form.") - action_text: str = Field(..., description="Display text of the chosen action button.") - - -class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase): - """Emitted when a HumanInput form times out.""" - - node_title: str = Field(..., description="HumanInput node title") - expiration_time: datetime = Field(..., description="Form expiration time") - - class NodeRunPauseRequestedEvent(GraphNodeEventBase): reason: PauseReason = Field(..., description="pause reason") diff --git a/api/core/workflow/node_events/__init__.py b/api/core/workflow/node_events/__init__.py index a9bef8f9a2..f14a594c85 100644 --- a/api/core/workflow/node_events/__init__.py +++ b/api/core/workflow/node_events/__init__.py @@ -13,8 +13,6 @@ from .loop import ( LoopSucceededEvent, ) from .node import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, ModelInvokeCompletedEvent, PauseRequestedEvent, RunRetrieverResourceEvent, @@ -25,8 +23,6 @@ from .node import ( __all__ = [ "AgentLogEvent", - "HumanInputFormFilledEvent", - "HumanInputFormTimeoutEvent", "IterationFailedEvent", "IterationNextEvent", "IterationStartedEvent", diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index 9c76b7d7c2..e4fa52f444 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -47,19 +47,3 @@ class StreamCompletedEvent(NodeEventBase): class PauseRequestedEvent(NodeEventBase): reason: PauseReason = Field(..., description="pause reason") - - -class HumanInputFormFilledEvent(NodeEventBase): - """Event emitted when a human input form is submitted.""" - - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class HumanInputFormTimeoutEvent(NodeEventBase): - """Event emitted when a human input form times out.""" - - node_title: str - expiration_time: datetime diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5a365f769d..e195aebe6d 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]): result[parameter_name] = None continue agent_input = node_data.agent_parameters[parameter_name] - if agent_input.type == "variable": - variable = variable_pool.get(agent_input.value) # type: ignore - if variable is None: - raise AgentVariableNotFoundError(str(agent_input.value)) - parameter_value = variable.value - elif agent_input.type in {"mixed", "constant"}: - # variable_pool.convert_template expects a string template, - # but if passing a dict, convert to JSON string first before rendering - try: - if not isinstance(agent_input.value, str): - parameter_value = json.dumps(agent_input.value, ensure_ascii=False) - else: + match agent_input.type: + case "variable": + variable = variable_pool.get(agent_input.value) # type: ignore + if variable is None: + raise AgentVariableNotFoundError(str(agent_input.value)) + parameter_value = variable.value + case "mixed" | "constant": + # variable_pool.convert_template expects a string template, + # but if passing a dict, convert to JSON string first before rendering + try: + if not isinstance(agent_input.value, str): + parameter_value = json.dumps(agent_input.value, ensure_ascii=False) + else: + parameter_value = str(agent_input.value) + except TypeError: parameter_value = str(agent_input.value) - except TypeError: - parameter_value = str(agent_input.value) - segment_group = variable_pool.convert_template(parameter_value) - parameter_value = segment_group.log if for_log else segment_group.text - # variable_pool.convert_template returns a string, - # so we need to convert it back to a dictionary - try: - if not isinstance(agent_input.value, str): - parameter_value = json.loads(parameter_value) - except json.JSONDecodeError: - parameter_value = parameter_value - else: - raise AgentInputTypeError(agent_input.type) + segment_group = variable_pool.convert_template(parameter_value) + parameter_value = segment_group.log if for_log else segment_group.text + # variable_pool.convert_template returns a string, + # so we need to convert it back to a dictionary + try: + if not isinstance(agent_input.value, str): + parameter_value = json.loads(parameter_value) + except json.JSONDecodeError: + parameter_value = parameter_value + case _: + raise AgentInputTypeError(agent_input.type) value = parameter_value if parameter.type == "array[tools]": value = cast(list[dict[str, Any]], value) @@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]): result: dict[str, Any] = {} for parameter_name in typed_node_data.agent_parameters: input = typed_node_data.agent_parameters[parameter_name] - if input.type in ["mixed", "constant"]: - selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value + match input.type: + case "mixed" | "constant": + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value result = {node_id + "." + key: value for key, value in result.items()} diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index e5a20c8e91..c5426e3fb7 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -115,7 +115,7 @@ class DefaultValue(BaseModel): @model_validator(mode="after") def validate_value_type(self) -> DefaultValue: # Type validation configuration - type_validators = { + type_validators: dict[DefaultValueType, dict[str, Any]] = { DefaultValueType.STRING: { "type": str, "converter": lambda x: x, diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 2b773b537c..63e0260341 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -18,8 +18,6 @@ from core.workflow.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, NodeRunIterationFailedEvent, NodeRunIterationNextEvent, NodeRunIterationStartedEvent, @@ -36,8 +34,6 @@ from core.workflow.graph_events import ( ) from core.workflow.node_events import ( AgentLogEvent, - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, IterationFailedEvent, IterationNextEvent, IterationStartedEvent, @@ -65,15 +61,6 @@ logger = logging.getLogger(__name__) class Node(Generic[NodeDataT]): - """BaseNode serves as the foundational class for all node implementations. - - Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output` - attribute to track files generated by the LLM). However, these states are not persisted - when the workflow is suspended or resumed. If a node needs its state to be preserved - across workflow suspension and resumption, it should include the relevant state data - in its output. - """ - node_type: ClassVar[NodeType] execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData @@ -264,33 +251,10 @@ class Node(Generic[NodeDataT]): return self._node_execution_id def ensure_execution_id(self) -> str: - if self._node_execution_id: - return self._node_execution_id - - resumed_execution_id = self._restore_execution_id_from_runtime_state() - if resumed_execution_id: - self._node_execution_id = resumed_execution_id - return self._node_execution_id - - self._node_execution_id = str(uuid4()) + if not self._node_execution_id: + self._node_execution_id = str(uuid4()) return self._node_execution_id - def _restore_execution_id_from_runtime_state(self) -> str | None: - graph_execution = self.graph_runtime_state.graph_execution - try: - node_executions = graph_execution.node_executions - except AttributeError: - return None - if not isinstance(node_executions, dict): - return None - node_execution = node_executions.get(self._node_id) - if node_execution is None: - return None - execution_id = node_execution.execution_id - if not execution_id: - return None - return str(execution_id) - def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: return cast(NodeDataT, self._node_data_type.model_validate(data)) @@ -656,28 +620,6 @@ class Node(Generic[NodeDataT]): metadata=event.metadata, ) - @_dispatch.register - def _(self, event: HumanInputFormFilledEvent): - return NodeRunHumanInputFormFilledEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ) - - @_dispatch.register - def _(self, event: HumanInputFormTimeoutEvent): - return NodeRunHumanInputFormTimeoutEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - expiration_time=event.expiration_time, - ) - @_dispatch.register def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: return NodeRunLoopStartedEvent( diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 10a1c897e9..8026011196 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal, Self +from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel @@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: dict[str, Self] | None = None + children: dict[str, "CodeNodeData.Output"] | None = None class Dependency(BaseModel): name: str diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 925561cf7c..a732a70417 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -69,11 +69,13 @@ class DatasourceNode(Node[DatasourceNodeData]): if datasource_type is None: raise DatasourceNodeError("Datasource type is not set") + datasource_type = DatasourceProviderType.value_of(datasource_type) + datasource_runtime = DatasourceManager.get_datasource_runtime( provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", datasource_name=node_data.datasource_name or "", tenant_id=self.tenant_id, - datasource_type=DatasourceProviderType.value_of(datasource_type), + datasource_type=datasource_type, ) datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id) @@ -268,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]): if typed_node_data.datasource_parameters: for parameter_name in typed_node_data.datasource_parameters: input = typed_node_data.datasource_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value + case "constant": + pass + case None: + pass result = {node_id + "." + key: value for key, value in result.items()} @@ -306,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]): variables: dict[str, Any] = {} for message in message_stream: - if message.type in { - DatasourceMessage.MessageType.IMAGE_LINK, - DatasourceMessage.MessageType.BINARY_LINK, - DatasourceMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, DatasourceMessage.TextMessage) + match message.type: + case ( + DatasourceMessage.MessageType.IMAGE_LINK + | DatasourceMessage.MessageType.BINARY_LINK + | DatasourceMessage.MessageType.IMAGE + ): + assert isinstance(message.message, DatasourceMessage.TextMessage) - url = message.message.text - transfer_method = FileTransferMethod.TOOL_FILE + url = message.message.text + transfer_method = FileTransferMethod.TOOL_FILE - datasource_file_id = str(url).split("/")[-1].split(".")[0] + datasource_file_id = str(url).split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"Tool file {datasource_file_id} does not exist") + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"Tool file {datasource_file_id} does not exist") - mapping = { - "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - elif message.type == DatasourceMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, DatasourceMessage.TextMessage) - assert message.meta - - datasource_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"datasource file {datasource_file_id} not exists") - - mapping = { - "tool_file_id": datasource_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( + mapping = { + "tool_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( mapping=mapping, tenant_id=self.tenant_id, ) - ) - elif message.type == DatasourceMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceMessage.JsonMessage) - json.append(message.message.json_object) - elif message.type == DatasourceMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value + files.append(file) + case DatasourceMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, DatasourceMessage.TextMessage) + assert message.meta + datasource_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"datasource file {datasource_file_id} not exists") + + mapping = { + "tool_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + case DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) + text += message.message.text yield StreamChunkEvent( - selector=[self._node_id, variable_name], - chunk=variable_value, + selector=[self._node_id, "text"], + chunk=message.message.text, is_final=False, ) - else: - variables[variable_name] = variable_value - elif message.type == DatasourceMessage.MessageType.FILE: - assert message.meta is not None - files.append(message.meta["file"]) + case DatasourceMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceMessage.JsonMessage) + json.append(message.message.json_object) + case DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=stream_text, + is_final=False, + ) + case DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[self._node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + case DatasourceMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + case ( + DatasourceMessage.MessageType.BLOB_CHUNK + | DatasourceMessage.MessageType.LOG + | DatasourceMessage.MessageType.RETRIEVER_RESOURCES + ): + pass + # mark the end of the stream yield StreamChunkEvent( selector=[self._node_id, "text"], diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 429f8411a6..7de8216562 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -2,7 +2,7 @@ import base64 import json import secrets import string -from collections.abc import Mapping +from collections.abc import Callable, Mapping from copy import deepcopy from typing import Any, Literal from urllib.parse import urlencode, urlparse @@ -11,9 +11,9 @@ import httpx from json_repair import repair_json from configs import dify_config -from core.file import file_manager from core.file.enums import FileTransferMethod -from core.helper import ssrf_proxy +from core.file.file_manager import file_manager as default_file_manager +from core.helper.ssrf_proxy import ssrf_proxy from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.runtime import VariablePool @@ -79,8 +79,8 @@ class Executor: timeout: HttpRequestNodeTimeout, variable_pool: VariablePool, max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, - http_client: HttpClientProtocol = ssrf_proxy, - file_manager: FileManagerProtocol = file_manager, + http_client: HttpClientProtocol | None = None, + file_manager: FileManagerProtocol | None = None, ): # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": @@ -107,8 +107,8 @@ class Executor: self.data = None self.json = None self.max_retries = max_retries - self._http_client = http_client - self._file_manager = file_manager + self._http_client = http_client or ssrf_proxy + self._file_manager = file_manager or default_file_manager # init template self.variable_pool = variable_pool @@ -336,7 +336,7 @@ class Executor: """ do http request depending on api bundle """ - _METHOD_MAP = { + _METHOD_MAP: dict[str, Callable[..., httpx.Response]] = { "get": self._http_client.get, "head": self._http_client.head, "post": self._http_client.post, @@ -348,7 +348,7 @@ class Executor: if method_lc not in _METHOD_MAP: raise InvalidHttpMethodError(f"Invalid http method {self.method}") - request_args = { + request_args: dict[str, Any] = { "data": self.data, "files": self.files, "json": self.json, @@ -361,14 +361,13 @@ class Executor: } # request_args = {k: v for k, v in request_args.items() if v is not None} try: - response: httpx.Response = _METHOD_MAP[method_lc]( + response = _METHOD_MAP[method_lc]( url=self.url, **request_args, max_retries=self.max_retries, ) except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e: raise HttpRequestNodeError(str(e)) from e - # FIXME: fix type ignore, this maybe httpx type issue return response def invoke(self) -> Response: diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 964e53e03c..480482375f 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -4,8 +4,9 @@ from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any from configs import dify_config -from core.file import File, FileTransferMethod, file_manager -from core.helper import ssrf_proxy +from core.file import File, FileTransferMethod +from core.file.file_manager import file_manager as default_file_manager +from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.variables.segments import ArrayFileSegment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -47,9 +48,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - http_client: HttpClientProtocol = ssrf_proxy, + http_client: HttpClientProtocol | None = None, tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, - file_manager: FileManagerProtocol = file_manager, + file_manager: FileManagerProtocol | None = None, ) -> None: super().__init__( id=id, @@ -57,9 +58,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._http_client = http_client + self._http_client = http_client or ssrf_proxy self._tool_file_manager_factory = tool_file_manager_factory - self._file_manager = file_manager + self._file_manager = file_manager or default_file_manager @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: diff --git a/api/core/workflow/nodes/human_input/__init__.py b/api/core/workflow/nodes/human_input/__init__.py index 1789604577..379440557c 100644 --- a/api/core/workflow/nodes/human_input/__init__.py +++ b/api/core/workflow/nodes/human_input/__init__.py @@ -1,3 +1,3 @@ -""" -Human Input node implementation. -""" +from .human_input_node import HumanInputNode + +__all__ = ["HumanInputNode"] diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/core/workflow/nodes/human_input/entities.py index 72d4fc675b..02913d93c3 100644 --- a/api/core/workflow/nodes/human_input/entities.py +++ b/api/core/workflow/nodes/human_input/entities.py @@ -1,350 +1,10 @@ -""" -Human Input node entities. -""" +from pydantic import Field -import re -import uuid -from collections.abc import Mapping, Sequence -from datetime import datetime, timedelta -from typing import Annotated, Any, ClassVar, Literal, Self - -from pydantic import BaseModel, Field, field_validator, model_validator - -from core.variables.consts import SELECTORS_LENGTH from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.runtime import VariablePool - -from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -class _WebAppDeliveryConfig(BaseModel): - """Configuration for webapp delivery method.""" - - pass # Empty for webapp delivery - - -class MemberRecipient(BaseModel): - """Member recipient for email delivery.""" - - type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER - user_id: str - - -class ExternalRecipient(BaseModel): - """External recipient for email delivery.""" - - type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL - email: str - - -EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")] - - -class EmailRecipients(BaseModel): - """Email recipients configuration.""" - - # When true, recipients are the union of all workspace members and external items. - # Member items are ignored because they are already covered by the workspace scope. - # De-duplication is applied by email, with member recipients taking precedence. - whole_workspace: bool = False - items: list[EmailRecipient] = Field(default_factory=list) - - -class EmailDeliveryConfig(BaseModel): - """Configuration for email delivery method.""" - - URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" - - recipients: EmailRecipients - - # the subject of email - subject: str - - # Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which - # represent the url to submit the form. - # - # It may also reference the output variable of the previous node with the syntax - # `{{#.#}}`. - body: str - debug_mode: bool = False - - def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig": - if not user_id: - debug_recipients = EmailRecipients(whole_workspace=False, items=[]) - return self.model_copy(update={"recipients": debug_recipients}) - debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) - return self.model_copy(update={"recipients": debug_recipients}) - - @classmethod - def replace_url_placeholder(cls, body: str, url: str | None) -> str: - """Replace the url placeholder with provided value.""" - return body.replace(cls.URL_PLACEHOLDER, url or "") - - @classmethod - def render_body_template( - cls, - *, - body: str, - url: str | None, - variable_pool: VariablePool | None = None, - ) -> str: - """Render email body by replacing placeholders with runtime values.""" - templated_body = cls.replace_url_placeholder(body, url) - if variable_pool is None: - return templated_body - return variable_pool.convert_template(templated_body).text - - -class _DeliveryMethodBase(BaseModel): - """Base delivery method configuration.""" - - enabled: bool = True - id: uuid.UUID = Field(default_factory=uuid.uuid4) - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - return () - - -class WebAppDeliveryMethod(_DeliveryMethodBase): - """Webapp delivery method configuration.""" - - type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP - # The config field is not used currently. - config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig) - - -class EmailDeliveryMethod(_DeliveryMethodBase): - """Email delivery method configuration.""" - - type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL - config: EmailDeliveryConfig - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - variable_template_parser = VariableTemplateParser(template=self.config.body) - selectors: list[Sequence[str]] = [] - for variable_selector in variable_template_parser.extract_variable_selectors(): - value_selector = list(variable_selector.value_selector) - if len(value_selector) < SELECTORS_LENGTH: - continue - selectors.append(value_selector[:SELECTORS_LENGTH]) - return selectors - - -DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] - - -def apply_debug_email_recipient( - method: DeliveryChannelConfig, - *, - enabled: bool, - user_id: str, -) -> DeliveryChannelConfig: - if not enabled: - return method - if not isinstance(method, EmailDeliveryMethod): - return method - if not method.config.debug_mode: - return method - debug_config = method.config.with_debug_recipient(user_id or "") - return method.model_copy(update={"config": debug_config}) - - -class FormInputDefault(BaseModel): - """Default configuration for form inputs.""" - - # NOTE: Ideally, a discriminated union would be used to model - # FormInputDefault. However, the UI requires preserving the previous - # value when switching between `VARIABLE` and `CONSTANT` types. This - # necessitates retaining all fields, making a discriminated union unsuitable. - - type: PlaceholderType - - # The selector of default variable, used when `type` is `VARIABLE`. - selector: Sequence[str] = Field(default_factory=tuple) # - - # The value of the default, used when `type` is `CONSTANT`. - # TODO: How should we express JSON values? - value: str = "" - - @model_validator(mode="after") - def _validate_selector(self) -> Self: - if self.type == PlaceholderType.CONSTANT: - return self - if len(self.selector) < SELECTORS_LENGTH: - raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") - return self - - -class FormInput(BaseModel): - """Form input definition.""" - - type: FormInputType - output_variable_name: str - default: FormInputDefault | None = None - - -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -class UserAction(BaseModel): - """User action configuration.""" - - # id is the identifier for this action. - # It also serves as the identifiers of output handle. - # - # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) - id: str = Field(max_length=20) - title: str = Field(max_length=20) - button_style: ButtonStyle = ButtonStyle.DEFAULT - - @field_validator("id") - @classmethod - def _validate_id(cls, value: str) -> str: - if not _IDENTIFIER_PATTERN.match(value): - raise ValueError( - f"'{value}' is not a valid identifier. It must start with a letter or underscore, " - f"and contain only letters, numbers, or underscores." - ) - return value class HumanInputNodeData(BaseNodeData): - """Human Input node data.""" + """Configuration schema for the HumanInput node.""" - delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) - form_content: str = "" - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - timeout: int = 36 - timeout_unit: TimeoutUnit = TimeoutUnit.HOUR - - @field_validator("inputs") - @classmethod - def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: - seen_names: set[str] = set() - for form_input in inputs: - name = form_input.output_variable_name - if name in seen_names: - raise ValueError(f"duplicated output_variable_name '{name}' in inputs") - seen_names.add(name) - return inputs - - @field_validator("user_actions") - @classmethod - def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: - seen_ids: set[str] = set() - for action in user_actions: - action_id = action.id - if action_id in seen_ids: - raise ValueError(f"duplicated user action id '{action_id}'") - seen_ids.add(action_id) - return user_actions - - def is_webapp_enabled(self) -> bool: - for dm in self.delivery_methods: - if not dm.enabled: - continue - if dm.type == DeliveryMethodType.WEBAPP: - return True - return False - - def expiration_time(self, start_time: datetime) -> datetime: - if self.timeout_unit == TimeoutUnit.HOUR: - return start_time + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - return start_time + timedelta(days=self.timeout) - else: - raise AssertionError("unknown timeout unit.") - - def outputs_field_names(self) -> Sequence[str]: - field_names = [] - for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): - field_names.append(match.group("field_name")) - return field_names - - def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: - variable_mappings: dict[str, Sequence[str]] = {} - - def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: - for selector in selectors: - if len(selector) < SELECTORS_LENGTH: - continue - qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" - variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) - - form_template_parser = VariableTemplateParser(template=self.form_content) - _add_variable_selectors( - [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] - ) - for delivery_method in self.delivery_methods: - if not delivery_method.enabled: - continue - _add_variable_selectors(delivery_method.extract_variable_selectors()) - - for input in self.inputs: - default_value = input.default - if default_value is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - default_value_key = ".".join(default_value.selector) - qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" - variable_mappings[qualified_variable_mapping_key] = default_value.selector - - return variable_mappings - - def find_action_text(self, action_id: str) -> str: - """ - Resolve action display text by id. - """ - for action in self.user_actions: - if action.id == action_id: - return action.title - return action_id - - -class FormDefinition(BaseModel): - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - rendered_content: str - expiration_time: datetime - - # this is used to store the resolved default values - default_values: dict[str, Any] = Field(default_factory=dict) - - # node_title records the title of the HumanInput node. - node_title: str | None = None - - # display_in_ui controls whether the form should be displayed in UI surfaces. - display_in_ui: bool | None = None - - -class HumanInputSubmissionValidationError(ValueError): - pass - - -def validate_human_input_submission( - *, - inputs: Sequence[FormInput], - user_actions: Sequence[UserAction], - selected_action_id: str, - form_data: Mapping[str, Any], -) -> None: - available_actions = {action.id for action in user_actions} - if selected_action_id not in available_actions: - raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") - - provided_inputs = set(form_data.keys()) - missing_inputs = [ - form_input.output_variable_name - for form_input in inputs - if form_input.output_variable_name not in provided_inputs - ] - - if missing_inputs: - missing_list = ", ".join(missing_inputs) - raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") + required_variables: list[str] = Field(default_factory=list) + pause_reason: str | None = Field(default=None) diff --git a/api/core/workflow/nodes/human_input/enums.py b/api/core/workflow/nodes/human_input/enums.py deleted file mode 100644 index da85728828..0000000000 --- a/api/core/workflow/nodes/human_input/enums.py +++ /dev/null @@ -1,72 +0,0 @@ -import enum - - -class HumanInputFormStatus(enum.StrEnum): - """Status of a human input form.""" - - # Awaiting submission from any recipient. Forms stay in this state until - # submitted or a timeout rule applies. - WAITING = enum.auto() - # Global timeout reached. The workflow run is stopped and will not resume. - # This is distinct from node-level timeout. - EXPIRED = enum.auto() - # Submitted by a recipient; form data is available and execution resumes - # along the selected action edge. - SUBMITTED = enum.auto() - # Node-level timeout reached. The human input node should emit a timeout - # event and the workflow should resume along the timeout edge. - TIMEOUT = enum.auto() - - -class HumanInputFormKind(enum.StrEnum): - """Kind of a human input form.""" - - RUNTIME = enum.auto() # Form created during workflow execution. - DELIVERY_TEST = enum.auto() # Form created for delivery tests. - - -class DeliveryMethodType(enum.StrEnum): - """Delivery method types for human input forms.""" - - # WEBAPP controls whether the form is delivered to the web app. It not only controls - # the standalone web app, but also controls the installed apps in the console. - WEBAPP = enum.auto() - - EMAIL = enum.auto() - - -class ButtonStyle(enum.StrEnum): - """Button styles for user actions.""" - - PRIMARY = enum.auto() - DEFAULT = enum.auto() - ACCENT = enum.auto() - GHOST = enum.auto() - - -class TimeoutUnit(enum.StrEnum): - """Timeout unit for form expiration.""" - - HOUR = enum.auto() - DAY = enum.auto() - - -class FormInputType(enum.StrEnum): - """Form input types.""" - - TEXT_INPUT = enum.auto() - PARAGRAPH = enum.auto() - - -class PlaceholderType(enum.StrEnum): - """Default value types for form inputs.""" - - VARIABLE = enum.auto() - CONSTANT = enum.auto() - - -class EmailRecipientType(enum.StrEnum): - """Email recipient types.""" - - MEMBER = enum.auto() - EXTERNAL = enum.auto() diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index 1d7522ea25..6c8bf36fab 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -1,42 +1,12 @@ -import json -import logging -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from collections.abc import Mapping +from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - NodeRunResult, - PauseRequestedEvent, -) -from core.workflow.node_events.base import NodeEventBase -from core.workflow.node_events.node import StreamCompletedEvent +from core.workflow.node_events import NodeRunResult, PauseRequestedEvent from core.workflow.nodes.base.node import Node -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db -from libs.datetime_utils import naive_utc_now -from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient -from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType - -if TYPE_CHECKING: - from core.workflow.entities.graph_init_params import GraphInitParams - from core.workflow.runtime.graph_runtime_state import GraphRuntimeState - - -_SELECTED_BRANCH_KEY = "selected_branch" - - -logger = logging.getLogger(__name__) +from .entities import HumanInputNodeData class HumanInputNode(Node[HumanInputNodeData]): @@ -47,7 +17,7 @@ class HumanInputNode(Node[HumanInputNodeData]): "edge_source_handle", "edgeSourceHandle", "source_handle", - _SELECTED_BRANCH_KEY, + "selected_branch", "selectedBranch", "branch", "branch_id", @@ -55,37 +25,43 @@ class HumanInputNode(Node[HumanInputNodeData]): "handle", ) - _node_data: HumanInputNodeData - _form_repository: HumanInputFormRepository - _OUTPUT_FIELD_ACTION_ID = "__action_id" - _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" - _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" - - def __init__( - self, - id: str, - config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - if form_repository is None: - form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, - tenant_id=self.tenant_id, - ) - self._form_repository = form_repository - @classmethod def version(cls) -> str: return "1" + def _run(self): # type: ignore[override] + if self._is_completion_ready(): + branch_handle = self._resolve_branch_selection() + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={}, + edge_source_handle=branch_handle or "source", + ) + + return self._pause_generator() + + def _pause_generator(self): + # TODO(QuantumGhost): yield a real form id. + yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id)) + + def _is_completion_ready(self) -> bool: + """Determine whether all required inputs are satisfied.""" + + if not self.node_data.required_variables: + return False + + variable_pool = self.graph_runtime_state.variable_pool + + for selector_str in self.node_data.required_variables: + parts = selector_str.split(".") + if len(parts) != 2: + return False + segment = variable_pool.get(parts) + if segment is None: + return False + + return True + def _resolve_branch_selection(self) -> str | None: """Determine the branch handle selected by human input if available.""" @@ -132,224 +108,3 @@ class HumanInputNode(Node[HumanInputNodeData]): return candidate return None - - @property - def _workflow_execution_id(self) -> str: - workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id - assert workflow_exec_id is not None - return workflow_exec_id - - def _form_to_pause_event(self, form_entity: HumanInputFormEntity): - required_event = self._human_input_required_event(form_entity) - pause_requested_event = PauseRequestedEvent(reason=required_event) - return pause_requested_event - - def resolve_default_values(self) -> Mapping[str, Any]: - variable_pool = self.graph_runtime_state.variable_pool - resolved_defaults: dict[str, Any] = {} - for input in self._node_data.inputs: - if (default_value := input.default) is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - resolved_value = variable_pool.get(default_value.selector) - if resolved_value is None: - # TODO: How should we handle this? - continue - resolved_defaults[input.output_variable_name] = ( - WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value) - ) - - return resolved_defaults - - def _should_require_console_recipient(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: - return True - if self.invoke_from == InvokeFrom.EXPLORE: - return self._node_data.is_webapp_enabled() - return False - - def _display_in_ui(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: - return True - return self._node_data.is_webapp_enabled() - - def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: - enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}: - enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] - return [ - apply_debug_email_recipient( - method, - enabled=self.invoke_from == InvokeFrom.DEBUGGER, - user_id=self.user_id or "", - ) - for method in enabled_methods - ] - - def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: - node_data = self._node_data - resolved_default_values = self.resolve_default_values() - display_in_ui = self._display_in_ui() - form_token = form_entity.web_app_token - if display_in_ui and form_token is None: - raise AssertionError("Form token should be available for UI execution.") - return HumanInputRequired( - form_id=form_entity.id, - form_content=form_entity.rendered_content, - inputs=node_data.inputs, - actions=node_data.user_actions, - display_in_ui=display_in_ui, - node_id=self.id, - node_title=node_data.title, - form_token=form_token, - resolved_default_values=resolved_default_values, - ) - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Execute the human input node. - - This method will: - 1. Generate a unique form ID - 2. Create form content with variable substitution - 3. Create form in database - 4. Send form via configured delivery methods - 5. Suspend workflow execution - 6. Wait for form submission to resume - """ - repo = self._form_repository - form = repo.get_form(self._workflow_execution_id, self.id) - if form is None: - display_in_ui = self._display_in_ui() - params = FormCreateParams( - app_id=self.app_id, - workflow_execution_id=self._workflow_execution_id, - node_id=self.id, - form_config=self._node_data, - rendered_content=self.render_form_content_before_submission(), - delivery_methods=self._effective_delivery_methods(), - display_in_ui=display_in_ui, - resolved_default_values=self.resolve_default_values(), - console_recipient_required=self._should_require_console_recipient(), - console_creator_account_id=( - self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None - ), - backstage_recipient_required=True, - ) - form_entity = self._form_repository.create_form(params) - # Create human input required event - - logger.info( - "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", - self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, - self.id, - form_entity.id, - ) - yield self._form_to_pause_event(form_entity) - return - - if ( - form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED} - or form.expiration_time <= naive_utc_now() - ): - yield HumanInputFormTimeoutEvent( - node_title=self._node_data.title, - expiration_time=form.expiration_time, - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={self._OUTPUT_FIELD_ACTION_ID: ""}, - edge_source_handle=self._TIMEOUT_HANDLE, - ) - ) - return - - if not form.submitted: - yield self._form_to_pause_event(form) - return - - selected_action_id = form.selected_action_id - if selected_action_id is None: - raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") - submitted_data = form.submitted_data or {} - outputs: dict[str, Any] = dict(submitted_data) - outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id - rendered_content = self.render_form_content_with_outputs( - form.rendered_content, - outputs, - self._node_data.outputs_field_names(), - ) - outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content - - action_text = self._node_data.find_action_text(selected_action_id) - - yield HumanInputFormFilledEvent( - node_title=self._node_data.title, - rendered_content=rendered_content, - action_id=selected_action_id, - action_text=action_text, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - edge_source_handle=selected_action_id, - ) - ) - - def render_form_content_before_submission(self) -> str: - """ - Process form content by substituting variables. - - This method should: - 1. Parse the form_content markdown - 2. Substitute {{#node_name.var_name#}} with actual values - 3. Keep {{#$output.field_name#}} placeholders for form inputs - """ - rendered_form_content = self.graph_runtime_state.variable_pool.convert_template( - self._node_data.form_content, - ) - return rendered_form_content.markdown - - @staticmethod - def render_form_content_with_outputs( - form_content: str, - outputs: Mapping[str, Any], - field_names: Sequence[str], - ) -> str: - """ - Replace {{#$output.xxx#}} placeholders with submitted values. - """ - rendered_content = form_content - for field_name in field_names: - placeholder = "{{#$output." + field_name + "#}}" - value = outputs.get(field_name) - if value is None: - replacement = "" - elif isinstance(value, (dict, list)): - replacement = json.dumps(value, ensure_ascii=False) - else: - replacement = str(value) - rendered_content = rendered_content.replace(placeholder, replacement) - return rendered_content - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: Mapping[str, Any], - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selectors referenced in form content and input default values. - - This method should parse: - 1. Variables referenced in form_content ({{#node_name.var_name#}}) - 2. Variables referenced in input default values - """ - validated_node_data = HumanInputNodeData.model_validate(node_data) - return validated_node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index c19182549f..25a881ea7d 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -397,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return outputs # Check if all non-None outputs are lists - non_none_outputs = [output for output in outputs if output is not None] + non_none_outputs: list[object] = [output for output in outputs if output is not None] if not non_none_outputs: return outputs diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index b88c2d510f..2aff953bc6 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): indexing_technique = node_data.indexing_technique or dataset.indexing_technique summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting + # Try to get document language if document_id is available + doc_language = None + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter_by(id=document_id.value).first() + if document and document.doc_language: + doc_language = document.doc_language + outputs = self._get_preview_output_with_summaries( node_data.chunk_structure, chunks, dataset=dataset, indexing_technique=indexing_technique, summary_index_setting=summary_index_setting, + doc_language=doc_language, ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): dataset: Dataset, indexing_technique: str | None = None, summary_index_setting: dict | None = None, + doc_language: str | None = None, ) -> Mapping[str, Any]: """ Generate preview output with summaries for chunks in preview mode. @@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): dataset: Dataset object (for tenant_id) indexing_technique: Indexing technique from node config or dataset summary_index_setting: Summary index setting from node config or dataset + doc_language: Optional document language to ensure summary is generated in the correct language """ index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() preview_output = index_processor.format_preview(chunks) @@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): tenant_id=dataset.tenant_id, text=preview_item["content"], summary_index_setting=summary_index_setting, + document_language=doc_language, ) if summary: preview_item["summary"] = summary @@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): tenant_id=dataset.tenant_id, text=preview_item["content"], summary_index_setting=summary_index_setting, + document_language=doc_language, ) if summary: preview_item["summary"] = summary diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 3c4850ebac..0827494a48 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") - if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": - if node_data.multiple_retrieval_config.reranking_model: - reranking_model = { - "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, - "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, - } - else: + match node_data.multiple_retrieval_config.reranking_mode: + case "reranking_model": + if node_data.multiple_retrieval_config.reranking_model: + reranking_model = { + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, + } + else: + reranking_model = None + weights = None + case "weighted_score": + if node_data.multiple_retrieval_config.weights is None: + raise ValueError("weights is required") reranking_model = None - weights = None - elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": - if node_data.multiple_retrieval_config.weights is None: - raise ValueError("weights is required") - reranking_model = None - vector_setting = node_data.multiple_retrieval_config.weights.vector_setting - weights = { - "vector_setting": { - "vector_weight": vector_setting.vector_weight, - "embedding_provider_name": vector_setting.embedding_provider_name, - "embedding_model_name": vector_setting.embedding_model_name, - }, - "keyword_setting": { - "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight - }, - } - else: - reranking_model = None - weights = None + vector_setting = node_data.multiple_retrieval_config.weights.vector_setting + weights = { + "vector_setting": { + "vector_weight": vector_setting.vector_weight, + "embedding_provider_name": vector_setting.embedding_provider_name, + "embedding_model_name": vector_setting.embedding_model_name, + }, + "keyword_setting": { + "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight + }, + } + case _: + reranking_model = None + weights = None all_documents = dataset_retrieval.multiple_retrieve( app_id=self.app_id, tenant_id=self.tenant_id, @@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD ) filters: list[Any] = [] metadata_condition = None - if node_data.metadata_filtering_mode == "disabled": - return None, None, usage - elif node_data.metadata_filtering_mode == "automatic": - automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( - dataset_ids, query, node_data - ) - usage = self._merge_usage(usage, automatic_usage) - if automatic_metadata_filters: - conditions = [] - for sequence, filter in enumerate(automatic_metadata_filters): - DatasetRetrieval.process_metadata_filter_func( - sequence, - filter.get("condition", ""), - filter.get("metadata_name", ""), - filter.get("value"), - filters, - ) - conditions.append( - Condition( - name=filter.get("metadata_name"), # type: ignore - comparison_operator=filter.get("condition"), # type: ignore - value=filter.get("value"), - ) - ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator - if node_data.metadata_filtering_conditions - else "or", - conditions=conditions, + match node_data.metadata_filtering_mode: + case "disabled": + return None, None, usage + case "automatic": + automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( + dataset_ids, query, node_data ) - elif node_data.metadata_filtering_mode == "manual": - if node_data.metadata_filtering_conditions: - conditions = [] - for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore - metadata_name = condition.name - expected_value = condition.value - if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): - if isinstance(expected_value, str): - expected_value = self.graph_runtime_state.variable_pool.convert_template( - expected_value - ).value[0] - if expected_value.value_type in {"number", "integer", "float"}: - expected_value = expected_value.value - elif expected_value.value_type == "string": - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() - else: - raise ValueError("Invalid expected metadata value type") - conditions.append( - Condition( - name=metadata_name, - comparison_operator=condition.comparison_operator, - value=expected_value, + usage = self._merge_usage(usage, automatic_usage) + if automatic_metadata_filters: + conditions = [] + for sequence, filter in enumerate(automatic_metadata_filters): + DatasetRetrieval.process_metadata_filter_func( + sequence, + filter.get("condition", ""), + filter.get("metadata_name", ""), + filter.get("value"), + filters, ) + conditions.append( + Condition( + name=filter.get("metadata_name"), # type: ignore + comparison_operator=filter.get("condition"), # type: ignore + value=filter.get("value"), + ) + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator + if node_data.metadata_filtering_conditions + else "or", + conditions=conditions, ) - filters = DatasetRetrieval.process_metadata_filter_func( - sequence, - condition.comparison_operator, - metadata_name, - expected_value, - filters, + case "manual": + if node_data.metadata_filtering_conditions: + conditions = [] + for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore + metadata_name = condition.name + expected_value = condition.value + if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): + if isinstance(expected_value, str): + expected_value = self.graph_runtime_state.variable_pool.convert_template( + expected_value + ).value[0] + if expected_value.value_type in {"number", "integer", "float"}: + expected_value = expected_value.value + elif expected_value.value_type == "string": + expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() + else: + raise ValueError("Invalid expected metadata value type") + conditions.append( + Condition( + name=metadata_name, + comparison_operator=condition.comparison_operator, + value=expected_value, + ) + ) + filters = DatasetRetrieval.process_metadata_filter_func( + sequence, + condition.comparison_operator, + metadata_name, + expected_value, + filters, + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator, + conditions=conditions, ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator, - conditions=conditions, - ) - else: - raise ValueError("Invalid metadata filtering mode") + case _: + raise ValueError("Invalid metadata filtering mode") if filters: if ( node_data.metadata_filtering_conditions diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 813d898b9a..235f5b9c52 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: case "name": return lambda x: x.filename or "" case "type": - return lambda x: x.type + return lambda x: str(x.type) case "extension": return lambda x: x.extension or "" case "mime_type": return lambda x: x.mime_type or "" case "transfer_method": - return lambda x: x.transfer_method + return lambda x: str(x.transfer_method) case "url": return lambda x: x.remote_url or "" case "related_id": @@ -276,7 +276,6 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: - extract_func: Callable[[File], Any] if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str): extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) @@ -284,8 +283,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) elif key == "size" and isinstance(value, str): - extract_func = _get_file_extract_number_func(key=key) - return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) + extract_number = _get_file_extract_number_func(key=key) + return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x)) else: raise InvalidKeyError(f"Invalid key: {key}") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 17d82c2118..beccf79344 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -852,18 +852,16 @@ class LLMNode(Node[LLMNodeData]): # Insert histories into the prompt prompt_content = prompt_messages[0].content # For issue #11247 - Check if prompt content is a string or a list - prompt_content_type = type(prompt_content) - if prompt_content_type == str: + if isinstance(prompt_content, str): prompt_content = str(prompt_content) if "#histories#" in prompt_content: prompt_content = prompt_content.replace("#histories#", memory_text) else: prompt_content = memory_text + "\n" + prompt_content prompt_messages[0].content = prompt_content - elif prompt_content_type == list: - prompt_content = prompt_content if isinstance(prompt_content, list) else [] + elif isinstance(prompt_content, list): for content_item in prompt_content: - if content_item.type == PromptMessageContentType.TEXT: + if isinstance(content_item, TextPromptMessageContent): if "#histories#" in content_item.data: content_item.data = content_item.data.replace("#histories#", memory_text) else: @@ -873,13 +871,12 @@ class LLMNode(Node[LLMNodeData]): # Add current query to the prompt message if sys_query: - if prompt_content_type == str: + if isinstance(prompt_content, str): prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) prompt_messages[0].content = prompt_content - elif prompt_content_type == list: - prompt_content = prompt_content if isinstance(prompt_content, list) else [] + elif isinstance(prompt_content, list): for content_item in prompt_content: - if content_item.type == PromptMessageContentType.TEXT: + if isinstance(content_item, TextPromptMessageContent): content_item.data = sys_query + "\n" + content_item.data else: raise ValueError("Invalid prompt content type") @@ -1033,14 +1030,14 @@ class LLMNode(Node[LLMNodeData]): if typed_node_data.prompt_config: enable_jinja = False - if isinstance(prompt_template, list): + if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + if prompt_template.edition_type == "jinja2": + enable_jinja = True + else: for prompt in prompt_template: if prompt.edition_type == "jinja2": enable_jinja = True break - else: - if prompt_template.edition_type == "jinja2": - enable_jinja = True if enable_jinja: for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py index e7dcf62fcf..2ad39e0ab5 100644 --- a/api/core/workflow/nodes/protocols.py +++ b/api/core/workflow/nodes/protocols.py @@ -1,4 +1,4 @@ -from typing import Protocol +from typing import Any, Protocol import httpx @@ -12,17 +12,17 @@ class HttpClientProtocol(Protocol): @property def request_error(self) -> type[Exception]: ... - def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... class FileManagerProtocol(Protocol): diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 68ac60e4f6..60d76db9b6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -482,16 +482,17 @@ class ToolNode(Node[ToolNodeData]): result = {} for parameter_name in typed_node_data.tool_parameters: input = typed_node_data.tool_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + selector_key = ".".join(input.value) + result[f"#{selector_key}#"] = input.value + case "constant": + pass result = {node_id + "." + key: value for key, value in result.items()} diff --git a/api/core/workflow/repositories/human_input_form_repository.py b/api/core/workflow/repositories/human_input_form_repository.py deleted file mode 100644 index efde59c6fd..0000000000 --- a/api/core/workflow/repositories/human_input_form_repository.py +++ /dev/null @@ -1,152 +0,0 @@ -import abc -import dataclasses -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any, Protocol - -from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus - - -class HumanInputError(Exception): - pass - - -class FormNotFoundError(HumanInputError): - pass - - -@dataclasses.dataclass -class FormCreateParams: - # app_id is the identifier for the app that the form belongs to. - # It is a string with uuid format. - app_id: str - # None when creating a delivery test form; set for runtime forms. - workflow_execution_id: str | None - - # node_id is the identifier for a specific - # node in the graph. - # - # TODO: for node inside loop / iteration, this would - # cause problems, as a single node may be executed multiple times. - node_id: str - - form_config: HumanInputNodeData - rendered_content: str - # Delivery methods already filtered by runtime context (invoke_from). - delivery_methods: Sequence[DeliveryChannelConfig] - # UI display flag computed by runtime context. - display_in_ui: bool - - # resolved_default_values saves the values for defaults with - # type = VARIABLE. - # - # For type = CONSTANT, the value is not stored inside `resolved_default_values` - resolved_default_values: Mapping[str, Any] - form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME - - # Force creating a console-only recipient for submission in Console. - console_recipient_required: bool = False - console_creator_account_id: str | None = None - # Force creating a backstage recipient for submission in Console. - backstage_recipient_required: bool = False - - -class HumanInputFormEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of the form.""" - pass - - @property - @abc.abstractmethod - def web_app_token(self) -> str | None: - """web_app_token returns the token for submission inside webapp. - - For console/debug execution, this may point to the console submission token - if the form is configured to require console delivery. - """ - - # TODO: what if the users are allowed to add multiple - # webapp delivery? - pass - - @property - @abc.abstractmethod - def recipients(self) -> list["HumanInputFormRecipientEntity"]: ... - - @property - @abc.abstractmethod - def rendered_content(self) -> str: - """Rendered markdown content associated with the form.""" - ... - - @property - @abc.abstractmethod - def selected_action_id(self) -> str | None: - """Identifier of the selected user action if the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def submitted_data(self) -> Mapping[str, Any] | None: - """Submitted form data if available.""" - ... - - @property - @abc.abstractmethod - def submitted(self) -> bool: - """Whether the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def status(self) -> HumanInputFormStatus: - """Current status of the form.""" - ... - - @property - @abc.abstractmethod - def expiration_time(self) -> datetime: - """When the form expires.""" - ... - - -class HumanInputFormRecipientEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of this recipient.""" - ... - - @property - @abc.abstractmethod - def token(self) -> str: - """token returns a random string used to submit form""" - ... - - -class HumanInputFormRepository(Protocol): - """ - Repository interface for HumanInputForm. - - This interface defines the contract for accessing and manipulating - HumanInputForm data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - """Get the form created for a given human input node in a workflow execution. Returns - `None` if the form has not been created yet.""" - ... - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - """ - Create a human input form from form definition. - """ - ... diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index f79230217c..acf0ee6839 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -6,18 +6,15 @@ import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Protocol +from typing import Any, ClassVar, Protocol -from pydantic import BaseModel, Field from pydantic.json import pydantic_encoder from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import NodeState +from core.workflow.entities.pause_reason import PauseReason +from core.workflow.enums import NodeExecutionType, NodeState, NodeType from core.workflow.runtime.variable_pool import VariablePool -if TYPE_CHECKING: - from core.workflow.entities.pause_reason import PauseReason - class ReadyQueueProtocol(Protocol): """Structural interface required from ready queue implementations.""" @@ -64,7 +61,7 @@ class GraphExecutionProtocol(Protocol): aborted: bool error: Exception | None exceptions_count: int - pause_reasons: Sequence[PauseReason] + pause_reasons: list[PauseReason] def start(self) -> None: """Transition execution into the running state.""" @@ -112,11 +109,18 @@ class NodeProtocol(Protocol): id: str state: NodeState + execution_type: NodeExecutionType + node_type: ClassVar[NodeType] + + def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ... class EdgeProtocol(Protocol): id: str state: NodeState + tail: str + head: str + source_handle: str class GraphProtocol(Protocol): @@ -129,13 +133,6 @@ class GraphProtocol(Protocol): def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... -class _GraphStateSnapshot(BaseModel): - """Serializable graph state snapshot for node/edge states.""" - - nodes: dict[str, NodeState] = Field(default_factory=dict) - edges: dict[str, NodeState] = Field(default_factory=dict) - - @dataclass(slots=True) class _GraphRuntimeStateSnapshot: """Immutable view of a serialized runtime state snapshot.""" @@ -151,20 +148,10 @@ class _GraphRuntimeStateSnapshot: graph_execution_dump: str | None response_coordinator_dump: str | None paused_nodes: tuple[str, ...] - deferred_nodes: tuple[str, ...] - graph_node_states: dict[str, NodeState] - graph_edge_states: dict[str, NodeState] class GraphRuntimeState: - """Mutable runtime state shared across graph execution components. - - `GraphRuntimeState` encapsulates the runtime state of workflow execution, - including scheduling details, variable values, and timing information. - - Values that are initialized prior to workflow execution and remain constant - throughout the execution should be part of `GraphInitParams` instead. - """ + """Mutable runtime state shared across graph execution components.""" def __init__( self, @@ -202,16 +189,6 @@ class GraphRuntimeState: self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() - self._deferred_nodes: set[str] = set() - - # Node and edges states needed to be restored into - # graph object. - # - # These two fields are non-None only when resuming from a snapshot. - # Once the graph is attached, these two fields will be set to None. - self._pending_graph_node_states: dict[str, NodeState] | None = None - self._pending_graph_edge_states: dict[str, NodeState] | None = None - self.stop_event: threading.Event = threading.Event() if graph is not None: @@ -233,7 +210,6 @@ class GraphRuntimeState: if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None: self._response_coordinator.loads(self._pending_response_coordinator_dump) self._pending_response_coordinator_dump = None - self._apply_pending_graph_state() def configure(self, *, graph: GraphProtocol | None = None) -> None: """Ensure core collaborators are initialized with the provided context.""" @@ -355,13 +331,8 @@ class GraphRuntimeState: "ready_queue": self.ready_queue.dumps(), "graph_execution": self.graph_execution.dumps(), "paused_nodes": list(self._paused_nodes), - "deferred_nodes": list(self._deferred_nodes), } - graph_state = self._snapshot_graph_state() - if graph_state is not None: - snapshot["graph_state"] = graph_state - if self._response_coordinator is not None and self._graph is not None: snapshot["response_coordinator"] = self._response_coordinator.dumps() @@ -395,11 +366,6 @@ class GraphRuntimeState: self._paused_nodes.add(node_id) - def get_paused_nodes(self) -> list[str]: - """Retrieve the list of paused nodes without mutating internal state.""" - - return list(self._paused_nodes) - def consume_paused_nodes(self) -> list[str]: """Retrieve and clear the list of paused nodes awaiting resume.""" @@ -407,23 +373,6 @@ class GraphRuntimeState: self._paused_nodes.clear() return nodes - def register_deferred_node(self, node_id: str) -> None: - """Record a node that became ready during pause and should resume later.""" - - self._deferred_nodes.add(node_id) - - def get_deferred_nodes(self) -> list[str]: - """Retrieve deferred nodes without mutating internal state.""" - - return list(self._deferred_nodes) - - def consume_deferred_nodes(self) -> list[str]: - """Retrieve and clear deferred nodes awaiting resume.""" - - nodes = list(self._deferred_nodes) - self._deferred_nodes.clear() - return nodes - # ------------------------------------------------------------------ # Builders # ------------------------------------------------------------------ @@ -485,10 +434,6 @@ class GraphRuntimeState: graph_execution_payload = payload.get("graph_execution") response_payload = payload.get("response_coordinator") paused_nodes_payload = payload.get("paused_nodes", []) - deferred_nodes_payload = payload.get("deferred_nodes", []) - graph_state_payload = payload.get("graph_state", {}) or {} - graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes") - graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges") return _GraphRuntimeStateSnapshot( start_at=start_at, @@ -502,9 +447,6 @@ class GraphRuntimeState: graph_execution_dump=graph_execution_payload, response_coordinator_dump=response_payload, paused_nodes=tuple(map(str, paused_nodes_payload)), - deferred_nodes=tuple(map(str, deferred_nodes_payload)), - graph_node_states=graph_node_states, - graph_edge_states=graph_edge_states, ) def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None: @@ -520,10 +462,6 @@ class GraphRuntimeState: self._restore_graph_execution(snapshot.graph_execution_dump) self._restore_response_coordinator(snapshot.response_coordinator_dump) self._paused_nodes = set(snapshot.paused_nodes) - self._deferred_nodes = set(snapshot.deferred_nodes) - self._pending_graph_node_states = snapshot.graph_node_states or None - self._pending_graph_edge_states = snapshot.graph_edge_states or None - self._apply_pending_graph_state() def _restore_ready_queue(self, payload: str | None) -> None: if payload is not None: @@ -560,68 +498,3 @@ class GraphRuntimeState: self._pending_response_coordinator_dump = payload self._response_coordinator = None - - def _snapshot_graph_state(self) -> _GraphStateSnapshot: - graph = self._graph - if graph is None: - if self._pending_graph_node_states is None and self._pending_graph_edge_states is None: - return _GraphStateSnapshot() - return _GraphStateSnapshot( - nodes=self._pending_graph_node_states or {}, - edges=self._pending_graph_edge_states or {}, - ) - - nodes = graph.nodes - edges = graph.edges - if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping): - return _GraphStateSnapshot() - - node_states = {} - for node_id, node in nodes.items(): - if not isinstance(node_id, str): - continue - node_states[node_id] = node.state - - edge_states = {} - for edge_id, edge in edges.items(): - if not isinstance(edge_id, str): - continue - edge_states[edge_id] = edge.state - - return _GraphStateSnapshot(nodes=node_states, edges=edge_states) - - def _apply_pending_graph_state(self) -> None: - if self._graph is None: - return - if self._pending_graph_node_states: - for node_id, state in self._pending_graph_node_states.items(): - node = self._graph.nodes.get(node_id) - if node is None: - continue - node.state = state - if self._pending_graph_edge_states: - for edge_id, state in self._pending_graph_edge_states.items(): - edge = self._graph.edges.get(edge_id) - if edge is None: - continue - edge.state = state - - self._pending_graph_node_states = None - self._pending_graph_edge_states = None - - -def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]: - if not isinstance(payload, Mapping): - return {} - raw_map = payload.get(key, {}) - if not isinstance(raw_map, Mapping): - return {} - result: dict[str, NodeState] = {} - for node_id, raw_state in raw_map.items(): - if not isinstance(node_id, str): - continue - try: - result[node_id] = NodeState(str(raw_state)) - except ValueError: - continue - return result diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 43f15f6fd0..4b1845cda2 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -144,11 +144,11 @@ class WorkflowEntry: :param user_inputs: user inputs :return: """ - node_config = dict(workflow.get_node_config_by_id(node_id)) - node_config_data = node_config.get("data", {}) + node_config = workflow.get_node_config_by_id(node_id) + node_config_data = node_config["data"] # Get node type - node_type = NodeType(node_config_data.get("type")) + node_type = NodeType(node_config_data["type"]) # init graph init params and runtime state graph_init_params = GraphInitParams( diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index f1f549e1f8..5456043ccd 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -15,14 +15,12 @@ class WorkflowRuntimeTypeConverter: def to_json_encodable(self, value: None) -> None: ... def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: - """Convert runtime values to JSON-serializable structures.""" - - result = self.value_to_json_encodable_recursive(value) + result = self._to_json_encodable_recursive(value) if isinstance(result, Mapping) or result is None: return result return {} - def value_to_json_encodable_recursive(self, value: Any): + def _to_json_encodable_recursive(self, value: Any): if value is None: return value if isinstance(value, (bool, int, str, float)): @@ -31,7 +29,7 @@ class WorkflowRuntimeTypeConverter: # Convert Decimal to float for JSON serialization return float(value) if isinstance(value, Segment): - return self.value_to_json_encodable_recursive(value.value) + return self._to_json_encodable_recursive(value.value) if isinstance(value, File): return value.to_dict() if isinstance(value, BaseModel): @@ -39,11 +37,11 @@ class WorkflowRuntimeTypeConverter: if isinstance(value, dict): res = {} for k, v in value.items(): - res[k] = self.value_to_json_encodable_recursive(v) + res[k] = self._to_json_encodable_recursive(v) return res if isinstance(value, list): res_list = [] for item in value: - res_list.append(self.value_to_json_encodable_recursive(item)) + res_list.append(self._to_json_encodable_recursive(item)) return res_list return value diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 03e6cbda68..c0279f893b 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" @@ -102,7 +102,7 @@ elif [[ "${MODE}" == "job" ]]; then fi echo "Running Flask job command: flask $*" - + # Temporarily disable exit on error to capture exit code set +e flask "$@" diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index aa9723f375..af983f6d87 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -151,12 +151,6 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.queue_monitor_task.queue_monitor_task", "schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30), } - if dify_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK: - imports.append("tasks.human_input_timeout_tasks") - beat_schedule["human_input_form_timeout"] = { - "task": "human_input_form_timeout.check_and_resume", - "schedule": timedelta(minutes=dify_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL), - } if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") imports.append("tasks.process_tenant_plugin_autoupgrade_check_task") diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py index e6c1bc6bee..ab4d23a072 100644 --- a/api/extensions/ext_fastopenapi.py +++ b/api/extensions/ext_fastopenapi.py @@ -27,10 +27,13 @@ def init_app(app: DifyApp) -> None: ) # Ensure route decorators are evaluated. + import controllers.console.init_validate as init_validate_module import controllers.console.ping as ping_module - from controllers.console import setup + from controllers.console import remote_files, setup + _ = init_validate_module _ = ping_module + _ = remote_files _ = setup router.include_router(console_router, prefix="/console/api") diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 0797a3cb98..5e75bc36b0 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -8,16 +8,12 @@ from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union import redis from redis import RedisError from redis.cache import CacheConfig -from redis.client import PubSub from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.sentinel import Sentinel from configs import dify_config from dify_app import DifyApp -from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol -from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel -from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel if TYPE_CHECKING: from redis.lock import Lock @@ -110,7 +106,6 @@ class RedisClientWrapper: def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... def zcard(self, name: str | bytes) -> Any: ... def getdel(self, name: str | bytes) -> Any: ... - def pubsub(self) -> PubSub: ... def __getattr__(self, item: str) -> Any: if self._client is None: @@ -119,7 +114,6 @@ class RedisClientWrapper: redis_client: RedisClientWrapper = RedisClientWrapper() -pubsub_redis_client: RedisClientWrapper = RedisClientWrapper() def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: @@ -232,12 +226,6 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis return client -def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]: - if use_clusters: - return RedisCluster.from_url(pubsub_url) - return redis.Redis.from_url(pubsub_url) - - def init_app(app: DifyApp): """Initialize Redis client and attach it to the app.""" global redis_client @@ -256,24 +244,6 @@ def init_app(app: DifyApp): redis_client.initialize(client) app.extensions["redis"] = redis_client - pubsub_client = client - if dify_config.normalized_pubsub_redis_url: - pubsub_client = _create_pubsub_client( - dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS - ) - pubsub_redis_client.initialize(pubsub_client) - - -def get_pubsub_redis_client() -> RedisClientWrapper: - return pubsub_redis_client - - -def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol: - redis_conn = get_pubsub_redis_client() - if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded": - return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType] - return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType] - P = ParamSpec("P") R = TypeVar("R") diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 817c8b0448..f67723630b 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -13,7 +13,6 @@ from typing import Any from sqlalchemy.orm import sessionmaker -from core.workflow.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value @@ -208,10 +207,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep reverse=True, ) - for row in deduplicated_results: - model = _dict_to_workflow_node_execution_model(row) - if model.status != WorkflowNodeExecutionStatus.PAUSED: - return model + if deduplicated_results: + return _dict_to_workflow_node_execution_model(deduplicated_results[0]) return None @@ -312,8 +309,6 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep if model and model.id: # Ensure model is valid models.append(model) - models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED] - # Sort by index DESC for trace visualization models.sort(key=lambda x: x.index, reverse=True) diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index c1608f58a5..18eed4e481 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -390,8 +390,7 @@ class ClickZettaVolumeStorage(BaseStorage): """ content = self.load_once(filename) - with Path(target_filepath).open("wb") as f: - f.write(content) + Path(target_filepath).write_bytes(content) logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index e69306dcb2..a646950722 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,36 +1,69 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import TimestampField +from datetime import datetime -annotation_fields = { - "id": fields.String, - "question": fields.String, - "answer": fields.Raw(attribute="content"), - "hit_count": fields.Integer, - "created_at": TimestampField, - # 'account': fields.Nested(simple_account_fields, allow_null=True) -} +from pydantic import BaseModel, ConfigDict, Field, field_validator -def build_annotation_model(api_or_ns: Namespace): - """Build the annotation model for the API or Namespace.""" - return api_or_ns.model("Annotation", annotation_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -annotation_list_fields = { - "data": fields.List(fields.Nested(annotation_fields)), -} +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -annotation_hit_history_fields = { - "id": fields.String, - "source": fields.String, - "score": fields.Float, - "question": fields.String, - "created_at": TimestampField, - "match": fields.String(attribute="annotation_question"), - "response": fields.String(attribute="annotation_content"), -} -annotation_hit_history_list_fields = { - "data": fields.List(fields.Nested(annotation_hit_history_fields)), -} +class Annotation(ResponseModel): + id: str + question: str | None = None + answer: str | None = Field(default=None, validation_alias="content") + hit_count: int | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AnnotationList(ResponseModel): + data: list[Annotation] + has_more: bool + limit: int + total: int + page: int + + +class AnnotationExportList(ResponseModel): + data: list[Annotation] + + +class AnnotationHitHistory(ResponseModel): + id: str + source: str | None = None + score: float | None = None + question: str | None = None + created_at: int | None = None + match: str | None = Field(default=None, validation_alias="annotation_question") + response: str | None = Field(default=None, validation_alias="annotation_content") + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AnnotationHitHistoryList(ResponseModel): + data: list[AnnotationHitHistory] + has_more: bool + limit: int + total: int + page: int diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index cda46f2339..d8ae0ad8b8 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -192,7 +192,6 @@ class StatusCount(ResponseModel): success: int failed: int partial_success: int - paused: int class ModelConfig(ResponseModel): diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 5389b0213a..effe7bfb20 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,7 @@ -from flask_restx import Namespace, fields +from __future__ import annotations + +from flask_restx import fields +from pydantic import BaseModel, ConfigDict simple_end_user_fields = { "id": fields.String, @@ -8,5 +11,18 @@ simple_end_user_fields = { } -def build_simple_end_user_model(api_or_ns: Namespace): - return api_or_ns.model("SimpleEndUser", simple_end_user_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class SimpleEndUser(ResponseModel): + id: str + type: str + is_anonymous: bool + session_id: str | None = None diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 25160927e6..11d9a1a2fc 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,6 +1,11 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import AvatarUrlField, TimestampField +from datetime import datetime + +from flask_restx import fields +from pydantic import BaseModel, ConfigDict, computed_field, field_validator + +from core.file import helpers as file_helpers simple_account_fields = { "id": fields.String, @@ -9,36 +14,78 @@ simple_account_fields = { } -def build_simple_account_model(api_or_ns: Namespace): - return api_or_ns.model("SimpleAccount", simple_account_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -account_fields = { - "id": fields.String, - "name": fields.String, - "avatar": fields.String, - "avatar_url": AvatarUrlField, - "email": fields.String, - "is_password_set": fields.Boolean, - "interface_language": fields.String, - "interface_theme": fields.String, - "timezone": fields.String, - "last_login_at": TimestampField, - "last_login_ip": fields.String, - "created_at": TimestampField, -} +def _build_avatar_url(avatar: str | None) -> str | None: + if avatar is None: + return None + if avatar.startswith(("http://", "https://")): + return avatar + return file_helpers.get_signed_file_url(avatar) -account_with_role_fields = { - "id": fields.String, - "name": fields.String, - "avatar": fields.String, - "avatar_url": AvatarUrlField, - "email": fields.String, - "last_login_at": TimestampField, - "last_active_at": TimestampField, - "created_at": TimestampField, - "role": fields.String, - "status": fields.String, -} -account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))} +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class SimpleAccount(ResponseModel): + id: str + name: str + email: str + + +class _AccountAvatar(ResponseModel): + avatar: str | None = None + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def avatar_url(self) -> str | None: + return _build_avatar_url(self.avatar) + + +class Account(_AccountAvatar): + id: str + name: str + email: str + is_password_set: bool + interface_language: str | None = None + interface_theme: str | None = None + timezone: str | None = None + last_login_at: int | None = None + last_login_ip: str | None = None + created_at: int | None = None + + @field_validator("last_login_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountWithRole(_AccountAvatar): + id: str + name: str + email: str + last_login_at: int | None = None + last_active_at: int | None = None + created_at: int | None = None + role: str + status: str + + @field_validator("last_login_at", "last_active_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountWithRoleList(ResponseModel): + accounts: list[AccountWithRole] diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 77b26a7423..e6c3b42f93 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -6,7 +6,6 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, field_validator -from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from core.file import File from fields.conversation_fields import AgentThought, JSONValue, MessageFile @@ -62,7 +61,6 @@ class MessageListItem(ResponseModel): message_files: list[MessageFile] status: str error: str | None = None - extra_contents: list[ExecutionExtraContentDomainModel] @field_validator("inputs", mode="before") @classmethod diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index e359a4408c..7cb64e5ca8 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,12 +1,20 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -dataset_tag_fields = { - "id": fields.String, - "name": fields.String, - "type": fields.String, - "binding_count": fields.String, -} +from pydantic import BaseModel, ConfigDict -def build_dataset_tag_fields(api_or_ns: Namespace): - return api_or_ns.model("DataSetTag", dataset_tag_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class DataSetTag(ResponseModel): + id: str + name: str + type: str + binding_count: str | None = None diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index ae70356322..d0e762f62b 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,7 +1,7 @@ from flask_restx import Namespace, fields -from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields -from fields.member_fields import build_simple_account_model, simple_account_fields +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields from fields.workflow_run_fields import ( build_workflow_run_for_archived_log_model, build_workflow_run_for_log_model, @@ -25,17 +25,9 @@ workflow_app_log_partial_fields = { def build_workflow_app_log_partial_model(api_or_ns: Namespace): """Build the workflow app log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_log_model(api_or_ns) - simple_account_model = build_simple_account_model(api_or_ns) - simple_end_user_model = build_simple_end_user_model(api_or_ns) copied_fields = workflow_app_log_partial_fields.copy() copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True) - copied_fields["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True - ) - copied_fields["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True - ) return api_or_ns.model("WorkflowAppLogPartial", copied_fields) @@ -52,17 +44,9 @@ workflow_archived_log_partial_fields = { def build_workflow_archived_log_partial_model(api_or_ns: Namespace): """Build the workflow archived log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns) - simple_account_model = build_simple_account_model(api_or_ns) - simple_end_user_model = build_simple_end_user_model(api_or_ns) copied_fields = workflow_archived_log_partial_fields.copy() copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True) - copied_fields["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True - ) - copied_fields["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True - ) return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields) diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index fa2be421a1..7d4b8e63ca 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -162,7 +162,7 @@ class RedisSubscriptionBase(Subscription): self._start_if_needed() return iter(self._message_iterator()) - def receive(self, timeout: float | None = 0.1) -> bytes | None: + def receive(self, timeout: float | None = None) -> bytes | None: """Receive the next message from the subscription.""" if self._closed.is_set(): raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed") diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 9e8ab90e8e..d190c51bbc 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -61,14 +61,7 @@ class _RedisShardedSubscription(RedisSubscriptionBase): def _get_message(self) -> dict | None: assert self._pubsub is not None - # NOTE(QuantumGhost): this is an issue in - # upstream code. If Sharded PubSub is used with Cluster, the - # `ClusterPubSub.get_sharded_message` will return `None` regardless of - # message['type']. - # - # Since we have already filtered at the caller's site, we can safely set - # `ignore_subscribe_messages=False`. - return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined] + return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined] def _get_message_type(self) -> str: return "smessage" diff --git a/api/libs/email_template_renderer.py b/api/libs/email_template_renderer.py deleted file mode 100644 index 98ea30ab46..0000000000 --- a/api/libs/email_template_renderer.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -Email template rendering helpers with configurable safety modes. -""" - -import time -from collections.abc import Mapping -from typing import Any - -from flask import render_template_string -from jinja2.runtime import Context -from jinja2.sandbox import ImmutableSandboxedEnvironment - -from configs import dify_config -from configs.feature import TemplateMode - - -class SandboxedEnvironment(ImmutableSandboxedEnvironment): - """Sandboxed environment with execution timeout.""" - - def __init__(self, timeout: int, *args: Any, **kwargs: Any): - self._deadline = time.time() + timeout if timeout else None - super().__init__(*args, **kwargs) - - def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any: - if self._deadline is not None and time.time() > self._deadline: - raise TimeoutError("Template rendering timeout") - return super().call(context, obj, *args, **kwargs) - - -def render_email_template(template: str, substitutions: Mapping[str, str]) -> str: - """ - Render email template content according to the configured template mode. - - In unsafe mode, Jinja expressions are evaluated directly. - In sandbox mode, a sandboxed environment with timeout is used. - In disabled mode, the template is returned without rendering. - """ - mode = dify_config.MAIL_TEMPLATING_MODE - timeout = dify_config.MAIL_TEMPLATING_TIMEOUT - - if mode == TemplateMode.UNSAFE: - return render_template_string(template, **substitutions) - if mode == TemplateMode.SANDBOX: - env = SandboxedEnvironment(timeout=timeout) - tmpl = env.from_string(template) - return tmpl.render(substitutions) - if mode == TemplateMode.DISABLED: - return template - raise ValueError(f"Unsupported mail templating mode: {mode}") diff --git a/api/libs/flask_utils.py b/api/libs/flask_utils.py index e45c8fe319..beade7eb25 100644 --- a/api/libs/flask_utils.py +++ b/api/libs/flask_utils.py @@ -1,15 +1,12 @@ import contextvars from collections.abc import Iterator from contextlib import contextmanager -from typing import TYPE_CHECKING, TypeVar +from typing import TypeVar from flask import Flask, g T = TypeVar("T") -if TYPE_CHECKING: - from models import Account, EndUser - @contextmanager def preserve_flask_contexts( @@ -67,7 +64,3 @@ def preserve_flask_contexts( finally: # Any cleanup can be added here if needed pass - - -def set_login_user(user: "Account | EndUser"): - g._login_user = user diff --git a/api/libs/helper.py b/api/libs/helper.py index fb577b9c99..07c4823727 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -7,10 +7,10 @@ import struct import subprocess import time import uuid -from collections.abc import Callable, Generator, Mapping +from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast +from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast from uuid import UUID from zoneinfo import available_timezones @@ -126,13 +126,6 @@ class TimestampField(fields.Raw): return int(value.timestamp()) -class OptionalTimestampField(fields.Raw): - def format(self, value) -> int | None: - if value is None: - return None - return int(value.timestamp()) - - def email(email): # Define a regex pattern for email addresses pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$" @@ -244,26 +237,6 @@ def convert_datetime_to_date(field, target_timezone: str = ":tz"): def generate_string(n): - """ - Generates a cryptographically secure random string of the specified length. - - This function uses a cryptographically secure pseudorandom number generator (CSPRNG) - to create a string composed of ASCII letters (both uppercase and lowercase) and digits. - - Each character in the generated string provides approximately 5.95 bits of entropy - (log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the - length of the string (`n`) should be at least 22 characters. - - Args: - n (int): The length of the random string to generate. For secure usage, - `n` should be 22 or greater. - - Returns: - str: A random string of length `n` composed of ASCII letters and digits. - - Note: - This function is suitable for generating credentials or other secure tokens. - """ letters_digits = string.ascii_letters + string.digits result = "" for _ in range(n): @@ -432,35 +405,11 @@ class TokenManager: return f"{token_type}:account:{account_id}" -class _RateLimiterRedisClient(Protocol): - def zadd(self, name: str | bytes, mapping: dict[str | bytes | int | float, float | int | str | bytes]) -> int: ... - - def zremrangebyscore(self, name: str | bytes, min: str | float, max: str | float) -> int: ... - - def zcard(self, name: str | bytes) -> int: ... - - def expire(self, name: str | bytes, time: int) -> bool: ... - - -def _default_rate_limit_member_factory() -> str: - current_time = int(time.time()) - return f"{current_time}:{secrets.token_urlsafe(nbytes=8)}" - - class RateLimiter: - def __init__( - self, - prefix: str, - max_attempts: int, - time_window: int, - member_factory: Callable[[], str] = _default_rate_limit_member_factory, - redis_client: _RateLimiterRedisClient = redis_client, - ): + def __init__(self, prefix: str, max_attempts: int, time_window: int): self.prefix = prefix self.max_attempts = max_attempts self.time_window = time_window - self._member_factory = member_factory - self._redis_client = redis_client def _get_key(self, email: str) -> str: return f"{self.prefix}:{email}" @@ -470,8 +419,8 @@ class RateLimiter: current_time = int(time.time()) window_start_time = current_time - self.time_window - self._redis_client.zremrangebyscore(key, "-inf", window_start_time) - attempts = self._redis_client.zcard(key) + redis_client.zremrangebyscore(key, "-inf", window_start_time) + attempts = redis_client.zcard(key) if attempts and int(attempts) >= self.max_attempts: return True @@ -479,8 +428,7 @@ class RateLimiter: def increment_rate_limit(self, email: str): key = self._get_key(email) - member = self._member_factory() current_time = int(time.time()) - self._redis_client.zadd(key, {member: current_time}) - self._redis_client.expire(key, self.time_window * 2) + redis_client.zadd(key, {current_time: current_time}) + redis_client.expire(key, self.time_window * 2) diff --git a/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py index 3c2e0822e1..c6c72859dc 100644 --- a/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py +++ b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py @@ -51,7 +51,7 @@ def upgrade(): batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True)) with op.batch_alter_table('documents', schema=None) as batch_op: - batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True)) + batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=False)) else: # MySQL: Use compatible syntax op.create_table( @@ -83,7 +83,7 @@ def upgrade(): batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True)) with op.batch_alter_table('documents', schema=None) as batch_op: - batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True)) + batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/2026_01_29_1415-e8c3b3c46151_add_human_input_related_db_models.py b/api/migrations/versions/2026_01_29_1415-e8c3b3c46151_add_human_input_related_db_models.py deleted file mode 100644 index a1546ef940..0000000000 --- a/api/migrations/versions/2026_01_29_1415-e8c3b3c46151_add_human_input_related_db_models.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Add human input related db models - -Revision ID: e8c3b3c46151 -Revises: 788d3099ae3a -Create Date: 2026-01-29 14:15:23.081903 - -""" - -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = "e8c3b3c46151" -down_revision = "788d3099ae3a" -branch_labels = None -depends_on = None - - -def upgrade(): - op.create_table( - "execution_extra_contents", - sa.Column("id", models.types.StringUUID(), nullable=False), - sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - - sa.Column("type", sa.String(length=30), nullable=False), - sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False), - sa.Column("message_id", models.types.StringUUID(), nullable=True), - sa.Column("form_id", models.types.StringUUID(), nullable=True), - sa.PrimaryKeyConstraint("id", name=op.f("execution_extra_contents_pkey")), - ) - with op.batch_alter_table("execution_extra_contents", schema=None) as batch_op: - batch_op.create_index(batch_op.f("execution_extra_contents_message_id_idx"), ["message_id"], unique=False) - batch_op.create_index( - batch_op.f("execution_extra_contents_workflow_run_id_idx"), ["workflow_run_id"], unique=False - ) - - op.create_table( - "human_input_form_deliveries", - sa.Column("id", models.types.StringUUID(), nullable=False), - sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - - sa.Column("form_id", models.types.StringUUID(), nullable=False), - sa.Column("delivery_method_type", sa.String(length=20), nullable=False), - sa.Column("delivery_config_id", models.types.StringUUID(), nullable=True), - sa.Column("channel_payload", sa.Text(), nullable=False), - sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_deliveries_pkey")), - ) - - op.create_table( - "human_input_form_recipients", - sa.Column("id", models.types.StringUUID(), nullable=False), - sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - - sa.Column("form_id", models.types.StringUUID(), nullable=False), - sa.Column("delivery_id", models.types.StringUUID(), nullable=False), - sa.Column("recipient_type", sa.String(length=20), nullable=False), - sa.Column("recipient_payload", sa.Text(), nullable=False), - sa.Column("access_token", sa.VARCHAR(length=32), nullable=False), - sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_recipients_pkey")), - ) - with op.batch_alter_table('human_input_form_recipients', schema=None) as batch_op: - batch_op.create_unique_constraint(batch_op.f('human_input_form_recipients_access_token_key'), ['access_token']) - - op.create_table( - "human_input_forms", - sa.Column("id", models.types.StringUUID(), nullable=False), - sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - - sa.Column("tenant_id", models.types.StringUUID(), nullable=False), - sa.Column("app_id", models.types.StringUUID(), nullable=False), - sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True), - sa.Column("form_kind", sa.String(length=20), nullable=False), - sa.Column("node_id", sa.String(length=60), nullable=False), - sa.Column("form_definition", sa.Text(), nullable=False), - sa.Column("rendered_content", sa.Text(), nullable=False), - sa.Column("status", sa.String(length=20), nullable=False), - sa.Column("expiration_time", sa.DateTime(), nullable=False), - sa.Column("selected_action_id", sa.String(length=200), nullable=True), - sa.Column("submitted_data", sa.Text(), nullable=True), - sa.Column("submitted_at", sa.DateTime(), nullable=True), - sa.Column("submission_user_id", models.types.StringUUID(), nullable=True), - sa.Column("submission_end_user_id", models.types.StringUUID(), nullable=True), - sa.Column("completed_by_recipient_id", models.types.StringUUID(), nullable=True), - - sa.PrimaryKeyConstraint("id", name=op.f("human_input_forms_pkey")), - ) - - -def downgrade(): - op.drop_table("human_input_forms") - op.drop_table("human_input_form_recipients") - op.drop_table("human_input_form_deliveries") - op.drop_table("execution_extra_contents") diff --git a/api/models/__init__.py b/api/models/__init__.py index 1d5d604ba7..74b33130ef 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -34,8 +34,6 @@ from .enums import ( WorkflowRunTriggeredFrom, WorkflowTriggerStatus, ) -from .execution_extra_content import ExecutionExtraContent, HumanInputContent -from .human_input import HumanInputForm from .model import ( AccountTrialAppRecord, ApiRequest, @@ -157,12 +155,9 @@ __all__ = [ "DocumentSegment", "Embedding", "EndUser", - "ExecutionExtraContent", "ExporleBanner", "ExternalKnowledgeApis", "ExternalKnowledgeBindings", - "HumanInputContent", - "HumanInputForm", "IconType", "InstalledApp", "InvitationCode", diff --git a/api/models/base.py b/api/models/base.py index aa93d31199..c8a5e20f25 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -41,7 +41,7 @@ class DefaultFieldsMixin: ) updated_at: Mapped[datetime] = mapped_column( - DateTime, + __name_pos=DateTime, nullable=False, default=naive_utc_now, server_default=func.current_timestamp(), diff --git a/api/models/dataset.py b/api/models/dataset.py index 6ab8f372bf..e7da2961bc 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -420,7 +420,7 @@ class Document(Base): doc_metadata = mapped_column(AdjustedJSON, nullable=True) doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) doc_language = mapped_column(String(255), nullable=True) - need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) + need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] diff --git a/api/models/enums.py b/api/models/enums.py index 2bc61120ce..8cd3d4cf2a 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -36,7 +36,6 @@ class MessageStatus(StrEnum): """ NORMAL = "normal" - PAUSED = "paused" ERROR = "error" diff --git a/api/models/execution_extra_content.py b/api/models/execution_extra_content.py deleted file mode 100644 index d0bd34efec..0000000000 --- a/api/models/execution_extra_content.py +++ /dev/null @@ -1,78 +0,0 @@ -from enum import StrEnum, auto -from typing import TYPE_CHECKING - -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from .base import Base, DefaultFieldsMixin -from .types import EnumText, StringUUID - -if TYPE_CHECKING: - from .human_input import HumanInputForm - - -class ExecutionContentType(StrEnum): - HUMAN_INPUT = auto() - - -class ExecutionExtraContent(DefaultFieldsMixin, Base): - """ExecutionExtraContent stores extra contents produced during workflow / chatflow execution.""" - - # The `ExecutionExtraContent` uses single table inheritance to model different - # kinds of contents produced during message generation. - # - # See: https://docs.sqlalchemy.org/en/20/orm/inheritance.html#single-table-inheritance - - __tablename__ = "execution_extra_contents" - __mapper_args__ = { - "polymorphic_abstract": True, - "polymorphic_on": "type", - "with_polymorphic": "*", - } - # type records the type of the content. It serves as the `discriminator` for the - # single table inheritance. - type: Mapped[ExecutionContentType] = mapped_column( - EnumText(ExecutionContentType, length=30), - nullable=False, - ) - - # `workflow_run_id` records the workflow execution which generates this content, correspond to - # `WorkflowRun.id`. - workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) - - # `message_id` records the messages generated by the execution associated with this `ExecutionExtraContent`. - # It references to `Message.id`. - # - # For workflow execution, this field is `None`. - # - # For chatflow execution, `message_id`` is not None, and the following condition holds: - # - # The message referenced by `message_id` has `message.workflow_run_id == execution_extra_content.workflow_run_id` - # - message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, index=True) - - -class HumanInputContent(ExecutionExtraContent): - """HumanInputContent is a concrete class that represents human input content. - It should only be initialized with the `new` class method.""" - - __mapper_args__ = { - "polymorphic_identity": ExecutionContentType.HUMAN_INPUT, - } - - # A relation to HumanInputForm table. - # - # While the form_id column is nullable in database (due to the nature of single table inheritance), - # the form_id field should not be null for a given `HumanInputContent` instance. - form_id: Mapped[str] = mapped_column(StringUUID, nullable=True) - - @classmethod - def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent": - return cls(form_id=form_id, message_id=message_id) - - form: Mapped["HumanInputForm"] = relationship( - "HumanInputForm", - foreign_keys=[form_id], - uselist=False, - lazy="raise", - primaryjoin="foreign(HumanInputContent.form_id) == HumanInputForm.id", - ) diff --git a/api/models/human_input.py b/api/models/human_input.py deleted file mode 100644 index 5208461de1..0000000000 --- a/api/models/human_input.py +++ /dev/null @@ -1,237 +0,0 @@ -from datetime import datetime -from enum import StrEnum -from typing import Annotated, Literal, Self, final - -import sqlalchemy as sa -from pydantic import BaseModel, Field -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from core.workflow.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) -from libs.helper import generate_string - -from .base import Base, DefaultFieldsMixin -from .types import EnumText, StringUUID - -_token_length = 22 -# A 32-character string can store a base64-encoded value with 192 bits of entropy -# or a base62-encoded value with over 180 bits of entropy, providing sufficient -# uniqueness for most use cases. -_token_field_length = 32 -_email_field_length = 330 - - -def _generate_token() -> str: - return generate_string(_token_length) - - -class HumanInputForm(DefaultFieldsMixin, Base): - __tablename__ = "human_input_forms" - - tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - form_kind: Mapped[HumanInputFormKind] = mapped_column( - EnumText(HumanInputFormKind), - nullable=False, - default=HumanInputFormKind.RUNTIME, - ) - - # The human input node the current form corresponds to. - node_id: Mapped[str] = mapped_column(sa.String(60), nullable=False) - form_definition: Mapped[str] = mapped_column(sa.Text, nullable=False) - rendered_content: Mapped[str] = mapped_column(sa.Text, nullable=False) - status: Mapped[HumanInputFormStatus] = mapped_column( - EnumText(HumanInputFormStatus), - nullable=False, - default=HumanInputFormStatus.WAITING, - ) - - expiration_time: Mapped[datetime] = mapped_column( - sa.DateTime, - nullable=False, - ) - - # Submission-related fields (nullable until a submission happens). - selected_action_id: Mapped[str | None] = mapped_column(sa.String(200), nullable=True) - submitted_data: Mapped[str | None] = mapped_column(sa.Text, nullable=True) - submitted_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True) - submission_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - submission_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - - completed_by_recipient_id: Mapped[str | None] = mapped_column( - StringUUID, - nullable=True, - ) - - deliveries: Mapped[list["HumanInputDelivery"]] = relationship( - "HumanInputDelivery", - primaryjoin="HumanInputForm.id == foreign(HumanInputDelivery.form_id)", - uselist=True, - back_populates="form", - lazy="raise", - ) - completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship( - "HumanInputFormRecipient", - primaryjoin="HumanInputForm.completed_by_recipient_id == foreign(HumanInputFormRecipient.id)", - lazy="raise", - viewonly=True, - ) - - -class HumanInputDelivery(DefaultFieldsMixin, Base): - __tablename__ = "human_input_form_deliveries" - - form_id: Mapped[str] = mapped_column( - StringUUID, - nullable=False, - ) - delivery_method_type: Mapped[DeliveryMethodType] = mapped_column( - EnumText(DeliveryMethodType), - nullable=False, - ) - delivery_config_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - channel_payload: Mapped[str] = mapped_column(sa.Text, nullable=False) - - form: Mapped[HumanInputForm] = relationship( - "HumanInputForm", - uselist=False, - foreign_keys=[form_id], - primaryjoin="HumanInputDelivery.form_id == HumanInputForm.id", - back_populates="deliveries", - lazy="raise", - ) - - recipients: Mapped[list["HumanInputFormRecipient"]] = relationship( - "HumanInputFormRecipient", - primaryjoin="HumanInputDelivery.id == foreign(HumanInputFormRecipient.delivery_id)", - uselist=True, - back_populates="delivery", - # Require explicit preloading - lazy="raise", - ) - - -class RecipientType(StrEnum): - # EMAIL_MEMBER member means that the - EMAIL_MEMBER = "email_member" - EMAIL_EXTERNAL = "email_external" - # STANDALONE_WEB_APP is used by the standalone web app. - # - # It's not used while running workflows / chatflows containing HumanInput - # node inside console. - STANDALONE_WEB_APP = "standalone_web_app" - # CONSOLE is used while running workflows / chatflows containing HumanInput - # node inside console. (E.G. running installed apps or debugging workflows / chatflows) - CONSOLE = "console" - # BACKSTAGE is used for backstage input inside console. - BACKSTAGE = "backstage" - - -@final -class EmailMemberRecipientPayload(BaseModel): - TYPE: Literal[RecipientType.EMAIL_MEMBER] = RecipientType.EMAIL_MEMBER - user_id: str - - # The `email` field here is only used for mail sending. - email: str - - -@final -class EmailExternalRecipientPayload(BaseModel): - TYPE: Literal[RecipientType.EMAIL_EXTERNAL] = RecipientType.EMAIL_EXTERNAL - email: str - - -@final -class StandaloneWebAppRecipientPayload(BaseModel): - TYPE: Literal[RecipientType.STANDALONE_WEB_APP] = RecipientType.STANDALONE_WEB_APP - - -@final -class ConsoleRecipientPayload(BaseModel): - TYPE: Literal[RecipientType.CONSOLE] = RecipientType.CONSOLE - account_id: str | None = None - - -@final -class BackstageRecipientPayload(BaseModel): - TYPE: Literal[RecipientType.BACKSTAGE] = RecipientType.BACKSTAGE - account_id: str | None = None - - -@final -class ConsoleDeliveryPayload(BaseModel): - type: Literal["console"] = "console" - internal: bool = True - - -RecipientPayload = Annotated[ - EmailMemberRecipientPayload - | EmailExternalRecipientPayload - | StandaloneWebAppRecipientPayload - | ConsoleRecipientPayload - | BackstageRecipientPayload, - Field(discriminator="TYPE"), -] - - -class HumanInputFormRecipient(DefaultFieldsMixin, Base): - __tablename__ = "human_input_form_recipients" - - form_id: Mapped[str] = mapped_column( - StringUUID, - nullable=False, - ) - delivery_id: Mapped[str] = mapped_column( - StringUUID, - nullable=False, - ) - recipient_type: Mapped["RecipientType"] = mapped_column(EnumText(RecipientType), nullable=False) - recipient_payload: Mapped[str] = mapped_column(sa.Text, nullable=False) - - # Token primarily used for authenticated resume links (email, etc.). - access_token: Mapped[str | None] = mapped_column( - sa.VARCHAR(_token_field_length), - nullable=False, - default=_generate_token, - unique=True, - ) - - delivery: Mapped[HumanInputDelivery] = relationship( - "HumanInputDelivery", - uselist=False, - foreign_keys=[delivery_id], - back_populates="recipients", - primaryjoin="HumanInputFormRecipient.delivery_id == HumanInputDelivery.id", - # Require explicit preloading - lazy="raise", - ) - - form: Mapped[HumanInputForm] = relationship( - "HumanInputForm", - uselist=False, - foreign_keys=[form_id], - primaryjoin="HumanInputFormRecipient.form_id == HumanInputForm.id", - # Require explicit preloading - lazy="raise", - ) - - @classmethod - def new( - cls, - form_id: str, - delivery_id: str, - payload: RecipientPayload, - ) -> Self: - recipient_model = cls( - form_id=form_id, - delivery_id=delivery_id, - recipient_type=payload.TYPE, - recipient_payload=payload.model_dump_json(), - access_token=_generate_token(), - ) - return recipient_model diff --git a/api/models/model.py b/api/models/model.py index c12362f359..c1c6e04ce9 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from datetime import datetime from decimal import Decimal from enum import StrEnum, auto @@ -943,7 +943,6 @@ class Conversation(Base): WorkflowExecutionStatus.FAILED: 0, WorkflowExecutionStatus.STOPPED: 0, WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0, - WorkflowExecutionStatus.PAUSED: 0, } for message in messages: @@ -964,7 +963,6 @@ class Conversation(Base): "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], "failed": status_counts[WorkflowExecutionStatus.FAILED], "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], - "paused": status_counts[WorkflowExecutionStatus.PAUSED], } @property @@ -1347,14 +1345,6 @@ class Message(Base): db.session.commit() return result - # TODO(QuantumGhost): dirty hacks, fix this later. - def set_extra_contents(self, contents: Sequence[dict[str, Any]]) -> None: - self._extra_contents = list(contents) - - @property - def extra_contents(self) -> list[dict[str, Any]]: - return getattr(self, "_extra_contents", []) - @property def workflow_run(self): if self.workflow_run_id: diff --git a/api/models/workflow.py b/api/models/workflow.py index 94e0881bd1..83956b1114 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -20,7 +20,6 @@ from sqlalchemy import ( select, ) from sqlalchemy.orm import Mapped, declared_attr, mapped_column -from typing_extensions import deprecated from core.file.constants import maybe_file_object from core.file.models import File @@ -30,8 +29,9 @@ from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) +from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from core.workflow.enums import NodeType, WorkflowExecutionStatus +from core.workflow.enums import NodeType from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -230,7 +230,7 @@ class Workflow(Base): # bug # - `_get_graph_and_variable_pool_for_single_node_run`. return json.loads(self.graph) if self.graph else {} - def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]: + def get_node_config_by_id(self, node_id: str) -> NodeConfigDict: """Extract a node configuration from the workflow graph by node ID. A node configuration is a dictionary containing the node's properties, including the node's id, title, and its data as a dict. @@ -248,8 +248,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - assert isinstance(node_config, dict) - return node_config + return NodeConfigDictAdapter.validate_python(node_config) @staticmethod def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType: @@ -406,11 +405,6 @@ class Workflow(Base): # bug return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @property - @deprecated( - "This property is not accurate for determining if a workflow is published as a tool." - "It only checks if there's a WorkflowToolProvider for the app, " - "not if this specific workflow version is the one being used by the tool." - ) def tool_published(self) -> bool: """ DEPRECATED: This property is not accurate for determining if a workflow is published as a tool. @@ -613,16 +607,13 @@ class WorkflowRun(Base): version: Mapped[str] = mapped_column(String(255)) graph: Mapped[str | None] = mapped_column(LongText) inputs: Mapped[str | None] = mapped_column(LongText) - status: Mapped[WorkflowExecutionStatus] = mapped_column( - EnumText(WorkflowExecutionStatus, length=255), - nullable=False, - ) + status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[str | None] = mapped_column(LongText, default="{}") error: Mapped[str | None] = mapped_column(LongText) elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) - created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) # account, end_user + created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) finished_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -638,13 +629,11 @@ class WorkflowRun(Base): ) @property - @deprecated("This method is retained for historical reasons; avoid using it if possible.") def created_by_account(self): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property - @deprecated("This method is retained for historical reasons; avoid using it if possible.") def created_by_end_user(self): from .model import EndUser @@ -664,7 +653,6 @@ class WorkflowRun(Base): return json.loads(self.outputs) if self.outputs else {} @property - @deprecated("This method is retained for historical reasons; avoid using it if possible.") def message(self): from .model import Message @@ -673,7 +661,6 @@ class WorkflowRun(Base): ) @property - @deprecated("This method is retained for historical reasons; avoid using it if possible.") def workflow(self): return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() @@ -1874,12 +1861,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): def to_entity(self) -> PauseReason: if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: - return HumanInputRequired( - form_id=self.form_id, - form_content="", - node_id=self.node_id, - node_title="", - ) + return HumanInputRequired(form_id=self.form_id, node_id=self.node_id) elif self.type_ == PauseReasonType.SCHEDULED_PAUSE: return SchedulingPause(message=self.message) else: diff --git a/api/pyproject.toml b/api/pyproject.toml index af2dba6fac..ab1f523267 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -87,7 +87,7 @@ dependencies = [ "sseclient-py~=1.8.0", "httpx-sse~=0.4.0", "sendgrid~=6.12.3", - "flask-restx~=1.3.0", + "flask-restx~=1.3.2", "packaging~=23.2", "croniter>=6.0.0", "weaviate-client==4.17.0", @@ -116,7 +116,7 @@ dev = [ "dotenv-linter~=0.5.0", "faker~=38.2.0", "lxml-stubs~=0.5.1", - "ty~=0.0.1a19", + "ty>=0.0.14", "basedpyright~=1.31.0", "ruff~=0.14.0", "pytest~=8.3.2", @@ -145,7 +145,7 @@ dev = [ "types-openpyxl~=3.1.5", "types-pexpect~=4.9.0", "types-protobuf~=5.29.1", - "types-psutil~=7.0.0", + "types-psutil~=7.2.2", "types-psycopg2~=2.9.21", "types-pygments~=2.19.0", "types-pymysql~=1.1.0", diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 6446eb0d6e..5b3f635301 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -10,7 +10,6 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m """ from collections.abc import Sequence -from dataclasses import dataclass from datetime import datetime from typing import Protocol @@ -20,27 +19,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload -@dataclass(frozen=True) -class WorkflowNodeExecutionSnapshot: - """ - Minimal snapshot of workflow node execution for stream recovery. - - Only includes fields required by snapshot events. - """ - - execution_id: str # Unique execution identifier (node_execution_id or row id). - node_id: str # Workflow graph node id. - node_type: str # Workflow graph node type (e.g. "human-input"). - title: str # Human-friendly node title. - index: int # Execution order index within the workflow run. - status: str # Execution status (running/succeeded/failed/paused). - elapsed_time: float # Execution elapsed time in seconds. - created_at: datetime # Execution created timestamp. - finished_at: datetime | None # Execution finished timestamp. - iteration_id: str | None = None # Iteration id from execution metadata, if any. - loop_id: str | None = None # Loop id from execution metadata, if any. - - class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): """ Protocol for service-layer operations on WorkflowNodeExecutionModel. @@ -101,8 +79,6 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr Args: tenant_id: The tenant identifier app_id: The application identifier - workflow_id: The workflow identifier - triggered_from: The workflow trigger source workflow_run_id: The workflow run identifier Returns: @@ -110,27 +86,6 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr """ ... - def get_execution_snapshots_by_workflow_run( - self, - tenant_id: str, - app_id: str, - workflow_id: str, - triggered_from: str, - workflow_run_id: str, - ) -> Sequence[WorkflowNodeExecutionSnapshot]: - """ - Get minimal snapshots for node executions in a workflow run. - - Args: - tenant_id: The tenant identifier - app_id: The application identifier - workflow_run_id: The workflow run identifier - - Returns: - A sequence of WorkflowNodeExecutionSnapshot ordered by creation time - """ - ... - def get_execution_by_id( self, execution_id: str, diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 17e01a6e18..1d3954571f 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -432,13 +432,6 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): # while creating pause. ... - def get_workflow_pause(self, workflow_run_id: str) -> WorkflowPauseEntity | None: - """Retrieve the current pause for a workflow execution. - - If there is no current pause, this method would return `None`. - """ - ... - def resume_workflow_pause( self, workflow_run_id: str, @@ -634,19 +627,3 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): [{"date": "2024-01-01", "interactions": 2.5}, ...] """ ... - - def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None: - """ - Get a specific workflow run by its id and the associated tenant id. - - This function does not apply application isolation. It should only be used when - the application identifier is not available. - - Args: - tenant_id: Tenant identifier for multi-tenant isolation - run_id: Workflow run identifier - - Returns: - WorkflowRun object if found, None otherwise - """ - ... diff --git a/api/repositories/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py index a3c4039aaa..b970f39816 100644 --- a/api/repositories/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -63,12 +63,6 @@ class WorkflowPauseEntity(ABC): """ pass - @property - @abstractmethod - def paused_at(self) -> datetime: - """`paused_at` returns the creation time of the pause.""" - pass - @abstractmethod def get_pause_reasons(self) -> Sequence[PauseReason]: """ @@ -76,5 +70,7 @@ class WorkflowPauseEntity(ABC): Returns a sequence of `PauseReason` objects describing the specific nodes and reasons for which the workflow execution was paused. + This information is related to, but distinct from, the `PauseReason` type + defined in `api/core/workflow/entities/pause_reason.py`. """ ... diff --git a/api/repositories/execution_extra_content_repository.py b/api/repositories/execution_extra_content_repository.py deleted file mode 100644 index 72b5443d2c..0000000000 --- a/api/repositories/execution_extra_content_repository.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import Protocol - -from core.entities.execution_extra_content import ExecutionExtraContentDomainModel - - -class ExecutionExtraContentRepository(Protocol): - def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: ... - - -__all__ = ["ExecutionExtraContentRepository"] diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 6c696b6478..b19cc73bd1 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -5,7 +5,6 @@ This module provides a concrete implementation of the service repository protoco using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. """ -import json from collections.abc import Sequence from datetime import datetime from typing import cast @@ -14,12 +13,11 @@ from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload -from repositories.api_workflow_node_execution_repository import ( - DifyAPIWorkflowNodeExecutionRepository, - WorkflowNodeExecutionSnapshot, +from models.workflow import ( + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, ) +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): @@ -81,7 +79,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut WorkflowNodeExecutionModel.app_id == app_id, WorkflowNodeExecutionModel.workflow_id == workflow_id, WorkflowNodeExecutionModel.node_id == node_id, - WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED, ) .order_by(desc(WorkflowNodeExecutionModel.created_at)) .limit(1) @@ -120,80 +117,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut with self._session_maker() as session: return session.execute(stmt).scalars().all() - def get_execution_snapshots_by_workflow_run( - self, - tenant_id: str, - app_id: str, - workflow_id: str, - triggered_from: str, - workflow_run_id: str, - ) -> Sequence[WorkflowNodeExecutionSnapshot]: - stmt = ( - select( - WorkflowNodeExecutionModel.id, - WorkflowNodeExecutionModel.node_execution_id, - WorkflowNodeExecutionModel.node_id, - WorkflowNodeExecutionModel.node_type, - WorkflowNodeExecutionModel.title, - WorkflowNodeExecutionModel.index, - WorkflowNodeExecutionModel.status, - WorkflowNodeExecutionModel.elapsed_time, - WorkflowNodeExecutionModel.created_at, - WorkflowNodeExecutionModel.finished_at, - WorkflowNodeExecutionModel.execution_metadata, - ) - .where( - WorkflowNodeExecutionModel.tenant_id == tenant_id, - WorkflowNodeExecutionModel.app_id == app_id, - WorkflowNodeExecutionModel.workflow_id == workflow_id, - WorkflowNodeExecutionModel.triggered_from == triggered_from, - WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, - ) - .order_by( - asc(WorkflowNodeExecutionModel.created_at), - asc(WorkflowNodeExecutionModel.index), - ) - ) - - with self._session_maker() as session: - rows = session.execute(stmt).all() - - return [self._row_to_snapshot(row) for row in rows] - - @staticmethod - def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot: - metadata: dict[str, object] = {} - execution_metadata = getattr(row, "execution_metadata", None) - if execution_metadata: - try: - metadata = json.loads(execution_metadata) - except json.JSONDecodeError: - metadata = {} - iteration_id = metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID.value) - loop_id = metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID.value) - execution_id = getattr(row, "node_execution_id", None) or row.id - elapsed_time = getattr(row, "elapsed_time", None) - created_at = row.created_at - finished_at = getattr(row, "finished_at", None) - if elapsed_time is None: - if finished_at is not None and created_at is not None: - elapsed_time = (finished_at - created_at).total_seconds() - else: - elapsed_time = 0.0 - return WorkflowNodeExecutionSnapshot( - execution_id=str(execution_id), - node_id=row.node_id, - node_type=row.node_type, - title=row.title, - index=row.index, - status=row.status, - elapsed_time=float(elapsed_time), - created_at=created_at, - finished_at=finished_at, - iteration_id=str(iteration_id) if iteration_id else None, - loop_id=str(loop_id) if loop_id else None, - ) - def get_execution_by_id( self, execution_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 00cb979e17..d5214be042 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -19,7 +19,6 @@ Implementation Notes: - Maintains data consistency with proper transaction handling """ -import json import logging import uuid from collections.abc import Callable, Sequence @@ -28,14 +27,12 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa -from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause from core.workflow.enums import WorkflowExecutionStatus, WorkflowType -from core.workflow.nodes.human_input.entities import FormDefinition from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date @@ -43,7 +40,6 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -61,67 +57,6 @@ class _WorkflowRunError(Exception): pass -def _select_recipient_token( - recipients: Sequence[HumanInputFormRecipient], - recipient_type: RecipientType, -) -> str | None: - for recipient in recipients: - if recipient.recipient_type == recipient_type and recipient.access_token: - return recipient.access_token - return None - - -def _build_human_input_required_reason( - reason_model: WorkflowPauseReason, - form_model: HumanInputForm | None, - recipients: Sequence[HumanInputFormRecipient], -) -> HumanInputRequired: - form_content = "" - inputs = [] - actions = [] - display_in_ui = False - resolved_default_values: dict[str, Any] = {} - node_title = "Human Input" - form_id = reason_model.form_id - node_id = reason_model.node_id - if form_model is not None: - form_id = form_model.id - node_id = form_model.node_id or node_id - try: - definition_payload = json.loads(form_model.form_definition) - if "expiration_time" not in definition_payload: - definition_payload["expiration_time"] = form_model.expiration_time - definition = FormDefinition.model_validate(definition_payload) - except ValidationError: - definition = None - - if definition is not None: - form_content = definition.form_content - inputs = list(definition.inputs) - actions = list(definition.user_actions) - display_in_ui = bool(definition.display_in_ui) - resolved_default_values = dict(definition.default_values) - node_title = definition.node_title or node_title - - form_token = ( - _select_recipient_token(recipients, RecipientType.BACKSTAGE) - or _select_recipient_token(recipients, RecipientType.CONSOLE) - or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP) - ) - - return HumanInputRequired( - form_id=form_id, - form_content=form_content, - inputs=inputs, - actions=actions, - display_in_ui=display_in_ui, - node_id=node_id, - node_title=node_title, - form_token=form_token, - resolved_default_values=resolved_default_values, - ) - - class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): """ SQLAlchemy implementation of APIWorkflowRunRepository. @@ -741,11 +676,9 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): raise ValueError(f"WorkflowRun not found: {workflow_run_id}") # Check if workflow is in RUNNING status - # TODO(QuantumGhost): It seems that the persistence of `WorkflowRun.status` - # happens before the execution of GraphLayer - if workflow_run.status not in {WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.PAUSED}: + if workflow_run.status != WorkflowExecutionStatus.RUNNING: raise _WorkflowRunError( - f"Only WorkflowRun with RUNNING or PAUSED status can be paused, " + f"Only WorkflowRun with RUNNING status can be paused, " f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}" ) # @@ -796,48 +729,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) - return _PrivateWorkflowPauseEntity( - pause_model=pause_model, - reason_models=pause_reason_models, - pause_reasons=pause_reasons, - ) + return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models) def _get_reasons_by_pause_id(self, session: Session, pause_id: str): reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id) pause_reason_models = session.scalars(reason_stmt).all() return pause_reason_models - def _hydrate_pause_reasons( - self, - session: Session, - pause_reason_models: Sequence[WorkflowPauseReason], - ) -> list[PauseReason]: - form_ids = [ - reason.form_id - for reason in pause_reason_models - if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id - ] - form_models: dict[str, HumanInputForm] = {} - recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {} - if form_ids: - form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) - for form in session.scalars(form_stmt).all(): - form_models[form.id] = form - - recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) - for recipient in session.scalars(recipient_stmt).all(): - recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient) - - pause_reasons: list[PauseReason] = [] - for reason in pause_reason_models: - if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: - form_model = form_models.get(reason.form_id) - recipients = recipient_models_by_form.get(reason.form_id, []) - pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients)) - else: - pause_reasons.append(reason.to_entity()) - return pause_reasons - def get_workflow_pause( self, workflow_run_id: str, @@ -869,12 +767,14 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if pause_model is None: return None pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id) - pause_reasons = self._hydrate_pause_reasons(session, pause_reason_models) + + human_input_form: list[Any] = [] + # TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason return _PrivateWorkflowPauseEntity( pause_model=pause_model, reason_models=pause_reason_models, - pause_reasons=pause_reasons, + human_input_form=human_input_form, ) def resume_workflow_pause( @@ -928,10 +828,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}") pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id) - hydrated_pause_reasons = self._hydrate_pause_reasons(session, pause_reasons) # Mark as resumed pause_model.resumed_at = naive_utc_now() + workflow_run.pause_id = None # type: ignore workflow_run.status = WorkflowExecutionStatus.RUNNING session.add(pause_model) @@ -939,11 +839,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) - return _PrivateWorkflowPauseEntity( - pause_model=pause_model, - reason_models=pause_reasons, - pause_reasons=hydrated_pause_reasons, - ) + return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons) def delete_workflow_pause( self, @@ -1269,15 +1165,6 @@ GROUP BY return cast(list[AverageInteractionStats], response_data) - def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None: - """Get a specific workflow run by its id and the associated tenant id.""" - with self._session_maker() as session: - stmt = select(WorkflowRun).where( - WorkflowRun.tenant_id == tenant_id, - WorkflowRun.id == run_id, - ) - return session.scalar(stmt) - class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): """ @@ -1292,12 +1179,10 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): *, pause_model: WorkflowPause, reason_models: Sequence[WorkflowPauseReason], - pause_reasons: Sequence[PauseReason] | None = None, human_input_form: Sequence = (), ) -> None: self._pause_model = pause_model self._reason_models = reason_models - self._pause_reasons = pause_reasons self._cached_state: bytes | None = None self._human_input_form = human_input_form @@ -1334,10 +1219,4 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): return self._pause_model.resumed_at def get_pause_reasons(self) -> Sequence[PauseReason]: - if self._pause_reasons is not None: - return list(self._pause_reasons) return [reason.to_entity() for reason in self._reason_models] - - @property - def paused_at(self) -> datetime: - return self._pause_model.created_at diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py deleted file mode 100644 index 5a2c0ea46f..0000000000 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ /dev/null @@ -1,200 +0,0 @@ -from __future__ import annotations - -import json -import logging -import re -from collections import defaultdict -from collections.abc import Sequence -from typing import Any - -from sqlalchemy import select -from sqlalchemy.orm import Session, selectinload, sessionmaker - -from core.entities.execution_extra_content import ( - ExecutionExtraContentDomainModel, - HumanInputFormDefinition, - HumanInputFormSubmissionData, -) -from core.entities.execution_extra_content import ( - HumanInputContent as HumanInputContentDomainModel, -) -from core.workflow.nodes.human_input.entities import FormDefinition -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from models.execution_extra_content import ( - ExecutionExtraContent as ExecutionExtraContentModel, -) -from models.execution_extra_content import ( - HumanInputContent as HumanInputContentModel, -) -from models.human_input import HumanInputFormRecipient, RecipientType -from repositories.execution_extra_content_repository import ExecutionExtraContentRepository - -logger = logging.getLogger(__name__) - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -def _extract_output_field_names(form_content: str) -> list[str]: - if not form_content: - return [] - return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)] - - -class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository): - def __init__(self, session_maker: sessionmaker[Session]): - self._session_maker = session_maker - - def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: - if not message_ids: - return [] - - grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = { - message_id: [] for message_id in message_ids - } - - stmt = ( - select(ExecutionExtraContentModel) - .where(ExecutionExtraContentModel.message_id.in_(message_ids)) - .options(selectinload(HumanInputContentModel.form)) - .order_by(ExecutionExtraContentModel.created_at.asc()) - ) - - with self._session_maker() as session: - results = session.scalars(stmt).all() - - form_ids = { - content.form_id - for content in results - if isinstance(content, HumanInputContentModel) and content.form_id is not None - } - recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list) - if form_ids: - recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) - recipients = session.scalars(recipient_stmt).all() - for recipient in recipients: - recipients_by_form_id[recipient.form_id].append(recipient) - else: - recipients_by_form_id = {} - - for content in results: - message_id = content.message_id - if not message_id or message_id not in grouped_contents: - continue - - domain_model = self._map_model_to_domain(content, recipients_by_form_id) - if domain_model is None: - continue - - grouped_contents[message_id].append(domain_model) - - return [grouped_contents[message_id] for message_id in message_ids] - - def _map_model_to_domain( - self, - model: ExecutionExtraContentModel, - recipients_by_form_id: dict[str, list[HumanInputFormRecipient]], - ) -> ExecutionExtraContentDomainModel | None: - if isinstance(model, HumanInputContentModel): - return self._map_human_input_content(model, recipients_by_form_id) - - logger.debug("Unsupported execution extra content type encountered: %s", model.type) - return None - - def _map_human_input_content( - self, - model: HumanInputContentModel, - recipients_by_form_id: dict[str, list[HumanInputFormRecipient]], - ) -> HumanInputContentDomainModel | None: - form = model.form - if form is None: - logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id) - return None - - try: - definition_payload = json.loads(form.form_definition) - if "expiration_time" not in definition_payload: - definition_payload["expiration_time"] = form.expiration_time - form_definition = FormDefinition.model_validate(definition_payload) - except ValueError: - logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id) - return None - node_title = form_definition.node_title or form.node_id - display_in_ui = bool(form_definition.display_in_ui) - - submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED - if not submitted: - form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, [])) - return HumanInputContentDomainModel( - workflow_run_id=model.workflow_run_id, - submitted=False, - form_definition=HumanInputFormDefinition( - form_id=form.id, - node_id=form.node_id, - node_title=node_title, - form_content=form.rendered_content, - inputs=form_definition.inputs, - actions=form_definition.user_actions, - display_in_ui=display_in_ui, - form_token=form_token, - resolved_default_values=form_definition.default_values, - expiration_time=int(form.expiration_time.timestamp()), - ), - ) - - selected_action_id = form.selected_action_id - if not selected_action_id: - logger.warning("HumanInputContent(id=%s) form has no selected action", model.id) - return None - - action_text = next( - (action.title for action in form_definition.user_actions if action.id == selected_action_id), - selected_action_id, - ) - - submitted_data: dict[str, Any] = {} - if form.submitted_data: - try: - submitted_data = json.loads(form.submitted_data) - except ValueError: - logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id) - return None - - rendered_content = HumanInputNode.render_form_content_with_outputs( - form.rendered_content, - submitted_data, - _extract_output_field_names(form_definition.form_content), - ) - - return HumanInputContentDomainModel( - workflow_run_id=model.workflow_run_id, - submitted=True, - form_submission_data=HumanInputFormSubmissionData( - node_id=form.node_id, - node_title=node_title, - rendered_content=rendered_content, - action_id=selected_action_id, - action_text=action_text, - ), - ) - - @staticmethod - def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None: - console_recipient = next( - (recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE), - None, - ) - if console_recipient and console_recipient.access_token: - return console_recipient.access_token - - web_app_recipient = next( - (recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP), - None, - ) - if web_app_recipient and web_app_recipient.access_token: - return web_app_recipient.access_token - - return None - - -__all__ = ["SQLAlchemyExecutionExtraContentRepository"] diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index 1f6740b066..f3dc4cd60b 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -92,16 +92,6 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): return list(self.session.scalars(query).all()) - def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None: - """Get the trigger log associated with a workflow run.""" - query = ( - select(WorkflowTriggerLog) - .where(WorkflowTriggerLog.workflow_run_id == workflow_run_id) - .order_by(WorkflowTriggerLog.created_at.desc()) - .limit(1) - ) - return self.session.scalar(query) - def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: """ Delete trigger logs associated with the given workflow run ids. diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py index 7f9e6b7b68..b0009e398d 100644 --- a/api/repositories/workflow_trigger_log_repository.py +++ b/api/repositories/workflow_trigger_log_repository.py @@ -110,18 +110,6 @@ class WorkflowTriggerLogRepository(Protocol): """ ... - def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None: - """ - Retrieve a trigger log associated with a specific workflow run. - - Args: - workflow_run_id: Identifier of the workflow run - - Returns: - The matching WorkflowTriggerLog if present, None otherwise - """ - ... - def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: """ Delete trigger logs for workflow run IDs. diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 56e9cc6a00..8ebc87a670 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -158,7 +158,7 @@ class AppAnnotationService: .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) ) annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False) - return annotations.items, annotations.total + return annotations.items, annotations.total or 0 @classmethod def export_annotation_list_by_app_id(cls, app_id: str): @@ -524,7 +524,7 @@ class AppAnnotationService: annotation_hit_histories = db.paginate( select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False ) - return annotation_hit_histories.items, annotation_hit_histories.total + return annotation_hit_histories.items, annotation_hit_histories.total or 0 @classmethod def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 9400362605..0f42c99246 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -44,7 +44,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.6.0" +CURRENT_DSL_VERSION = "0.5.0" class ImportMode(StrEnum): diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index a3de046d99..ce85f2e914 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,9 +1,7 @@ from __future__ import annotations -import logging -import threading import uuid -from collections.abc import Callable, Generator, Mapping +from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Union from configs import dify_config @@ -11,61 +9,22 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.completion.app_generator import CompletionAppGenerator -from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit -from core.app.features.rate_limiting.rate_limit import rate_limit_context from enums.quota_type import QuotaType, unlimited from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser -from models.workflow import Workflow, WorkflowRun +from models.workflow import Workflow from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService -from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task - -logger = logging.getLogger(__name__) - -SSE_TASK_START_FALLBACK_MS = 200 if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload class AppGenerateService: - @staticmethod - def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]: - started = False - lock = threading.Lock() - - def _try_start() -> bool: - nonlocal started - with lock: - if started: - return True - try: - start_task() - except Exception: - logger.exception("Failed to enqueue streaming task") - return False - started = True - return True - - # XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber. - # The Celery task may publish the first event before the API side actually subscribes, - # causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe, - # but also use a short fallback timer so the task still runs if the client never consumes. - timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start) - timer.daemon = True - timer.start() - - def _on_subscribe() -> None: - if _try_start(): - timer.cancel() - - return _on_subscribe - @classmethod @trace_span(AppGenerateHandler) def generate( @@ -129,29 +88,15 @@ class AppGenerateService: elif app_model.mode == AppMode.ADVANCED_CHAT: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) - with rate_limit_context(rate_limit, request_id): - payload = AppExecutionParams.new( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=streaming, - call_depth=0, - ) - payload_json = payload.model_dump_json() - - def on_subscribe(): - workflow_based_app_execution_task.delay(payload_json) - - on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) - generator = AdvancedChatAppGenerator() return rate_limit.generate( - generator.convert_to_event_stream( - generator.retrieve_events( - AppMode.ADVANCED_CHAT, - payload.workflow_run_id, - on_subscribe=on_subscribe, + AdvancedChatAppGenerator.convert_to_event_stream( + AdvancedChatAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, ), ), request_id=request_id, @@ -159,36 +104,6 @@ class AppGenerateService: elif app_model.mode == AppMode.WORKFLOW: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) - if streaming: - with rate_limit_context(rate_limit, request_id): - payload = AppExecutionParams.new( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=True, - call_depth=0, - root_node_id=root_node_id, - workflow_run_id=str(uuid.uuid4()), - ) - payload_json = payload.model_dump_json() - - def on_subscribe(): - workflow_based_app_execution_task.delay(payload_json) - - on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) - return rate_limit.generate( - WorkflowAppGenerator.convert_to_event_stream( - MessageBasedAppGenerator.retrieve_events( - AppMode.WORKFLOW, - payload.workflow_run_id, - on_subscribe=on_subscribe, - ), - ), - request_id, - ) - return rate_limit.generate( WorkflowAppGenerator.convert_to_event_stream( WorkflowAppGenerator().generate( @@ -197,7 +112,7 @@ class AppGenerateService: user=user, args=args, invoke_from=invoke_from, - streaming=False, + streaming=streaming, root_node_id=root_node_id, call_depth=0, ), @@ -333,19 +248,3 @@ class AppGenerateService: raise ValueError("Workflow not published") return workflow - - @classmethod - def get_response_generator( - cls, - app_model: App, - workflow_run: WorkflowRun, - ): - if workflow_run.status.is_ended(): - # TODO(QuantumGhost): handled the ended scenario. - pass - - generator = AdvancedChatAppGenerator() - - return generator.convert_to_event_stream( - generator.retrieve_events(AppMode(app_model.mode), workflow_run.id), - ) diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a95361cebd..41ee9c88aa 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -136,7 +136,7 @@ class AudioService: message = db.session.query(Message).where(Message.id == message_id).first() if message is None: return None - if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}: + if message.answer == "" and message.status == MessageStatus.NORMAL: return None else: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0b3fcbe4ae..1ea6c4e1c3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config +from core.db.session_factory import session_factory from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.file import helpers as file_helpers from core.helper.name_generator import generate_incremental_name @@ -1388,6 +1389,46 @@ class DocumentService: ).all() return documents + @staticmethod + def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int: + """ + Update need_summary field for multiple documents. + + This method handles the case where documents were created when summary_index_setting was disabled, + and need to be updated when summary_index_setting is later enabled. + + Args: + dataset_id: Dataset ID + document_ids: List of document IDs to update + need_summary: Value to set for need_summary field (default: True) + + Returns: + Number of documents updated + """ + if not document_ids: + return 0 + + document_id_list: list[str] = [str(document_id) for document_id in document_ids] + + with session_factory.create_session() as session: + updated_count = ( + session.query(Document) + .filter( + Document.id.in_(document_id_list), + Document.dataset_id == dataset_id, + Document.doc_form != "qa_model", # Skip qa_model documents + ) + .update({Document.need_summary: need_summary}, synchronize_session=False) + ) + session.commit() + logger.info( + "Updated need_summary to %s for %d documents in dataset %s", + need_summary, + updated_count, + dataset_id, + ) + return updated_count + @staticmethod def get_document_download_url(document: Document) -> str: """ @@ -2937,14 +2978,15 @@ class DocumentService: """ now = naive_utc_now() - if action == "enable": - return DocumentService._prepare_enable_update(document, now) - elif action == "disable": - return DocumentService._prepare_disable_update(document, user, now) - elif action == "archive": - return DocumentService._prepare_archive_update(document, user, now) - elif action == "un_archive": - return DocumentService._prepare_unarchive_update(document, now) + match action: + case "enable": + return DocumentService._prepare_enable_update(document, now) + case "disable": + return DocumentService._prepare_disable_update(document, user, now) + case "archive": + return DocumentService._prepare_archive_update(document, user, now) + case "un_archive": + return DocumentService._prepare_unarchive_update(document, now) return None @@ -3581,56 +3623,57 @@ class SegmentService: # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return - if action == "enable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == False, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + match action: + case "enable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) - elif action == "disable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == True, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = False - segment.disabled_at = naive_utc_now() - segment.disabled_by = current_user.id - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + case "disable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = naive_utc_now() + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) @classmethod def create_child_chunk( diff --git a/api/services/feature_service.py b/api/services/feature_service.py index fda3a15144..d94ae49d91 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -138,8 +138,6 @@ class FeatureModel(BaseModel): is_allow_transfer_workspace: bool = True trigger_event: Quota = Quota(usage=0, limit=3000, reset_date=0) api_rate_limit: Quota = Quota(usage=0, limit=5000, reset_date=0) - # Controls whether email delivery is allowed for HumanInput nodes. - human_input_email_delivery_enabled: bool = False # pydantic configs model_config = ConfigDict(protected_namespaces=()) knowledge_pipeline: KnowledgePipeline = KnowledgePipeline() @@ -193,11 +191,6 @@ class FeatureService: features.knowledge_pipeline.publish_enabled = True cls._fulfill_params_from_workspace_info(features, tenant_id) - features.human_input_email_delivery_enabled = cls._resolve_human_input_email_delivery_enabled( - features=features, - tenant_id=tenant_id, - ) - return features @classmethod @@ -210,17 +203,6 @@ class FeatureService: knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", CloudPlan.SANDBOX) return knowledge_rate_limit - @classmethod - def _resolve_human_input_email_delivery_enabled(cls, *, features: FeatureModel, tenant_id: str | None) -> bool: - if dify_config.ENTERPRISE_ENABLED or not dify_config.BILLING_ENABLED: - return True - if not tenant_id: - return False - return features.billing.enabled and features.billing.subscription.plan in ( - CloudPlan.PROFESSIONAL, - CloudPlan.TEAM, - ) - @classmethod def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py deleted file mode 100644 index ff37ff098f..0000000000 --- a/api/services/human_input_delivery_test_service.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from enum import StrEnum -from typing import Protocol - -from sqlalchemy import Engine, select -from sqlalchemy.orm import sessionmaker - -from configs import dify_config -from core.workflow.nodes.human_input.entities import ( - DeliveryChannelConfig, - EmailDeliveryConfig, - EmailDeliveryMethod, - ExternalRecipient, - MemberRecipient, -) -from core.workflow.runtime import VariablePool -from extensions.ext_database import db -from extensions.ext_mail import mail -from libs.email_template_renderer import render_email_template -from models import Account, TenantAccountJoin -from services.feature_service import FeatureService - - -class DeliveryTestStatus(StrEnum): - OK = "ok" - FAILED = "failed" - - -@dataclass(frozen=True) -class DeliveryTestEmailRecipient: - email: str - form_token: str - - -@dataclass(frozen=True) -class DeliveryTestContext: - tenant_id: str - app_id: str - node_id: str - node_title: str | None - rendered_content: str - template_vars: dict[str, str] = field(default_factory=dict) - recipients: list[DeliveryTestEmailRecipient] = field(default_factory=list) - variable_pool: VariablePool | None = None - - -@dataclass(frozen=True) -class DeliveryTestResult: - status: DeliveryTestStatus - delivered_to: list[str] = field(default_factory=list) - warnings: list[str] = field(default_factory=list) - - -class DeliveryTestError(Exception): - pass - - -class DeliveryTestUnsupportedError(DeliveryTestError): - pass - - -def _build_form_link(token: str | None) -> str | None: - if not token: - return None - base_url = dify_config.APP_WEB_URL - if not base_url: - return None - return f"{base_url.rstrip('/')}/form/{token}" - - -class DeliveryTestHandler(Protocol): - def supports(self, method: DeliveryChannelConfig) -> bool: ... - - def send_test( - self, - *, - context: DeliveryTestContext, - method: DeliveryChannelConfig, - ) -> DeliveryTestResult: ... - - -class DeliveryTestRegistry: - def __init__(self, handlers: list[DeliveryTestHandler] | None = None) -> None: - self._handlers = list(handlers or []) - - def register(self, handler: DeliveryTestHandler) -> None: - self._handlers.append(handler) - - def dispatch( - self, - *, - context: DeliveryTestContext, - method: DeliveryChannelConfig, - ) -> DeliveryTestResult: - for handler in self._handlers: - if handler.supports(method): - return handler.send_test(context=context, method=method) - raise DeliveryTestUnsupportedError("Delivery method does not support test send.") - - @classmethod - def default(cls) -> DeliveryTestRegistry: - return cls([EmailDeliveryTestHandler()]) - - -class HumanInputDeliveryTestService: - def __init__(self, registry: DeliveryTestRegistry | None = None) -> None: - self._registry = registry or DeliveryTestRegistry.default() - - def send_test( - self, - *, - context: DeliveryTestContext, - method: DeliveryChannelConfig, - ) -> DeliveryTestResult: - return self._registry.dispatch(context=context, method=method) - - -class EmailDeliveryTestHandler: - def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None: - if session_factory is None: - session_factory = sessionmaker(bind=db.engine) - elif isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - - def supports(self, method: DeliveryChannelConfig) -> bool: - return isinstance(method, EmailDeliveryMethod) - - def send_test( - self, - *, - context: DeliveryTestContext, - method: DeliveryChannelConfig, - ) -> DeliveryTestResult: - if not isinstance(method, EmailDeliveryMethod): - raise DeliveryTestUnsupportedError("Delivery method does not support test send.") - features = FeatureService.get_features(context.tenant_id) - if not features.human_input_email_delivery_enabled: - raise DeliveryTestError("Email delivery is not available for current plan.") - if not mail.is_inited(): - raise DeliveryTestError("Mail client is not initialized.") - - recipients = self._resolve_recipients( - tenant_id=context.tenant_id, - method=method, - ) - if not recipients: - raise DeliveryTestError("No recipients configured for delivery method.") - - delivered: list[str] = [] - for recipient_email in recipients: - substitutions = self._build_substitutions( - context=context, - recipient_email=recipient_email, - ) - subject = render_email_template(method.config.subject, substitutions) - templated_body = EmailDeliveryConfig.render_body_template( - body=method.config.body, - url=substitutions.get("form_link"), - variable_pool=context.variable_pool, - ) - body = render_email_template(templated_body, substitutions) - - mail.send( - to=recipient_email, - subject=subject, - html=body, - ) - delivered.append(recipient_email) - - return DeliveryTestResult(status=DeliveryTestStatus.OK, delivered_to=delivered) - - def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]: - recipients = method.config.recipients - emails: list[str] = [] - member_user_ids: list[str] = [] - for recipient in recipients.items: - if isinstance(recipient, MemberRecipient): - member_user_ids.append(recipient.user_id) - elif isinstance(recipient, ExternalRecipient): - if recipient.email: - emails.append(recipient.email) - - if recipients.whole_workspace: - member_user_ids = [] - member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None) - emails.extend(member_emails.values()) - elif member_user_ids: - member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids) - for user_id in member_user_ids: - email = member_emails.get(user_id) - if email: - emails.append(email) - - return list(dict.fromkeys([email for email in emails if email])) - - def _query_workspace_member_emails( - self, - *, - tenant_id: str, - user_ids: list[str] | None, - ) -> dict[str, str]: - if user_ids is None: - unique_ids = None - else: - unique_ids = {user_id for user_id in user_ids if user_id} - if not unique_ids: - return {} - - stmt = ( - select(Account.id, Account.email) - .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) - .where(TenantAccountJoin.tenant_id == tenant_id) - ) - if unique_ids is not None: - stmt = stmt.where(Account.id.in_(unique_ids)) - - with self._session_factory() as session: - rows = session.execute(stmt).all() - return dict(rows) - - @staticmethod - def _build_substitutions( - *, - context: DeliveryTestContext, - recipient_email: str, - ) -> dict[str, str]: - raw_values: dict[str, str | None] = { - "form_id": "", - "node_title": context.node_title, - "workflow_run_id": "", - "form_token": "", - "form_link": "", - "form_content": context.rendered_content, - "recipient_email": recipient_email, - } - substitutions = {key: value or "" for key, value in raw_values.items()} - if context.template_vars: - substitutions.update({key: value for key, value in context.template_vars.items() if value is not None}) - token = next( - (recipient.form_token for recipient in context.recipients if recipient.email == recipient_email), - None, - ) - if token: - substitutions["form_token"] = token - substitutions["form_link"] = _build_form_link(token) or "" - return substitutions diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py deleted file mode 100644 index d50325e5e5..0000000000 --- a/api/services/human_input_service.py +++ /dev/null @@ -1,250 +0,0 @@ -import logging -from collections.abc import Mapping -from datetime import datetime, timedelta -from typing import Any - -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, sessionmaker - -from configs import dify_config -from core.repositories.human_input_repository import ( - HumanInputFormRecord, - HumanInputFormSubmissionRepository, -) -from core.workflow.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from libs.datetime_utils import ensure_naive_utc, naive_utc_now -from libs.exception import BaseHTTPException -from models.human_input import RecipientType -from models.model import App, AppMode -from repositories.factory import DifyAPIRepositoryFactory -from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE, resume_app_execution - - -class Form: - def __init__(self, record: HumanInputFormRecord): - self._record = record - - def get_definition(self) -> FormDefinition: - return self._record.definition - - @property - def submitted(self) -> bool: - return self._record.submitted - - @property - def id(self) -> str: - return self._record.form_id - - @property - def workflow_run_id(self) -> str | None: - """Workflow run id for runtime forms; None for delivery tests.""" - return self._record.workflow_run_id - - @property - def tenant_id(self) -> str: - return self._record.tenant_id - - @property - def app_id(self) -> str: - return self._record.app_id - - @property - def recipient_id(self) -> str | None: - return self._record.recipient_id - - @property - def recipient_type(self) -> RecipientType | None: - return self._record.recipient_type - - @property - def status(self) -> HumanInputFormStatus: - return self._record.status - - @property - def form_kind(self) -> HumanInputFormKind: - return self._record.form_kind - - @property - def created_at(self) -> "datetime": - return self._record.created_at - - @property - def expiration_time(self) -> "datetime": - return self._record.expiration_time - - -class HumanInputError(Exception): - pass - - -class FormSubmittedError(HumanInputError, BaseHTTPException): - error_code = "human_input_form_submitted" - description = "This form has already been submitted by another user, form_id={form_id}" - code = 412 - - def __init__(self, form_id: str): - template = self.description or "This form has already been submitted by another user, form_id={form_id}" - description = template.format(form_id=form_id) - super().__init__(description=description) - - -class FormNotFoundError(HumanInputError, BaseHTTPException): - error_code = "human_input_form_not_found" - code = 404 - - -class InvalidFormDataError(HumanInputError, BaseHTTPException): - error_code = "invalid_form_data" - code = 400 - - def __init__(self, description: str): - super().__init__(description=description) - - -class WebAppDeliveryNotEnabledError(HumanInputError, BaseException): - pass - - -class FormExpiredError(HumanInputError, BaseHTTPException): - error_code = "human_input_form_expired" - code = 412 - - def __init__(self, form_id: str): - super().__init__(description=f"This form has expired, form_id={form_id}") - - -logger = logging.getLogger(__name__) - - -class HumanInputService: - def __init__( - self, - session_factory: sessionmaker[Session] | Engine, - form_repository: HumanInputFormSubmissionRepository | None = None, - ): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory) - - def get_form_by_token(self, form_token: str) -> Form | None: - record = self._form_repository.get_by_token(form_token) - if record is None: - return None - return Form(record) - - def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form | None: - form = self.get_form_by_token(form_token) - if form is None or form.recipient_type != recipient_type: - return None - self._ensure_not_submitted(form) - return form - - def get_form_definition_by_token_for_console(self, form_token: str) -> Form | None: - form = self.get_form_by_token(form_token) - if form is None or form.recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}: - return None - self._ensure_not_submitted(form) - return form - - def submit_form_by_token( - self, - recipient_type: RecipientType, - form_token: str, - selected_action_id: str, - form_data: Mapping[str, Any], - submission_end_user_id: str | None = None, - submission_user_id: str | None = None, - ): - form = self.get_form_by_token(form_token) - if form is None or form.recipient_type != recipient_type: - raise WebAppDeliveryNotEnabledError() - - self.ensure_form_active(form) - self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data) - - result = self._form_repository.mark_submitted( - form_id=form.id, - recipient_id=form.recipient_id, - selected_action_id=selected_action_id, - form_data=form_data, - submission_user_id=submission_user_id, - submission_end_user_id=submission_end_user_id, - ) - - if result.form_kind != HumanInputFormKind.RUNTIME: - return - if result.workflow_run_id is None: - return - self.enqueue_resume(result.workflow_run_id) - - def ensure_form_active(self, form: Form) -> None: - if form.submitted: - raise FormSubmittedError(form.id) - if form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}: - raise FormExpiredError(form.id) - now = naive_utc_now() - if ensure_naive_utc(form.expiration_time) <= now: - raise FormExpiredError(form.id) - if self._is_globally_expired(form, now=now): - raise FormExpiredError(form.id) - - def _ensure_not_submitted(self, form: Form) -> None: - if form.submitted: - raise FormSubmittedError(form.id) - - def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None: - definition = form.get_definition() - try: - validate_human_input_submission( - inputs=definition.inputs, - user_actions=definition.user_actions, - selected_action_id=selected_action_id, - form_data=form_data, - ) - except HumanInputSubmissionValidationError as exc: - raise InvalidFormDataError(str(exc)) from exc - - def enqueue_resume(self, workflow_run_id: str) -> None: - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory) - workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(workflow_run_id) - - if workflow_run is None: - raise AssertionError(f"WorkflowRun not found, id={workflow_run_id}") - with self._session_factory(expire_on_commit=False) as session: - app_query = select(App).where(App.id == workflow_run.app_id) - app = session.execute(app_query).scalar_one_or_none() - if app is None: - logger.error( - "App not found for WorkflowRun, workflow_run_id=%s, app_id=%s", workflow_run_id, workflow_run.app_id - ) - return - - if app.mode in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}: - payload = {"workflow_run_id": workflow_run_id} - try: - resume_app_execution.apply_async( - kwargs={"payload": payload}, - queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, - ) - except Exception: # pragma: no cover - logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id) - return - - logger.warning("App mode %s does not support resume for workflow run %s", app.mode, workflow_run_id) - - def _is_globally_expired(self, form: Form, *, now: datetime | None = None) -> bool: - global_timeout_seconds = dify_config.HITL_GLOBAL_TIMEOUT_SECONDS - if global_timeout_seconds <= 0: - return False - if form.workflow_run_id is None: - return False - current = now or naive_utc_now() - created_at = ensure_naive_utc(form.created_at) - global_deadline = created_at + timedelta(seconds=global_timeout_seconds) - return global_deadline <= current diff --git a/api/services/message_service.py b/api/services/message_service.py index ce699e79d4..a53ca8b22d 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,9 +1,6 @@ import json -from collections.abc import Sequence from typing import Union -from sqlalchemy.orm import sessionmaker - from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -17,10 +14,6 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback -from repositories.execution_extra_content_repository import ExecutionExtraContentRepository -from repositories.sqlalchemy_execution_extra_content_repository import ( - SQLAlchemyExecutionExtraContentRepository, -) from services.conversation_service import ConversationService from services.errors.message import ( FirstMessageNotExistsError, @@ -31,23 +24,6 @@ from services.errors.message import ( from services.workflow_service import WorkflowService -def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository: - session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) - return SQLAlchemyExecutionExtraContentRepository(session_maker=session_maker) - - -def attach_message_extra_contents(messages: Sequence[Message]) -> None: - if not messages: - return - - repository = _create_execution_extra_content_repository() - extra_contents_lists = repository.get_by_message_ids([message.id for message in messages]) - - for index, message in enumerate(messages): - contents = extra_contents_lists[index] if index < len(extra_contents_lists) else [] - message.set_extra_contents([content.model_dump(mode="json", exclude_none=True) for content in contents]) - - class MessageService: @classmethod def pagination_by_first_id( @@ -109,8 +85,6 @@ class MessageService: if order == "asc": history_messages = list(reversed(history_messages)) - attach_message_extra_contents(history_messages) - return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 8ea365e907..d0dfbc1070 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -174,6 +174,10 @@ class RagPipelineTransformService: else: dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Copy summary_index_setting from dataset to knowledge_index node configuration + if dataset.summary_index_setting: + knowledge_configuration.summary_index_setting = dataset.summary_index_setting + knowledge_configuration_dict.update(knowledge_configuration.model_dump()) node["data"] = knowledge_configuration_dict return node diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index b8e1f8bc3f..7c03ceed5b 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -49,11 +49,18 @@ class SummaryIndexService: # Use lazy import to avoid circular import from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + # Get document language to ensure summary is generated in the correct language + # This is especially important for image-only chunks where text is empty or minimal + document_language = None + if segment.document and segment.document.doc_language: + document_language = segment.document.doc_language + summary_content, usage = ParagraphIndexProcessor.generate_summary( tenant_id=dataset.tenant_id, text=segment.content, summary_index_setting=summary_index_setting, segment_id=segment.id, + document_language=document_language, ) if not summary_content: @@ -558,6 +565,9 @@ class SummaryIndexService: ) session.add(summary_record) + # Commit the batch created records + session.commit() + @staticmethod def update_summary_record_error( segment: DocumentSegment, @@ -762,7 +772,6 @@ class SummaryIndexService: dataset=dataset, status="not_started", ) - session.commit() # Commit initial records summary_records = [] diff --git a/api/services/tag_service.py b/api/services/tag_service.py index bd3585acf4..56f4ae9494 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -24,7 +24,7 @@ class TagService: escaped_keyword = escape_like_pattern(keyword) query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) - results: list = query.order_by(Tag.created_at.desc()).all() + results = query.order_by(Tag.created_at.desc()).all() return results @staticmethod diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 0ae40199ab..6d84d4e250 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,8 +1,6 @@ import json import logging -from collections.abc import Mapping from datetime import datetime -from typing import Any from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -10,8 +8,8 @@ from sqlalchemy.orm import Session from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db @@ -38,12 +36,10 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -67,8 +63,6 @@ class WorkflowToolManageService: if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") - WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) - workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -77,7 +71,7 @@ class WorkflowToolManageService: label=label, icon=json.dumps(icon), description=description, - parameter_configuration=json.dumps(parameters), + parameter_configuration=json.dumps([p.model_dump() for p in parameters]), privacy_policy=privacy_policy, version=workflow.version, ) @@ -106,7 +100,7 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): @@ -124,8 +118,6 @@ class WorkflowToolManageService: :param labels: labels :return: the updated tool """ - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -160,13 +152,11 @@ class WorkflowToolManageService: if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") - WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) - workflow_tool_provider.name = name workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps(parameters) + workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) workflow_tool_provider.privacy_policy = privacy_policy workflow_tool_provider.version = workflow.version workflow_tool_provider.updated_at = datetime.now() diff --git a/api/services/workflow/entities.py b/api/services/workflow/entities.py index 2af0d1fd90..70ec8d6e2a 100644 --- a/api/services/workflow/entities.py +++ b/api/services/workflow/entities.py @@ -98,12 +98,6 @@ class WorkflowTaskData(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) -class WorkflowResumeTaskData(BaseModel): - """Payload for workflow resumption tasks.""" - - workflow_run_id: str - - class AsyncTriggerExecutionResult(BaseModel): """Result from async trigger-based workflow execution""" diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py deleted file mode 100644 index dd4651f130..0000000000 --- a/api/services/workflow_event_snapshot_service.py +++ /dev/null @@ -1,460 +0,0 @@ -from __future__ import annotations - -import json -import logging -import queue -import threading -import time -from collections.abc import Generator, Mapping, Sequence -from dataclasses import dataclass -from typing import Any - -from sqlalchemy import desc, select -from sqlalchemy.orm import Session, sessionmaker - -from core.app.apps.message_generator import MessageGenerator -from core.app.entities.task_entities import ( - MessageReplaceStreamResponse, - NodeFinishStreamResponse, - NodeStartStreamResponse, - StreamEvent, - WorkflowPauseStreamResponse, - WorkflowStartStreamResponse, -) -from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.entities import WorkflowStartReason -from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from core.workflow.runtime import GraphRuntimeState -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from models.model import AppMode, Message -from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun -from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot -from repositories.entities.workflow_pause import WorkflowPauseEntity -from repositories.factory import DifyAPIRepositoryFactory - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class MessageContext: - conversation_id: str - message_id: str - created_at: int - answer: str | None = None - - -@dataclass -class BufferState: - queue: queue.Queue[Mapping[str, Any]] - stop_event: threading.Event - done_event: threading.Event - task_id_ready: threading.Event - task_id_hint: str | None = None - - -def build_workflow_event_stream( - *, - app_mode: AppMode, - workflow_run: WorkflowRun, - tenant_id: str, - app_id: str, - session_maker: sessionmaker[Session], - idle_timeout: float = 300, - ping_interval: float = 10.0, -) -> Generator[Mapping[str, Any] | str, None, None]: - topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id) - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) - message_context = ( - _get_message_context(session_maker, workflow_run.id) if app_mode == AppMode.ADVANCED_CHAT else None - ) - - pause_entity: WorkflowPauseEntity | None = None - if workflow_run.status == WorkflowExecutionStatus.PAUSED: - try: - pause_entity = workflow_run_repo.get_workflow_pause(workflow_run.id) - except Exception: - logger.exception("Failed to load workflow pause for run %s", workflow_run.id) - pause_entity = None - - resumption_context = _load_resumption_context(pause_entity) - node_snapshots = node_execution_repo.get_execution_snapshots_by_workflow_run( - tenant_id=tenant_id, - app_id=app_id, - workflow_id=workflow_run.workflow_id, - # NOTE(QuantumGhost): for events resumption, we only care about - # the execution records from `WORKFLOW_RUN`. - # - # Ideally filtering with `workflow_run_id` is enough. However, - # due to the index of `WorkflowNodeExecution` table, we have to - # add a filter condition of `triggered_from` to - # ensure that we can utilize the index. - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - workflow_run_id=workflow_run.id, - ) - - def _generate() -> Generator[Mapping[str, Any] | str, None, None]: - # send a PING event immediately to prevent the connection staying in pending state for a long time. - # - # This simplify the debugging process as the DevTools in Chrome does not - # provide complete curl command for pending connections. - yield StreamEvent.PING.value - - last_msg_time = time.time() - last_ping_time = last_msg_time - - with topic.subscribe() as sub: - buffer_state = _start_buffering(sub) - try: - task_id = _resolve_task_id(resumption_context, buffer_state, workflow_run.id) - - snapshot_events = _build_snapshot_events( - workflow_run=workflow_run, - node_snapshots=node_snapshots, - task_id=task_id, - message_context=message_context, - pause_entity=pause_entity, - resumption_context=resumption_context, - ) - - for event in snapshot_events: - last_msg_time = time.time() - last_ping_time = last_msg_time - yield event - if _is_terminal_event(event, include_paused=True): - return - - while True: - if buffer_state.done_event.is_set() and buffer_state.queue.empty(): - return - - try: - event = buffer_state.queue.get(timeout=0.1) - except queue.Empty: - current_time = time.time() - if current_time - last_msg_time > idle_timeout: - logger.debug( - "No workflow events received for %s seconds, keeping stream open", - idle_timeout, - ) - last_msg_time = current_time - if current_time - last_ping_time >= ping_interval: - yield StreamEvent.PING.value - last_ping_time = current_time - continue - - last_msg_time = time.time() - last_ping_time = last_msg_time - yield event - if _is_terminal_event(event, include_paused=True): - return - finally: - buffer_state.stop_event.set() - - return _generate() - - -def _get_message_context(session_maker: sessionmaker[Session], workflow_run_id: str) -> MessageContext | None: - with session_maker() as session: - stmt = select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(desc(Message.created_at)) - message = session.scalar(stmt) - if message is None: - return None - created_at = int(message.created_at.timestamp()) if message.created_at else 0 - return MessageContext( - conversation_id=message.conversation_id, - message_id=message.id, - created_at=created_at, - answer=message.answer, - ) - - -def _load_resumption_context(pause_entity: WorkflowPauseEntity | None) -> WorkflowResumptionContext | None: - if pause_entity is None: - return None - try: - raw_state = pause_entity.get_state().decode() - return WorkflowResumptionContext.loads(raw_state) - except Exception: - logger.exception("Failed to load resumption context") - return None - - -def _resolve_task_id( - resumption_context: WorkflowResumptionContext | None, - buffer_state: BufferState | None, - workflow_run_id: str, - wait_timeout: float = 0.2, -) -> str: - if resumption_context is not None: - generate_entity = resumption_context.get_generate_entity() - if generate_entity.task_id: - return generate_entity.task_id - if buffer_state is None: - return workflow_run_id - if buffer_state.task_id_hint is None: - buffer_state.task_id_ready.wait(timeout=wait_timeout) - if buffer_state.task_id_hint: - return buffer_state.task_id_hint - return workflow_run_id - - -def _build_snapshot_events( - *, - workflow_run: WorkflowRun, - node_snapshots: Sequence[WorkflowNodeExecutionSnapshot], - task_id: str, - message_context: MessageContext | None, - pause_entity: WorkflowPauseEntity | None, - resumption_context: WorkflowResumptionContext | None, -) -> list[Mapping[str, Any]]: - events: list[Mapping[str, Any]] = [] - - workflow_started = _build_workflow_started_event( - workflow_run=workflow_run, - task_id=task_id, - ) - _apply_message_context(workflow_started, message_context) - events.append(workflow_started) - - if message_context is not None and message_context.answer is not None: - message_replace = _build_message_replace_event(task_id=task_id, answer=message_context.answer) - _apply_message_context(message_replace, message_context) - events.append(message_replace) - - for snapshot in node_snapshots: - node_started = _build_node_started_event( - workflow_run_id=workflow_run.id, - task_id=task_id, - snapshot=snapshot, - ) - _apply_message_context(node_started, message_context) - events.append(node_started) - - if snapshot.status != WorkflowNodeExecutionStatus.RUNNING.value: - node_finished = _build_node_finished_event( - workflow_run_id=workflow_run.id, - task_id=task_id, - snapshot=snapshot, - ) - _apply_message_context(node_finished, message_context) - events.append(node_finished) - - if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None: - pause_event = _build_pause_event( - workflow_run=workflow_run, - workflow_run_id=workflow_run.id, - task_id=task_id, - pause_entity=pause_entity, - resumption_context=resumption_context, - ) - if pause_event is not None: - _apply_message_context(pause_event, message_context) - events.append(pause_event) - - return events - - -def _build_workflow_started_event( - *, - workflow_run: WorkflowRun, - task_id: str, -) -> dict[str, Any]: - response = WorkflowStartStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=WorkflowStartStreamResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - inputs=workflow_run.inputs_dict or {}, - created_at=int(workflow_run.created_at.timestamp()), - reason=WorkflowStartReason.INITIAL, - ), - ) - payload = response.model_dump(mode="json") - payload["event"] = response.event.value - return payload - - -def _build_message_replace_event(*, task_id: str, answer: str) -> dict[str, Any]: - response = MessageReplaceStreamResponse( - task_id=task_id, - answer=answer, - reason="", - ) - payload = response.model_dump(mode="json") - payload["event"] = response.event.value - return payload - - -def _build_node_started_event( - *, - workflow_run_id: str, - task_id: str, - snapshot: WorkflowNodeExecutionSnapshot, -) -> dict[str, Any]: - created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0 - response = NodeStartStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run_id, - data=NodeStartStreamResponse.Data( - id=snapshot.execution_id, - node_id=snapshot.node_id, - node_type=snapshot.node_type, - title=snapshot.title, - index=snapshot.index, - predecessor_node_id=None, - inputs=None, - created_at=created_at, - extras={}, - iteration_id=snapshot.iteration_id, - loop_id=snapshot.loop_id, - ), - ) - return response.to_ignore_detail_dict() - - -def _build_node_finished_event( - *, - workflow_run_id: str, - task_id: str, - snapshot: WorkflowNodeExecutionSnapshot, -) -> dict[str, Any]: - created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0 - finished_at = int(snapshot.finished_at.timestamp()) if snapshot.finished_at else created_at - response = NodeFinishStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run_id, - data=NodeFinishStreamResponse.Data( - id=snapshot.execution_id, - node_id=snapshot.node_id, - node_type=snapshot.node_type, - title=snapshot.title, - index=snapshot.index, - predecessor_node_id=None, - inputs=None, - process_data=None, - outputs=None, - status=snapshot.status, - error=None, - elapsed_time=snapshot.elapsed_time, - execution_metadata=None, - created_at=created_at, - finished_at=finished_at, - files=[], - iteration_id=snapshot.iteration_id, - loop_id=snapshot.loop_id, - ), - ) - return response.to_ignore_detail_dict() - - -def _build_pause_event( - *, - workflow_run: WorkflowRun, - workflow_run_id: str, - task_id: str, - pause_entity: WorkflowPauseEntity, - resumption_context: WorkflowResumptionContext | None, -) -> dict[str, Any] | None: - paused_nodes: list[str] = [] - outputs: dict[str, Any] = {} - if resumption_context is not None: - state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) - paused_nodes = state.get_paused_nodes() - outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {})) - - reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()] - response = WorkflowPauseStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run_id, - data=WorkflowPauseStreamResponse.Data( - workflow_run_id=workflow_run_id, - paused_nodes=paused_nodes, - outputs=outputs, - reasons=reasons, - status=workflow_run.status.value, - created_at=int(workflow_run.created_at.timestamp()), - elapsed_time=float(workflow_run.elapsed_time or 0.0), - total_tokens=int(workflow_run.total_tokens or 0), - total_steps=int(workflow_run.total_steps or 0), - ), - ) - payload = response.model_dump(mode="json") - payload["event"] = response.event.value - return payload - - -def _apply_message_context(payload: dict[str, Any], message_context: MessageContext | None) -> None: - if message_context is None: - return - payload["conversation_id"] = message_context.conversation_id - payload["message_id"] = message_context.message_id - payload["created_at"] = message_context.created_at - - -def _start_buffering(subscription) -> BufferState: - buffer_state = BufferState( - queue=queue.Queue(maxsize=2048), - stop_event=threading.Event(), - done_event=threading.Event(), - task_id_ready=threading.Event(), - ) - - def _worker() -> None: - dropped_count = 0 - try: - while not buffer_state.stop_event.is_set(): - msg = subscription.receive(timeout=0.1) - if msg is None: - continue - event = _parse_event_message(msg) - if event is None: - continue - task_id = event.get("task_id") - if task_id and buffer_state.task_id_hint is None: - buffer_state.task_id_hint = str(task_id) - buffer_state.task_id_ready.set() - try: - buffer_state.queue.put_nowait(event) - except queue.Full: - dropped_count += 1 - try: - buffer_state.queue.get_nowait() - except queue.Empty: - pass - try: - buffer_state.queue.put_nowait(event) - except queue.Full: - continue - logger.warning("Dropped buffered workflow event, total_dropped=%s", dropped_count) - except Exception: - logger.exception("Failed while buffering workflow events") - finally: - buffer_state.done_event.set() - - thread = threading.Thread(target=_worker, name=f"workflow-event-buffer-{id(subscription)}", daemon=True) - thread.start() - return buffer_state - - -def _parse_event_message(message: bytes) -> Mapping[str, Any] | None: - try: - event = json.loads(message) - except json.JSONDecodeError: - logger.warning("Failed to decode workflow event payload") - return None - if not isinstance(event, dict): - return None - return event - - -def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool: - if not isinstance(event, Mapping): - return False - event_type = event.get("event") - if event_type == StreamEvent.WORKFLOW_FINISHED.value: - return True - if include_paused: - return event_type == StreamEvent.WORKFLOW_PAUSED.value - return False diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 4e1e515de5..6404136994 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,5 +1,4 @@ import json -import logging import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence @@ -12,34 +11,21 @@ from configs import dify_config from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.variables import VariableBase from core.variables.variables import Variable -from core.workflow.entities import GraphInitParams, WorkflowNodeExecution -from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent from core.workflow.node_events import NodeRunResult from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node -from core.workflow.nodes.human_input.entities import ( - DeliveryChannelConfig, - HumanInputNodeData, - apply_debug_email_recipient, - validate_human_input_submission, -) -from core.workflow.nodes.human_input.enums import HumanInputFormKind -from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.repositories.human_input_form_repository import FormCreateParams -from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated @@ -48,8 +34,6 @@ from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account -from models.enums import UserFrom -from models.human_input import HumanInputFormRecipient, RecipientType from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType @@ -60,13 +44,6 @@ from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededEr from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from .human_input_delivery_test_service import ( - DeliveryTestContext, - DeliveryTestEmailRecipient, - DeliveryTestError, - DeliveryTestUnsupportedError, - HumanInputDeliveryTestService, -) from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService @@ -767,344 +744,6 @@ class WorkflowService: return workflow_node_execution - def get_human_input_form_preview( - self, - *, - app_model: App, - account: Account, - node_id: str, - inputs: Mapping[str, Any] | None = None, - ) -> Mapping[str, Any]: - """ - Build a human input form preview for a draft workflow. - - Args: - app_model: Target application model. - account: Current account. - node_id: Human input node ID. - inputs: Values used to fill missing upstream variables referenced in form_content. - """ - draft_workflow = self.get_draft_workflow(app_model=app_model) - if not draft_workflow: - raise ValueError("Workflow not initialized") - - node_config = draft_workflow.get_node_config_by_id(node_id) - node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: - raise ValueError("Node type must be human-input.") - - # inputs: values used to fill missing upstream variables referenced in form_content. - variable_pool = self._build_human_input_variable_pool( - app_model=app_model, - workflow=draft_workflow, - node_config=node_config, - manual_inputs=inputs or {}, - ) - node = self._build_human_input_node( - workflow=draft_workflow, - account=account, - node_config=node_config, - variable_pool=variable_pool, - ) - - rendered_content = node.render_form_content_before_submission() - resolved_default_values = node.resolve_default_values() - node_data = node.node_data - human_input_required = HumanInputRequired( - form_id=node_id, - form_content=rendered_content, - inputs=node_data.inputs, - actions=node_data.user_actions, - node_id=node_id, - node_title=node.title, - resolved_default_values=resolved_default_values, - form_token=None, - ) - return human_input_required.model_dump(mode="json") - - def submit_human_input_form_preview( - self, - *, - app_model: App, - account: Account, - node_id: str, - form_inputs: Mapping[str, Any], - inputs: Mapping[str, Any] | None = None, - action: str, - ) -> Mapping[str, Any]: - """ - Submit a human input form preview for a draft workflow. - - Args: - app_model: Target application model. - account: Current account. - node_id: Human input node ID. - form_inputs: Values the user provides for the form's own fields. - inputs: Values used to fill missing upstream variables referenced in form_content. - action: Selected action ID. - """ - draft_workflow = self.get_draft_workflow(app_model=app_model) - if not draft_workflow: - raise ValueError("Workflow not initialized") - - node_config = draft_workflow.get_node_config_by_id(node_id) - node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: - raise ValueError("Node type must be human-input.") - - # inputs: values used to fill missing upstream variables referenced in form_content. - # form_inputs: values the user provides for the form's own fields. - variable_pool = self._build_human_input_variable_pool( - app_model=app_model, - workflow=draft_workflow, - node_config=node_config, - manual_inputs=inputs or {}, - ) - node = self._build_human_input_node( - workflow=draft_workflow, - account=account, - node_config=node_config, - variable_pool=variable_pool, - ) - node_data = node.node_data - - validate_human_input_submission( - inputs=node_data.inputs, - user_actions=node_data.user_actions, - selected_action_id=action, - form_data=form_inputs, - ) - - rendered_content = node.render_form_content_before_submission() - outputs: dict[str, Any] = dict(form_inputs) - outputs["__action_id"] = action - outputs["__rendered_content"] = node.render_form_content_with_outputs( - rendered_content, outputs, node_data.outputs_field_names() - ) - - enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) - enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None - with Session(bind=db.engine) as session, session.begin(): - draft_var_saver = DraftVariableSaver( - session=session, - app_id=app_model.id, - node_id=node_id, - node_type=NodeType.HUMAN_INPUT, - node_execution_id=str(uuid.uuid4()), - user=account, - enclosing_node_id=enclosing_node_id, - ) - draft_var_saver.save(outputs=outputs, process_data={}) - session.commit() - - return outputs - - def test_human_input_delivery( - self, - *, - app_model: App, - account: Account, - node_id: str, - delivery_method_id: str, - inputs: Mapping[str, Any] | None = None, - ) -> None: - draft_workflow = self.get_draft_workflow(app_model=app_model) - if not draft_workflow: - raise ValueError("Workflow not initialized") - - node_config = draft_workflow.get_node_config_by_id(node_id) - node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: - raise ValueError("Node type must be human-input.") - - node_data = HumanInputNodeData.model_validate(node_config.get("data", {})) - delivery_method = self._resolve_human_input_delivery_method( - node_data=node_data, - delivery_method_id=delivery_method_id, - ) - if delivery_method is None: - raise ValueError("Delivery method not found.") - delivery_method = apply_debug_email_recipient( - delivery_method, - enabled=True, - user_id=account.id or "", - ) - - variable_pool = self._build_human_input_variable_pool( - app_model=app_model, - workflow=draft_workflow, - node_config=node_config, - manual_inputs=inputs or {}, - ) - node = self._build_human_input_node( - workflow=draft_workflow, - account=account, - node_config=node_config, - variable_pool=variable_pool, - ) - rendered_content = node.render_form_content_before_submission() - resolved_default_values = node.resolve_default_values() - form_id, recipients = self._create_human_input_delivery_test_form( - app_model=app_model, - node_id=node_id, - node_data=node_data, - delivery_method=delivery_method, - rendered_content=rendered_content, - resolved_default_values=resolved_default_values, - ) - test_service = HumanInputDeliveryTestService() - context = DeliveryTestContext( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - node_id=node_id, - node_title=node_data.title, - rendered_content=rendered_content, - template_vars={"form_id": form_id}, - recipients=recipients, - variable_pool=variable_pool, - ) - try: - test_service.send_test(context=context, method=delivery_method) - except DeliveryTestUnsupportedError as exc: - raise ValueError("Delivery method does not support test send.") from exc - except DeliveryTestError as exc: - raise ValueError(str(exc)) from exc - - @staticmethod - def _resolve_human_input_delivery_method( - *, - node_data: HumanInputNodeData, - delivery_method_id: str, - ) -> DeliveryChannelConfig | None: - for method in node_data.delivery_methods: - if str(method.id) == delivery_method_id: - return method - return None - - def _create_human_input_delivery_test_form( - self, - *, - app_model: App, - node_id: str, - node_data: HumanInputNodeData, - delivery_method: DeliveryChannelConfig, - rendered_content: str, - resolved_default_values: Mapping[str, Any], - ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(session_factory=db.engine, tenant_id=app_model.tenant_id) - params = FormCreateParams( - app_id=app_model.id, - workflow_execution_id=None, - node_id=node_id, - form_config=node_data, - rendered_content=rendered_content, - delivery_methods=[delivery_method], - display_in_ui=False, - resolved_default_values=resolved_default_values, - form_kind=HumanInputFormKind.DELIVERY_TEST, - ) - form_entity = repo.create_form(params) - return form_entity.id, self._load_email_recipients(form_entity.id) - - @staticmethod - def _load_email_recipients(form_id: str) -> list[DeliveryTestEmailRecipient]: - logger = logging.getLogger(__name__) - - with Session(bind=db.engine) as session: - recipients = session.scalars( - select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id) - ).all() - recipients_data: list[DeliveryTestEmailRecipient] = [] - for recipient in recipients: - if recipient.recipient_type not in {RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL}: - continue - if not recipient.access_token: - continue - try: - payload = json.loads(recipient.recipient_payload) - except Exception: - logger.exception("Failed to parse human input recipient payload for delivery test.") - continue - email = payload.get("email") - if email: - recipients_data.append(DeliveryTestEmailRecipient(email=email, form_token=recipient.access_token)) - return recipients_data - - def _build_human_input_node( - self, - *, - workflow: Workflow, - account: Account, - node_config: Mapping[str, Any], - variable_pool: VariablePool, - ) -> HumanInputNode: - graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - graph_config=workflow.graph_dict, - user_id=account.id, - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.DEBUGGER.value, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=time.perf_counter(), - ) - node = HumanInputNode( - id=node_config.get("id", str(uuid.uuid4())), - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - return node - - def _build_human_input_variable_pool( - self, - *, - app_model: App, - workflow: Workflow, - node_config: Mapping[str, Any], - manual_inputs: Mapping[str, Any], - ) -> VariablePool: - with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): - draft_var_srv = WorkflowDraftVariableService(session) - draft_var_srv.prefill_conversation_variable_default_values(workflow) - - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=[], - ) - - variable_loader = DraftVarLoader( - engine=db.engine, - app_id=app_model.id, - tenant_id=app_model.tenant_id, - ) - variable_mapping = HumanInputNode.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, - config=node_config, - ) - normalized_user_inputs: dict[str, Any] = dict(manual_inputs) - - load_into_variable_pool( - variable_loader=variable_loader, - variable_pool=variable_pool, - variable_mapping=variable_mapping, - user_inputs=normalized_user_inputs, - ) - WorkflowEntry.mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=normalized_user_inputs, - variable_pool=variable_pool, - tenant_id=app_model.tenant_id, - ) - - return variable_pool - def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: @@ -1306,13 +945,6 @@ class WorkflowService: if any(nt.is_trigger_node for nt in node_types): raise ValueError("Start node and trigger nodes cannot coexist in the same workflow") - for node in node_configs: - node_data = node.get("data", {}) - node_type = node_data.get("type") - - if node_type == NodeType.HUMAN_INPUT: - self._validate_human_input_node_data(node_data) - def validate_features_structure(self, app_model: App, features: dict): if app_model.mode == AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( @@ -1325,23 +957,6 @@ class WorkflowService: else: raise ValueError(f"Invalid app mode: {app_model.mode}") - def _validate_human_input_node_data(self, node_data: dict) -> None: - """ - Validate HumanInput node data format. - - Args: - node_data: The node data dictionary - - Raises: - ValueError: If the node data format is invalid - """ - from core.workflow.nodes.human_input.entities import HumanInputNodeData - - try: - HumanInputNodeData.model_validate(node_data) - except Exception as e: - raise ValueError(f"Invalid HumanInput node data: {str(e)}") - def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict ) -> Workflow | None: diff --git a/api/tasks/app_generate/__init__.py b/api/tasks/app_generate/__init__.py deleted file mode 100644 index 4aa02ef39f..0000000000 --- a/api/tasks/app_generate/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .workflow_execute_task import AppExecutionParams, resume_app_execution, workflow_based_app_execution_task - -__all__ = ["AppExecutionParams", "resume_app_execution", "workflow_based_app_execution_task"] diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py deleted file mode 100644 index e58d334f41..0000000000 --- a/api/tasks/app_generate/workflow_execute_task.py +++ /dev/null @@ -1,491 +0,0 @@ -import contextlib -import logging -import uuid -from collections.abc import Generator, Mapping -from enum import StrEnum -from typing import Annotated, Any, TypeAlias, Union - -from celery import shared_task -from flask import current_app, json -from pydantic import BaseModel, Discriminator, Field, Tag -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, sessionmaker - -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.apps.workflow.app_generator import WorkflowAppGenerator -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, - WorkflowAppGenerateEntity, -) -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext -from core.repositories import DifyCoreRepositoryFactory -from core.workflow.runtime import GraphRuntimeState -from extensions.ext_database import db -from libs.flask_utils import set_login_user -from models.account import Account -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom -from models.model import App, AppMode, Conversation, EndUser, Message -from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun -from repositories.factory import DifyAPIRepositoryFactory - -logger = logging.getLogger(__name__) - -WORKFLOW_BASED_APP_EXECUTION_QUEUE = "workflow_based_app_execution" - - -class _UserType(StrEnum): - ACCOUNT = "account" - END_USER = "end_user" - - -class _Account(BaseModel): - TYPE: _UserType = _UserType.ACCOUNT - - user_id: str - - -class _EndUser(BaseModel): - TYPE: _UserType = _UserType.END_USER - end_user_id: str - - -def _get_user_type_descriminator(value: Any): - if isinstance(value, (_Account, _EndUser)): - return value.TYPE - elif isinstance(value, dict): - user_type_str = value.get("TYPE") - if user_type_str is None: - return None - try: - user_type = _UserType(user_type_str) - except ValueError: - return None - return user_type - else: - # return None if the discriminator value isn't found - return None - - -User: TypeAlias = Annotated[ - (Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]), - Discriminator(_get_user_type_descriminator), -] - - -class AppExecutionParams(BaseModel): - app_id: str - workflow_id: str - tenant_id: str - app_mode: AppMode = AppMode.ADVANCED_CHAT - user: User - args: Mapping[str, Any] - - invoke_from: InvokeFrom - streaming: bool = True - call_depth: int = 0 - root_node_id: str | None = None - workflow_run_id: str = Field(default_factory=lambda: str(uuid.uuid4())) - - @classmethod - def new( - cls, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: Mapping[str, Any], - invoke_from: InvokeFrom, - streaming: bool = True, - call_depth: int = 0, - root_node_id: str | None = None, - workflow_run_id: str | None = None, - ): - user_params: _Account | _EndUser - if isinstance(user, Account): - user_params = _Account(user_id=user.id) - elif isinstance(user, EndUser): - user_params = _EndUser(end_user_id=user.id) - else: - raise AssertionError("this statement should be unreachable.") - return cls( - app_id=app_model.id, - workflow_id=workflow.id, - tenant_id=app_model.tenant_id, - app_mode=AppMode.value_of(app_model.mode), - user=user_params, - args=args, - invoke_from=invoke_from, - streaming=streaming, - call_depth=call_depth, - root_node_id=root_node_id, - workflow_run_id=workflow_run_id or str(uuid.uuid4()), - ) - - -class _AppRunner: - def __init__(self, session_factory: sessionmaker | Engine, exec_params: AppExecutionParams): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - self._exec_params = exec_params - - @contextlib.contextmanager - def _session(self): - with self._session_factory(expire_on_commit=False) as session, session.begin(): - yield session - - @contextlib.contextmanager - def _setup_flask_context(self, user: Account | EndUser): - flask_app = current_app._get_current_object() # type: ignore - with flask_app.app_context(): - set_login_user(user) - yield - - def run(self): - exec_params = self._exec_params - with self._session() as session: - workflow = session.get(Workflow, exec_params.workflow_id) - if workflow is None: - logger.warning("Workflow %s not found for execution", exec_params.workflow_id) - return None - app = session.get(App, workflow.app_id) - if app is None: - logger.warning("App %s not found for workflow %s", workflow.app_id, exec_params.workflow_id) - return None - - pause_config = PauseStateLayerConfig( - session_factory=self._session_factory, - state_owner_user_id=workflow.created_by, - ) - - user = self._resolve_user() - - with self._setup_flask_context(user): - response = self._run_app( - app=app, - workflow=workflow, - user=user, - pause_state_config=pause_config, - ) - if not exec_params.streaming: - return response - - assert isinstance(response, Generator) - _publish_streaming_response(response, exec_params.workflow_run_id, exec_params.app_mode) - - def _run_app( - self, - *, - app: App, - workflow: Workflow, - user: Account | EndUser, - pause_state_config: PauseStateLayerConfig, - ): - exec_params = self._exec_params - if exec_params.app_mode == AppMode.ADVANCED_CHAT: - return AdvancedChatAppGenerator().generate( - app_model=app, - workflow=workflow, - user=user, - args=exec_params.args, - invoke_from=exec_params.invoke_from, - streaming=exec_params.streaming, - workflow_run_id=exec_params.workflow_run_id, - pause_state_config=pause_state_config, - ) - if exec_params.app_mode == AppMode.WORKFLOW: - return WorkflowAppGenerator().generate( - app_model=app, - workflow=workflow, - user=user, - args=exec_params.args, - invoke_from=exec_params.invoke_from, - streaming=exec_params.streaming, - call_depth=exec_params.call_depth, - root_node_id=exec_params.root_node_id, - workflow_run_id=exec_params.workflow_run_id, - pause_state_config=pause_state_config, - ) - - logger.error("Unsupported app mode for execution: %s", exec_params.app_mode) - return None - - def _resolve_user(self) -> Account | EndUser: - user_params = self._exec_params.user - - if isinstance(user_params, _EndUser): - with self._session() as session: - return session.get(EndUser, user_params.end_user_id) - elif not isinstance(user_params, _Account): - raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}") - - with self._session() as session: - user: Account = session.get(Account, user_params.user_id) - user.set_tenant_id(self._exec_params.tenant_id) - - return user - - -def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None: - role = CreatorUserRole(workflow_run.created_by_role) - if role == CreatorUserRole.ACCOUNT: - user = session.get(Account, workflow_run.created_by) - if user: - user.set_tenant_id(workflow_run.tenant_id) - return user - - return session.get(EndUser, workflow_run.created_by) - - -def _publish_streaming_response( - response_stream: Generator[str | Mapping[str, Any], None, None], workflow_run_id: str, app_mode: AppMode -) -> None: - topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id) - for event in response_stream: - try: - payload = json.dumps(event) - except TypeError: - logger.exception("error while encoding event") - continue - - topic.publish(payload.encode()) - - -@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE) -def workflow_based_app_execution_task( - payload: str, -) -> Generator[Mapping[str, Any] | str, None, None] | Mapping[str, Any] | None: - exec_params = AppExecutionParams.model_validate_json(payload) - - logger.info("workflow_based_app_execution_task run with params: %s", exec_params) - - runner = _AppRunner(db.engine, exec_params=exec_params) - return runner.run() - - -def _resume_app_execution(payload: dict[str, Any]) -> None: - workflow_run_id = payload["workflow_run_id"] - - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory) - - pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) - if pause_entity is None: - logger.warning("No pause entity found for workflow run %s", workflow_run_id) - return - - try: - resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) - except Exception: - logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id) - return - - generate_entity = resumption_context.get_generate_entity() - - graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) - - conversation = None - message = None - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = session.get(WorkflowRun, workflow_run_id) - if workflow_run is None: - logger.warning("Workflow run %s not found during resume", workflow_run_id) - return - - workflow = session.get(Workflow, workflow_run.workflow_id) - if workflow is None: - logger.warning("Workflow %s not found during resume", workflow_run.workflow_id) - return - - app_model = session.get(App, workflow_run.app_id) - if app_model is None: - logger.warning("App %s not found during resume", workflow_run.app_id) - return - - user = _resolve_user_for_run(session, workflow_run) - if user is None: - logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id) - return - - if isinstance(generate_entity, AdvancedChatAppGenerateEntity): - if generate_entity.conversation_id is None: - logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id) - return - - conversation = session.get(Conversation, generate_entity.conversation_id) - if conversation is None: - logger.warning( - "Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id - ) - return - - message = session.scalar( - select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc()) - ) - if message is None: - logger.warning("Message not found for workflow run %s", workflow_run_id) - return - - if not isinstance(generate_entity, (AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity)): - logger.error( - "Unsupported resumption entity for workflow run %s (found %s)", - workflow_run_id, - type(generate_entity), - ) - return - - workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity) - - pause_config = PauseStateLayerConfig( - session_factory=session_factory, - state_owner_user_id=workflow.created_by, - ) - - if isinstance(generate_entity, AdvancedChatAppGenerateEntity): - assert conversation is not None - assert message is not None - _resume_advanced_chat( - app_model=app_model, - workflow=workflow, - user=user, - conversation=conversation, - message=message, - generate_entity=generate_entity, - graph_runtime_state=graph_runtime_state, - session_factory=session_factory, - pause_state_config=pause_config, - workflow_run_id=workflow_run_id, - workflow_run=workflow_run, - ) - elif isinstance(generate_entity, WorkflowAppGenerateEntity): - _resume_workflow( - app_model=app_model, - workflow=workflow, - user=user, - generate_entity=generate_entity, - graph_runtime_state=graph_runtime_state, - session_factory=session_factory, - pause_state_config=pause_config, - workflow_run_id=workflow_run_id, - workflow_run=workflow_run, - workflow_run_repo=workflow_run_repo, - pause_entity=pause_entity, - ) - - -def _resume_advanced_chat( - *, - app_model: App, - workflow: Workflow, - user: Account | EndUser, - conversation: Conversation, - message: Message, - generate_entity: AdvancedChatAppGenerateEntity, - graph_runtime_state: GraphRuntimeState, - session_factory: sessionmaker, - pause_state_config: PauseStateLayerConfig, - workflow_run_id: str, - workflow_run: WorkflowRun, -) -> None: - try: - triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) - except ValueError: - triggered_from = WorkflowRunTriggeredFrom.APP_RUN - - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=app_model.id, - triggered_from=triggered_from, - ) - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=app_model.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - generator = AdvancedChatAppGenerator() - - try: - response = generator.resume( - app_model=app_model, - workflow=workflow, - user=user, - conversation=conversation, - message=message, - application_generate_entity=generate_entity, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - graph_runtime_state=graph_runtime_state, - pause_state_config=pause_state_config, - ) - except Exception: - logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id) - raise - - if generate_entity.stream: - assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT) - - -def _resume_workflow( - *, - app_model: App, - workflow: Workflow, - user: Account | EndUser, - generate_entity: WorkflowAppGenerateEntity, - graph_runtime_state: GraphRuntimeState, - session_factory: sessionmaker, - pause_state_config: PauseStateLayerConfig, - workflow_run_id: str, - workflow_run: WorkflowRun, - workflow_run_repo, - pause_entity, -) -> None: - try: - triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) - except ValueError: - triggered_from = WorkflowRunTriggeredFrom.APP_RUN - - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=app_model.id, - triggered_from=triggered_from, - ) - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=app_model.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - generator = WorkflowAppGenerator() - - try: - response = generator.resume( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=generate_entity, - graph_runtime_state=graph_runtime_state, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - pause_state_config=pause_state_config, - ) - except Exception: - logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id) - raise - - if generate_entity.stream: - assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW) - - workflow_run_repo.delete_workflow_pause(pause_entity) - - -@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution") -def resume_app_execution(payload: dict[str, Any]) -> None: - _resume_app_execution(payload) diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index cc96542d4b..b51884148e 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -5,42 +5,32 @@ These tasks handle workflow execution for different subscription tiers with appropriate retry policies and error handling. """ -import logging from datetime import UTC, datetime from typing import Any from celery import shared_task from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session from configs import dify_config from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext -from core.app.layers.timeslice_layer import TimeSliceLayer +from core.app.entities.app_invoke_entities import InvokeFrom from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory -from core.repositories import DifyCoreRepositoryFactory -from core.workflow.runtime import GraphRuntimeState -from extensions.ext_database import db from models.account import Account -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus +from models.enums import CreatorUserRole, WorkflowTriggerStatus from models.model import App, EndUser, Tenant from models.trigger import WorkflowTriggerLog -from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun -from repositories.factory import DifyAPIRepositoryFactory +from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import WorkflowNotFoundError from services.workflow.entities import ( TriggerData, - WorkflowResumeTaskData, WorkflowTaskData, ) from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy -logger = logging.getLogger(__name__) - @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) def execute_workflow_professional(task_data_dict: dict[str, Any]): @@ -151,11 +141,6 @@ def _execute_workflow_common( if trigger_data.workflow_id: args["workflow_id"] = str(trigger_data.workflow_id) - pause_config = PauseStateLayerConfig( - session_factory=session_factory.get_session_maker(), - state_owner_user_id=workflow.created_by, - ) - # Execute the workflow with the trigger type generator.generate( app_model=app_model, @@ -171,7 +156,6 @@ def _execute_workflow_common( # TODO: Re-enable TimeSliceLayer after the HITL release. TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id), ], - pause_state_config=pause_config, ) except Exception as e: @@ -189,153 +173,21 @@ def _execute_workflow_common( session.commit() -@shared_task(name="resume_workflow_execution") -def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None: - """Resume a paused workflow run via Celery.""" - task_data = WorkflowResumeTaskData.model_validate(task_data_dict) - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory) - - pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id) - if pause_entity is None: - logger.warning("No pause state for workflow run %s", task_data.workflow_run_id) - return - workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(pause_entity.workflow_execution_id) - if workflow_run is None: - logger.warning("Workflow run not found for pause entity: pause_entity_id=%s", pause_entity.id) - return - - try: - resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) - except Exception as exc: - logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id) - raise exc - - generate_entity = resumption_context.get_generate_entity() - if not isinstance(generate_entity, WorkflowAppGenerateEntity): - logger.error( - "Unsupported resumption entity for workflow run %s: %s", - task_data.workflow_run_id, - type(generate_entity), - ) - return - - graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) - - with session_factory() as session: - workflow = session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) - if workflow is None: - raise WorkflowNotFoundError( - "Workflow not found: workflow_run_id=%s, workflow_id=%s", workflow_run.id, workflow_run.workflow_id - ) - user = _get_user(session, workflow_run) - app_model = session.scalar(select(App).where(App.id == workflow_run.app_id)) - if app_model is None: - raise _AppNotFoundError( - "App not found: app_id=%s, workflow_run_id=%s", workflow_run.app_id, workflow_run.id - ) - - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=generate_entity.app_config.app_id, - triggered_from=WorkflowRunTriggeredFrom(workflow_run.triggered_from), - ) - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - pause_config = PauseStateLayerConfig( - session_factory=session_factory, - state_owner_user_id=workflow.created_by, - ) - - generator = WorkflowAppGenerator() - start_time = datetime.now(UTC) - graph_engine_layers = [] - trigger_log = _query_trigger_log_info(session_factory, task_data.workflow_run_id) - - if trigger_log: - cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity( - queue=AsyncWorkflowQueue(trigger_log.queue_name), - schedule_strategy=AsyncWorkflowSystemStrategy, - granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY, - ) - cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity) - - graph_engine_layers.extend( - [ - TimeSliceLayer(cfs_plan_scheduler), - TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id), - ] - ) - - workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity) - - generator.resume( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=generate_entity, - graph_runtime_state=graph_runtime_state, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - graph_engine_layers=graph_engine_layers, - pause_state_config=pause_config, - ) - workflow_run_repo.delete_workflow_pause(pause_entity) - - -def _get_user(session: Session, workflow_run: WorkflowRun | WorkflowTriggerLog) -> Account | EndUser: +def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser: """Compose user from trigger log""" - tenant = session.scalar(select(Tenant).where(Tenant.id == workflow_run.tenant_id)) + tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id)) if not tenant: - raise _TenantNotFoundError( - "Tenant not found for WorkflowRun: tenant_id=%s, workflow_run_id=%s", - workflow_run.tenant_id, - workflow_run.id, - ) + raise ValueError(f"Tenant not found: {trigger_log.tenant_id}") # Get user from trigger log - if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: - user = session.scalar(select(Account).where(Account.id == workflow_run.created_by)) + if trigger_log.created_by_role == CreatorUserRole.ACCOUNT: + user = session.scalar(select(Account).where(Account.id == trigger_log.created_by)) if user: user.current_tenant = tenant else: # CreatorUserRole.END_USER - user = session.scalar(select(EndUser).where(EndUser.id == workflow_run.created_by)) + user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by)) if not user: - raise _UserNotFoundError( - "User not found: user_id=%s, created_by_role=%s, workflow_run_id=%s", - workflow_run.created_by, - workflow_run.created_by_role, - workflow_run.id, - ) + raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})") return user - - -def _query_trigger_log_info(session_factory: sessionmaker[Session], workflow_run_id) -> WorkflowTriggerLog | None: - with session_factory() as session, session.begin(): - trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) - trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id) - if not trigger_log: - logger.debug("Trigger log not found for workflow_run: workflow_run_id=%s", workflow_run_id) - return None - - return trigger_log - - -class _TenantNotFoundError(Exception): - pass - - -class _UserNotFoundError(Exception): - pass - - -class _AppNotFoundError(Exception): - pass diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py deleted file mode 100644 index 0c40877309..0000000000 --- a/api/tasks/human_input_timeout_tasks.py +++ /dev/null @@ -1,113 +0,0 @@ -import logging -from datetime import timedelta - -from celery import shared_task -from sqlalchemy import or_, select -from sqlalchemy.orm import sessionmaker - -from configs import dify_config -from core.repositories.human_input_repository import HumanInputFormSubmissionRepository -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from extensions.ext_database import db -from extensions.ext_storage import storage -from libs.datetime_utils import ensure_naive_utc, naive_utc_now -from models.human_input import HumanInputForm -from models.workflow import WorkflowPause, WorkflowRun -from services.human_input_service import HumanInputService - -logger = logging.getLogger(__name__) - - -def _is_global_timeout(form_model: HumanInputForm, global_timeout_seconds: int, *, now) -> bool: - if global_timeout_seconds <= 0: - return False - if form_model.workflow_run_id is None: - return False - created_at = ensure_naive_utc(form_model.created_at) - global_deadline = created_at + timedelta(seconds=global_timeout_seconds) - return global_deadline <= now - - -def _handle_global_timeout(*, form_id: str, workflow_run_id: str, node_id: str, session_factory: sessionmaker) -> None: - now = naive_utc_now() - with session_factory() as session, session.begin(): - workflow_run = session.get(WorkflowRun, workflow_run_id) - if workflow_run is not None: - workflow_run.status = WorkflowExecutionStatus.STOPPED - workflow_run.error = f"Human input global timeout at node {node_id}" - workflow_run.finished_at = now - session.add(workflow_run) - - pause_model = session.scalar(select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id)) - if pause_model is not None: - try: - storage.delete(pause_model.state_object_key) - except Exception: - logger.exception( - "Failed to delete pause state object for workflow_run_id=%s, pause_id=%s", - workflow_run_id, - pause_model.id, - ) - pause_model.resumed_at = now - session.add(pause_model) - - -@shared_task(name="human_input_form_timeout.check_and_resume", queue="schedule_executor") -def check_and_handle_human_input_timeouts(limit: int = 100) -> None: - """Scan for expired human input forms and resume or end workflows.""" - - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - form_repo = HumanInputFormSubmissionRepository(session_factory) - service = HumanInputService(session_factory, form_repository=form_repo) - now = naive_utc_now() - global_timeout_seconds = dify_config.HITL_GLOBAL_TIMEOUT_SECONDS - - with session_factory() as session: - global_deadline = now - timedelta(seconds=global_timeout_seconds) if global_timeout_seconds > 0 else None - timeout_filter = HumanInputForm.expiration_time <= now - if global_deadline is not None: - timeout_filter = or_(timeout_filter, HumanInputForm.created_at <= global_deadline) - stmt = ( - select(HumanInputForm) - .where( - HumanInputForm.status == HumanInputFormStatus.WAITING, - timeout_filter, - ) - .order_by(HumanInputForm.id.asc()) - .limit(limit) - ) - expired_forms = session.scalars(stmt).all() - - for form_model in expired_forms: - try: - if form_model.form_kind == HumanInputFormKind.DELIVERY_TEST: - form_repo.mark_timeout( - form_id=form_model.id, - timeout_status=HumanInputFormStatus.TIMEOUT, - reason="delivery_test_timeout", - ) - continue - - is_global = _is_global_timeout(form_model, global_timeout_seconds, now=now) - record = form_repo.mark_timeout( - form_id=form_model.id, - timeout_status=HumanInputFormStatus.EXPIRED if is_global else HumanInputFormStatus.TIMEOUT, - reason="global_timeout" if is_global else "node_timeout", - ) - assert record.workflow_run_id is not None, "workflow_run_id should not be None for non-test form" - if is_global: - _handle_global_timeout( - form_id=record.form_id, - workflow_run_id=record.workflow_run_id, - node_id=record.node_id, - session_factory=session_factory, - ) - else: - service.enqueue_resume(record.workflow_run_id) - except Exception: - logger.exception( - "Failed to handle timeout for form_id=%s workflow_run_id=%s", - form_model.id, - form_model.workflow_run_id, - ) diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py deleted file mode 100644 index d1cd0fbadc..0000000000 --- a/api/tasks/mail_human_input_delivery_task.py +++ /dev/null @@ -1,190 +0,0 @@ -import json -import logging -import time -from dataclasses import dataclass -from typing import Any - -import click -from celery import shared_task -from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker - -from configs import dify_config -from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod -from core.workflow.runtime import GraphRuntimeState, VariablePool -from extensions.ext_database import db -from extensions.ext_mail import mail -from models.human_input import ( - DeliveryMethodType, - HumanInputDelivery, - HumanInputForm, - HumanInputFormRecipient, - RecipientType, -) -from repositories.factory import DifyAPIRepositoryFactory -from services.feature_service import FeatureService - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class _EmailRecipient: - email: str - token: str - - -@dataclass(frozen=True) -class _EmailDeliveryJob: - form_id: str - subject: str - body: str - form_content: str - recipients: list[_EmailRecipient] - - -def _build_form_link(token: str) -> str: - base_url = dify_config.APP_WEB_URL - return f"{base_url.rstrip('/')}/form/{token}" - - -def _parse_recipient_payload(payload: str) -> tuple[str | None, RecipientType | None]: - try: - payload_dict: dict[str, Any] = json.loads(payload) - except Exception: - logger.exception("Failed to parse recipient payload") - return None, None - - return payload_dict.get("email"), payload_dict.get("TYPE") - - -def _load_email_jobs(session: Session, form: HumanInputForm) -> list[_EmailDeliveryJob]: - deliveries = session.scalars( - select(HumanInputDelivery).where( - HumanInputDelivery.form_id == form.id, - HumanInputDelivery.delivery_method_type == DeliveryMethodType.EMAIL, - ) - ).all() - jobs: list[_EmailDeliveryJob] = [] - for delivery in deliveries: - delivery_config = EmailDeliveryMethod.model_validate_json(delivery.channel_payload) - - recipients = session.scalars( - select(HumanInputFormRecipient).where(HumanInputFormRecipient.delivery_id == delivery.id) - ).all() - - recipient_entities: list[_EmailRecipient] = [] - for recipient in recipients: - email, recipient_type = _parse_recipient_payload(recipient.recipient_payload) - if recipient_type not in {RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL}: - continue - if not email: - continue - token = recipient.access_token - if not token: - continue - recipient_entities.append(_EmailRecipient(email=email, token=token)) - - if not recipient_entities: - continue - - jobs.append( - _EmailDeliveryJob( - form_id=form.id, - subject=delivery_config.config.subject, - body=delivery_config.config.body, - form_content=form.rendered_content, - recipients=recipient_entities, - ) - ) - return jobs - - -def _render_body( - body_template: str, - form_link: str, - *, - variable_pool: VariablePool | None, -) -> str: - body = EmailDeliveryConfig.render_body_template( - body=body_template, - url=form_link, - variable_pool=variable_pool, - ) - return body - - -def _load_variable_pool(workflow_run_id: str | None) -> VariablePool | None: - if not workflow_run_id: - return None - - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory) - pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) - if pause_entity is None: - logger.info("No pause state found for workflow run %s", workflow_run_id) - return None - - try: - resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) - except Exception: - logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id) - return None - - graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) - return graph_runtime_state.variable_pool - - -def _open_session(session_factory: sessionmaker | Session | None): - if session_factory is None: - return Session(db.engine) - if isinstance(session_factory, Session): - return session_factory - return session_factory() - - -@shared_task(queue="mail") -def dispatch_human_input_email_task(form_id: str, node_title: str | None = None, session_factory=None): - if not mail.is_inited(): - return - - logger.info(click.style(f"Start human input email delivery for form {form_id}", fg="green")) - start_at = time.perf_counter() - - try: - with _open_session(session_factory) as session: - form = session.get(HumanInputForm, form_id) - if form is None: - logger.warning("Human input form not found, form_id=%s", form_id) - return - features = FeatureService.get_features(form.tenant_id) - if not features.human_input_email_delivery_enabled: - logger.info( - "Human input email delivery is not available for tenant=%s, form_id=%s", - form.tenant_id, - form_id, - ) - return - jobs = _load_email_jobs(session, form) - - variable_pool = _load_variable_pool(form.workflow_run_id) - - for job in jobs: - for recipient in job.recipients: - form_link = _build_form_link(recipient.token) - body = _render_body(job.body, form_link, variable_pool=variable_pool) - - mail.send( - to=recipient.email, - subject=job.subject, - html=body, - ) - - end_at = time.perf_counter() - logger.info( - click.style( - f"Human input email delivery succeeded for form {form_id}: latency: {end_at - start_at}", fg="green" - ) - ) - except Exception: - logger.exception("Send human input email failed, form_id=%s", form_id) diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 44adadeaa5..948cf8b3a0 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -1,4 +1,3 @@ -import logging import os import pathlib import random @@ -11,34 +10,26 @@ from flask.testing import FlaskClient from sqlalchemy.orm import Session from app_factory import create_app -from configs.app_config import DifyConfig from extensions.ext_database import db from models import Account, DifySetup, Tenant, TenantAccountJoin from services.account_service import AccountService, RegisterService -_DEFUALT_TEST_ENV = ".env" -_DEFAULT_VDB_TEST_ENV = "vdb.env" - -_logger = logging.getLogger(__name__) - # Loading the .env file if it exists def _load_env(): current_file_path = pathlib.Path(__file__).absolute() # Items later in the list have higher precedence. - env_file_paths = [ - os.getenv("DIFY_TEST_ENV_FILE", str(current_file_path.parent / _DEFUALT_TEST_ENV)), - os.getenv("DIFY_VDB_TEST_ENV_FILE", str(current_file_path.parent / _DEFAULT_VDB_TEST_ENV)), - ] + files_to_load = [".env", "vdb.env"] - for env_path_str in env_file_paths: - if not pathlib.Path(env_path_str).exists(): - _logger.warning("specified configuration file %s not exist", env_path_str) + env_file_paths = [current_file_path.parent / i for i in files_to_load] + for path in env_file_paths: + if not path.exists(): + continue from dotenv import load_dotenv # Set `override=True` to ensure values from `vdb.env` take priority over values from `.env` - load_dotenv(str(env_path_str), override=True) + load_dotenv(str(path), override=True) _load_env() @@ -50,12 +41,6 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs") _CACHED_APP = create_app() -@pytest.fixture(scope="session") -def dify_config() -> DifyConfig: - config = DifyConfig() # type: ignore - return config - - @pytest.fixture def flask_app() -> Flask: return _CACHED_APP diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py deleted file mode 100644 index e3f0d8a96e..0000000000 --- a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -Utilities and helpers for Redis broadcast channel integration tests. - -This module provides utility classes and functions for testing -Redis broadcast channel functionality. -""" - -from .test_data import ( - LARGE_MESSAGES, - SMALL_MESSAGES, - SPECIAL_MESSAGES, - BufferTestConfig, - ConcurrencyTestConfig, - ErrorTestConfig, -) -from .test_helpers import ( - ConcurrentPublisher, - SubscriptionMonitor, - assert_message_order, - measure_throughput, - wait_for_condition, -) - -__all__ = [ - "LARGE_MESSAGES", - "SMALL_MESSAGES", - "SPECIAL_MESSAGES", - "BufferTestConfig", - "ConcurrencyTestConfig", - "ConcurrentPublisher", - "ErrorTestConfig", - "SubscriptionMonitor", - "assert_message_order", - "measure_throughput", - "wait_for_condition", -] diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py deleted file mode 100644 index 2cccb08304..0000000000 --- a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -Test data and configuration classes for Redis broadcast channel integration tests. - -This module provides dataclasses and constants for test configurations, -message sets, and test scenarios. -""" - -import dataclasses -from typing import Any - -from libs.broadcast_channel.channel import Overflow - - -@dataclasses.dataclass(frozen=True) -class BufferTestConfig: - """Configuration for buffer management tests.""" - - buffer_size: int - overflow_strategy: Overflow - message_count: int - expected_behavior: str - description: str - - -@dataclasses.dataclass(frozen=True) -class ConcurrencyTestConfig: - """Configuration for concurrency tests.""" - - publisher_count: int - subscriber_count: int - messages_per_publisher: int - test_duration: float - description: str - - -@dataclasses.dataclass(frozen=True) -class ErrorTestConfig: - """Configuration for error handling tests.""" - - error_type: str - test_input: Any - expected_exception: type[Exception] - description: str - - -# Test message sets for different scenarios -SMALL_MESSAGES = [ - b"msg_1", - b"msg_2", - b"msg_3", - b"msg_4", - b"msg_5", -] - -MEDIUM_MESSAGES = [ - b"medium_message_1_with_more_content", - b"medium_message_2_with_more_content", - b"medium_message_3_with_more_content", - b"medium_message_4_with_more_content", - b"medium_message_5_with_more_content", -] - -LARGE_MESSAGES = [ - b"large_message_" + b"x" * 1000, - b"large_message_" + b"y" * 1000, - b"large_message_" + b"z" * 1000, -] - -VERY_LARGE_MESSAGES = [ - b"very_large_message_" + b"x" * 10000, # ~10KB - b"very_large_message_" + b"y" * 50000, # ~50KB - b"very_large_message_" + b"z" * 100000, # ~100KB -] - -SPECIAL_MESSAGES = [ - b"", # Empty message - b"\x00\x01\x02", # Binary data with null bytes - "unicode_test_你好".encode(), # Unicode - b"special_chars_!@#$%^&*()_+-=[]{}|;':\",./<>?", # Special characters - b"newlines\n\r\t", # Control characters -] - -BINARY_MESSAGES = [ - bytes(range(256)), # All possible byte values - b"\xff\xfe\xfd\xfc\xfb\xfa\xf9\xf8", # High byte values - b"\x00\x01\x02\x03\x04\x05\x06\x07", # Low byte values -] - -# Buffer test configurations -BUFFER_TEST_CONFIGS = [ - BufferTestConfig( - buffer_size=3, - overflow_strategy=Overflow.DROP_OLDEST, - message_count=5, - expected_behavior="drop_oldest", - description="Drop oldest messages when buffer is full", - ), - BufferTestConfig( - buffer_size=3, - overflow_strategy=Overflow.DROP_NEWEST, - message_count=5, - expected_behavior="drop_newest", - description="Drop newest messages when buffer is full", - ), - BufferTestConfig( - buffer_size=3, - overflow_strategy=Overflow.BLOCK, - message_count=5, - expected_behavior="block", - description="Block when buffer is full", - ), -] - -# Concurrency test configurations -CONCURRENCY_TEST_CONFIGS = [ - ConcurrencyTestConfig( - publisher_count=1, - subscriber_count=1, - messages_per_publisher=10, - test_duration=5.0, - description="Single publisher, single subscriber", - ), - ConcurrencyTestConfig( - publisher_count=3, - subscriber_count=1, - messages_per_publisher=10, - test_duration=5.0, - description="Multiple publishers, single subscriber", - ), - ConcurrencyTestConfig( - publisher_count=1, - subscriber_count=3, - messages_per_publisher=10, - test_duration=5.0, - description="Single publisher, multiple subscribers", - ), - ConcurrencyTestConfig( - publisher_count=3, - subscriber_count=3, - messages_per_publisher=10, - test_duration=5.0, - description="Multiple publishers, multiple subscribers", - ), -] - -# Error test configurations -ERROR_TEST_CONFIGS = [ - ErrorTestConfig( - error_type="invalid_buffer_size", - test_input=0, - expected_exception=ValueError, - description="Zero buffer size should raise ValueError", - ), - ErrorTestConfig( - error_type="invalid_buffer_size", - test_input=-1, - expected_exception=ValueError, - description="Negative buffer size should raise ValueError", - ), - ErrorTestConfig( - error_type="invalid_buffer_size", - test_input=1.5, - expected_exception=TypeError, - description="Float buffer size should raise TypeError", - ), - ErrorTestConfig( - error_type="invalid_buffer_size", - test_input="invalid", - expected_exception=TypeError, - description="String buffer size should raise TypeError", - ), -] - -# Topic name test cases -TOPIC_NAME_TEST_CASES = [ - "simple_topic", - "topic_with_underscores", - "topic-with-dashes", - "topic.with.dots", - "topic_with_numbers_123", - "UPPERCASE_TOPIC", - "mixed_Case_Topic", - "topic_with_symbols_!@#$%", - "very_long_topic_name_" + "x" * 100, - "unicode_topic_你好", - "topic:with:colons", - "topic/with/slashes", - "topic\\with\\backslashes", -] - -# Performance test configurations -PERFORMANCE_TEST_CONFIGS = [ - { - "name": "small_messages_high_frequency", - "message_size": 50, - "message_count": 1000, - "description": "Many small messages", - }, - { - "name": "medium_messages_medium_frequency", - "message_size": 500, - "message_count": 100, - "description": "Medium messages", - }, - { - "name": "large_messages_low_frequency", - "message_size": 5000, - "message_count": 10, - "description": "Large messages", - }, -] - -# Stress test configurations -STRESS_TEST_CONFIGS = [ - { - "name": "high_frequency_publishing", - "publisher_count": 5, - "messages_per_publisher": 100, - "subscriber_count": 3, - "description": "High frequency publishing with multiple publishers", - }, - { - "name": "many_subscribers", - "publisher_count": 1, - "messages_per_publisher": 50, - "subscriber_count": 10, - "description": "Many subscribers to single publisher", - }, - { - "name": "mixed_load", - "publisher_count": 3, - "messages_per_publisher": 100, - "subscriber_count": 5, - "description": "Mixed load with multiple publishers and subscribers", - }, -] - -# Edge case test data -EDGE_CASE_MESSAGES = [ - b"", # Empty message - b"\x00", # Single null byte - b"\xff", # Single max byte value - b"a", # Single ASCII character - "ä".encode(), # Single unicode character (2 bytes) - "𐍈".encode(), # Unicode character outside BMP (4 bytes) - b"\x00" * 1000, # 1000 null bytes - b"\xff" * 1000, # 1000 max byte values -] - -# Message validation test data -MESSAGE_VALIDATION_TEST_CASES = [ - { - "name": "valid_bytes", - "input": b"valid_message", - "should_pass": True, - "description": "Valid bytes message", - }, - { - "name": "empty_bytes", - "input": b"", - "should_pass": True, - "description": "Empty bytes message", - }, - { - "name": "binary_data", - "input": bytes(range(256)), - "should_pass": True, - "description": "Binary data with all byte values", - }, - { - "name": "large_message", - "input": b"x" * 1000000, # 1MB - "should_pass": True, - "description": "Large message (1MB)", - }, -] - -# Redis connection test scenarios -REDIS_CONNECTION_TEST_SCENARIOS = [ - { - "name": "normal_connection", - "should_fail": False, - "description": "Normal Redis connection", - }, - { - "name": "connection_timeout", - "should_fail": True, - "description": "Connection timeout scenario", - }, - { - "name": "connection_refused", - "should_fail": True, - "description": "Connection refused scenario", - }, -] - -# Test constants -DEFAULT_TIMEOUT = 10.0 -SHORT_TIMEOUT = 2.0 -LONG_TIMEOUT = 30.0 - -# Message size limits for testing -MAX_SMALL_MESSAGE_SIZE = 100 -MAX_MEDIUM_MESSAGE_SIZE = 1000 -MAX_LARGE_MESSAGE_SIZE = 10000 - -# Thread counts for concurrency testing -MIN_THREAD_COUNT = 1 -MAX_THREAD_COUNT = 10 -DEFAULT_THREAD_COUNT = 3 - -# Buffer sizes for testing -MIN_BUFFER_SIZE = 1 -MAX_BUFFER_SIZE = 1000 -DEFAULT_BUFFER_SIZE = 10 diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py deleted file mode 100644 index 65f3007b01..0000000000 --- a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py +++ /dev/null @@ -1,396 +0,0 @@ -""" -Test helper utilities for Redis broadcast channel integration tests. - -This module provides utility classes and functions for testing concurrent -operations, monitoring subscriptions, and measuring performance. -""" - -import logging -import threading -import time -from collections.abc import Callable -from typing import Any - -_logger = logging.getLogger(__name__) - - -class ConcurrentPublisher: - """ - Utility class for publishing messages concurrently from multiple threads. - - This class manages multiple publisher threads that can publish messages - to the same or different topics concurrently, useful for stress testing - and concurrency validation. - """ - - def __init__(self, producer, message_count: int = 10, delay: float = 0.0): - """ - Initialize the concurrent publisher. - - Args: - producer: The producer instance to publish with - message_count: Number of messages to publish per thread - delay: Delay between messages in seconds - """ - self.producer = producer - self.message_count = message_count - self.delay = delay - self.threads: list[threading.Thread] = [] - self.published_messages: list[list[bytes]] = [] - self._lock = threading.Lock() - self._started = False - - def start_publishers(self, thread_count: int = 3) -> None: - """ - Start multiple publisher threads. - - Args: - thread_count: Number of publisher threads to start - """ - if self._started: - raise RuntimeError("Publishers already started") - - self._started = True - - def _publisher(thread_id: int) -> None: - messages: list[bytes] = [] - for i in range(self.message_count): - message = f"thread_{thread_id}_msg_{i}".encode() - try: - self.producer.publish(message) - messages.append(message) - if self.delay > 0: - time.sleep(self.delay) - except Exception: - _logger.exception("Pubmsg=lisher %s", thread_id) - - with self._lock: - self.published_messages.append(messages) - - for thread_id in range(thread_count): - thread = threading.Thread( - target=_publisher, - args=(thread_id,), - name=f"publisher-{thread_id}", - daemon=True, - ) - thread.start() - self.threads.append(thread) - - def wait_for_completion(self, timeout: float = 30.0) -> bool: - """ - Wait for all publisher threads to complete. - - Args: - timeout: Maximum time to wait in seconds - - Returns: - bool: True if all threads completed successfully - """ - for thread in self.threads: - thread.join(timeout) - if thread.is_alive(): - return False - return True - - def get_all_messages(self) -> list[bytes]: - """ - Get all messages published by all threads. - - Returns: - list[bytes]: Flattened list of all published messages - """ - with self._lock: - all_messages = [] - for thread_messages in self.published_messages: - all_messages.extend(thread_messages) - return all_messages - - def get_thread_messages(self, thread_id: int) -> list[bytes]: - """ - Get messages published by a specific thread. - - Args: - thread_id: ID of the thread - - Returns: - list[bytes]: Messages published by the specified thread - """ - with self._lock: - if 0 <= thread_id < len(self.published_messages): - return self.published_messages[thread_id].copy() - return [] - - -class SubscriptionMonitor: - """ - Utility class for monitoring subscription activity in tests. - - This class monitors a subscription and tracks message reception, - errors, and completion status for testing purposes. - """ - - def __init__(self, subscription, timeout: float = 10.0): - """ - Initialize the subscription monitor. - - Args: - subscription: The subscription to monitor - timeout: Default timeout for operations - """ - self.subscription = subscription - self.timeout = timeout - self.messages: list[bytes] = [] - self.errors: list[Exception] = [] - self.completed = False - self._lock = threading.Lock() - self._condition = threading.Condition(self._lock) - self._monitor_thread: threading.Thread | None = None - self._start_time: float | None = None - - def start_monitoring(self) -> None: - """Start monitoring the subscription in a separate thread.""" - if self._monitor_thread is not None: - raise RuntimeError("Monitoring already started") - - self._start_time = time.time() - - def _monitor(): - try: - for message in self.subscription: - with self._lock: - self.messages.append(message) - self._condition.notify_all() - except Exception as e: - with self._lock: - self.errors.append(e) - self._condition.notify_all() - finally: - with self._lock: - self.completed = True - self._condition.notify_all() - - self._monitor_thread = threading.Thread( - target=_monitor, - name="subscription-monitor", - daemon=True, - ) - self._monitor_thread.start() - - def wait_for_messages(self, count: int, timeout: float | None = None) -> bool: - """ - Wait for a specific number of messages. - - Args: - count: Number of messages to wait for - timeout: Timeout in seconds (uses default if None) - - Returns: - bool: True if expected messages were received - """ - if timeout is None: - timeout = self.timeout - - deadline = time.time() + timeout - - with self._condition: - while len(self.messages) < count and not self.completed: - remaining = deadline - time.time() - if remaining <= 0: - return False - self._condition.wait(remaining) - - return len(self.messages) >= count - - def wait_for_completion(self, timeout: float | None = None) -> bool: - """ - Wait for monitoring to complete. - - Args: - timeout: Timeout in seconds (uses default if None) - - Returns: - bool: True if monitoring completed successfully - """ - if timeout is None: - timeout = self.timeout - - deadline = time.time() + timeout - - with self._condition: - while not self.completed: - remaining = deadline - time.time() - if remaining <= 0: - return False - self._condition.wait(remaining) - - return True - - def get_messages(self) -> list[bytes]: - """ - Get all received messages. - - Returns: - list[bytes]: Copy of received messages - """ - with self._lock: - return self.messages.copy() - - def get_error_count(self) -> int: - """ - Get the number of errors encountered. - - Returns: - int: Number of errors - """ - with self._lock: - return len(self.errors) - - def get_elapsed_time(self) -> float: - """ - Get the elapsed monitoring time. - - Returns: - float: Elapsed time in seconds - """ - if self._start_time is None: - return 0.0 - return time.time() - self._start_time - - def stop(self) -> None: - """Stop monitoring and close the subscription.""" - if self._monitor_thread is not None: - self.subscription.close() - self._monitor_thread.join(timeout=1.0) - - -def assert_message_order(received: list[bytes], expected: list[bytes]) -> bool: - """ - Assert that messages were received in the expected order. - - Args: - received: List of received messages - expected: List of expected messages in order - - Returns: - bool: True if order matches expected - """ - if len(received) != len(expected): - return False - - for i, (recv_msg, exp_msg) in enumerate(zip(received, expected)): - if recv_msg != exp_msg: - _logger.error("Message order mismatch at index %s: expected %s, got %s", i, exp_msg, recv_msg) - return False - - return True - - -def measure_throughput( - operation: Callable[[], Any], - duration: float = 1.0, -) -> tuple[float, int]: - """ - Measure the throughput of an operation over a specified duration. - - Args: - operation: The operation to measure - duration: Duration to run the operation in seconds - - Returns: - tuple[float, int]: (operations per second, total operations) - """ - start_time = time.time() - end_time = start_time + duration - count = 0 - - while time.time() < end_time: - try: - operation() - count += 1 - except Exception: - _logger.exception("Operation failed") - break - - elapsed = time.time() - start_time - ops_per_sec = count / elapsed if elapsed > 0 else 0.0 - - return ops_per_sec, count - - -def wait_for_condition( - condition: Callable[[], bool], - timeout: float = 10.0, - interval: float = 0.1, -) -> bool: - """ - Wait for a condition to become true. - - Args: - condition: Function that returns True when condition is met - timeout: Maximum time to wait in seconds - interval: Check interval in seconds - - Returns: - bool: True if condition was met within timeout - """ - deadline = time.time() + timeout - - while time.time() < deadline: - if condition(): - return True - time.sleep(interval) - - return False - - -def create_stress_test_messages( - count: int, - size: int = 100, -) -> list[bytes]: - """ - Create messages for stress testing. - - Args: - count: Number of messages to create - size: Size of each message in bytes - - Returns: - list[bytes]: List of test messages - """ - messages = [] - for i in range(count): - message = f"stress_test_msg_{i:06d}_".ljust(size, "x").encode() - messages.append(message) - return messages - - -def validate_message_integrity( - original_messages: list[bytes], - received_messages: list[bytes], -) -> dict[str, Any]: - """ - Validate the integrity of received messages. - - Args: - original_messages: Messages that were sent - received_messages: Messages that were received - - Returns: - dict[str, Any]: Validation results - """ - original_set = set(original_messages) - received_set = set(received_messages) - - missing_messages = original_set - received_set - extra_messages = received_set - original_set - - return { - "total_sent": len(original_messages), - "total_received": len(received_messages), - "missing_count": len(missing_messages), - "extra_count": len(extra_messages), - "missing_messages": list(missing_messages), - "extra_messages": list(extra_messages), - "integrity_ok": len(missing_messages) == 0 and len(extra_messages) == 0, - } diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py deleted file mode 100644 index 7fad603a6d..0000000000 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ /dev/null @@ -1,166 +0,0 @@ -"""TestContainers integration tests for ChatConversationApi status_count behavior.""" - -import json -import uuid - -from flask.testing import FlaskClient -from sqlalchemy.orm import Session - -from configs import dify_config -from constants import HEADER_NAME_CSRF_TOKEN -from core.workflow.enums import WorkflowExecutionStatus -from libs.datetime_utils import naive_utc_now -from libs.token import _real_cookie_name, generate_csrf_token -from models import Account, DifySetup, Tenant, TenantAccountJoin -from models.account import AccountStatus, TenantAccountRole -from models.enums import CreatorUserRole -from models.model import App, AppMode, Conversation, Message -from models.workflow import WorkflowRun -from services.account_service import AccountService - - -def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: - account = Account( - email=f"test-{uuid.uuid4()}@example.com", - name="Test User", - interface_language="en-US", - status=AccountStatus.ACTIVE, - ) - account.initialized_at = naive_utc_now() - db_session.add(account) - db_session.commit() - - tenant = Tenant(name="Test Tenant", status="normal") - db_session.add(tenant) - db_session.commit() - - join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=TenantAccountRole.OWNER, - current=True, - ) - db_session.add(join) - db_session.commit() - - account.set_tenant_id(tenant.id) - account.timezone = "UTC" - db_session.commit() - - dify_setup = DifySetup(version=dify_config.project.version) - db_session.add(dify_setup) - db_session.commit() - - return account, tenant - - -def _create_app(db_session: Session, tenant_id: str, account_id: str) -> App: - app = App( - tenant_id=tenant_id, - name="Test Chat App", - mode=AppMode.CHAT, - enable_site=True, - enable_api=True, - created_by=account_id, - ) - db_session.add(app) - db_session.commit() - return app - - -def _create_conversation(db_session: Session, app_id: str, account_id: str) -> Conversation: - conversation = Conversation( - app_id=app_id, - name="Test Conversation", - inputs={}, - status="normal", - mode=AppMode.CHAT, - from_source=CreatorUserRole.ACCOUNT, - from_account_id=account_id, - ) - db_session.add(conversation) - db_session.commit() - return conversation - - -def _create_workflow_run(db_session: Session, app_id: str, tenant_id: str, account_id: str) -> WorkflowRun: - workflow_run = WorkflowRun( - tenant_id=tenant_id, - app_id=app_id, - workflow_id=str(uuid.uuid4()), - type="chat", - triggered_from="app-run", - version="1.0.0", - graph=json.dumps({"nodes": [], "edges": []}), - inputs=json.dumps({"query": "test"}), - status=WorkflowExecutionStatus.PAUSED, - outputs=json.dumps({}), - elapsed_time=0.0, - total_tokens=0, - total_steps=0, - created_by_role=CreatorUserRole.ACCOUNT, - created_by=account_id, - created_at=naive_utc_now(), - ) - db_session.add(workflow_run) - db_session.commit() - return workflow_run - - -def _create_message( - db_session: Session, app_id: str, conversation_id: str, workflow_run_id: str, account_id: str -) -> Message: - message = Message( - app_id=app_id, - conversation_id=conversation_id, - query="Hello", - message={"type": "text", "content": "Hello"}, - answer="Hi there", - message_tokens=1, - answer_tokens=1, - message_unit_price=0.001, - answer_unit_price=0.001, - message_price_unit=0.001, - answer_price_unit=0.001, - currency="USD", - status="normal", - from_source=CreatorUserRole.ACCOUNT, - from_account_id=account_id, - workflow_run_id=workflow_run_id, - inputs={"query": "Hello"}, - ) - db_session.add(message) - db_session.commit() - return message - - -def test_chat_conversation_status_count_includes_paused( - db_session_with_containers: Session, - test_client_with_containers: FlaskClient, -): - account, tenant = _create_account_and_tenant(db_session_with_containers) - app = _create_app(db_session_with_containers, tenant.id, account.id) - conversation = _create_conversation(db_session_with_containers, app.id, account.id) - conversation_id = conversation.id - workflow_run = _create_workflow_run(db_session_with_containers, app.id, tenant.id, account.id) - _create_message(db_session_with_containers, app.id, conversation.id, workflow_run.id, account.id) - - access_token = AccountService.get_account_jwt_token(account) - csrf_token = generate_csrf_token(account.id) - cookie_name = _real_cookie_name("csrf_token") - - test_client_with_containers.set_cookie(cookie_name, csrf_token, domain="localhost") - response = test_client_with_containers.get( - f"/console/api/apps/{app.id}/chat-conversations", - headers={ - "Authorization": f"Bearer {access_token}", - HEADER_NAME_CSRF_TOKEN: csrf_token, - }, - ) - - assert response.status_code == 200 - payload = response.get_json() - assert payload is not None - assert payload["total"] == 1 - assert payload["data"][0]["id"] == conversation_id - assert payload["data"][0]["status_count"]["paused"] == 1 diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py deleted file mode 100644 index 079e4934bb..0000000000 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ /dev/null @@ -1,240 +0,0 @@ -"""TestContainers integration tests for HumanInputFormRepositoryImpl.""" - -from __future__ import annotations - -from uuid import uuid4 - -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session - -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.nodes.human_input.entities import ( - DeliveryChannelConfig, - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - FormDefinition, - HumanInputNodeData, - MemberRecipient, - UserAction, - WebAppDeliveryMethod, -) -from core.workflow.repositories.human_input_form_repository import FormCreateParams -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.human_input import ( - EmailExternalRecipientPayload, - EmailMemberRecipientPayload, - HumanInputForm, - HumanInputFormRecipient, - RecipientType, -) - - -def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]: - tenant = Tenant(name="Test Tenant", status="normal") - session.add(tenant) - session.flush() - - members: list[Account] = [] - for index, email in enumerate(member_emails): - account = Account( - email=email, - name=f"Member {index}", - interface_language="en-US", - status="active", - ) - session.add(account) - session.flush() - - tenant_join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=TenantAccountRole.NORMAL, - current=True, - ) - session.add(tenant_join) - members.append(account) - - session.commit() - return tenant, members - - -def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCreateParams: - form_config = HumanInputNodeData( - title="Human Approval", - delivery_methods=delivery_methods, - form_content="

Approve?

", - user_actions=[UserAction(id="approve", title="Approve")], - ) - return FormCreateParams( - app_id=str(uuid4()), - workflow_execution_id=str(uuid4()), - node_id="human-input-node", - form_config=form_config, - rendered_content="

Approve?

", - delivery_methods=delivery_methods, - display_in_ui=False, - resolved_default_values={}, - ) - - -def _build_email_delivery( - whole_workspace: bool, recipients: list[MemberRecipient | ExternalRecipient] -) -> EmailDeliveryMethod: - return EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), - subject="Approval Needed", - body="Please review", - ) - ) - - -class TestHumanInputFormRepositoryImplWithContainers: - def test_create_form_with_whole_workspace_recipients(self, db_session_with_containers: Session) -> None: - engine = db_session_with_containers.get_bind() - assert isinstance(engine, Engine) - tenant, members = _create_tenant_with_members( - db_session_with_containers, - member_emails=["member1@example.com", "member2@example.com"], - ) - - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) - params = _build_form_params( - delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], - ) - - form_entity = repository.create_form(params) - - with Session(engine) as verification_session: - recipients = verification_session.scalars( - select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id) - ).all() - - assert len(recipients) == len(members) - member_payloads = [ - EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload) - for recipient in recipients - if recipient.recipient_type == RecipientType.EMAIL_MEMBER - ] - member_emails = {payload.email for payload in member_payloads} - assert member_emails == {member.email for member in members} - - def test_create_form_with_specific_members_and_external(self, db_session_with_containers: Session) -> None: - engine = db_session_with_containers.get_bind() - assert isinstance(engine, Engine) - tenant, members = _create_tenant_with_members( - db_session_with_containers, - member_emails=["primary@example.com", "secondary@example.com"], - ) - - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) - params = _build_form_params( - delivery_methods=[ - _build_email_delivery( - whole_workspace=False, - recipients=[ - MemberRecipient(user_id=members[0].id), - ExternalRecipient(email="external@example.com"), - ], - ) - ], - ) - - form_entity = repository.create_form(params) - - with Session(engine) as verification_session: - recipients = verification_session.scalars( - select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id) - ).all() - - member_recipient_payloads = [ - EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload) - for recipient in recipients - if recipient.recipient_type == RecipientType.EMAIL_MEMBER - ] - assert len(member_recipient_payloads) == 1 - assert member_recipient_payloads[0].user_id == members[0].id - - external_payloads = [ - EmailExternalRecipientPayload.model_validate_json(recipient.recipient_payload) - for recipient in recipients - if recipient.recipient_type == RecipientType.EMAIL_EXTERNAL - ] - assert len(external_payloads) == 1 - assert external_payloads[0].email == "external@example.com" - - def test_create_form_persists_default_values(self, db_session_with_containers: Session) -> None: - engine = db_session_with_containers.get_bind() - assert isinstance(engine, Engine) - tenant, _ = _create_tenant_with_members( - db_session_with_containers, - member_emails=["prefill@example.com"], - ) - - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) - resolved_values = {"greeting": "Hello!"} - params = FormCreateParams( - app_id=str(uuid4()), - workflow_execution_id=str(uuid4()), - node_id="human-input-node", - form_config=HumanInputNodeData( - title="Human Approval", - form_content="

Approve?

", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ), - rendered_content="

Approve?

", - delivery_methods=[], - display_in_ui=False, - resolved_default_values=resolved_values, - ) - - form_entity = repository.create_form(params) - - with Session(engine) as verification_session: - form_model = verification_session.scalars( - select(HumanInputForm).where(HumanInputForm.id == form_entity.id) - ).first() - - assert form_model is not None - definition = FormDefinition.model_validate_json(form_model.form_definition) - assert definition.default_values == resolved_values - - def test_create_form_persists_display_in_ui(self, db_session_with_containers: Session) -> None: - engine = db_session_with_containers.get_bind() - assert isinstance(engine, Engine) - tenant, _ = _create_tenant_with_members( - db_session_with_containers, - member_emails=["ui@example.com"], - ) - - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) - params = FormCreateParams( - app_id=str(uuid4()), - workflow_execution_id=str(uuid4()), - node_id="human-input-node", - form_config=HumanInputNodeData( - title="Human Approval", - form_content="

Approve?

", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - delivery_methods=[WebAppDeliveryMethod()], - ), - rendered_content="

Approve?

", - delivery_methods=[WebAppDeliveryMethod()], - display_in_ui=True, - resolved_default_values={}, - ) - - form_entity = repository.create_form(params) - - with Session(engine) as verification_session: - form_model = verification_session.scalars( - select(HumanInputForm).where(HumanInputForm.id == form_entity.id) - ).first() - - assert form_model is not None - definition = FormDefinition.model_validate_json(form_model.form_definition) - assert definition.display_in_ui is True diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py deleted file mode 100644 index 06d55177eb..0000000000 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ /dev/null @@ -1,336 +0,0 @@ -import time -import uuid -from datetime import timedelta -from unittest.mock import MagicMock - -import pytest -from sqlalchemy import delete, select -from sqlalchemy.orm import Session - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowType -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from models import Account -from models.account import Tenant, TenantAccountJoin, TenantAccountRole -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom -from models.model import App, AppMode, IconType -from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun - - -def _mock_form_repository_without_submission() -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = False - repo.create_form.return_value = form_entity - repo.get_form.return_value = None - return repo - - -def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = True - form_entity.selected_action_id = action_id - form_entity.submitted_data = {} - form_entity.status = HumanInputFormStatus.WAITING - form_entity.expiration_time = naive_utc_now() + timedelta(hours=1) - repo.get_form.return_value = form_entity - return repo - - -def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - workflow_execution_id=workflow_execution_id, - app_id=app_id, - workflow_id=workflow_id, - user_id=user_id, - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph( - runtime_state: GraphRuntimeState, - tenant_id: str, - app_id: str, - workflow_id: str, - user_id: str, - form_repository: HumanInputFormRepository, -) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, - workflow_id=workflow_id, - graph_config=graph_config, - user_id=user_id, - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_data = StartNodeData(title="start", variables=[]) - start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="human", - form_content="Awaiting human input", - inputs=[], - user_actions=[ - UserAction(id="continue", title="Continue"), - ], - ) - human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - ) - - end_data = EndNodeData( - title="end", - outputs=[], - desc=None, - ) - end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_node) - .add_node(end_node, from_node_id="human", source_handle="continue") - .build() - ) - - -def _build_generate_entity( - tenant_id: str, - app_id: str, - workflow_id: str, - workflow_execution_id: str, - user_id: str, -) -> WorkflowAppGenerateEntity: - app_config = WorkflowUIBasedAppConfig( - tenant_id=tenant_id, - app_id=app_id, - app_mode=AppMode.WORKFLOW, - workflow_id=workflow_id, - ) - return WorkflowAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - inputs={}, - files=[], - user_id=user_id, - stream=False, - invoke_from=InvokeFrom.DEBUGGER, - workflow_execution_id=workflow_execution_id, - ) - - -class TestHumanInputResumeNodeExecutionIntegration: - @pytest.fixture(autouse=True) - def setup_test_data(self, db_session_with_containers: Session): - tenant = Tenant( - name="Test Tenant", - status="normal", - ) - db_session_with_containers.add(tenant) - db_session_with_containers.commit() - - account = Account( - email="test@example.com", - name="Test User", - interface_language="en-US", - status="active", - ) - db_session_with_containers.add(account) - db_session_with_containers.commit() - - tenant_join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=TenantAccountRole.OWNER, - current=True, - ) - db_session_with_containers.add(tenant_join) - db_session_with_containers.commit() - - account.current_tenant = tenant - - app = App( - tenant_id=tenant.id, - name="Test App", - description="", - mode=AppMode.WORKFLOW.value, - icon_type=IconType.EMOJI.value, - icon="rocket", - icon_background="#4ECDC4", - enable_site=False, - enable_api=False, - api_rpm=0, - api_rph=0, - is_demo=False, - is_public=False, - is_universal=False, - max_active_requests=None, - created_by=account.id, - updated_by=account.id, - ) - db_session_with_containers.add(app) - db_session_with_containers.commit() - - workflow = Workflow( - tenant_id=tenant.id, - app_id=app.id, - type="workflow", - version="draft", - graph='{"nodes": [], "edges": []}', - features='{"file_upload": {"enabled": false}}', - created_by=account.id, - created_at=naive_utc_now(), - ) - db_session_with_containers.add(workflow) - db_session_with_containers.commit() - - self.session = db_session_with_containers - self.tenant = tenant - self.account = account - self.app = app - self.workflow = workflow - - yield - - self.session.execute(delete(WorkflowNodeExecutionModel)) - self.session.execute(delete(WorkflowRun)) - self.session.execute(delete(Workflow).where(Workflow.id == self.workflow.id)) - self.session.execute(delete(App).where(App.id == self.app.id)) - self.session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == self.tenant.id)) - self.session.execute(delete(Account).where(Account.id == self.account.id)) - self.session.execute(delete(Tenant).where(Tenant.id == self.tenant.id)) - self.session.commit() - - def _build_persistence_layer(self, execution_id: str) -> WorkflowPersistenceLayer: - generate_entity = _build_generate_entity( - tenant_id=self.tenant.id, - app_id=self.app.id, - workflow_id=self.workflow.id, - workflow_execution_id=execution_id, - user_id=self.account.id, - ) - execution_repo = SQLAlchemyWorkflowExecutionRepository( - session_factory=self.session.get_bind(), - user=self.account, - app_id=self.app.id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - ) - node_execution_repo = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=self.session.get_bind(), - user=self.account, - app_id=self.app.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - return WorkflowPersistenceLayer( - application_generate_entity=generate_entity, - workflow_info=PersistenceWorkflowInfo( - workflow_id=self.workflow.id, - workflow_type=WorkflowType.WORKFLOW, - version=self.workflow.version, - graph_data=self.workflow.graph_dict, - ), - workflow_execution_repository=execution_repo, - workflow_node_execution_repository=node_execution_repo, - ) - - def _run_graph(self, graph: Graph, runtime_state: GraphRuntimeState, execution_id: str) -> None: - engine = GraphEngine( - workflow_id=self.workflow.id, - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - ) - engine.layer(self._build_persistence_layer(execution_id)) - for _ in engine.run(): - continue - - def test_resume_human_input_does_not_create_duplicate_node_execution(self): - execution_id = str(uuid.uuid4()) - runtime_state = _build_runtime_state( - workflow_execution_id=execution_id, - app_id=self.app.id, - workflow_id=self.workflow.id, - user_id=self.account.id, - ) - pause_repo = _mock_form_repository_without_submission() - paused_graph = _build_graph( - runtime_state, - self.tenant.id, - self.app.id, - self.workflow.id, - self.account.id, - pause_repo, - ) - self._run_graph(paused_graph, runtime_state, execution_id) - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resume_repo = _mock_form_repository_with_submission(action_id="continue") - resumed_graph = _build_graph( - resumed_state, - self.tenant.id, - self.app.id, - self.workflow.id, - self.account.id, - resume_repo, - ) - self._run_graph(resumed_graph, resumed_state, execution_id) - - stmt = select(WorkflowNodeExecutionModel).where( - WorkflowNodeExecutionModel.workflow_run_id == execution_id, - WorkflowNodeExecutionModel.node_id == "human", - ) - records = self.session.execute(stmt).scalars().all() - assert len(records) == 1 - assert records[0].status != "paused" - assert records[0].triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN - assert records[0].created_by_role == CreatorUserRole.ACCOUNT diff --git a/api/tests/test_containers_integration_tests/helpers/__init__.py b/api/tests/test_containers_integration_tests/helpers/__init__.py deleted file mode 100644 index 40d03889a9..0000000000 --- a/api/tests/test_containers_integration_tests/helpers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Helper utilities for integration tests.""" diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py deleted file mode 100644 index 19d7772c39..0000000000 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ /dev/null @@ -1,154 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from datetime import datetime, timedelta -from decimal import Decimal -from uuid import uuid4 - -from core.workflow.nodes.human_input.entities import FormDefinition, UserAction -from models.account import Account, Tenant, TenantAccountJoin -from models.execution_extra_content import HumanInputContent -from models.human_input import HumanInputForm, HumanInputFormStatus -from models.model import App, Conversation, Message - - -@dataclass -class HumanInputMessageFixture: - app: App - account: Account - conversation: Conversation - message: Message - form: HumanInputForm - action_id: str - action_text: str - node_title: str - - -def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: - tenant = Tenant(name=f"Tenant {uuid4()}") - db_session.add(tenant) - db_session.flush() - - account = Account( - name=f"Account {uuid4()}", - email=f"human_input_{uuid4()}@example.com", - password="hashed-password", - password_salt="salt", - interface_language="en-US", - timezone="UTC", - ) - db_session.add(account) - db_session.flush() - - tenant_join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role="owner", - current=True, - ) - db_session.add(tenant_join) - db_session.flush() - - app = App( - tenant_id=tenant.id, - name=f"App {uuid4()}", - description="", - mode="chat", - icon_type="emoji", - icon="🤖", - icon_background="#FFFFFF", - enable_site=False, - enable_api=True, - api_rpm=100, - api_rph=100, - is_demo=False, - is_public=False, - is_universal=False, - created_by=account.id, - updated_by=account.id, - ) - db_session.add(app) - db_session.flush() - - conversation = Conversation( - app_id=app.id, - mode="chat", - name="Test Conversation", - summary="", - introduction="", - system_instruction="", - status="normal", - invoke_from="console", - from_source="console", - from_account_id=account.id, - from_end_user_id=None, - ) - conversation.inputs = {} - db_session.add(conversation) - db_session.flush() - - workflow_run_id = str(uuid4()) - message = Message( - app_id=app.id, - conversation_id=conversation.id, - inputs={}, - query="Human input query", - message={"messages": []}, - answer="Human input answer", - message_tokens=50, - message_unit_price=Decimal("0.001"), - answer_tokens=80, - answer_unit_price=Decimal("0.001"), - provider_response_latency=0.5, - currency="USD", - from_source="console", - from_account_id=account.id, - workflow_run_id=workflow_run_id, - ) - db_session.add(message) - db_session.flush() - - action_id = "approve" - action_text = "Approve request" - node_title = "Approval" - form_definition = FormDefinition( - form_content="content", - inputs=[], - user_actions=[UserAction(id=action_id, title=action_text)], - rendered_content="Rendered block", - expiration_time=datetime.utcnow() + timedelta(days=1), - node_title=node_title, - display_in_ui=True, - ) - form = HumanInputForm( - tenant_id=tenant.id, - app_id=app.id, - workflow_run_id=workflow_run_id, - node_id="node-id", - form_definition=form_definition.model_dump_json(), - rendered_content="Rendered block", - status=HumanInputFormStatus.SUBMITTED, - expiration_time=datetime.utcnow() + timedelta(days=1), - selected_action_id=action_id, - ) - db_session.add(form) - db_session.flush() - - content = HumanInputContent( - workflow_run_id=workflow_run_id, - message_id=message.id, - form_id=form.id, - ) - db_session.add(content) - db_session.commit() - - return HumanInputMessageFixture( - app=app, - account=account, - conversation=conversation, - message=message, - form=form, - action_id=action_id, - action_text=action_text, - node_title=node_title, - ) diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py index 43915a204d..d612e70910 100644 --- a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py @@ -16,7 +16,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest import redis -from redis.cluster import RedisCluster from testcontainers.redis import RedisContainer from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic @@ -333,95 +332,3 @@ class TestShardedRedisBroadcastChannelIntegration: # Verify subscriptions are cleaned up topic_subscribers_after = self._get_sharded_numsub(redis_client, topic_name) assert topic_subscribers_after == 0 - - -class TestShardedRedisBroadcastChannelClusterIntegration: - """Integration tests for sharded pub/sub with RedisCluster client.""" - - @pytest.fixture(scope="class") - def redis_cluster_container(self) -> Iterator[RedisContainer]: - """Create a Redis 7 container with cluster mode enabled.""" - command = ( - "redis-server --port 6379 " - "--cluster-enabled yes " - "--cluster-config-file nodes.conf " - "--cluster-node-timeout 5000 " - "--appendonly no " - "--protected-mode no" - ) - with RedisContainer(image="redis:7-alpine").with_command(command) as container: - yield container - - @classmethod - def _get_test_topic_name(cls) -> str: - return f"test_sharded_cluster_topic_{uuid.uuid4()}" - - @staticmethod - def _ensure_single_node_cluster(host: str, port: int) -> None: - client = redis.Redis(host=host, port=port, decode_responses=False) - client.config_set("cluster-announce-ip", host) - client.config_set("cluster-announce-port", port) - slots = client.execute_command("CLUSTER", "SLOTS") - if not slots: - client.execute_command("CLUSTER", "ADDSLOTSRANGE", 0, 16383) - - deadline = time.time() + 5.0 - while time.time() < deadline: - info = client.execute_command("CLUSTER", "INFO") - info_text = info.decode("utf-8") if isinstance(info, (bytes, bytearray)) else str(info) - if "cluster_state:ok" in info_text: - return - time.sleep(0.05) - raise RuntimeError("Redis cluster did not become ready in time") - - @pytest.fixture(scope="class") - def redis_cluster_client(self, redis_cluster_container: RedisContainer) -> RedisCluster: - host = redis_cluster_container.get_container_host_ip() - port = int(redis_cluster_container.get_exposed_port(6379)) - self._ensure_single_node_cluster(host, port) - return RedisCluster(host=host, port=port, decode_responses=False) - - @pytest.fixture - def broadcast_channel(self, redis_cluster_client: RedisCluster) -> BroadcastChannel: - return ShardedRedisBroadcastChannel(redis_cluster_client) - - def test_cluster_sharded_pubsub_delivers_message(self, broadcast_channel: BroadcastChannel): - """Ensure sharded subscription receives messages when using RedisCluster client.""" - topic_name = self._get_test_topic_name() - message = b"cluster sharded message" - - topic = broadcast_channel.topic(topic_name) - producer = topic.as_producer() - subscription = topic.subscribe() - ready_event = threading.Event() - - def consumer_thread() -> list[bytes]: - received = [] - try: - _ = subscription.receive(0.01) - except SubscriptionClosedError: - return received - ready_event.set() - deadline = time.time() + 5.0 - while time.time() < deadline: - msg = subscription.receive(timeout=0.1) - if msg is None: - continue - received.append(msg) - break - subscription.close() - return received - - def producer_thread(): - if not ready_event.wait(timeout=2.0): - pytest.fail("subscriber did not become ready before publish") - producer.publish(message) - - with ThreadPoolExecutor(max_workers=2) as executor: - consumer_future = executor.submit(consumer_thread) - producer_future = executor.submit(producer_thread) - - producer_future.result(timeout=5.0) - received_messages = consumer_future.result(timeout=5.0) - - assert received_messages == [message] diff --git a/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py b/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py deleted file mode 100644 index 178fc2e4fb..0000000000 --- a/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -Integration tests for RateLimiter using testcontainers Redis. -""" - -import uuid - -import pytest - -from extensions.ext_redis import redis_client -from libs import helper as helper_module - - -@pytest.mark.usefixtures("flask_app_with_containers") -def test_rate_limiter_counts_multiple_attempts_in_same_second(monkeypatch): - prefix = f"test_rate_limit:{uuid.uuid4().hex}" - limiter = helper_module.RateLimiter(prefix=prefix, max_attempts=2, time_window=60) - key = limiter._get_key("203.0.113.10") - - redis_client.delete(key) - monkeypatch.setattr(helper_module.time, "time", lambda: 1_700_000_000) - - limiter.increment_rate_limit("203.0.113.10") - limiter.increment_rate_limit("203.0.113.10") - - assert limiter.is_rate_limited("203.0.113.10") is True diff --git a/api/tests/test_containers_integration_tests/models/test_account.py b/api/tests/test_containers_integration_tests/models/test_account.py deleted file mode 100644 index 078dc0e8de..0000000000 --- a/api/tests/test_containers_integration_tests/models/test_account.py +++ /dev/null @@ -1,79 +0,0 @@ -# import secrets - -# import pytest -# from sqlalchemy import select -# from sqlalchemy.orm import Session -# from sqlalchemy.orm.exc import DetachedInstanceError - -# from libs.datetime_utils import naive_utc_now -# from models.account import Account, Tenant, TenantAccountJoin - - -# @pytest.fixture -# def session(db_session_with_containers): -# with Session(db_session_with_containers.get_bind()) as session: -# yield session - - -# @pytest.fixture -# def account(session): -# account = Account( -# name="test account", -# email=f"test_{secrets.token_hex(8)}@example.com", -# ) -# session.add(account) -# session.commit() -# return account - - -# @pytest.fixture -# def tenant(session): -# tenant = Tenant(name="test tenant") -# session.add(tenant) -# session.commit() -# return tenant - - -# @pytest.fixture -# def tenant_account_join(session, account, tenant): -# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id) -# session.add(tenant_join) -# session.commit() -# yield tenant_join -# session.delete(tenant_join) -# session.commit() - - -# class TestAccountTenant: -# def test_set_current_tenant_should_reload_tenant( -# self, -# db_session_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session: -# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one() -# account.current_tenant = scoped_tenant -# scoped_tenant.created_at = naive_utc_now() -# # session.commit() - -# # Ensure the tenant used in assignment is detached. -# with pytest.raises(DetachedInstanceError): -# _ = scoped_tenant.name - -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id - -# def test_set_tenant_id_should_load_tenant_as_not_expire( -# self, -# flask_app_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with flask_app_with_containers.test_request_context(): -# account.set_tenant_id(tenant.id) - -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id diff --git a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py deleted file mode 100644 index c9058626d1..0000000000 --- a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from sqlalchemy.orm import sessionmaker - -from extensions.ext_database import db -from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository -from tests.test_containers_integration_tests.helpers.execution_extra_content import ( - create_human_input_message_fixture, -) - - -def test_get_by_message_ids_returns_human_input_content(db_session_with_containers): - fixture = create_human_input_message_fixture(db_session_with_containers) - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=sessionmaker(bind=db.engine, expire_on_commit=False) - ) - - results = repository.get_by_message_ids([fixture.message.id]) - - assert len(results) == 1 - assert len(results[0]) == 1 - content = results[0][0] - assert content.submitted is True - assert content.form_submission_data is not None - assert content.form_submission_data.action_id == fixture.action_id - assert content.form_submission_data.action_text == fixture.action_text - assert content.form_submission_data.rendered_content == fixture.form.rendered_content diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 4b6b5048a1..4d4e77a802 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -2293,12 +2293,6 @@ class TestRegisterService: mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False - from extensions.ext_database import db - from models.model import DifySetup - - db.session.query(DifySetup).delete() - db.session.commit() - # Execute setup RegisterService.setup( email=admin_email, @@ -2309,7 +2303,9 @@ class TestRegisterService: ) # Verify account was created + from extensions.ext_database import db from models import Account + from models.model import DifySetup account = db.session.query(Account).filter_by(email=admin_email).first() assert account is not None diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 81bfa0ea20..476f58585d 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -1,5 +1,5 @@ import uuid -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker @@ -26,7 +26,6 @@ class TestAppGenerateService: patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator, patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator, patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator, - patch("services.app_generate_service.MessageBasedAppGenerator") as mock_message_based_generator, patch("services.account_service.FeatureService") as mock_account_feature_service, patch("services.app_generate_service.dify_config") as mock_dify_config, patch("configs.dify_config") as mock_global_dify_config, @@ -39,13 +38,9 @@ class TestAppGenerateService: # Setup default mock returns for workflow service mock_workflow_service_instance = mock_workflow_service.return_value - mock_published_workflow = MagicMock(spec=Workflow) - mock_published_workflow.id = str(uuid.uuid4()) - mock_workflow_service_instance.get_published_workflow.return_value = mock_published_workflow - mock_draft_workflow = MagicMock(spec=Workflow) - mock_draft_workflow.id = str(uuid.uuid4()) - mock_workflow_service_instance.get_draft_workflow.return_value = mock_draft_workflow - mock_workflow_service_instance.get_published_workflow_by_id.return_value = mock_published_workflow + mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow) + mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow) + mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow) # Setup default mock returns for rate limiting mock_rate_limit_instance = mock_rate_limit.return_value @@ -71,8 +66,6 @@ class TestAppGenerateService: mock_advanced_chat_generator_instance.generate.return_value = ["advanced_chat_response"] mock_advanced_chat_generator_instance.single_iteration_generate.return_value = ["single_iteration_response"] mock_advanced_chat_generator_instance.single_loop_generate.return_value = ["single_loop_response"] - mock_advanced_chat_generator_instance.retrieve_events.return_value = ["advanced_chat_events"] - mock_advanced_chat_generator_instance.convert_to_event_stream.return_value = ["advanced_chat_stream"] mock_advanced_chat_generator.convert_to_event_stream.return_value = ["advanced_chat_stream"] mock_workflow_generator_instance = mock_workflow_generator.return_value @@ -83,8 +76,6 @@ class TestAppGenerateService: mock_workflow_generator_instance.single_loop_generate.return_value = ["workflow_single_loop_response"] mock_workflow_generator.convert_to_event_stream.return_value = ["workflow_stream"] - mock_message_based_generator.retrieve_events.return_value = ["workflow_events"] - # Setup default mock returns for account service mock_account_feature_service.get_system_features.return_value.is_allow_register = True @@ -97,7 +88,6 @@ class TestAppGenerateService: mock_global_dify_config.BILLING_ENABLED = False mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000 - mock_global_dify_config.HOSTED_POOL_CREDITS = 1000 yield { "billing_service": mock_billing_service, @@ -108,7 +98,6 @@ class TestAppGenerateService: "agent_chat_generator": mock_agent_chat_generator, "advanced_chat_generator": mock_advanced_chat_generator, "workflow_generator": mock_workflow_generator, - "message_based_generator": mock_message_based_generator, "account_feature_service": mock_account_feature_service, "dify_config": mock_dify_config, "global_dify_config": mock_global_dify_config, @@ -291,10 +280,8 @@ class TestAppGenerateService: assert result == ["test_response"] # Verify advanced chat generator was called - mock_external_service_dependencies["advanced_chat_generator"].return_value.retrieve_events.assert_called_once() - mock_external_service_dependencies[ - "advanced_chat_generator" - ].return_value.convert_to_event_stream.assert_called_once() + mock_external_service_dependencies["advanced_chat_generator"].return_value.generate.assert_called_once() + mock_external_service_dependencies["advanced_chat_generator"].convert_to_event_stream.assert_called_once() def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -317,7 +304,7 @@ class TestAppGenerateService: assert result == ["test_response"] # Verify workflow generator was called - mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once() + mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once() def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies): @@ -983,27 +970,14 @@ class TestAppGenerateService: } # Execute the method under test - with patch("services.app_generate_service.AppExecutionParams") as mock_exec_params: - mock_payload = MagicMock() - mock_payload.workflow_run_id = fake.uuid4() - mock_payload.model_dump_json.return_value = "{}" - mock_exec_params.new.return_value = mock_payload - - result = AppGenerateService.generate( - app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True - ) + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) # Verify the result assert result == ["test_response"] - # Verify payload was built with complex args - mock_exec_params.new.assert_called_once() - call_kwargs = mock_exec_params.new.call_args.kwargs - assert call_kwargs["args"] == args - - # Verify workflow streaming event retrieval was used - mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once_with( - ANY, - mock_payload.workflow_run_id, - on_subscribe=ANY, - ) + # Verify workflow generator was called with complex args + mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once() + call_args = mock_external_service_dependencies["workflow_generator"].return_value.generate.call_args + assert call_args[1]["args"] == args diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py deleted file mode 100644 index 9c978f830f..0000000000 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ /dev/null @@ -1,112 +0,0 @@ -import json -import uuid -from unittest.mock import MagicMock - -import pytest - -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - HumanInputNodeData, -) -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.model import App, AppMode -from models.workflow import Workflow, WorkflowType -from services.workflow_service import WorkflowService - - -def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) -> tuple[App, Account]: - tenant = Tenant(name="Test Tenant") - account = Account(name="Tester", email="tester@example.com") - session.add_all([tenant, account]) - session.flush() - - session.add( - TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - current=True, - role=TenantAccountRole.OWNER.value, - ) - ) - - app = App( - tenant_id=tenant.id, - name="Test App", - description="", - mode=AppMode.WORKFLOW.value, - icon_type="emoji", - icon="app", - icon_background="#ffffff", - enable_site=True, - enable_api=True, - created_by=account.id, - updated_by=account.id, - ) - session.add(app) - session.flush() - - email_method = EmailDeliveryMethod( - id=delivery_method_id, - enabled=True, - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(email="recipient@example.com")], - ), - subject="Test {{recipient_email}}", - body="Body {{#url#}} {{form_content}}", - ), - ) - node_data = HumanInputNodeData( - title="Human Input", - delivery_methods=[email_method], - form_content="Hello Human Input", - inputs=[], - user_actions=[], - ).model_dump(mode="json") - node_data["type"] = NodeType.HUMAN_INPUT.value - graph = json.dumps({"nodes": [{"id": "human-node", "data": node_data}], "edges": []}) - - workflow = Workflow.new( - tenant_id=tenant.id, - app_id=app.id, - type=WorkflowType.WORKFLOW.value, - version=Workflow.VERSION_DRAFT, - graph=graph, - features=json.dumps({}), - created_by=account.id, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], - ) - session.add(workflow) - session.commit() - - return app, account - - -def test_human_input_delivery_test_sends_email( - db_session_with_containers, - monkeypatch: pytest.MonkeyPatch, -) -> None: - delivery_method_id = uuid.uuid4() - app, account = _create_app_with_draft_workflow(db_session_with_containers, delivery_method_id=delivery_method_id) - - send_mock = MagicMock() - monkeypatch.setattr("services.human_input_delivery_test_service.mail.is_inited", lambda: True) - monkeypatch.setattr("services.human_input_delivery_test_service.mail.send", send_mock) - - service = WorkflowService() - service.test_human_input_delivery( - app_model=app, - account=account, - node_id="human-node", - delivery_method_id=str(delivery_method_id), - ) - - assert send_mock.call_count == 1 - assert send_mock.call_args.kwargs["to"] == "recipient@example.com" diff --git a/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py b/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py deleted file mode 100644 index 44e5a82868..0000000000 --- a/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -import pytest - -from services.message_service import MessageService -from tests.test_containers_integration_tests.helpers.execution_extra_content import ( - create_human_input_message_fixture, -) - - -@pytest.mark.usefixtures("flask_req_ctx_with_containers") -def test_pagination_returns_extra_contents(db_session_with_containers): - fixture = create_human_input_message_fixture(db_session_with_containers) - - pagination = MessageService.pagination_by_first_id( - app_model=fixture.app, - user=fixture.account, - conversation_id=fixture.conversation.id, - first_id=None, - limit=10, - ) - - assert pagination.data - message = pagination.data[0] - assert message.extra_contents == [ - { - "type": "human_input", - "workflow_run_id": fixture.message.workflow_run_id, - "submitted": True, - "form_submission_data": { - "node_id": fixture.form.node_id, - "node_title": fixture.node_title, - "rendered_content": fixture.form.rendered_content, - "action_id": fixture.action_id, - "action_text": fixture.action_text, - }, - } - ] diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index e3431fd382..934d1bdd34 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -90,6 +90,7 @@ class TestWebhookService: "id": "webhook_node", "type": "webhook", "data": { + "type": "trigger-webhook", "title": "Test Webhook", "method": "post", "content_type": "application/json", diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 3a88081db3..23c4eeb82f 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -465,27 +465,6 @@ class TestWorkflowRunService: db.session.add(node_execution) node_executions.append(node_execution) - paused_node_execution = WorkflowNodeExecutionModel( - tenant_id=app.tenant_id, - app_id=app.id, - workflow_id=workflow_run.workflow_id, - triggered_from="workflow-run", - workflow_run_id=workflow_run.id, - index=99, - node_id="node_paused", - node_type="human_input", - title="Paused Node", - inputs=json.dumps({"input": "paused"}), - process_data=json.dumps({"process": "paused"}), - status="paused", - elapsed_time=0.5, - execution_metadata=json.dumps({"tokens": 0}), - created_by_role=CreatorUserRole.ACCOUNT, - created_by=account.id, - created_at=datetime.now(UTC), - ) - db.session.add(paused_node_execution) - db.session.commit() # Act: Execute the method under test @@ -494,19 +473,16 @@ class TestWorkflowRunService: # Assert: Verify the expected outcomes assert result is not None - assert len(result) == 4 + assert len(result) == 3 # Verify node execution properties - statuses = [node_execution.status for node_execution in result] - assert "paused" in statuses - assert statuses.count("succeeded") == 3 - assert statuses.count("paused") == 1 - for node_execution in result: assert node_execution.tenant_id == app.tenant_id assert node_execution.app_id == app.id assert node_execution.workflow_run_id == workflow_run.id - assert node_execution.node_id.startswith("node_") + assert node_execution.index in [0, 1, 2] # Check that index is one of the expected values + assert node_execution.node_id.startswith("node_") # Check that node_id starts with "node_" + assert node_execution.status == "succeeded" def test_get_workflow_run_node_executions_empty( self, db_session_with_containers, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index acd9d78c91..3c0a660e7c 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -3,8 +3,9 @@ from unittest.mock import patch import pytest from faker import Faker +from pydantic import ValidationError -from core.tools.errors import WorkflowToolHumanInputNotSupportedError +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from models.tools import WorkflowToolProvider from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService @@ -131,20 +132,24 @@ class TestWorkflowToolManageService: def _create_test_workflow_tool_parameters(self): """Helper method to create valid workflow tool parameters.""" return [ - { - "name": "input_text", - "description": "Input text for processing", - "form": "form", - "type": "string", - "required": True, - }, - { - "name": "output_format", - "description": "Output format specification", - "form": "form", - "type": "select", - "required": False, - }, + WorkflowToolParameterConfiguration.model_validate( + { + "name": "input_text", + "description": "Input text for processing", + "form": "form", + "type": "string", + "required": True, + } + ), + WorkflowToolParameterConfiguration.model_validate( + { + "name": "output_format", + "description": "Output format specification", + "form": "form", + "type": "select", + "required": False, + } + ), ] def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -209,7 +214,7 @@ class TestWorkflowToolManageService: assert created_tool_provider.label == tool_label assert created_tool_provider.icon == json.dumps(tool_icon) assert created_tool_provider.description == tool_description - assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters) + assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters]) assert created_tool_provider.privacy_policy == tool_privacy_policy assert created_tool_provider.version == workflow.version assert created_tool_provider.user_id == account.id @@ -354,18 +359,9 @@ class TestWorkflowToolManageService: app, account, workflow = self._create_test_app_and_account( db_session_with_containers, mock_external_service_dependencies ) - - # Setup invalid workflow tool parameters (missing required fields) - invalid_parameters = [ - { - "name": "input_text", - # Missing description and form fields - "type": "string", - "required": True, - } - ] # Attempt to create workflow tool with invalid parameters - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValidationError) as exc_info: + # Setup invalid workflow tool parameters (missing required fields) WorkflowToolManageService.create_workflow_tool( user_id=account.id, tenant_id=account.current_tenant.id, @@ -374,7 +370,16 @@ class TestWorkflowToolManageService: label=fake.word(), icon={"type": "emoji", "emoji": "🔧"}, description=fake.text(max_nb_chars=200), - parameters=invalid_parameters, + parameters=[ + WorkflowToolParameterConfiguration.model_validate( + { + "name": "input_text", + # Missing description and form fields + "type": "string", + "required": True, + } + ) + ], ) # Verify error message contains validation error @@ -508,62 +513,6 @@ class TestWorkflowToolManageService: assert tool_count == 0 - def test_create_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test workflow tool creation fails when workflow contains human input nodes. - - This test verifies: - - Human input nodes prevent workflow tool publishing - - Correct error message - - No database changes when workflow is invalid - """ - fake = Faker() - - # Create test data - app, account, workflow = self._create_test_app_and_account( - db_session_with_containers, mock_external_service_dependencies - ) - - workflow.graph = json.dumps( - { - "nodes": [ - { - "id": "human_input_node", - "data": {"type": "human-input"}, - } - ] - } - ) - - tool_parameters = self._create_test_workflow_tool_parameters() - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - WorkflowToolManageService.create_workflow_tool( - user_id=account.id, - tenant_id=account.current_tenant.id, - workflow_app_id=app.id, - name=fake.word(), - label=fake.word(), - icon={"type": "emoji", "emoji": "🔧"}, - description=fake.text(max_nb_chars=200), - parameters=tool_parameters, - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - - from extensions.ext_database import db - - tool_count = ( - db.session.query(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == account.current_tenant.id, - ) - .count() - ) - - assert tool_count == 0 - def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): """ Test successful workflow tool update with valid parameters. @@ -636,11 +585,12 @@ class TestWorkflowToolManageService: # Verify database state was updated db.session.refresh(created_tool) + assert created_tool is not None assert created_tool.name == updated_tool_name assert created_tool.label == updated_tool_label assert created_tool.icon == json.dumps(updated_tool_icon) assert created_tool.description == updated_tool_description - assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters) + assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters]) assert created_tool.privacy_policy == updated_tool_privacy_policy assert created_tool.version == workflow.version assert created_tool.updated_at is not None @@ -650,80 +600,6 @@ class TestWorkflowToolManageService: mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called() mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() - def test_update_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test workflow tool update fails when workflow contains human input nodes. - - This test verifies: - - Human input nodes prevent workflow tool updates - - Correct error message - - Existing tool data remains unchanged - """ - fake = Faker() - - # Create test data - app, account, workflow = self._create_test_app_and_account( - db_session_with_containers, mock_external_service_dependencies - ) - - # Create initial workflow tool - initial_tool_name = fake.word() - initial_tool_parameters = self._create_test_workflow_tool_parameters() - WorkflowToolManageService.create_workflow_tool( - user_id=account.id, - tenant_id=account.current_tenant.id, - workflow_app_id=app.id, - name=initial_tool_name, - label=fake.word(), - icon={"type": "emoji", "emoji": "🔧"}, - description=fake.text(max_nb_chars=200), - parameters=initial_tool_parameters, - ) - - from extensions.ext_database import db - - created_tool = ( - db.session.query(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == account.current_tenant.id, - WorkflowToolProvider.app_id == app.id, - ) - .first() - ) - - original_name = created_tool.name - - workflow.graph = json.dumps( - { - "nodes": [ - { - "id": "human_input_node", - "data": {"type": "human-input"}, - } - ] - } - ) - db.session.commit() - - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - WorkflowToolManageService.update_workflow_tool( - user_id=account.id, - tenant_id=account.current_tenant.id, - workflow_tool_id=created_tool.id, - name=fake.word(), - label=fake.word(), - icon={"type": "emoji", "emoji": "⚙️"}, - description=fake.text(max_nb_chars=200), - parameters=initial_tool_parameters, - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - - db.session.refresh(created_tool) - assert created_tool.name == original_name - def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): """ Test workflow tool update fails when tool does not exist. @@ -881,13 +757,15 @@ class TestWorkflowToolManageService: # Setup workflow tool parameters with FILE type file_parameters = [ - { - "name": "document", - "description": "Upload a document", - "form": "form", - "type": "file", - "required": False, - } + WorkflowToolParameterConfiguration.model_validate( + { + "name": "document", + "description": "Upload a document", + "form": "form", + "type": "file", + "required": False, + } + ) ] # Execute the method under test @@ -954,13 +832,15 @@ class TestWorkflowToolManageService: # Setup workflow tool parameters with FILES type files_parameters = [ - { - "name": "documents", - "description": "Upload multiple documents", - "form": "form", - "type": "files", - "required": False, - } + WorkflowToolParameterConfiguration.model_validate( + { + "name": "documents", + "description": "Upload multiple documents", + "form": "form", + "type": "files", + "required": False, + } + ) ] # Execute the method under test diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py deleted file mode 100644 index 5fd6c56f7a..0000000000 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ /dev/null @@ -1,214 +0,0 @@ -import uuid -from datetime import UTC, datetime -from unittest.mock import patch - -import pytest - -from configs import dify_config -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - HumanInputNodeData, - MemberRecipient, -) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from extensions.ext_storage import storage -from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom -from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient -from models.model import AppMode -from models.workflow import WorkflowPause, WorkflowRun, WorkflowType -from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task - - -@pytest.fixture(autouse=True) -def cleanup_database(db_session_with_containers): - db_session_with_containers.query(HumanInputFormRecipient).delete() - db_session_with_containers.query(HumanInputDelivery).delete() - db_session_with_containers.query(HumanInputForm).delete() - db_session_with_containers.query(WorkflowPause).delete() - db_session_with_containers.query(WorkflowRun).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() - db_session_with_containers.commit() - - -def _create_workspace_member(db_session_with_containers): - account = Account( - email="owner@example.com", - name="Owner", - password="password", - interface_language="en-US", - status=AccountStatus.ACTIVE, - ) - account.created_at = datetime.now(UTC) - account.updated_at = datetime.now(UTC) - db_session_with_containers.add(account) - db_session_with_containers.commit() - db_session_with_containers.refresh(account) - - tenant = Tenant(name="Test Tenant") - tenant.created_at = datetime.now(UTC) - tenant.updated_at = datetime.now(UTC) - db_session_with_containers.add(tenant) - db_session_with_containers.commit() - db_session_with_containers.refresh(tenant) - - tenant_join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=TenantAccountRole.OWNER, - ) - tenant_join.created_at = datetime.now(UTC) - tenant_join.updated_at = datetime.now(UTC) - db_session_with_containers.add(tenant_join) - db_session_with_containers.commit() - - return tenant, account - - -def _build_form(db_session_with_containers, tenant, account, *, app_id: str, workflow_execution_id: str): - delivery_method = EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(user_id=account.id), - ExternalRecipient(email="external@example.com"), - ], - ), - subject="Action needed {{ node_title }} {{#node1.value#}}", - body="Token {{ form_token }} link {{#url#}} content {{#node1.value#}}", - ) - ) - - node_data = HumanInputNodeData( - title="Review", - form_content="Form content", - delivery_methods=[delivery_method], - ) - - engine = db_session_with_containers.get_bind() - repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) - params = FormCreateParams( - app_id=app_id, - workflow_execution_id=workflow_execution_id, - node_id="node-1", - form_config=node_data, - rendered_content="Rendered", - delivery_methods=node_data.delivery_methods, - display_in_ui=False, - resolved_default_values={}, - ) - return repo.create_form(params) - - -def _create_workflow_pause_state( - db_session_with_containers, - *, - workflow_run_id: str, - workflow_id: str, - tenant_id: str, - app_id: str, - account_id: str, - variable_pool: VariablePool, -): - workflow_run = WorkflowRun( - id=workflow_run_id, - tenant_id=tenant_id, - app_id=app_id, - workflow_id=workflow_id, - type=WorkflowType.WORKFLOW, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN, - version="1", - graph="{}", - inputs="{}", - status=WorkflowExecutionStatus.PAUSED, - created_by_role=CreatorUserRole.ACCOUNT, - created_by=account_id, - created_at=datetime.now(UTC), - ) - db_session_with_containers.add(workflow_run) - - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - resumption_context = WorkflowResumptionContext( - generate_entity={ - "type": AppMode.WORKFLOW, - "entity": WorkflowAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=WorkflowUIBasedAppConfig( - tenant_id=tenant_id, - app_id=app_id, - app_mode=AppMode.WORKFLOW, - workflow_id=workflow_id, - ), - inputs={}, - files=[], - user_id=account_id, - stream=False, - invoke_from=InvokeFrom.WEB_APP, - workflow_execution_id=workflow_run_id, - ), - }, - serialized_graph_runtime_state=runtime_state.dumps(), - ) - - state_object_key = f"workflow_pause_states/{workflow_run_id}.json" - storage.save(state_object_key, resumption_context.dumps().encode()) - - pause_state = WorkflowPause( - workflow_id=workflow_id, - workflow_run_id=workflow_run_id, - state_object_key=state_object_key, - ) - db_session_with_containers.add(pause_state) - db_session_with_containers.commit() - - -def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers): - tenant, account = _create_workspace_member(db_session_with_containers) - workflow_run_id = str(uuid.uuid4()) - workflow_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - variable_pool = VariablePool() - variable_pool.add(["node1", "value"], "OK") - _create_workflow_pause_state( - db_session_with_containers, - workflow_run_id=workflow_run_id, - workflow_id=workflow_id, - tenant_id=tenant.id, - app_id=app_id, - account_id=account.id, - variable_pool=variable_pool, - ) - form_entity = _build_form( - db_session_with_containers, - tenant, - account, - app_id=app_id, - workflow_execution_id=workflow_run_id, - ) - - monkeypatch.setattr(dify_config, "APP_WEB_URL", "https://app.example.com") - - with patch("tasks.mail_human_input_delivery_task.mail") as mock_mail: - mock_mail.is_inited.return_value = True - - dispatch_human_input_email_task(form_id=form_entity.id, node_title="Approval") - - assert mock_mail.send.call_count == 2 - send_args = [call.kwargs for call in mock_mail.send.call_args_list] - recipients = {kwargs["to"] for kwargs in send_args} - assert recipients == {"owner@example.com", "external@example.com"} - assert all(kwargs["subject"] == "Action needed {{ node_title }} {{#node1.value#}}" for kwargs in send_args) - assert all("app.example.com/form/" in kwargs["html"] for kwargs in send_args) - assert all("content OK" in kwargs["html"] for kwargs in send_args) - assert all("{{ form_token }}" in kwargs["html"] for kwargs in send_args) diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 5f4f28cf4f..889e3d1d83 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -94,6 +94,11 @@ class PrunePausesTestCase: def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]: """Create test cases for pause workflow failure scenarios.""" return [ + PauseWorkflowFailureCase( + name="pause_already_paused_workflow", + initial_status=WorkflowExecutionStatus.PAUSED, + description="Should fail to pause an already paused workflow", + ), PauseWorkflowFailureCase( name="pause_completed_workflow", initial_status=WorkflowExecutionStatus.SUCCEEDED, diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index cf52980e57..6fce7849f9 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -164,62 +164,6 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): assert "timezone=UTC" in options -def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch): - os.environ.clear() - - monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") - monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") - monkeypatch.setenv("DB_USERNAME", "postgres") - monkeypatch.setenv("DB_PASSWORD", "postgres") - monkeypatch.setenv("DB_HOST", "localhost") - monkeypatch.setenv("DB_PORT", "5432") - monkeypatch.setenv("DB_DATABASE", "dify") - monkeypatch.setenv("REDIS_HOST", "redis.example.com") - monkeypatch.setenv("REDIS_PORT", "6380") - monkeypatch.setenv("REDIS_USERNAME", "user") - monkeypatch.setenv("REDIS_PASSWORD", "pass@word") - monkeypatch.setenv("REDIS_DB", "2") - monkeypatch.setenv("REDIS_USE_SSL", "true") - - config = DifyConfig() - - assert config.normalized_pubsub_redis_url == "rediss://user:pass%40word@redis.example.com:6380/2" - assert config.PUBSUB_REDIS_CHANNEL_TYPE == "pubsub" - - -def test_pubsub_redis_url_override(monkeypatch: pytest.MonkeyPatch): - os.environ.clear() - - monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") - monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") - monkeypatch.setenv("DB_USERNAME", "postgres") - monkeypatch.setenv("DB_PASSWORD", "postgres") - monkeypatch.setenv("DB_HOST", "localhost") - monkeypatch.setenv("DB_PORT", "5432") - monkeypatch.setenv("DB_DATABASE", "dify") - monkeypatch.setenv("PUBSUB_REDIS_URL", "redis://pubsub-host:6381/5") - - config = DifyConfig() - - assert config.normalized_pubsub_redis_url == "redis://pubsub-host:6381/5" - - -def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.MonkeyPatch): - os.environ.clear() - - monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") - monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") - monkeypatch.setenv("DB_USERNAME", "postgres") - monkeypatch.setenv("DB_PASSWORD", "postgres") - monkeypatch.setenv("DB_HOST", "localhost") - monkeypatch.setenv("DB_PORT", "5432") - monkeypatch.setenv("DB_DATABASE", "dify") - monkeypatch.setenv("REDIS_HOST", "") - - with pytest.raises(ValueError, match="PUBSUB_REDIS_URL must be set"): - _ = DifyConfig().normalized_pubsub_redis_url - - @pytest.mark.parametrize( ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), [ diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index da957d3a81..e3c1a617f7 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -51,8 +51,6 @@ def _patch_redis_clients_on_loaded_modules(): continue if hasattr(module, "redis_client"): module.redis_client = redis_mock - if hasattr(module, "pubsub_redis_client"): - module.pubsub_redis_client = redis_mock @pytest.fixture @@ -70,10 +68,7 @@ def _provide_app_context(app: Flask): def _patch_redis_clients(): """Patch redis_client to MagicMock only for unit test executions.""" - with ( - patch.object(ext_redis, "redis_client", redis_mock), - patch.object(ext_redis, "pubsub_redis_client", redis_mock), - ): + with patch.object(ext_redis, "redis_client", redis_mock): _patch_redis_clients_on_loaded_modules() yield diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index 2ac3dc037d..c557605916 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -16,9 +16,11 @@ if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] -@pytest.fixture(scope="module") -def app_module(): +def _load_app_module(): module_name = "controllers.console.app.app" + if module_name in sys.modules: + return sys.modules[module_name] + root = Path(__file__).resolve().parents[5] module_path = root / "controllers" / "console" / "app" / "app.py" @@ -57,12 +59,8 @@ def app_module(): stub_namespace = _StubNamespace() - original_modules: dict[str, ModuleType | None] = { - "controllers.console": sys.modules.get("controllers.console"), - "controllers.console.app": sys.modules.get("controllers.console.app"), - "controllers.common.schema": sys.modules.get("controllers.common.schema"), - module_name: sys.modules.get(module_name), - } + original_console = sys.modules.get("controllers.console") + original_app_pkg = sys.modules.get("controllers.console.app") stubbed_modules: list[tuple[str, ModuleType | None]] = [] console_module = ModuleType("controllers.console") @@ -107,35 +105,35 @@ def app_module(): module = util.module_from_spec(spec) sys.modules[module_name] = module - assert spec.loader is not None - spec.loader.exec_module(module) - try: - yield module + assert spec.loader is not None + spec.loader.exec_module(module) finally: for name, original in reversed(stubbed_modules): if original is not None: sys.modules[name] = original else: sys.modules.pop(name, None) - for name, original in original_modules.items(): - if original is not None: - sys.modules[name] = original - else: - sys.modules.pop(name, None) + if original_console is not None: + sys.modules["controllers.console"] = original_console + else: + sys.modules.pop("controllers.console", None) + if original_app_pkg is not None: + sys.modules["controllers.console.app"] = original_app_pkg + else: + sys.modules.pop("controllers.console.app", None) + + return module -@pytest.fixture(scope="module") -def app_models(app_module): - return SimpleNamespace( - AppDetailWithSite=app_module.AppDetailWithSite, - AppPagination=app_module.AppPagination, - AppPartial=app_module.AppPartial, - ) +_app_module = _load_app_module() +AppDetailWithSite = _app_module.AppDetailWithSite +AppPagination = _app_module.AppPagination +AppPartial = _app_module.AppPartial @pytest.fixture(autouse=True) -def patch_signed_url(monkeypatch, app_module): +def patch_signed_url(monkeypatch): """Ensure icon URL generation uses a deterministic helper for tests.""" def _fake_signed_url(key: str | None) -> str | None: @@ -143,7 +141,7 @@ def patch_signed_url(monkeypatch, app_module): return None return f"signed:{key}" - monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url) def _ts(hour: int = 12) -> datetime: @@ -171,8 +169,7 @@ def _dummy_workflow(): ) -def test_app_partial_serialization_uses_aliases(app_models): - AppPartial = app_models.AppPartial +def test_app_partial_serialization_uses_aliases(): created_at = _ts() app_obj = SimpleNamespace( id="app-1", @@ -207,8 +204,7 @@ def test_app_partial_serialization_uses_aliases(app_models): assert serialized["tags"][0]["name"] == "Utilities" -def test_app_detail_with_site_includes_nested_serialization(app_models): - AppDetailWithSite = app_models.AppDetailWithSite +def test_app_detail_with_site_includes_nested_serialization(): timestamp = _ts(14) site = SimpleNamespace( code="site-code", @@ -257,8 +253,7 @@ def test_app_detail_with_site_includes_nested_serialization(app_models): assert serialized["site"]["created_at"] == int(timestamp.timestamp()) -def test_app_pagination_aliases_per_page_and_has_next(app_models): - AppPagination = app_models.AppPagination +def test_app_pagination_aliases_per_page_and_has_next(): item_one = SimpleNamespace( id="app-10", name="Paginated One", diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py deleted file mode 100644 index 86a3b2bd93..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py +++ /dev/null @@ -1,229 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -from flask import Flask -from pydantic import ValidationError - -from controllers.console import wraps as console_wraps -from controllers.console.app import workflow as workflow_module -from controllers.console.app import wraps as app_wraps -from libs import login as login_lib -from models.account import Account, AccountStatus, TenantAccountRole -from models.model import AppMode - - -def _make_account() -> Account: - account = Account(name="tester", email="tester@example.com") - account.status = AccountStatus.ACTIVE - account.role = TenantAccountRole.OWNER - account.id = "account-123" # type: ignore[assignment] - account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] - account._get_current_object = lambda: account # type: ignore[attr-defined] - return account - - -def _make_app(mode: AppMode) -> SimpleNamespace: - return SimpleNamespace(id="app-123", tenant_id="tenant-123", mode=mode.value) - - -def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None: - # Skip setup and auth guardrails - monkeypatch.setattr("configs.dify_config.EDITION", "CLOUD") - monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) - monkeypatch.setattr(login_lib, "current_user", account) - monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) - monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") - monkeypatch.delenv("INIT_PASSWORD", raising=False) - - # Avoid hitting the database when resolving the app model - monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) - - -@dataclass -class PreviewCase: - resource_cls: type - path: str - mode: AppMode - - -@pytest.mark.parametrize( - "case", - [ - PreviewCase( - resource_cls=workflow_module.AdvancedChatDraftHumanInputFormPreviewApi, - path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-42/form/preview", - mode=AppMode.ADVANCED_CHAT, - ), - PreviewCase( - resource_cls=workflow_module.WorkflowDraftHumanInputFormPreviewApi, - path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-42/form/preview", - mode=AppMode.WORKFLOW, - ), - ], -) -def test_human_input_preview_delegates_to_service( - app: Flask, monkeypatch: pytest.MonkeyPatch, case: PreviewCase -) -> None: - account = _make_account() - app_model = _make_app(case.mode) - _patch_console_guards(monkeypatch, account, app_model) - - preview_payload = { - "form_id": "node-42", - "form_content": "
example
", - "inputs": [{"name": "topic"}], - "actions": [{"id": "continue"}], - } - service_instance = MagicMock() - service_instance.get_human_input_form_preview.return_value = preview_payload - monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) - - with app.test_request_context(case.path, method="POST", json={"inputs": {"topic": "tech"}}): - response = case.resource_cls().post(app_id=app_model.id, node_id="node-42") - - assert response == preview_payload - service_instance.get_human_input_form_preview.assert_called_once_with( - app_model=app_model, - account=account, - node_id="node-42", - inputs={"topic": "tech"}, - ) - - -@dataclass -class SubmitCase: - resource_cls: type - path: str - mode: AppMode - - -@pytest.mark.parametrize( - "case", - [ - SubmitCase( - resource_cls=workflow_module.AdvancedChatDraftHumanInputFormRunApi, - path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-99/form/run", - mode=AppMode.ADVANCED_CHAT, - ), - SubmitCase( - resource_cls=workflow_module.WorkflowDraftHumanInputFormRunApi, - path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-99/form/run", - mode=AppMode.WORKFLOW, - ), - ], -) -def test_human_input_submit_forwards_payload(app: Flask, monkeypatch: pytest.MonkeyPatch, case: SubmitCase) -> None: - account = _make_account() - app_model = _make_app(case.mode) - _patch_console_guards(monkeypatch, account, app_model) - - result_payload = {"node_id": "node-99", "outputs": {"__rendered_content": "

done

"}, "action": "approve"} - service_instance = MagicMock() - service_instance.submit_human_input_form_preview.return_value = result_payload - monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) - - with app.test_request_context( - case.path, - method="POST", - json={"form_inputs": {"answer": "42"}, "inputs": {"#node-1.result#": "LLM output"}, "action": "approve"}, - ): - response = case.resource_cls().post(app_id=app_model.id, node_id="node-99") - - assert response == result_payload - service_instance.submit_human_input_form_preview.assert_called_once_with( - app_model=app_model, - account=account, - node_id="node-99", - form_inputs={"answer": "42"}, - inputs={"#node-1.result#": "LLM output"}, - action="approve", - ) - - -@dataclass -class DeliveryTestCase: - resource_cls: type - path: str - mode: AppMode - - -@pytest.mark.parametrize( - "case", - [ - DeliveryTestCase( - resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi, - path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test", - mode=AppMode.ADVANCED_CHAT, - ), - DeliveryTestCase( - resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi, - path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test", - mode=AppMode.WORKFLOW, - ), - ], -) -def test_human_input_delivery_test_calls_service( - app: Flask, monkeypatch: pytest.MonkeyPatch, case: DeliveryTestCase -) -> None: - account = _make_account() - app_model = _make_app(case.mode) - _patch_console_guards(monkeypatch, account, app_model) - - service_instance = MagicMock() - monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) - - with app.test_request_context( - case.path, - method="POST", - json={"delivery_method_id": "delivery-123"}, - ): - response = case.resource_cls().post(app_id=app_model.id, node_id="node-7") - - assert response == {} - service_instance.test_human_input_delivery.assert_called_once_with( - app_model=app_model, - account=account, - node_id="node-7", - delivery_method_id="delivery-123", - inputs={}, - ) - - -def test_human_input_delivery_test_maps_validation_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: - account = _make_account() - app_model = _make_app(AppMode.ADVANCED_CHAT) - _patch_console_guards(monkeypatch, account, app_model) - - service_instance = MagicMock() - service_instance.test_human_input_delivery.side_effect = ValueError("bad delivery method") - monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) - - with app.test_request_context( - "/console/api/apps/app-123/workflows/draft/human-input/nodes/node-1/delivery-test", - method="POST", - json={"delivery_method_id": "bad"}, - ): - with pytest.raises(ValueError): - workflow_module.WorkflowDraftHumanInputDeliveryTestApi().post(app_id=app_model.id, node_id="node-1") - - -def test_human_input_preview_rejects_non_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: - account = _make_account() - app_model = _make_app(AppMode.ADVANCED_CHAT) - _patch_console_guards(monkeypatch, account, app_model) - - with app.test_request_context( - "/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-1/form/preview", - method="POST", - json={"inputs": ["not-a-dict"]}, - ): - with pytest.raises(ValidationError): - workflow_module.AdvancedChatDraftHumanInputFormPreviewApi().post(app_id=app_model.id, node_id="node-1") diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py deleted file mode 100644 index 34d6a2232c..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import Mock - -import pytest -from flask import Flask - -from controllers.console import wraps as console_wraps -from controllers.console.app import workflow_run as workflow_run_module -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.entities import FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType -from libs import login as login_lib -from models.account import Account, AccountStatus, TenantAccountRole -from models.workflow import WorkflowRun - - -def _make_account() -> Account: - account = Account(name="tester", email="tester@example.com") - account.status = AccountStatus.ACTIVE - account.role = TenantAccountRole.OWNER - account.id = "account-123" # type: ignore[assignment] - account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] - account._get_current_object = lambda: account # type: ignore[attr-defined] - return account - - -def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account) -> None: - monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) - monkeypatch.setattr(login_lib, "current_user", account) - monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) - monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(workflow_run_module, "current_user", account) - monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") - - -class _PauseEntity: - def __init__(self, paused_at: datetime, reasons: list[HumanInputRequired]): - self.paused_at = paused_at - self._reasons = reasons - - def get_pause_reasons(self): - return self._reasons - - -def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: - account = _make_account() - _patch_console_guards(monkeypatch, account) - monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com") - - workflow_run = Mock(spec=WorkflowRun) - workflow_run.status = WorkflowExecutionStatus.PAUSED - workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0) - fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run)) - monkeypatch.setattr(workflow_run_module, "db", fake_db) - - reason = HumanInputRequired( - form_id="form-1", - form_content="content", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - actions=[UserAction(id="approve", title="Approve")], - node_id="node-1", - node_title="Ask Name", - form_token="backstage-token", - ) - pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason]) - - repo = Mock() - repo.get_workflow_pause.return_value = pause_entity - monkeypatch.setattr( - workflow_run_module.DifyAPIRepositoryFactory, - "create_api_workflow_run_repository", - lambda *_, **__: repo, - ) - - with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): - response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") - - assert status == 200 - assert response["paused_at"] == "2024-01-01T12:00:00Z" - assert response["paused_nodes"][0]["node_id"] == "node-1" - assert response["paused_nodes"][0]["pause_type"]["type"] == "human_input" - assert ( - response["paused_nodes"][0]["pause_type"]["backstage_input_url"] - == "https://web.example.com/form/backstage-token" - ) - assert "pending_human_inputs" not in response diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py new file mode 100644 index 0000000000..b9bc42fb25 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py @@ -0,0 +1,46 @@ +import builtins +from unittest.mock import patch + +import pytest +from flask import Flask +from flask.views import MethodView + +from extensions import ext_fastopenapi + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + app.secret_key = "test-secret-key" + return app + + +def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch): + ext_fastopenapi.init_app(app) + monkeypatch.delenv("INIT_PASSWORD", raising=False) + + with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"): + client = app.test_client() + response = client.get("/console/api/init") + + assert response.status_code == 200 + assert response.get_json() == {"status": "finished"} + + +def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch): + ext_fastopenapi.init_app(app) + monkeypatch.setenv("INIT_PASSWORD", "test-init-password") + + with ( + patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0), + ): + client = app.test_client() + response = client.post("/console/api/init", json={"password": "test-init-password"}) + + assert response.status_code == 201 + assert response.get_json() == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py new file mode 100644 index 0000000000..cb2604cf1c --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py @@ -0,0 +1,92 @@ +import builtins +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import patch + +import httpx +import pytest +from flask import Flask +from flask.views import MethodView + +from extensions import ext_fastopenapi + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +def test_console_remote_files_fastopenapi_get_info(app: Flask): + ext_fastopenapi.init_app(app) + + response = httpx.Response( + 200, + request=httpx.Request("HEAD", "http://example.com/file.txt"), + headers={"Content-Type": "text/plain", "Content-Length": "10"}, + ) + + with patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response): + client = app.test_client() + encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt" + resp = client.get(f"/console/api/remote-files/{encoded_url}") + + assert resp.status_code == 200 + assert resp.get_json() == {"file_type": "text/plain", "file_length": 10} + + +def test_console_remote_files_fastopenapi_upload(app: Flask): + ext_fastopenapi.init_app(app) + + head_response = httpx.Response( + 200, + request=httpx.Request("GET", "http://example.com/file.txt"), + content=b"hello", + ) + file_info = SimpleNamespace( + extension="txt", + size=5, + filename="file.txt", + mimetype="text/plain", + ) + uploaded = SimpleNamespace( + id="file-id", + name="file.txt", + size=5, + extension="txt", + mime_type="text/plain", + created_by="user-id", + created_at=datetime(2024, 1, 1), + ) + + with ( + patch("controllers.console.remote_files.db", new=SimpleNamespace(engine=object())), + patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response), + patch("controllers.console.remote_files.helpers.guess_file_info_from_response", return_value=file_info), + patch("controllers.console.remote_files.FileService.is_file_size_within_limit", return_value=True), + patch("controllers.console.remote_files.FileService.__init__", return_value=None), + patch("controllers.console.remote_files.current_account_with_tenant", return_value=(object(), "tenant-id")), + patch("controllers.console.remote_files.FileService.upload_file", return_value=uploaded), + patch("controllers.console.remote_files.file_helpers.get_signed_file_url", return_value="signed-url"), + ): + client = app.test_client() + resp = client.post( + "/console/api/remote-files/upload", + json={"url": "http://example.com/file.txt"}, + ) + + assert resp.status_code == 201 + assert resp.get_json() == { + "id": "file-id", + "name": "file.txt", + "size": 5, + "extension": "txt", + "url": "signed-url", + "mime_type": "text/plain", + "created_by": "user-id", + "created_at": int(uploaded.created_at.timestamp()), + } diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_tags.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_tags.py new file mode 100644 index 0000000000..62d143f32d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_tags.py @@ -0,0 +1,222 @@ +import builtins +import contextlib +import importlib +import sys +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from flask.views import MethodView + +from extensions import ext_fastopenapi +from extensions.ext_database import db + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["SECRET_KEY"] = "test-secret" + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" + + db.init_app(app) + + return app + + +@pytest.fixture(autouse=True) +def fix_method_view_issue(monkeypatch): + if not hasattr(builtins, "MethodView"): + monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False) + + +def _create_isolated_router(): + import controllers.fastopenapi + + router_class = type(controllers.fastopenapi.console_router) + return router_class() + + +@contextlib.contextmanager +def _patch_auth_and_router(temp_router): + def noop(func): + return func + + default_user = MagicMock(has_edit_permission=True, is_dataset_editor=False) + + with ( + patch("controllers.fastopenapi.console_router", temp_router), + patch("extensions.ext_fastopenapi.console_router", temp_router), + patch("controllers.console.wraps.setup_required", side_effect=noop), + patch("libs.login.login_required", side_effect=noop), + patch("controllers.console.wraps.account_initialization_required", side_effect=noop), + patch("controllers.console.wraps.edit_permission_required", side_effect=noop), + patch("libs.login.current_account_with_tenant", return_value=(default_user, "tenant-id")), + patch("configs.dify_config.EDITION", "CLOUD"), + ): + import extensions.ext_fastopenapi + + importlib.reload(extensions.ext_fastopenapi) + + yield + + +def _force_reload_module(target_module: str, alias_module: str): + if target_module in sys.modules: + del sys.modules[target_module] + if alias_module in sys.modules: + del sys.modules[alias_module] + + module = importlib.import_module(target_module) + sys.modules[alias_module] = sys.modules[target_module] + + return module + + +def _dedupe_routes(router): + seen = set() + unique_routes = [] + for path, method, endpoint in reversed(router.get_routes()): + key = (path, method, endpoint.__name__) + if key in seen: + continue + seen.add(key) + unique_routes.append((path, method, endpoint)) + router._routes = list(reversed(unique_routes)) + + +def _cleanup_modules(target_module: str, alias_module: str): + if target_module in sys.modules: + del sys.modules[target_module] + if alias_module in sys.modules: + del sys.modules[alias_module] + + +@pytest.fixture +def mock_tags_module_env(): + target_module = "controllers.console.tag.tags" + alias_module = "api.controllers.console.tag.tags" + temp_router = _create_isolated_router() + + try: + with _patch_auth_and_router(temp_router): + tags_module = _force_reload_module(target_module, alias_module) + _dedupe_routes(temp_router) + yield tags_module + finally: + _cleanup_modules(target_module, alias_module) + + +def test_list_tags_success(app: Flask, mock_tags_module_env): + # Arrange + tag = SimpleNamespace(id="tag-1", name="Alpha", type="app", binding_count=2) + with patch("controllers.console.tag.tags.TagService.get_tags", return_value=[tag]): + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.get("/console/api/tags?type=app&keyword=Alpha") + + # Assert + assert response.status_code == 200 + assert response.get_json() == [ + {"id": "tag-1", "name": "Alpha", "type": "app", "binding_count": 2}, + ] + + +def test_create_tag_success(app: Flask, mock_tags_module_env): + # Arrange + tag = SimpleNamespace(id="tag-2", name="Beta", type="app") + with patch("controllers.console.tag.tags.TagService.save_tags", return_value=tag) as mock_save: + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.post("/console/api/tags", json={"name": "Beta", "type": "app"}) + + # Assert + assert response.status_code == 200 + assert response.get_json() == { + "id": "tag-2", + "name": "Beta", + "type": "app", + "binding_count": 0, + } + mock_save.assert_called_once_with({"name": "Beta", "type": "app"}) + + +def test_update_tag_success(app: Flask, mock_tags_module_env): + # Arrange + tag = SimpleNamespace(id="tag-3", name="Gamma", type="app") + with ( + patch("controllers.console.tag.tags.TagService.update_tags", return_value=tag) as mock_update, + patch("controllers.console.tag.tags.TagService.get_tag_binding_count", return_value=4), + ): + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.patch( + "/console/api/tags/11111111-1111-1111-1111-111111111111", + json={"name": "Gamma", "type": "app"}, + ) + + # Assert + assert response.status_code == 200 + assert response.get_json() == { + "id": "tag-3", + "name": "Gamma", + "type": "app", + "binding_count": 4, + } + mock_update.assert_called_once_with( + {"name": "Gamma", "type": "app"}, + "11111111-1111-1111-1111-111111111111", + ) + + +def test_delete_tag_success(app: Flask, mock_tags_module_env): + # Arrange + with patch("controllers.console.tag.tags.TagService.delete_tag") as mock_delete: + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.delete("/console/api/tags/11111111-1111-1111-1111-111111111111") + + # Assert + assert response.status_code == 204 + mock_delete.assert_called_once_with("11111111-1111-1111-1111-111111111111") + + +def test_create_tag_binding_success(app: Flask, mock_tags_module_env): + # Arrange + payload = {"tag_ids": ["tag-1", "tag-2"], "target_id": "target-1", "type": "app"} + with patch("controllers.console.tag.tags.TagService.save_tag_binding") as mock_bind: + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.post("/console/api/tag-bindings/create", json=payload) + + # Assert + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + mock_bind.assert_called_once_with(payload) + + +def test_delete_tag_binding_success(app: Flask, mock_tags_module_env): + # Arrange + payload = {"tag_id": "tag-1", "target_id": "target-1", "type": "app"} + with patch("controllers.console.tag.tags.TagService.delete_tag_binding") as mock_unbind: + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.post("/console/api/tag-bindings/remove", json=payload) + + # Assert + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + mock_unbind.assert_called_once_with(payload) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py new file mode 100644 index 0000000000..94c3019d5e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -0,0 +1,364 @@ +"""Endpoint tests for controllers.console.workspace.tool_providers.""" + +from __future__ import annotations + +import builtins +import importlib +from contextlib import contextmanager +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from flask.views import MethodView + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +_CONTROLLER_MODULE: ModuleType | None = None +_WRAPS_MODULE: ModuleType | None = None +_CONTROLLER_PATCHERS: list[patch] = [] + + +@contextmanager +def _mock_db(): + mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True)) + with patch("extensions.ext_database.db.session", mock_session): + yield + + +@pytest.fixture +def app() -> Flask: + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture +def controller_module(monkeypatch: pytest.MonkeyPatch): + module_name = "controllers.console.workspace.tool_providers" + global _CONTROLLER_MODULE + if _CONTROLLER_MODULE is None: + + def _noop(func): + return func + + patch_targets = [ + ("libs.login.login_required", _noop), + ("controllers.console.wraps.setup_required", _noop), + ("controllers.console.wraps.account_initialization_required", _noop), + ("controllers.console.wraps.is_admin_or_owner_required", _noop), + ("controllers.console.wraps.enterprise_license_required", _noop), + ] + for target, value in patch_targets: + patcher = patch(target, value) + patcher.start() + _CONTROLLER_PATCHERS.append(patcher) + monkeypatch.setenv("DIFY_SETUP_READY", "true") + with _mock_db(): + _CONTROLLER_MODULE = importlib.import_module(module_name) + + module = _CONTROLLER_MODULE + monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) + + # Ensure decorators that consult deployment edition do not reach the database. + global _WRAPS_MODULE + wraps_module = importlib.import_module("controllers.console.wraps") + _WRAPS_MODULE = wraps_module + monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD") + + login_module = importlib.import_module("libs.login") + monkeypatch.setattr(login_module, "check_csrf_token", lambda *args, **kwargs: None) + return module + + +def _mock_account(user_id: str = "user-123") -> SimpleNamespace: + return SimpleNamespace(id=user_id, status="active", is_authenticated=True, current_tenant_id=None) + + +def _set_current_account( + monkeypatch: pytest.MonkeyPatch, + controller_module: ModuleType, + user: SimpleNamespace, + tenant_id: str, +) -> None: + def _getter(): + return user, tenant_id + + user.current_tenant_id = tenant_id + + monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter) + if _WRAPS_MODULE is not None: + monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter) + + login_module = importlib.import_module("libs.login") + monkeypatch.setattr(login_module, "_get_user", lambda: user) + + +def test_tool_provider_list_calls_service_with_query( + app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch +): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-456") + + service_mock = MagicMock(return_value=[{"provider": "builtin"}]) + monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock) + + with app.test_request_context("/workspaces/current/tool-providers?type=builtin"): + response = controller_module.ToolProviderListApi().get() + + assert response == [{"provider": "builtin"}] + service_mock.assert_called_once_with(user.id, "tenant-456", "builtin") + + +def test_builtin_provider_add_passes_payload( + app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch +): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-456") + + service_mock = MagicMock(return_value={"status": "ok"}) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock) + + payload = { + "credentials": {"api_key": "sk-test"}, + "name": "MyTool", + "type": controller_module.CredentialType.API_KEY, + } + + with app.test_request_context( + "/workspaces/current/tool-provider/builtin/openai/add", + method="POST", + json=payload, + ): + response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai") + + assert response == {"status": "ok"} + service_mock.assert_called_once_with( + user_id="user-123", + tenant_id="tenant-456", + provider="openai", + credentials={"api_key": "sk-test"}, + name="MyTool", + api_type=controller_module.CredentialType.API_KEY, + ) + + +def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-789") + _set_current_account(monkeypatch, controller_module, user, "tenant-789") + + service_mock = MagicMock(return_value=[{"name": "tool-a"}]) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock) + monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload) + + with app.test_request_context( + "/workspaces/current/tool-provider/builtin/my-provider/tools", + method="GET", + ): + response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider") + + assert response == [{"name": "tool-a"}] + service_mock.assert_called_once_with("tenant-789", "my-provider") + + +def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-9") + _set_current_account(monkeypatch, controller_module, user, "tenant-9") + service_mock = MagicMock(return_value={"info": True}) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock) + + with app.test_request_context("/info", method="GET"): + resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo") + + assert resp == {"info": True} + service_mock.assert_called_once_with("tenant-9", "demo") + + +def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-cred") + _set_current_account(monkeypatch, controller_module, user, "tenant-cred") + service_mock = MagicMock(return_value=[{"cred": 1}]) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "get_builtin_tool_provider_credentials", + service_mock, + ) + + with app.test_request_context("/creds", method="GET"): + resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo") + + assert resp == [{"cred": 1}] + service_mock.assert_called_once_with(tenant_id="tenant-cred", provider_name="demo") + + +def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-10") + service_mock = MagicMock(return_value={"schema": "ok"}) + monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock) + + with app.test_request_context("/remote?url=https://example.com/"): + resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get() + + assert resp == {"schema": "ok"} + service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/") + + +def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-11") + service_mock = MagicMock(return_value=[{"tool": "t"}]) + monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock) + + with app.test_request_context("/tools?provider=foo"): + resp = controller_module.ToolApiProviderListToolsApi().get() + + assert resp == [{"tool": "t"}] + service_mock.assert_called_once_with(user.id, "tenant-11", "foo") + + +def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-12") + service_mock = MagicMock(return_value={"provider": "foo"}) + monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock) + + with app.test_request_context("/get?provider=foo"): + resp = controller_module.ToolApiProviderGetApi().get() + + assert resp == {"provider": "foo"} + service_mock.assert_called_once_with(user.id, "tenant-12", "foo") + + +def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-13") + _set_current_account(monkeypatch, controller_module, user, "tenant-13") + service_mock = MagicMock(return_value={"schema": True}) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "list_builtin_provider_credentials_schema", + service_mock, + ) + + with app.test_request_context("/schema", method="GET"): + resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get( + provider="demo", credential_type="api-key" + ) + + assert resp == {"schema": True} + service_mock.assert_called_once() + + +def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf") + tool_service = MagicMock(return_value={"wf": 1}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "get_workflow_tool_by_tool_id", + tool_service, + ) + + tool_id = "00000000-0000-0000-0000-000000000001" + with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"): + resp = controller_module.ToolWorkflowProviderGetApi().get() + + assert resp == {"wf": 1} + tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id) + + +def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf2") + service_mock = MagicMock(return_value={"app": 1}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "get_workflow_tool_by_app_id", + service_mock, + ) + + app_id = "00000000-0000-0000-0000-000000000002" + with app.test_request_context(f"/workflow?workflow_app_id={app_id}"): + resp = controller_module.ToolWorkflowProviderGetApi().get() + + assert resp == {"app": 1} + service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id) + + +def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf3") + service_mock = MagicMock(return_value=[{"id": 1}]) + monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock) + + tool_id = "00000000-0000-0000-0000-000000000003" + with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"): + resp = controller_module.ToolWorkflowProviderListToolApi().get() + + assert resp == [{"id": 1}] + service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id) + + +def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-bt") + + provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"}) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "list_builtin_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/builtin"): + resp = controller_module.ToolBuiltinListApi().get() + + assert resp == [{"name": "builtin"}] + + +def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-api") + _set_current_account(monkeypatch, controller_module, user, "tenant-api") + + provider = SimpleNamespace(to_dict=lambda: {"name": "api"}) + monkeypatch.setattr( + controller_module.ApiToolManageService, + "list_api_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/api"): + resp = controller_module.ToolApiListApi().get() + + assert resp == [{"name": "api"}] + + +def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf4") + + provider = SimpleNamespace(to_dict=lambda: {"name": "wf"}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "list_tenant_workflow_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/workflow"): + resp = controller_module.ToolWorkflowListApi().get() + + assert resp == [{"name": "wf"}] + + +def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-label") + _set_current_account(monkeypatch, controller_module, user, "tenant-labels") + monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"]) + + with app.test_request_context("/tool-labels"): + resp = controller_module.ToolLabelsApi().get() + + assert resp == ["a", "b"] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py deleted file mode 100644 index fcaa61a871..0000000000 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ /dev/null @@ -1,25 +0,0 @@ -from types import SimpleNamespace - -from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField -from core.workflow.enums import WorkflowExecutionStatus - - -def test_workflow_run_status_field_with_enum() -> None: - field = WorkflowRunStatusField() - obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED) - - assert field.output("status", obj) == "paused" - - -def test_workflow_run_outputs_field_paused_returns_empty() -> None: - field = WorkflowRunOutputsField() - obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED, outputs_dict={"foo": "bar"}) - - assert field.output("outputs", obj) == {} - - -def test_workflow_run_outputs_field_running_returns_outputs() -> None: - field = WorkflowRunOutputsField() - obj = SimpleNamespace(status=WorkflowExecutionStatus.RUNNING, outputs_dict={"foo": "bar"}) - - assert field.output("outputs", obj) == {"foo": "bar"} diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py deleted file mode 100644 index 4fb735b033..0000000000 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ /dev/null @@ -1,456 +0,0 @@ -"""Unit tests for controllers.web.human_input_form endpoints.""" - -from __future__ import annotations - -import json -from datetime import UTC, datetime -from types import SimpleNamespace -from typing import Any -from unittest.mock import MagicMock - -import pytest -from flask import Flask -from werkzeug.exceptions import Forbidden - -import controllers.web.human_input_form as human_input_module -import controllers.web.site as site_module -from controllers.web.error import WebFormRateLimitExceededError -from models.human_input import RecipientType -from services.human_input_service import FormExpiredError - -HumanInputFormApi = human_input_module.HumanInputFormApi -TenantStatus = human_input_module.TenantStatus - - -@pytest.fixture -def app() -> Flask: - """Configure a minimal Flask app for request contexts.""" - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - -class _FakeSession: - """Simple stand-in for db.session that returns pre-seeded objects.""" - - def __init__(self, mapping: dict[str, Any]): - self._mapping = mapping - self._model_name: str | None = None - - def query(self, model): - self._model_name = model.__name__ - return self - - def where(self, *args, **kwargs): - return self - - def first(self): - assert self._model_name is not None - return self._mapping.get(self._model_name) - - -class _FakeDB: - """Minimal db stub exposing engine and session.""" - - def __init__(self, session: _FakeSession): - self.session = session - self.engine = object() - - -def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask): - """GET returns form definition merged with site payload.""" - - expiration_time = datetime(2099, 1, 1, tzinfo=UTC) - - class _FakeDefinition: - def model_dump(self): - return { - "form_content": "Raw content", - "rendered_content": "Rendered {{#$output.name#}}", - "inputs": [{"type": "text", "output_variable_name": "name", "default": None}], - "default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}}, - "user_actions": [{"id": "approve", "title": "Approve", "button_style": "default"}], - } - - class _FakeForm: - def __init__(self, expiration: datetime): - self.workflow_run_id = "workflow-1" - self.app_id = "app-1" - self.tenant_id = "tenant-1" - self.expiration_time = expiration - self.recipient_type = RecipientType.BACKSTAGE - - def get_definition(self): - return _FakeDefinition() - - form = _FakeForm(expiration_time) - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = False - monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - - tenant = SimpleNamespace( - id="tenant-1", - status=TenantStatus.NORMAL, - plan="basic", - custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False}, - ) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) - workflow_run = SimpleNamespace(app_id="app-1") - site_model = SimpleNamespace( - title="My Site", - icon_type="emoji", - icon="robot", - icon_background="#fff", - description="desc", - default_language="en", - chat_color_theme="light", - chat_color_theme_inverted=False, - copyright=None, - privacy_policy=None, - custom_disclaimer=None, - prompt_public=False, - show_workflow_steps=True, - use_icon_as_answer_icon=False, - ) - - # Patch service to return fake form. - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = form - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - - # Patch db session. - db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model})) - monkeypatch.setattr(human_input_module, "db", db_stub) - - monkeypatch.setattr( - site_module.FeatureService, - "get_features", - lambda tenant_id: SimpleNamespace(can_replace_logo=True), - ) - - with app.test_request_context("/api/form/human_input/token-1", method="GET"): - response = HumanInputFormApi().get("token-1") - - body = json.loads(response.get_data(as_text=True)) - assert set(body.keys()) == { - "site", - "form_content", - "inputs", - "resolved_default_values", - "user_actions", - "expiration_time", - } - assert body["form_content"] == "Rendered {{#$output.name#}}" - assert body["inputs"] == [{"type": "text", "output_variable_name": "name", "default": None}] - assert body["resolved_default_values"] == {"name": "Alice", "age": "30", "meta": '{"k": "v"}'} - assert body["user_actions"] == [{"id": "approve", "title": "Approve", "button_style": "default"}] - assert body["expiration_time"] == int(expiration_time.timestamp()) - assert body["site"] == { - "app_id": "app-1", - "end_user_id": None, - "enable_site": True, - "site": { - "title": "My Site", - "chat_color_theme": "light", - "chat_color_theme_inverted": False, - "icon_type": "emoji", - "icon": "robot", - "icon_background": "#fff", - "icon_url": None, - "description": "desc", - "copyright": None, - "privacy_policy": None, - "custom_disclaimer": None, - "default_language": "en", - "prompt_public": False, - "show_workflow_steps": True, - "use_icon_as_answer_icon": False, - }, - "model_config": None, - "plan": "basic", - "can_replace_logo": True, - "custom_config": { - "remove_webapp_brand": True, - "replace_webapp_logo": None, - }, - } - service_mock.get_form_by_token.assert_called_once_with("token-1") - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") - - -def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask): - """GET returns form payload for backstage token.""" - - expiration_time = datetime(2099, 1, 2, tzinfo=UTC) - - class _FakeDefinition: - def model_dump(self): - return { - "form_content": "Raw content", - "rendered_content": "Rendered", - "inputs": [], - "default_values": {}, - "user_actions": [], - } - - class _FakeForm: - def __init__(self, expiration: datetime): - self.workflow_run_id = "workflow-1" - self.app_id = "app-1" - self.tenant_id = "tenant-1" - self.expiration_time = expiration - - def get_definition(self): - return _FakeDefinition() - - form = _FakeForm(expiration_time) - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = False - monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - tenant = SimpleNamespace( - id="tenant-1", - status=TenantStatus.NORMAL, - plan="basic", - custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False}, - ) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) - workflow_run = SimpleNamespace(app_id="app-1") - site_model = SimpleNamespace( - title="My Site", - icon_type="emoji", - icon="robot", - icon_background="#fff", - description="desc", - default_language="en", - chat_color_theme="light", - chat_color_theme_inverted=False, - copyright=None, - privacy_policy=None, - custom_disclaimer=None, - prompt_public=False, - show_workflow_steps=True, - use_icon_as_answer_icon=False, - ) - - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = form - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - - db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model})) - monkeypatch.setattr(human_input_module, "db", db_stub) - - monkeypatch.setattr( - site_module.FeatureService, - "get_features", - lambda tenant_id: SimpleNamespace(can_replace_logo=True), - ) - - with app.test_request_context("/api/form/human_input/token-1", method="GET"): - response = HumanInputFormApi().get("token-1") - - body = json.loads(response.get_data(as_text=True)) - assert set(body.keys()) == { - "site", - "form_content", - "inputs", - "resolved_default_values", - "user_actions", - "expiration_time", - } - assert body["form_content"] == "Rendered" - assert body["inputs"] == [] - assert body["resolved_default_values"] == {} - assert body["user_actions"] == [] - assert body["expiration_time"] == int(expiration_time.timestamp()) - assert body["site"] == { - "app_id": "app-1", - "end_user_id": None, - "enable_site": True, - "site": { - "title": "My Site", - "chat_color_theme": "light", - "chat_color_theme_inverted": False, - "icon_type": "emoji", - "icon": "robot", - "icon_background": "#fff", - "icon_url": None, - "description": "desc", - "copyright": None, - "privacy_policy": None, - "custom_disclaimer": None, - "default_language": "en", - "prompt_public": False, - "show_workflow_steps": True, - "use_icon_as_answer_icon": False, - }, - "model_config": None, - "plan": "basic", - "can_replace_logo": True, - "custom_config": { - "remove_webapp_brand": True, - "replace_webapp_logo": None, - }, - } - service_mock.get_form_by_token.assert_called_once_with("token-1") - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") - - -def test_get_form_raises_forbidden_when_site_missing(monkeypatch: pytest.MonkeyPatch, app: Flask): - """GET raises Forbidden if site cannot be resolved.""" - - expiration_time = datetime(2099, 1, 3, tzinfo=UTC) - - class _FakeDefinition: - def model_dump(self): - return { - "form_content": "Raw content", - "rendered_content": "Rendered", - "inputs": [], - "default_values": {}, - "user_actions": [], - } - - class _FakeForm: - def __init__(self, expiration: datetime): - self.workflow_run_id = "workflow-1" - self.app_id = "app-1" - self.tenant_id = "tenant-1" - self.expiration_time = expiration - - def get_definition(self): - return _FakeDefinition() - - form = _FakeForm(expiration_time) - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = False - monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - tenant = SimpleNamespace(status=TenantStatus.NORMAL) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) - workflow_run = SimpleNamespace(app_id="app-1") - - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = form - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - - db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": None})) - monkeypatch.setattr(human_input_module, "db", db_stub) - - with app.test_request_context("/api/form/human_input/token-1", method="GET"): - with pytest.raises(Forbidden): - HumanInputFormApi().get("token-1") - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") - - -def test_submit_form_accepts_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask): - """POST forwards backstage submissions to the service.""" - - class _FakeForm: - recipient_type = RecipientType.BACKSTAGE - - form = _FakeForm() - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = False - monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = form - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) - - with app.test_request_context( - "/api/form/human_input/token-1", - method="POST", - json={"inputs": {"content": "ok"}, "action": "approve"}, - ): - response, status = HumanInputFormApi().post("token-1") - - assert status == 200 - assert response == {} - service_mock.submit_form_by_token.assert_called_once_with( - recipient_type=RecipientType.BACKSTAGE, - form_token="token-1", - selected_action_id="approve", - form_data={"content": "ok"}, - submission_end_user_id=None, - ) - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") - - -def test_submit_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask): - """POST rejects submissions when rate limit is exceeded.""" - - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = True - monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = None - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) - - with app.test_request_context( - "/api/form/human_input/token-1", - method="POST", - json={"inputs": {"content": "ok"}, "action": "approve"}, - ): - with pytest.raises(WebFormRateLimitExceededError): - HumanInputFormApi().post("token-1") - - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_not_called() - service_mock.get_form_by_token.assert_not_called() - - -def test_get_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask): - """GET rejects requests when rate limit is exceeded.""" - - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = True - monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = None - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) - - with app.test_request_context("/api/form/human_input/token-1", method="GET"): - with pytest.raises(WebFormRateLimitExceededError): - HumanInputFormApi().get("token-1") - - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_not_called() - service_mock.get_form_by_token.assert_not_called() - - -def test_get_form_raises_expired(monkeypatch: pytest.MonkeyPatch, app: Flask): - class _FakeForm: - pass - - form = _FakeForm() - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = False - monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = form - service_mock.ensure_form_active.side_effect = FormExpiredError("form-id") - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) - - with app.test_request_context("/api/form/human_input/token-1", method="GET"): - with pytest.raises(FormExpiredError): - HumanInputFormApi().get("token-1") - - service_mock.ensure_form_active.assert_called_once_with(form) - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") diff --git a/api/tests/unit_tests/controllers/web/test_message_list.py b/api/tests/unit_tests/controllers/web/test_message_list.py index 1c096bfbcf..2835f7ffbf 100644 --- a/api/tests/unit_tests/controllers/web/test_message_list.py +++ b/api/tests/unit_tests/controllers/web/test_message_list.py @@ -3,7 +3,6 @@ from __future__ import annotations import builtins -import uuid from datetime import datetime from types import ModuleType, SimpleNamespace from unittest.mock import patch @@ -13,8 +12,6 @@ import pytest from flask import Flask from flask.views import MethodView -from core.entities.execution_extra_content import HumanInputContent - # Ensure flask_restx.api finds MethodView during import. if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] @@ -140,12 +137,6 @@ def test_message_list_mapping(app: Flask) -> None: status="success", error=None, message_metadata_dict={"meta": "value"}, - extra_contents=[ - HumanInputContent( - workflow_run_id=str(uuid.uuid4()), - submitted=True, - ) - ], ) pagination = SimpleNamespace(limit=20, has_more=False, data=[message]) @@ -178,8 +169,6 @@ def test_message_list_mapping(app: Flask) -> None: assert item["agent_thoughts"][0]["chain_id"] == "chain-1" assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp()) - assert item["extra_contents"][0]["workflow_run_id"] == message.extra_contents[0].workflow_run_id - assert item["extra_contents"][0]["submitted"] == message.extra_contents[0].submitted assert item["message_files"][0]["id"] == "file-dict" assert item["message_files"][1]["id"] == "file-obj" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py deleted file mode 100644 index a94b5445f7..0000000000 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py +++ /dev/null @@ -1,187 +0,0 @@ -from __future__ import annotations - -from contextlib import contextmanager -from datetime import datetime -from types import SimpleNamespace -from unittest import mock - -import pytest - -from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent -from core.workflow.entities.pause_reason import HumanInputRequired -from models.enums import MessageStatus -from models.execution_extra_content import HumanInputContent -from models.model import EndUser - - -def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline: - pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__( - pipeline_module.AdvancedChatAppGenerateTaskPipeline - ) - pipeline._workflow_run_id = "run-1" - pipeline._message_id = "message-1" - pipeline._workflow_tenant_id = "tenant-1" - return pipeline - - -def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None: - pipeline = _build_pipeline() - monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1") - - captured_session: dict[str, mock.Mock] = {} - - @contextmanager - def fake_session(): - session = mock.Mock() - session.scalar.return_value = None - captured_session["session"] = session - yield session - - pipeline._database_session = fake_session # type: ignore[method-assign] - - pipeline._persist_human_input_extra_content(node_id="node-1") - - session = captured_session["session"] - session.add.assert_called_once() - content = session.add.call_args.args[0] - assert isinstance(content, HumanInputContent) - assert content.workflow_run_id == "run-1" - assert content.message_id == "message-1" - assert content.form_id == "form-1" - - -def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: - pipeline = _build_pipeline() - monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None) - - called = {"value": False} - - @contextmanager - def fake_session(): - called["value"] = True - session = mock.Mock() - yield session - - pipeline._database_session = fake_session # type: ignore[method-assign] - - pipeline._persist_human_input_extra_content(node_id="node-1") - - assert called["value"] is False - - -def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None: - pipeline = _build_pipeline() - monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1") - - captured_session: dict[str, mock.Mock] = {} - - @contextmanager - def fake_session(): - session = mock.Mock() - session.scalar.return_value = HumanInputContent( - workflow_run_id="run-1", - message_id="message-1", - form_id="form-1", - ) - captured_session["session"] = session - yield session - - pipeline._database_session = fake_session # type: ignore[method-assign] - - pipeline._persist_human_input_extra_content(node_id="node-1") - - session = captured_session["session"] - session.add.assert_not_called() - - -def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None: - pipeline = _build_pipeline() - pipeline._application_generate_entity = SimpleNamespace(task_id="task-1") - pipeline._workflow_response_converter = mock.Mock() - pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = [] - pipeline._ensure_graph_runtime_initialized = mock.Mock( - return_value=SimpleNamespace( - total_tokens=0, - node_run_steps=0, - ), - ) - pipeline._save_message = mock.Mock() - message = SimpleNamespace(status=MessageStatus.NORMAL) - pipeline._get_message = mock.Mock(return_value=message) - pipeline._persist_human_input_extra_content = mock.Mock() - pipeline._base_task_pipeline = mock.Mock() - pipeline._base_task_pipeline.queue_manager = mock.Mock() - pipeline._message_saved_on_pause = False - - @contextmanager - def fake_session(): - session = mock.Mock() - yield session - - pipeline._database_session = fake_session # type: ignore[method-assign] - - reason = HumanInputRequired( - form_id="form-1", - form_content="content", - inputs=[], - actions=[], - node_id="node-1", - node_title="Approval", - form_token="token-1", - resolved_default_values={}, - ) - event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"]) - - list(pipeline._handle_workflow_paused_event(event)) - - pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1") - assert message.status == MessageStatus.PAUSED - - -def test_resume_appends_chunks_to_paused_answer() -> None: - app_config = SimpleNamespace(app_id="app-1", tenant_id="tenant-1", sensitive_word_avoidance=None) - application_generate_entity = SimpleNamespace( - app_config=app_config, - files=[], - workflow_run_id="run-1", - query="hello", - invoke_from=InvokeFrom.WEB_APP, - inputs={}, - task_id="task-1", - ) - queue_manager = SimpleNamespace(graph_runtime_state=None) - conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat") - message = SimpleNamespace( - id="message-1", - created_at=datetime(2024, 1, 1), - query="hello", - answer="before", - status=MessageStatus.PAUSED, - ) - user = EndUser() - user.id = "user-1" - user.session_id = "session-1" - workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={}) - - pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=True, - dialogue_count=1, - draft_var_saver_factory=SimpleNamespace(), - ) - - pipeline._get_message = mock.Mock(return_value=message) - pipeline._recorded_files = [] - - list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after"))) - pipeline._save_message(session=mock.Mock()) - - assert message.answer == "beforeafter" - assert message.status == MessageStatus.NORMAL diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py deleted file mode 100644 index 1c36b4d12b..0000000000 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ /dev/null @@ -1,87 +0,0 @@ -from datetime import UTC, datetime -from types import SimpleNamespace - -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable - - -def _build_converter(): - system_variables = SystemVariable( - files=[], - user_id="user-1", - app_id="app-1", - workflow_id="wf-1", - workflow_execution_id="run-1", - ) - runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - app_entity = SimpleNamespace( - task_id="task-1", - app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"), - invoke_from=InvokeFrom.EXPLORE, - files=[], - inputs={}, - workflow_execution_id="run-1", - call_depth=0, - ) - account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com") - return WorkflowResponseConverter( - application_generate_entity=app_entity, - user=account, - system_variables=system_variables, - ) - - -def test_human_input_form_filled_stream_response_contains_rendered_content(): - converter = _build_converter() - converter.workflow_start_to_stream_response( - task_id="task-1", - workflow_run_id="run-1", - workflow_id="wf-1", - reason=WorkflowStartReason.INITIAL, - ) - - queue_event = QueueHumanInputFormFilledEvent( - node_execution_id="exec-1", - node_id="node-1", - node_type="human-input", - node_title="Human Input", - rendered_content="# Title\nvalue", - action_id="Approve", - action_text="Approve", - ) - - resp = converter.human_input_form_filled_to_stream_response(event=queue_event, task_id="task-1") - - assert resp.workflow_run_id == "run-1" - assert resp.data.node_id == "node-1" - assert resp.data.node_title == "Human Input" - assert resp.data.rendered_content.startswith("# Title") - assert resp.data.action_id == "Approve" - - -def test_human_input_form_timeout_stream_response_contains_timeout_metadata(): - converter = _build_converter() - converter.workflow_start_to_stream_response( - task_id="task-1", - workflow_run_id="run-1", - workflow_id="wf-1", - reason=WorkflowStartReason.INITIAL, - ) - - queue_event = QueueHumanInputFormTimeoutEvent( - node_id="node-1", - node_type="human-input", - node_title="Human Input", - expiration_time=datetime(2025, 1, 1, tzinfo=UTC), - ) - - resp = converter.human_input_form_timeout_to_stream_response(event=queue_event, task_id="task-1") - - assert resp.workflow_run_id == "run-1" - assert resp.data.node_id == "node-1" - assert resp.data.node_title == "Human Input" - assert resp.data.expiration_time == 1735689600 diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py deleted file mode 100644 index 0a9794e41c..0000000000 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ /dev/null @@ -1,56 +0,0 @@ -from types import SimpleNamespace - -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable - - -def _build_converter() -> WorkflowResponseConverter: - """Construct a minimal WorkflowResponseConverter for testing.""" - system_variables = SystemVariable( - files=[], - user_id="user-1", - app_id="app-1", - workflow_id="wf-1", - workflow_execution_id="run-1", - ) - runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - app_entity = SimpleNamespace( - task_id="task-1", - app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"), - invoke_from=InvokeFrom.EXPLORE, - files=[], - inputs={}, - workflow_execution_id="run-1", - call_depth=0, - ) - account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com") - return WorkflowResponseConverter( - application_generate_entity=app_entity, - user=account, - system_variables=system_variables, - ) - - -def test_workflow_start_stream_response_carries_resumption_reason(): - converter = _build_converter() - resp = converter.workflow_start_to_stream_response( - task_id="task-1", - workflow_run_id="run-1", - workflow_id="wf-1", - reason=WorkflowStartReason.RESUMPTION, - ) - assert resp.data.reason is WorkflowStartReason.RESUMPTION - - -def test_workflow_start_stream_response_carries_initial_reason(): - converter = _build_converter() - resp = converter.workflow_start_to_stream_response( - task_id="task-1", - workflow_run_id="run-1", - workflow_id="wf-1", - reason=WorkflowStartReason.INITIAL, - ) - assert resp.data.reason is WorkflowStartReason.INITIAL diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index d25bff92dc..6b40bf462b 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -23,7 +23,6 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import NodeType from core.workflow.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now @@ -125,12 +124,7 @@ class TestWorkflowResponseConverter: original_data = {"large_field": "x" * 10000, "metadata": "info"} truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} - converter.workflow_start_to_stream_response( - task_id="bootstrap", - workflow_run_id="run-id", - workflow_id="wf-id", - reason=WorkflowStartReason.INITIAL, - ) + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -166,12 +160,7 @@ class TestWorkflowResponseConverter: original_data = {"small": "data"} - converter.workflow_start_to_stream_response( - task_id="bootstrap", - workflow_run_id="run-id", - workflow_id="wf-id", - reason=WorkflowStartReason.INITIAL, - ) + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -202,12 +191,7 @@ class TestWorkflowResponseConverter: """Test node finish response when process_data is None.""" converter = self.create_workflow_response_converter() - converter.workflow_start_to_stream_response( - task_id="bootstrap", - workflow_run_id="run-id", - workflow_id="wf-id", - reason=WorkflowStartReason.INITIAL, - ) + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -241,12 +225,7 @@ class TestWorkflowResponseConverter: original_data = {"large_field": "x" * 10000, "metadata": "info"} truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} - converter.workflow_start_to_stream_response( - task_id="bootstrap", - workflow_run_id="run-id", - workflow_id="wf-id", - reason=WorkflowStartReason.INITIAL, - ) + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -282,12 +261,7 @@ class TestWorkflowResponseConverter: original_data = {"small": "data"} - converter.workflow_start_to_stream_response( - task_id="bootstrap", - workflow_run_id="run-id", - workflow_id="wf-id", - reason=WorkflowStartReason.INITIAL, - ) + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -426,7 +400,6 @@ class TestWorkflowResponseConverterServiceApiTruncation: task_id="test-task-id", workflow_run_id="test-workflow-run-id", workflow_id="test-workflow-id", - reason=WorkflowStartReason.INITIAL, ) return converter diff --git a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py deleted file mode 100644 index f0d9afc0db..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py +++ /dev/null @@ -1,139 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps import message_based_app_generator -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from core.app.task_pipeline import message_cycle_manager -from core.app.task_pipeline.message_cycle_manager import MessageCycleManager -from models.model import AppMode, Conversation, Message - - -def _make_app_config() -> WorkflowUIBasedAppConfig: - return WorkflowUIBasedAppConfig( - tenant_id="tenant-id", - app_id="app-id", - app_mode=AppMode.ADVANCED_CHAT, - workflow_id="workflow-id", - additional_features=AppAdditionalFeatures(), - variables=[], - ) - - -def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity: - return AdvancedChatAppGenerateEntity( - task_id="task-id", - app_config=app_config, - file_upload_config=None, - conversation_id=None, - inputs={}, - query="hello", - files=[], - parent_message_id=None, - user_id="user-id", - stream=True, - invoke_from=InvokeFrom.WEB_APP, - extras={}, - workflow_run_id="workflow-run-id", - ) - - -@pytest.fixture(autouse=True) -def _mock_db_session(monkeypatch): - session = MagicMock() - - def refresh_side_effect(obj): - if isinstance(obj, Conversation) and obj.id is None: - obj.id = "generated-conversation-id" - if isinstance(obj, Message) and obj.id is None: - obj.id = "generated-message-id" - - session.refresh.side_effect = refresh_side_effect - session.add.return_value = None - session.commit.return_value = None - - monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session)) - return session - - -def test_init_generate_records_sets_conversation_metadata(): - app_config = _make_app_config() - entity = _make_generate_entity(app_config) - - generator = AdvancedChatAppGenerator() - - conversation, _ = generator._init_generate_records(entity, conversation=None) - - assert entity.conversation_id == "generated-conversation-id" - assert conversation.id == "generated-conversation-id" - assert entity.is_new_conversation is True - - -def test_init_generate_records_marks_existing_conversation(): - app_config = _make_app_config() - entity = _make_generate_entity(app_config) - - existing_conversation = Conversation( - app_id=app_config.app_id, - app_model_config_id=None, - model_provider=None, - override_model_configs=None, - model_id=None, - mode=app_config.app_mode.value, - name="existing", - inputs={}, - introduction="", - system_instruction="", - system_instruction_tokens=0, - status="normal", - invoke_from=InvokeFrom.WEB_APP.value, - from_source="api", - from_end_user_id="user-id", - from_account_id=None, - ) - existing_conversation.id = "existing-conversation-id" - - generator = AdvancedChatAppGenerator() - - conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation) - - assert entity.conversation_id == "existing-conversation-id" - assert conversation is existing_conversation - assert entity.is_new_conversation is False - - -def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch): - app_config = _make_app_config() - entity = _make_generate_entity(app_config) - entity.conversation_id = "existing-conversation-id" - entity.is_new_conversation = True - entity.extras = {"auto_generate_conversation_name": True} - - captured = {} - - class DummyThread: - def __init__(self, **kwargs): - self.kwargs = kwargs - self.started = False - - def start(self): - self.started = True - - def fake_thread(**kwargs): - thread = DummyThread(**kwargs) - captured["thread"] = thread - return thread - - monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread) - - manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock()) - thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello") - - assert thread is captured["thread"] - assert thread.started is True - assert entity.is_new_conversation is False diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py deleted file mode 100644 index 87b8dc51e7..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py +++ /dev/null @@ -1,127 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.app.app_config.entities import ( - AppAdditionalFeatures, - EasyUIBasedAppConfig, - EasyUIBasedAppModelConfigFrom, - ModelConfigEntity, - PromptTemplateEntity, -) -from core.app.apps import message_based_app_generator -from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom -from models.model import AppMode, Conversation, Message - - -class DummyModelConf: - def __init__(self, provider: str = "mock-provider", model: str = "mock-model") -> None: - self.provider = provider - self.model = model - - -class DummyCompletionGenerateEntity: - __slots__ = ("app_config", "invoke_from", "user_id", "query", "inputs", "files", "model_conf") - app_config: EasyUIBasedAppConfig - invoke_from: InvokeFrom - user_id: str - query: str - inputs: dict - files: list - model_conf: DummyModelConf - - def __init__(self, app_config: EasyUIBasedAppConfig) -> None: - self.app_config = app_config - self.invoke_from = InvokeFrom.WEB_APP - self.user_id = "user-id" - self.query = "hello" - self.inputs = {} - self.files = [] - self.model_conf = DummyModelConf() - - -def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig: - return EasyUIBasedAppConfig( - tenant_id="tenant-id", - app_id="app-id", - app_mode=app_mode, - app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG, - app_model_config_id="model-config-id", - app_model_config_dict={}, - model=ModelConfigEntity(provider="mock-provider", model="mock-model", mode="chat"), - prompt_template=PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="Hello", - ), - additional_features=AppAdditionalFeatures(), - variables=[], - ) - - -def _make_chat_generate_entity(app_config: EasyUIBasedAppConfig) -> ChatAppGenerateEntity: - return ChatAppGenerateEntity.model_construct( - task_id="task-id", - app_config=app_config, - model_conf=DummyModelConf(), - file_upload_config=None, - conversation_id=None, - inputs={}, - query="hello", - files=[], - parent_message_id=None, - user_id="user-id", - stream=False, - invoke_from=InvokeFrom.WEB_APP, - extras={}, - call_depth=0, - trace_manager=None, - ) - - -@pytest.fixture(autouse=True) -def _mock_db_session(monkeypatch): - session = MagicMock() - - def refresh_side_effect(obj): - if isinstance(obj, Conversation) and obj.id is None: - obj.id = "generated-conversation-id" - if isinstance(obj, Message) and obj.id is None: - obj.id = "generated-message-id" - - session.refresh.side_effect = refresh_side_effect - session.add.return_value = None - session.commit.return_value = None - - monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session)) - return session - - -def test_init_generate_records_skips_conversation_fields_for_non_conversation_entity(): - app_config = _make_app_config(AppMode.COMPLETION) - entity = DummyCompletionGenerateEntity(app_config=app_config) - - generator = MessageBasedAppGenerator() - - conversation, message = generator._init_generate_records(entity, conversation=None) - - assert conversation.id == "generated-conversation-id" - assert message.id == "generated-message-id" - assert hasattr(entity, "conversation_id") is False - assert hasattr(entity, "is_new_conversation") is False - - -def test_init_generate_records_sets_conversation_fields_for_chat_entity(): - app_config = _make_app_config(AppMode.CHAT) - entity = _make_chat_generate_entity(app_config) - - generator = MessageBasedAppGenerator() - - conversation, _ = generator._init_generate_records(entity, conversation=None) - - assert entity.conversation_id == "generated-conversation-id" - assert entity.is_new_conversation is True - assert conversation.id == "generated-conversation-id" diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py deleted file mode 100644 index 97c993928e..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ /dev/null @@ -1,287 +0,0 @@ -import sys -import time -from pathlib import Path -from types import ModuleType, SimpleNamespace -from typing import Any - -API_DIR = str(Path(__file__).resolve().parents[5]) -if API_DIR not in sys.path: - sys.path.insert(0, API_DIR) - -import core.workflow.nodes.human_input.entities # noqa: F401 -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_events import ( - GraphEngineEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunSucceededEvent, -) -from core.workflow.node_events import NodeRunResult, PauseRequestedEvent -from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable - -if "core.ops.ops_trace_manager" not in sys.modules: - ops_stub = ModuleType("core.ops.ops_trace_manager") - - class _StubTraceQueueManager: - def __init__(self, *_, **__): - pass - - ops_stub.TraceQueueManager = _StubTraceQueueManager - sys.modules["core.ops.ops_trace_manager"] = ops_stub - - -class _StubToolNodeData(BaseNodeData): - pause_on: bool = False - - -class _StubToolNode(Node[_StubToolNodeData]): - node_type = NodeType.TOOL - - @classmethod - def version(cls) -> str: - return "1" - - def init_node_data(self, data): - self._node_data = _StubToolNodeData.model_validate(data) - - def _get_error_strategy(self): - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self): - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - - def _run(self): - if self.node_data.pause_on: - yield PauseRequestedEvent(reason=SchedulingPause(message="test pause")) - return - - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"value": f"{self.id}-done"}, - ) - yield self._convert_node_run_result_to_graph_node_event(result) - - -def _patch_tool_node(mocker): - original_create_node = DifyNodeFactory.create_node - - def _patched_create_node(self, node_config: dict[str, object]) -> Node: - node_data = node_config.get("data", {}) - if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value: - return _StubToolNode( - id=str(node_config["id"]), - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return original_create_node(self, node_config) - - mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) - - -def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: - node_data = data.model_dump() - node_data["type"] = node_type.value - return node_data - - -def _build_graph_config(*, pause_on: str | None) -> dict[str, object]: - start_data = StartNodeData(title="start", variables=[]) - tool_data_a = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_a") - tool_data_b = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_b") - tool_data_c = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_c") - end_data = EndNodeData( - title="end", - outputs=[OutputVariableEntity(variable="result", value_selector=["tool_c", "value"])], - desc=None, - ) - - nodes = [ - {"id": "start", "data": _node_data(NodeType.START, start_data)}, - {"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)}, - {"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)}, - {"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)}, - {"id": "end", "data": _node_data(NodeType.END, end_data)}, - ] - edges = [ - {"source": "start", "target": "tool_a"}, - {"source": "tool_a", "target": "tool_b"}, - {"source": "tool_b", "target": "tool_c"}, - {"source": "tool_c", "target": "end"}, - ] - return {"nodes": nodes, "edges": edges} - - -def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph: - graph_config = _build_graph_config(pause_on=pause_on) - params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - - node_factory = DifyNodeFactory( - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return Graph.init(graph_config=graph_config, node_factory=node_factory) - - -def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - variable_pool.system_variables.workflow_execution_id = run_id - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]: - command_channel = InMemoryChannel() - graph = _build_graph(runtime_state, pause_on=pause_on) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=command_channel, - ) - - events: list[GraphEngineEvent] = [] - for event in engine.run(): - events.append(event) - return events - - -def _node_successes(events: list[GraphEngineEvent]) -> list[str]: - return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)] - - -def test_workflow_app_pause_resume_matches_baseline(mocker): - _patch_tool_node(mocker) - - baseline_state = _build_runtime_state("baseline") - baseline_events = _run_with_optional_pause(baseline_state, pause_on=None) - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_nodes = _node_successes(baseline_events) - baseline_outputs = baseline_state.outputs - - paused_state = _build_runtime_state("paused-run") - paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") - assert isinstance(paused_events[-1], GraphRunPausedEvent) - paused_nodes = _node_successes(paused_events) - snapshot = paused_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - generator = wf_app_gen_module.WorkflowAppGenerator() - - def _fake_generate(**kwargs): - state: GraphRuntimeState = kwargs["graph_runtime_state"] - events = _run_with_optional_pause(state, pause_on=None) - return _node_successes(events) - - mocker.patch.object(generator, "_generate", side_effect=_fake_generate) - - resumed_nodes = generator.resume( - app_model=SimpleNamespace(mode="workflow"), - workflow=SimpleNamespace(), - user=SimpleNamespace(), - application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API), - graph_runtime_state=resumed_state, - workflow_execution_repository=SimpleNamespace(), - workflow_node_execution_repository=SimpleNamespace(), - ) - - assert paused_nodes + resumed_nodes == baseline_nodes - assert resumed_state.outputs == baseline_outputs - - -def test_advanced_chat_pause_resume_matches_baseline(mocker): - _patch_tool_node(mocker) - - baseline_state = _build_runtime_state("adv-baseline") - baseline_events = _run_with_optional_pause(baseline_state, pause_on=None) - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_nodes = _node_successes(baseline_events) - baseline_outputs = baseline_state.outputs - - paused_state = _build_runtime_state("adv-paused") - paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") - assert isinstance(paused_events[-1], GraphRunPausedEvent) - paused_nodes = _node_successes(paused_events) - snapshot = paused_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - generator = adv_app_gen_module.AdvancedChatAppGenerator() - - def _fake_generate(**kwargs): - state: GraphRuntimeState = kwargs["graph_runtime_state"] - events = _run_with_optional_pause(state, pause_on=None) - return _node_successes(events) - - mocker.patch.object(generator, "_generate", side_effect=_fake_generate) - - resumed_nodes = generator.resume( - app_model=SimpleNamespace(mode="workflow"), - workflow=SimpleNamespace(), - user=SimpleNamespace(), - conversation=SimpleNamespace(id="conv"), - message=SimpleNamespace(id="msg"), - application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API), - workflow_execution_repository=SimpleNamespace(), - workflow_node_execution_repository=SimpleNamespace(), - graph_runtime_state=resumed_state, - ) - - assert paused_nodes + resumed_nodes == baseline_nodes - assert resumed_state.outputs == baseline_outputs - - -def test_resume_emits_resumption_start_reason(mocker) -> None: - _patch_tool_node(mocker) - - paused_state = _build_runtime_state("resume-reason") - paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") - initial_start = next(event for event in paused_events if isinstance(event, GraphRunStartedEvent)) - assert initial_start.reason == WorkflowStartReason.INITIAL - - resumed_state = GraphRuntimeState.from_snapshot(paused_state.dumps()) - resumed_events = _run_with_optional_pause(resumed_state, pause_on=None) - resume_start = next(event for event in resumed_events if isinstance(event, GraphRunStartedEvent)) - assert resume_start.reason == WorkflowStartReason.RESUMPTION diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py deleted file mode 100644 index 7b5447c01e..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import annotations - -import json -import queue - -import pytest - -from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.entities.task_entities import StreamEvent -from models.model import AppMode - - -class FakeSubscription: - def __init__(self, message_queue: queue.Queue[bytes], state: dict[str, bool]) -> None: - self._queue = message_queue - self._state = state - self._closed = False - - def __enter__(self): - self._state["subscribed"] = True - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def close(self) -> None: - self._closed = True - - def receive(self, timeout: float | None = 0.1) -> bytes | None: - if self._closed: - return None - try: - if timeout is None: - return self._queue.get() - return self._queue.get(timeout=timeout) - except queue.Empty: - return None - - -class FakeTopic: - def __init__(self) -> None: - self._queue: queue.Queue[bytes] = queue.Queue() - self._state = {"subscribed": False} - - def subscribe(self) -> FakeSubscription: - return FakeSubscription(self._queue, self._state) - - def publish(self, payload: bytes) -> None: - self._queue.put(payload) - - @property - def subscribed(self) -> bool: - return self._state["subscribed"] - - -def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch): - topic = FakeTopic() - - def fake_get_response_topic(cls, app_mode, workflow_run_id): - return topic - - monkeypatch.setattr(MessageBasedAppGenerator, "get_response_topic", classmethod(fake_get_response_topic)) - - def on_subscribe() -> None: - assert topic.subscribed is True - event = {"event": StreamEvent.WORKFLOW_FINISHED.value} - topic.publish(json.dumps(event).encode()) - - generator = MessageBasedAppGenerator.retrieve_events( - AppMode.WORKFLOW, - "workflow-run-id", - idle_timeout=0.5, - on_subscribe=on_subscribe, - ) - - assert next(generator) == StreamEvent.PING.value - event = next(generator) - assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value - with pytest.raises(StopIteration): - next(generator) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py index 7e8367c6c4..83ac3a5591 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py @@ -1,6 +1,3 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator @@ -20,193 +17,3 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false(): args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False} assert WorkflowAppGenerator()._should_prepare_user_inputs(args) - - -def test_resume_delegates_to_generate(mocker): - generator = WorkflowAppGenerator() - mock_generate = mocker.patch.object(generator, "_generate", return_value="ok") - - application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger") - runtime_state = MagicMock(name="runtime-state") - pause_config = MagicMock(name="pause-config") - - result = generator.resume( - app_model=MagicMock(), - workflow=MagicMock(), - user=MagicMock(), - application_generate_entity=application_generate_entity, - graph_runtime_state=runtime_state, - workflow_execution_repository=MagicMock(), - workflow_node_execution_repository=MagicMock(), - graph_engine_layers=("layer",), - pause_state_config=pause_config, - variable_loader=MagicMock(), - ) - - assert result == "ok" - mock_generate.assert_called_once() - kwargs = mock_generate.call_args.kwargs - assert kwargs["graph_runtime_state"] is runtime_state - assert kwargs["pause_state_config"] is pause_config - assert kwargs["streaming"] is False - assert kwargs["invoke_from"] == "debugger" - - -def test_generate_appends_pause_layer_and_forwards_state(mocker): - generator = WorkflowAppGenerator() - - mock_queue_manager = MagicMock() - mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager) - - fake_current_app = MagicMock() - fake_current_app._get_current_object.return_value = MagicMock() - mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app) - - mocker.patch( - "core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert", - return_value="converted", - ) - mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response") - mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock()) - - pause_layer = MagicMock(name="pause-layer") - mocker.patch( - "core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", - return_value=pause_layer, - ) - - dummy_session = MagicMock() - dummy_session.close = MagicMock() - mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session) - - worker_kwargs: dict[str, object] = {} - - class DummyThread: - def __init__(self, target, kwargs): - worker_kwargs["target"] = target - worker_kwargs["kwargs"] = kwargs - - def start(self): - return None - - mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread) - - app_model = SimpleNamespace(mode="workflow") - app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf") - application_generate_entity = SimpleNamespace( - task_id="task", - user_id="user", - invoke_from="service-api", - app_config=app_config, - files=[], - stream=True, - workflow_execution_id="run", - ) - - graph_runtime_state = MagicMock() - - result = generator._generate( - app_model=app_model, - workflow=MagicMock(), - user=MagicMock(), - application_generate_entity=application_generate_entity, - invoke_from="service-api", - workflow_execution_repository=MagicMock(), - workflow_node_execution_repository=MagicMock(), - streaming=True, - graph_engine_layers=("base-layer",), - graph_runtime_state=graph_runtime_state, - pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"), - ) - - assert result == "converted" - assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer) - assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state - - -def test_resume_path_runs_worker_with_runtime_state(mocker): - generator = WorkflowAppGenerator() - runtime_state = MagicMock(name="runtime-state") - - pause_layer = MagicMock(name="pause-layer") - mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer) - - queue_manager = MagicMock() - mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager) - - mocker.patch.object(generator, "_handle_response", return_value="raw-response") - mocker.patch( - "core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert", - side_effect=lambda response, invoke_from: response, - ) - - fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock()) - mocker.patch("core.app.apps.workflow.app_generator.db", fake_db) - - workflow = SimpleNamespace( - id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1" - ) - end_user = SimpleNamespace(session_id="end-user-session") - app_record = SimpleNamespace(id="app") - - session = MagicMock() - session.__enter__.return_value = session - session.__exit__.return_value = False - session.scalar.side_effect = [workflow, end_user, app_record] - mocker.patch("core.app.apps.workflow.app_generator.session_factory", return_value=session) - - runner_instance = MagicMock() - - def runner_ctor(**kwargs): - assert kwargs["graph_runtime_state"] is runtime_state - return runner_instance - - mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor) - - class ImmediateThread: - def __init__(self, target, kwargs): - target(**kwargs) - - def start(self): - return None - - mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread) - - mocker.patch( - "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", - return_value=MagicMock(), - ) - mocker.patch( - "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", - return_value=MagicMock(), - ) - - pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner") - - app_model = SimpleNamespace(mode="workflow") - app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow") - application_generate_entity = SimpleNamespace( - task_id="task", - user_id="user", - invoke_from="service-api", - app_config=app_config, - files=[], - stream=True, - workflow_execution_id="run", - trace_manager=MagicMock(), - ) - - result = generator.resume( - app_model=app_model, - workflow=workflow, - user=MagicMock(), - application_generate_entity=application_generate_entity, - graph_runtime_state=runtime_state, - workflow_execution_repository=MagicMock(), - workflow_node_execution_repository=MagicMock(), - pause_state_config=pause_config, - ) - - assert result == "raw-response" - runner_instance.run.assert_called_once() - queue_manager.graph_runtime_state = runtime_state diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py deleted file mode 100644 index f4efb240c0..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.graph_events.graph import GraphRunPausedEvent - - -class _DummyQueueManager: - def __init__(self): - self.published = [] - - def publish(self, event, _from): - self.published.append(event) - - -class _DummyRuntimeState: - def get_paused_nodes(self): - return ["node-1"] - - -class _DummyGraphEngine: - def __init__(self): - self.graph_runtime_state = _DummyRuntimeState() - - -class _DummyWorkflowEntry: - def __init__(self): - self.graph_engine = _DummyGraphEngine() - - -def test_handle_pause_event_enqueues_email_task(monkeypatch: pytest.MonkeyPatch): - queue_manager = _DummyQueueManager() - runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app-id") - workflow_entry = _DummyWorkflowEntry() - - reason = HumanInputRequired( - form_id="form-123", - form_content="content", - inputs=[], - actions=[], - node_id="node-1", - node_title="Review", - ) - event = GraphRunPausedEvent(reasons=[reason], outputs={}) - - email_task = MagicMock() - monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", email_task) - - runner._handle_event(workflow_entry, event) - - email_task.apply_async.assert_called_once() - kwargs = email_task.apply_async.call_args.kwargs["kwargs"] - assert kwargs["form_id"] == "form-123" - assert kwargs["node_title"] == "Review" - - assert any(isinstance(evt, QueueWorkflowPausedEvent) for evt in queue_manager.published) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py deleted file mode 100644 index c30b925d88..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ /dev/null @@ -1,183 +0,0 @@ -from datetime import UTC, datetime -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.app.apps.common import workflow_response_converter -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.apps.workflow.app_runner import WorkflowAppRunner -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph_events.graph import GraphRunPausedEvent -from core.workflow.nodes.human_input.entities import FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType -from core.workflow.system_variable import SystemVariable -from models.account import Account - - -class _RecordingWorkflowAppRunner(WorkflowAppRunner): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.published_events = [] - - def _publish_event(self, event): - self.published_events.append(event) - - -class _FakeRuntimeState: - def get_paused_nodes(self): - return ["node-pause-1"] - - -def _build_runner(): - app_entity = SimpleNamespace( - app_config=SimpleNamespace(app_id="app-id"), - inputs={}, - files=[], - invoke_from=InvokeFrom.SERVICE_API, - single_iteration_run=None, - single_loop_run=None, - workflow_execution_id="run-id", - user_id="user-id", - ) - workflow = SimpleNamespace( - graph_dict={}, - tenant_id="tenant-id", - environment_variables={}, - id="workflow-id", - ) - queue_manager = SimpleNamespace(publish=lambda event, pub_from: None) - return _RecordingWorkflowAppRunner( - application_generate_entity=app_entity, - queue_manager=queue_manager, - variable_loader=MagicMock(), - workflow=workflow, - system_user_id="sys-user", - root_node_id=None, - workflow_execution_repository=MagicMock(), - workflow_node_execution_repository=MagicMock(), - graph_engine_layers=(), - graph_runtime_state=None, - ) - - -def test_graph_run_paused_event_emits_queue_pause_event(): - runner = _build_runner() - reason = HumanInputRequired( - form_id="form-1", - form_content="content", - inputs=[], - actions=[], - node_id="node-human", - node_title="Human Step", - form_token="tok", - ) - event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"}) - workflow_entry = SimpleNamespace( - graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()), - ) - - runner._handle_event(workflow_entry, event) - - assert len(runner.published_events) == 1 - queue_event = runner.published_events[0] - assert isinstance(queue_event, QueueWorkflowPausedEvent) - assert queue_event.reasons == [reason] - assert queue_event.outputs == {"foo": "bar"} - assert queue_event.paused_nodes == ["node-pause-1"] - - -def _build_converter(): - application_generate_entity = SimpleNamespace( - inputs={}, - files=[], - invoke_from=InvokeFrom.SERVICE_API, - app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), - ) - system_variables = SystemVariable( - user_id="user", - app_id="app-id", - workflow_id="workflow-id", - workflow_execution_id="run-id", - ) - user = MagicMock(spec=Account) - user.id = "account-id" - user.name = "Tester" - user.email = "tester@example.com" - return WorkflowResponseConverter( - application_generate_entity=application_generate_entity, - user=user, - system_variables=system_variables, - ) - - -def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.MonkeyPatch): - converter = _build_converter() - converter.workflow_start_to_stream_response( - task_id="task", - workflow_run_id="run-id", - workflow_id="workflow-id", - reason=WorkflowStartReason.INITIAL, - ) - - expiration_time = datetime(2024, 1, 1, tzinfo=UTC) - - class _FakeSession: - def execute(self, _stmt): - return [("form-1", expiration_time)] - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession()) - monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object())) - - reason = HumanInputRequired( - form_id="form-1", - form_content="Rendered", - inputs=[ - FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), - ], - actions=[UserAction(id="approve", title="Approve")], - display_in_ui=True, - node_id="node-id", - node_title="Human Step", - form_token="token", - ) - queue_event = QueueWorkflowPausedEvent( - reasons=[reason], - outputs={"answer": "value"}, - paused_nodes=["node-id"], - ) - - runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0) - responses = converter.workflow_pause_to_stream_response( - event=queue_event, - task_id="task", - graph_runtime_state=runtime_state, - ) - - assert isinstance(responses[-1], WorkflowPauseStreamResponse) - pause_resp = responses[-1] - assert pause_resp.workflow_run_id == "run-id" - assert pause_resp.data.paused_nodes == ["node-id"] - assert pause_resp.data.outputs == {} - assert pause_resp.data.reasons[0]["form_id"] == "form-1" - assert pause_resp.data.reasons[0]["display_in_ui"] is True - - assert isinstance(responses[0], HumanInputRequiredResponse) - hi_resp = responses[0] - assert hi_resp.data.form_id == "form-1" - assert hi_resp.data.node_id == "node-id" - assert hi_resp.data.node_title == "Human Step" - assert hi_resp.data.inputs[0].output_variable_name == "field" - assert hi_resp.data.actions[0].id == "approve" - assert hi_resp.data.display_in_ui is True - assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py deleted file mode 100644 index 32cb1ed47c..0000000000 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ /dev/null @@ -1,96 +0,0 @@ -import time -from contextlib import contextmanager -from unittest.mock import MagicMock - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.apps.base_app_queue_manager import AppQueueManager -from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.entities.queue_entities import QueueWorkflowStartedEvent -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.account import Account -from models.model import AppMode - - -def _build_workflow_app_config() -> WorkflowUIBasedAppConfig: - return WorkflowUIBasedAppConfig( - tenant_id="tenant-id", - app_id="app-id", - app_mode=AppMode.WORKFLOW, - workflow_id="workflow-id", - ) - - -def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity: - return WorkflowAppGenerateEntity( - task_id="task-id", - app_config=_build_workflow_app_config(), - inputs={}, - files=[], - user_id="user-id", - stream=False, - invoke_from=InvokeFrom.SERVICE_API, - workflow_execution_id=run_id, - ) - - -def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable(workflow_execution_id=run_id), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -@contextmanager -def _noop_session(): - yield MagicMock() - - -def _build_pipeline(run_id: str) -> WorkflowAppGenerateTaskPipeline: - queue_manager = MagicMock(spec=AppQueueManager) - queue_manager.invoke_from = InvokeFrom.SERVICE_API - queue_manager.graph_runtime_state = _build_runtime_state(run_id) - workflow = MagicMock() - workflow.id = "workflow-id" - workflow.features_dict = {} - user = Account(name="user", email="user@example.com") - pipeline = WorkflowAppGenerateTaskPipeline( - application_generate_entity=_build_generate_entity(run_id), - workflow=workflow, - queue_manager=queue_manager, - user=user, - stream=False, - draft_var_saver_factory=MagicMock(), - ) - pipeline._database_session = _noop_session - return pipeline - - -def test_workflow_app_log_saved_only_on_initial_start() -> None: - run_id = "run-initial" - pipeline = _build_pipeline(run_id) - pipeline._save_workflow_app_log = MagicMock() - - event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.INITIAL) - list(pipeline._handle_workflow_started_event(event)) - - pipeline._save_workflow_app_log.assert_called_once() - _, kwargs = pipeline._save_workflow_app_log.call_args - assert kwargs["workflow_run_id"] == run_id - assert pipeline._workflow_execution_id == run_id - - -def test_workflow_app_log_skipped_on_resumption_start() -> None: - run_id = "run-resume" - pipeline = _build_pipeline(run_id) - pipeline._save_workflow_app_log = MagicMock() - - event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.RESUMPTION) - list(pipeline._handle_workflow_started_event(event)) - - pipeline._save_workflow_app_log.assert_not_called() - assert pipeline._workflow_execution_id == run_id diff --git a/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py b/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py deleted file mode 100644 index 86c80985c4..0000000000 --- a/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py +++ /dev/null @@ -1,143 +0,0 @@ -import json -from collections.abc import Callable -from dataclasses import dataclass - -import pytest - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, - WorkflowAppGenerateEntity, -) -from core.app.layers.pause_state_persist_layer import ( - WorkflowResumptionContext, - _AdvancedChatAppGenerateEntityWrapper, - _WorkflowGenerateEntityWrapper, -) -from core.ops.ops_trace_manager import TraceQueueManager -from models.model import AppMode - - -class TraceQueueManagerStub(TraceQueueManager): - """Minimal TraceQueueManager stub that avoids Flask dependencies.""" - - def __init__(self): - # Skip parent initialization to avoid starting timers or accessing Flask globals. - pass - - -def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig: - return WorkflowUIBasedAppConfig( - tenant_id="tenant-id", - app_id="app-id", - app_mode=app_mode, - workflow_id=f"{app_mode.value}-workflow-id", - ) - - -def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity: - return WorkflowAppGenerateEntity( - task_id="workflow-task", - app_config=_build_workflow_app_config(AppMode.WORKFLOW), - inputs={"topic": "serialization"}, - files=[], - user_id="user-workflow", - stream=True, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=1, - trace_manager=trace_manager, - workflow_execution_id="workflow-exec-id", - extras={"external_trace_id": "trace-id"}, - ) - - -def _create_advanced_chat_generate_entity( - trace_manager: TraceQueueManager | None = None, -) -> AdvancedChatAppGenerateEntity: - return AdvancedChatAppGenerateEntity( - task_id="advanced-task", - app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT), - conversation_id="conversation-id", - inputs={"topic": "roundtrip"}, - files=[], - user_id="user-advanced", - stream=False, - invoke_from=InvokeFrom.DEBUGGER, - query="Explain serialization", - extras={"auto_generate_conversation_name": True}, - trace_manager=trace_manager, - workflow_run_id="workflow-run-id", - ) - - -def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager(): - entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub()) - - serialized = entity.model_dump_json() - payload = json.loads(serialized) - - assert "trace_manager" not in payload - - restored = WorkflowAppGenerateEntity.model_validate_json(serialized) - - assert restored.model_dump() == entity.model_dump() - assert restored.trace_manager is None - - -def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager(): - entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub()) - - serialized = entity.model_dump_json() - payload = json.loads(serialized) - - assert "trace_manager" not in payload - - restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized) - - assert restored.model_dump() == entity.model_dump() - assert restored.trace_manager is None - - -@dataclass(frozen=True) -class ResumptionContextCase: - name: str - context_factory: Callable[[], tuple[WorkflowResumptionContext, type]] - - -def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]: - entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub()) - context = WorkflowResumptionContext( - serialized_graph_runtime_state=json.dumps({"state": "workflow"}), - generate_entity=_WorkflowGenerateEntityWrapper(entity=entity), - ) - return context, WorkflowAppGenerateEntity - - -def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]: - entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub()) - context = WorkflowResumptionContext( - serialized_graph_runtime_state=json.dumps({"state": "advanced"}), - generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity), - ) - return context, AdvancedChatAppGenerateEntity - - -@pytest.mark.parametrize( - "case", - [ - pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"), - pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"), - ], -) -def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase): - context, expected_type = case.context_factory() - - serialized = context.dumps() - restored = WorkflowResumptionContext.loads(serialized) - - assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state - entity = restored.get_generate_entity() - assert isinstance(entity, expected_type) - assert entity.model_dump() == context.get_generate_entity().model_dump() - assert entity.trace_manager is None diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py index 91352b2a5f..cfdeef6a8d 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py +++ b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py @@ -101,3 +101,26 @@ def test__normalize_non_stream_plugin_result__empty_iterator_defaults(): assert result.message.tool_calls == [] assert result.usage == LLMUsage.empty_usage() assert result.system_fingerprint is None + + +def test__normalize_non_stream_plugin_result__closes_chunk_iterator(): + prompt_messages = [UserPromptMessage(content="hi")] + + chunk = _make_chunk(content="hello", usage=LLMUsage.empty_usage()) + closed: list[bool] = [] + + def _chunk_iter(): + try: + yield chunk + yield _make_chunk(content="ignored", usage=LLMUsage.empty_usage()) + finally: + closed.append(True) + + result = _normalize_non_stream_plugin_result( + model="test-model", + prompt_messages=prompt_messages, + result=_chunk_iter(), + ) + + assert result.message.content == "hello" + assert closed == [True] diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py deleted file mode 100644 index 811ed2143b..0000000000 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ /dev/null @@ -1,574 +0,0 @@ -"""Unit tests for HumanInputFormRepositoryImpl private helpers.""" - -from __future__ import annotations - -import dataclasses -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.repositories.human_input_repository import ( - HumanInputFormRecord, - HumanInputFormRepositoryImpl, - HumanInputFormSubmissionRepository, - _WorkspaceMemberInfo, -) -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - FormDefinition, - MemberRecipient, - UserAction, -) -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from libs.datetime_utils import naive_utc_now -from models.human_input import ( - EmailExternalRecipientPayload, - EmailMemberRecipientPayload, - HumanInputFormRecipient, - RecipientType, -) - - -def _build_repository() -> HumanInputFormRepositoryImpl: - return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id") - - -def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]: - created: list[SimpleNamespace] = [] - - def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def] - recipient = SimpleNamespace( - form_id=form_id, - delivery_id=delivery_id, - recipient_type=payload.TYPE, - recipient_payload=payload.model_dump_json(), - ) - created.append(recipient) - return recipient - - monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new)) - return created - - -@pytest.fixture(autouse=True) -def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None: - """Avoid SQLAlchemy mapper configuration in tests using fake sessions.""" - - class _FakeSelect: - def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def] - return self - - def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def] - return self - - monkeypatch.setattr( - "core.repositories.human_input_repository.selectinload", lambda *args, **kwargs: "_loader_option" - ) - monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *args, **kwargs: _FakeSelect()) - - -class TestHumanInputFormRepositoryImplHelpers: - def test_build_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None: - repo = _build_repository() - session_stub = object() - _patch_recipient_factory(monkeypatch) - - def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] - assert session is session_stub - assert restrict_to_user_ids == ["member-1"] - return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) - - recipients = repo._build_email_recipients( - session=session_stub, - form_id="form-id", - delivery_id="delivery-id", - recipients_config=EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(user_id="member-1"), - ExternalRecipient(email="external@example.com"), - ], - ), - ) - - assert len(recipients) == 2 - member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER) - external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL) - - member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload) - assert member_payload.user_id == "member-1" - assert member_payload.email == "member@example.com" - - external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload) - assert external_payload.email == "external@example.com" - - def test_build_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None: - repo = _build_repository() - session_stub = object() - created = _patch_recipient_factory(monkeypatch) - - def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] - assert session is session_stub - assert restrict_to_user_ids == ["missing-member"] - return [] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) - - recipients = repo._build_email_recipients( - session=session_stub, - form_id="form-id", - delivery_id="delivery-id", - recipients_config=EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(user_id="missing-member"), - ExternalRecipient(email="external@example.com"), - ], - ), - ) - - assert len(recipients) == 1 - assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL - assert len(created) == 1 # only external recipient created via factory - - def test_build_email_recipients_whole_workspace_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None: - repo = _build_repository() - session_stub = object() - _patch_recipient_factory(monkeypatch) - - def fake_query(self, session): # type: ignore[no-untyped-def] - assert session is session_stub - return [ - _WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"), - _WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"), - ] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query) - - recipients = repo._build_email_recipients( - session=session_stub, - form_id="form-id", - delivery_id="delivery-id", - recipients_config=EmailRecipients( - whole_workspace=True, - items=[], - ), - ) - - assert len(recipients) == 2 - emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients} - assert emails == {"member1@example.com", "member2@example.com"} - - def test_build_email_recipients_dedupes_external_by_email(self, monkeypatch: pytest.MonkeyPatch) -> None: - repo = _build_repository() - session_stub = object() - created = _patch_recipient_factory(monkeypatch) - - def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] - assert session is session_stub - assert restrict_to_user_ids == [] - return [] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) - - recipients = repo._build_email_recipients( - session=session_stub, - form_id="form-id", - delivery_id="delivery-id", - recipients_config=EmailRecipients( - whole_workspace=False, - items=[ - ExternalRecipient(email="external@example.com"), - ExternalRecipient(email="external@example.com"), - ], - ), - ) - - assert len(recipients) == 1 - assert len(created) == 1 - - def test_build_email_recipients_prefers_member_over_external_by_email( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: - repo = _build_repository() - session_stub = object() - _patch_recipient_factory(monkeypatch) - - def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] - assert session is session_stub - assert restrict_to_user_ids == ["member-1"] - return [_WorkspaceMemberInfo(user_id="member-1", email="shared@example.com")] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) - - recipients = repo._build_email_recipients( - session=session_stub, - form_id="form-id", - delivery_id="delivery-id", - recipients_config=EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(user_id="member-1"), - ExternalRecipient(email="shared@example.com"), - ], - ), - ) - - assert len(recipients) == 1 - assert recipients[0].recipient_type == RecipientType.EMAIL_MEMBER - - def test_delivery_method_to_model_includes_external_recipients_with_whole_workspace( - self, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - repo = _build_repository() - session_stub = object() - _patch_recipient_factory(monkeypatch) - - def fake_query(self, session): # type: ignore[no-untyped-def] - assert session is session_stub - return [ - _WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"), - _WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"), - ] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query) - - method = EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=True, - items=[ExternalRecipient(email="external@example.com")], - ), - subject="subject", - body="body", - ) - ) - - result = repo._delivery_method_to_model(session=session_stub, form_id="form-id", delivery_method=method) - - assert len(result.recipients) == 3 - member_emails = { - EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email - for r in result.recipients - if r.recipient_type == RecipientType.EMAIL_MEMBER - } - assert member_emails == {"member1@example.com", "member2@example.com"} - external_payload = EmailExternalRecipientPayload.model_validate_json( - next(r for r in result.recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL).recipient_payload - ) - assert external_payload.email == "external@example.com" - - -def _make_form_definition() -> str: - return FormDefinition( - form_content="hello", - inputs=[], - user_actions=[UserAction(id="submit", title="Submit")], - rendered_content="

hello

", - expiration_time=datetime.utcnow(), - ).model_dump_json() - - -@dataclasses.dataclass -class _DummyForm: - id: str - workflow_run_id: str - node_id: str - tenant_id: str - app_id: str - form_definition: str - rendered_content: str - expiration_time: datetime - form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME - created_at: datetime = dataclasses.field(default_factory=naive_utc_now) - selected_action_id: str | None = None - submitted_data: str | None = None - submitted_at: datetime | None = None - submission_user_id: str | None = None - submission_end_user_id: str | None = None - completed_by_recipient_id: str | None = None - status: HumanInputFormStatus = HumanInputFormStatus.WAITING - - -@dataclasses.dataclass -class _DummyRecipient: - id: str - form_id: str - recipient_type: RecipientType - access_token: str - form: _DummyForm | None = None - - -class _FakeScalarResult: - def __init__(self, obj): - self._obj = obj - - def first(self): - if isinstance(self._obj, list): - return self._obj[0] if self._obj else None - return self._obj - - def all(self): - if isinstance(self._obj, list): - return list(self._obj) - if self._obj is None: - return [] - return [self._obj] - - -class _FakeSession: - def __init__( - self, - *, - scalars_result=None, - scalars_results: list[object] | None = None, - forms: dict[str, _DummyForm] | None = None, - recipients: dict[str, _DummyRecipient] | None = None, - ): - if scalars_results is not None: - self._scalars_queue = list(scalars_results) - elif scalars_result is not None: - self._scalars_queue = [scalars_result] - else: - self._scalars_queue = [] - self.forms = forms or {} - self.recipients = recipients or {} - - def scalars(self, _query): - if self._scalars_queue: - result = self._scalars_queue.pop(0) - else: - result = None - return _FakeScalarResult(result) - - def get(self, model_cls, obj_id): # type: ignore[no-untyped-def] - if getattr(model_cls, "__name__", None) == "HumanInputForm": - return self.forms.get(obj_id) - if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient": - return self.recipients.get(obj_id) - return None - - def add(self, _obj): - return None - - def flush(self): - return None - - def refresh(self, _obj): - return None - - def begin(self): - return self - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return None - - -def _session_factory(session: _FakeSession): - class _SessionContext: - def __enter__(self): - return session - - def __exit__(self, exc_type, exc, tb): - return None - - def _factory(*_args, **_kwargs): - return _SessionContext() - - return _factory - - -class TestHumanInputFormRepositoryImplPublicMethods: - def test_get_form_returns_entity_and_recipients(self): - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-id", - app_id="app-id", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=naive_utc_now(), - ) - recipient = _DummyRecipient( - id="recipient-1", - form_id=form.id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - access_token="token-123", - ) - session = _FakeSession(scalars_results=[form, [recipient]]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - - entity = repo.get_form(form.workflow_run_id, form.node_id) - - assert entity is not None - assert entity.id == form.id - assert entity.web_app_token == "token-123" - assert len(entity.recipients) == 1 - assert entity.recipients[0].token == "token-123" - - def test_get_form_returns_none_when_missing(self): - session = _FakeSession(scalars_results=[None]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - - assert repo.get_form("run-1", "node-1") is None - - def test_get_form_returns_unsubmitted_state(self): - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-id", - app_id="app-id", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=naive_utc_now(), - ) - session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - - entity = repo.get_form(form.workflow_run_id, form.node_id) - - assert entity is not None - assert entity.submitted is False - assert entity.selected_action_id is None - assert entity.submitted_data is None - - def test_get_form_returns_submission_when_completed(self): - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-id", - app_id="app-id", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=naive_utc_now(), - selected_action_id="approve", - submitted_data='{"field": "value"}', - submitted_at=naive_utc_now(), - ) - session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - - entity = repo.get_form(form.workflow_run_id, form.node_id) - - assert entity is not None - assert entity.submitted is True - assert entity.selected_action_id == "approve" - assert entity.submitted_data == {"field": "value"} - - -class TestHumanInputFormSubmissionRepository: - def test_get_by_token_returns_record(self): - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-1", - app_id="app-1", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=naive_utc_now(), - ) - recipient = _DummyRecipient( - id="recipient-1", - form_id=form.id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - access_token="token-123", - form=form, - ) - session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) - - record = repo.get_by_token("token-123") - - assert record is not None - assert record.form_id == form.id - assert record.recipient_type == RecipientType.STANDALONE_WEB_APP - assert record.submitted is False - - def test_get_by_form_id_and_recipient_type_uses_recipient(self): - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-1", - app_id="app-1", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=naive_utc_now(), - ) - recipient = _DummyRecipient( - id="recipient-1", - form_id=form.id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - access_token="token-123", - form=form, - ) - session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) - - record = repo.get_by_form_id_and_recipient_type( - form_id=form.id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - ) - - assert record is not None - assert record.recipient_id == recipient.id - assert record.access_token == recipient.access_token - - def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch): - fixed_now = datetime(2024, 1, 1, 0, 0, 0) - monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) - - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-1", - app_id="app-1", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=fixed_now, - ) - recipient = _DummyRecipient( - id="recipient-1", - form_id="form-1", - recipient_type=RecipientType.STANDALONE_WEB_APP, - access_token="token-123", - ) - session = _FakeSession( - forms={form.id: form}, - recipients={recipient.id: recipient}, - ) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) - - record: HumanInputFormRecord = repo.mark_submitted( - form_id=form.id, - recipient_id=recipient.id, - selected_action_id="approve", - form_data={"field": "value"}, - submission_user_id="user-1", - submission_end_user_id="end-user-1", - ) - - assert form.selected_action_id == "approve" - assert form.completed_by_recipient_id == recipient.id - assert form.submission_user_id == "user-1" - assert form.submission_end_user_id == "end-user-1" - assert form.submitted_at == fixed_now - assert record.submitted is True - assert record.selected_action_id == "approve" - assert record.submitted_data == {"field": "value"} diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py deleted file mode 100644 index c46e31d90f..0000000000 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ /dev/null @@ -1,33 +0,0 @@ -import pytest - -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils - - -def test_ensure_no_human_input_nodes_passes_for_non_human_input(): - graph = { - "nodes": [ - { - "id": "start_node", - "data": {"type": "start"}, - } - ] - } - - WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) - - -def test_ensure_no_human_input_nodes_raises_for_human_input(): - graph = { - "nodes": [ - { - "id": "human_input_node", - "data": {"type": "human-input"}, - } - ] - } - - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 1b6d03e36a..deff06fc5d 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -118,6 +118,7 @@ class TestGraphRuntimeState: from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue assert isinstance(queue, InMemoryReadyQueue) + assert state.ready_queue is queue def test_graph_execution_lazy_instantiation(self): state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py deleted file mode 100644 index 6144df06e0..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Tests for PauseReason discriminated union serialization/deserialization. -""" - -import pytest -from pydantic import BaseModel, ValidationError - -from core.workflow.entities.pause_reason import ( - HumanInputRequired, - PauseReason, - SchedulingPause, -) - - -class _Holder(BaseModel): - """Helper model that embeds PauseReason for union tests.""" - - reason: PauseReason - - -class TestPauseReasonDiscriminator: - """Test suite for PauseReason union discriminator.""" - - @pytest.mark.parametrize( - ("dict_value", "expected"), - [ - pytest.param( - { - "reason": { - "TYPE": "human_input_required", - "form_id": "form_id", - "form_content": "form_content", - "node_id": "node_id", - "node_title": "node_title", - }, - }, - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - id="HumanInputRequired", - ), - pytest.param( - { - "reason": { - "TYPE": "scheduled_pause", - "message": "Hold on", - } - }, - SchedulingPause(message="Hold on"), - id="SchedulingPause", - ), - ], - ) - def test_model_validate(self, dict_value, expected): - """Ensure scheduled pause payloads with lowercase TYPE deserialize.""" - holder = _Holder.model_validate(dict_value) - - assert type(holder.reason) == type(expected) - assert holder.reason == expected - - @pytest.mark.parametrize( - "reason", - [ - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - SchedulingPause(message="Hold on"), - ], - ids=lambda x: type(x).__name__, - ) - def test_model_construct(self, reason): - holder = _Holder(reason=reason) - assert holder.reason == reason - - def test_model_construct_with_invalid_type(self): - with pytest.raises(ValidationError): - holder = _Holder(reason=object()) # type: ignore - - def test_unknown_type_fails_validation(self): - """Unknown TYPE values should raise a validation error.""" - with pytest.raises(ValidationError): - _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py deleted file mode 100644 index 2ef23c7f0f..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Utilities for testing HumanInputNode without database dependencies.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRecipientEntity, - HumanInputFormRepository, -) -from libs.datetime_utils import naive_utc_now - - -class _InMemoryFormRecipient(HumanInputFormRecipientEntity): - """Minimal recipient entity required by the repository interface.""" - - def __init__(self, recipient_id: str, token: str) -> None: - self._id = recipient_id - self._token = token - - @property - def id(self) -> str: - return self._id - - @property - def token(self) -> str: - return self._token - - -@dataclass -class _InMemoryFormEntity(HumanInputFormEntity): - form_id: str - rendered: str - token: str | None = None - action_id: str | None = None - data: Mapping[str, Any] | None = None - is_submitted: bool = False - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return self.token - - @property - def recipients(self) -> list[HumanInputFormRecipientEntity]: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class InMemoryHumanInputFormRepository(HumanInputFormRepository): - """Pure in-memory repository used by workflow graph engine tests.""" - - def __init__(self) -> None: - self._form_counter = 0 - self.created_params: list[FormCreateParams] = [] - self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - self.created_params.append(params) - self._form_counter += 1 - form_id = f"form-{self._form_counter}" - token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}" - entity = _InMemoryFormEntity( - form_id=form_id, - rendered=params.rendered_content, - token=token, - ) - self.created_forms.append(entity) - self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity - return entity - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_key.get((workflow_execution_id, node_id)) - - # Convenience helpers for tests ------------------------------------- - - def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: - """Simulate a human submission for the next repository lookup.""" - - if not self.created_forms: - raise AssertionError("no form has been created to attach submission data") - entity = self.created_forms[-1] - entity.action_id = action_id - entity.data = form_data or {} - entity.is_submitted = True - entity.status_value = HumanInputFormStatus.SUBMITTED - entity.expiration = naive_utc_now() + timedelta(days=1) - - def clear_submission(self) -> None: - if not self.created_forms: - return - for form in self.created_forms: - form.action_id = None - form.data = None - form.is_submitted = False - form.status_value = HumanInputFormStatus.WAITING diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py deleted file mode 100644 index 6038a15211..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ /dev/null @@ -1,74 +0,0 @@ -import queue -import threading -from datetime import datetime - -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher -from core.workflow.graph_events import NodeRunSucceededEvent -from core.workflow.node_events import NodeRunResult - - -class StubExecutionCoordinator: - def __init__(self, paused: bool) -> None: - self._paused = paused - self.mark_complete_called = False - self.failed_error: Exception | None = None - - @property - def aborted(self) -> bool: - return False - - @property - def paused(self) -> bool: - return self._paused - - @property - def execution_complete(self) -> bool: - return False - - def check_scaling(self) -> None: - return None - - def process_commands(self) -> None: - return None - - def mark_complete(self) -> None: - self.mark_complete_called = True - - def mark_failed(self, error: Exception) -> None: - self.failed_error = error - - -class StubEventHandler: - def __init__(self) -> None: - self.events: list[object] = [] - - def dispatch(self, event: object) -> None: - self.events.append(event) - - -def test_dispatcher_drains_events_when_paused() -> None: - event_queue: queue.Queue = queue.Queue() - event = NodeRunSucceededEvent( - id="exec-1", - node_id="node-1", - node_type=NodeType.START, - start_at=datetime.utcnow(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - event_queue.put(event) - - handler = StubEventHandler() - coordinator = StubExecutionCoordinator(paused=True) - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=handler, - execution_coordinator=coordinator, - event_emitter=None, - stop_event=threading.Event(), - ) - - dispatcher._dispatcher_loop() - - assert handler.events == [event] - assert coordinator.mark_complete_called is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py index 53de8908a8..0d67a76169 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py @@ -2,8 +2,6 @@ from unittest.mock import MagicMock -import pytest - from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor from core.workflow.graph_engine.domain.graph_execution import GraphExecution from core.workflow.graph_engine.graph_state_manager import GraphStateManager @@ -50,13 +48,3 @@ def test_handle_pause_noop_when_execution_running() -> None: worker_pool.stop.assert_not_called() state_manager.clear_executing.assert_not_called() - - -def test_has_executing_nodes_requires_pause() -> None: - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, _, _ = _build_coordinator(graph_execution) - - with pytest.raises(AssertionError): - coordinator.has_executing_nodes() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py deleted file mode 100644 index 65d34c2009..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ /dev/null @@ -1,189 +0,0 @@ -import time -from collections.abc import Mapping - -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeState -from core.workflow.graph import Graph -from core.workflow.graph_engine.graph_state_manager import GraphStateManager -from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_llm_node( - *, - node_id: str, - runtime_state: GraphRuntimeState, - graph_init_params: GraphInitParams, - mock_config: MockConfig, -) -> MockLLMNode: - llm_data = LLMNodeData( - title=f"LLM {node_id}", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=f"Prompt {node_id}", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - return MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - -def _build_graph(runtime_state: GraphRuntimeState) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - mock_config = MockConfig() - llm_a = _build_llm_node( - node_id="llm_a", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - llm_b = _build_llm_node( - node_id="llm_b", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - - end_data = EndNodeData(title="End", outputs=[], desc=None) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - builder = ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(llm_b, from_node_id="start") - .add_node(end_node, from_node_id="llm_a") - ) - return builder.connect(tail="llm_b", head="end").build() - - -def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]: - return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()} - - -def test_runtime_state_snapshot_restores_graph_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - graph.nodes["llm_a"].state = NodeState.TAKEN - graph.nodes["llm_b"].state = NodeState.SKIPPED - - for edge in graph.edges.values(): - if edge.tail == "start" and edge.head == "llm_a": - edge.state = NodeState.TAKEN - elif edge.tail == "start" and edge.head == "llm_b": - edge.state = NodeState.SKIPPED - elif edge.head == "end" and edge.tail == "llm_a": - edge.state = NodeState.TAKEN - elif edge.head == "end" and edge.tail == "llm_b": - edge.state = NodeState.SKIPPED - - snapshot = runtime_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN - assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED - assert _edge_state_map(resumed_graph) == _edge_state_map(graph) - - -def test_join_readiness_uses_restored_edge_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - ready_queue = InMemoryReadyQueue() - state_manager = GraphStateManager(graph, ready_queue) - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_a": - edge.state = NodeState.TAKEN - if edge.tail == "llm_b": - edge.state = NodeState.UNKNOWN - - assert state_manager.is_node_ready("end") is False - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_b": - edge.state = NodeState.TAKEN - - assert state_manager.is_node_ready("end") is True - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue()) - assert resumed_state_manager.is_node_ready("end") is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 194d009288..c398e4e8c1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -1,7 +1,5 @@ -import datetime import time from collections.abc import Iterable -from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -16,12 +14,11 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.human_input import HumanInputNode +from core.workflow.nodes.human_input.entities import HumanInputNodeData from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, @@ -31,21 +28,15 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode from .test_table_runner import TableTestRunner, WorkflowTestCase -def _build_branching_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: +def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} graph_init_params = GraphInitParams( tenant_id="tenant", @@ -58,18 +49,12 @@ def _build_branching_graph( call_depth=0, ) - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( @@ -108,21 +93,15 @@ def _build_branching_graph( human_data = HumanInputNodeData( title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="primary", title="Primary"), - UserAction(id="secondary", title="Secondary"), - ], + required_variables=["human.input_ready"], + pause_reason="Awaiting human input", ) - human_config = {"id": "human", "data": human_data.model_dump()} human_node = HumanInputNode( id=human_config["id"], config=human_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=form_repository, ) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") @@ -240,18 +219,8 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: for scenario in branch_scenarios: runner = TableTestRunner() - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: - return _build_branching_graph(mock_config, mock_create_repo) + def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]: + return _build_branching_graph(mock_config) initial_case = WorkflowTestCase( description="HumanInput pause before branching decision", @@ -273,16 +242,23 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: assert initial_result.success, initial_result.event_mismatch_details assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) + graph_runtime_state = initial_result.graph_runtime_state + graph = initial_result.graph + assert graph_runtime_state is not None + assert graph is not None + + graph_runtime_state.variable_pool.add(("human", "input_ready"), True) + graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"]) + graph_runtime_state.graph_execution.pause_reason = None + pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) - expected_pre_chunk_events_in_resumption = [ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunHumanInputFormFilledEvent, - ] expected_resume_sequence: list[type] = ( - expected_pre_chunk_events_in_resumption + [ + GraphRunStartedEvent, + NodeRunStartedEvent, + ] + [NodeRunStreamChunkEvent] * pre_chunk_count + [ NodeRunSucceededEvent, @@ -297,25 +273,11 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: ] ) - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = scenario["handle"] - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - def resume_graph_factory( - initial_result=initial_result, mock_get_repo=mock_get_repo + graph_snapshot: Graph = graph, + state_snapshot: GraphRuntimeState = graph_runtime_state, ) -> tuple[Graph, GraphRuntimeState]: - assert initial_result.graph_runtime_state is not None - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state) + return graph_snapshot, state_snapshot resume_case = WorkflowTestCase( description=f"HumanInput resumes via {scenario['handle']} branch", @@ -359,8 +321,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: for index, event in enumerate(resume_events) if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index ] - expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption) - assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index)) + assert pre_indices == list(range(2, 2 + pre_chunk_count)) resume_chunk_indices = [ index diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index d8f229205b..ece69b080b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -1,6 +1,4 @@ -import datetime import time -from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -15,12 +13,11 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.human_input import HumanInputNode +from core.workflow.nodes.human_input.entities import HumanInputNodeData from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, @@ -30,21 +27,15 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode from .test_table_runner import TableTestRunner, WorkflowTestCase -def _build_llm_human_llm_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: +def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} graph_init_params = GraphInitParams( tenant_id="tenant", @@ -57,15 +48,12 @@ def _build_llm_human_llm_graph( call_depth=0, ) - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( @@ -104,21 +92,15 @@ def _build_llm_human_llm_graph( human_data = HumanInputNodeData( title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="accept", title="Accept"), - UserAction(id="reject", title="Reject"), - ], + required_variables=["human.input_ready"], + pause_reason="Awaiting human input", ) - human_config = {"id": "human", "data": human_data.model_dump()} human_node = HumanInputNode( id=human_config["id"], config=human_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=form_repository, ) llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") @@ -148,7 +130,7 @@ def _build_llm_human_llm_graph( .add_root(start_node) .add_node(llm_first) .add_node(human_node) - .add_node(llm_second, source_handle="accept") + .add_node(llm_second) .add_node(end_node) .build() ) @@ -185,18 +167,8 @@ def test_human_input_llm_streaming_order_across_pause() -> None: GraphRunPausedEvent, # graph run pauses awaiting resume ] - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - def graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_llm_human_llm_graph(mock_config, mock_create_repo) + return _build_llm_human_llm_graph(mock_config) initial_case = WorkflowTestCase( description="HumanInput pause preserves LLM streaming order", @@ -238,8 +210,6 @@ def test_human_input_llm_streaming_order_across_pause() -> None: expected_resume_sequence: list[type] = [ GraphRunStartedEvent, # resumed graph run begins NodeRunStartedEvent, # human node restarts - # Form Filled should be generated first, then the node execution ends and stream chunk is generated. - NodeRunHumanInputFormFilledEvent, NodeRunStreamChunkEvent, # cached llm_initial chunk 1 NodeRunStreamChunkEvent, # cached llm_initial chunk 2 NodeRunStreamChunkEvent, # cached llm_initial final chunk @@ -255,27 +225,12 @@ def test_human_input_llm_streaming_order_across_pause() -> None: GraphRunSucceededEvent, # graph run succeeds after resume ] - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = "accept" - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: - # restruct the graph runtime state - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_llm_human_llm_graph( - mock_config, - mock_get_repo, - resume_runtime_state, - ) + assert graph_runtime_state is not None + assert graph is not None + graph_runtime_state.variable_pool.add(("human", "input_ready"), True) + graph_runtime_state.graph_execution.pause_reason = None + return graph, graph_runtime_state resume_case = WorkflowTestCase( description="HumanInput resume continues LLM streaming order", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py deleted file mode 100644 index a6aab81f6c..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ /dev/null @@ -1,270 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any, Protocol - -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunSucceededEvent, -) -from core.workflow.nodes.base.entities import OutputVariableEntity -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now - - -class PauseStateStore(Protocol): - def save(self, runtime_state: GraphRuntimeState) -> None: ... - - def load(self) -> GraphRuntimeState: ... - - -class InMemoryPauseStore: - def __init__(self) -> None: - self._snapshot: str | None = None - - def save(self, runtime_state: GraphRuntimeState) -> None: - self._snapshot = runtime_state.dumps() - - def load(self) -> GraphRuntimeState: - assert self._snapshot is not None - return GraphRuntimeState.from_snapshot(self._snapshot) - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: - self._forms_by_node_id = dict(forms_by_node_id) - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in resume scenario") - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - - human_a_config = {"id": "human_a", "data": human_data.model_dump()} - human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - human_b_config = {"id": "human_b", "data": human_data.model_dump()} - human_b = HumanInputNode( - id=human_b_config["id"], - config=human_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - end_data = EndNodeData( - title="End", - outputs=[ - OutputVariableEntity(variable="res_a", value_selector=["human_a", "__action_id"]), - OutputVariableEntity(variable="res_b", value_selector=["human_b", "__action_id"]), - ], - desc=None, - ) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - builder = ( - Graph.new() - .add_root(start_node) - .add_node(human_a, from_node_id="start") - .add_node(human_b, from_node_id="start") - .add_node(end_node, from_node_id="human_a", source_handle="approve") - ) - return builder.connect(tail="human_b", head="end", source_handle="approve").build() - - -def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[object]: - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - return list(engine.run()) - - -def _form(submitted: bool, action_id: str | None) -> StaticForm: - return StaticForm( - form_id="form", - rendered="rendered", - is_submitted=submitted, - action_id=action_id, - data={}, - status_value=HumanInputFormStatus.SUBMITTED if submitted else HumanInputFormStatus.WAITING, - ) - - -def test_parallel_human_input_join_completes_after_second_resume() -> None: - pause_store: PauseStateStore = InMemoryPauseStore() - - initial_state = _build_runtime_state() - initial_repo = StaticRepo( - { - "human_a": _form(submitted=False, action_id=None), - "human_b": _form(submitted=False, action_id=None), - } - ) - initial_graph = _build_graph(initial_state, initial_repo) - initial_events = _run_graph(initial_graph, initial_state) - - assert isinstance(initial_events[-1], GraphRunPausedEvent) - pause_store.save(initial_state) - - first_resume_state = pause_store.load() - first_resume_repo = StaticRepo( - { - "human_a": _form(submitted=True, action_id="approve"), - "human_b": _form(submitted=False, action_id=None), - } - ) - first_resume_graph = _build_graph(first_resume_state, first_resume_repo) - first_resume_events = _run_graph(first_resume_graph, first_resume_state) - - assert isinstance(first_resume_events[0], GraphRunStartedEvent) - assert first_resume_events[0].reason is WorkflowStartReason.RESUMPTION - assert isinstance(first_resume_events[-1], GraphRunPausedEvent) - pause_store.save(first_resume_state) - - second_resume_state = pause_store.load() - second_resume_repo = StaticRepo( - { - "human_a": _form(submitted=True, action_id="approve"), - "human_b": _form(submitted=True, action_id="approve"), - } - ) - second_resume_graph = _build_graph(second_resume_state, second_resume_repo) - second_resume_events = _run_graph(second_resume_graph, second_resume_state) - - assert isinstance(second_resume_events[0], GraphRunStartedEvent) - assert second_resume_events[0].reason is WorkflowStartReason.RESUMPTION - assert isinstance(second_resume_events[-1], GraphRunSucceededEvent) - assert any(isinstance(event, NodeRunSucceededEvent) and event.node_id == "end" for event in second_resume_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py deleted file mode 100644 index 62aa56fc57..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ /dev/null @@ -1,333 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: - self._forms_by_node_id = dict(forms_by_node_id) - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in resume scenario") - - -class DelayedHumanInputNode(HumanInputNode): - def __init__(self, delay_seconds: float, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._delay_seconds = delay_seconds - - def _run(self): - if self._delay_seconds > 0: - time.sleep(self._delay_seconds) - yield from super()._run() - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - - human_a_config = {"id": "human_a", "data": human_data.model_dump()} - human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - human_b_config = {"id": "human_b", "data": human_data.model_dump()} - human_b = DelayedHumanInputNode( - id=human_b_config["id"], - config=human_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - delay_seconds=0.2, - ) - - llm_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_config = {"id": "llm_a", "data": llm_data.model_dump()} - llm_a = MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_a, from_node_id="start") - .add_node(human_b, from_node_id="start") - .add_node(llm_a, from_node_id="human_a", source_handle="approve") - .build() - ) - - -def test_parallel_human_input_pause_preserves_node_finished() -> None: - runtime_state = _build_runtime_state() - - runtime_state.graph_execution.start() - runtime_state.register_paused_node("human_a") - runtime_state.register_paused_node("human_b") - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(runtime_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events) - - assert graph_started - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded - - -def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None: - base_state = _build_runtime_state() - base_state.graph_execution.start() - base_state.register_paused_node("human_a") - base_state.register_paused_node("human_b") - snapshot = base_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(resumed_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py deleted file mode 100644 index 156cfefcd6..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ /dev/null @@ -1,309 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, form: HumanInputFormEntity) -> None: - self._form = form - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - if node_id != "human_pause": - return None - return self._form - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in this test") - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - llm_a_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()} - llm_a = MockLLMNode( - id=llm_a_config["id"], - config=llm_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - llm_b_data = LLMNodeData( - title="LLM B", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt B", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()} - llm_b = MockLLMNode( - id=llm_b_config["id"], - config=llm_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Pause here", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - human_config = {"id": "human_pause", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) - end_human_config = {"id": "end_human", "data": end_human_data.model_dump()} - end_human = EndNode( - id=end_human_config["id"], - config=end_human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(human_node, from_node_id="start") - .add_node(llm_b, from_node_id="llm_a") - .add_node(end_human, from_node_id="human_pause", source_handle="approve") - .build() - ) - - -def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def test_pause_defers_ready_nodes_until_resume() -> None: - runtime_state = _build_runtime_state() - - paused_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=False, - status_value=HumanInputFormStatus.WAITING, - ) - pause_repo = StaticRepo(paused_form) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - mock_config.set_node_config( - "llm_b", - NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0), - ) - - graph = _build_graph(runtime_state, pause_repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - paused_events = list(engine.run()) - - assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events) - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events) - assert _get_node_started_event(paused_events, "llm_b") is None - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - resume_repo = StaticRepo(submitted_form) - - resumed_graph = _build_graph(resumed_state, resume_repo, mock_config) - resumed_engine = GraphEngine( - workflow_id="workflow", - graph=resumed_graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - resumed_events = list(resumed_engine.run()) - - start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_b_started = _get_node_started_event(resumed_events, "llm_b") - assert llm_b_started is not None - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py deleted file mode 100644 index 700b3f4b8b..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ /dev/null @@ -1,217 +0,0 @@ -import datetime -import time -from typing import Any -from unittest.mock import MagicMock - -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( - GraphEngineEvent, - GraphRunPausedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from core.workflow.graph_events.graph import GraphRunStartedEvent -from core.workflow.nodes.base.entities import OutputVariableEntity -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = True - form_entity.selected_action_id = action_id - form_entity.submitted_data = {} - form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - repo.get_form.return_value = form_entity - return repo - - -def _mock_form_repository_without_submission() -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = False - repo.create_form.return_value = form_entity - repo.get_form.return_value = None - return repo - - -def _build_human_input_graph( - runtime_state: GraphRuntimeState, - form_repository: HumanInputFormRepository, -) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - - start_data = StartNodeData(title="start", variables=[]) - start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="human", - form_content="Awaiting human input", - inputs=[], - user_actions=[ - UserAction(id="continue", title="Continue"), - ], - ) - human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - ) - - end_data = EndNodeData( - title="end", - outputs=[ - OutputVariableEntity(variable="result", value_selector=["human", "action_id"]), - ], - desc=None, - ) - end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_node) - .add_node(end_node, from_node_id="human", source_handle="continue") - .build() - ) - - -def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - ) - return list(engine.run()) - - -def _node_successes(events: list[GraphEngineEvent]) -> list[str]: - return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)] - - -def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any: - segment = variable_pool.get(selector) - assert segment is not None - return getattr(segment, "value", segment) - - -def test_engine_resume_restores_state_and_completion(): - # Baseline run without pausing - baseline_state = _build_runtime_state() - baseline_repo = _mock_form_repository_with_submission(action_id="continue") - baseline_graph = _build_human_input_graph(baseline_state, baseline_repo) - baseline_events = _run_graph(baseline_graph, baseline_state) - assert baseline_events - first_paused_event = baseline_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_success_nodes = _node_successes(baseline_events) - - # Run with pause - paused_state = _build_runtime_state() - pause_repo = _mock_form_repository_without_submission() - paused_graph = _build_human_input_graph(paused_state, pause_repo) - paused_events = _run_graph(paused_graph, paused_state) - assert paused_events - first_paused_event = paused_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(paused_events[-1], GraphRunPausedEvent) - snapshot = paused_state.dumps() - - # Resume from snapshot - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resume_repo = _mock_form_repository_with_submission(action_id="continue") - resumed_graph = _build_human_input_graph(resumed_state, resume_repo) - resumed_events = _run_graph(resumed_graph, resumed_state) - assert resumed_events - first_resumed_event = resumed_events[0] - assert isinstance(first_resumed_event, GraphRunStartedEvent) - assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION - assert isinstance(resumed_events[-1], GraphRunSucceededEvent) - - combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events) - assert combined_success_nodes == baseline_success_nodes - - paused_human_started = _node_start_event(paused_events, "human") - resumed_human_started = _node_start_event(resumed_events, "human") - assert paused_human_started is not None - assert resumed_human_started is not None - assert paused_human_started.id == resumed_human_started.id - - assert baseline_state.outputs == resumed_state.outputs - assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value( - resumed_state.variable_pool, ("human", "__action_id") - ) - assert baseline_state.graph_execution.completed - assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 21a642c2f8..488b47761b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -7,7 +7,6 @@ from core.workflow.nodes.base.node import Node # Ensures that all node classes are imported. from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING -# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed. _ = NODE_TYPE_CLASSES_MAPPING @@ -46,9 +45,7 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined assert isinstance(cls.node_type, NodeType) assert isinstance(node_version, str) node_type_and_version = (node_type, node_version) - assert node_type_and_version not in type_version_set, ( - f"Duplicate node type and version for class: {cls=} {node_type_and_version=}" - ) + assert node_type_and_version not in type_version_set type_version_set.add(node_type_and_version) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py b/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py deleted file mode 100644 index 20807e9ef9..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Unit tests for human input node diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py deleted file mode 100644 index ca4a887d20..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ /dev/null @@ -1,16 +0,0 @@ -from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients -from core.workflow.runtime import VariablePool - - -def test_render_body_template_replaces_variable_values(): - config = EmailDeliveryConfig( - recipients=EmailRecipients(), - subject="Subject", - body="Hello {{#node1.value#}} {{#url#}}", - ) - variable_pool = VariablePool() - variable_pool.add(["node1", "value"], "World") - - result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool) - - assert result == "Hello World https://example.com" diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py deleted file mode 100644 index bfe7b03c13..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ /dev/null @@ -1,597 +0,0 @@ -""" -Unit tests for human input node entities. -""" - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -from pydantic import ValidationError - -from core.workflow.entities import GraphInitParams -from core.workflow.node_events import PauseRequestedEvent -from core.workflow.node_events.node import StreamCompletedEvent -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - FormInput, - FormInputDefault, - HumanInputNodeData, - MemberRecipient, - UserAction, - WebAppDeliveryMethod, - _WebAppDeliveryConfig, -) -from core.workflow.nodes.human_input.enums import ( - ButtonStyle, - DeliveryMethodType, - EmailRecipientType, - FormInputType, - PlaceholderType, - TimeoutUnit, -) -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository - - -class TestDeliveryMethod: - """Test DeliveryMethod entity.""" - - def test_webapp_delivery_method(self): - """Test webapp delivery method creation.""" - delivery_method = WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()) - - assert delivery_method.type == DeliveryMethodType.WEBAPP - assert delivery_method.enabled is True - assert isinstance(delivery_method.config, _WebAppDeliveryConfig) - - def test_email_delivery_method(self): - """Test email delivery method creation.""" - recipients = EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), - ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), - ], - ) - - config = EmailDeliveryConfig( - recipients=recipients, subject="Test Subject", body="Test body with {{#url#}} placeholder" - ) - - delivery_method = EmailDeliveryMethod(enabled=True, config=config) - - assert delivery_method.type == DeliveryMethodType.EMAIL - assert delivery_method.enabled is True - assert isinstance(delivery_method.config, EmailDeliveryConfig) - assert delivery_method.config.subject == "Test Subject" - assert len(delivery_method.config.recipients.items) == 2 - - -class TestFormInput: - """Test FormInput entity.""" - - def test_text_input_with_constant_default(self): - """Test text input with constant default value.""" - default = FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter your response here...") - - form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default) - - assert form_input.type == FormInputType.TEXT_INPUT - assert form_input.output_variable_name == "user_input" - assert form_input.default.type == PlaceholderType.CONSTANT - assert form_input.default.value == "Enter your response here..." - - def test_text_input_with_variable_default(self): - """Test text input with variable default value.""" - default = FormInputDefault(type=PlaceholderType.VARIABLE, selector=["node_123", "output_var"]) - - form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default) - - assert form_input.default.type == PlaceholderType.VARIABLE - assert form_input.default.selector == ["node_123", "output_var"] - - def test_form_input_without_default(self): - """Test form input without default value.""" - form_input = FormInput(type=FormInputType.PARAGRAPH, output_variable_name="description") - - assert form_input.type == FormInputType.PARAGRAPH - assert form_input.output_variable_name == "description" - assert form_input.default is None - - -class TestUserAction: - """Test UserAction entity.""" - - def test_user_action_creation(self): - """Test user action creation.""" - action = UserAction(id="approve", title="Approve", button_style=ButtonStyle.PRIMARY) - - assert action.id == "approve" - assert action.title == "Approve" - assert action.button_style == ButtonStyle.PRIMARY - - def test_user_action_default_button_style(self): - """Test user action with default button style.""" - action = UserAction(id="cancel", title="Cancel") - - assert action.button_style == ButtonStyle.DEFAULT - - def test_user_action_length_boundaries(self): - """Test user action id and title length boundaries.""" - action = UserAction(id="a" * 20, title="b" * 20) - - assert action.id == "a" * 20 - assert action.title == "b" * 20 - - @pytest.mark.parametrize( - ("field_name", "value"), - [ - ("id", "a" * 21), - ("title", "b" * 21), - ], - ) - def test_user_action_length_limits(self, field_name: str, value: str): - """User action fields should enforce max length.""" - data = {"id": "approve", "title": "Approve"} - data[field_name] = value - - with pytest.raises(ValidationError) as exc_info: - UserAction(**data) - - errors = exc_info.value.errors() - assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors) - - -class TestHumanInputNodeData: - """Test HumanInputNodeData entity.""" - - def test_valid_node_data_creation(self): - """Test creating valid human input node data.""" - delivery_methods = [WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())] - - inputs = [ - FormInput( - type=FormInputType.TEXT_INPUT, - output_variable_name="content", - default=FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter content..."), - ) - ] - - user_actions = [UserAction(id="submit", title="Submit", button_style=ButtonStyle.PRIMARY)] - - node_data = HumanInputNodeData( - title="Human Input Test", - desc="Test node description", - delivery_methods=delivery_methods, - form_content="# Test Form\n\nPlease provide input:\n\n{{#$output.content#}}", - inputs=inputs, - user_actions=user_actions, - timeout=24, - timeout_unit=TimeoutUnit.HOUR, - ) - - assert node_data.title == "Human Input Test" - assert node_data.desc == "Test node description" - assert len(node_data.delivery_methods) == 1 - assert node_data.form_content.startswith("# Test Form") - assert len(node_data.inputs) == 1 - assert len(node_data.user_actions) == 1 - assert node_data.timeout == 24 - assert node_data.timeout_unit == TimeoutUnit.HOUR - - def test_node_data_with_multiple_delivery_methods(self): - """Test node data with multiple delivery methods.""" - delivery_methods = [ - WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()), - EmailDeliveryMethod( - enabled=False, # Disabled method should be fine - config=EmailDeliveryConfig( - subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) - ), - ), - ] - - node_data = HumanInputNodeData( - title="Test Node", delivery_methods=delivery_methods, timeout=1, timeout_unit=TimeoutUnit.DAY - ) - - assert len(node_data.delivery_methods) == 2 - assert node_data.timeout == 1 - assert node_data.timeout_unit == TimeoutUnit.DAY - - def test_node_data_defaults(self): - """Test node data with default values.""" - node_data = HumanInputNodeData(title="Test Node") - - assert node_data.title == "Test Node" - assert node_data.desc is None - assert node_data.delivery_methods == [] - assert node_data.form_content == "" - assert node_data.inputs == [] - assert node_data.user_actions == [] - assert node_data.timeout == 36 - assert node_data.timeout_unit == TimeoutUnit.HOUR - - def test_duplicate_input_output_variable_name_raises_validation_error(self): - """Duplicate form input output_variable_name should raise validation error.""" - duplicate_inputs = [ - FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"), - FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"), - ] - - with pytest.raises(ValidationError, match="duplicated output_variable_name 'content'"): - HumanInputNodeData(title="Test Node", inputs=duplicate_inputs) - - def test_duplicate_user_action_ids_raise_validation_error(self): - """Duplicate user action ids should raise validation error.""" - duplicate_actions = [ - UserAction(id="submit", title="Submit"), - UserAction(id="submit", title="Submit Again"), - ] - - with pytest.raises(ValidationError, match="duplicated user action id 'submit'"): - HumanInputNodeData(title="Test Node", user_actions=duplicate_actions) - - def test_extract_outputs_field_names(self): - content = r"""This is titile {{#start.title#}} - - A content is required: - - {{#$output.content#}} - - A ending is required: - - {{#$output.ending#}} - """ - - node_data = HumanInputNodeData(title="Human Input", form_content=content) - field_names = node_data.outputs_field_names() - assert field_names == ["content", "ending"] - - -class TestRecipients: - """Test email recipient entities.""" - - def test_member_recipient(self): - """Test member recipient creation.""" - recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") - - assert recipient.type == EmailRecipientType.MEMBER - assert recipient.user_id == "user-123" - - def test_external_recipient(self): - """Test external recipient creation.""" - recipient = ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com") - - assert recipient.type == EmailRecipientType.EXTERNAL - assert recipient.email == "test@example.com" - - def test_email_recipients_whole_workspace(self): - """Test email recipients with whole workspace enabled.""" - recipients = EmailRecipients( - whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] - ) - - assert recipients.whole_workspace is True - assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True - - def test_email_recipients_specific_users(self): - """Test email recipients with specific users.""" - recipients = EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), - ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), - ], - ) - - assert recipients.whole_workspace is False - assert len(recipients.items) == 2 - assert recipients.items[0].user_id == "user-123" - assert recipients.items[1].email == "external@example.com" - - -class TestHumanInputNodeVariableResolution: - """Tests for resolving variable-based defaults in HumanInputNode.""" - - def test_resolves_variable_defaults(self): - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - variable_pool.add(("start", "name"), "Jane Doe") - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - node_data = HumanInputNodeData( - title="Human Input", - form_content="Provide your name", - inputs=[ - FormInput( - type=FormInputType.TEXT_INPUT, - output_variable_name="user_name", - default=FormInputDefault(type=PlaceholderType.VARIABLE, selector=["start", "name"]), - ), - FormInput( - type=FormInputType.TEXT_INPUT, - output_variable_name="user_email", - default=FormInputDefault(type=PlaceholderType.CONSTANT, value="foo@example.com"), - ), - ], - user_actions=[UserAction(id="submit", title="Submit")], - ) - config = {"id": "human", "data": node_data.model_dump()} - - mock_repo = MagicMock(spec=HumanInputFormRepository) - mock_repo.get_form.return_value = None - mock_repo.create_form.return_value = SimpleNamespace( - id="form-1", - rendered_content="Provide your name", - web_app_token="token", - recipients=[], - submitted=False, - ) - - node = HumanInputNode( - id=config["id"], - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=mock_repo, - ) - - run_result = node._run() - pause_event = next(run_result) - - assert isinstance(pause_event, PauseRequestedEvent) - expected_values = {"user_name": "Jane Doe"} - assert pause_event.reason.resolved_default_values == expected_values - - params = mock_repo.create_form.call_args.args[0] - assert params.resolved_default_values == expected_values - - def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-2", - ), - user_inputs={}, - conversation_variables=[], - ) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - node_data = HumanInputNodeData( - title="Human Input", - form_content="Provide your name", - inputs=[], - user_actions=[UserAction(id="submit", title="Submit")], - ) - config = {"id": "human", "data": node_data.model_dump()} - - mock_repo = MagicMock(spec=HumanInputFormRepository) - mock_repo.get_form.return_value = None - mock_repo.create_form.return_value = SimpleNamespace( - id="form-2", - rendered_content="Provide your name", - web_app_token="console-token", - recipients=[SimpleNamespace(token="recipient-token")], - submitted=False, - ) - - node = HumanInputNode( - id=config["id"], - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=mock_repo, - ) - - run_result = node._run() - pause_event = next(run_result) - - assert isinstance(pause_event, PauseRequestedEvent) - assert pause_event.reason.form_token == "console-token" - - def test_debugger_debug_mode_overrides_email_recipients(self): - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user-123", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-3", - ), - user_inputs={}, - conversation_variables=[], - ) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user-123", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - node_data = HumanInputNodeData( - title="Human Input", - form_content="Provide your name", - inputs=[], - user_actions=[UserAction(id="submit", title="Submit")], - delivery_methods=[ - EmailDeliveryMethod( - enabled=True, - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")], - ), - subject="Subject", - body="Body", - debug_mode=True, - ), - ) - ], - ) - config = {"id": "human", "data": node_data.model_dump()} - - mock_repo = MagicMock(spec=HumanInputFormRepository) - mock_repo.get_form.return_value = None - mock_repo.create_form.return_value = SimpleNamespace( - id="form-3", - rendered_content="Provide your name", - web_app_token="token", - recipients=[], - submitted=False, - ) - - node = HumanInputNode( - id=config["id"], - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=mock_repo, - ) - - run_result = node._run() - pause_event = next(run_result) - assert isinstance(pause_event, PauseRequestedEvent) - - params = mock_repo.create_form.call_args.args[0] - assert len(params.delivery_methods) == 1 - method = params.delivery_methods[0] - assert isinstance(method, EmailDeliveryMethod) - assert method.config.debug_mode is True - assert method.config.recipients.whole_workspace is False - assert len(method.config.recipients.items) == 1 - recipient = method.config.recipients.items[0] - assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == "user-123" - - -class TestValidation: - """Test validation scenarios.""" - - def test_invalid_form_input_type(self): - """Test validation with invalid form input type.""" - with pytest.raises(ValidationError): - FormInput( - type="invalid-type", # Invalid type - output_variable_name="test", - ) - - def test_invalid_button_style(self): - """Test validation with invalid button style.""" - with pytest.raises(ValidationError): - UserAction( - id="test", - title="Test", - button_style="invalid-style", # Invalid style - ) - - def test_invalid_timeout_unit(self): - """Test validation with invalid timeout unit.""" - with pytest.raises(ValidationError): - HumanInputNodeData( - title="Test", - timeout_unit="invalid-unit", # Invalid unit - ) - - -class TestHumanInputNodeRenderedContent: - """Tests for rendering submitted content.""" - - def test_replaces_outputs_placeholders_after_submission(self): - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - node_data = HumanInputNodeData( - title="Human Input", - form_content="Name: {{#$output.name#}}", - inputs=[ - FormInput( - type=FormInputType.TEXT_INPUT, - output_variable_name="name", - ) - ], - user_actions=[UserAction(id="approve", title="Approve")], - ) - config = {"id": "human", "data": node_data.model_dump()} - - form_repository = InMemoryHumanInputFormRepository() - node = HumanInputNode( - id=config["id"], - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - ) - - pause_gen = node._run() - pause_event = next(pause_gen) - assert isinstance(pause_event, PauseRequestedEvent) - with pytest.raises(StopIteration): - next(pause_gen) - - form_repository.set_submission(action_id="approve", form_data={"name": "Alice"}) - - events = list(node._run()) - last_event = events[-1] - assert isinstance(last_event, StreamCompletedEvent) - node_run_result = last_event.node_run_result - assert node_run_result.outputs["__rendered_content"] == "Name: Alice" diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py deleted file mode 100644 index a19ee4dee3..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ /dev/null @@ -1,172 +0,0 @@ -import datetime -from types import SimpleNamespace - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.enums import NodeType -from core.workflow.graph_events import ( - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunStartedEvent, -) -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from models.enums import UserFrom - - -class _FakeFormRepository: - def __init__(self, form): - self._form = form - - def get_form(self, *_args, **_kwargs): - return self._form - - -def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: - system_variables = SystemVariable.default() - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), - start_at=0.0, - ) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ) - - config = { - "id": "node-1", - "type": NodeType.HUMAN_INPUT.value, - "data": { - "title": "Human Input", - "form_content": form_content, - "inputs": [ - { - "type": "text_input", - "output_variable_name": "name", - "default": {"type": "constant", "value": ""}, - } - ], - "user_actions": [ - { - "id": "Accept", - "title": "Approve", - "button_style": "default", - } - ], - }, - } - - fake_form = SimpleNamespace( - id="form-1", - rendered_content=form_content, - submitted=True, - selected_action_id="Accept", - submitted_data={"name": "Alice"}, - status=HumanInputFormStatus.SUBMITTED, - expiration_time=naive_utc_now() + datetime.timedelta(days=1), - ) - - repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=repo, - ) - - -def _build_timeout_node() -> HumanInputNode: - system_variables = SystemVariable.default() - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), - start_at=0.0, - ) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ) - - config = { - "id": "node-1", - "type": NodeType.HUMAN_INPUT.value, - "data": { - "title": "Human Input", - "form_content": "Please enter your name:\n\n{{#$output.name#}}", - "inputs": [ - { - "type": "text_input", - "output_variable_name": "name", - "default": {"type": "constant", "value": ""}, - } - ], - "user_actions": [ - { - "id": "Accept", - "title": "Approve", - "button_style": "default", - } - ], - }, - } - - fake_form = SimpleNamespace( - id="form-1", - rendered_content="content", - submitted=False, - selected_action_id=None, - submitted_data=None, - status=HumanInputFormStatus.TIMEOUT, - expiration_time=naive_utc_now() - datetime.timedelta(minutes=1), - ) - - repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=repo, - ) - - -def test_human_input_node_emits_form_filled_event_before_succeeded(): - node = _build_node() - - events = list(node.run()) - - assert isinstance(events[0], NodeRunStartedEvent) - assert isinstance(events[1], NodeRunHumanInputFormFilledEvent) - - filled_event = events[1] - assert filled_event.node_title == "Human Input" - assert filled_event.rendered_content.endswith("Alice") - assert filled_event.action_id == "Accept" - assert filled_event.action_text == "Approve" - - -def test_human_input_node_emits_timeout_event_before_succeeded(): - node = _build_timeout_node() - - events = list(node.run()) - - assert isinstance(events[0], NodeRunStartedEvent) - assert isinstance(events[1], NodeRunHumanInputFormTimeoutEvent) - - timeout_event = events[1] - assert timeout_event.node_title == "Human Input" diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool_conver.py b/api/tests/unit_tests/core/workflow/test_variable_pool_conver.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index 38477409bb..d3a4d69f07 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -104,7 +104,6 @@ class TestCelerySSLConfiguration: def test_celery_init_applies_ssl_to_broker_and_backend(self): """Test that SSL options are applied to both broker and backend when using Redis.""" mock_config = MagicMock() - mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1 mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.CELERY_BACKEND = "redis" mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" diff --git a/api/tests/unit_tests/extensions/test_pubsub_channel.py b/api/tests/unit_tests/extensions/test_pubsub_channel.py deleted file mode 100644 index a5b41a7266..0000000000 --- a/api/tests/unit_tests/extensions/test_pubsub_channel.py +++ /dev/null @@ -1,20 +0,0 @@ -from configs import dify_config -from extensions import ext_redis -from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel -from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel - - -def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch): - monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") - - channel = ext_redis.get_pubsub_broadcast_channel() - - assert isinstance(channel, RedisBroadcastChannel) - - -def test_get_pubsub_broadcast_channel_sharded(monkeypatch): - monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded") - - channel = ext_redis.get_pubsub_broadcast_channel() - - assert isinstance(channel, ShardedRedisBroadcastChannel) diff --git a/api/tests/unit_tests/libs/_human_input/__init__.py b/api/tests/unit_tests/libs/_human_input/__init__.py deleted file mode 100644 index 66714e72f8..0000000000 --- a/api/tests/unit_tests/libs/_human_input/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Treat this directory as a package so support modules can be imported relatively. diff --git a/api/tests/unit_tests/libs/_human_input/support.py b/api/tests/unit_tests/libs/_human_input/support.py deleted file mode 100644 index bd86c13a2c..0000000000 --- a/api/tests/unit_tests/libs/_human_input/support.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from typing import Any - -from core.workflow.nodes.human_input.entities import FormInput -from core.workflow.nodes.human_input.enums import TimeoutUnit - - -# Exceptions -class HumanInputError(Exception): - error_code: str = "unknown" - - def __init__(self, message: str = "", error_code: str | None = None): - super().__init__(message) - self.message = message or self.__class__.__name__ - if error_code: - self.error_code = error_code - - -class FormNotFoundError(HumanInputError): - error_code = "form_not_found" - - -class FormExpiredError(HumanInputError): - error_code = "human_input_form_expired" - - -class FormAlreadySubmittedError(HumanInputError): - error_code = "human_input_form_submitted" - - -class InvalidFormDataError(HumanInputError): - error_code = "invalid_form_data" - - -# Models -@dataclass -class HumanInputForm: - form_id: str - workflow_run_id: str - node_id: str - tenant_id: str - app_id: str | None - form_content: str - inputs: list[FormInput] - user_actions: list[dict[str, Any]] - timeout: int - timeout_unit: TimeoutUnit - form_token: str | None = None - created_at: datetime = field(default_factory=datetime.utcnow) - expires_at: datetime | None = None - submitted_at: datetime | None = None - submitted_data: dict[str, Any] | None = None - submitted_action: str | None = None - - def __post_init__(self) -> None: - if self.expires_at is None: - self.calculate_expiration() - - @property - def is_expired(self) -> bool: - return self.expires_at is not None and datetime.utcnow() > self.expires_at - - @property - def is_submitted(self) -> bool: - return self.submitted_at is not None - - def mark_submitted(self, inputs: dict[str, Any], action: str) -> None: - self.submitted_data = inputs - self.submitted_action = action - self.submitted_at = datetime.utcnow() - - def submit(self, inputs: dict[str, Any], action: str) -> None: - self.mark_submitted(inputs, action) - - def calculate_expiration(self) -> None: - start = self.created_at - if self.timeout_unit == TimeoutUnit.HOUR: - self.expires_at = start + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - self.expires_at = start + timedelta(days=self.timeout) - else: - raise ValueError(f"Unsupported timeout unit {self.timeout_unit}") - - def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]: - inputs_response = [ - { - "type": form_input.type.name.lower().replace("_", "-"), - "output_variable_name": form_input.output_variable_name, - } - for form_input in self.inputs - ] - response = { - "form_content": self.form_content, - "inputs": inputs_response, - "user_actions": self.user_actions, - } - if include_site_info: - response["site"] = {"app_id": self.app_id, "title": "Workflow Form"} - return response - - -@dataclass -class FormSubmissionData: - form_id: str - inputs: dict[str, Any] - action: str - submitted_at: datetime = field(default_factory=datetime.utcnow) - - @classmethod - def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore - return cls(form_id=form_id, inputs=request.inputs, action=request.action) - - -@dataclass -class FormSubmissionRequest: - inputs: dict[str, Any] - action: str - - -# Repository -class InMemoryFormRepository: - """ - Simple in-memory repository used by unit tests. - """ - - def __init__(self): - self._forms: dict[str, HumanInputForm] = {} - - @property - def forms(self) -> dict[str, HumanInputForm]: - return self._forms - - def save(self, form: HumanInputForm) -> None: - self._forms[form.form_id] = form - - def get_by_id(self, form_id: str) -> HumanInputForm | None: - return self._forms.get(form_id) - - def get_by_token(self, token: str) -> HumanInputForm | None: - for form in self._forms.values(): - if form.form_token == token: - return form - return None - - def delete(self, form_id: str) -> None: - self._forms.pop(form_id, None) - - -# Service -class FormService: - """Service layer for managing human input forms in tests.""" - - def __init__(self, repository: InMemoryFormRepository): - self.repository = repository - - def create_form( - self, - *, - form_id: str, - workflow_run_id: str, - node_id: str, - tenant_id: str, - app_id: str | None, - form_content: str, - inputs, - user_actions, - timeout: int, - timeout_unit: TimeoutUnit, - form_token: str | None = None, - ) -> HumanInputForm: - form = HumanInputForm( - form_id=form_id, - workflow_run_id=workflow_run_id, - node_id=node_id, - tenant_id=tenant_id, - app_id=app_id, - form_content=form_content, - inputs=list(inputs), - user_actions=[{"id": action.id, "title": action.title} for action in user_actions], - timeout=timeout, - timeout_unit=timeout_unit, - form_token=form_token, - ) - form.calculate_expiration() - self.repository.save(form) - return form - - def get_form_by_id(self, form_id: str) -> HumanInputForm: - form = self.repository.get_by_id(form_id) - if form is None: - raise FormNotFoundError() - return form - - def get_form_by_token(self, token: str) -> HumanInputForm: - form = self.repository.get_by_token(token) - if form is None: - raise FormNotFoundError() - return form - - def get_form_definition(self, form_id: str, *, is_token: bool) -> dict: - form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id) - if form.is_expired: - raise FormExpiredError() - if form.is_submitted: - raise FormAlreadySubmittedError() - - definition = { - "form_content": form.form_content, - "inputs": form.inputs, - "user_actions": form.user_actions, - } - if is_token: - definition["site"] = {"title": "Workflow Form"} - return definition - - def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None: - form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id) - if form.is_expired: - raise FormExpiredError() - if form.is_submitted: - raise FormAlreadySubmittedError() - - self._validate_submission(form=form, submission_data=submission_data) - form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action) - self.repository.save(form) - - def cleanup_expired_forms(self) -> int: - expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired] - for form_id in expired_ids: - self.repository.delete(form_id) - return len(expired_ids) - - def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None: - defined_actions = {action["id"] for action in form.user_actions} - if submission_data.action not in defined_actions: - raise InvalidFormDataError(f"Invalid action: {submission_data.action}") - - missing_inputs = [] - for form_input in form.inputs: - if form_input.output_variable_name not in submission_data.inputs: - missing_inputs.append(form_input.output_variable_name) - - if missing_inputs: - raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}") - - # Extra inputs are allowed; no further validation required. diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py deleted file mode 100644 index 15e7d41e85..0000000000 --- a/api/tests/unit_tests/libs/_human_input/test_form_service.py +++ /dev/null @@ -1,326 +0,0 @@ -""" -Unit tests for FormService. -""" - -from datetime import datetime, timedelta - -import pytest - -from core.workflow.nodes.human_input.entities import ( - FormInput, - UserAction, -) -from core.workflow.nodes.human_input.enums import ( - FormInputType, - TimeoutUnit, -) -from libs.datetime_utils import naive_utc_now - -from .support import ( - FormAlreadySubmittedError, - FormExpiredError, - FormNotFoundError, - FormService, - FormSubmissionData, - InMemoryFormRepository, - InvalidFormDataError, -) - - -class TestFormService: - """Test FormService functionality.""" - - @pytest.fixture - def repository(self): - """Create in-memory repository for testing.""" - return InMemoryFormRepository() - - @pytest.fixture - def form_service(self, repository): - """Create FormService with in-memory repository.""" - return FormService(repository) - - @pytest.fixture - def sample_form_data(self): - """Create sample form data.""" - return { - "form_id": "form-123", - "workflow_run_id": "run-456", - "node_id": "node-789", - "tenant_id": "tenant-abc", - "app_id": "app-def", - "form_content": "# Test Form\n\nInput: {{#$output.input#}}", - "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)], - "user_actions": [UserAction(id="submit", title="Submit")], - "timeout": 1, - "timeout_unit": TimeoutUnit.HOUR, - "form_token": "token-xyz", - } - - def test_create_form(self, form_service, sample_form_data): - """Test form creation.""" - form = form_service.create_form(**sample_form_data) - - assert form.form_id == "form-123" - assert form.workflow_run_id == "run-456" - assert form.node_id == "node-789" - assert form.tenant_id == "tenant-abc" - assert form.app_id == "app-def" - assert form.form_token == "token-xyz" - assert form.timeout == 1 - assert form.timeout_unit == TimeoutUnit.HOUR - assert form.expires_at is not None - assert not form.is_expired - assert not form.is_submitted - - def test_get_form_by_id(self, form_service, sample_form_data): - """Test getting form by ID.""" - # Create form first - created_form = form_service.create_form(**sample_form_data) - - # Retrieve form - retrieved_form = form_service.get_form_by_id("form-123") - - assert retrieved_form.form_id == created_form.form_id - assert retrieved_form.workflow_run_id == created_form.workflow_run_id - - def test_get_form_by_id_not_found(self, form_service): - """Test getting non-existent form by ID.""" - with pytest.raises(FormNotFoundError) as exc_info: - form_service.get_form_by_id("non-existent-form") - - assert exc_info.value.error_code == "form_not_found" - - def test_get_form_by_token(self, form_service, sample_form_data): - """Test getting form by token.""" - # Create form first - created_form = form_service.create_form(**sample_form_data) - - # Retrieve form by token - retrieved_form = form_service.get_form_by_token("token-xyz") - - assert retrieved_form.form_id == created_form.form_id - assert retrieved_form.form_token == "token-xyz" - - def test_get_form_by_token_not_found(self, form_service): - """Test getting non-existent form by token.""" - with pytest.raises(FormNotFoundError) as exc_info: - form_service.get_form_by_token("non-existent-token") - - assert exc_info.value.error_code == "form_not_found" - - def test_get_form_definition_by_id(self, form_service, sample_form_data): - """Test getting form definition by ID.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Get form definition - definition = form_service.get_form_definition("form-123", is_token=False) - - assert "form_content" in definition - assert "inputs" in definition - assert definition["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}" - assert len(definition["inputs"]) == 1 - assert "site" not in definition # Should not include site info for ID-based access - - def test_get_form_definition_by_token(self, form_service, sample_form_data): - """Test getting form definition by token.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Get form definition - definition = form_service.get_form_definition("token-xyz", is_token=True) - - assert "form_content" in definition - assert "inputs" in definition - assert "site" in definition # Should include site info for token-based access - - def test_get_form_definition_expired_form(self, form_service, sample_form_data): - """Test getting definition for expired form.""" - # Create form with past expiry - form_service.create_form(**sample_form_data) - - # Manually expire the form by modifying expiry time - form = form_service.get_form_by_id("form-123") - form.expires_at = datetime.utcnow() - timedelta(hours=1) - form_service.repository.save(form) - - # Should raise FormExpiredError - with pytest.raises(FormExpiredError) as exc_info: - form_service.get_form_definition("form-123", is_token=False) - - assert exc_info.value.error_code == "human_input_form_expired" - - def test_get_form_definition_submitted_form(self, form_service, sample_form_data): - """Test getting definition for already submitted form.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Submit the form - submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") - form_service.submit_form("form-123", submission_data, is_token=False) - - # Should raise FormAlreadySubmittedError - with pytest.raises(FormAlreadySubmittedError) as exc_info: - form_service.get_form_definition("form-123", is_token=False) - - assert exc_info.value.error_code == "human_input_form_submitted" - - def test_submit_form_success(self, form_service, sample_form_data): - """Test successful form submission.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Submit form - submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") - - # Should not raise any exception - form_service.submit_form("form-123", submission_data, is_token=False) - - # Verify form is marked as submitted - form = form_service.get_form_by_id("form-123") - assert form.is_submitted - assert form.submitted_data == {"input": "test value"} - assert form.submitted_action == "submit" - assert form.submitted_at is not None - - def test_submit_form_missing_inputs(self, form_service, sample_form_data): - """Test form submission with missing inputs.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Submit form with missing required input - submission_data = FormSubmissionData( - form_id="form-123", - inputs={}, # Missing required "input" field - action="submit", - ) - - with pytest.raises(InvalidFormDataError) as exc_info: - form_service.submit_form("form-123", submission_data, is_token=False) - - assert "Missing required inputs" in exc_info.value.message - assert "input" in exc_info.value.message - - def test_submit_form_invalid_action(self, form_service, sample_form_data): - """Test form submission with invalid action.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Submit form with invalid action - submission_data = FormSubmissionData( - form_id="form-123", - inputs={"input": "test value"}, - action="invalid_action", # Not in the allowed actions - ) - - with pytest.raises(InvalidFormDataError) as exc_info: - form_service.submit_form("form-123", submission_data, is_token=False) - - assert "Invalid action" in exc_info.value.message - assert "invalid_action" in exc_info.value.message - - def test_submit_form_expired(self, form_service, sample_form_data): - """Test submitting expired form.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Manually expire the form - form = form_service.get_form_by_id("form-123") - form.expires_at = datetime.utcnow() - timedelta(hours=1) - form_service.repository.save(form) - - # Try to submit expired form - submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") - - with pytest.raises(FormExpiredError) as exc_info: - form_service.submit_form("form-123", submission_data, is_token=False) - - assert exc_info.value.error_code == "human_input_form_expired" - - def test_submit_form_already_submitted(self, form_service, sample_form_data): - """Test submitting form that's already submitted.""" - # Create and submit form first - form_service.create_form(**sample_form_data) - - submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "first submission"}, action="submit") - form_service.submit_form("form-123", submission_data, is_token=False) - - # Try to submit again - second_submission = FormSubmissionData( - form_id="form-123", inputs={"input": "second submission"}, action="submit" - ) - - with pytest.raises(FormAlreadySubmittedError) as exc_info: - form_service.submit_form("form-123", second_submission, is_token=False) - - assert exc_info.value.error_code == "human_input_form_submitted" - - def test_cleanup_expired_forms(self, form_service, sample_form_data): - """Test cleanup of expired forms.""" - # Create multiple forms - for i in range(3): - data = sample_form_data.copy() - data["form_id"] = f"form-{i}" - data["form_token"] = f"token-{i}" - form_service.create_form(**data) - - # Manually expire some forms - for i in range(2): # Expire first 2 forms - form = form_service.get_form_by_id(f"form-{i}") - form.expires_at = naive_utc_now() - timedelta(hours=1) - form_service.repository.save(form) - - # Clean up expired forms - cleaned_count = form_service.cleanup_expired_forms() - - assert cleaned_count == 2 - - # Verify expired forms are gone - with pytest.raises(FormNotFoundError): - form_service.get_form_by_id("form-0") - - with pytest.raises(FormNotFoundError): - form_service.get_form_by_id("form-1") - - # Verify non-expired form still exists - form = form_service.get_form_by_id("form-2") - assert form.form_id == "form-2" - - -class TestFormValidation: - """Test form validation logic.""" - - def test_validate_submission_with_extra_inputs(self): - """Test validation allows extra inputs that aren't defined in form.""" - repository = InMemoryFormRepository() - form_service = FormService(repository) - - # Create form with one input - form_data = { - "form_id": "form-123", - "workflow_run_id": "run-456", - "node_id": "node-789", - "tenant_id": "tenant-abc", - "app_id": "app-def", - "form_content": "Test form", - "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="required_input", default=None)], - "user_actions": [UserAction(id="submit", title="Submit")], - "timeout": 1, - "timeout_unit": TimeoutUnit.HOUR, - } - - form_service.create_form(**form_data) - - # Submit with extra input (should be allowed) - submission_data = FormSubmissionData( - form_id="form-123", - inputs={ - "required_input": "value1", - "extra_input": "value2", # Extra input not defined in form - }, - action="submit", - ) - - # Should not raise any exception - form_service.submit_form("form-123", submission_data, is_token=False) diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py deleted file mode 100644 index 962eeb9e11..0000000000 --- a/api/tests/unit_tests/libs/_human_input/test_models.py +++ /dev/null @@ -1,232 +0,0 @@ -""" -Unit tests for human input form models. -""" - -from datetime import datetime, timedelta - -import pytest - -from core.workflow.nodes.human_input.entities import ( - FormInput, - UserAction, -) -from core.workflow.nodes.human_input.enums import ( - FormInputType, - TimeoutUnit, -) - -from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm - - -class TestHumanInputForm: - """Test HumanInputForm model.""" - - @pytest.fixture - def sample_form_data(self): - """Create sample form data.""" - return { - "form_id": "form-123", - "workflow_run_id": "run-456", - "node_id": "node-789", - "tenant_id": "tenant-abc", - "app_id": "app-def", - "form_content": "# Test Form\n\nInput: {{#$output.input#}}", - "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)], - "user_actions": [UserAction(id="submit", title="Submit")], - "timeout": 2, - "timeout_unit": TimeoutUnit.HOUR, - "form_token": "token-xyz", - } - - def test_form_creation(self, sample_form_data): - """Test form creation.""" - form = HumanInputForm(**sample_form_data) - - assert form.form_id == "form-123" - assert form.workflow_run_id == "run-456" - assert form.node_id == "node-789" - assert form.tenant_id == "tenant-abc" - assert form.app_id == "app-def" - assert form.form_token == "token-xyz" - assert form.timeout == 2 - assert form.timeout_unit == TimeoutUnit.HOUR - assert form.created_at is not None - assert form.expires_at is not None - assert form.submitted_at is None - assert form.submitted_data is None - assert form.submitted_action is None - - def test_form_expiry_calculation_hours(self, sample_form_data): - """Test form expiry calculation for hours.""" - form = HumanInputForm(**sample_form_data) - - # Should expire 2 hours after creation - expected_expiry = form.created_at + timedelta(hours=2) - assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second - - def test_form_expiry_calculation_days(self, sample_form_data): - """Test form expiry calculation for days.""" - sample_form_data["timeout"] = 3 - sample_form_data["timeout_unit"] = TimeoutUnit.DAY - - form = HumanInputForm(**sample_form_data) - - # Should expire 3 days after creation - expected_expiry = form.created_at + timedelta(days=3) - assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second - - def test_form_expiry_property_not_expired(self, sample_form_data): - """Test is_expired property for non-expired form.""" - form = HumanInputForm(**sample_form_data) - assert not form.is_expired - - def test_form_expiry_property_expired(self, sample_form_data): - """Test is_expired property for expired form.""" - # Create form with past expiry - past_time = datetime.utcnow() - timedelta(hours=1) - sample_form_data["created_at"] = past_time - - form = HumanInputForm(**sample_form_data) - # Manually set expiry to past time - form.expires_at = past_time - - assert form.is_expired - - def test_form_submission_property_not_submitted(self, sample_form_data): - """Test is_submitted property for non-submitted form.""" - form = HumanInputForm(**sample_form_data) - assert not form.is_submitted - - def test_form_submission_property_submitted(self, sample_form_data): - """Test is_submitted property for submitted form.""" - form = HumanInputForm(**sample_form_data) - form.submit({"input": "test value"}, "submit") - - assert form.is_submitted - assert form.submitted_at is not None - assert form.submitted_data == {"input": "test value"} - assert form.submitted_action == "submit" - - def test_form_submit_method(self, sample_form_data): - """Test form submit method.""" - form = HumanInputForm(**sample_form_data) - - submission_time_before = datetime.utcnow() - form.submit({"input": "test value"}, "submit") - submission_time_after = datetime.utcnow() - - assert form.is_submitted - assert form.submitted_data == {"input": "test value"} - assert form.submitted_action == "submit" - assert submission_time_before <= form.submitted_at <= submission_time_after - - def test_form_to_response_dict_without_site_info(self, sample_form_data): - """Test converting form to response dict without site info.""" - form = HumanInputForm(**sample_form_data) - - response = form.to_response_dict(include_site_info=False) - - assert "form_content" in response - assert "inputs" in response - assert "site" not in response - assert response["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}" - assert len(response["inputs"]) == 1 - assert response["inputs"][0]["type"] == "text-input" - assert response["inputs"][0]["output_variable_name"] == "input" - - def test_form_to_response_dict_with_site_info(self, sample_form_data): - """Test converting form to response dict with site info.""" - form = HumanInputForm(**sample_form_data) - - response = form.to_response_dict(include_site_info=True) - - assert "form_content" in response - assert "inputs" in response - assert "site" in response - assert response["site"]["app_id"] == "app-def" - assert response["site"]["title"] == "Workflow Form" - - def test_form_without_web_app_token(self, sample_form_data): - """Test form creation without web app token.""" - sample_form_data["form_token"] = None - - form = HumanInputForm(**sample_form_data) - - assert form.form_token is None - assert form.form_id == "form-123" # Other fields should still work - - def test_form_with_explicit_timestamps(self): - """Test form creation with explicit timestamps.""" - created_time = datetime(2024, 1, 15, 10, 30, 0) - expires_time = datetime(2024, 1, 15, 12, 30, 0) - - form = HumanInputForm( - form_id="form-123", - workflow_run_id="run-456", - node_id="node-789", - tenant_id="tenant-abc", - app_id="app-def", - form_content="Test content", - inputs=[], - user_actions=[], - timeout=2, - timeout_unit=TimeoutUnit.HOUR, - created_at=created_time, - expires_at=expires_time, - ) - - assert form.created_at == created_time - assert form.expires_at == expires_time - - -class TestFormSubmissionData: - """Test FormSubmissionData model.""" - - def test_submission_data_creation(self): - """Test submission data creation.""" - submission_data = FormSubmissionData( - form_id="form-123", inputs={"field1": "value1", "field2": "value2"}, action="submit" - ) - - assert submission_data.form_id == "form-123" - assert submission_data.inputs == {"field1": "value1", "field2": "value2"} - assert submission_data.action == "submit" - assert submission_data.submitted_at is not None - - def test_submission_data_from_request(self): - """Test creating submission data from API request.""" - request = FormSubmissionRequest(inputs={"input": "test value"}, action="confirm") - - submission_data = FormSubmissionData.from_request("form-456", request) - - assert submission_data.form_id == "form-456" - assert submission_data.inputs == {"input": "test value"} - assert submission_data.action == "confirm" - assert submission_data.submitted_at is not None - - def test_submission_data_with_empty_inputs(self): - """Test submission data with empty inputs.""" - submission_data = FormSubmissionData(form_id="form-123", inputs={}, action="cancel") - - assert submission_data.inputs == {} - assert submission_data.action == "cancel" - - def test_submission_data_timestamps(self): - """Test submission data timestamp handling.""" - before_time = datetime.utcnow() - - submission_data = FormSubmissionData(form_id="form-123", inputs={"test": "value"}, action="submit") - - after_time = datetime.utcnow() - - assert before_time <= submission_data.submitted_at <= after_time - - def test_submission_data_with_explicit_timestamp(self): - """Test submission data with explicit timestamp.""" - specific_time = datetime(2024, 1, 15, 14, 30, 0) - - submission_data = FormSubmissionData( - form_id="form-123", inputs={"test": "value"}, action="submit", submitted_at=specific_time - ) - - assert submission_data.submitted_at == specific_time diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index 1a93dbbca1..de74eff82f 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -1,8 +1,6 @@ -from datetime import datetime - import pytest -from libs.helper import OptionalTimestampField, escape_like_pattern, extract_tenant_id +from libs.helper import escape_like_pattern, extract_tenant_id from models.account import Account from models.model import EndUser @@ -67,19 +65,6 @@ class TestExtractTenantId: extract_tenant_id(dict_user) -class TestOptionalTimestampField: - def test_format_returns_none_for_none(self): - field = OptionalTimestampField() - - assert field.format(None) is None - - def test_format_returns_unix_timestamp_for_datetime(self): - field = OptionalTimestampField() - value = datetime(2024, 1, 2, 3, 4, 5) - - assert field.format(value) == int(value.timestamp()) - - class TestEscapeLikePattern: """Test cases for the escape_like_pattern utility function.""" diff --git a/api/tests/unit_tests/libs/test_rate_limiter.py b/api/tests/unit_tests/libs/test_rate_limiter.py deleted file mode 100644 index 9d44b07b5e..0000000000 --- a/api/tests/unit_tests/libs/test_rate_limiter.py +++ /dev/null @@ -1,68 +0,0 @@ -from unittest.mock import MagicMock - -from libs import helper as helper_module - - -class _FakeRedis: - def __init__(self) -> None: - self._zsets: dict[str, dict[str, float]] = {} - self._expiry: dict[str, int] = {} - - def zadd(self, key: str, mapping: dict[str, float]) -> int: - zset = self._zsets.setdefault(key, {}) - for member, score in mapping.items(): - zset[str(member)] = float(score) - return len(mapping) - - def zremrangebyscore(self, key: str, min_score: str | float, max_score: str | float) -> int: - zset = self._zsets.get(key, {}) - min_value = float("-inf") if min_score == "-inf" else float(min_score) - max_value = float("inf") if max_score == "+inf" else float(max_score) - to_delete = [member for member, score in zset.items() if min_value <= score <= max_value] - for member in to_delete: - del zset[member] - return len(to_delete) - - def zcard(self, key: str) -> int: - return len(self._zsets.get(key, {})) - - def expire(self, key: str, ttl: int) -> bool: - self._expiry[key] = ttl - return True - - -def test_rate_limiter_counts_attempts_within_same_second(monkeypatch): - fake_redis = _FakeRedis() - monkeypatch.setattr(helper_module.time, "time", lambda: 1000) - - limiter = helper_module.RateLimiter( - prefix="test_rate_limit", - max_attempts=2, - time_window=60, - redis_client=fake_redis, - ) - - limiter.increment_rate_limit("203.0.113.10") - limiter.increment_rate_limit("203.0.113.10") - - assert limiter.is_rate_limited("203.0.113.10") is True - - -def test_rate_limiter_uses_injected_redis(monkeypatch): - redis_client = MagicMock() - redis_client.zcard.return_value = 1 - monkeypatch.setattr(helper_module.time, "time", lambda: 1000) - - limiter = helper_module.RateLimiter( - prefix="test_rate_limit", - max_attempts=1, - time_window=60, - redis_client=redis_client, - ) - - limiter.increment_rate_limit("203.0.113.10") - limiter.is_rate_limited("203.0.113.10") - - assert redis_client.zadd.called is True - assert redis_client.zremrangebyscore.called is True - assert redis_client.zcard.called is True diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index c6dfd41803..8be2eea121 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -1296,7 +1296,6 @@ class TestConversationStatusCount: assert result["success"] == 1 # One SUCCEEDED assert result["failed"] == 1 # One FAILED assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED - assert result["paused"] == 0 def test_status_count_app_id_filtering(self): """Test that status_count filters workflow runs by app_id for security.""" @@ -1351,7 +1350,6 @@ class TestConversationStatusCount: assert result["success"] == 0 assert result["failed"] == 0 assert result["partial_success"] == 0 - assert result["paused"] == 0 def test_status_count_handles_invalid_workflow_status(self): """Test that status_count gracefully handles invalid workflow status values.""" @@ -1406,57 +1404,3 @@ class TestConversationStatusCount: assert result["success"] == 0 assert result["failed"] == 0 assert result["partial_success"] == 0 - assert result["paused"] == 0 - - def test_status_count_paused(self): - """Test status_count includes paused workflow runs.""" - # Arrange - from core.workflow.enums import WorkflowExecutionStatus - - app_id = str(uuid4()) - conversation_id = str(uuid4()) - workflow_run_id = str(uuid4()) - - conversation = Conversation( - app_id=app_id, - mode=AppMode.CHAT, - name="Test Conversation", - status="normal", - from_source="api", - ) - conversation.id = conversation_id - - mock_messages = [ - MagicMock( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id, - ), - ] - - mock_workflow_runs = [ - MagicMock( - id=workflow_run_id, - status=WorkflowExecutionStatus.PAUSED.value, - app_id=app_id, - ), - ] - - with patch("models.model.db.session.scalars") as mock_scalars: - - def mock_scalars_side_effect(query): - mock_result = MagicMock() - if "messages" in str(query): - mock_result.all.return_value = mock_messages - elif "workflow_runs" in str(query): - mock_result.all.return_value = mock_workflow_runs - else: - mock_result.all.return_value = [] - return mock_result - - mock_scalars.side_effect = mock_scalars_side_effect - - # Act - result = conversation.status_count - - # Assert - assert result["paused"] == 1 diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py deleted file mode 100644 index ceb1406a4b..0000000000 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Unit tests for DifyAPISQLAlchemyWorkflowNodeExecutionRepository implementation.""" - -from unittest.mock import Mock - -from sqlalchemy.orm import Session, sessionmaker - -from repositories.sqlalchemy_api_workflow_node_execution_repository import ( - DifyAPISQLAlchemyWorkflowNodeExecutionRepository, -) - - -class TestDifyAPISQLAlchemyWorkflowNodeExecutionRepository: - def test_get_executions_by_workflow_run_keeps_paused_records(self): - mock_session = Mock(spec=Session) - execute_result = Mock() - execute_result.scalars.return_value.all.return_value = [] - mock_session.execute.return_value = execute_result - - session_maker = Mock(spec=sessionmaker) - context_manager = Mock() - context_manager.__enter__ = Mock(return_value=mock_session) - context_manager.__exit__ = Mock(return_value=None) - session_maker.return_value = context_manager - - repository = DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker) - - repository.get_executions_by_workflow_run( - tenant_id="tenant-123", - app_id="app-123", - workflow_run_id="workflow-run-123", - ) - - stmt = mock_session.execute.call_args[0][0] - where_clauses = list(getattr(stmt, "_where_criteria", []) or []) - where_strs = [str(clause).lower() for clause in where_clauses] - - assert any("tenant_id" in clause for clause in where_strs) - assert any("app_id" in clause for clause in where_strs) - assert any("workflow_run_id" in clause for clause in where_strs) - assert not any("paused" in clause for clause in where_strs) diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 4caaa056ff..d443c4c9a5 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -1,6 +1,5 @@ """Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation.""" -import secrets from datetime import UTC, datetime from unittest.mock import Mock, patch @@ -8,17 +7,12 @@ import pytest from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session, sessionmaker -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason, WorkflowRun +from models.workflow import WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, - _build_human_input_required_reason, _PrivateWorkflowPauseEntity, _WorkflowRunError, ) @@ -211,11 +205,11 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): ): """Test workflow pause creation when workflow not in RUNNING status.""" # Arrange - sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED + sample_workflow_run.status = WorkflowExecutionStatus.PAUSED mock_session.get.return_value = sample_workflow_run # Act & Assert - with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"): + with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"): repository.create_workflow_pause( workflow_run_id="workflow-run-123", state_owner_user_id="user-123", @@ -301,7 +295,6 @@ class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): sample_workflow_pause.resumed_at = None mock_session.scalar.return_value = sample_workflow_run - mock_session.scalars.return_value.all.return_value = [] with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now: mock_now.return_value = datetime.now(UTC) @@ -462,53 +455,3 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository) assert result1 == expected_state assert result2 == expected_state mock_storage.load.assert_called_once() # Only called once due to caching - - -class TestBuildHumanInputRequiredReason: - def test_prefers_backstage_token_when_available(self): - expiration_time = datetime.now(UTC) - form_definition = FormDefinition( - form_content="content", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - rendered_content="rendered", - expiration_time=expiration_time, - default_values={"name": "Alice"}, - node_title="Ask Name", - display_in_ui=True, - ) - form_model = HumanInputForm( - id="form-1", - tenant_id="tenant-1", - app_id="app-1", - workflow_run_id="run-1", - node_id="node-1", - form_definition=form_definition.model_dump_json(), - rendered_content="rendered", - status=HumanInputFormStatus.WAITING, - expiration_time=expiration_time, - ) - reason_model = WorkflowPauseReason( - pause_id="pause-1", - type_=PauseReasonType.HUMAN_INPUT_REQUIRED, - form_id="form-1", - node_id="node-1", - message="", - ) - access_token = secrets.token_urlsafe(8) - backstage_recipient = HumanInputFormRecipient( - form_id="form-1", - delivery_id="delivery-1", - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload().model_dump_json(), - access_token=access_token, - ) - - reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient]) - - assert isinstance(reason, HumanInputRequired) - assert reason.form_token == access_token - assert reason.node_title == "Ask Name" - assert reason.form_content == "content" - assert reason.inputs[0].output_variable_name == "name" - assert reason.actions[0].id == "approve" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py deleted file mode 100644 index f5428b46ff..0000000000 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ /dev/null @@ -1,180 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import dataclass -from datetime import UTC, datetime, timedelta - -from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain -from core.entities.execution_extra_content import HumanInputFormSubmissionData -from core.workflow.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from models.execution_extra_content import HumanInputContent as HumanInputContentModel -from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType -from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository - - -class _FakeScalarResult: - def __init__(self, values: Sequence[HumanInputContentModel]): - self._values = list(values) - - def all(self) -> list[HumanInputContentModel]: - return list(self._values) - - -class _FakeSession: - def __init__(self, values: Sequence[Sequence[object]]): - self._values = list(values) - - def scalars(self, _stmt): - if not self._values: - return _FakeScalarResult([]) - return _FakeScalarResult(self._values.pop(0)) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - -@dataclass -class _FakeSessionMaker: - session: _FakeSession - - def __call__(self) -> _FakeSession: - return self.session - - -def _build_form(action_id: str, action_title: str, rendered_content: str) -> HumanInputForm: - expiration_time = datetime.now(UTC) + timedelta(days=1) - definition = FormDefinition( - form_content="content", - inputs=[], - user_actions=[UserAction(id=action_id, title=action_title)], - rendered_content="rendered", - expiration_time=expiration_time, - node_title="Approval", - display_in_ui=True, - ) - form = HumanInputForm( - id=f"form-{action_id}", - tenant_id="tenant-id", - app_id="app-id", - workflow_run_id="workflow-run", - node_id="node-id", - form_definition=definition.model_dump_json(), - rendered_content=rendered_content, - status=HumanInputFormStatus.SUBMITTED, - expiration_time=expiration_time, - ) - form.selected_action_id = action_id - return form - - -def _build_content(message_id: str, action_id: str, action_title: str) -> HumanInputContentModel: - form = _build_form( - action_id=action_id, - action_title=action_title, - rendered_content=f"Rendered {action_title}", - ) - content = HumanInputContentModel( - id=f"content-{message_id}", - form_id=form.id, - message_id=message_id, - workflow_run_id=form.workflow_run_id, - ) - content.form = form - return content - - -def test_get_by_message_ids_groups_contents_by_message() -> None: - message_ids = ["msg-1", "msg-2"] - contents = [_build_content("msg-1", "approve", "Approve")] - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=_FakeSessionMaker(session=_FakeSession(values=[contents, []])) - ) - - result = repository.get_by_message_ids(message_ids) - - assert len(result) == 2 - assert [content.model_dump(mode="json", exclude_none=True) for content in result[0]] == [ - HumanInputContentDomain( - workflow_run_id="workflow-run", - submitted=True, - form_submission_data=HumanInputFormSubmissionData( - node_id="node-id", - node_title="Approval", - rendered_content="Rendered Approve", - action_id="approve", - action_text="Approve", - ), - ).model_dump(mode="json", exclude_none=True) - ] - assert result[1] == [] - - -def test_get_by_message_ids_returns_unsubmitted_form_definition() -> None: - expiration_time = datetime.now(UTC) + timedelta(days=1) - definition = FormDefinition( - form_content="content", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - rendered_content="rendered", - expiration_time=expiration_time, - default_values={"name": "John"}, - node_title="Approval", - display_in_ui=True, - ) - form = HumanInputForm( - id="form-1", - tenant_id="tenant-id", - app_id="app-id", - workflow_run_id="workflow-run", - node_id="node-id", - form_definition=definition.model_dump_json(), - rendered_content="Rendered block", - status=HumanInputFormStatus.WAITING, - expiration_time=expiration_time, - ) - content = HumanInputContentModel( - id="content-msg-1", - form_id=form.id, - message_id="msg-1", - workflow_run_id=form.workflow_run_id, - ) - content.form = form - - recipient = HumanInputFormRecipient( - form_id=form.id, - delivery_id="delivery-1", - recipient_type=RecipientType.CONSOLE, - recipient_payload=ConsoleRecipientPayload(account_id=None).model_dump_json(), - access_token="token-1", - ) - - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=_FakeSessionMaker(session=_FakeSession(values=[[content], [recipient]])) - ) - - result = repository.get_by_message_ids(["msg-1"]) - - assert len(result) == 1 - assert len(result[0]) == 1 - domain_content = result[0][0] - assert domain_content.submitted is False - assert domain_content.workflow_run_id == "workflow-run" - assert domain_content.form_definition is not None - assert domain_content.form_definition.expiration_time == int(form.expiration_time.timestamp()) - assert domain_content.form_definition is not None - form_definition = domain_content.form_definition - assert form_definition.form_id == "form-1" - assert form_definition.node_id == "node-id" - assert form_definition.node_title == "Approval" - assert form_definition.form_content == "Rendered block" - assert form_definition.display_in_ui is True - assert form_definition.form_token == "token-1" - assert form_definition.resolved_default_values == {"name": "John"} - assert form_definition.expiration_time == int(form.expiration_time.timestamp()) diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index eca1d44d23..81135dbbdf 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -508,12 +508,9 @@ class TestConversationServiceMessageCreation: within conversations. """ - @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_by_first_id_without_first_id( - self, mock_get_conversation, mock_db_session, mock_create_extra_repo - ): + def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session): """ Test message pagination without specifying first_id. @@ -543,9 +540,6 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages - mock_repository = MagicMock() - mock_repository.get_by_message_ids.return_value = [[] for _ in messages] - mock_create_extra_repo.return_value = mock_repository # Act - Call the pagination method without first_id result = MessageService.pagination_by_first_id( @@ -562,10 +556,9 @@ class TestConversationServiceMessageCreation: # Verify conversation was looked up with correct parameters mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id) - @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): + def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session): """ Test message pagination with first_id specified. @@ -597,9 +590,6 @@ class TestConversationServiceMessageCreation: mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.first.return_value = first_message # First message returned mock_query.all.return_value = messages # Remaining messages returned - mock_repository = MagicMock() - mock_repository.get_by_message_ids.return_value = [[] for _ in messages] - mock_create_extra_repo.return_value = mock_repository # Act - Call the pagination method with first_id result = MessageService.pagination_by_first_id( @@ -694,10 +684,9 @@ class TestConversationServiceMessageCreation: assert result.data == [] assert result.has_more is False - @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): + def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session): """ Test that has_more flag is correctly set when there are more messages. @@ -727,9 +716,6 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages - mock_repository = MagicMock() - mock_repository.get_by_message_ids.return_value = [[] for _ in messages] - mock_create_extra_repo.return_value = mock_repository # Act result = MessageService.pagination_by_first_id( @@ -744,10 +730,9 @@ class TestConversationServiceMessageCreation: assert len(result.data) == limit # Extra message should be removed assert result.has_more is True # Flag should be set - @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): + def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session): """ Test message pagination with ascending order. @@ -776,9 +761,6 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages - mock_repository = MagicMock() - mock_repository.get_by_message_ids.return_value = [[] for _ in messages] - mock_create_extra_repo.return_value = mock_repository # Act result = MessageService.pagination_by_first_id( diff --git a/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py b/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py deleted file mode 100644 index ab141a7b2d..0000000000 --- a/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py +++ /dev/null @@ -1,104 +0,0 @@ -from dataclasses import dataclass - -import pytest - -from enums.cloud_plan import CloudPlan -from services import feature_service as feature_service_module -from services.feature_service import FeatureModel, FeatureService - - -@dataclass(frozen=True) -class HumanInputEmailDeliveryCase: - name: str - enterprise_enabled: bool - billing_enabled: bool - tenant_id: str | None - billing_feature_enabled: bool - plan: str - expected: bool - - -CASES = [ - HumanInputEmailDeliveryCase( - name="enterprise_enabled", - enterprise_enabled=True, - billing_enabled=True, - tenant_id=None, - billing_feature_enabled=False, - plan=CloudPlan.SANDBOX, - expected=True, - ), - HumanInputEmailDeliveryCase( - name="billing_disabled", - enterprise_enabled=False, - billing_enabled=False, - tenant_id=None, - billing_feature_enabled=False, - plan=CloudPlan.SANDBOX, - expected=True, - ), - HumanInputEmailDeliveryCase( - name="billing_enabled_requires_tenant", - enterprise_enabled=False, - billing_enabled=True, - tenant_id=None, - billing_feature_enabled=True, - plan=CloudPlan.PROFESSIONAL, - expected=False, - ), - HumanInputEmailDeliveryCase( - name="billing_feature_off", - enterprise_enabled=False, - billing_enabled=True, - tenant_id="tenant-1", - billing_feature_enabled=False, - plan=CloudPlan.PROFESSIONAL, - expected=False, - ), - HumanInputEmailDeliveryCase( - name="professional_plan", - enterprise_enabled=False, - billing_enabled=True, - tenant_id="tenant-1", - billing_feature_enabled=True, - plan=CloudPlan.PROFESSIONAL, - expected=True, - ), - HumanInputEmailDeliveryCase( - name="team_plan", - enterprise_enabled=False, - billing_enabled=True, - tenant_id="tenant-1", - billing_feature_enabled=True, - plan=CloudPlan.TEAM, - expected=True, - ), - HumanInputEmailDeliveryCase( - name="sandbox_plan", - enterprise_enabled=False, - billing_enabled=True, - tenant_id="tenant-1", - billing_feature_enabled=True, - plan=CloudPlan.SANDBOX, - expected=False, - ), -] - - -@pytest.mark.parametrize("case", CASES, ids=lambda case: case.name) -def test_resolve_human_input_email_delivery_enabled_matrix( - monkeypatch: pytest.MonkeyPatch, - case: HumanInputEmailDeliveryCase, -): - monkeypatch.setattr(feature_service_module.dify_config, "ENTERPRISE_ENABLED", case.enterprise_enabled) - monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", case.billing_enabled) - features = FeatureModel() - features.billing.enabled = case.billing_feature_enabled - features.billing.subscription.plan = case.plan - - result = FeatureService._resolve_human_input_email_delivery_enabled( - features=features, - tenant_id=case.tenant_id, - ) - - assert result is case.expected diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py deleted file mode 100644 index e0d6ad1b39..0000000000 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ /dev/null @@ -1,97 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, -) -from core.workflow.runtime import VariablePool -from services import human_input_delivery_test_service as service_module -from services.human_input_delivery_test_service import ( - DeliveryTestContext, - DeliveryTestError, - EmailDeliveryTestHandler, -) - - -def _make_email_method() -> EmailDeliveryMethod: - return EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(email="tester@example.com")], - ), - subject="Test subject", - body="Test body", - ) - ) - - -def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr( - service_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), - ) - - handler = EmailDeliveryTestHandler(session_factory=object()) - context = DeliveryTestContext( - tenant_id="tenant-1", - app_id="app-1", - node_id="node-1", - node_title="Human Input", - rendered_content="content", - ) - method = _make_email_method() - - with pytest.raises(DeliveryTestError, match="Email delivery is not available"): - handler.send_test(context=context, method=method) - - -def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): - class DummyMail: - def __init__(self): - self.sent: list[dict[str, str]] = [] - - def is_inited(self) -> bool: - return True - - def send(self, *, to: str, subject: str, html: str): - self.sent.append({"to": to, "subject": subject, "html": html}) - - mail = DummyMail() - monkeypatch.setattr(service_module, "mail", mail) - monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template) - monkeypatch.setattr( - service_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), - ) - - handler = EmailDeliveryTestHandler(session_factory=object()) - handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment] - - method = EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]), - subject="Subject", - body="Value {{#node1.value#}}", - ) - ) - variable_pool = VariablePool() - variable_pool.add(["node1", "value"], "OK") - context = DeliveryTestContext( - tenant_id="tenant-1", - app_id="app-1", - node_id="node-1", - node_title="Human Input", - rendered_content="content", - variable_pool=variable_pool, - ) - - handler.send_test(context=context, method=method) - - assert mail.sent[0]["html"] == "Value OK" diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py deleted file mode 100644 index 72e19447bd..0000000000 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ /dev/null @@ -1,290 +0,0 @@ -import dataclasses -from datetime import datetime, timedelta -from unittest.mock import MagicMock - -import pytest - -import services.human_input_service as human_input_service_module -from core.repositories.human_input_repository import ( - HumanInputFormRecord, - HumanInputFormSubmissionRepository, -) -from core.workflow.nodes.human_input.entities import ( - FormDefinition, - FormInput, - UserAction, -) -from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus -from models.human_input import RecipientType -from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError -from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE - - -@pytest.fixture -def mock_session_factory(): - session = MagicMock() - session_cm = MagicMock() - session_cm.__enter__.return_value = session - session_cm.__exit__.return_value = None - - factory = MagicMock() - factory.return_value = session_cm - return factory, session - - -@pytest.fixture -def sample_form_record(): - return HumanInputFormRecord( - form_id="form-id", - workflow_run_id="workflow-run-id", - node_id="node-id", - tenant_id="tenant-id", - app_id="app-id", - form_kind=HumanInputFormKind.RUNTIME, - definition=FormDefinition( - form_content="hello", - inputs=[], - user_actions=[UserAction(id="submit", title="Submit")], - rendered_content="

hello

", - expiration_time=datetime.utcnow() + timedelta(hours=1), - ), - rendered_content="

hello

", - created_at=datetime.utcnow(), - expiration_time=datetime.utcnow() + timedelta(hours=1), - status=HumanInputFormStatus.WAITING, - selected_action_id=None, - submitted_data=None, - submitted_at=None, - submission_user_id=None, - submission_end_user_id=None, - completed_by_recipient_id=None, - recipient_id="recipient-id", - recipient_type=RecipientType.STANDALONE_WEB_APP, - access_token="token", - ) - - -def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory): - session_factory, session = mock_session_factory - service = HumanInputService(session_factory) - - workflow_run = MagicMock() - workflow_run.app_id = "app-id" - - workflow_run_repo = MagicMock() - workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run - mocker.patch( - "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", - return_value=workflow_run_repo, - ) - - app = MagicMock() - app.mode = "workflow" - session.execute.return_value.scalar_one_or_none.return_value = app - - resume_task = mocker.patch("services.human_input_service.resume_app_execution") - - service.enqueue_resume("workflow-run-id") - - resume_task.apply_async.assert_called_once() - call_kwargs = resume_task.apply_async.call_args.kwargs - assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE - assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id" - - -def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_record, mock_session_factory): - session_factory, _ = mock_session_factory - service = HumanInputService(session_factory) - expired_record = dataclasses.replace( - sample_form_record, - created_at=datetime.utcnow() - timedelta(hours=2), - expiration_time=datetime.utcnow() + timedelta(hours=2), - ) - monkeypatch.setattr(human_input_service_module.dify_config, "HITL_GLOBAL_TIMEOUT_SECONDS", 3600) - - with pytest.raises(FormExpiredError): - service.ensure_form_active(Form(expired_record)) - - -def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory): - session_factory, session = mock_session_factory - service = HumanInputService(session_factory) - - workflow_run = MagicMock() - workflow_run.app_id = "app-id" - - workflow_run_repo = MagicMock() - workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run - mocker.patch( - "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", - return_value=workflow_run_repo, - ) - - app = MagicMock() - app.mode = "advanced-chat" - session.execute.return_value.scalar_one_or_none.return_value = app - - resume_task = mocker.patch("services.human_input_service.resume_app_execution") - - service.enqueue_resume("workflow-run-id") - - resume_task.apply_async.assert_called_once() - call_kwargs = resume_task.apply_async.call_args.kwargs - assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE - assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id" - - -def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory): - session_factory, session = mock_session_factory - service = HumanInputService(session_factory) - - workflow_run = MagicMock() - workflow_run.app_id = "app-id" - - workflow_run_repo = MagicMock() - workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run - mocker.patch( - "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", - return_value=workflow_run_repo, - ) - - app = MagicMock() - app.mode = "completion" - session.execute.return_value.scalar_one_or_none.return_value = app - - resume_task = mocker.patch("services.human_input_service.resume_app_execution") - - service.enqueue_resume("workflow-run-id") - - resume_task.apply_async.assert_not_called() - - -def test_get_form_definition_by_token_for_console_uses_repository(sample_form_record, mock_session_factory): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - console_record = dataclasses.replace(sample_form_record, recipient_type=RecipientType.CONSOLE) - repo.get_by_token.return_value = console_record - - service = HumanInputService(session_factory, form_repository=repo) - form = service.get_form_definition_by_token_for_console("token") - - repo.get_by_token.assert_called_once_with("token") - assert form is not None - assert form.get_definition() == console_record.definition - - -def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - repo.get_by_token.return_value = sample_form_record - repo.mark_submitted.return_value = sample_form_record - service = HumanInputService(session_factory, form_repository=repo) - enqueue_spy = mocker.patch.object(service, "enqueue_resume") - - service.submit_form_by_token( - recipient_type=RecipientType.STANDALONE_WEB_APP, - form_token="token", - selected_action_id="submit", - form_data={"field": "value"}, - submission_end_user_id="end-user-id", - ) - - repo.get_by_token.assert_called_once_with("token") - repo.mark_submitted.assert_called_once() - call_kwargs = repo.mark_submitted.call_args.kwargs - assert call_kwargs["form_id"] == sample_form_record.form_id - assert call_kwargs["recipient_id"] == sample_form_record.recipient_id - assert call_kwargs["selected_action_id"] == "submit" - assert call_kwargs["form_data"] == {"field": "value"} - assert call_kwargs["submission_end_user_id"] == "end-user-id" - enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) - - -def test_submit_form_by_token_skips_enqueue_for_delivery_test(sample_form_record, mock_session_factory, mocker): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - test_record = dataclasses.replace( - sample_form_record, - form_kind=HumanInputFormKind.DELIVERY_TEST, - workflow_run_id=None, - ) - repo.get_by_token.return_value = test_record - repo.mark_submitted.return_value = test_record - service = HumanInputService(session_factory, form_repository=repo) - enqueue_spy = mocker.patch.object(service, "enqueue_resume") - - service.submit_form_by_token( - recipient_type=RecipientType.STANDALONE_WEB_APP, - form_token="token", - selected_action_id="submit", - form_data={"field": "value"}, - ) - - enqueue_spy.assert_not_called() - - -def test_submit_form_by_token_passes_submission_user_id(sample_form_record, mock_session_factory, mocker): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - repo.get_by_token.return_value = sample_form_record - repo.mark_submitted.return_value = sample_form_record - service = HumanInputService(session_factory, form_repository=repo) - enqueue_spy = mocker.patch.object(service, "enqueue_resume") - - service.submit_form_by_token( - recipient_type=RecipientType.STANDALONE_WEB_APP, - form_token="token", - selected_action_id="submit", - form_data={"field": "value"}, - submission_user_id="account-id", - ) - - call_kwargs = repo.mark_submitted.call_args.kwargs - assert call_kwargs["submission_user_id"] == "account-id" - assert call_kwargs["submission_end_user_id"] is None - enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) - - -def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_factory): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - repo.get_by_token.return_value = dataclasses.replace(sample_form_record) - service = HumanInputService(session_factory, form_repository=repo) - - with pytest.raises(InvalidFormDataError) as exc_info: - service.submit_form_by_token( - recipient_type=RecipientType.STANDALONE_WEB_APP, - form_token="token", - selected_action_id="invalid", - form_data={}, - ) - - assert "Invalid action" in str(exc_info.value) - repo.mark_submitted.assert_not_called() - - -def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_factory): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - - definition_with_input = FormDefinition( - form_content="hello", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content")], - user_actions=sample_form_record.definition.user_actions, - rendered_content="

hello

", - expiration_time=sample_form_record.expiration_time, - ) - form_with_input = dataclasses.replace(sample_form_record, definition=definition_with_input) - repo.get_by_token.return_value = form_with_input - service = HumanInputService(session_factory, form_repository=repo) - - with pytest.raises(InvalidFormDataError) as exc_info: - service.submit_form_by_token( - recipient_type=RecipientType.STANDALONE_WEB_APP, - form_token="token", - selected_action_id="submit", - form_data={}, - ) - - assert "Missing required inputs" in str(exc_info.value) - repo.mark_submitted.assert_not_called() diff --git a/api/tests/unit_tests/services/test_message_service_extra_contents.py b/api/tests/unit_tests/services/test_message_service_extra_contents.py deleted file mode 100644 index 3c8e301caa..0000000000 --- a/api/tests/unit_tests/services/test_message_service_extra_contents.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -import pytest - -from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData -from services import message_service - - -class _FakeMessage: - def __init__(self, message_id: str): - self.id = message_id - self.extra_contents = None - - def set_extra_contents(self, contents): - self.extra_contents = contents - - -def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None: - messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")] - repo = type( - "Repo", - (), - { - "get_by_message_ids": lambda _self, message_ids: [ - [ - HumanInputContent( - workflow_run_id="workflow-run-1", - submitted=True, - form_submission_data=HumanInputFormSubmissionData( - node_id="node-1", - node_title="Approval", - rendered_content="Rendered", - action_id="approve", - action_text="Approve", - ), - ) - ], - [], - ] - }, - )() - - monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo) - - message_service.attach_message_extra_contents(messages) - - assert messages[0].extra_contents == [ - { - "type": "human_input", - "workflow_run_id": "workflow-run-1", - "submitted": True, - "form_submission_data": { - "node_id": "node-1", - "node_title": "Approval", - "rendered_content": "Rendered", - "action_id": "approve", - "action_text": "Approve", - }, - } - ] - assert messages[1].extra_contents == [] diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index ded141f01a..f45a72927e 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -35,6 +35,7 @@ class TestDataFactory: app_id: str = "app-789", workflow_id: str = "workflow-101", status: str | WorkflowExecutionStatus = "paused", + pause_id: str | None = None, **kwargs, ) -> MagicMock: """Create a mock WorkflowRun object.""" @@ -44,6 +45,7 @@ class TestDataFactory: mock_run.app_id = app_id mock_run.workflow_id = workflow_id mock_run.status = status + mock_run.pause_id = pause_id for key, value in kwargs.items(): setattr(mock_run, key, value) diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py deleted file mode 100644 index d6c92f1013..0000000000 --- a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py +++ /dev/null @@ -1,158 +0,0 @@ -import json -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from models.model import App -from models.tools import WorkflowToolProvider -from services.tools import workflow_tools_manage_service - - -class DummyWorkflow: - def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: - self._graph_dict = graph_dict - self.version = version - - @property - def graph_dict(self) -> dict: - return self._graph_dict - - -class FakeQuery: - def __init__(self, result): - self._result = result - - def where(self, *args, **kwargs): - return self - - def first(self): - return self._result - - -class DummySession: - def __init__(self) -> None: - self.added: list[object] = [] - - def __enter__(self) -> "DummySession": - return self - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - def add(self, obj) -> None: - self.added.append(obj) - - def begin(self): - return DummyBegin(self) - - -class DummyBegin: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionContext: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionFactory: - def __init__(self, session: DummySession) -> None: - self._session = session - - def create_session(self) -> DummySessionContext: - return DummySessionContext(self._session) - - -def _build_fake_session(app) -> SimpleNamespace: - def query(model): - if model is WorkflowToolProvider: - return FakeQuery(None) - if model is App: - return FakeQuery(app) - return FakeQuery(None) - - return SimpleNamespace(query=query) - - -def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]}) - app = SimpleNamespace(workflow=workflow) - - fake_session = _build_fake_session(app) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) - - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - mock_invalidate = MagicMock() - - parameters = [{"name": "input", "description": "input", "form": "form"}] - - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon={"type": "emoji", "emoji": "tool"}, - description="desc", - parameters=parameters, - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - mock_from_db.assert_not_called() - mock_invalidate.assert_not_called() - - -def test_create_workflow_tool_success(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]}) - app = SimpleNamespace(workflow=workflow) - - fake_db = MagicMock() - fake_session = _build_fake_session(app) - fake_db.session = fake_session - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - - parameters = [{"name": "input", "description": "input", "form": "form"}] - icon = {"type": "emoji", "emoji": "tool"} - - result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon=icon, - description="desc", - parameters=parameters, - ) - - assert result == {"result": "success"} - assert len(dummy_session.added) == 1 - created_provider = dummy_session.added[0] - assert created_provider.name == "tool_name" - assert created_provider.label == "Tool" - assert created_provider.icon == json.dumps(icon) - assert created_provider.version == workflow.version - mock_from_db.assert_called_once() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py deleted file mode 100644 index 844dab8976..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ /dev/null @@ -1,226 +0,0 @@ -from __future__ import annotations - -import json -import queue -from collections.abc import Sequence -from dataclasses import dataclass -from datetime import UTC, datetime -from threading import Event - -import pytest - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from core.workflow.runtime import GraphRuntimeState, VariablePool -from models.enums import CreatorUserRole -from models.model import AppMode -from models.workflow import WorkflowRun -from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot -from repositories.entities.workflow_pause import WorkflowPauseEntity -from services.workflow_event_snapshot_service import ( - BufferState, - MessageContext, - _build_snapshot_events, - _resolve_task_id, -) - - -@dataclass(frozen=True) -class _FakePauseEntity(WorkflowPauseEntity): - pause_id: str - workflow_run_id: str - paused_at_value: datetime - pause_reasons: Sequence[HumanInputRequired] - - @property - def id(self) -> str: - return self.pause_id - - @property - def workflow_execution_id(self) -> str: - return self.workflow_run_id - - def get_state(self) -> bytes: - raise AssertionError("state is not required for snapshot tests") - - @property - def resumed_at(self) -> datetime | None: - return None - - @property - def paused_at(self) -> datetime: - return self.paused_at_value - - def get_pause_reasons(self) -> Sequence[HumanInputRequired]: - return self.pause_reasons - - -def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun: - return WorkflowRun( - id="run-1", - tenant_id="tenant-1", - app_id="app-1", - workflow_id="workflow-1", - type="workflow", - triggered_from="app-run", - version="v1", - graph=None, - inputs=json.dumps({"input": "value"}), - status=status, - outputs=json.dumps({}), - error=None, - elapsed_time=0.0, - total_tokens=0, - total_steps=0, - created_by_role=CreatorUserRole.END_USER, - created_by="user-1", - created_at=datetime(2024, 1, 1, tzinfo=UTC), - ) - - -def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot: - created_at = datetime(2024, 1, 1, tzinfo=UTC) - finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC) - return WorkflowNodeExecutionSnapshot( - execution_id="exec-1", - node_id="node-1", - node_type="human-input", - title="Human Input", - index=1, - status=status.value, - elapsed_time=0.5, - created_at=created_at, - finished_at=finished_at, - iteration_id=None, - loop_id=None, - ) - - -def _build_resumption_context(task_id: str) -> WorkflowResumptionContext: - app_config = WorkflowUIBasedAppConfig( - tenant_id="tenant-1", - app_id="app-1", - app_mode=AppMode.WORKFLOW, - workflow_id="workflow-1", - ) - generate_entity = WorkflowAppGenerateEntity( - task_id=task_id, - app_config=app_config, - inputs={}, - files=[], - user_id="user-1", - stream=True, - invoke_from=InvokeFrom.EXPLORE, - call_depth=0, - workflow_execution_id="run-1", - ) - runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - runtime_state.register_paused_node("node-1") - runtime_state.outputs = {"result": "value"} - wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity) - return WorkflowResumptionContext( - generate_entity=wrapper, - serialized_graph_runtime_state=runtime_state.dumps(), - ) - - -def test_build_snapshot_events_includes_pause_event() -> None: - workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED) - snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED) - resumption_context = _build_resumption_context("task-ctx") - pause_entity = _FakePauseEntity( - pause_id="pause-1", - workflow_run_id="run-1", - paused_at_value=datetime(2024, 1, 1, tzinfo=UTC), - pause_reasons=[ - HumanInputRequired( - form_id="form-1", - form_content="content", - node_id="node-1", - node_title="Human Input", - ) - ], - ) - - events = _build_snapshot_events( - workflow_run=workflow_run, - node_snapshots=[snapshot], - task_id="task-ctx", - message_context=None, - pause_entity=pause_entity, - resumption_context=resumption_context, - ) - - assert [event["event"] for event in events] == [ - "workflow_started", - "node_started", - "node_finished", - "workflow_paused", - ] - assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value - pause_data = events[-1]["data"] - assert pause_data["paused_nodes"] == ["node-1"] - assert pause_data["outputs"] == {"result": "value"} - assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value - assert pause_data["created_at"] == int(workflow_run.created_at.timestamp()) - assert pause_data["elapsed_time"] == workflow_run.elapsed_time - assert pause_data["total_tokens"] == workflow_run.total_tokens - assert pause_data["total_steps"] == workflow_run.total_steps - - -def test_build_snapshot_events_applies_message_context() -> None: - workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING) - snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED) - message_context = MessageContext( - conversation_id="conv-1", - message_id="msg-1", - created_at=1700000000, - answer="snapshot message", - ) - - events = _build_snapshot_events( - workflow_run=workflow_run, - node_snapshots=[snapshot], - task_id="task-1", - message_context=message_context, - pause_entity=None, - resumption_context=None, - ) - - assert [event["event"] for event in events] == [ - "workflow_started", - "message_replace", - "node_started", - "node_finished", - ] - assert events[1]["answer"] == "snapshot message" - for event in events: - assert event["conversation_id"] == "conv-1" - assert event["message_id"] == "msg-1" - assert event["created_at"] == 1700000000 - - -@pytest.mark.parametrize( - ("context_task_id", "buffered_task_id", "expected"), - [ - ("task-ctx", "task-buffer", "task-ctx"), - (None, "task-buffer", "task-buffer"), - (None, None, "run-1"), - ], -) -def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None: - resumption_context = _build_resumption_context(context_task_id) if context_task_id else None - buffer_state = BufferState( - queue=queue.Queue(), - stop_event=Event(), - done_event=Event(), - task_id_ready=Event(), - task_id_hint=buffered_task_id, - ) - if buffered_task_id: - buffer_state.task_id_ready.set() - task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0) - assert task_id == expected diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py deleted file mode 100644 index 5ac5ac8ad2..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ /dev/null @@ -1,184 +0,0 @@ -import uuid -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -from sqlalchemy.orm import sessionmaker - -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - HumanInputNodeData, - MemberRecipient, -) -from services import workflow_service as workflow_service_module -from services.workflow_service import WorkflowService - - -def _make_service() -> WorkflowService: - return WorkflowService(session_maker=sessionmaker()) - - -def _build_node_config(delivery_methods): - node_data = HumanInputNodeData( - title="Human Input", - delivery_methods=delivery_methods, - form_content="Test content", - inputs=[], - user_actions=[], - ).model_dump(mode="json") - node_data["type"] = NodeType.HUMAN_INPUT.value - return {"id": "node-1", "data": node_data} - - -def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod: - return EmailDeliveryMethod( - id=uuid.uuid4(), - enabled=enabled, - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(email="tester@example.com")], - ), - subject="Test subject", - body="Test body", - debug_mode=debug_mode, - ), - ) - - -def test_human_input_delivery_requires_draft_workflow(): - service = _make_service() - service.get_draft_workflow = MagicMock(return_value=None) # type: ignore[method-assign] - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - account = SimpleNamespace(id="account-1") - - with pytest.raises(ValueError, match="Workflow not initialized"): - service.test_human_input_delivery( - app_model=app_model, - account=account, - node_id="node-1", - delivery_method_id="delivery-1", - ) - - -def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch): - service = _make_service() - delivery_method = _make_email_method(enabled=False) - node_config = _build_node_config([delivery_method]) - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = node_config - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] - node_stub = MagicMock() - node_stub._render_form_content_before_submission.return_value = "rendered" - node_stub._resolve_default_values.return_value = {} - service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] - service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] - return_value=("form-1", {}) - ) - - test_service_instance = MagicMock() - monkeypatch.setattr( - workflow_service_module, - "HumanInputDeliveryTestService", - MagicMock(return_value=test_service_instance), - ) - - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - account = SimpleNamespace(id="account-1") - - service.test_human_input_delivery( - app_model=app_model, - account=account, - node_id="node-1", - delivery_method_id=str(delivery_method.id), - ) - - test_service_instance.send_test.assert_called_once() - - -def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch): - service = _make_service() - delivery_method = _make_email_method(enabled=True) - node_config = _build_node_config([delivery_method]) - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = node_config - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] - node_stub = MagicMock() - node_stub._render_form_content_before_submission.return_value = "rendered" - node_stub._resolve_default_values.return_value = {} - service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] - service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] - return_value=("form-1", {}) - ) - - test_service_instance = MagicMock() - monkeypatch.setattr( - workflow_service_module, - "HumanInputDeliveryTestService", - MagicMock(return_value=test_service_instance), - ) - - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - account = SimpleNamespace(id="account-1") - - service.test_human_input_delivery( - app_model=app_model, - account=account, - node_id="node-1", - delivery_method_id=str(delivery_method.id), - inputs={"#node-1.output#": "value"}, - ) - - pool_args = service._build_human_input_variable_pool.call_args.kwargs - assert pool_args["manual_inputs"] == {"#node-1.output#": "value"} - test_service_instance.send_test.assert_called_once() - - -def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch): - service = _make_service() - delivery_method = _make_email_method(enabled=True, debug_mode=True) - node_config = _build_node_config([delivery_method]) - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = node_config - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] - node_stub = MagicMock() - node_stub._render_form_content_before_submission.return_value = "rendered" - node_stub._resolve_default_values.return_value = {} - service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] - service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] - return_value=("form-1", {}) - ) - - test_service_instance = MagicMock() - monkeypatch.setattr( - workflow_service_module, - "HumanInputDeliveryTestService", - MagicMock(return_value=test_service_instance), - ) - - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - account = SimpleNamespace(id="account-1") - - service.test_human_input_delivery( - app_model=app_model, - account=account, - node_id="node-1", - delivery_method_id=str(delivery_method.id), - ) - - test_service_instance.send_test.assert_called_once() - sent_method = test_service_instance.send_test.call_args.kwargs["method"] - assert isinstance(sent_method, EmailDeliveryMethod) - assert sent_method.config.debug_mode is True - assert sent_method.config.recipients.whole_workspace is False - assert len(sent_method.config.recipients.items) == 1 - recipient = sent_method.config.recipients.items[0] - assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == account.id diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py index 70d7bde870..32d2f8b7e0 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -5,7 +5,6 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from core.workflow.enums import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel from repositories.sqlalchemy_api_workflow_node_execution_repository import ( DifyAPISQLAlchemyWorkflowNodeExecutionRepository, @@ -53,9 +52,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: call_args = mock_session.scalar.call_args[0][0] assert hasattr(call_args, "compile") # It's a SQLAlchemy statement - compiled = call_args.compile() - assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values() - def test_get_node_last_execution_not_found(self, repository): """Test getting the last execution for a node when it doesn't exist.""" # Arrange @@ -75,6 +71,28 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result is None mock_session.scalar.assert_called_once() + def test_get_executions_by_workflow_run(self, repository, mock_execution): + """Test getting all executions for a workflow run.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + executions = [mock_execution] + mock_session.execute.return_value.scalars.return_value.all.return_value = executions + + # Act + result = repository.get_executions_by_workflow_run( + tenant_id="tenant-123", + app_id="app-456", + workflow_run_id="run-101", + ) + + # Assert + assert result == executions + mock_session.execute.assert_called_once() + # Verify the query was constructed correctly + call_args = mock_session.execute.call_args[0][0] + assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + def test_get_executions_by_workflow_run_empty(self, repository): """Test getting executions for a workflow run when none exist.""" # Arrange diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 015dac257e..9700cbaf0e 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -1,15 +1,9 @@ -from contextlib import nullcontext -from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import FormInputType from models.model import App from models.workflow import Workflow -from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService @@ -167,120 +161,3 @@ class TestWorkflowService: assert workflows == [] assert has_more is False mock_session.scalars.assert_called_once() - - def test_submit_human_input_form_preview_uses_rendered_content( - self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch - ) -> None: - service = workflow_service - node_data = HumanInputNodeData( - title="Human Input", - form_content="

{{#$output.name#}}

", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - ) - node = MagicMock() - node.node_data = node_data - node.render_form_content_before_submission.return_value = "

preview

" - node.render_form_content_with_outputs.return_value = "

rendered

" - - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] - - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} - workflow.get_enclosing_node_type_and_id.return_value = None - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - - saved_outputs: dict[str, object] = {} - - class DummySession: - def __init__(self, *args, **kwargs): - self.commit = MagicMock() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return nullcontext() - - class DummySaver: - def __init__(self, *args, **kwargs): - pass - - def save(self, outputs, process_data): - saved_outputs.update(outputs) - - monkeypatch.setattr(workflow_service_module, "Session", DummySession) - monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") - account = SimpleNamespace(id="account-1") - - result = service.submit_human_input_form_preview( - app_model=app_model, - account=account, - node_id="node-1", - form_inputs={"name": "Ada", "extra": "ignored"}, - inputs={"#node-0.result#": "LLM output"}, - action="approve", - ) - - service._build_human_input_variable_pool.assert_called_once_with( - app_model=app_model, - workflow=workflow, - node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}, - manual_inputs={"#node-0.result#": "LLM output"}, - ) - - node.render_form_content_with_outputs.assert_called_once() - called_args = node.render_form_content_with_outputs.call_args.args - assert called_args[0] == "

preview

" - assert called_args[2] == node_data.outputs_field_names() - rendered_outputs = called_args[1] - assert rendered_outputs["name"] == "Ada" - assert rendered_outputs["extra"] == "ignored" - assert "extra" in saved_outputs - assert "extra" in result - assert saved_outputs["name"] == "Ada" - assert result["name"] == "Ada" - assert result["__action_id"] == "approve" - assert "__rendered_content" in result - - def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None: - service = workflow_service - node_data = HumanInputNodeData( - title="Human Input", - form_content="

{{#$output.name#}}

", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - ) - node = MagicMock() - node.node_data = node_data - node._render_form_content_before_submission.return_value = "

preview

" - node._render_form_content_with_outputs.return_value = "

rendered

" - - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] - - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") - account = SimpleNamespace(id="account-1") - - with pytest.raises(ValueError) as exc_info: - service.submit_human_input_form_preview( - app_model=app_model, - account=account, - node_id="node-1", - form_inputs={}, - inputs={}, - action="approve", - ) - - assert "Missing required inputs" in str(exc_info.value) diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py deleted file mode 100644 index 051eefa60a..0000000000 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ /dev/null @@ -1,210 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timedelta -from types import SimpleNamespace -from typing import Any - -import pytest - -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from tasks import human_input_timeout_tasks as task_module - - -class _FakeScalarResult: - def __init__(self, items: list[Any]): - self._items = items - - def all(self) -> list[Any]: - return self._items - - -class _FakeSession: - def __init__(self, items: list[Any], capture: dict[str, Any]): - self._items = items - self._capture = capture - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def scalars(self, stmt): - self._capture["stmt"] = stmt - return _FakeScalarResult(self._items) - - -class _FakeSessionFactory: - def __init__(self, items: list[Any], capture: dict[str, Any]): - self._items = items - self._capture = capture - self._capture["session_factory"] = self - - def __call__(self): - session = _FakeSession(self._items, self._capture) - self._capture["session"] = session - return session - - -class _FakeFormRepo: - def __init__(self, _session_factory, form_map: dict[str, Any] | None = None): - self.calls: list[dict[str, Any]] = [] - self._form_map = form_map or {} - - def mark_timeout(self, *, form_id: str, timeout_status: HumanInputFormStatus, reason: str | None = None): - self.calls.append( - { - "form_id": form_id, - "timeout_status": timeout_status, - "reason": reason, - } - ) - form = self._form_map.get(form_id) - return SimpleNamespace( - form_id=form_id, - workflow_run_id=getattr(form, "workflow_run_id", None), - node_id=getattr(form, "node_id", None), - ) - - -class _FakeService: - def __init__(self, _session_factory, form_repository=None): - self.enqueued: list[str] = [] - - def enqueue_resume(self, workflow_run_id: str | None) -> None: - if workflow_run_id is not None: - self.enqueued.append(workflow_run_id) - - -def _build_form( - *, - form_id: str, - form_kind: HumanInputFormKind, - created_at: datetime, - expiration_time: datetime, - workflow_run_id: str | None, - node_id: str, -) -> SimpleNamespace: - return SimpleNamespace( - id=form_id, - form_kind=form_kind, - created_at=created_at, - expiration_time=expiration_time, - workflow_run_id=workflow_run_id, - node_id=node_id, - status=HumanInputFormStatus.WAITING, - ) - - -def test_is_global_timeout_uses_created_at(): - now = datetime(2025, 1, 1, 12, 0, 0) - form = SimpleNamespace(created_at=now - timedelta(seconds=61), workflow_run_id="run-1") - - assert task_module._is_global_timeout(form, 60, now=now) is True - - form.workflow_run_id = None - assert task_module._is_global_timeout(form, 60, now=now) is False - - form.workflow_run_id = "run-1" - form.created_at = now - timedelta(seconds=59) - assert task_module._is_global_timeout(form, 60, now=now) is False - - assert task_module._is_global_timeout(form, 0, now=now) is False - - -def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pytest.MonkeyPatch): - now = datetime(2025, 1, 1, 12, 0, 0) - monkeypatch.setattr(task_module, "naive_utc_now", lambda: now) - monkeypatch.setattr(task_module.dify_config, "HITL_GLOBAL_TIMEOUT_SECONDS", 3600) - monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object())) - - forms = [ - _build_form( - form_id="form-global", - form_kind=HumanInputFormKind.RUNTIME, - created_at=now - timedelta(hours=2), - expiration_time=now + timedelta(hours=1), - workflow_run_id="run-global", - node_id="node-global", - ), - _build_form( - form_id="form-node", - form_kind=HumanInputFormKind.RUNTIME, - created_at=now - timedelta(minutes=5), - expiration_time=now - timedelta(seconds=1), - workflow_run_id="run-node", - node_id="node-node", - ), - _build_form( - form_id="form-delivery", - form_kind=HumanInputFormKind.DELIVERY_TEST, - created_at=now - timedelta(minutes=1), - expiration_time=now - timedelta(seconds=1), - workflow_run_id=None, - node_id="node-delivery", - ), - ] - - capture: dict[str, Any] = {} - monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture)) - - form_map = {form.id: form for form in forms} - repo = _FakeFormRepo(None, form_map=form_map) - - def _repo_factory(_session_factory): - return repo - - service = _FakeService(None) - - def _service_factory(_session_factory, form_repository=None): - return service - - global_calls: list[dict[str, Any]] = [] - - monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _repo_factory) - monkeypatch.setattr(task_module, "HumanInputService", _service_factory) - monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **kwargs: global_calls.append(kwargs)) - - task_module.check_and_handle_human_input_timeouts(limit=100) - - assert {(call["form_id"], call["timeout_status"], call["reason"]) for call in repo.calls} == { - ("form-global", HumanInputFormStatus.EXPIRED, "global_timeout"), - ("form-node", HumanInputFormStatus.TIMEOUT, "node_timeout"), - ("form-delivery", HumanInputFormStatus.TIMEOUT, "delivery_test_timeout"), - } - assert service.enqueued == ["run-node"] - assert global_calls == [ - { - "form_id": "form-global", - "workflow_run_id": "run-global", - "node_id": "node-global", - "session_factory": capture.get("session_factory"), - } - ] - - stmt = capture.get("stmt") - assert stmt is not None - stmt_text = str(stmt) - assert "created_at <=" in stmt_text - assert "expiration_time <=" in stmt_text - assert "ORDER BY human_input_forms.id" in stmt_text - - -def test_check_and_handle_human_input_timeouts_omits_global_filter_when_disabled(monkeypatch: pytest.MonkeyPatch): - now = datetime(2025, 1, 1, 12, 0, 0) - monkeypatch.setattr(task_module, "naive_utc_now", lambda: now) - monkeypatch.setattr(task_module.dify_config, "HITL_GLOBAL_TIMEOUT_SECONDS", 0) - monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object())) - - capture: dict[str, Any] = {} - monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory([], capture)) - monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _FakeFormRepo) - monkeypatch.setattr(task_module, "HumanInputService", _FakeService) - monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **_kwargs: None) - - task_module.check_and_handle_human_input_timeouts(limit=1) - - stmt = capture.get("stmt") - assert stmt is not None - stmt_text = str(stmt) - assert "created_at <=" not in stmt_text diff --git a/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py deleted file mode 100644 index 20cb7a211e..0000000000 --- a/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py +++ /dev/null @@ -1,123 +0,0 @@ -from collections.abc import Sequence -from types import SimpleNamespace - -import pytest - -from tasks import mail_human_input_delivery_task as task_module - - -class _DummyMail: - def __init__(self): - self.sent: list[dict[str, str]] = [] - self._inited = True - - def is_inited(self) -> bool: - return self._inited - - def send(self, *, to: str, subject: str, html: str): - self.sent.append({"to": to, "subject": subject, "html": html}) - - -class _DummySession: - def __init__(self, form): - self._form = form - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return False - - def get(self, _model, _form_id): - return self._form - - -def _build_job(recipient_count: int = 1) -> task_module._EmailDeliveryJob: - recipients: list[task_module._EmailRecipient] = [] - for idx in range(recipient_count): - recipients.append(task_module._EmailRecipient(email=f"user{idx}@example.com", token=f"token-{idx}")) - - return task_module._EmailDeliveryJob( - form_id="form-1", - subject="Subject", - body="Body for {{#url}}", - form_content="content", - recipients=recipients, - ) - - -def test_dispatch_human_input_email_task_sends_to_each_recipient(monkeypatch: pytest.MonkeyPatch): - mail = _DummyMail() - form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None) - - monkeypatch.setattr(task_module, "mail", mail) - monkeypatch.setattr( - task_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), - ) - jobs: Sequence[task_module._EmailDeliveryJob] = [_build_job(recipient_count=2)] - monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: jobs) - - task_module.dispatch_human_input_email_task( - form_id="form-1", - node_title="Approve", - session_factory=lambda: _DummySession(form), - ) - - assert len(mail.sent) == 2 - assert all(payload["subject"] == "Subject" for payload in mail.sent) - assert all("Body for" in payload["html"] for payload in mail.sent) - - -def test_dispatch_human_input_email_task_skips_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): - mail = _DummyMail() - form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None) - - monkeypatch.setattr(task_module, "mail", mail) - monkeypatch.setattr( - task_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), - ) - monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: []) - - task_module.dispatch_human_input_email_task( - form_id="form-1", - node_title="Approve", - session_factory=lambda: _DummySession(form), - ) - - assert mail.sent == [] - - -def test_dispatch_human_input_email_task_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): - mail = _DummyMail() - form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id="run-1") - job = task_module._EmailDeliveryJob( - form_id="form-1", - subject="Subject", - body="Body {{#node1.value#}}", - form_content="content", - recipients=[task_module._EmailRecipient(email="user@example.com", token="token-1")], - ) - - variable_pool = task_module.VariablePool() - variable_pool.add(["node1", "value"], "OK") - - monkeypatch.setattr(task_module, "mail", mail) - monkeypatch.setattr( - task_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), - ) - monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job]) - monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: variable_pool) - - task_module.dispatch_human_input_email_task( - form_id="form-1", - node_title="Approve", - session_factory=lambda: _DummySession(form), - ) - - assert mail.sent[0]["html"] == "Body OK" diff --git a/api/tests/unit_tests/tasks/test_workflow_execute_task.py b/api/tests/unit_tests/tasks/test_workflow_execute_task.py deleted file mode 100644 index 161151305d..0000000000 --- a/api/tests/unit_tests/tasks/test_workflow_execute_task.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -import json -import uuid -from unittest.mock import MagicMock - -import pytest - -from models.model import AppMode -from tasks.app_generate.workflow_execute_task import _publish_streaming_response - - -@pytest.fixture -def mock_topic(mocker) -> MagicMock: - topic = MagicMock() - mocker.patch( - "tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic", - return_value=topic, - ) - return topic - - -def test_publish_streaming_response_with_uuid(mock_topic: MagicMock): - workflow_run_id = uuid.uuid4() - response_stream = iter([{"event": "foo"}, "ping"]) - - _publish_streaming_response(response_stream, workflow_run_id, app_mode=AppMode.ADVANCED_CHAT) - - payloads = [call.args[0] for call in mock_topic.publish.call_args_list] - assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()] - - -def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock): - workflow_run_id = uuid.uuid4() - response_stream = iter([{"event": "bar"}]) - - _publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT) - - mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode()) diff --git a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py deleted file mode 100644 index fd5f0713a4..0000000000 --- a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py +++ /dev/null @@ -1,488 +0,0 @@ -# """ -# Unit tests for workflow node execution Celery tasks. - -# These tests verify the asynchronous storage functionality for workflow node execution data, -# including truncation and offloading logic. -# """ - -# import json -# from unittest.mock import MagicMock, Mock, patch -# from uuid import uuid4 - -# import pytest - -# from core.workflow.entities.workflow_node_execution import ( -# WorkflowNodeExecution, -# WorkflowNodeExecutionStatus, -# ) -# from core.workflow.enums import NodeType -# from libs.datetime_utils import naive_utc_now -# from models import WorkflowNodeExecutionModel -# from models.enums import ExecutionOffLoadType -# from models.model import UploadFile -# from models.workflow import WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom -# from tasks.workflow_node_execution_tasks import ( -# _create_truncator, -# _json_encode, -# _replace_or_append_offload, -# _truncate_and_upload_async, -# save_workflow_node_execution_data_task, -# save_workflow_node_execution_task, -# ) - - -# @pytest.fixture -# def sample_execution_data(): -# """Sample execution data for testing.""" -# execution = WorkflowNodeExecution( -# id=str(uuid4()), -# node_execution_id=str(uuid4()), -# workflow_id=str(uuid4()), -# workflow_execution_id=str(uuid4()), -# index=1, -# node_id="test_node", -# node_type=NodeType.LLM, -# title="Test Node", -# inputs={"input_key": "input_value"}, -# outputs={"output_key": "output_value"}, -# process_data={"process_key": "process_value"}, -# status=WorkflowNodeExecutionStatus.RUNNING, -# created_at=naive_utc_now(), -# ) -# return execution.model_dump() - - -# @pytest.fixture -# def mock_db_model(): -# """Mock database model for testing.""" -# db_model = Mock(spec=WorkflowNodeExecutionModel) -# db_model.id = "test-execution-id" -# db_model.offload_data = [] -# return db_model - - -# @pytest.fixture -# def mock_file_service(): -# """Mock file service for testing.""" -# file_service = Mock() -# mock_upload_file = Mock(spec=UploadFile) -# mock_upload_file.id = "mock-file-id" -# file_service.upload_file.return_value = mock_upload_file -# return file_service - - -# class TestSaveWorkflowNodeExecutionDataTask: -# """Test cases for save_workflow_node_execution_data_task.""" - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# @patch("tasks.workflow_node_execution_tasks.select") -# def test_save_execution_data_task_success( -# self, mock_select, mock_sessionmaker, sample_execution_data, mock_db_model -# ): -# """Test successful execution of save_workflow_node_execution_data_task.""" -# # Setup mocks -# mock_session = MagicMock() -# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session -# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model - -# # Execute task -# result = save_workflow_node_execution_data_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# user_data={"user_id": "test-user-id", "user_type": "account"}, -# ) - -# # Verify success -# assert result is True -# mock_session.merge.assert_called_once_with(mock_db_model) -# mock_session.commit.assert_called_once() - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# @patch("tasks.workflow_node_execution_tasks.select") -# def test_save_execution_data_task_execution_not_found(self, mock_select, mock_sessionmaker, -# sample_execution_data): -# """Test task when execution is not found in database.""" -# # Setup mocks -# mock_session = MagicMock() -# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session -# mock_session.execute.return_value.scalars.return_value.first.return_value = None - -# # Execute task -# result = save_workflow_node_execution_data_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# user_data={"user_id": "test-user-id", "user_type": "account"}, -# ) - -# # Verify failure -# assert result is False -# mock_session.merge.assert_not_called() -# mock_session.commit.assert_not_called() - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# @patch("tasks.workflow_node_execution_tasks.select") -# def test_save_execution_data_task_with_truncation(self, mock_select, mock_sessionmaker, mock_db_model): -# """Test task with data that requires truncation.""" -# # Create execution with large data -# large_data = {"large_field": "x" * 10000} -# execution = WorkflowNodeExecution( -# id=str(uuid4()), -# node_execution_id=str(uuid4()), -# workflow_id=str(uuid4()), -# workflow_execution_id=str(uuid4()), -# index=1, -# node_id="test_node", -# node_type=NodeType.LLM, -# title="Test Node", -# inputs=large_data, -# outputs=large_data, -# process_data=large_data, -# status=WorkflowNodeExecutionStatus.RUNNING, -# created_at=naive_utc_now(), -# ) -# execution_data = execution.model_dump() - -# # Setup mocks -# mock_session = MagicMock() -# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session -# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model - -# # Create mock upload file -# mock_upload_file = Mock(spec=UploadFile) -# mock_upload_file.id = "mock-file-id" - -# # Execute task -# with patch("tasks.workflow_node_execution_tasks._truncate_and_upload_async") as mock_truncate: -# # Mock truncation results -# mock_truncate.return_value = { -# "truncated_value": {"large_field": "[TRUNCATED]"}, -# "file": mock_upload_file, -# "offload": WorkflowNodeExecutionOffload( -# id=str(uuid4()), -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# node_execution_id=execution.id, -# type_=ExecutionOffLoadType.INPUTS, -# file_id=mock_upload_file.id, -# ), -# } - -# result = save_workflow_node_execution_data_task( -# execution_data=execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# user_data={"user_id": "test-user-id", "user_type": "account"}, -# ) - -# # Verify success and truncation was called -# assert result is True -# assert mock_truncate.call_count == 3 # inputs, outputs, process_data -# mock_session.merge.assert_called_once_with(mock_db_model) -# mock_session.commit.assert_called_once() - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# def test_save_execution_data_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data): -# """Test task retry mechanism on exception.""" -# # Setup mock to raise exception -# mock_sessionmaker.side_effect = Exception("Database error") - -# # Create a mock task instance with proper retry behavior -# with patch.object(save_workflow_node_execution_data_task, "retry") as mock_retry: -# mock_retry.side_effect = Exception("Retry called") - -# # Execute task and expect retry -# with pytest.raises(Exception, match="Retry called"): -# save_workflow_node_execution_data_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# user_data={"user_id": "test-user-id", "user_type": "account"}, -# ) - -# # Verify retry was called -# mock_retry.assert_called_once() - - -# class TestTruncateAndUploadAsync: -# """Test cases for _truncate_and_upload_async function.""" - -# def test_truncate_and_upload_with_none_values(self, mock_file_service): -# """Test _truncate_and_upload_async with None values.""" -# # The function handles None values internally, so we test with empty dict instead -# result = _truncate_and_upload_async( -# values={}, -# execution_id="test-id", -# type_=ExecutionOffLoadType.INPUTS, -# tenant_id="test-tenant", -# app_id="test-app", -# user_data={"user_id": "test-user", "user_type": "account"}, -# file_service=mock_file_service, -# ) - -# # Empty dict should not require truncation -# assert result is None -# mock_file_service.upload_file.assert_not_called() - -# @patch("tasks.workflow_node_execution_tasks._create_truncator") -# def test_truncate_and_upload_no_truncation_needed(self, mock_create_truncator, mock_file_service): -# """Test _truncate_and_upload_async when no truncation is needed.""" -# # Mock truncator to return no truncation -# mock_truncator = Mock() -# mock_truncator.truncate_variable_mapping.return_value = ({"small": "data"}, False) -# mock_create_truncator.return_value = mock_truncator - -# small_values = {"small": "data"} -# result = _truncate_and_upload_async( -# values=small_values, -# execution_id="test-id", -# type_=ExecutionOffLoadType.INPUTS, -# tenant_id="test-tenant", -# app_id="test-app", -# user_data={"user_id": "test-user", "user_type": "account"}, -# file_service=mock_file_service, -# ) - -# assert result is None -# mock_file_service.upload_file.assert_not_called() - -# @patch("tasks.workflow_node_execution_tasks._create_truncator") -# @patch("models.Account") -# @patch("models.Tenant") -# def test_truncate_and_upload_with_account_user( -# self, mock_tenant_class, mock_account_class, mock_create_truncator, mock_file_service -# ): -# """Test _truncate_and_upload_async with account user.""" -# # Mock truncator to return truncation needed -# mock_truncator = Mock() -# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True) -# mock_create_truncator.return_value = mock_truncator - -# # Mock user and tenant creation -# mock_account = Mock() -# mock_account.id = "test-user" -# mock_account_class.return_value = mock_account - -# mock_tenant = Mock() -# mock_tenant.id = "test-tenant" -# mock_tenant_class.return_value = mock_tenant - -# large_values = {"large": "x" * 10000} -# result = _truncate_and_upload_async( -# values=large_values, -# execution_id="test-id", -# type_=ExecutionOffLoadType.INPUTS, -# tenant_id="test-tenant", -# app_id="test-app", -# user_data={"user_id": "test-user", "user_type": "account"}, -# file_service=mock_file_service, -# ) - -# # Verify result structure -# assert result is not None -# assert "truncated_value" in result -# assert "file" in result -# assert "offload" in result -# assert result["truncated_value"] == {"truncated": "data"} - -# # Verify file upload was called -# mock_file_service.upload_file.assert_called_once() -# upload_call = mock_file_service.upload_file.call_args -# assert upload_call[1]["filename"] == "node_execution_test-id_inputs.json" -# assert upload_call[1]["mimetype"] == "application/json" -# assert upload_call[1]["user"] == mock_account - -# @patch("tasks.workflow_node_execution_tasks._create_truncator") -# @patch("models.EndUser") -# def test_truncate_and_upload_with_end_user(self, mock_end_user_class, mock_create_truncator, mock_file_service): -# """Test _truncate_and_upload_async with end user.""" -# # Mock truncator to return truncation needed -# mock_truncator = Mock() -# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True) -# mock_create_truncator.return_value = mock_truncator - -# # Mock end user creation -# mock_end_user = Mock() -# mock_end_user.id = "test-user" -# mock_end_user.tenant_id = "test-tenant" -# mock_end_user_class.return_value = mock_end_user - -# large_values = {"large": "x" * 10000} -# result = _truncate_and_upload_async( -# values=large_values, -# execution_id="test-id", -# type_=ExecutionOffLoadType.OUTPUTS, -# tenant_id="test-tenant", -# app_id="test-app", -# user_data={"user_id": "test-user", "user_type": "end_user"}, -# file_service=mock_file_service, -# ) - -# # Verify result structure -# assert result is not None -# assert result["truncated_value"] == {"truncated": "data"} - -# # Verify file upload was called with end user -# mock_file_service.upload_file.assert_called_once() -# upload_call = mock_file_service.upload_file.call_args -# assert upload_call[1]["filename"] == "node_execution_test-id_outputs.json" -# assert upload_call[1]["user"] == mock_end_user - - -# class TestHelperFunctions: -# """Test cases for helper functions.""" - -# @patch("tasks.workflow_node_execution_tasks.dify_config") -# def test_create_truncator(self, mock_config): -# """Test _create_truncator function.""" -# mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000 -# mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100 -# mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500 - -# truncator = _create_truncator() - -# # Verify truncator was created with correct config -# assert truncator is not None - -# def test_json_encode(self): -# """Test _json_encode function.""" -# test_data = {"key": "value", "number": 42} -# result = _json_encode(test_data) - -# assert isinstance(result, str) -# decoded = json.loads(result) -# assert decoded == test_data - -# def test_replace_or_append_offload_replace_existing(self): -# """Test _replace_or_append_offload replaces existing offload of same type.""" -# existing_offload = WorkflowNodeExecutionOffload( -# id=str(uuid4()), -# tenant_id="test-tenant", -# app_id="test-app", -# node_execution_id="test-execution", -# type_=ExecutionOffLoadType.INPUTS, -# file_id="old-file-id", -# ) - -# new_offload = WorkflowNodeExecutionOffload( -# id=str(uuid4()), -# tenant_id="test-tenant", -# app_id="test-app", -# node_execution_id="test-execution", -# type_=ExecutionOffLoadType.INPUTS, -# file_id="new-file-id", -# ) - -# result = _replace_or_append_offload([existing_offload], new_offload) - -# assert len(result) == 1 -# assert result[0].file_id == "new-file-id" - -# def test_replace_or_append_offload_append_new_type(self): -# """Test _replace_or_append_offload appends new offload of different type.""" -# existing_offload = WorkflowNodeExecutionOffload( -# id=str(uuid4()), -# tenant_id="test-tenant", -# app_id="test-app", -# node_execution_id="test-execution", -# type_=ExecutionOffLoadType.INPUTS, -# file_id="inputs-file-id", -# ) - -# new_offload = WorkflowNodeExecutionOffload( -# id=str(uuid4()), -# tenant_id="test-tenant", -# app_id="test-app", -# node_execution_id="test-execution", -# type_=ExecutionOffLoadType.OUTPUTS, -# file_id="outputs-file-id", -# ) - -# result = _replace_or_append_offload([existing_offload], new_offload) - -# assert len(result) == 2 -# file_ids = [offload.file_id for offload in result] -# assert "inputs-file-id" in file_ids -# assert "outputs-file-id" in file_ids - - -# class TestSaveWorkflowNodeExecutionTask: -# """Test cases for save_workflow_node_execution_task.""" - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# @patch("tasks.workflow_node_execution_tasks.select") -# def test_save_workflow_node_execution_task_create_new(self, mock_select, mock_sessionmaker, -# sample_execution_data): -# """Test creating a new workflow node execution.""" -# # Setup mocks -# mock_session = MagicMock() -# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session -# mock_session.scalar.return_value = None # No existing execution - -# # Execute task -# result = save_workflow_node_execution_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, -# creator_user_id="test-user-id", -# creator_user_role="account", -# ) - -# # Verify success -# assert result is True -# mock_session.add.assert_called_once() -# mock_session.commit.assert_called_once() - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# @patch("tasks.workflow_node_execution_tasks.select") -# def test_save_workflow_node_execution_task_update_existing( -# self, mock_select, mock_sessionmaker, sample_execution_data -# ): -# """Test updating an existing workflow node execution.""" -# # Setup mocks -# mock_session = MagicMock() -# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session - -# existing_execution = Mock(spec=WorkflowNodeExecutionModel) -# mock_session.scalar.return_value = existing_execution - -# # Execute task -# result = save_workflow_node_execution_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, -# creator_user_id="test-user-id", -# creator_user_role="account", -# ) - -# # Verify success -# assert result is True -# mock_session.add.assert_not_called() # Should not add new, just update existing -# mock_session.commit.assert_called_once() - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# def test_save_workflow_node_execution_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data): -# """Test task retry mechanism on exception.""" -# # Setup mock to raise exception -# mock_sessionmaker.side_effect = Exception("Database error") - -# # Create a mock task instance with proper retry behavior -# with patch.object(save_workflow_node_execution_task, "retry") as mock_retry: -# mock_retry.side_effect = Exception("Retry called") - -# # Execute task and expect retry -# with pytest.raises(Exception, match="Retry called"): -# save_workflow_node_execution_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, -# creator_user_id="test-user-id", -# creator_user_role="account", -# ) - -# # Verify retry was called -# mock_retry.assert_called_once() diff --git a/api/ty.toml b/api/ty.toml index 3d3dda4595..380e14dbef 100644 --- a/api/ty.toml +++ b/api/ty.toml @@ -1,16 +1,15 @@ [src] exclude = [ # deps groups (A1/A2/B/C/D/E) - # A2: workflow engine/nodes - "core/workflow", - "core/app/workflow", - "core/helper/code_executor", # B: app runner + prompt "core/prompt", "core/app/apps/base_app_runner.py", "core/app/apps/workflow_app_runner.py", + "core/agent", + "core/plugin", # C: services/controllers/fields/libs "services", + "controllers/inner_api", "controllers/console/app", "controllers/console/explore", "controllers/console/datasets", @@ -26,21 +25,10 @@ exclude = [ # non-producition or generated code "migrations", "tests", - # targeted ignores for current type-check errors - # TODO(QuantumGhost): suppress type errors in HITL related code. - # fix the type error later - "configs/middleware/cache/redis_pubsub_config.py", - "extensions/ext_redis.py", - "models/execution_extra_content.py", - "tasks/workflow_execution_tasks.py", - "core/workflow/nodes/base/node.py", - "services/human_input_delivery_test_service.py", - "core/app/apps/advanced_chat/app_generator.py", - "controllers/console/human_input_form.py", - "controllers/console/app/workflow_run.py", - "repositories/sqlalchemy_api_workflow_node_execution_repository.py", - "extensions/logstore/repositories/logstore_api_workflow_run_repository.py", - "controllers/web/workflow_events.py", - "tasks/app_generate/workflow_execute_task.py", ] + +[rules] +deprecated = "ignore" +unused-ignore-comment = "ignore" +# possibly-missing-attribute = "ignore" \ No newline at end of file diff --git a/api/uv.lock b/api/uv.lock index a3ad292168..f253976cc1 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1589,7 +1589,7 @@ requires-dist = [ { name = "flask-login", specifier = "~=0.6.3" }, { name = "flask-migrate", specifier = "~=4.0.7" }, { name = "flask-orjson", specifier = "~=2.0.0" }, - { name = "flask-restx", specifier = "~=1.3.0" }, + { name = "flask-restx", specifier = "~=1.3.2" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "gevent", specifier = "~=25.9.1" }, { name = "gmpy2", specifier = "~=2.2.1" }, @@ -1684,7 +1684,7 @@ dev = [ { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "testcontainers", specifier = "~=4.13.2" }, - { name = "ty", specifier = "~=0.0.1a19" }, + { name = "ty", specifier = ">=0.0.14" }, { name = "types-aiofiles", specifier = "~=24.1.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0" }, { name = "types-cachetools", specifier = "~=5.5.0" }, @@ -1707,7 +1707,7 @@ dev = [ { name = "types-openpyxl", specifier = "~=3.1.5" }, { name = "types-pexpect", specifier = "~=4.9.0" }, { name = "types-protobuf", specifier = "~=5.29.1" }, - { name = "types-psutil", specifier = "~=7.0.0" }, + { name = "types-psutil", specifier = "~=7.2.2" }, { name = "types-psycopg2", specifier = "~=2.9.21" }, { name = "types-pygments", specifier = "~=2.19.0" }, { name = "types-pymysql", specifier = "~=1.1.0" }, @@ -6239,27 +6239,26 @@ wheels = [ [[package]] name = "ty" -version = "0.0.1a27" +version = "0.0.14" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8f/65/3592d7c73d80664378fc90d0a00c33449a99cbf13b984433c883815245f3/ty-0.0.1a27.tar.gz", hash = "sha256:d34fe04979f2c912700cbf0919e8f9b4eeaa10c4a2aff7450e5e4c90f998bc28", size = 4516059, upload-time = "2025-11-18T21:55:18.381Z" } +sdist = { url = "https://files.pythonhosted.org/packages/af/57/22c3d6bf95c2229120c49ffc2f0da8d9e8823755a1c3194da56e51f1cc31/ty-0.0.14.tar.gz", hash = "sha256:a691010565f59dd7f15cf324cdcd1d9065e010c77a04f887e1ea070ba34a7de2", size = 5036573, upload-time = "2026-01-27T00:57:31.427Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/05/7945aa97356446fd53ed3ddc7ee02a88d8ad394217acd9428f472d6b109d/ty-0.0.1a27-py3-none-linux_armv6l.whl", hash = "sha256:3cbb735f5ecb3a7a5f5b82fb24da17912788c109086df4e97d454c8fb236fbc5", size = 9375047, upload-time = "2025-11-18T21:54:31.577Z" }, - { url = "https://files.pythonhosted.org/packages/69/4e/89b167a03de0e9ec329dc89bc02e8694768e4576337ef6c0699987681342/ty-0.0.1a27-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:4a6367236dc456ba2416563301d498aef8c6f8959be88777ef7ba5ac1bf15f0b", size = 9169540, upload-time = "2025-11-18T21:54:34.036Z" }, - { url = "https://files.pythonhosted.org/packages/38/07/e62009ab9cc242e1becb2bd992097c80a133fce0d4f055fba6576150d08a/ty-0.0.1a27-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8e93e231a1bcde964cdb062d2d5e549c24493fb1638eecae8fcc42b81e9463a4", size = 8711942, upload-time = "2025-11-18T21:54:36.3Z" }, - { url = "https://files.pythonhosted.org/packages/b5/43/f35716ec15406f13085db52e762a3cc663c651531a8124481d0ba602eca0/ty-0.0.1a27-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5b6a8166b60117da1179851a3d719cc798bf7e61f91b35d76242f0059e9ae1d", size = 8984208, upload-time = "2025-11-18T21:54:39.453Z" }, - { url = "https://files.pythonhosted.org/packages/2d/79/486a3374809523172379768de882c7a369861165802990177fe81489b85f/ty-0.0.1a27-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bfbe8b0e831c072b79a078d6c126d7f4d48ca17f64a103de1b93aeda32265dc5", size = 9157209, upload-time = "2025-11-18T21:54:42.664Z" }, - { url = "https://files.pythonhosted.org/packages/ff/08/9a7c8efcb327197d7d347c548850ef4b54de1c254981b65e8cd0672dc327/ty-0.0.1a27-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:90e09678331552e7c25d7eb47868b0910dc5b9b212ae22c8ce71a52d6576ddbb", size = 9519207, upload-time = "2025-11-18T21:54:45.311Z" }, - { url = "https://files.pythonhosted.org/packages/e0/9d/7b4680683e83204b9edec551bb91c21c789ebc586b949c5218157ee474b7/ty-0.0.1a27-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:88c03e4beeca79d85a5618921e44b3a6ea957e0453e08b1cdd418b51da645939", size = 10148794, upload-time = "2025-11-18T21:54:48.329Z" }, - { url = "https://files.pythonhosted.org/packages/89/21/8b961b0ab00c28223f06b33222427a8e31aa04f39d1b236acc93021c626c/ty-0.0.1a27-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ece5811322789fefe22fc088ed36c5879489cd39e913f9c1ff2a7678f089c61", size = 9900563, upload-time = "2025-11-18T21:54:51.214Z" }, - { url = "https://files.pythonhosted.org/packages/85/eb/95e1f0b426c2ea8d443aa923fcab509059c467bbe64a15baaf573fea1203/ty-0.0.1a27-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f2ccb4f0fddcd6e2017c268dfce2489e9a36cb82a5900afe6425835248b1086", size = 9926355, upload-time = "2025-11-18T21:54:53.927Z" }, - { url = "https://files.pythonhosted.org/packages/f5/78/40e7f072049e63c414f2845df780be3a494d92198c87c2ffa65e63aecf3f/ty-0.0.1a27-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33450528312e41d003e96a1647780b2783ab7569bbc29c04fc76f2d1908061e3", size = 9480580, upload-time = "2025-11-18T21:54:56.617Z" }, - { url = "https://files.pythonhosted.org/packages/18/da/f4a2dfedab39096808ddf7475f35ceb750d9a9da840bee4afd47b871742f/ty-0.0.1a27-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a0a9ac635deaa2b15947701197ede40cdecd13f89f19351872d16f9ccd773fa1", size = 8957524, upload-time = "2025-11-18T21:54:59.085Z" }, - { url = "https://files.pythonhosted.org/packages/21/ea/26fee9a20cf77a157316fd3ab9c6db8ad5a0b20b2d38a43f3452622587ac/ty-0.0.1a27-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:797fb2cd49b6b9b3ac9f2f0e401fb02d3aa155badc05a8591d048d38d28f1e0c", size = 9201098, upload-time = "2025-11-18T21:55:01.845Z" }, - { url = "https://files.pythonhosted.org/packages/b0/53/e14591d1275108c9ae28f97ac5d4b93adcc2c8a4b1b9a880dfa9d07c15f8/ty-0.0.1a27-py3-none-musllinux_1_2_i686.whl", hash = "sha256:7fe81679a0941f85e98187d444604e24b15bde0a85874957c945751756314d03", size = 9275470, upload-time = "2025-11-18T21:55:04.23Z" }, - { url = "https://files.pythonhosted.org/packages/37/44/e2c9acecac70bf06fb41de285e7be2433c2c9828f71e3bf0e886fc85c4fd/ty-0.0.1a27-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:355f651d0cdb85535a82bd9f0583f77b28e3fd7bba7b7da33dcee5a576eff28b", size = 9592394, upload-time = "2025-11-18T21:55:06.542Z" }, - { url = "https://files.pythonhosted.org/packages/ee/a7/4636369731b24ed07c2b4c7805b8d990283d677180662c532d82e4ef1a36/ty-0.0.1a27-py3-none-win32.whl", hash = "sha256:61782e5f40e6df622093847b34c366634b75d53f839986f1bf4481672ad6cb55", size = 8783816, upload-time = "2025-11-18T21:55:09.648Z" }, - { url = "https://files.pythonhosted.org/packages/a7/1d/b76487725628d9e81d9047dc0033a5e167e0d10f27893d04de67fe1a9763/ty-0.0.1a27-py3-none-win_amd64.whl", hash = "sha256:c682b238085d3191acddcf66ef22641562946b1bba2a7f316012d5b2a2f4de11", size = 9616833, upload-time = "2025-11-18T21:55:12.457Z" }, - { url = "https://files.pythonhosted.org/packages/3a/db/c7cd5276c8f336a3cf87992b75ba9d486a7cf54e753fcd42495b3bc56fb7/ty-0.0.1a27-py3-none-win_arm64.whl", hash = "sha256:e146dfa32cbb0ac6afb0cb65659e87e4e313715e68d76fe5ae0a4b3d5b912ce8", size = 9137796, upload-time = "2025-11-18T21:55:15.897Z" }, + { url = "https://files.pythonhosted.org/packages/99/cb/cc6d1d8de59beb17a41f9a614585f884ec2d95450306c173b3b7cc090d2e/ty-0.0.14-py3-none-linux_armv6l.whl", hash = "sha256:32cf2a7596e693094621d3ae568d7ee16707dce28c34d1762947874060fdddaa", size = 10034228, upload-time = "2026-01-27T00:57:53.133Z" }, + { url = "https://files.pythonhosted.org/packages/f3/96/dd42816a2075a8f31542296ae687483a8d047f86a6538dfba573223eaf9a/ty-0.0.14-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f971bf9805f49ce8c0968ad53e29624d80b970b9eb597b7cbaba25d8a18ce9a2", size = 9939162, upload-time = "2026-01-27T00:57:43.857Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b4/73c4859004e0f0a9eead9ecb67021438b2e8e5fdd8d03e7f5aca77623992/ty-0.0.14-py3-none-macosx_11_0_arm64.whl", hash = "sha256:45448b9e4806423523268bc15e9208c4f3f2ead7c344f615549d2e2354d6e924", size = 9418661, upload-time = "2026-01-27T00:58:03.411Z" }, + { url = "https://files.pythonhosted.org/packages/58/35/839c4551b94613db4afa20ee555dd4f33bfa7352d5da74c5fa416ffa0fd2/ty-0.0.14-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee94a9b747ff40114085206bdb3205a631ef19a4d3fb89e302a88754cbbae54c", size = 9837872, upload-time = "2026-01-27T00:57:23.718Z" }, + { url = "https://files.pythonhosted.org/packages/41/2b/bbecf7e2faa20c04bebd35fc478668953ca50ee5847ce23e08acf20ea119/ty-0.0.14-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6756715a3c33182e9ab8ffca2bb314d3c99b9c410b171736e145773ee0ae41c3", size = 9848819, upload-time = "2026-01-27T00:57:58.501Z" }, + { url = "https://files.pythonhosted.org/packages/be/60/3c0ba0f19c0f647ad9d2b5b5ac68c0f0b4dc899001bd53b3a7537fb247a2/ty-0.0.14-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89d0038a2f698ba8b6fec5cf216a4e44e2f95e4a5095a8c0f57fe549f87087c2", size = 10324371, upload-time = "2026-01-27T00:57:29.291Z" }, + { url = "https://files.pythonhosted.org/packages/24/32/99d0a0b37d0397b0a989ffc2682493286aa3bc252b24004a6714368c2c3d/ty-0.0.14-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c64a83a2d669b77f50a4957039ca1450626fb474619f18f6f8a3eb885bf7544", size = 10865898, upload-time = "2026-01-27T00:57:33.542Z" }, + { url = "https://files.pythonhosted.org/packages/1a/88/30b583a9e0311bb474269cfa91db53350557ebec09002bfc3fb3fc364e8c/ty-0.0.14-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:242488bfb547ef080199f6fd81369ab9cb638a778bb161511d091ffd49c12129", size = 10555777, upload-time = "2026-01-27T00:58:05.853Z" }, + { url = "https://files.pythonhosted.org/packages/cd/a2/cb53fb6325dcf3d40f2b1d0457a25d55bfbae633c8e337bde8ec01a190eb/ty-0.0.14-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4790c3866f6c83a4f424fc7d09ebdb225c1f1131647ba8bdc6fcdc28f09ed0ff", size = 10412913, upload-time = "2026-01-27T00:57:38.834Z" }, + { url = "https://files.pythonhosted.org/packages/42/8f/f2f5202d725ed1e6a4e5ffaa32b190a1fe70c0b1a2503d38515da4130b4c/ty-0.0.14-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:950f320437f96d4ea9a2332bbfb5b68f1c1acd269ebfa4c09b6970cc1565bd9d", size = 9837608, upload-time = "2026-01-27T00:57:55.898Z" }, + { url = "https://files.pythonhosted.org/packages/f7/ba/59a2a0521640c489dafa2c546ae1f8465f92956fede18660653cce73b4c5/ty-0.0.14-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4a0ec3ee70d83887f86925bbc1c56f4628bd58a0f47f6f32ddfe04e1f05466df", size = 9884324, upload-time = "2026-01-27T00:57:46.786Z" }, + { url = "https://files.pythonhosted.org/packages/03/95/8d2a49880f47b638743212f011088552ecc454dd7a665ddcbdabea25772a/ty-0.0.14-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a1a4e6b6da0c58b34415955279eff754d6206b35af56a18bb70eb519d8d139ef", size = 10033537, upload-time = "2026-01-27T00:58:01.149Z" }, + { url = "https://files.pythonhosted.org/packages/e9/40/4523b36f2ce69f92ccf783855a9e0ebbbd0f0bb5cdce6211ee1737159ed3/ty-0.0.14-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:dc04384e874c5de4c5d743369c277c8aa73d1edea3c7fc646b2064b637db4db3", size = 10495910, upload-time = "2026-01-27T00:57:26.691Z" }, + { url = "https://files.pythonhosted.org/packages/08/d5/655beb51224d1bfd4f9ddc0bb209659bfe71ff141bcf05c418ab670698f0/ty-0.0.14-py3-none-win32.whl", hash = "sha256:b20e22cf54c66b3e37e87377635da412d9a552c9bf4ad9fc449fed8b2e19dad2", size = 9507626, upload-time = "2026-01-27T00:57:41.43Z" }, + { url = "https://files.pythonhosted.org/packages/b6/d9/c569c9961760e20e0a4bc008eeb1415754564304fd53997a371b7cf3f864/ty-0.0.14-py3-none-win_amd64.whl", hash = "sha256:e312ff9475522d1a33186657fe74d1ec98e4a13e016d66f5758a452c90ff6409", size = 10437980, upload-time = "2026-01-27T00:57:36.422Z" }, + { url = "https://files.pythonhosted.org/packages/ad/0c/186829654f5bfd9a028f6648e9caeb11271960a61de97484627d24443f91/ty-0.0.14-py3-none-win_arm64.whl", hash = "sha256:b6facdbe9b740cb2c15293a1d178e22ffc600653646452632541d01c36d5e378", size = 9885831, upload-time = "2026-01-27T00:57:49.747Z" }, ] [[package]] @@ -6509,11 +6508,11 @@ wheels = [ [[package]] name = "types-psutil" -version = "7.0.0.20251116" +version = "7.2.2.20260130" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/47/ec/c1e9308b91582cad1d7e7d3007fd003ef45a62c2500f8219313df5fc3bba/types_psutil-7.0.0.20251116.tar.gz", hash = "sha256:92b5c78962e55ce1ed7b0189901a4409ece36ab9fd50c3029cca7e681c606c8a", size = 22192, upload-time = "2025-11-16T03:10:32.859Z" } +sdist = { url = "https://files.pythonhosted.org/packages/69/14/fc5fb0a6ddfadf68c27e254a02ececd4d5c7fdb0efcb7e7e917a183497fb/types_psutil-7.2.2.20260130.tar.gz", hash = "sha256:15b0ab69c52841cf9ce3c383e8480c620a4d13d6a8e22b16978ebddac5590950", size = 26535, upload-time = "2026-01-30T03:58:14.116Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/0e/11ba08a5375c21039ed5f8e6bba41e9452fb69f0e2f7ee05ed5cca2a2cdf/types_psutil-7.0.0.20251116-py3-none-any.whl", hash = "sha256:74c052de077c2024b85cd435e2cba971165fe92a5eace79cbeb821e776dbc047", size = 25376, upload-time = "2025-11-16T03:10:31.813Z" }, + { url = "https://files.pythonhosted.org/packages/17/d7/60974b7e31545d3768d1770c5fe6e093182c3bfd819429b33133ba6b3e89/types_psutil-7.2.2.20260130-py3-none-any.whl", hash = "sha256:15523a3caa7b3ff03ac7f9b78a6470a59f88f48df1d74a39e70e06d2a99107da", size = 32876, upload-time = "2026-01-30T03:58:13.172Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index 93099347bd..41a0205bf5 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1399,9 +1399,9 @@ PLUGIN_STDIO_BUFFER_SIZE=1024 PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880 PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 -# Plugin Daemon side timeout (configure to match the API side below) +# Plugin Daemon side timeout (configure to match the API side below) PLUGIN_MAX_EXECUTION_TIMEOUT=600 -# API side timeout (configure to match the Plugin Daemon side above) +# API side timeout (configure to match the Plugin Daemon side above) PLUGIN_DAEMON_TIMEOUT=600.0 # PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple PIP_MIRROR_URL= @@ -1519,31 +1519,4 @@ AMPLITUDE_API_KEY= SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 - - -# Redis URL used for PubSub between API and -# celery worker -# defaults to url constructed from `REDIS_*` -# configurations -PUBSUB_REDIS_URL= -# Pub/sub channel type for streaming events. -# valid options are: -# -# - pubsub: for normal Pub/Sub -# - sharded: for sharded Pub/Sub -# -# It's highly recommended to use sharded Pub/Sub AND redis cluster -# for large deployments. -PUBSUB_REDIS_CHANNEL_TYPE=pubsub -# Whether to use Redis cluster mode while running -# PubSub. -# It's highly recommended to enable this for large deployments. -PUBSUB_REDIS_USE_CLUSTERS=false - -# Whether to Enable human input timeout check task -ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true -# Human input timeout check interval in minutes -HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1 - - SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 9659990383..eb8c2b53c5 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -270,7 +270,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.2-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 81c34fc6a2..4a739bbbe0 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -123,7 +123,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.2-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always env_file: - ./middleware.env diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index f9a254c1a6..02b8146aa9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -683,11 +683,6 @@ x-shared-env: &shared-api-worker-env SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21} SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000} SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30} - PUBSUB_REDIS_URL: ${PUBSUB_REDIS_URL:-} - PUBSUB_REDIS_CHANNEL_TYPE: ${PUBSUB_REDIS_CHANNEL_TYPE:-pubsub} - PUBSUB_REDIS_USE_CLUSTERS: ${PUBSUB_REDIS_USE_CLUSTERS:-false} - ENABLE_HUMAN_INPUT_TIMEOUT_TASK: ${ENABLE_HUMAN_INPUT_TIMEOUT_TASK:-true} - HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: ${HUMAN_INPUT_TIMEOUT_TASK_INTERVAL:-1} SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: ${SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL:-90000} services: @@ -961,7 +956,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.2-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts index d3296bacd0..373c2f86d3 100644 --- a/web/__mocks__/provider-context.ts +++ b/web/__mocks__/provider-context.ts @@ -35,7 +35,6 @@ export const baseProviderContextValue: ProviderContextState = { refreshLicenseLimit: noop, isAllowTransferWorkspace: false, isAllowPublishAsCustomKnowledgePipelineTemplate: false, - humanInputEmailDeliveryEnabled: false, } export const createMockProviderContextValue = (overrides: Partial = {}): ProviderContextState => { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx index fffc1ff2a5..fc27f84c60 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -8,6 +8,7 @@ describe('SVG Attribute Error Reproduction', () => { // Capture console errors const originalError = console.error let errorMessages: string[] = [] + beforeEach(() => { errorMessages = [] console.error = vi.fn((message) => { diff --git a/web/app/(humanInputLayout)/form/[token]/form.tsx b/web/app/(humanInputLayout)/form/[token]/form.tsx deleted file mode 100644 index d027ef8b7d..0000000000 --- a/web/app/(humanInputLayout)/form/[token]/form.tsx +++ /dev/null @@ -1,289 +0,0 @@ -'use client' -import type { ButtonProps } from '@/app/components/base/button' -import type { FormInputItem, UserAction } from '@/app/components/workflow/nodes/human-input/types' -import type { SiteInfo } from '@/models/share' -import type { HumanInputFormError } from '@/service/use-share' -import { - RiCheckboxCircleFill, - RiErrorWarningFill, - RiInformation2Fill, -} from '@remixicon/react' -import { produce } from 'immer' -import { useParams } from 'next/navigation' -import * as React from 'react' -import { useEffect, useMemo, useState } from 'react' -import { useTranslation } from 'react-i18next' -import AppIcon from '@/app/components/base/app-icon' -import Button from '@/app/components/base/button' -import ContentItem from '@/app/components/base/chat/chat/answer/human-input-content/content-item' -import ExpirationTime from '@/app/components/base/chat/chat/answer/human-input-content/expiration-time' -import { getButtonStyle } from '@/app/components/base/chat/chat/answer/human-input-content/utils' -import Loading from '@/app/components/base/loading' -import DifyLogo from '@/app/components/base/logo/dify-logo' -import useDocumentTitle from '@/hooks/use-document-title' -import { useGetHumanInputForm, useSubmitHumanInputForm } from '@/service/use-share' -import { cn } from '@/utils/classnames' - -export type FormData = { - site: { site: SiteInfo } - form_content: string - inputs: FormInputItem[] - resolved_default_values: Record - user_actions: UserAction[] - expiration_time: number -} - -const FormContent = () => { - const { t } = useTranslation() - - const { token } = useParams<{ token: string }>() - useDocumentTitle('') - - const [inputs, setInputs] = useState>({}) - const [success, setSuccess] = useState(false) - - const { mutate: submitForm, isPending: isSubmitting } = useSubmitHumanInputForm() - - const { data: formData, isLoading, error } = useGetHumanInputForm(token) - - const expired = (error as HumanInputFormError | null)?.code === 'human_input_form_expired' - const submitted = (error as HumanInputFormError | null)?.code === 'human_input_form_submitted' - const rateLimitExceeded = (error as HumanInputFormError | null)?.code === 'web_form_rate_limit_exceeded' - - const splitByOutputVar = (content: string): string[] => { - const outputVarRegex = /(\{\{#\$output\.[^#]+#\}\})/g - const parts = content.split(outputVarRegex) - return parts.filter(part => part.length > 0) - } - - const contentList = useMemo(() => { - if (!formData?.form_content) - return [] - return splitByOutputVar(formData.form_content) - }, [formData?.form_content]) - - useEffect(() => { - if (!formData?.inputs) - return - const initialInputs: Record = {} - formData.inputs.forEach((item) => { - initialInputs[item.output_variable_name] = item.default.type === 'variable' ? formData.resolved_default_values[item.output_variable_name] || '' : item.default.value - }) - setInputs(initialInputs) - }, [formData?.inputs, formData?.resolved_default_values]) - - // use immer - const handleInputsChange = (name: string, value: string) => { - const newInputs = produce(inputs, (draft) => { - draft[name] = value - }) - setInputs(newInputs) - } - - const submit = (actionID: string) => { - submitForm( - { token, data: { inputs, action: actionID } }, - { - onSuccess: () => { - setSuccess(true) - }, - }, - ) - } - - if (isLoading) { - return ( - - ) - } - - if (success) { - return ( -
-
-
-
- -
-
-
{t('humanInput.thanks', { ns: 'share' })}
-
{t('humanInput.recorded', { ns: 'share' })}
-
-
{t('humanInput.submissionID', { id: token, ns: 'share' })}
-
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) - } - - if (expired) { - return ( -
-
-
-
- -
-
-
{t('humanInput.sorry', { ns: 'share' })}
-
{t('humanInput.expired', { ns: 'share' })}
-
-
{t('humanInput.submissionID', { id: token, ns: 'share' })}
-
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) - } - - if (submitted) { - return ( -
-
-
-
- -
-
-
{t('humanInput.sorry', { ns: 'share' })}
-
{t('humanInput.completed', { ns: 'share' })}
-
-
{t('humanInput.submissionID', { id: token, ns: 'share' })}
-
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) - } - - if (rateLimitExceeded) { - return ( -
-
-
-
- -
-
-
{t('humanInput.rateLimitExceeded', { ns: 'share' })}
-
-
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) - } - - if (!formData) { - return ( -
-
-
-
- -
-
-
{t('humanInput.formNotFound', { ns: 'share' })}
-
-
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) - } - - const site = formData.site.site - - return ( -
-
- -
{site.title}
-
-
-
- {contentList.map((content, index) => ( - - ))} -
- {formData.user_actions.map((action: UserAction) => ( - - ))} -
- -
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) -} - -export default React.memo(FormContent) diff --git a/web/app/(humanInputLayout)/form/[token]/page.tsx b/web/app/(humanInputLayout)/form/[token]/page.tsx deleted file mode 100644 index a7e2305b2b..0000000000 --- a/web/app/(humanInputLayout)/form/[token]/page.tsx +++ /dev/null @@ -1,13 +0,0 @@ -'use client' -import * as React from 'react' -import FormContent from './form' - -const FormPage = () => { - return ( -
- -
- ) -} - -export default React.memo(FormPage) diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index c874990448..113f3b5680 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -47,7 +47,7 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => { await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router, shareCode]) + }, [getSigninUrl, router, webAppLogout, shareCode]) if (appInfoError) { return ( diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index a2b847f74f..9f89a03993 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -31,7 +31,7 @@ const Splash: FC = ({ children }) => { await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router, shareCode]) + }, [getSigninUrl, router, webAppLogout, shareCode]) const [isLoading, setIsLoading] = useState(true) useEffect(() => { diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 1348e3111f..0fc364cb7e 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -115,7 +115,6 @@ export type AppPublisherProps = { missingStartNode?: boolean hasTriggerNode?: boolean // Whether workflow currently contains any trigger nodes (used to hide missing-start CTA when triggers exist). startNodeLimitExceeded?: boolean - hasHumanInputNode?: boolean } const PUBLISH_SHORTCUT = ['ctrl', '⇧', 'P'] @@ -139,14 +138,13 @@ const AppPublisher = ({ missingStartNode = false, hasTriggerNode = false, startNodeLimitExceeded = false, - hasHumanInputNode = false, }: AppPublisherProps) => { const { t } = useTranslation() const [published, setPublished] = useState(false) const [open, setOpen] = useState(false) const [showAppAccessControl, setShowAppAccessControl] = useState(false) - + const [isAppAccessSet, setIsAppAccessSet] = useState(true) const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false) const appDetail = useAppStore(state => state.appDetail) @@ -163,13 +161,6 @@ const AppPublisher = ({ const { data: appAccessSubjects, isLoading: isGettingAppWhiteListSubjects } = useAppWhiteListSubjects(appDetail?.id, open && systemFeatures.webapp_auth.enabled && appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS) const openAsyncWindow = useAsyncWindowOpen() - const isAppAccessSet = useMemo(() => { - if (appDetail && appAccessSubjects) { - return !(appDetail.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && appAccessSubjects.groups?.length === 0 && appAccessSubjects.members?.length === 0) - } - return true - }, [appAccessSubjects, appDetail]) - const noAccessPermission = useMemo(() => systemFeatures.webapp_auth.enabled && appDetail && appDetail.access_mode !== AccessMode.EXTERNAL_MEMBERS && !userCanAccessApp?.result, [systemFeatures, appDetail, userCanAccessApp]) const disabledFunctionButton = useMemo(() => (!publishedAt || missingStartNode || noAccessPermission), [publishedAt, missingStartNode, noAccessPermission]) @@ -180,13 +171,25 @@ const AppPublisher = ({ return t('noUserInputNode', { ns: 'app' }) if (noAccessPermission) return t('noAccessPermission', { ns: 'app' }) - }, [missingStartNode, noAccessPermission, publishedAt, t]) + }, [missingStartNode, noAccessPermission, publishedAt]) useEffect(() => { if (systemFeatures.webapp_auth.enabled && open && appDetail) refetch() }, [open, appDetail, refetch, systemFeatures]) + useEffect(() => { + if (appDetail && appAccessSubjects) { + if (appDetail.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && appAccessSubjects.groups?.length === 0 && appAccessSubjects.members?.length === 0) + setIsAppAccessSet(false) + else + setIsAppAccessSet(true) + } + else { + setIsAppAccessSet(true) + } + }, [appAccessSubjects, appDetail]) + const handlePublish = useCallback(async (params?: ModelAndParameter | PublishWorkflowParams) => { try { await onPublish?.(params) @@ -458,7 +461,7 @@ const AppPublisher = ({ {t('common.accessAPIReference', { ns: 'workflow' })} - {appDetail?.mode === AppModeEnum.WORKFLOW && !hasHumanInputNode && ( + {appDetail?.mode === AppModeEnum.WORKFLOW && ( { const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') fireEvent.change(nameInput, { target: { value: 'My App' } }) - fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({ name: 'My App', @@ -152,7 +152,7 @@ describe('CreateAppModal', () => { const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') fireEvent.change(nameInput, { target: { value: 'My App' } }) - fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalled()) expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' }) diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 40519dcb36..b13eec2e3d 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -68,7 +68,6 @@ type IDrawerContext = { } type StatusCount = { - paused: number success: number failed: number partial_success: number @@ -94,15 +93,7 @@ const statusTdRender = (statusCount: StatusCount) => { if (!statusCount) return null - if (statusCount.paused > 0) { - return ( -
- - Pending -
- ) - } - else if (statusCount.partial_success + statusCount.failed === 0) { + if (statusCount.partial_success + statusCount.failed === 0) { return (
@@ -305,7 +296,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { if (abortControllerRef.current === controller) abortControllerRef.current = null } - }, [detail.id, hasMore, timezone, t, appDetail]) + }, [detail.id, hasMore, timezone, t, appDetail, detail?.model_config?.configs?.introduction]) // Derive chatItemTree, threadChatItems, and oldestAnswerIdRef from allChatItems useEffect(() => { @@ -420,7 +411,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) return false } - }, [allChatItems, appDetail?.id, notify, t]) + }, [allChatItems, appDetail?.id, t]) const fetchInitiated = useRef(false) @@ -513,7 +504,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { finally { setIsLoading(false) } - }, [detail.id, hasMore, isLoading, timezone, t, appDetail]) + }, [detail.id, hasMore, isLoading, timezone, t, appDetail, detail?.model_config?.configs?.introduction]) const handleScroll = useCallback(() => { const scrollableDiv = document.getElementById('scrollableDiv') diff --git a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx index 54763907df..17857ec702 100644 --- a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx +++ b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx @@ -53,7 +53,6 @@ const defaultProviderContext = { refreshLicenseLimit: noop, isAllowTransferWorkspace: false, isAllowPublishAsCustomKnowledgePipelineTemplate: false, - humanInputEmailDeliveryEnabled: false, } const defaultModalContext: ModalContextState = { diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index 22358805a7..c39282a022 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -8,7 +8,7 @@ import { RiClipboardLine, RiFileList3Line, RiPlayList2Line, - RiResetLeftLine, + RiReplay15Line, RiSparklingFill, RiSparklingLine, RiThumbDownLine, @@ -18,12 +18,10 @@ import { useBoolean } from 'ahooks' import copy from 'copy-to-clipboard' import { useParams } from 'next/navigation' import * as React from 'react' -import { useCallback, useEffect, useState } from 'react' +import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useStore as useAppStore } from '@/app/components/app/store' import ActionButton, { ActionButtonState } from '@/app/components/base/action-button' -import HumanInputFilledFormList from '@/app/components/base/chat/chat/answer/human-input-filled-form-list' -import HumanInputFormList from '@/app/components/base/chat/chat/answer/human-input-form-list' import WorkflowProcessItem from '@/app/components/base/chat/chat/answer/workflow-process' import { useChatContext } from '@/app/components/base/chat/chat/context' import Loading from '@/app/components/base/loading' @@ -31,8 +29,7 @@ import { Markdown } from '@/app/components/base/markdown' import NewAudioButton from '@/app/components/base/new-audio-button' import Toast from '@/app/components/base/toast' import { fetchTextGenerationMessage } from '@/service/debug' -import { AppSourceType, fetchMoreLikeThis, submitHumanInputForm, updateFeedback } from '@/service/share' -import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow' +import { AppSourceType, fetchMoreLikeThis, updateFeedback } from '@/service/share' import { cn } from '@/utils/classnames' import ResultTab from './result-tab' @@ -124,7 +121,7 @@ const GenerationItem: FC = ({ const [isQuerying, { setTrue: startQuerying, setFalse: stopQuerying }] = useBoolean(false) const childProps = { - isInWebApp, + isInWebApp: true, content: completionRes, messageId: childMessageId, depth: depth + 1, @@ -205,22 +202,16 @@ const GenerationItem: FC = ({ } const [currentTab, setCurrentTab] = useState('DETAIL') - const showResultTabs = !!workflowProcessData?.resultText || !!workflowProcessData?.files?.length || (workflowProcessData?.humanInputFormDataList && workflowProcessData?.humanInputFormDataList.length > 0) || (workflowProcessData?.humanInputFilledFormDataList && workflowProcessData?.humanInputFilledFormDataList.length > 0) + const showResultTabs = !!workflowProcessData?.resultText || !!workflowProcessData?.files?.length const switchTab = async (tab: string) => { setCurrentTab(tab) } useEffect(() => { - if (workflowProcessData?.resultText || !!workflowProcessData?.files?.length || (workflowProcessData?.humanInputFormDataList && workflowProcessData?.humanInputFormDataList.length > 0) || (workflowProcessData?.humanInputFilledFormDataList && workflowProcessData?.humanInputFilledFormDataList.length > 0)) + if (workflowProcessData?.resultText || !!workflowProcessData?.files?.length) switchTab('RESULT') else switchTab('DETAIL') - }, [workflowProcessData?.files?.length, workflowProcessData?.resultText, workflowProcessData?.humanInputFormDataList, workflowProcessData?.humanInputFilledFormDataList]) - const handleSubmitHumanInputForm = useCallback(async (formToken: string, formData: { inputs: Record, action: string }) => { - if (appSourceType === AppSourceType.installedApp) - await submitHumanInputFormService(formToken, formData) - else - await submitHumanInputForm(formToken, formData) - }, [appSourceType]) + }, [workflowProcessData?.files?.length, workflowProcessData?.resultText]) return ( <> @@ -284,24 +275,7 @@ const GenerationItem: FC = ({ )}
{!isError && ( - <> - {currentTab === 'RESULT' && workflowProcessData.humanInputFormDataList && workflowProcessData.humanInputFormDataList.length > 0 && ( -
- -
- )} - {currentTab === 'RESULT' && workflowProcessData.humanInputFilledFormDataList && workflowProcessData.humanInputFilledFormDataList.length > 0 && ( -
- -
- )} - - + )} )} @@ -374,7 +348,7 @@ const GenerationItem: FC = ({ )} {isInWebApp && isError && ( - + )} {isInWebApp && !isWorkflow && !isTryApp && ( diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index 262efad781..b9597c8ea1 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -81,14 +81,6 @@ const WorkflowAppLogList: FC = ({ logs, appDetail, onRefresh }) => { ) } - if (status === 'paused') { - return ( -
- - Pending -
- ) - } if (status === 'running') { return (
diff --git a/web/app/components/base/action-button/index.css b/web/app/components/base/action-button/index.css index 4ede34aeb5..3c1a10b86f 100644 --- a/web/app/components/base/action-button/index.css +++ b/web/app/components/base/action-button/index.css @@ -26,10 +26,6 @@ @apply p-0.5 w-6 h-6 rounded-lg } - .action-btn-s { - @apply w-5 h-5 rounded-[6px] - } - .action-btn-xs { @apply p-0 w-4 h-4 rounded } diff --git a/web/app/components/base/action-button/index.tsx b/web/app/components/base/action-button/index.tsx index d182193b00..c91d472087 100644 --- a/web/app/components/base/action-button/index.tsx +++ b/web/app/components/base/action-button/index.tsx @@ -18,7 +18,6 @@ const actionButtonVariants = cva( variants: { size: { xs: 'action-btn-xs', - s: 'action-btn-s', m: 'action-btn-m', l: 'action-btn-l', xl: 'action-btn-xl', diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 304425b9a7..38a3f6c6b2 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -2,7 +2,6 @@ import type { FileEntity } from '../../file-uploader/types' import type { ChatConfig, ChatItem, - ChatItemInTree, OnSend, } from '../types' import { useCallback, useEffect, useMemo, useState } from 'react' @@ -17,9 +16,7 @@ import { fetchSuggestedQuestions, getUrl, stopChatMessageResponding, - submitHumanInputForm, } from '@/service/share' -import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow' import { TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' import { formatBooleanInputs } from '@/utils/model-config' @@ -76,9 +73,9 @@ const ChatWrapper = () => { }, [appParams, currentConversationItem?.introduction]) const { chatList, + setTargetMessageId, handleSend, handleStop, - handleSwitchSibling, isResponding: respondingState, suggestedQuestions, } = useChat( @@ -125,11 +122,8 @@ const ChatWrapper = () => { if (fileIsUploading) return true - - if (chatList.some(item => item.isAnswer && item.humanInputFormDataList && item.humanInputFormDataList.length > 0)) - return true return false - }, [allInputsHidden, inputsForms, chatList, inputsFormValue]) + }, [inputsFormValue, inputsForms, allInputsHidden]) useEffect(() => { if (currentChatInstanceRef.current) @@ -140,40 +134,6 @@ const ChatWrapper = () => { setIsResponding(respondingState) }, [respondingState, setIsResponding]) - // Resume paused workflows when chat history is loaded - useEffect(() => { - if (!appPrevChatTree || appPrevChatTree.length === 0) - return - - // Find the last answer item with workflow_run_id that needs resumption (DFS - find deepest first) - let lastPausedNode: ChatItemInTree | undefined - const findLastPausedWorkflow = (nodes: ChatItemInTree[]) => { - nodes.forEach((node) => { - // DFS: recurse to children first - if (node.children && node.children.length > 0) - findLastPausedWorkflow(node.children) - - // Track the last node with humanInputFormDataList - if (node.isAnswer && node.workflow_run_id && node.humanInputFormDataList && node.humanInputFormDataList.length > 0) - lastPausedNode = node - }) - } - - findLastPausedWorkflow(appPrevChatTree) - - // Only resume the last paused workflow - if (lastPausedNode) { - handleSwitchSibling( - lastPausedNode.id, - { - onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), - onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted, - isPublicAPI: appSourceType === AppSourceType.webApp, - }, - ) - } - }, []) - const doSend: OnSend = useCallback((message, files, isRegenerate = false, parentAnswer: ChatItem | null = null) => { const data: any = { query: message, @@ -189,10 +149,10 @@ const ChatWrapper = () => { { onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), onConversationComplete: isHistoryConversation ? undefined : handleNewConversationCompleted, - isPublicAPI: appSourceType === AppSourceType.webApp, + isPublicAPI: !isInstalledApp, }, ) - }, [inputsForms, currentConversationId, currentConversationInputs, newConversationInputs, chatList, handleSend, appSourceType, appId, isHistoryConversation, handleNewConversationCompleted]) + }, [chatList, handleNewConversationCompleted, handleSend, currentConversationId, currentConversationInputs, newConversationInputs, isInstalledApp, appId]) const doRegenerate = useCallback((chatItem: ChatItem, editedQuestion?: { message: string, files?: FileEntity[] }) => { const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)! @@ -200,27 +160,12 @@ const ChatWrapper = () => { doSend(editedQuestion ? editedQuestion.message : question.content, editedQuestion ? editedQuestion.files : question.message_files, true, isValidGeneratedAnswer(parentAnswer) ? parentAnswer : null) }, [chatList, doSend]) - const doSwitchSibling = useCallback((siblingMessageId: string) => { - handleSwitchSibling(siblingMessageId, { - onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), - onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted, - isPublicAPI: appSourceType === AppSourceType.webApp, - }) - }, [handleSwitchSibling, currentConversationId, handleNewConversationCompleted, appSourceType, appId]) - const messageList = useMemo(() => { if (currentConversationId || chatList.length > 1) return chatList // Without messages we are in the welcome screen, so hide the opening statement from chatlist return chatList.filter(item => !item.isOpeningStatement) - }, [chatList, currentConversationId]) - - const handleSubmitHumanInputForm = useCallback(async (formToken: string, formData: any) => { - if (isInstalledApp) - await submitHumanInputFormService(formToken, formData) - else - await submitHumanInputForm(formToken, formData) - }, [isInstalledApp]) + }, [chatList]) const [collapsed, setCollapsed] = useState(!!currentConversationId) @@ -329,7 +274,6 @@ const ChatWrapper = () => { inputsForm={inputsForms} onRegenerate={doRegenerate} onStopResponding={handleStop} - onHumanInputFormSubmit={handleSubmitHumanInputForm} chatNode={( <> {chatNode} @@ -342,7 +286,7 @@ const ChatWrapper = () => { answerIcon={answerIcon} hideProcessDetail themeBuilder={themeBuilder} - switchSibling={doSwitchSibling} + switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)} inputDisabled={inputDisabled} sidebarCollapseState={sidebarCollapseState} questionIcon={ diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index da344a9789..ad1de38d07 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -1,4 +1,3 @@ -import type { ExtraContent } from '../chat/type' import type { Callback, ChatConfig, @@ -10,7 +9,6 @@ import type { AppData, ConversationItem, } from '@/models/share' -import type { HumanInputFilledFormData, HumanInputFormData } from '@/types/workflow' import { useLocalStorageState } from 'ahooks' import { noop } from 'es-toolkit/function' import { produce } from 'immer' @@ -59,24 +57,6 @@ function getFormattedChatList(messages: any[]) { parentMessageId: item.parent_message_id || undefined, }) const answerFiles = item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [] - const humanInputFormDataList: HumanInputFormData[] = [] - const humanInputFilledFormDataList: HumanInputFilledFormData[] = [] - let workflowRunId = '' - if (item.status === 'paused') { - item.extra_contents?.forEach((content: ExtraContent) => { - if (content.type === 'human_input' && !content.submitted) { - humanInputFormDataList.push(content.form_definition) - workflowRunId = content.workflow_run_id - } - }) - } - else if (item.status === 'normal') { - item.extra_contents?.forEach((content: ExtraContent) => { - if (content.type === 'human_input' && content.submitted) { - humanInputFilledFormDataList.push(content.form_submission_data) - } - }) - } newChatList.push({ id: item.id, content: item.answer, @@ -86,9 +66,6 @@ function getFormattedChatList(messages: any[]) { citation: item.retriever_resources, message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id, upload_file_id: item.upload_file_id }))), parentMessageId: `question-${item.id}`, - humanInputFormDataList, - humanInputFilledFormDataList, - workflow_run_id: workflowRunId, }) }) return newChatList diff --git a/web/app/components/base/chat/chat/answer/human-input-content/content-item.tsx b/web/app/components/base/chat/chat/answer/human-input-content/content-item.tsx deleted file mode 100644 index 3ed777d41e..0000000000 --- a/web/app/components/base/chat/chat/answer/human-input-content/content-item.tsx +++ /dev/null @@ -1,54 +0,0 @@ -import type { ContentItemProps } from './type' -import * as React from 'react' -import { useMemo } from 'react' -import { Markdown } from '@/app/components/base/markdown' -import Textarea from '@/app/components/base/textarea' - -const ContentItem = ({ - content, - formInputFields, - inputs, - onInputChange, -}: ContentItemProps) => { - const isInputField = (field: string) => { - const outputVarRegex = /\{\{#\$output\.[^#]+#\}\}/ - return outputVarRegex.test(field) - } - - const extractFieldName = (str: string): string => { - const outputVarRegex = /\{\{#\$output\.([^#]+)#\}\}/ - const match = str.match(outputVarRegex) - return match ? match[1] : '' - } - - const fieldName = useMemo(() => { - return extractFieldName(content) - }, [content]) - - const formInputField = useMemo(() => { - return formInputFields.find(field => field.output_variable_name === fieldName) - }, [formInputFields, fieldName]) - - if (!isInputField(content)) { - return ( - - ) - } - - if (!formInputField) - return null - - return ( -
- {formInputField.type === 'paragraph' && ( -