diff --git a/.agents/skills/component-refactoring/SKILL.md b/.agents/skills/component-refactoring/SKILL.md index 7006c382c8..140e0ef434 100644 --- a/.agents/skills/component-refactoring/SKILL.md +++ b/.agents/skills/component-refactoring/SKILL.md @@ -480,4 +480,4 @@ const useButtonState = () => { ### Related Skills - `frontend-testing` - For testing refactored components -- `web/testing/testing.md` - Testing specification +- `web/docs/test.md` - Testing specification diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 0716c81ef7..280fcb6341 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -7,7 +7,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices. -> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`). +> **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`). ## When to Apply This Skill @@ -309,7 +309,7 @@ For more detailed information, refer to: ### Primary Specification (MUST follow) -- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document. +- **`web/docs/test.md`** - The canonical testing specification. This skill is derived from this document. ### Reference Examples in Codebase diff --git a/.agents/skills/frontend-testing/references/workflow.md b/.agents/skills/frontend-testing/references/workflow.md index 009c3e013b..bc4ed8285a 100644 --- a/.agents/skills/frontend-testing/references/workflow.md +++ b/.agents/skills/frontend-testing/references/workflow.md @@ -4,7 +4,7 @@ This guide defines the workflow for generating tests, especially for complex com ## Scope Clarification -This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals. +This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/docs/test.md` § Coverage Goals. | Scope | Rule | |-------|------| diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 190e00d9fe..52e3272f99 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -72,6 +72,7 @@ jobs: OPENDAL_FS_ROOT: /tmp/dify-storage run: | uv run --project api pytest \ + -n auto \ --timeout "${PYTEST_TIMEOUT:-180}" \ api/tests/integration_tests/workflow \ api/tests/integration_tests/tools \ diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 704d896192..ac7f3a6b48 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -8,6 +8,7 @@ on: - "build/**" - "release/e-*" - "hotfix/**" + - "feat/hitl-backend" tags: - "*" diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index fdc05d1d65..cbd6edf94b 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -47,13 +47,9 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: uv run --directory api --dev lint-imports - - name: Run Basedpyright Checks + - name: Run Type Checks if: steps.changed-files.outputs.any_changed == 'true' - run: dev/basedpyright-check - - - name: Run Mypy Type Checks - if: steps.changed-files.outputs.any_changed == 'true' - run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . + run: make type-check - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/AGENTS.md b/AGENTS.md index 7d96ac3a6d..51fa6e4527 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -7,7 +7,7 @@ Dify is an open-source platform for developing LLM applications with an intuitiv The codebase is split into: - **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design -- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19 +- **Frontend Web** (`/web`): Next.js application using TypeScript and React - **Docker deployment** (`/docker`): Containerized deployment configurations ## Backend Workflow @@ -18,36 +18,7 @@ The codebase is split into: ## Frontend Workflow -```bash -cd web -pnpm lint:fix -pnpm type-check:tsgo -pnpm test -``` - -### Frontend Linting - -ESLint is used for frontend code quality. Available commands: - -```bash -# Lint all files (report only) -pnpm lint - -# Lint and auto-fix issues -pnpm lint:fix - -# Lint specific files or directories -pnpm lint:fix app/components/base/button/ -pnpm lint:fix app/components/base/button/index.tsx - -# Lint quietly (errors only, no warnings) -pnpm lint:quiet - -# Check code complexity -pnpm lint:complexity -``` - -**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files. +- Read `web/AGENTS.md` for details ## Testing & Quality Practices diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 20a7d6c6f6..d7f007af67 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -77,7 +77,7 @@ How we prioritize: For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly. -**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there. +**Testing**: All React components must have comprehensive test coverage. See [web/docs/test.md](https://github.com/langgenius/dify/blob/main/web/docs/test.md) for the canonical frontend testing guidelines and follow every requirement described there. #### Backend diff --git a/Makefile b/Makefile index e92a7b1314..984e8676ee 100644 --- a/Makefile +++ b/Makefile @@ -68,9 +68,11 @@ lint: @echo "✅ Linting complete" type-check: - @echo "📝 Running type check with basedpyright..." - @uv run --directory api --dev basedpyright - @echo "✅ Type check complete" + @echo "📝 Running type checks (basedpyright + mypy + ty)..." + @./dev/basedpyright-check $(PATH_TO_CHECK) + @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . + @cd api && uv run ty check + @echo "✅ Type checks complete" test: @echo "🧪 Running backend unit tests..." @@ -78,7 +80,7 @@ test: echo "Target: $(TARGET_TESTS)"; \ uv run --project api --dev pytest $(TARGET_TESTS); \ else \ - uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \ + PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \ fi @echo "✅ Tests complete" @@ -130,7 +132,7 @@ help: @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" - @echo " make type-check - Run type checking with basedpyright" + @echo " make type-check - Run type checks (basedpyright, mypy, ty)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" diff --git a/api/.env.example b/api/.env.example index c3b1474549..fcadfa1c3b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -617,6 +617,7 @@ PLUGIN_DAEMON_URL=http://127.0.0.1:5002 PLUGIN_REMOTE_INSTALL_PORT=5003 PLUGIN_REMOTE_INSTALL_HOST=localhost PLUGIN_MAX_PACKAGE_SIZE=15728640 +PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 # Marketplace configuration @@ -717,3 +718,27 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 + +# Redis URL used for PubSub between API and +# celery worker +# defaults to url constructed from `REDIS_*` +# configurations +PUBSUB_REDIS_URL= +# Pub/sub channel type for streaming events. +# valid options are: +# +# - pubsub: for normal Pub/Sub +# - sharded: for sharded Pub/Sub +# +# It's highly recommended to use sharded Pub/Sub AND redis cluster +# for large deployments. +PUBSUB_REDIS_CHANNEL_TYPE=pubsub +# Whether to use Redis cluster mode while running +# PubSub. +# It's highly recommended to enable this for large deployments. +PUBSUB_REDIS_USE_CLUSTERS=false + +# Whether to Enable human input timeout check task +ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true +# Human input timeout check interval in minutes +HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1 diff --git a/api/.importlinter b/api/.importlinter index 2b4a3a5bd6..98f87710ed 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -36,6 +36,8 @@ ignore_imports = core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine core.workflow.nodes.loop.loop_node -> core.workflow.graph core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels + # TODO(QuantumGhost): fix the import violation later + core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities [importlinter:contract:workflow-infrastructure-dependencies] name = Workflow Infrastructure Dependencies @@ -58,6 +60,8 @@ ignore_imports = core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.graph_engine.manager -> extensions.ext_redis core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis + # TODO(QuantumGhost): use DI to avoid depending on global DB. + core.workflow.nodes.human_input.human_input_node -> extensions.ext_database [importlinter:contract:workflow-external-imports] name = Workflow External Imports @@ -145,6 +149,7 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> core.agent.entities core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities + core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities @@ -227,6 +232,9 @@ ignore_imports = core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset + core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service + core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.llm.node -> models.dataset core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer @@ -245,6 +253,7 @@ ignore_imports = core.workflow.nodes.document_extractor.node -> core.variables.segments core.workflow.nodes.http_request.executor -> core.variables.segments core.workflow.nodes.http_request.node -> core.variables.segments + core.workflow.nodes.human_input.entities -> core.variables.consts core.workflow.nodes.iteration.iteration_node -> core.variables core.workflow.nodes.iteration.iteration_node -> core.variables.segments core.workflow.nodes.iteration.iteration_node -> core.variables.variables @@ -291,6 +300,8 @@ ignore_imports = core.workflow.nodes.llm.llm_utils -> extensions.ext_database core.workflow.nodes.llm.node -> extensions.ext_database core.workflow.nodes.tool.tool_node -> extensions.ext_database + core.workflow.nodes.human_input.human_input_node -> extensions.ext_database + core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository core.workflow.workflow_entry -> extensions.otel.runtime core.workflow.nodes.agent.agent_node -> models core.workflow.nodes.base.node -> models.enums @@ -300,6 +311,58 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> services core.workflow.nodes.tool.tool_node -> services +[importlinter:contract:model-runtime-no-internal-imports] +name = Model Runtime Internal Imports +type = forbidden +source_modules = + core.model_runtime +forbidden_modules = + configs + controllers + extensions + models + services + tasks + core.agent + core.app + core.base + core.callback_handler + core.datasource + core.db + core.entities + core.errors + core.extension + core.external_data_tool + core.file + core.helper + core.hosting_configuration + core.indexing_runner + core.llm_generator + core.logging + core.mcp + core.memory + core.model_manager + core.moderation + core.ops + core.plugin + core.prompt + core.provider_manager + core.rag + core.repositories + core.schemas + core.tools + core.trigger + core.variables + core.workflow +ignore_imports = + core.model_runtime.model_providers.__base.ai_model -> configs + core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis + core.model_runtime.model_providers.__base.large_language_model -> configs + core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type + core.model_runtime.model_providers.model_provider_factory -> configs + core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis + core.model_runtime.model_providers.model_provider_factory -> models.provider_ids + [importlinter:contract:rsc] name = RSC type = layers diff --git a/api/app.py b/api/app.py index 99f70f32d5..c018c8a045 100644 --- a/api/app.py +++ b/api/app.py @@ -1,4 +1,12 @@ +from __future__ import annotations + import sys +from typing import TYPE_CHECKING, cast + +if TYPE_CHECKING: + from celery import Celery + + celery: Celery def is_db_command() -> bool: @@ -23,7 +31,7 @@ else: from app_factory import create_app app = create_app() - celery = app.extensions["celery"] + celery = cast("Celery", app.extensions["celery"]) if __name__ == "__main__": app.run(host="0.0.0.0", port=5001) diff --git a/api/app_factory.py b/api/app_factory.py index 07859a3758..dcbc821687 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -149,7 +149,7 @@ def initialize_extensions(app: DifyApp): logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2)) -def create_migrations_app(): +def create_migrations_app() -> DifyApp: app = create_flask_app_with_configs() from extensions import ext_database, ext_migrate diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 786094f295..c405d5d44c 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,3 +1,4 @@ +from datetime import timedelta from enum import StrEnum from typing import Literal @@ -48,6 +49,16 @@ class SecurityConfig(BaseSettings): default=5, ) + WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS: PositiveInt = Field( + description="Maximum number of web form submissions allowed per IP within the rate limit window", + default=30, + ) + + WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS: PositiveInt = Field( + description="Time window in seconds for web form submission rate limiting", + default=60, + ) + LOGIN_DISABLED: bool = Field( description="Whether to disable login checks", default=False, @@ -82,6 +93,12 @@ class AppExecutionConfig(BaseSettings): default=0, ) + HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS: PositiveInt = Field( + description="Maximum seconds a workflow run can stay paused waiting for human input before global timeout.", + default=int(timedelta(days=7).total_seconds()), + ge=1, + ) + class CodeExecutionSandboxConfig(BaseSettings): """ @@ -243,6 +260,11 @@ class PluginConfig(BaseSettings): default=15728640 * 12, ) + PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field( + description="TTL in seconds for caching plugin model schemas in Redis", + default=60 * 60, + ) + class MarketplaceConfig(BaseSettings): """ @@ -1129,6 +1151,14 @@ class CeleryScheduleTasksConfig(BaseSettings): description="Enable queue monitor task", default=False, ) + ENABLE_HUMAN_INPUT_TIMEOUT_TASK: bool = Field( + description="Enable human input timeout check task", + default=True, + ) + HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: PositiveInt = Field( + description="Human input timeout check interval in minutes", + default=1, + ) ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field( description="Enable check upgradable plugin task", default=True, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 63f75924bf..a15e42babf 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -6,6 +6,7 @@ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, Pos from pydantic_settings import BaseSettings from .cache.redis_config import RedisConfig +from .cache.redis_pubsub_config import RedisPubSubConfig from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig from .storage.amazon_s3_storage_config import S3StorageConfig from .storage.azure_blob_storage_config import AzureBlobStorageConfig @@ -317,6 +318,7 @@ class MiddlewareConfig( CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig KeywordStoreConfig, RedisConfig, + RedisPubSubConfig, # configs of storage and storage providers StorageConfig, AliyunOSSStorageConfig, diff --git a/api/configs/middleware/cache/redis_pubsub_config.py b/api/configs/middleware/cache/redis_pubsub_config.py new file mode 100644 index 0000000000..a72e1dd28f --- /dev/null +++ b/api/configs/middleware/cache/redis_pubsub_config.py @@ -0,0 +1,96 @@ +from typing import Literal, Protocol +from urllib.parse import quote_plus, urlunparse + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class RedisConfigDefaults(Protocol): + REDIS_HOST: str + REDIS_PORT: int + REDIS_USERNAME: str | None + REDIS_PASSWORD: str | None + REDIS_DB: int + REDIS_USE_SSL: bool + REDIS_USE_SENTINEL: bool | None + REDIS_USE_CLUSTERS: bool + + +class RedisConfigDefaultsMixin: + def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults: + return self + + +class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): + """ + Configuration settings for Redis pub/sub streaming. + """ + + PUBSUB_REDIS_URL: str | None = Field( + alias="PUBSUB_REDIS_URL", + description=( + "Redis connection URL for pub/sub streaming events between API " + "and celery worker, defaults to url constructed from " + "`REDIS_*` configurations" + ), + default=None, + ) + + PUBSUB_REDIS_USE_CLUSTERS: bool = Field( + description=( + "Enable Redis Cluster mode for pub/sub streaming. It's highly " + "recommended to enable this for large deployments." + ), + default=False, + ) + + PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field( + description=( + "Pub/sub channel type for streaming events. " + "Valid options are:\n" + "\n" + " - pubsub: for normal Pub/Sub\n" + " - sharded: for sharded Pub/Sub\n" + "\n" + "It's highly recommended to use sharded Pub/Sub AND redis cluster " + "for large deployments." + ), + default="pubsub", + ) + + def _build_default_pubsub_url(self) -> str: + defaults = self._redis_defaults() + if not defaults.REDIS_HOST or not defaults.REDIS_PORT: + raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed") + + scheme = "rediss" if defaults.REDIS_USE_SSL else "redis" + username = defaults.REDIS_USERNAME or None + password = defaults.REDIS_PASSWORD or None + + userinfo = "" + if username: + userinfo = quote_plus(username) + if password: + password_part = quote_plus(password) + userinfo = f"{userinfo}:{password_part}" if userinfo else f":{password_part}" + if userinfo: + userinfo = f"{userinfo}@" + + host = defaults.REDIS_HOST + port = defaults.REDIS_PORT + db = defaults.REDIS_DB + + netloc = f"{userinfo}{host}:{port}" + return urlunparse((scheme, netloc, f"/{db}", "", "", "")) + + @property + def normalized_pubsub_redis_url(self) -> str: + pubsub_redis_url = self.PUBSUB_REDIS_URL + if pubsub_redis_url: + cleaned = pubsub_redis_url.strip() + pubsub_redis_url = cleaned or None + + if pubsub_redis_url: + return pubsub_redis_url + + return self._build_default_pubsub_url() diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 7c16bc231f..c52dcf8a57 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: from core.datasource.__base.datasource_provider import DatasourcePluginProviderController - from core.model_runtime.entities.model_entities import AIModelEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.trigger.provider import PluginTriggerProviderController @@ -29,12 +28,6 @@ plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( ContextVar("plugin_model_providers_lock") ) -plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock")) - -plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( - ContextVar("plugin_model_schemas") -) - datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( RecyclableContextVar(ContextVar("datasource_plugin_providers")) ) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index fdc9aabc83..902d67174b 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -37,6 +37,7 @@ from . import ( apikey, extension, feature, + human_input_form, init_validate, ping, setup, @@ -171,6 +172,7 @@ __all__ = [ "forgot_password", "generator", "hit_testing", + "human_input_form", "init_validate", "installed_app", "load_balancing_config", diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index e1ee2c24b8..03b602f6e8 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -243,15 +243,13 @@ class InsertExploreBannerApi(Resource): def post(self): payload = InsertExploreBannerPayload.model_validate(console_ns.payload) - content = { - "category": payload.category, - "title": payload.title, - "description": payload.description, - "img-src": payload.img_src, - } - banner = ExporleBanner( - content=content, + content={ + "category": payload.category, + "title": payload.title, + "description": payload.description, + "img-src": payload.img_src, + }, link=payload.link, sort=payload.sort, language=payload.language, diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 55fdcb51e4..14910c5895 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -89,6 +89,7 @@ status_count_model = console_ns.model( "success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer, + "paused": fields.Integer, }, ) diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index b4fc44767a..1ac55b5e8d 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Any from flask_restx import Resource from pydantic import BaseModel, Field @@ -12,10 +11,12 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, ) from controllers.console.wraps import account_initialization_required, setup_required +from core.app.app_config.entities import ModelConfig from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -26,28 +27,13 @@ from services.workflow_service import WorkflowService DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -class RuleGeneratePayload(BaseModel): - instruction: str = Field(..., description="Rule generation instruction") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") - no_variable: bool = Field(default=False, description="Whether to exclude variables") - - -class RuleCodeGeneratePayload(RuleGeneratePayload): - code_language: str = Field(default="javascript", description="Programming language for code generation") - - -class RuleStructuredOutputPayload(BaseModel): - instruction: str = Field(..., description="Structured output generation instruction") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") - - class InstructionGeneratePayload(BaseModel): flow_id: str = Field(..., description="Workflow/Flow ID") node_id: str = Field(default="", description="Node ID for workflow context") current: str = Field(default="", description="Current instruction text") language: str = Field(default="javascript", description="Programming language (javascript/python)") instruction: str = Field(..., description="Instruction for generation") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") ideal_output: str = Field(default="", description="Expected ideal output") @@ -64,6 +50,7 @@ reg(RuleCodeGeneratePayload) reg(RuleStructuredOutputPayload) reg(InstructionGeneratePayload) reg(InstructionTemplatePayload) +reg(ModelConfig) @console_ns.route("/rule-generate") @@ -82,12 +69,7 @@ class RuleGenerateApi(Resource): _, current_tenant_id = current_account_with_tenant() try: - rules = LLMGenerator.generate_rule_config( - tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=args.no_variable, - ) + rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -118,9 +100,7 @@ class RuleCodeGenerateApi(Resource): try: code_result = LLMGenerator.generate_code( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.code_language, + args=args, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -152,8 +132,7 @@ class RuleStructuredOutputGenerateApi(Resource): try: structured_output = LLMGenerator.generate_structured_output( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, + args=args, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -204,23 +183,29 @@ class InstructionGenerateApi(Resource): case "llm": return LLMGenerator.generate_rule_config( current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, + args=RuleGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + ), ) case "agent": return LLMGenerator.generate_rule_config( current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, + args=RuleGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + ), ) case "code": return LLMGenerator.generate_code( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.language, + args=RuleCodeGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + code_language=args.language, + ), ) case _: return {"error": f"invalid node type: {node_type}"} diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 12ada8b798..ab1628d5d4 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -32,7 +32,7 @@ from libs.login import current_account_with_tenant, login_required from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError -from services.message_service import MessageService +from services.message_service import MessageService, attach_message_extra_contents logger = logging.getLogger(__name__) DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -198,6 +198,7 @@ message_detail_model = console_ns.model( "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), "message_files": fields.List(fields.Nested(message_file_model)), + "extra_contents": fields.List(fields.Raw), "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, @@ -290,6 +291,7 @@ class ChatMessageListApi(Resource): has_more = False history_messages = list(reversed(history_messages)) + attach_message_extra_contents(history_messages) return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more) @@ -474,4 +476,5 @@ class MessageApi(Resource): if not message: raise NotFound("Message Not Exists.") + attach_message_extra_contents([message]) return message diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 755463cb70..27e1d01af6 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -507,6 +507,179 @@ class WorkflowDraftRunLoopNodeApi(Resource): raise InternalServerError() +class HumanInputFormPreviewPayload(BaseModel): + inputs: dict[str, Any] = Field( + default_factory=dict, + description="Values used to fill missing upstream variables referenced in form_content", + ) + + +class HumanInputFormSubmitPayload(BaseModel): + form_inputs: dict[str, Any] = Field(..., description="Values the user provides for the form's own fields") + inputs: dict[str, Any] = Field( + ..., + description="Values used to fill missing upstream variables referenced in form_content", + ) + action: str = Field(..., description="Selected action ID") + + +class HumanInputDeliveryTestPayload(BaseModel): + delivery_method_id: str = Field(..., description="Delivery method ID") + inputs: dict[str, Any] = Field( + default_factory=dict, + description="Values used to fill missing upstream variables referenced in form_content", + ) + + +reg(HumanInputFormPreviewPayload) +reg(HumanInputFormSubmitPayload) +reg(HumanInputDeliveryTestPayload) + + +@console_ns.route("/apps//advanced-chat/workflows/draft/human-input/nodes//form/preview") +class AdvancedChatDraftHumanInputFormPreviewApi(Resource): + @console_ns.doc("get_advanced_chat_draft_human_input_form") + @console_ns.doc(description="Get human input form preview for advanced chat workflow") + @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @edit_permission_required + def post(self, app_model: App, node_id: str): + """ + Preview human input form content and placeholders + """ + current_user, _ = current_account_with_tenant() + args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {}) + inputs = args.inputs + + workflow_service = WorkflowService() + preview = workflow_service.get_human_input_form_preview( + app_model=app_model, + account=current_user, + node_id=node_id, + inputs=inputs, + ) + return jsonable_encoder(preview) + + +@console_ns.route("/apps//advanced-chat/workflows/draft/human-input/nodes//form/run") +class AdvancedChatDraftHumanInputFormRunApi(Resource): + @console_ns.doc("submit_advanced_chat_draft_human_input_form") + @console_ns.doc(description="Submit human input form preview for advanced chat workflow") + @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @edit_permission_required + def post(self, app_model: App, node_id: str): + """ + Submit human input form preview + """ + current_user, _ = current_account_with_tenant() + args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {}) + workflow_service = WorkflowService() + result = workflow_service.submit_human_input_form_preview( + app_model=app_model, + account=current_user, + node_id=node_id, + form_inputs=args.form_inputs, + inputs=args.inputs, + action=args.action, + ) + return jsonable_encoder(result) + + +@console_ns.route("/apps//workflows/draft/human-input/nodes//form/preview") +class WorkflowDraftHumanInputFormPreviewApi(Resource): + @console_ns.doc("get_workflow_draft_human_input_form") + @console_ns.doc(description="Get human input form preview for workflow") + @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App, node_id: str): + """ + Preview human input form content and placeholders + """ + current_user, _ = current_account_with_tenant() + args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {}) + inputs = args.inputs + + workflow_service = WorkflowService() + preview = workflow_service.get_human_input_form_preview( + app_model=app_model, + account=current_user, + node_id=node_id, + inputs=inputs, + ) + return jsonable_encoder(preview) + + +@console_ns.route("/apps//workflows/draft/human-input/nodes//form/run") +class WorkflowDraftHumanInputFormRunApi(Resource): + @console_ns.doc("submit_workflow_draft_human_input_form") + @console_ns.doc(description="Submit human input form preview for workflow") + @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App, node_id: str): + """ + Submit human input form preview + """ + current_user, _ = current_account_with_tenant() + workflow_service = WorkflowService() + args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {}) + result = workflow_service.submit_human_input_form_preview( + app_model=app_model, + account=current_user, + node_id=node_id, + form_inputs=args.form_inputs, + inputs=args.inputs, + action=args.action, + ) + return jsonable_encoder(result) + + +@console_ns.route("/apps//workflows/draft/human-input/nodes//delivery-test") +class WorkflowDraftHumanInputDeliveryTestApi(Resource): + @console_ns.doc("test_workflow_draft_human_input_delivery") + @console_ns.doc(description="Test human input delivery for workflow") + @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @console_ns.expect(console_ns.models[HumanInputDeliveryTestPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) + @edit_permission_required + def post(self, app_model: App, node_id: str): + """ + Test human input delivery + """ + current_user, _ = current_account_with_tenant() + workflow_service = WorkflowService() + args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {}) + workflow_service.test_human_input_delivery( + app_model=app_model, + account=current_user, + node_id=node_id, + delivery_method_id=args.delivery_method_id, + inputs=args.inputs, + ) + return jsonable_encoder({}) + + @console_ns.route("/apps//workflows/draft/run") class DraftWorkflowRunApi(Resource): @console_ns.doc("run_draft_workflow") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index fa74f8aea1..d9a5dde55a 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -5,10 +5,15 @@ from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator from sqlalchemy import select +from sqlalchemy.orm import sessionmaker +from configs import dify_config from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required +from controllers.web.error import NotFoundError +from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -27,9 +32,21 @@ from libs.custom_inputs import time_duration from libs.helper import uuid_value from libs.login import current_user, login_required from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom +from models.workflow import WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME from services.workflow_run_service import WorkflowRunService + +def _build_backstage_input_url(form_token: str | None) -> str | None: + if not form_token: + return None + base_url = dify_config.APP_WEB_URL + if not base_url: + return None + return f"{base_url.rstrip('/')}/form/{form_token}" + + # Workflow run status choices for filtering WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600 @@ -440,3 +457,63 @@ class WorkflowRunNodeExecutionListApi(Resource): ) return {"data": node_executions} + + +@console_ns.route("/workflow//pause-details") +class ConsoleWorkflowPauseDetailsApi(Resource): + """Console API for getting workflow pause details.""" + + @account_initialization_required + @login_required + def get(self, workflow_run_id: str): + """ + Get workflow pause details. + + GET /console/api/workflow//pause-details + + Returns information about why and where the workflow is paused. + """ + + # Query WorkflowRun to determine if workflow is suspended + session_maker = sessionmaker(bind=db.engine) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker) + workflow_run = db.session.get(WorkflowRun, workflow_run_id) + if not workflow_run: + raise NotFoundError("Workflow run not found") + + # Check if workflow is suspended + is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED + if not is_paused: + return { + "paused_at": None, + "paused_nodes": [], + }, 200 + + pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) + pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] + + # Build response + paused_at = pause_entity.paused_at if pause_entity else None + paused_nodes = [] + response = { + "paused_at": paused_at.isoformat() + "Z" if paused_at else None, + "paused_nodes": paused_nodes, + } + + for reason in pause_reasons: + if isinstance(reason, HumanInputRequired): + paused_nodes.append( + { + "node_id": reason.node_id, + "node_title": reason.node_title, + "pause_type": { + "type": "human_input", + "form_id": reason.form_id, + "backstage_input_url": _build_backstage_input_url(reason.form_token), + }, + } + ) + else: + raise AssertionError("unimplemented.") + + return response, 200 diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 8fbbc51e21..30e4ed1119 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -148,6 +148,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: dict[str, Any] | None = None + summary_index_setting: dict[str, Any] | None = None partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None @@ -288,7 +289,14 @@ class DatasetListApi(Resource): @enterprise_license_required def get(self): current_user, current_tenant_id = current_account_with_tenant() - query = ConsoleDatasetListQuery.model_validate(request.args.to_dict()) + # Convert query parameters to dict, handling list parameters correctly + query_params: dict[str, str | list[str]] = dict(request.args.to_dict()) + # Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value) + if "ids" in request.args: + query_params["ids"] = request.args.getlist("ids") + if "tag_ids" in request.args: + query_params["tag_ids"] = request.args.getlist("tag_ids") + query = ConsoleDatasetListQuery.model_validate(query_params) # provider = request.args.get("provider", default="vendor") if query.ids: datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 57fb9abf29..6e3c0db8a3 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -45,6 +45,7 @@ from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService +from tasks.generate_summary_index_task import generate_summary_index_task from ..app.error import ( ProviderModelCurrentlyNotSupportError, @@ -103,6 +104,10 @@ class DocumentRenamePayload(BaseModel): name: str +class GenerateSummaryPayload(BaseModel): + document_list: list[str] + + class DocumentBatchDownloadZipPayload(BaseModel): """Request payload for bulk downloading documents as a zip archive.""" @@ -125,6 +130,7 @@ register_schema_models( RetrievalModel, DocumentRetryPayload, DocumentRenamePayload, + GenerateSummaryPayload, DocumentBatchDownloadZipPayload, ) @@ -312,6 +318,13 @@ class DatasetDocumentListApi(Resource): paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items + + DocumentService.enrich_documents_with_summary_index_status( + documents=documents, + dataset=dataset, + tenant_id=current_tenant_id, + ) + if fetch: for document in documents: completed_segments = ( @@ -797,6 +810,7 @@ class DocumentApi(DocumentResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "need_summary": document.need_summary if document.need_summary is not None else False, } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) @@ -832,6 +846,7 @@ class DocumentApi(DocumentResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "need_summary": document.need_summary if document.need_summary is not None else False, } return response, 200 @@ -1255,3 +1270,137 @@ class DocumentPipelineExecutionLogApi(DocumentResource): "input_data": log.input_data, "datasource_node_id": log.datasource_node_id, }, 200 + + +@console_ns.route("/datasets//documents/generate-summary") +class DocumentGenerateSummaryApi(Resource): + @console_ns.doc("generate_summary_for_documents") + @console_ns.doc(description="Generate summary index for documents") + @console_ns.doc(params={"dataset_id": "Dataset ID"}) + @console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__]) + @console_ns.response(200, "Summary generation started successfully") + @console_ns.response(400, "Invalid request or dataset configuration") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset not found") + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def post(self, dataset_id): + """ + Generate summary index for specified documents. + + This endpoint checks if the dataset configuration supports summary generation + (indexing_technique must be 'high_quality' and summary_index_setting.enable must be true), + then asynchronously generates summary indexes for the provided documents. + """ + current_user, _ = current_account_with_tenant() + dataset_id = str(dataset_id) + + # Get dataset + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + + # Check permissions + if not current_user.is_dataset_editor: + raise Forbidden() + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Validate request payload + payload = GenerateSummaryPayload.model_validate(console_ns.payload or {}) + document_list = payload.document_list + + if not document_list: + from werkzeug.exceptions import BadRequest + + raise BadRequest("document_list cannot be empty.") + + # Check if dataset configuration supports summary generation + if dataset.indexing_technique != "high_quality": + raise ValueError( + f"Summary generation is only available for 'high_quality' indexing technique. " + f"Current indexing technique: {dataset.indexing_technique}" + ) + + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.") + + # Verify all documents exist and belong to the dataset + documents = DocumentService.get_documents_by_ids(dataset_id, document_list) + + if len(documents) != len(document_list): + found_ids = {doc.id for doc in documents} + missing_ids = set(document_list) - found_ids + raise NotFound(f"Some documents not found: {list(missing_ids)}") + + # Dispatch async tasks for each document + for document in documents: + # Skip qa_model documents as they don't generate summaries + if document.doc_form == "qa_model": + logger.info("Skipping summary generation for qa_model document %s", document.id) + continue + + # Dispatch async task + generate_summary_index_task.delay(dataset_id, document.id) + logger.info( + "Dispatched summary generation task for document %s in dataset %s", + document.id, + dataset_id, + ) + + return {"result": "success"}, 200 + + +@console_ns.route("/datasets//documents//summary-status") +class DocumentSummaryStatusApi(DocumentResource): + @console_ns.doc("get_document_summary_status") + @console_ns.doc(description="Get summary index generation status for a document") + @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @console_ns.response(200, "Summary status retrieved successfully") + @console_ns.response(404, "Document not found") + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + """ + Get summary index generation status for a document. + + Returns: + - total_segments: Total number of segments in the document + - summary_status: Dictionary with status counts + - completed: Number of summaries completed + - generating: Number of summaries being generated + - error: Number of summaries with errors + - not_started: Number of segments without summary records + - summaries: List of summary records with status and content preview + """ + current_user, _ = current_account_with_tenant() + dataset_id = str(dataset_id) + document_id = str(document_id) + + # Get dataset + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + + # Check permissions + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Get summary status detail from service + from services.summary_index_service import SummaryIndexService + + result = SummaryIndexService.get_document_summary_status_detail( + document_id=document_id, + dataset_id=dataset_id, + ) + + return result, 200 diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 08e1ddd3e0..23a668112d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -41,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task +def _get_segment_with_summary(segment, dataset_id): + """Helper function to marshal segment and add summary information.""" + from services.summary_index_service import SummaryIndexService + + segment_dict = dict(marshal(segment, segment_fields)) + # Query summary for this segment (only enabled summaries) + summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) + segment_dict["summary"] = summary.summary_content if summary else None + return segment_dict + + class SegmentListQuery(BaseModel): limit: int = Field(default=20, ge=1, le=100) status: list[str] = Field(default_factory=list) @@ -63,6 +74,7 @@ class SegmentUpdatePayload(BaseModel): keywords: list[str] | None = None regenerate_child_chunks: bool = False attachment_ids: list[str] | None = None + summary: str | None = None # Summary content for summary index class BatchImportPayload(BaseModel): @@ -181,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource): segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) + # Query summaries for all segments in this page (batch query for efficiency) + segment_ids = [segment.id for segment in segments.items] + summaries = {} + if segment_ids: + from services.summary_index_service import SummaryIndexService + + summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) + # Only include enabled summaries (already filtered by service) + summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()} + + # Add summary to each segment + segments_with_summary = [] + for segment in segments.items: + segment_dict = dict(marshal(segment, segment_fields)) + segment_dict["summary"] = summaries.get(segment.id) + segments_with_summary.append(segment_dict) + response = { - "data": marshal(segments.items, segment_fields), + "data": segments_with_summary, "limit": limit, "total": segments.total, "total_pages": segments.pages, @@ -328,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource): payload_dict = payload.model_dump(exclude_none=True) SegmentService.segment_create_args_validate(payload_dict, document) segment = SegmentService.create_segment(payload_dict, document, dataset) - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 @console_ns.route("/datasets//documents//segments/") @@ -390,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource): payload = SegmentUpdatePayload.model_validate(console_ns.payload or {}) payload_dict = payload.model_dump(exclude_none=True) SegmentService.segment_create_args_validate(payload_dict, document) + + # Update segment (summary update with change detection is handled in SegmentService.update_segment) segment = SegmentService.update_segment( SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset ) - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 @setup_required @login_required diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 932cb4fcce..e62be13c2f 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,6 +1,13 @@ -from flask_restx import Resource +from flask_restx import Resource, fields from controllers.common.schema import register_schema_model +from fields.hit_testing_fields import ( + child_chunk_fields, + document_fields, + files_fields, + hit_testing_record_fields, + segment_fields, +) from libs.login import login_required from .. import console_ns @@ -14,13 +21,45 @@ from ..wraps import ( register_schema_model(console_ns, HitTestingPayload) +def _get_or_create_model(model_name: str, field_def): + """Get or create a flask_restx model to avoid dict type issues in Swagger.""" + existing = console_ns.models.get(model_name) + if existing is None: + existing = console_ns.model(model_name, field_def) + return existing + + +# Register models for flask_restx to avoid dict type issues in Swagger +document_model = _get_or_create_model("HitTestingDocument", document_fields) + +segment_fields_copy = segment_fields.copy() +segment_fields_copy["document"] = fields.Nested(document_model) +segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy) + +child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields) +files_model = _get_or_create_model("HitTestingFile", files_fields) + +hit_testing_record_fields_copy = hit_testing_record_fields.copy() +hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model) +hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model)) +hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model)) +hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy) + +# Response model for hit testing API +hit_testing_response_fields = { + "query": fields.String, + "records": fields.List(fields.Nested(hit_testing_record_model)), +} +hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields) + + @console_ns.route("/datasets//hit-testing") class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc("test_dataset_retrieval") @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) - @console_ns.response(200, "Hit testing completed successfully") + @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model) @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @setup_required diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py new file mode 100644 index 0000000000..7207f7fd1d --- /dev/null +++ b/api/controllers/console/human_input_form.py @@ -0,0 +1,217 @@ +""" +Console/Studio Human Input Form APIs. +""" + +import json +import logging +from collections.abc import Generator + +from flask import Response, jsonify, request +from flask_restx import Resource, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker + +from controllers.console import console_ns +from controllers.console.wraps import account_initialization_required, setup_required +from controllers.web.error import InvalidArgumentError, NotFoundError +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.message_generator import MessageGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from extensions.ext_database import db +from libs.login import current_account_with_tenant, login_required +from models import App +from models.enums import CreatorUserRole +from models.human_input import RecipientType +from models.model import AppMode +from models.workflow import WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory +from services.human_input_service import Form, HumanInputService +from services.workflow_event_snapshot_service import build_workflow_event_stream + +logger = logging.getLogger(__name__) + + +def _jsonify_form_definition(form: Form) -> Response: + payload = form.get_definition().model_dump() + payload["expiration_time"] = int(form.expiration_time.timestamp()) + return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json") + + +@console_ns.route("/form/human_input/") +class ConsoleHumanInputFormApi(Resource): + """Console API for getting human input form definition.""" + + @staticmethod + def _ensure_console_access(form: Form): + _, current_tenant_id = current_account_with_tenant() + + if form.tenant_id != current_tenant_id: + raise NotFoundError("App not found") + + @setup_required + @login_required + @account_initialization_required + def get(self, form_token: str): + """ + Get human input form definition by form token. + + GET /console/api/form/human_input/ + """ + service = HumanInputService(db.engine) + form = service.get_form_definition_by_token_for_console(form_token) + if form is None: + raise NotFoundError(f"form not found, token={form_token}") + + self._ensure_console_access(form) + + return _jsonify_form_definition(form) + + @account_initialization_required + @login_required + def post(self, form_token: str): + """ + Submit human input form by form token. + + POST /console/api/form/human_input/ + + Request body: + { + "inputs": { + "content": "User input content" + }, + "action": "Approve" + } + """ + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("action", type=str, required=True, location="json") + args = parser.parse_args() + current_user, _ = current_account_with_tenant() + + service = HumanInputService(db.engine) + form = service.get_form_by_token(form_token) + if form is None: + raise NotFoundError(f"form not found, token={form_token}") + + self._ensure_console_access(form) + + recipient_type = form.recipient_type + if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}: + raise NotFoundError(f"form not found, token={form_token}") + # The type checker is not smart enought to validate the following invariant. + # So we need to assert it manually. + assert recipient_type is not None, "recipient_type cannot be None here." + + service.submit_form_by_token( + recipient_type=recipient_type, + form_token=form_token, + selected_action_id=args["action"], + form_data=args["inputs"], + submission_user_id=current_user.id, + ) + + return jsonify({}) + + +@console_ns.route("/workflow//events") +class ConsoleWorkflowEventsApi(Resource): + """Console API for getting workflow execution events after resume.""" + + @account_initialization_required + @login_required + def get(self, workflow_run_id: str): + """ + Get workflow execution events stream after resume. + + GET /console/api/workflow//events + + Returns Server-Sent Events stream. + """ + + user, tenant_id = current_account_with_tenant() + session_maker = sessionmaker(db.engine) + repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + workflow_run = repo.get_workflow_run_by_id_and_tenant_id( + tenant_id=tenant_id, + run_id=workflow_run_id, + ) + if workflow_run is None: + raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}") + + if workflow_run.created_by_role != CreatorUserRole.ACCOUNT: + raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}") + + if workflow_run.created_by != user.id: + raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}") + + with Session(expire_on_commit=False, bind=db.engine) as session: + app = _retrieve_app_for_workflow_run(session, workflow_run) + + if workflow_run.finished_at is not None: + # TODO(QuantumGhost): should we modify the handling for finished workflow run here? + response = WorkflowResponseConverter.workflow_run_result_to_finish_response( + task_id=workflow_run.id, + workflow_run=workflow_run, + creator_user=user, + ) + + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + + def _generate_finished_events() -> Generator[str, None, None]: + yield f"data: {json.dumps(payload)}\n\n" + + event_generator = _generate_finished_events + + else: + msg_generator = MessageGenerator() + if app.mode == AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + elif app.mode == AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + else: + raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") + + include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" + + def _generate_stream_events(): + if include_state_snapshot: + return generator.convert_to_event_stream( + build_workflow_event_stream( + app_mode=AppMode(app.mode), + workflow_run=workflow_run, + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + session_maker=session_maker, + ) + ) + return generator.convert_to_event_stream( + msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id), + ) + + event_generator = _generate_stream_events + + return Response( + event_generator(), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + +def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun): + query = select(App).where( + App.id == workflow_run.app_id, + App.tenant_id == workflow_run.tenant_id, + ) + app = session.scalars(query).first() + if app is None: + raise AssertionError( + f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, " + f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}" + ) + + return app diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 6a549fc926..6088b142c2 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -33,8 +33,9 @@ from core.workflow.graph_engine.manager import GraphEngineManager from extensions.ext_database import db from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model from libs import helper -from libs.helper import TimestampField +from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser +from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError @@ -63,17 +64,32 @@ class WorkflowLogQuery(BaseModel): register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) + +class WorkflowRunStatusField(fields.Raw): + def output(self, key, obj: WorkflowRun, **kwargs): + return obj.status.value + + +class WorkflowRunOutputsField(fields.Raw): + def output(self, key, obj: WorkflowRun, **kwargs): + if obj.status == WorkflowExecutionStatus.PAUSED: + return {} + + outputs = obj.outputs_dict + return outputs or {} + + workflow_run_fields = { "id": fields.String, "workflow_id": fields.String, - "status": fields.String, + "status": WorkflowRunStatusField, "inputs": fields.Raw, - "outputs": fields.Raw, + "outputs": WorkflowRunOutputsField, "error": fields.String, "total_steps": fields.Integer, "total_tokens": fields.Integer, "created_at": TimestampField, - "finished_at": TimestampField, + "finished_at": OptionalTimestampField, "elapsed_time": fields.Float, } diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 28864a140a..c11f64585a 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -46,6 +46,7 @@ class DatasetCreatePayload(BaseModel): retrieval_model: RetrievalModel | None = None embedding_model: str | None = None embedding_model_provider: str | None = None + summary_index_setting: dict | None = None class DatasetUpdatePayload(BaseModel): @@ -217,6 +218,7 @@ class DatasetListApi(DatasetApiResource): embedding_model_provider=payload.embedding_model_provider, embedding_model_name=payload.embedding_model, retrieval_model=payload.retrieval_model, + summary_index_setting=payload.summary_index_setting, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index c85c1cf81e..a01524f1bc 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -45,6 +45,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( Segmentation, ) from services.file_service import FileService +from services.summary_index_service import SummaryIndexService class DocumentTextCreatePayload(BaseModel): @@ -508,6 +509,12 @@ class DocumentListApi(DatasetApiResource): ) documents = paginated_documents.items + DocumentService.enrich_documents_with_summary_index_status( + documents=documents, + dataset=dataset, + tenant_id=tenant_id, + ) + response = { "data": marshal(documents, document_fields), "has_more": len(documents) == query_params.limit, @@ -612,6 +619,16 @@ class DocumentApi(DatasetApiResource): if metadata not in self.METADATA_CHOICES: raise InvalidMetadataError(f"Invalid metadata value: {metadata}") + # Calculate summary_index_status if needed + summary_index_status = None + has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True + if has_summary_index and document.need_summary is True: + summary_index_status = SummaryIndexService.get_document_summary_index_status( + document_id=document_id, + dataset_id=dataset_id, + tenant_id=tenant_id, + ) + if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": @@ -646,6 +663,8 @@ class DocumentApi(DatasetApiResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "summary_index_status": summary_index_status, + "need_summary": document.need_summary if document.need_summary is not None else False, } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) @@ -681,6 +700,8 @@ class DocumentApi(DatasetApiResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "summary_index_status": summary_index_status, + "need_summary": document.need_summary if document.need_summary is not None else False, } return response diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 1d22954308..cfa39e0dfd 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -23,6 +23,7 @@ from . import ( feature, files, forgot_password, + human_input_form, login, message, passport, @@ -30,6 +31,7 @@ from . import ( saved_message, site, workflow, + workflow_events, ) api.add_namespace(web_ns) @@ -44,6 +46,7 @@ __all__ = [ "feature", "files", "forgot_password", + "human_input_form", "login", "message", "passport", @@ -52,4 +55,5 @@ __all__ = [ "site", "web_ns", "workflow", + "workflow_events", ] diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 196a27e348..d1f936768e 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -117,6 +117,12 @@ class InvokeRateLimitError(BaseHTTPException): code = 429 +class WebFormRateLimitExceededError(BaseHTTPException): + error_code = "web_form_rate_limit_exceeded" + description = "Too many form requests. Please try again later." + code = 429 + + class NotFoundError(BaseHTTPException): error_code = "not_found" code = 404 diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py new file mode 100644 index 0000000000..c3989b1965 --- /dev/null +++ b/api/controllers/web/human_input_form.py @@ -0,0 +1,164 @@ +""" +Web App Human Input Form APIs. +""" + +import json +import logging +from datetime import datetime + +from flask import Response, request +from flask_restx import Resource, reqparse +from werkzeug.exceptions import Forbidden + +from configs import dify_config +from controllers.web import web_ns +from controllers.web.error import NotFoundError, WebFormRateLimitExceededError +from controllers.web.site import serialize_app_site_payload +from extensions.ext_database import db +from libs.helper import RateLimiter, extract_remote_ip +from models.account import TenantStatus +from models.model import App, Site +from services.human_input_service import Form, FormNotFoundError, HumanInputService + +logger = logging.getLogger(__name__) + +_FORM_SUBMIT_RATE_LIMITER = RateLimiter( + prefix="web_form_submit_rate_limit", + max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, + time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS, +) +_FORM_ACCESS_RATE_LIMITER = RateLimiter( + prefix="web_form_access_rate_limit", + max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, + time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS, +) + + +def _stringify_default_values(values: dict[str, object]) -> dict[str, str]: + result: dict[str, str] = {} + for key, value in values.items(): + if value is None: + result[key] = "" + elif isinstance(value, (dict, list)): + result[key] = json.dumps(value, ensure_ascii=False) + else: + result[key] = str(value) + return result + + +def _to_timestamp(value: datetime) -> int: + return int(value.timestamp()) + + +def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response: + """Return the form payload (optionally with site) as a JSON response.""" + definition_payload = form.get_definition().model_dump() + payload = { + "form_content": definition_payload["rendered_content"], + "inputs": definition_payload["inputs"], + "resolved_default_values": _stringify_default_values(definition_payload["default_values"]), + "user_actions": definition_payload["user_actions"], + "expiration_time": _to_timestamp(form.expiration_time), + } + if site_payload is not None: + payload["site"] = site_payload + return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json") + + +# TODO(QuantumGhost): disable authorization for web app +# form api temporarily + + +@web_ns.route("/form/human_input/") +# class HumanInputFormApi(WebApiResource): +class HumanInputFormApi(Resource): + """API for getting and submitting human input forms via the web app.""" + + # def get(self, _app_model: App, _end_user: EndUser, form_token: str): + def get(self, form_token: str): + """ + Get human input form definition by token. + + GET /api/form/human_input/ + """ + ip_address = extract_remote_ip(request) + if _FORM_ACCESS_RATE_LIMITER.is_rate_limited(ip_address): + raise WebFormRateLimitExceededError() + _FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address) + + service = HumanInputService(db.engine) + # TODO(QuantumGhost): forbid submision for form tokens + # that are only for console. + form = service.get_form_by_token(form_token) + + if form is None: + raise NotFoundError("Form not found") + + service.ensure_form_active(form) + app_model, site = _get_app_site_from_form(form) + + return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None)) + + # def post(self, _app_model: App, _end_user: EndUser, form_token: str): + def post(self, form_token: str): + """ + Submit human input form by token. + + POST /api/form/human_input/ + + Request body: + { + "inputs": { + "content": "User input content" + }, + "action": "Approve" + } + """ + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("action", type=str, required=True, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address): + raise WebFormRateLimitExceededError() + _FORM_SUBMIT_RATE_LIMITER.increment_rate_limit(ip_address) + + service = HumanInputService(db.engine) + form = service.get_form_by_token(form_token) + if form is None: + raise NotFoundError("Form not found") + + if (recipient_type := form.recipient_type) is None: + logger.warning("Recipient type is None for form, form_id=%", form.id) + raise AssertionError("Recipient type is None") + + try: + service.submit_form_by_token( + recipient_type=recipient_type, + form_token=form_token, + selected_action_id=args["action"], + form_data=args["inputs"], + submission_end_user_id=None, + # submission_end_user_id=_end_user.id, + ) + except FormNotFoundError: + raise NotFoundError("Form not found") + + return {}, 200 + + +def _get_app_site_from_form(form: Form) -> tuple[App, Site]: + """Resolve App/Site for the form's app and validate tenant status.""" + app_model = db.session.query(App).where(App.id == form.app_id).first() + if app_model is None or app_model.tenant_id != form.tenant_id: + raise NotFoundError("Form not found") + + site = db.session.query(Site).where(Site.app_id == app_model.id).first() + if site is None: + raise Forbidden() + + if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE: + raise Forbidden() + + return app_model, site diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index b01aaba357..f957229ece 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,6 @@ -from flask_restx import fields, marshal_with +from typing import cast + +from flask_restx import fields, marshal, marshal_with from werkzeug.exceptions import Forbidden from configs import dify_config @@ -7,7 +9,7 @@ from controllers.web.wraps import WebApiResource from extensions.ext_database import db from libs.helper import AppIconUrlField from models.account import TenantStatus -from models.model import Site +from models.model import App, Site from services.feature_service import FeatureService @@ -108,3 +110,14 @@ class AppSiteInfo: "remove_webapp_brand": remove_webapp_brand, "replace_webapp_logo": replace_webapp_logo, } + + +def serialize_site(site: Site) -> dict: + """Serialize Site model using the same schema as AppSiteApi.""" + return cast(dict, marshal(site, AppSiteApi.site_fields)) + + +def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict: + can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo + app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo) + return cast(dict, marshal(app_site_info, AppSiteApi.app_fields)) diff --git a/api/controllers/web/workflow_events.py b/api/controllers/web/workflow_events.py new file mode 100644 index 0000000000..61568e70e6 --- /dev/null +++ b/api/controllers/web/workflow_events.py @@ -0,0 +1,112 @@ +""" +Web App Workflow Resume APIs. +""" + +import json +from collections.abc import Generator + +from flask import Response, request +from sqlalchemy.orm import sessionmaker + +from controllers.web import api +from controllers.web.error import InvalidArgumentError, NotFoundError +from controllers.web.wraps import WebApiResource +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.message_generator import MessageGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from extensions.ext_database import db +from models.enums import CreatorUserRole +from models.model import App, AppMode, EndUser +from repositories.factory import DifyAPIRepositoryFactory +from services.workflow_event_snapshot_service import build_workflow_event_stream + + +class WorkflowEventsApi(WebApiResource): + """API for getting workflow execution events after resume.""" + + def get(self, app_model: App, end_user: EndUser, task_id: str): + """ + Get workflow execution events stream after resume. + + GET /api/workflow//events + + Returns Server-Sent Events stream. + """ + workflow_run_id = task_id + session_maker = sessionmaker(db.engine) + repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + workflow_run = repo.get_workflow_run_by_id_and_tenant_id( + tenant_id=app_model.tenant_id, + run_id=workflow_run_id, + ) + + if workflow_run is None: + raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}") + + if workflow_run.app_id != app_model.id: + raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}") + + if workflow_run.created_by_role != CreatorUserRole.END_USER: + raise NotFoundError(f"WorkflowRun not created by end user, id={workflow_run_id}") + + if workflow_run.created_by != end_user.id: + raise NotFoundError(f"WorkflowRun not created by the current end user, id={workflow_run_id}") + + if workflow_run.finished_at is not None: + response = WorkflowResponseConverter.workflow_run_result_to_finish_response( + task_id=workflow_run.id, + workflow_run=workflow_run, + creator_user=end_user, + ) + + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + + def _generate_finished_events() -> Generator[str, None, None]: + yield f"data: {json.dumps(payload)}\n\n" + + event_generator = _generate_finished_events + else: + app_mode = AppMode.value_of(app_model.mode) + msg_generator = MessageGenerator() + generator: BaseAppGenerator + if app_mode == AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + elif app_mode == AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + else: + raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") + + include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" + + def _generate_stream_events(): + if include_state_snapshot: + return generator.convert_to_event_stream( + build_workflow_event_stream( + app_mode=app_mode, + workflow_run=workflow_run, + tenant_id=app_model.tenant_id, + app_id=app_model.id, + session_maker=session_maker, + ) + ) + return generator.convert_to_event_stream( + msg_generator.retrieve_events(app_mode, workflow_run.id), + ) + + event_generator = _generate_stream_events + + return Response( + event_generator(), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + +# Register the APIs +api.add_resource(WorkflowEventsApi, "/workflow//events") diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 528c45f6c8..2891d3ceeb 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -4,8 +4,8 @@ import contextvars import logging import threading import uuid -from collections.abc import Generator, Mapping -from typing import TYPE_CHECKING, Any, Literal, Union, overload +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -29,21 +29,25 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory +from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.runtime import GraphRuntimeState from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom +from models.base import Base from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.workflow_draft_variable_service import ( @@ -65,7 +69,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user: Union[Account, EndUser], args: Mapping[str, Any], invoke_from: InvokeFrom, + workflow_run_id: str, streaming: Literal[False], + pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any]: ... @overload @@ -74,9 +80,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: Mapping, + args: Mapping[str, Any], invoke_from: InvokeFrom, + workflow_run_id: str, streaming: Literal[True], + pause_state_config: PauseStateLayerConfig | None = None, ) -> Generator[Mapping | str, None, None]: ... @overload @@ -85,9 +93,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: Mapping, + args: Mapping[str, Any], invoke_from: InvokeFrom, + workflow_run_id: str, streaming: bool, + pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ... def generate( @@ -95,9 +105,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: Mapping, + args: Mapping[str, Any], invoke_from: InvokeFrom, + workflow_run_id: str, streaming: bool = True, + pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: """ Generate App response. @@ -161,7 +173,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode app_config.additional_features.show_retrieve_source = True # type: ignore - workflow_run_id = str(uuid.uuid4()) # init application generate entity application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), @@ -179,7 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, extras=extras, trace_manager=trace_manager, - workflow_run_id=workflow_run_id, + workflow_run_id=str(workflow_run_id), ) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -216,6 +227,38 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=conversation, stream=streaming, + pause_state_config=pause_state_config, + ) + + def resume( + self, + *, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + conversation: Conversation, + message: Message, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + graph_runtime_state: GraphRuntimeState, + pause_state_config: PauseStateLayerConfig | None = None, + ): + """ + Resume a paused advanced chat execution. + """ + return self._generate( + workflow=workflow, + user=user, + invoke_from=application_generate_entity.invoke_from, + application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + conversation=conversation, + message=message, + stream=application_generate_entity.stream, + pause_state_config=pause_state_config, + graph_runtime_state=graph_runtime_state, ) def single_iteration_generate( @@ -396,8 +439,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Conversation | None = None, + message: Message | None = None, stream: bool = True, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, + pause_state_config: PauseStateLayerConfig | None = None, + graph_runtime_state: GraphRuntimeState | None = None, + graph_engine_layers: Sequence[GraphEngineLayer] = (), ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: """ Generate App response. @@ -411,12 +458,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param conversation: conversation :param stream: is stream """ - is_first_conversation = False - if not conversation: - is_first_conversation = True + is_first_conversation = conversation is None - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + if conversation is not None and message is not None: + pass + else: + conversation, message = self._init_generate_records(application_generate_entity, conversation) if is_first_conversation: # update conversation features @@ -439,6 +486,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) + # new thread with request context and contextvars context = contextvars.copy_context() @@ -454,14 +511,25 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): "variable_loader": variable_loader, "workflow_execution_repository": workflow_execution_repository, "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, }, ) worker_thread.start() # release database connection, because the following new thread operations may take a long time - db.session.refresh(workflow) - db.session.refresh(message) + with Session(bind=db.engine, expire_on_commit=False) as session: + workflow = _refresh_model(session, workflow) + message = _refresh_model(session, message) + # workflow_ = session.get(Workflow, workflow.id) + # assert workflow_ is not None + # workflow = workflow_ + # message_ = session.get(Message, message.id) + # assert message_ is not None + # message = message_ + # db.session.refresh(workflow) + # db.session.refresh(message) # db.session.refresh(user) db.session.close() @@ -490,6 +558,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): variable_loader: VariableLoader, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, + graph_engine_layers: Sequence[GraphEngineLayer] = (), + graph_runtime_state: GraphRuntimeState | None = None, ): """ Generate worker in a new thread. @@ -547,6 +617,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app=app, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, + graph_engine_layers=graph_engine_layers, + graph_runtime_state=graph_runtime_state, ) try: @@ -614,3 +686,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id) raise e + + +_T = TypeVar("_T", bound=Base) + + +def _refresh_model(session, model: _T) -> _T: + with Session(bind=db.engine, expire_on_commit=False) as session: + detach_model = session.get(type(model), model.id) + assert detach_model is not None + return detach_model diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d702db0908..8b20442eab 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -66,6 +66,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, graph_engine_layers: Sequence[GraphEngineLayer] = (), + graph_runtime_state: GraphRuntimeState | None = None, ): super().__init__( queue_manager=queue_manager, @@ -82,6 +83,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._app = app self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + self._resume_graph_runtime_state = graph_runtime_state @trace_span(WorkflowAppRunnerHandler) def run(self): @@ -110,7 +112,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): invoke_from = InvokeFrom.DEBUGGER user_from = self._resolve_user_from(invoke_from) - if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + resume_state = self._resume_graph_runtime_state + + if resume_state is not None: + graph_runtime_state = resume_state + variable_pool = graph_runtime_state.variable_pool + graph = self._init_graph( + graph_config=self._workflow.graph_dict, + graph_runtime_state=graph_runtime_state, + workflow_id=self._workflow.id, + tenant_id=self._workflow.tenant_id, + user_id=self.application_generate_entity.user_id, + invoke_from=invoke_from, + user_from=user_from, + ) + elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: # Handle single iteration or single loop run graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index da1e9f19b6..00a6a3d9af 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -24,6 +24,8 @@ from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueAnnotationReplyEvent, QueueErrorEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -42,6 +44,7 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, WorkflowQueueMessage, @@ -63,6 +66,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager +from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import WorkflowExecutionStatus from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory @@ -71,7 +76,8 @@ from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, MessageStatus +from models.execution_extra_content import HumanInputContent from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -128,6 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) self._task_state = WorkflowTaskState() + self._seed_task_state_from_message(message) self._message_cycle_manager = MessageCycleManager( application_generate_entity=application_generate_entity, task_state=self._task_state ) @@ -135,6 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._application_generate_entity = application_generate_entity self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict + self._workflow_tenant_id = workflow.tenant_id self._conversation_id = conversation.id self._conversation_mode = conversation.mode self._message_id = message.id @@ -144,8 +152,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._workflow_run_id: str = "" self._draft_var_saver_factory = draft_var_saver_factory self._graph_runtime_state: GraphRuntimeState | None = None + self._message_saved_on_pause = False self._seed_graph_runtime_state_from_queue_manager() + def _seed_task_state_from_message(self, message: Message) -> None: + if message.status == MessageStatus.PAUSED and message.answer: + self._task_state.answer = message.answer + def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. @@ -308,6 +321,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): task_id=self._application_generate_entity.task_id, workflow_run_id=run_id, workflow_id=self._workflow_id, + reason=event.reason, ) yield workflow_start_resp @@ -525,6 +539,35 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) yield workflow_finish_resp + + def _handle_workflow_paused_event( + self, + event: QueueWorkflowPausedEvent, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow paused events.""" + validated_state = self._ensure_graph_runtime_initialized() + responses = self._workflow_response_converter.workflow_pause_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + graph_runtime_state=validated_state, + ) + for reason in event.reasons: + if isinstance(reason, HumanInputRequired): + self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id) + yield from responses + resolved_state: GraphRuntimeState | None = None + try: + resolved_state = self._ensure_graph_runtime_initialized() + except ValueError: + resolved_state = None + + with self._database_session() as session: + self._save_message(session=session, graph_runtime_state=resolved_state) + message = self._get_message(session=session) + if message is not None: + message.status = MessageStatus.PAUSED + self._message_saved_on_pause = True self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) def _handle_workflow_failed_event( @@ -614,9 +657,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, ) - # Save message - with self._database_session() as session: - self._save_message(session=session, graph_runtime_state=resolved_state) + # Save message unless it has already been persisted on pause. + if not self._message_saved_on_pause: + with self._database_session() as session: + self._save_message(session=session, graph_runtime_state=resolved_state) yield self._message_end_to_stream_response() @@ -642,6 +686,65 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """Handle message replace events.""" yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason) + def _handle_human_input_form_filled_event( + self, event: QueueHumanInputFormFilledEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle human input form filled events.""" + self._persist_human_input_extra_content(node_id=event.node_id) + yield self._workflow_response_converter.human_input_form_filled_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id + ) + + def _handle_human_input_form_timeout_event( + self, event: QueueHumanInputFormTimeoutEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle human input form timeout events.""" + yield self._workflow_response_converter.human_input_form_timeout_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id + ) + + def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None: + if not self._workflow_run_id or not self._message_id: + return + + if form_id is None: + if node_id is None: + return + form_id = self._load_human_input_form_id(node_id=node_id) + if form_id is None: + logger.warning( + "HumanInput form not found for workflow run %s node %s", + self._workflow_run_id, + node_id, + ) + return + + with self._database_session() as session: + exists_stmt = select(HumanInputContent).where( + HumanInputContent.workflow_run_id == self._workflow_run_id, + HumanInputContent.message_id == self._message_id, + HumanInputContent.form_id == form_id, + ) + if session.scalar(exists_stmt) is not None: + return + + content = HumanInputContent( + workflow_run_id=self._workflow_run_id, + message_id=self._message_id, + form_id=form_id, + ) + session.add(content) + + def _load_human_input_form_id(self, *, node_id: str) -> str | None: + form_repository = HumanInputFormRepositoryImpl( + session_factory=db.engine, + tenant_id=self._workflow_tenant_id, + ) + form = form_repository.get_form(self._workflow_run_id, node_id) + if form is None: + return None + return form.id + def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: """Handle agent log events.""" yield self._workflow_response_converter.handle_agent_log( @@ -659,6 +762,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueWorkflowStartedEvent: self._handle_workflow_started_event, QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event, + QueueWorkflowPausedEvent: self._handle_workflow_paused_event, QueueWorkflowFailedEvent: self._handle_workflow_failed_event, # Node events QueueNodeRetryEvent: self._handle_node_retry_event, @@ -680,6 +784,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueMessageReplaceEvent: self._handle_message_replace_event, QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event, QueueAgentLogEvent: self._handle_agent_log_event, + QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event, + QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event, } def _dispatch_event( @@ -747,6 +853,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): case QueueWorkflowFailedEvent(): yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager) break + case QueueWorkflowPausedEvent(): + yield from self._handle_workflow_paused_event(event) + break case QueueStopEvent(): yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager) @@ -772,6 +881,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None): message = self._get_message(session=session) + if message is None: + return + + if message.status == MessageStatus.PAUSED: + message.status = MessageStatus.NORMAL # If there are assistant files, remove markdown image links from answer answer_text = self._task_state.answer diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 74c6d2eca6..d1e2f16b6f 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -79,6 +79,7 @@ class AppGenerateResponseConverter(ABC): "document_name": resource["document_name"], "score": resource["score"], "content": resource["content"], + "summary": resource.get("summary"), } ) metadata["retriever_resources"] = updated_resources diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 38ecec5d30..6d329063f8 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -5,9 +5,14 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( QueueAgentLogEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -19,9 +24,13 @@ from core.app.entities.queue_entities import ( QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueWorkflowPausedEvent, ) from core.app.entities.task_entities import ( AgentLogStreamResponse, + HumanInputFormFilledResponse, + HumanInputFormTimeoutResponse, + HumanInputRequiredResponse, IterationNodeCompletedStreamResponse, IterationNodeNextStreamResponse, IterationNodeStartStreamResponse, @@ -31,7 +40,9 @@ from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, + StreamResponse, WorkflowFinishStreamResponse, + WorkflowPauseStreamResponse, WorkflowStartStreamResponse, ) from core.file import FILE_MODEL_IDENTITY, File @@ -40,6 +51,8 @@ from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.trigger_manager import TriggerManager from core.variables.segments import ArrayFileSegment, FileSegment, Segment +from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import ( NodeType, SystemVariableKey, @@ -51,8 +64,11 @@ from core.workflow.runtime import GraphRuntimeState from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, EndUser +from models.human_input import HumanInputForm +from models.workflow import WorkflowRun from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator NodeExecutionId = NewType("NodeExecutionId", str) @@ -191,6 +207,7 @@ class WorkflowResponseConverter: task_id: str, workflow_run_id: str, workflow_id: str, + reason: WorkflowStartReason, ) -> WorkflowStartStreamResponse: run_id = self._ensure_workflow_run_id(workflow_run_id) started_at = naive_utc_now() @@ -204,6 +221,7 @@ class WorkflowResponseConverter: workflow_id=workflow_id, inputs=self._workflow_inputs, created_at=int(started_at.timestamp()), + reason=reason, ), ) @@ -264,6 +282,160 @@ class WorkflowResponseConverter: ), ) + def workflow_pause_to_stream_response( + self, + *, + event: QueueWorkflowPausedEvent, + task_id: str, + graph_runtime_state: GraphRuntimeState, + ) -> list[StreamResponse]: + run_id = self._ensure_workflow_run_id() + started_at = self._workflow_started_at + if started_at is None: + raise ValueError( + "workflow_pause_to_stream_response called before workflow_start_to_stream_response", + ) + paused_at = naive_utc_now() + elapsed_time = (paused_at - started_at).total_seconds() + encoded_outputs = self._encode_outputs(event.outputs) or {} + if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API: + encoded_outputs = {} + pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons] + human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)] + expiration_times_by_form_id: dict[str, datetime] = {} + if human_input_form_ids: + stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( + HumanInputForm.id.in_(human_input_form_ids) + ) + with Session(bind=db.engine) as session: + for form_id, expiration_time in session.execute(stmt): + expiration_times_by_form_id[str(form_id)] = expiration_time + + responses: list[StreamResponse] = [] + + for reason in event.reasons: + if isinstance(reason, HumanInputRequired): + expiration_time = expiration_times_by_form_id.get(reason.form_id) + if expiration_time is None: + raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}") + responses.append( + HumanInputRequiredResponse( + task_id=task_id, + workflow_run_id=run_id, + data=HumanInputRequiredResponse.Data( + form_id=reason.form_id, + node_id=reason.node_id, + node_title=reason.node_title, + form_content=reason.form_content, + inputs=reason.inputs, + actions=reason.actions, + display_in_ui=reason.display_in_ui, + form_token=reason.form_token, + resolved_default_values=reason.resolved_default_values, + expiration_time=int(expiration_time.timestamp()), + ), + ) + ) + + responses.append( + WorkflowPauseStreamResponse( + task_id=task_id, + workflow_run_id=run_id, + data=WorkflowPauseStreamResponse.Data( + workflow_run_id=run_id, + paused_nodes=list(event.paused_nodes), + outputs=encoded_outputs, + reasons=pause_reasons, + status=WorkflowExecutionStatus.PAUSED.value, + created_at=int(started_at.timestamp()), + elapsed_time=elapsed_time, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + ), + ) + ) + + return responses + + def human_input_form_filled_to_stream_response( + self, *, event: QueueHumanInputFormFilledEvent, task_id: str + ) -> HumanInputFormFilledResponse: + run_id = self._ensure_workflow_run_id() + return HumanInputFormFilledResponse( + task_id=task_id, + workflow_run_id=run_id, + data=HumanInputFormFilledResponse.Data( + node_id=event.node_id, + node_title=event.node_title, + rendered_content=event.rendered_content, + action_id=event.action_id, + action_text=event.action_text, + ), + ) + + def human_input_form_timeout_to_stream_response( + self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str + ) -> HumanInputFormTimeoutResponse: + run_id = self._ensure_workflow_run_id() + return HumanInputFormTimeoutResponse( + task_id=task_id, + workflow_run_id=run_id, + data=HumanInputFormTimeoutResponse.Data( + node_id=event.node_id, + node_title=event.node_title, + expiration_time=int(event.expiration_time.timestamp()), + ), + ) + + @classmethod + def workflow_run_result_to_finish_response( + cls, + *, + task_id: str, + workflow_run: WorkflowRun, + creator_user: Account | EndUser, + ) -> WorkflowFinishStreamResponse: + run_id = workflow_run.id + elapsed_time = workflow_run.elapsed_time + + encoded_outputs = workflow_run.outputs_dict + finished_at = workflow_run.finished_at + assert finished_at is not None + + created_by: Mapping[str, object] + user = creator_user + if isinstance(user, Account): + created_by = { + "id": user.id, + "name": user.name, + "email": user.email, + } + else: + created_by = { + "id": user.id, + "user": user.session_id, + } + + return WorkflowFinishStreamResponse( + task_id=task_id, + workflow_run_id=run_id, + data=WorkflowFinishStreamResponse.Data( + id=run_id, + workflow_id=workflow_run.workflow_id, + status=workflow_run.status.value, + outputs=encoded_outputs, + error=workflow_run.error, + elapsed_time=elapsed_time, + total_tokens=workflow_run.total_tokens, + total_steps=workflow_run.total_steps, + created_by=created_by, + created_at=int(workflow_run.created_at.timestamp()), + finished_at=int(finished_at.timestamp()), + files=cls.fetch_files_from_node_outputs(encoded_outputs), + exceptions_count=workflow_run.exceptions_count, + ), + ) + def workflow_node_start_to_stream_response( self, *, @@ -592,7 +764,8 @@ class WorkflowResponseConverter: ), ) - def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: + @classmethod + def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -601,7 +774,7 @@ class WorkflowResponseConverter: if not outputs_dict: return [] - files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] + files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] # Remove None files = [file for file in files if file] # Flatten list diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 57617d8863..4e9a191dae 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,6 +1,6 @@ import json import logging -from collections.abc import Generator +from collections.abc import Callable, Generator, Mapping from typing import Union, cast from sqlalchemy import select @@ -10,12 +10,14 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.streaming_utils import stream_topic_events from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, AppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, + ConversationAppGenerateEntity, InvokeFrom, ) from core.app.entities.task_entities import ( @@ -27,6 +29,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db +from extensions.ext_redis import get_pubsub_broadcast_channel +from libs.broadcast_channel.channel import Topic from libs.datetime_utils import naive_utc_now from models import Account from models.enums import CreatorUserRole @@ -156,6 +160,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): query = application_generate_entity.query or "New conversation" conversation_name = (query[:20] + "…") if len(query) > 20 else query + created_new_conversation = conversation is None try: if not conversation: conversation = Conversation( @@ -232,6 +237,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add_all(message_files) db.session.commit() + + if isinstance(application_generate_entity, ConversationAppGenerateEntity): + application_generate_entity.conversation_id = conversation.id + application_generate_entity.is_new_conversation = created_new_conversation return conversation, message except Exception: db.session.rollback() @@ -284,3 +293,29 @@ class MessageBasedAppGenerator(BaseAppGenerator): raise MessageNotExistsError("Message not exists") return message + + @staticmethod + def _make_channel_key(app_mode: AppMode, workflow_run_id: str): + return f"channel:{app_mode}:{workflow_run_id}" + + @classmethod + def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic: + key = cls._make_channel_key(app_mode, workflow_run_id) + channel = get_pubsub_broadcast_channel() + topic = channel.topic(key) + return topic + + @classmethod + def retrieve_events( + cls, + app_mode: AppMode, + workflow_run_id: str, + idle_timeout=300, + on_subscribe: Callable[[], None] | None = None, + ) -> Generator[Mapping | str, None, None]: + topic = cls.get_response_topic(app_mode, workflow_run_id) + return stream_topic_events( + topic=topic, + idle_timeout=idle_timeout, + on_subscribe=on_subscribe, + ) diff --git a/api/core/app/apps/message_generator.py b/api/core/app/apps/message_generator.py new file mode 100644 index 0000000000..68631bb230 --- /dev/null +++ b/api/core/app/apps/message_generator.py @@ -0,0 +1,36 @@ +from collections.abc import Callable, Generator, Mapping + +from core.app.apps.streaming_utils import stream_topic_events +from extensions.ext_redis import get_pubsub_broadcast_channel +from libs.broadcast_channel.channel import Topic +from models.model import AppMode + + +class MessageGenerator: + @staticmethod + def _make_channel_key(app_mode: AppMode, workflow_run_id: str): + return f"channel:{app_mode}:{str(workflow_run_id)}" + + @classmethod + def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic: + key = cls._make_channel_key(app_mode, workflow_run_id) + channel = get_pubsub_broadcast_channel() + topic = channel.topic(key) + return topic + + @classmethod + def retrieve_events( + cls, + app_mode: AppMode, + workflow_run_id: str, + idle_timeout=300, + ping_interval: float = 10.0, + on_subscribe: Callable[[], None] | None = None, + ) -> Generator[Mapping | str, None, None]: + topic = cls.get_response_topic(app_mode, workflow_run_id) + return stream_topic_events( + topic=topic, + idle_timeout=idle_timeout, + ping_interval=ping_interval, + on_subscribe=on_subscribe, + ) diff --git a/api/core/app/apps/streaming_utils.py b/api/core/app/apps/streaming_utils.py new file mode 100644 index 0000000000..57d4b537a4 --- /dev/null +++ b/api/core/app/apps/streaming_utils.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import json +import time +from collections.abc import Callable, Generator, Iterable, Mapping +from typing import Any + +from core.app.entities.task_entities import StreamEvent +from libs.broadcast_channel.channel import Topic +from libs.broadcast_channel.exc import SubscriptionClosedError + + +def stream_topic_events( + *, + topic: Topic, + idle_timeout: float, + ping_interval: float | None = None, + on_subscribe: Callable[[], None] | None = None, + terminal_events: Iterable[str | StreamEvent] | None = None, +) -> Generator[Mapping[str, Any] | str, None, None]: + # send a PING event immediately to prevent the connection staying in pending state for a long time. + # + # This simplify the debugging process as the DevTools in Chrome does not + # provide complete curl command for pending connections. + yield StreamEvent.PING.value + + terminal_values = _normalize_terminal_events(terminal_events) + last_msg_time = time.time() + last_ping_time = last_msg_time + with topic.subscribe() as sub: + # on_subscribe fires only after the Redis subscription is active. + # This is used to gate task start and reduce pub/sub race for the first event. + if on_subscribe is not None: + on_subscribe() + while True: + try: + msg = sub.receive(timeout=0.1) + except SubscriptionClosedError: + return + if msg is None: + current_time = time.time() + if current_time - last_msg_time > idle_timeout: + return + if ping_interval is not None and current_time - last_ping_time >= ping_interval: + yield StreamEvent.PING.value + last_ping_time = current_time + continue + + last_msg_time = time.time() + last_ping_time = last_msg_time + event = json.loads(msg) + yield event + if not isinstance(event, dict): + continue + + event_type = event.get("event") + if event_type in terminal_values: + return + + +def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]: + if not terminal_events: + return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value} + values: set[str] = set() + for item in terminal_events: + if isinstance(item, StreamEvent): + values.add(item.value) + else: + values.add(str(item)) + return values diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index ee205ed153..dc5852d552 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -25,6 +25,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError @@ -34,12 +35,15 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.runtime import GraphRuntimeState from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts -from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom +from models.account import Account from models.enums import WorkflowRunTriggeredFrom +from models.model import App, EndUser +from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService if TYPE_CHECKING: @@ -66,9 +70,11 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[True], call_depth: int, + workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + pause_state_config: PauseStateLayerConfig | None = None, ) -> Generator[Mapping[str, Any] | str, None, None]: ... @overload @@ -82,9 +88,11 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[False], call_depth: int, + workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any]: ... @overload @@ -98,9 +106,11 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool, call_depth: int, + workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ... def generate( @@ -113,9 +123,11 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool = True, call_depth: int = 0, + workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: files: Sequence[Mapping[str, Any]] = args.get("files") or [] @@ -150,7 +162,7 @@ class WorkflowAppGenerator(BaseAppGenerator): extras = { **extract_external_trace_id_from_args(args), } - workflow_run_id = str(uuid.uuid4()) + workflow_run_id = str(workflow_run_id or uuid.uuid4()) # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args # trigger shouldn't prepare user inputs if self._should_prepare_user_inputs(args): @@ -216,13 +228,40 @@ class WorkflowAppGenerator(BaseAppGenerator): streaming=streaming, root_node_id=root_node_id, graph_engine_layers=graph_engine_layers, + pause_state_config=pause_state_config, ) - def resume(self, *, workflow_run_id: str) -> None: + def resume( + self, + *, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + application_generate_entity: WorkflowAppGenerateEntity, + graph_runtime_state: GraphRuntimeState, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + graph_engine_layers: Sequence[GraphEngineLayer] = (), + pause_state_config: PauseStateLayerConfig | None = None, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, + ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ - @TBD + Resume a paused workflow execution using the persisted runtime state. """ - pass + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=application_generate_entity.invoke_from, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=application_generate_entity.stream, + variable_loader=variable_loader, + graph_engine_layers=graph_engine_layers, + graph_runtime_state=graph_runtime_state, + pause_state_config=pause_state_config, + ) def _generate( self, @@ -238,6 +277,8 @@ class WorkflowAppGenerator(BaseAppGenerator): variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + graph_runtime_state: GraphRuntimeState | None = None, + pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -251,6 +292,8 @@ class WorkflowAppGenerator(BaseAppGenerator): :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream """ + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + # init queue manager queue_manager = WorkflowAppQueueManager( task_id=application_generate_entity.task_id, @@ -259,6 +302,15 @@ class WorkflowAppGenerator(BaseAppGenerator): app_mode=app_model.mode, ) + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) + # new thread with request context and contextvars context = contextvars.copy_context() @@ -276,7 +328,8 @@ class WorkflowAppGenerator(BaseAppGenerator): "root_node_id": root_node_id, "workflow_execution_repository": workflow_execution_repository, "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": graph_engine_layers, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, }, ) @@ -378,6 +431,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, + pause_state_config=None, ) def single_loop_generate( @@ -459,6 +513,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, + pause_state_config=None, ) def _generate_worker( @@ -472,6 +527,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + graph_runtime_state: GraphRuntimeState | None = None, ) -> None: """ Generate worker in a new thread. @@ -517,6 +573,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, root_node_id=root_node_id, graph_engine_layers=graph_engine_layers, + graph_runtime_state=graph_runtime_state, ) try: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 0ee3c177f2..a43f7879d6 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -42,6 +42,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, graph_engine_layers: Sequence[GraphEngineLayer] = (), + graph_runtime_state: GraphRuntimeState | None = None, ): super().__init__( queue_manager=queue_manager, @@ -55,6 +56,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self._root_node_id = root_node_id self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + self._resume_graph_runtime_state = graph_runtime_state @trace_span(WorkflowAppRunnerHandler) def run(self): @@ -63,23 +65,28 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): """ app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) - - system_inputs = SystemVariable( - files=self.application_generate_entity.files, - user_id=self._sys_user_id, - app_id=app_config.app_id, - timestamp=int(naive_utc_now().timestamp()), - workflow_id=app_config.workflow_id, - workflow_execution_id=self.application_generate_entity.workflow_execution_id, - ) - invoke_from = self.application_generate_entity.invoke_from # if only single iteration or single loop run is requested if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: invoke_from = InvokeFrom.DEBUGGER user_from = self._resolve_user_from(invoke_from) - if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + resume_state = self._resume_graph_runtime_state + + if resume_state is not None: + graph_runtime_state = resume_state + variable_pool = graph_runtime_state.variable_pool + graph = self._init_graph( + graph_config=self._workflow.graph_dict, + graph_runtime_state=graph_runtime_state, + workflow_id=self._workflow.id, + tenant_id=self._workflow.tenant_id, + user_id=self.application_generate_entity.user_id, + user_from=user_from, + invoke_from=invoke_from, + root_node_id=self._root_node_id, + ) + elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, @@ -89,7 +96,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): inputs = self.application_generate_entity.inputs # Create a variable pool. - + system_inputs = SystemVariable( + files=self.application_generate_entity.files, + user_id=self._sys_user_id, + app_id=app_config.app_id, + timestamp=int(naive_utc_now().timestamp()), + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_execution_id, + ) variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, @@ -98,8 +112,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # init graph graph = self._init_graph( graph_config=self._workflow.graph_dict, graph_runtime_state=graph_runtime_state, diff --git a/api/core/app/apps/workflow/errors.py b/api/core/app/apps/workflow/errors.py new file mode 100644 index 0000000000..16cd864209 --- /dev/null +++ b/api/core/app/apps/workflow/errors.py @@ -0,0 +1,7 @@ +from libs.exception import BaseHTTPException + + +class WorkflowPausedInBlockingModeError(BaseHTTPException): + error_code = "workflow_paused_in_blocking_mode" + description = "Workflow execution paused for human input; blocking response mode is not supported." + code = 400 diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 842ad545ad..0a567a4315 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -16,6 +16,8 @@ from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentLogEvent, QueueErrorEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -32,6 +34,7 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, WorkflowQueueMessage, @@ -46,11 +49,13 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, + WorkflowPauseStreamResponse, WorkflowStartStreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import WorkflowExecutionStatus from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.runtime import GraphRuntimeState @@ -132,6 +137,25 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): for stream_response in generator: if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err + elif isinstance(stream_response, WorkflowPauseStreamResponse): + response = WorkflowAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + workflow_run_id=stream_response.data.workflow_run_id, + data=WorkflowAppBlockingResponse.Data( + id=stream_response.data.workflow_run_id, + workflow_id=self._workflow.id, + status=stream_response.data.status, + outputs=stream_response.data.outputs or {}, + error=None, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + created_at=stream_response.data.created_at, + finished_at=None, + ), + ) + + return response elif isinstance(stream_response, WorkflowFinishStreamResponse): response = WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, @@ -146,7 +170,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): total_tokens=stream_response.data.total_tokens, total_steps=stream_response.data.total_steps, created_at=int(stream_response.data.created_at), - finished_at=int(stream_response.data.finished_at), + finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None, ), ) @@ -259,13 +283,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): run_id = self._extract_workflow_run_id(runtime_state) self._workflow_execution_id = run_id - with self._database_session() as session: - self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id) + if event.reason == WorkflowStartReason.INITIAL: + with self._database_session() as session: + self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id) start_resp = self._workflow_response_converter.workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run_id=run_id, workflow_id=self._workflow.id, + reason=event.reason, ) yield start_resp @@ -440,6 +466,21 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) yield workflow_finish_resp + def _handle_workflow_paused_event( + self, + event: QueueWorkflowPausedEvent, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow paused events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized() + responses = self._workflow_response_converter.workflow_pause_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + graph_runtime_state=validated_state, + ) + yield from responses + def _handle_workflow_failed_and_stop_events( self, event: Union[QueueWorkflowFailedEvent, QueueStopEvent], @@ -495,6 +536,22 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): task_id=self._application_generate_entity.task_id, event=event ) + def _handle_human_input_form_filled_event( + self, event: QueueHumanInputFormFilledEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle human input form filled events.""" + yield self._workflow_response_converter.human_input_form_filled_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id + ) + + def _handle_human_input_form_timeout_event( + self, event: QueueHumanInputFormTimeoutEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle human input form timeout events.""" + yield self._workflow_response_converter.human_input_form_timeout_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id + ) + def _get_event_handlers(self) -> dict[type, Callable]: """Get mapping of event types to their handlers using fluent pattern.""" return { @@ -506,6 +563,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueWorkflowStartedEvent: self._handle_workflow_started_event, QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event, + QueueWorkflowPausedEvent: self._handle_workflow_paused_event, # Node events QueueNodeRetryEvent: self._handle_node_retry_event, QueueNodeStartedEvent: self._handle_node_started_event, @@ -520,6 +578,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueLoopCompletedEvent: self._handle_loop_completed_event, # Agent events QueueAgentLogEvent: self._handle_agent_log_event, + QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event, + QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event, } def _dispatch_event( @@ -602,6 +662,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): case QueueWorkflowFailedEvent(): yield from self._handle_workflow_failed_and_stop_events(event) break + case QueueWorkflowPausedEvent(): + yield from self._handle_workflow_paused_event(event) + break case QueueStopEvent(): yield from self._handle_workflow_failed_and_stop_events(event) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 13b7865f55..c9d7464c17 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,3 +1,4 @@ +import logging import time from collections.abc import Mapping, Sequence from typing import Any, cast @@ -7,6 +8,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -22,22 +25,27 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams +from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.graph import Graph from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events import ( GraphEngineEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, + GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunAgentLogEvent, NodeRunExceptionEvent, NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, NodeRunIterationFailedEvent, NodeRunIterationNextEvent, NodeRunIterationStartedEvent, @@ -61,6 +69,9 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, from core.workflow.workflow_entry import WorkflowEntry from models.enums import UserFrom from models.workflow import Workflow +from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task + +logger = logging.getLogger(__name__) class WorkflowBasedAppRunner: @@ -327,7 +338,7 @@ class WorkflowBasedAppRunner: :param event: event """ if isinstance(event, GraphRunStartedEvent): - self._publish_event(QueueWorkflowStartedEvent()) + self._publish_event(QueueWorkflowStartedEvent(reason=event.reason)) elif isinstance(event, GraphRunSucceededEvent): self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) elif isinstance(event, GraphRunPartialSucceededEvent): @@ -338,6 +349,38 @@ class WorkflowBasedAppRunner: self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) elif isinstance(event, GraphRunAbortedEvent): self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0)) + elif isinstance(event, GraphRunPausedEvent): + runtime_state = workflow_entry.graph_engine.graph_runtime_state + paused_nodes = runtime_state.get_paused_nodes() + self._enqueue_human_input_notifications(event.reasons) + self._publish_event( + QueueWorkflowPausedEvent( + reasons=event.reasons, + outputs=event.outputs, + paused_nodes=paused_nodes, + ) + ) + elif isinstance(event, NodeRunHumanInputFormFilledEvent): + self._publish_event( + QueueHumanInputFormFilledEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_title=event.node_title, + rendered_content=event.rendered_content, + action_id=event.action_id, + action_text=event.action_text, + ) + ) + elif isinstance(event, NodeRunHumanInputFormTimeoutEvent): + self._publish_event( + QueueHumanInputFormTimeoutEvent( + node_id=event.node_id, + node_type=event.node_type, + node_title=event.node_title, + expiration_time=event.expiration_time, + ) + ) elif isinstance(event, NodeRunRetryEvent): node_run_result = event.node_run_result inputs = node_run_result.inputs @@ -544,5 +587,19 @@ class WorkflowBasedAppRunner: ) ) + def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None: + for reason in reasons: + if not isinstance(reason, HumanInputRequired): + continue + if not reason.form_id: + continue + try: + dispatch_human_input_email_task.apply_async( + kwargs={"form_id": reason.form_id, "node_title": reason.node_title}, + queue="mail", + ) + except Exception: # pragma: no cover - defensive logging + logger.exception("Failed to enqueue human input email task for form %s", reason.form_id) + def _publish_event(self, event: AppQueueEvent): self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 5bc453420d..0e68e554c8 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel): extras: dict[str, Any] = Field(default_factory=dict) # tracing instance - trace_manager: Optional["TraceQueueManager"] = None + trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False) class EasyUIBasedAppGenerateEntity(AppGenerateEntity): @@ -156,6 +156,7 @@ class ConversationAppGenerateEntity(AppGenerateEntity): """ conversation_id: str | None = None + is_new_conversation: bool = False parent_message_id: str | None = Field( default=None, description=( diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 77d6bf03b4..5b2fa29b56 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -8,6 +8,8 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit +from core.workflow.entities.pause_reason import PauseReason +from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import WorkflowNodeExecutionMetadataKey from core.workflow.nodes import NodeType @@ -46,6 +48,9 @@ class QueueEvent(StrEnum): PING = "ping" STOP = "stop" RETRY = "retry" + PAUSE = "pause" + HUMAN_INPUT_FORM_FILLED = "human_input_form_filled" + HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout" class AppQueueEvent(BaseModel): @@ -261,6 +266,8 @@ class QueueWorkflowStartedEvent(AppQueueEvent): """QueueWorkflowStartedEvent entity.""" event: QueueEvent = QueueEvent.WORKFLOW_STARTED + # Always present; mirrors GraphRunStartedEvent.reason for downstream consumers. + reason: WorkflowStartReason = WorkflowStartReason.INITIAL class QueueWorkflowSucceededEvent(AppQueueEvent): @@ -484,6 +491,35 @@ class QueueStopEvent(AppQueueEvent): return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.") +class QueueHumanInputFormFilledEvent(AppQueueEvent): + """ + QueueHumanInputFormFilledEvent entity + """ + + event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED + + node_execution_id: str + node_id: str + node_type: NodeType + node_title: str + rendered_content: str + action_id: str + action_text: str + + +class QueueHumanInputFormTimeoutEvent(AppQueueEvent): + """ + QueueHumanInputFormTimeoutEvent entity + """ + + event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT + + node_id: str + node_type: NodeType + node_title: str + expiration_time: datetime + + class QueueMessage(BaseModel): """ QueueMessage abstract entity @@ -509,3 +545,14 @@ class WorkflowQueueMessage(QueueMessage): """ pass + + +class QueueWorkflowPausedEvent(AppQueueEvent): + """ + QueueWorkflowPausedEvent entity + """ + + event: QueueEvent = QueueEvent.PAUSE + reasons: Sequence[PauseReason] = Field(default_factory=list) + outputs: Mapping[str, object] = Field(default_factory=dict) + paused_nodes: Sequence[str] = Field(default_factory=list) diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 79a5e657b3..3f38904d2f 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit +from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): @@ -69,6 +71,7 @@ class StreamEvent(StrEnum): AGENT_THOUGHT = "agent_thought" AGENT_MESSAGE = "agent_message" WORKFLOW_STARTED = "workflow_started" + WORKFLOW_PAUSED = "workflow_paused" WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" @@ -82,6 +85,9 @@ class StreamEvent(StrEnum): TEXT_CHUNK = "text_chunk" TEXT_REPLACE = "text_replace" AGENT_LOG = "agent_log" + HUMAN_INPUT_REQUIRED = "human_input_required" + HUMAN_INPUT_FORM_FILLED = "human_input_form_filled" + HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout" class StreamResponse(BaseModel): @@ -205,6 +211,8 @@ class WorkflowStartStreamResponse(StreamResponse): workflow_id: str inputs: Mapping[str, Any] created_at: int + # Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients. + reason: WorkflowStartReason = WorkflowStartReason.INITIAL event: StreamEvent = StreamEvent.WORKFLOW_STARTED workflow_run_id: str @@ -231,7 +239,7 @@ class WorkflowFinishStreamResponse(StreamResponse): total_steps: int created_by: Mapping[str, object] = Field(default_factory=dict) created_at: int - finished_at: int + finished_at: int | None exceptions_count: int | None = 0 files: Sequence[Mapping[str, Any]] | None = [] @@ -240,6 +248,85 @@ class WorkflowFinishStreamResponse(StreamResponse): data: Data +class WorkflowPauseStreamResponse(StreamResponse): + """ + WorkflowPauseStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + workflow_run_id: str + paused_nodes: Sequence[str] = Field(default_factory=list) + outputs: Mapping[str, Any] = Field(default_factory=dict) + reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list) + status: str + created_at: int + elapsed_time: float + total_tokens: int + total_steps: int + + event: StreamEvent = StreamEvent.WORKFLOW_PAUSED + workflow_run_id: str + data: Data + + +class HumanInputRequiredResponse(StreamResponse): + class Data(BaseModel): + """ + Data entity + """ + + form_id: str + node_id: str + node_title: str + form_content: str + inputs: Sequence[FormInput] = Field(default_factory=list) + actions: Sequence[UserAction] = Field(default_factory=list) + display_in_ui: bool = False + form_token: str | None = None + resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) + expiration_time: int = Field(..., description="Unix timestamp in seconds") + + event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED + workflow_run_id: str + data: Data + + +class HumanInputFormFilledResponse(StreamResponse): + class Data(BaseModel): + """ + Data entity + """ + + node_id: str + node_title: str + rendered_content: str + action_id: str + action_text: str + + event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED + workflow_run_id: str + data: Data + + +class HumanInputFormTimeoutResponse(StreamResponse): + class Data(BaseModel): + """ + Data entity + """ + + node_id: str + node_title: str + expiration_time: int + + event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT + workflow_run_id: str + data: Data + + class NodeStartStreamResponse(StreamResponse): """ NodeStartStreamResponse entity @@ -726,7 +813,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): total_tokens: int total_steps: int created_at: int - finished_at: int + finished_at: int | None workflow_run_id: str data: Data diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 565905be0d..2ca1275a8a 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -1,3 +1,4 @@ +import contextlib import logging import time import uuid @@ -103,6 +104,14 @@ class RateLimit: ) +@contextlib.contextmanager +def rate_limit_context(rate_limit: RateLimit, request_id: str | None): + request_id = rate_limit.enter(request_id) + yield + if request_id is not None: + rate_limit.exit(request_id) + + class RateLimitGenerator: def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str): self.rate_limit = rate_limit diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index bf76ae8178..1c267091a4 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Annotated, Literal, Self, TypeAlias from pydantic import BaseModel, Field @@ -52,6 +53,14 @@ class WorkflowResumptionContext(BaseModel): return self.generate_entity.entity +@dataclass(frozen=True) +class PauseStateLayerConfig: + """Configuration container for instantiating pause persistence layers.""" + + session_factory: Engine | sessionmaker[Session] + state_owner_user_id: str + + class PauseStatePersistenceLayer(GraphEngineLayer): def __init__( self, diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 2d4ee08daf..d682083f34 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -82,10 +82,11 @@ class MessageCycleManager: if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): return None - is_first_message = self._application_generate_entity.conversation_id is None + is_first_message = self._application_generate_entity.is_new_conversation extras = self._application_generate_entity.extras auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True) + thread: Thread | None = None if auto_generate_conversation_name and is_first_message: # start generate thread # time.sleep not block other logic @@ -101,9 +102,10 @@ class MessageCycleManager: thread.daemon = True thread.start() - return thread + if is_first_message: + self._application_generate_entity.is_new_conversation = False - return None + return thread def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py new file mode 100644 index 0000000000..46006f4381 --- /dev/null +++ b/api/core/entities/execution_extra_content.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any, TypeAlias + +from pydantic import BaseModel, ConfigDict, Field + +from core.workflow.nodes.human_input.entities import FormInput, UserAction +from models.execution_extra_content import ExecutionContentType + + +class HumanInputFormDefinition(BaseModel): + model_config = ConfigDict(frozen=True) + + form_id: str + node_id: str + node_title: str + form_content: str + inputs: Sequence[FormInput] = Field(default_factory=list) + actions: Sequence[UserAction] = Field(default_factory=list) + display_in_ui: bool = False + form_token: str | None = None + resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) + expiration_time: int + + +class HumanInputFormSubmissionData(BaseModel): + model_config = ConfigDict(frozen=True) + + node_id: str + node_title: str + rendered_content: str + action_id: str + action_text: str + + +class HumanInputContent(BaseModel): + model_config = ConfigDict(frozen=True) + + workflow_run_id: str + submitted: bool + form_definition: HumanInputFormDefinition | None = None + form_submission_data: HumanInputFormSubmissionData | None = None + type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT) + + +ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent + +__all__ = [ + "ExecutionExtraContentDomainModel", + "HumanInputContent", + "HumanInputFormDefinition", + "HumanInputFormSubmissionData", +] diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index d4093b5245..b1ba3c3e2a 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator class PreviewDetail(BaseModel): content: str + summary: str | None = None child_chunks: list[str] | None = None diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index e8d41b9387..8a26b2e91b 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import ( ) from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now +from models.engine import db from models.provider import ( LoadBalancingModelConfig, Provider, diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 120fb73cdb..c0fefef3d0 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -104,6 +104,8 @@ def download(f: File, /): ): return _download_file_content(f.storage_key) elif f.transfer_method == FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() return response.content @@ -134,6 +136,8 @@ def _download_file_content(path: str, /): def _get_encoded_string(f: File, /): match f.transfer_method: case FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() data = response.content diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 128c64ff2c..ddccfbaf45 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -4,8 +4,10 @@ Proxy requests to avoid SSRF import logging import time +from typing import Any, TypeAlias import httpx +from pydantic import TypeAdapter, ValidationError from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client @@ -18,6 +20,9 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] +Headers: TypeAlias = dict[str, str] +_HEADERS_ADAPTER = TypeAdapter(Headers) + _SSL_VERIFIED_POOL_KEY = "ssrf:verified" _SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified" _SSRF_CLIENT_LIMITS = httpx.Limits( @@ -76,7 +81,7 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client: ) -def _get_user_provided_host_header(headers: dict | None) -> str | None: +def _get_user_provided_host_header(headers: Headers | None) -> str | None: """ Extract the user-provided Host header from the headers dict. @@ -92,7 +97,7 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None: return None -def _inject_trace_headers(headers: dict | None) -> dict: +def _inject_trace_headers(headers: Headers | None) -> Headers: """ Inject W3C traceparent header for distributed tracing. @@ -125,7 +130,7 @@ def _inject_trace_headers(headers: dict | None) -> dict: return headers -def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: # Convert requests-style allow_redirects to httpx-style follow_redirects if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") @@ -142,10 +147,15 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): # prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY) + if not isinstance(verify_option, bool): + raise ValueError("ssl_verify must be a boolean") client = _get_ssrf_client(verify_option) # Inject traceparent header for distributed tracing (when OTEL is not enabled) - headers = kwargs.get("headers") or {} + try: + headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {}) + except ValidationError as e: + raise ValueError("headers must be a mapping of string keys to string values") from e headers = _inject_trace_headers(headers) kwargs["headers"] = headers @@ -198,25 +208,25 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") -def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("GET", url, max_retries=max_retries, **kwargs) -def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("POST", url, max_retries=max_retries, **kwargs) -def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("PUT", url, max_retries=max_retries, **kwargs) -def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("PATCH", url, max_retries=max_retries, **kwargs) -def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("DELETE", url, max_retries=max_retries, **kwargs) -def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("HEAD", url, max_retries=max_retries, **kwargs) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f1b50f360b..e172e88298 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -311,14 +311,18 @@ class IndexingRunner: qa_preview_texts: list[QAPreviewDetail] = [] total_segments = 0 + # doc_form represents the segmentation method (general, parent-child, QA) index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() + # one extract_setting is one source document for extract_setting in extract_settings: # extract processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) + # Extract document content text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) + # Cleaning and segmentation documents = index_processor.transform( text_docs, current_user=None, @@ -361,6 +365,12 @@ class IndexingRunner: if doc_form and doc_form == "qa_model": return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[]) + + # Generate summary preview + summary_index_setting = tmp_processing_rule.get("summary_index_setting") + if summary_index_setting and summary_index_setting.get("enable") and preview_texts: + preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting) + return IndexingEstimate(total_segments=total_segments, preview=preview_texts) def _extract( diff --git a/api/core/llm_generator/entities.py b/api/core/llm_generator/entities.py new file mode 100644 index 0000000000..3bb8d2c899 --- /dev/null +++ b/api/core/llm_generator/entities.py @@ -0,0 +1,20 @@ +"""Shared payload models for LLM generator helpers and controllers.""" + +from pydantic import BaseModel, Field + +from core.app.app_config.entities import ModelConfig + + +class RuleGeneratePayload(BaseModel): + instruction: str = Field(..., description="Rule generation instruction") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") + no_variable: bool = Field(default=False, description="Whether to exclude variables") + + +class RuleCodeGeneratePayload(RuleGeneratePayload): + code_language: str = Field(default="javascript", description="Programming language for code generation") + + +class RuleStructuredOutputPayload(BaseModel): + instruction: str = Field(..., description="Structured output generation instruction") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index be1e306d47..5b2c640265 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -6,6 +6,8 @@ from typing import Protocol, cast import json_repair +from core.app.app_config.entities import ModelConfig +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.prompts import ( @@ -151,19 +153,19 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool): + def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload): output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} - model_parameters = model_config.get("completion_params", {}) - if no_variable: + model_parameters = args.model_config_data.completion_params + if args.no_variable: prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, }, remove_template_variables=False, ) @@ -175,8 +177,8 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) try: @@ -190,7 +192,7 @@ class LLMGenerator: error = str(e) error_step = "generate rule config" except Exception as e: - logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) + logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) rule_config["error"] = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -209,7 +211,7 @@ class LLMGenerator: # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, }, remove_template_variables=False, ) @@ -220,8 +222,8 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) try: @@ -250,7 +252,7 @@ class LLMGenerator: # the second step to generate the task_parameter and task_statement statement_generate_prompt = statement_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, "INPUT_TEXT": prompt_content.message.get_text_content(), }, remove_template_variables=False, @@ -276,7 +278,7 @@ class LLMGenerator: error_step = "generate conversation opener" except Exception as e: - logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) + logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) rule_config["error"] = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -284,16 +286,20 @@ class LLMGenerator: return rule_config @classmethod - def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): - if code_language == "python": + def generate_code( + cls, + tenant_id: str, + args: RuleCodeGeneratePayload, + ): + if args.code_language == "python": prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) else: prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE) prompt = prompt_template.format( inputs={ - "INSTRUCTION": instruction, - "CODE_LANGUAGE": code_language, + "INSTRUCTION": args.instruction, + "CODE_LANGUAGE": args.code_language, }, remove_template_variables=False, ) @@ -302,28 +308,28 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) prompt_messages = [UserPromptMessage(content=prompt)] - model_parameters = model_config.get("completion_params", {}) + model_parameters = args.model_config_data.completion_params try: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ) generated_code = response.message.get_text_content() - return {"code": generated_code, "language": code_language, "error": ""} + return {"code": generated_code, "language": args.code_language, "error": ""} except InvokeError as e: error = str(e) - return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} + return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"} except Exception as e: logger.exception( - "Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language + "Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language ) - return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} + return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"} @classmethod def generate_qa_document(cls, tenant_id: str, query, document_language: str): @@ -353,20 +359,20 @@ class LLMGenerator: return answer.strip() @classmethod - def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict): + def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) prompt_messages = [ SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), - UserPromptMessage(content=instruction), + UserPromptMessage(content=args.instruction), ] - model_parameters = model_config.get("model_parameters", {}) + model_parameters = args.model_config_data.completion_params try: response: LLMResult = model_instance.invoke_llm( @@ -390,12 +396,17 @@ class LLMGenerator: error = str(e) return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} except Exception as e: - logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) + logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name) return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} @staticmethod def instruction_modify_legacy( - tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None + tenant_id: str, + flow_id: str, + current: str, + instruction: str, + model_config: ModelConfig, + ideal_output: str | None, ): last_run: Message | None = ( db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() @@ -434,7 +445,7 @@ class LLMGenerator: node_id: str, current: str, instruction: str, - model_config: dict, + model_config: ModelConfig, ideal_output: str | None, workflow_service: WorkflowServiceInterface, ): @@ -505,7 +516,7 @@ class LLMGenerator: @staticmethod def __instruction_modify_common( tenant_id: str, - model_config: dict, + model_config: ModelConfig, last_run: dict | None, current: str | None, error_message: str | None, @@ -526,8 +537,8 @@ class LLMGenerator: model_instance = ModelManager().get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=model_config.provider, + model=model_config.name, ) match node_type: case "llm" | "agent": @@ -570,7 +581,5 @@ class LLMGenerator: error = str(e) return {"error": f"Failed to generate code. Error: {error}"} except Exception as e: - logger.exception( - "Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True - ) + logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True) return {"error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ec2b7f2d44..d46cf049dd 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -434,3 +434,20 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex You should edit the prompt according to the IDEAL OUTPUT.""" INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}.""" + +DEFAULT_GENERATOR_SUMMARY_PROMPT = ( + """Summarize the following content. Extract only the key information and main points. """ + """Remove redundant details. + +Requirements: +1. Write a concise summary in plain text +2. Use the same language as the input content +3. Focus on important facts, concepts, and details +4. If images are included, describe their key information +5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions" +6. Write directly without extra words + +Output only the summary text. Start summarizing now: + +""" +) diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 84a6fd0d1f..e1a40593e7 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -347,7 +347,7 @@ class BaseSession( message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) ) - responder = RequestResponder( + responder = RequestResponder[ReceiveRequestT, SendResultT]( request_id=message.message.root.id, request_meta=validated_request.root.params.meta if validated_request.root.params else None, request=validated_request, diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 45f0335c2e..c3e50eaddd 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,10 +1,11 @@ import decimal import hashlib -from threading import Lock +import logging -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from redis import RedisError -import contexts +from configs import dify_config from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.model_entities import ( @@ -24,6 +25,9 @@ from core.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) class AIModel(BaseModel): @@ -144,34 +148,60 @@ class AIModel(BaseModel): plugin_model_manager = PluginModelClient() cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" - # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) + cached_schema_json = None try: - contexts.plugin_model_schemas.get() - except LookupError: - contexts.plugin_model_schemas.set({}) - contexts.plugin_model_schema_lock.set(Lock()) - - with contexts.plugin_model_schema_lock.get(): - if cache_key in contexts.plugin_model_schemas.get(): - return contexts.plugin_model_schemas.get()[cache_key] - - schema = plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials or {}, + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, ) + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning( + "Failed to validate cached plugin model schema for model %s", + model, + exc_info=True, + ) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) - if schema: - contexts.plugin_model_schemas.get()[cache_key] = schema + schema = plugin_model_manager.get_model_schema( + tenant_id=self.tenant_id, + user_id="unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model_type=self.model_type.value, + model=model, + credentials=credentials or {}, + ) - return schema + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 7a0757f219..bbbdec61d1 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -92,6 +92,10 @@ def _build_llm_result_from_first_chunk( Build a single `LLMResult` from the first returned chunk. This is used for `stream=False` because the plugin side may still implement the response via a chunked stream. + + Note: + This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying + streaming resources are released (e.g., HTTP connections owned by the plugin runtime). """ content = "" content_list: list[PromptMessageContentUnionTypes] = [] @@ -99,18 +103,25 @@ def _build_llm_result_from_first_chunk( system_fingerprint: str | None = None tools_calls: list[AssistantPromptMessage.ToolCall] = [] - first_chunk = next(chunks, None) - if first_chunk is not None: - if isinstance(first_chunk.delta.message.content, str): - content += first_chunk.delta.message.content - elif isinstance(first_chunk.delta.message.content, list): - content_list.extend(first_chunk.delta.message.content) + try: + first_chunk = next(chunks, None) + if first_chunk is not None: + if isinstance(first_chunk.delta.message.content, str): + content += first_chunk.delta.message.content + elif isinstance(first_chunk.delta.message.content, list): + content_list.extend(first_chunk.delta.message.content) - if first_chunk.delta.message.tool_calls: - _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) + if first_chunk.delta.message.tool_calls: + _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) - usage = first_chunk.delta.usage or LLMUsage.empty_usage() - system_fingerprint = first_chunk.system_fingerprint + usage = first_chunk.delta.usage or LLMUsage.empty_usage() + system_fingerprint = first_chunk.system_fingerprint + finally: + try: + for _ in chunks: + pass + except Exception: + logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True) return LLMResult( model=model, @@ -283,7 +294,7 @@ class LargeLanguageModel(AIModel): # TODO raise self._transform_invoke_error(e) - if stream and isinstance(result, Generator): + if stream and not isinstance(result, LLMResult): return self._invoke_result_generator( model=model, result=result, diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 28f162a928..9cfc6889ac 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -5,7 +5,11 @@ import logging from collections.abc import Sequence from threading import Lock +from pydantic import ValidationError +from redis import RedisError + import contexts +from configs import dify_config from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -18,6 +22,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from extensions.ext_redis import redis_client from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) @@ -175,34 +180,60 @@ class ModelProviderFactory: """ plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" - # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) + cached_schema_json = None try: - contexts.plugin_model_schemas.get() - except LookupError: - contexts.plugin_model_schemas.set({}) - contexts.plugin_model_schema_lock.set(Lock()) - - with contexts.plugin_model_schema_lock.get(): - if cache_key in contexts.plugin_model_schemas.get(): - return contexts.plugin_model_schemas.get()[cache_key] - - schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, - model=model, - credentials=credentials or {}, + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, ) + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning( + "Failed to validate cached plugin model schema for model %s", + model, + exc_info=True, + ) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) - if schema: - contexts.plugin_model_schemas.get()[cache_key] = schema + schema = self.plugin_model_manager.get_model_schema( + tenant_id=self.tenant_id, + user_id="unknown", + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials or {}, + ) - return schema + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema def get_models( self, @@ -283,6 +314,8 @@ class ModelProviderFactory: elif model_type == ModelType.TTS: return TTSModel.model_validate(init_params) + raise ValueError(f"Unsupported model type: {model_type}") + def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: """ Get provider icon diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 84f5bf5512..549e428f88 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -15,10 +15,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token -from core.ops.entities.config_entity import ( - OPS_FILE_PATH, - TracingProviderEnum, -) +from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -31,8 +28,8 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import get_message_data -from extensions.ext_database import db from extensions.ext_storage import storage +from models.engine import db from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks @@ -469,6 +466,8 @@ class TraceTask: @classmethod def _get_workflow_run_repo(cls): + from repositories.factory import DifyAPIRepositoryFactory + if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 631e3b77b2..a5196d66c0 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse from sqlalchemy import select -from extensions.ext_database import db +from models.engine import db from models.model import Message diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 32e8ef385c..3c5df2b905 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,3 +1,4 @@ +import uuid from collections.abc import Generator, Mapping from typing import Union @@ -11,6 +12,7 @@ from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from extensions.ext_database import db from models import Account @@ -101,6 +103,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): if not workflow: raise ValueError("unexpected app type") + pause_config = PauseStateLayerConfig( + session_factory=db.engine, + state_owner_user_id=workflow.created_by, + ) + return AdvancedChatAppGenerator().generate( app_model=app, workflow=workflow, @@ -112,7 +119,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): "conversation_id": conversation_id, }, invoke_from=InvokeFrom.SERVICE_API, + workflow_run_id=str(uuid.uuid4()), streaming=stream, + pause_state_config=pause_config, ) elif app.mode == AppMode.AGENT_CHAT: return AgentChatAppGenerator().generate( @@ -159,6 +168,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): if not workflow: raise ValueError("unexpected app type") + pause_config = PauseStateLayerConfig( + session_factory=db.engine, + state_owner_user_id=workflow.created_by, + ) + return WorkflowAppGenerator().generate( app_model=app, workflow=workflow, @@ -167,6 +181,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): invoke_from=InvokeFrom.SERVICE_API, streaming=stream, call_depth=1, + pause_state_config=pause_config, ) @classmethod diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 8ec1ce6242..91c16ce079 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -24,7 +24,13 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding +from models.dataset import ( + ChildChunk, + Dataset, + DocumentSegment, + DocumentSegmentSummary, + SegmentAttachmentBinding, +) from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService @@ -389,15 +395,15 @@ class RetrievalService: .all() } - records = [] - include_segment_ids = set() - segment_child_map = {} - valid_dataset_documents = {} image_doc_ids: list[Any] = [] child_index_node_ids = [] index_node_ids = [] doc_to_document_map = {} + summary_segment_ids = set() # Track segments retrieved via summary + summary_score_map: dict[str, float] = {} # Map original_chunk_id to summary score + + # First pass: collect all document IDs and identify summary documents for document in documents: document_id = document.metadata.get("document_id") if document_id not in dataset_documents: @@ -408,16 +414,39 @@ class RetrievalService: continue valid_dataset_documents[document_id] = dataset_document + doc_id = document.metadata.get("doc_id") or "" + doc_to_document_map[doc_id] = document + + # Check if this is a summary document + is_summary = document.metadata.get("is_summary", False) + if is_summary: + # For summary documents, find the original chunk via original_chunk_id + original_chunk_id = document.metadata.get("original_chunk_id") + if original_chunk_id: + summary_segment_ids.add(original_chunk_id) + # Save summary's score for later use + summary_score = document.metadata.get("score") + if summary_score is not None: + try: + summary_score_float = float(summary_score) + # If the same segment has multiple summary hits, take the highest score + if original_chunk_id not in summary_score_map: + summary_score_map[original_chunk_id] = summary_score_float + else: + summary_score_map[original_chunk_id] = max( + summary_score_map[original_chunk_id], summary_score_float + ) + except (ValueError, TypeError): + # Skip invalid score values + pass + continue # Skip adding to other lists for summary documents + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - doc_id = document.metadata.get("doc_id") or "" - doc_to_document_map[doc_id] = document if document.metadata.get("doc_type") == DocType.IMAGE: image_doc_ids.append(doc_id) else: child_index_node_ids.append(doc_id) else: - doc_id = document.metadata.get("doc_id") or "" - doc_to_document_map[doc_id] = document if document.metadata.get("doc_type") == DocType.IMAGE: image_doc_ids.append(doc_id) else: @@ -433,6 +462,7 @@ class RetrievalService: attachment_map: dict[str, list[dict[str, Any]]] = {} child_chunk_map: dict[str, list[ChildChunk]] = {} doc_segment_map: dict[str, list[str]] = {} + segment_summary_map: dict[str, str] = {} # Map segment_id to summary content with session_factory.create_session() as session: attachments = cls.get_segment_attachment_infos(image_doc_ids, session) @@ -447,6 +477,7 @@ class RetrievalService: doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"]) else: doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]] + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids)) child_index_nodes = session.execute(child_chunk_stmt).scalars().all() @@ -470,6 +501,7 @@ class RetrievalService: index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore for index_node_segment in index_node_segments: doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id] + if segment_ids: document_segment_stmt = select(DocumentSegment).where( DocumentSegment.enabled == True, @@ -481,6 +513,40 @@ class RetrievalService: if index_node_segments: segments.extend(index_node_segments) + # Handle summary documents: query segments by original_chunk_id + if summary_segment_ids: + summary_segment_ids_list = list(summary_segment_ids) + summary_segment_stmt = select(DocumentSegment).where( + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id.in_(summary_segment_ids_list), + ) + summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore + segments.extend(summary_segments) + # Add summary segment IDs to segment_ids for summary query + for seg in summary_segments: + if seg.id not in segment_ids: + segment_ids.append(seg.id) + + # Batch query summaries for segments retrieved via summary (only enabled summaries) + if summary_segment_ids: + summaries = ( + session.query(DocumentSegmentSummary) + .filter( + DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)), + DocumentSegmentSummary.status == "completed", + DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries + ) + .all() + ) + for summary in summaries: + if summary.summary_content: + segment_summary_map[summary.chunk_id] = summary.summary_content + + include_segment_ids = set() + segment_child_map: dict[str, dict[str, Any]] = {} + records: list[dict[str, Any]] = [] + for segment in segments: child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) @@ -489,30 +555,44 @@ class RetrievalService: if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) + # Check if this segment was retrieved via summary + # Use summary score as base score if available, otherwise 0.0 + max_score = summary_score_map.get(segment.id, 0.0) + if child_chunks or attachment_infos: child_chunk_details = [] - max_score = 0.0 for child_chunk in child_chunks: - document = doc_to_document_map[child_chunk.index_node_id] + child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id) + if child_document: + child_score = child_document.metadata.get("score", 0.0) + else: + child_score = 0.0 child_chunk_detail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, - "score": document.metadata.get("score", 0.0) if document else 0.0, + "score": child_score, } child_chunk_details.append(child_chunk_detail) - max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0) + max_score = max(max_score, child_score) for attachment_info in attachment_infos: - file_document = doc_to_document_map[attachment_info["id"]] - max_score = max( - max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0 - ) + file_document = doc_to_document_map.get(attachment_info["id"]) + if file_document: + max_score = max(max_score, file_document.metadata.get("score", 0.0)) map_detail = { "max_score": max_score, "child_chunks": child_chunk_details, } segment_child_map[segment.id] = map_detail + else: + # No child chunks or attachments, use summary score if available + summary_score = summary_score_map.get(segment.id) + if summary_score is not None: + segment_child_map[segment.id] = { + "max_score": summary_score, + "child_chunks": [], + } record: dict[str, Any] = { "segment": segment, } @@ -520,14 +600,23 @@ class RetrievalService: else: if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) - max_score = 0.0 - segment_document = doc_to_document_map.get(segment.index_node_id) - if segment_document: - max_score = max(max_score, segment_document.metadata.get("score", 0.0)) + + # Check if this segment was retrieved via summary + # Use summary score if available (summary retrieval takes priority) + max_score = summary_score_map.get(segment.id, 0.0) + + # If not retrieved via summary, use original segment's score + if segment.id not in summary_score_map: + segment_document = doc_to_document_map.get(segment.index_node_id) + if segment_document: + max_score = max(max_score, segment_document.metadata.get("score", 0.0)) + + # Also consider attachment scores for attachment_info in attachment_infos: file_doc = doc_to_document_map.get(attachment_info["id"]) if file_doc: max_score = max(max_score, file_doc.metadata.get("score", 0.0)) + record = { "segment": segment, "score": max_score, @@ -576,9 +665,16 @@ class RetrievalService: else None ) + # Extract summary if this segment was retrieved via summary + summary_content = segment_summary_map.get(segment.id) + # Create RetrievalSegments object retrieval_segment = RetrievalSegments( - segment=segment, child_chunks=child_chunks_list, score=score, files=files + segment=segment, + child_chunks=child_chunks_list, + score=score, + files=files, + summary=summary_content, ) result.append(retrieval_segment) diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index b54a37b49e..f6834ab87b 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel): child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None files: list[dict[str, str | int]] | None = None + summary: str | None = None # Summary content if retrieved via summary index diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index 9f66cd9a03..aec5c353f8 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -22,3 +22,4 @@ class RetrievalSourceMetadata(BaseModel): doc_metadata: dict[str, Any] | None = None title: str | None = None files: list[dict[str, Any]] | None = None + summary: str | None = None diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 511f5a698d..1ddbfc5864 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,4 +1,7 @@ -"""Abstract interface for document loader implementations.""" +"""Word (.docx) document extractor used for RAG ingestion. + +Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). +""" import logging import mimetypes @@ -8,7 +11,6 @@ import tempfile import uuid from urllib.parse import urlparse -import httpx from docx import Document as DocxDocument from docx.oxml.ns import qn from docx.text.run import Run @@ -44,7 +46,7 @@ class WordExtractor(BaseExtractor): # If the file is a web path, download it to a temporary file, and use that if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): - response = httpx.get(self.file_path, timeout=None) + response = ssrf_proxy.get(self.file_path) if response.status_code != 200: response.close() @@ -55,6 +57,7 @@ class WordExtractor(BaseExtractor): self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115 try: self.temp_file.write(response.content) + self.temp_file.flush() finally: response.close() self.file_path = self.temp_file.name diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index e36b54eedd..151a3de7d9 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse import httpx from configs import dify_config +from core.entities.knowledge_entities import PreviewDetail from core.helper import ssrf_proxy from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.doc_type import DocType @@ -45,6 +46,17 @@ class BaseIndexProcessor(ABC): def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: raise NotImplementedError + @abstractmethod + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + For each segment in preview_texts, generate a summary using LLM and attach it to the segment. + The summary can be stored in a new attribute, e.g., summary. + This method should be implemented by subclasses. + """ + raise NotImplementedError + @abstractmethod def load( self, diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index cf68cff7dc..ab91e29145 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,9 +1,27 @@ """Paragraph index processor.""" +import logging +import re import uuid from collections.abc import Mapping -from typing import Any +from typing import Any, cast +logger = logging.getLogger(__name__) + +from core.entities.knowledge_entities import PreviewDetail +from core.file import File, FileTransferMethod, FileType, file_manager +from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.provider_manager import ProviderManager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService @@ -17,12 +35,17 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.workflow.nodes.llm import llm_utils +from extensions.ext_database import db +from factories.file_factory import build_from_mapping from libs import helper +from models import UploadFile from models.account import Account -from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule +from services.summary_index_service import SummaryIndexService class ParagraphIndexProcessor(BaseIndexProcessor): @@ -108,6 +131,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword.add_texts(documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: @@ -227,3 +273,322 @@ class ParagraphIndexProcessor(BaseIndexProcessor): } else: raise ValueError("Chunks is not a list") + + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + For each segment, concurrently call generate_summary to generate a summary + and write it to the summary attribute of PreviewDetail. + In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception. + """ + import concurrent.futures + + from flask import current_app + + # Capture Flask app context for worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def process(preview: PreviewDetail) -> None: + """Generate summary for a single preview item.""" + if flask_app: + # Ensure Flask app context in worker thread + with flask_app.app_context(): + summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + preview.summary = summary + else: + # Fallback: try without app context (may fail) + summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + preview.summary = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_texts)) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor: + futures = [executor.submit(process, preview) for preview in preview_texts] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode (indexing-estimate), if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise ValueError(error_summary) + + return preview_texts + + @staticmethod + def generate_summary( + tenant_id: str, + text: str, + summary_index_setting: dict | None = None, + segment_id: str | None = None, + ) -> tuple[str, LLMUsage]: + """ + Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt, + and supports vision models by including images from the segment attachments or text content. + + Args: + tenant_id: Tenant ID + text: Text content to summarize + summary_index_setting: Summary index configuration + segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table + + Returns: + Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object + """ + if not summary_index_setting or not summary_index_setting.get("enable"): + raise ValueError("summary_index_setting is required and must be enabled to generate summary.") + + model_name = summary_index_setting.get("model_name") + model_provider_name = summary_index_setting.get("model_provider_name") + summary_prompt = summary_index_setting.get("summary_prompt") + + if not model_name or not model_provider_name: + raise ValueError("model_name and model_provider_name are required in summary_index_setting") + + # Import default summary prompt + if not summary_prompt: + summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT + + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id, model_provider_name, ModelType.LLM + ) + model_instance = ModelInstance(provider_model_bundle, model_name) + + # Get model schema to check if vision is supported + model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials) + supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features + + # Extract images if model supports vision + image_files = [] + if supports_vision: + # First, try to get images from SegmentAttachmentBinding (preferred method) + if segment_id: + image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id) + + # If no images from attachments, fall back to extracting from text + if not image_files: + image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text) + + # Build prompt messages + prompt_messages = [] + + if image_files: + # If we have images, create a UserPromptMessage with both text and images + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] + + # Add images first + for file in image_files: + try: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW + ) + prompt_message_contents.append(file_content) + except Exception as e: + logger.warning("Failed to convert image file to prompt message content: %s", str(e)) + continue + + # Add text content + if prompt_message_contents: # Only add text if we successfully added images + prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}")) + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + # If image conversion failed, fall back to text-only + prompt = f"{summary_prompt}\n{text}" + prompt_messages.append(UserPromptMessage(content=prompt)) + else: + # No images, use simple text prompt + prompt = f"{summary_prompt}\n{text}" + prompt_messages.append(UserPromptMessage(content=prompt)) + + result = model_instance.invoke_llm( + prompt_messages=cast(list[PromptMessage], prompt_messages), model_parameters={}, stream=False + ) + + # Type assertion: when stream=False, invoke_llm returns LLMResult, not Generator + if not isinstance(result, LLMResult): + raise ValueError("Expected LLMResult when stream=False") + + summary_content = getattr(result.message, "content", "") + usage = result.usage + + # Deduct quota for summary generation (same as workflow nodes) + try: + llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + except Exception as e: + # Log but don't fail summary generation if quota deduction fails + logger.warning("Failed to deduct quota for summary generation: %s", str(e)) + + return summary_content, usage + + @staticmethod + def _extract_images_from_text(tenant_id: str, text: str) -> list[File]: + """ + Extract images from markdown text and convert them to File objects. + + Args: + tenant_id: Tenant ID + text: Text content that may contain markdown image links + + Returns: + List of File objects representing images found in the text + """ + # Extract markdown images using regex pattern + pattern = r"!\[.*?\]\((.*?)\)" + images = re.findall(pattern, text) + + if not images: + return [] + + upload_file_id_list = [] + + for image in images: + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For tools directory - direct file formats (e.g., .png, .jpg, etc.) + pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?" + match = re.search(pattern, image) + if match: + # Tool files are handled differently, skip for now + continue + + if not upload_file_id_list: + return [] + + # Get unique IDs for database query + unique_upload_file_ids = list(set(upload_file_id_list)) + upload_files = ( + db.session.query(UploadFile) + .where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id) + .all() + ) + + # Create File objects from UploadFile records + file_objects = [] + for upload_file in upload_files: + # Only process image files + if not upload_file.mime_type or "image" not in upload_file.mime_type: + continue + + mapping = { + "upload_file_id": upload_file.id, + "transfer_method": FileTransferMethod.LOCAL_FILE.value, + "type": FileType.IMAGE.value, + } + + try: + file_obj = build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + file_objects.append(file_obj) + except Exception as e: + logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e)) + continue + + return file_objects + + @staticmethod + def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]: + """ + Extract images from SegmentAttachmentBinding table (preferred method). + This matches how DatasetRetrieval gets segment attachments. + + Args: + tenant_id: Tenant ID + segment_id: Segment ID to fetch attachments for + + Returns: + List of File objects representing images found in segment attachments + """ + from sqlalchemy import select + + # Query attachments from SegmentAttachmentBinding table + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == segment_id, + SegmentAttachmentBinding.tenant_id == tenant_id, + ) + ).all() + + if not attachments_with_bindings: + return [] + + file_objects = [] + for _, upload_file in attachments_with_bindings: + # Only process image files + if not upload_file.mime_type or "image" not in upload_file.mime_type: + continue + + try: + # Create File object directly (similar to DatasetRetrieval) + file_obj = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + ) + file_objects.append(file_obj) + except Exception as e: + logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e)) + continue + + return file_objects diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 0366f3259f..961df2e50c 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -1,11 +1,14 @@ """Paragraph index processor.""" import json +import logging import uuid from collections.abc import Mapping from typing import Any from configs import dify_config +from core.db.session_factory import session_factory +from core.entities.knowledge_entities import PreviewDetail from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService @@ -25,6 +28,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm from models.dataset import Document as DatasetDocument from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) class ParentChildIndexProcessor(BaseIndexProcessor): @@ -135,6 +141,30 @@ class ParentChildIndexProcessor(BaseIndexProcessor): def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + if dataset.indexing_technique == "high_quality": delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") @@ -326,3 +356,91 @@ class ParentChildIndexProcessor(BaseIndexProcessor): "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), } + + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary + and write it to the summary attribute of PreviewDetail. + In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception. + + Note: For parent-child structure, we only generate summaries for parent chunks. + """ + import concurrent.futures + + from flask import current_app + + # Capture Flask app context for worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def process(preview: PreviewDetail) -> None: + """Generate summary for a single preview item (parent chunk).""" + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + if flask_app: + # Ensure Flask app context in worker thread + with flask_app.app_context(): + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=tenant_id, + text=preview.content, + summary_index_setting=summary_index_setting, + ) + preview.summary = summary + else: + # Fallback: try without app context (may fail) + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=tenant_id, + text=preview.content, + summary_index_setting=summary_index_setting, + ) + preview.summary = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_texts)) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor: + futures = [executor.submit(process, preview) for preview in preview_texts] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode (indexing-estimate), if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise ValueError(error_summary) + + return preview_texts diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 1183d5fbd7..272d2ed351 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -11,6 +11,8 @@ import pandas as pd from flask import Flask, current_app from werkzeug.datastructures import FileStorage +from core.db.session_factory import session_factory +from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService @@ -25,9 +27,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.account import Account -from models.dataset import Dataset +from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule +from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) @@ -144,6 +147,31 @@ class QAIndexProcessor(BaseIndexProcessor): vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Note: qa_model doesn't generate summaries, but we clean them for completeness + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -212,6 +240,17 @@ class QAIndexProcessor(BaseIndexProcessor): "total_segments": len(qa_chunks.qa_chunks), } + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + QA model doesn't generate summaries, so this method returns preview_texts unchanged. + + Note: QA model uses question-answer pairs, which don't require summary generation. + """ + # QA model doesn't generate summaries, return as-is + return preview_texts + def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): format_documents = [] if document_node.page_content is None or not document_node.page_content.strip(): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index f8f85d141a..541c241ae5 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -236,20 +236,24 @@ class DatasetRetrieval: if records: for record in records: segment = record.segment + # Build content: if summary exists, add it before the segment content if segment.answer: - document_context_list.append( - DocumentContext( - content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=record.score, - ) - ) + segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}" else: - document_context_list.append( - DocumentContext( - content=segment.get_sign_content(), - score=record.score, - ) + segment_content = segment.get_sign_content() + + # If summary exists, prepend it to the content + if record.summary: + final_content = f"{record.summary}\n{segment_content}" + else: + final_content = segment_content + + document_context_list.append( + DocumentContext( + content=final_content, + score=record.score, ) + ) if vision_enabled: attachments_with_bindings = db.session.execute( select(SegmentAttachmentBinding, UploadFile) @@ -316,6 +320,9 @@ class DatasetRetrieval: source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: source.content = segment.content + # Add summary if this segment was retrieved via summary + if hasattr(record, "summary") and record.summary: + source.summary = record.summary retrieval_resource_list.append(source) if hit_callback and retrieval_resource_list: retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index d83823d7b9..6f2826f634 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -1,19 +1,18 @@ -""" -Repository implementations for data access. +"""Repository implementations for data access.""" -This package contains concrete implementations of the repository interfaces -defined in the core.workflow.repository package. -""" +from __future__ import annotations -from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError -from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from .factory import DifyCoreRepositoryFactory, RepositoryImportError +from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ "CeleryWorkflowExecutionRepository", "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", "RepositoryImportError", + "SQLAlchemyWorkflowExecutionRepository", "SQLAlchemyWorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py new file mode 100644 index 0000000000..0e04c56e0e --- /dev/null +++ b/api/core/repositories/human_input_repository.py @@ -0,0 +1,553 @@ +import dataclasses +import json +from collections.abc import Mapping, Sequence +from datetime import datetime +from typing import Any + +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session, selectinload, sessionmaker + +from core.workflow.nodes.human_input.entities import ( + DeliveryChannelConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + FormDefinition, + HumanInputNodeData, + MemberRecipient, + WebAppDeliveryMethod, +) +from core.workflow.nodes.human_input.enums import ( + DeliveryMethodType, + HumanInputFormKind, + HumanInputFormStatus, +) +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + FormNotFoundError, + HumanInputFormEntity, + HumanInputFormRecipientEntity, +) +from libs.datetime_utils import naive_utc_now +from libs.uuid_utils import uuidv7 +from models.account import Account, TenantAccountJoin +from models.human_input import ( + BackstageRecipientPayload, + ConsoleDeliveryPayload, + ConsoleRecipientPayload, + EmailExternalRecipientPayload, + EmailMemberRecipientPayload, + HumanInputDelivery, + HumanInputForm, + HumanInputFormRecipient, + RecipientType, + StandaloneWebAppRecipientPayload, +) + + +@dataclasses.dataclass(frozen=True) +class _DeliveryAndRecipients: + delivery: HumanInputDelivery + recipients: Sequence[HumanInputFormRecipient] + + +@dataclasses.dataclass(frozen=True) +class _WorkspaceMemberInfo: + user_id: str + email: str + + +class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): + def __init__(self, recipient_model: HumanInputFormRecipient): + self._recipient_model = recipient_model + + @property + def id(self) -> str: + return self._recipient_model.id + + @property + def token(self) -> str: + if self._recipient_model.access_token is None: + raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}") + return self._recipient_model.access_token + + +class _HumanInputFormEntityImpl(HumanInputFormEntity): + def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]): + self._form_model = form_model + self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models] + self._web_app_recipient = next( + ( + recipient + for recipient in recipient_models + if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP + ), + None, + ) + self._console_recipient = next( + (recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.CONSOLE), + None, + ) + self._submitted_data: Mapping[str, Any] | None = ( + json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None + ) + + @property + def id(self) -> str: + return self._form_model.id + + @property + def web_app_token(self): + if self._console_recipient is not None: + return self._console_recipient.access_token + if self._web_app_recipient is None: + return None + return self._web_app_recipient.access_token + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: + return list(self._recipients) + + @property + def rendered_content(self) -> str: + return self._form_model.rendered_content + + @property + def selected_action_id(self) -> str | None: + return self._form_model.selected_action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self._submitted_data + + @property + def submitted(self) -> bool: + return self._form_model.submitted_at is not None + + @property + def status(self) -> HumanInputFormStatus: + return self._form_model.status + + @property + def expiration_time(self) -> datetime: + return self._form_model.expiration_time + + +@dataclasses.dataclass(frozen=True) +class HumanInputFormRecord: + form_id: str + workflow_run_id: str | None + node_id: str + tenant_id: str + app_id: str + form_kind: HumanInputFormKind + definition: FormDefinition + rendered_content: str + created_at: datetime + expiration_time: datetime + status: HumanInputFormStatus + selected_action_id: str | None + submitted_data: Mapping[str, Any] | None + submitted_at: datetime | None + submission_user_id: str | None + submission_end_user_id: str | None + completed_by_recipient_id: str | None + recipient_id: str | None + recipient_type: RecipientType | None + access_token: str | None + + @property + def submitted(self) -> bool: + return self.submitted_at is not None + + @classmethod + def from_models( + cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None + ) -> "HumanInputFormRecord": + definition_payload = json.loads(form_model.form_definition) + if "expiration_time" not in definition_payload: + definition_payload["expiration_time"] = form_model.expiration_time + return cls( + form_id=form_model.id, + workflow_run_id=form_model.workflow_run_id, + node_id=form_model.node_id, + tenant_id=form_model.tenant_id, + app_id=form_model.app_id, + form_kind=form_model.form_kind, + definition=FormDefinition.model_validate(definition_payload), + rendered_content=form_model.rendered_content, + created_at=form_model.created_at, + expiration_time=form_model.expiration_time, + status=form_model.status, + selected_action_id=form_model.selected_action_id, + submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None, + submitted_at=form_model.submitted_at, + submission_user_id=form_model.submission_user_id, + submission_end_user_id=form_model.submission_end_user_id, + completed_by_recipient_id=form_model.completed_by_recipient_id, + recipient_id=recipient_model.id if recipient_model else None, + recipient_type=recipient_model.recipient_type if recipient_model else None, + access_token=recipient_model.access_token if recipient_model else None, + ) + + +class _InvalidTimeoutStatusError(ValueError): + pass + + +class HumanInputFormRepositoryImpl: + def __init__( + self, + session_factory: sessionmaker | Engine, + tenant_id: str, + ): + if isinstance(session_factory, Engine): + session_factory = sessionmaker(bind=session_factory) + self._session_factory = session_factory + self._tenant_id = tenant_id + + def _delivery_method_to_model( + self, + session: Session, + form_id: str, + delivery_method: DeliveryChannelConfig, + ) -> _DeliveryAndRecipients: + delivery_id = str(uuidv7()) + delivery_model = HumanInputDelivery( + id=delivery_id, + form_id=form_id, + delivery_method_type=delivery_method.type, + delivery_config_id=delivery_method.id, + channel_payload=delivery_method.model_dump_json(), + ) + recipients: list[HumanInputFormRecipient] = [] + if isinstance(delivery_method, WebAppDeliveryMethod): + recipient_model = HumanInputFormRecipient( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(), + ) + recipients.append(recipient_model) + elif isinstance(delivery_method, EmailDeliveryMethod): + email_recipients_config = delivery_method.config.recipients + recipients.extend( + self._build_email_recipients( + session=session, + form_id=form_id, + delivery_id=delivery_id, + recipients_config=email_recipients_config, + ) + ) + + return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients) + + def _build_email_recipients( + self, + session: Session, + form_id: str, + delivery_id: str, + recipients_config: EmailRecipients, + ) -> list[HumanInputFormRecipient]: + member_user_ids = [ + recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient) + ] + external_emails = [ + recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient) + ] + if recipients_config.whole_workspace: + members = self._query_all_workspace_members(session=session) + else: + members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids) + + return self._create_email_recipients_from_resolved( + form_id=form_id, + delivery_id=delivery_id, + members=members, + external_emails=external_emails, + ) + + @staticmethod + def _create_email_recipients_from_resolved( + *, + form_id: str, + delivery_id: str, + members: Sequence[_WorkspaceMemberInfo], + external_emails: Sequence[str], + ) -> list[HumanInputFormRecipient]: + recipient_models: list[HumanInputFormRecipient] = [] + seen_emails: set[str] = set() + + for member in members: + if not member.email: + continue + if member.email in seen_emails: + continue + seen_emails.add(member.email) + payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email) + recipient_models.append( + HumanInputFormRecipient.new( + form_id=form_id, + delivery_id=delivery_id, + payload=payload, + ) + ) + + for email in external_emails: + if not email: + continue + if email in seen_emails: + continue + seen_emails.add(email) + recipient_models.append( + HumanInputFormRecipient.new( + form_id=form_id, + delivery_id=delivery_id, + payload=EmailExternalRecipientPayload(email=email), + ) + ) + + return recipient_models + + def _query_all_workspace_members( + self, + session: Session, + ) -> list[_WorkspaceMemberInfo]: + stmt = ( + select(Account.id, Account.email) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where(TenantAccountJoin.tenant_id == self._tenant_id) + ) + rows = session.execute(stmt).all() + return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] + + def _query_workspace_members_by_ids( + self, + session: Session, + restrict_to_user_ids: Sequence[str], + ) -> list[_WorkspaceMemberInfo]: + unique_ids = {user_id for user_id in restrict_to_user_ids if user_id} + if not unique_ids: + return [] + + stmt = ( + select(Account.id, Account.email) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where(TenantAccountJoin.tenant_id == self._tenant_id) + ) + stmt = stmt.where(Account.id.in_(unique_ids)) + + rows = session.execute(stmt).all() + return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + form_config: HumanInputNodeData = params.form_config + + with self._session_factory(expire_on_commit=False) as session, session.begin(): + # Generate unique form ID + form_id = str(uuidv7()) + start_time = naive_utc_now() + node_expiration = form_config.expiration_time(start_time) + form_definition = FormDefinition( + form_content=form_config.form_content, + inputs=form_config.inputs, + user_actions=form_config.user_actions, + rendered_content=params.rendered_content, + expiration_time=node_expiration, + default_values=dict(params.resolved_default_values), + display_in_ui=params.display_in_ui, + node_title=form_config.title, + ) + form_model = HumanInputForm( + id=form_id, + tenant_id=self._tenant_id, + app_id=params.app_id, + workflow_run_id=params.workflow_execution_id, + form_kind=params.form_kind, + node_id=params.node_id, + form_definition=form_definition.model_dump_json(), + rendered_content=params.rendered_content, + expiration_time=node_expiration, + created_at=start_time, + ) + session.add(form_model) + recipient_models: list[HumanInputFormRecipient] = [] + for delivery in params.delivery_methods: + delivery_and_recipients = self._delivery_method_to_model( + session=session, + form_id=form_id, + delivery_method=delivery, + ) + session.add(delivery_and_recipients.delivery) + session.add_all(delivery_and_recipients.recipients) + recipient_models.extend(delivery_and_recipients.recipients) + if params.console_recipient_required and not any( + recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models + ): + console_delivery_id = str(uuidv7()) + console_delivery = HumanInputDelivery( + id=console_delivery_id, + form_id=form_id, + delivery_method_type=DeliveryMethodType.WEBAPP, + delivery_config_id=None, + channel_payload=ConsoleDeliveryPayload().model_dump_json(), + ) + console_recipient = HumanInputFormRecipient( + form_id=form_id, + delivery_id=console_delivery_id, + recipient_type=RecipientType.CONSOLE, + recipient_payload=ConsoleRecipientPayload( + account_id=params.console_creator_account_id, + ).model_dump_json(), + ) + session.add(console_delivery) + session.add(console_recipient) + recipient_models.append(console_recipient) + if params.backstage_recipient_required and not any( + recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models + ): + backstage_delivery_id = str(uuidv7()) + backstage_delivery = HumanInputDelivery( + id=backstage_delivery_id, + form_id=form_id, + delivery_method_type=DeliveryMethodType.WEBAPP, + delivery_config_id=None, + channel_payload=ConsoleDeliveryPayload().model_dump_json(), + ) + backstage_recipient = HumanInputFormRecipient( + form_id=form_id, + delivery_id=backstage_delivery_id, + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload( + account_id=params.console_creator_account_id, + ).model_dump_json(), + ) + session.add(backstage_delivery) + session.add(backstage_recipient) + recipient_models.append(backstage_recipient) + session.flush() + + return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + form_query = select(HumanInputForm).where( + HumanInputForm.workflow_run_id == workflow_execution_id, + HumanInputForm.node_id == node_id, + HumanInputForm.tenant_id == self._tenant_id, + ) + with self._session_factory(expire_on_commit=False) as session: + form_model: HumanInputForm | None = session.scalars(form_query).first() + if form_model is None: + return None + + recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id) + recipient_models = session.scalars(recipient_query).all() + return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) + + +class HumanInputFormSubmissionRepository: + """Repository for fetching and submitting human input forms.""" + + def __init__(self, session_factory: sessionmaker | Engine): + if isinstance(session_factory, Engine): + session_factory = sessionmaker(bind=session_factory) + self._session_factory = session_factory + + def get_by_token(self, form_token: str) -> HumanInputFormRecord | None: + query = ( + select(HumanInputFormRecipient) + .options(selectinload(HumanInputFormRecipient.form)) + .where(HumanInputFormRecipient.access_token == form_token) + ) + with self._session_factory(expire_on_commit=False) as session: + recipient_model = session.scalars(query).first() + if recipient_model is None or recipient_model.form is None: + return None + return HumanInputFormRecord.from_models(recipient_model.form, recipient_model) + + def get_by_form_id_and_recipient_type( + self, + form_id: str, + recipient_type: RecipientType, + ) -> HumanInputFormRecord | None: + query = ( + select(HumanInputFormRecipient) + .options(selectinload(HumanInputFormRecipient.form)) + .where( + HumanInputFormRecipient.form_id == form_id, + HumanInputFormRecipient.recipient_type == recipient_type, + ) + ) + with self._session_factory(expire_on_commit=False) as session: + recipient_model = session.scalars(query).first() + if recipient_model is None or recipient_model.form is None: + return None + return HumanInputFormRecord.from_models(recipient_model.form, recipient_model) + + def mark_submitted( + self, + *, + form_id: str, + recipient_id: str | None, + selected_action_id: str, + form_data: Mapping[str, Any], + submission_user_id: str | None, + submission_end_user_id: str | None, + ) -> HumanInputFormRecord: + with self._session_factory(expire_on_commit=False) as session, session.begin(): + form_model = session.get(HumanInputForm, form_id) + if form_model is None: + raise FormNotFoundError(f"form not found, id={form_id}") + + recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None + + form_model.selected_action_id = selected_action_id + form_model.submitted_data = json.dumps(form_data) + form_model.submitted_at = naive_utc_now() + form_model.status = HumanInputFormStatus.SUBMITTED + form_model.submission_user_id = submission_user_id + form_model.submission_end_user_id = submission_end_user_id + form_model.completed_by_recipient_id = recipient_id + + session.add(form_model) + session.flush() + session.refresh(form_model) + if recipient_model is not None: + session.refresh(recipient_model) + + return HumanInputFormRecord.from_models(form_model, recipient_model) + + def mark_timeout( + self, + *, + form_id: str, + timeout_status: HumanInputFormStatus, + reason: str | None = None, + ) -> HumanInputFormRecord: + with self._session_factory(expire_on_commit=False) as session, session.begin(): + form_model = session.get(HumanInputForm, form_id) + if form_model is None: + raise FormNotFoundError(f"form not found, id={form_id}") + + if timeout_status not in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}: + raise _InvalidTimeoutStatusError(f"invalid timeout status: {timeout_status}") + + # already handled or submitted + if form_model.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}: + return HumanInputFormRecord.from_models(form_model, None) + + if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED: + raise FormNotFoundError(f"form already submitted, id={form_id}") + + form_model.status = timeout_status + form_model.selected_action_id = None + form_model.submitted_data = None + form_model.submission_user_id = None + form_model.submission_end_user_id = None + form_model.completed_by_recipient_id = None + # Reason is recorded in status/error downstream; not stored on form. + session.add(form_model) + session.flush() + session.refresh(form_model) + + return HumanInputFormRecord.from_models(form_model, None) diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 4436773d25..324dd059d1 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -488,6 +488,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, WorkflowNodeExecutionModel.tenant_id == self._tenant_id, WorkflowNodeExecutionModel.triggered_from == triggered_from, + WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED, ) if self._app_id: diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index e4afe24426..4c3efd6ff9 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -1,4 +1,5 @@ from core.tools.entities.tool_entities import ToolInvokeMeta +from libs.exception import BaseHTTPException class ToolProviderNotFoundError(ValueError): @@ -37,6 +38,12 @@ class ToolCredentialPolicyViolationError(ValueError): pass +class WorkflowToolHumanInputNotSupportedError(BaseHTTPException): + error_code = "workflow_tool_human_input_not_supported" + description = "Workflow with Human Input nodes cannot be published as a workflow tool." + code = 400 + + class ToolEngineInvokeError(Exception): meta: ToolInvokeMeta diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f96510fb45..057ec41f65 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -169,20 +169,24 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if records: for record in records: segment = record.segment + # Build content: if summary exists, add it before the segment content if segment.answer: - document_context_list.append( - DocumentContext( - content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=record.score, - ) - ) + segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}" else: - document_context_list.append( - DocumentContext( - content=segment.get_sign_content(), - score=record.score, - ) + segment_content = segment.get_sign_content() + + # If summary exists, prepend it to the content + if record.summary: + final_content = f"{record.summary}\n{segment_content}" + else: + final_content = segment_content + + document_context_list.append( + DocumentContext( + content=final_content, + score=record.score, ) + ) if self.return_resource: for record in records: @@ -216,6 +220,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: source.content = segment.content + # Add summary if this segment was retrieved via summary + if hasattr(record, "summary") and record.summary: + source.summary = record.summary retrieval_resource_list.append(source) if self.return_resource and retrieval_resource_list: diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 188da0c32d..8588ccc718 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -3,6 +3,8 @@ from typing import Any from core.app.app_config.entities import VariableEntity from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError +from core.workflow.enums import NodeType from core.workflow.nodes.base.entities import OutputVariableEntity @@ -50,6 +52,13 @@ class WorkflowToolConfigurationUtils: return [outputs_by_variable[variable] for variable in variable_order] + @classmethod + def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None: + nodes = graph.get("nodes", []) + for node in nodes: + if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT: + raise WorkflowToolHumanInputNotSupportedError() + @classmethod def check_is_synced( cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 9c1ceff145..01fa5de31e 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -98,6 +98,10 @@ class WorkflowTool(Tool): invoke_from=self.runtime.invoke_from, streaming=False, call_depth=self.workflow_call_depth + 1, + # NOTE(QuantumGhost): We explicitly set `pause_state_config` to `None` + # because workflow pausing mechanisms (such as HumanInput) are not + # supported within WorkflowTool execution context. + pause_state_config=None, ) assert isinstance(result, dict) data = result.get("data", {}) diff --git a/api/core/trigger/debug/event_bus.py b/api/core/trigger/debug/event_bus.py index 9d10e1a0e0..e3fb6a13d9 100644 --- a/api/core/trigger/debug/event_bus.py +++ b/api/core/trigger/debug/event_bus.py @@ -23,8 +23,8 @@ class TriggerDebugEventBus: """ # LUA_SELECT: Atomic poll or register for event - # KEYS[1] = trigger_debug_inbox:{tenant_id}:{address_id} - # KEYS[2] = trigger_debug_waiting_pool:{tenant_id}:... + # KEYS[1] = trigger_debug_inbox:{}: + # KEYS[2] = trigger_debug_waiting_pool:{}:... # ARGV[1] = address_id LUA_SELECT = ( "local v=redis.call('GET',KEYS[1]);" @@ -35,7 +35,7 @@ class TriggerDebugEventBus: ) # LUA_DISPATCH: Dispatch event to all waiting addresses - # KEYS[1] = trigger_debug_waiting_pool:{tenant_id}:... + # KEYS[1] = trigger_debug_waiting_pool:{}:... # ARGV[1] = tenant_id # ARGV[2] = event_json LUA_DISPATCH = ( @@ -43,7 +43,7 @@ class TriggerDebugEventBus: "if #a==0 then return 0 end;" "redis.call('DEL',KEYS[1]);" "for i=1,#a do " - f"redis.call('SET','trigger_debug_inbox:'..ARGV[1]..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});" + f"redis.call('SET','trigger_debug_inbox:{{'..ARGV[1]..'}}'..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});" "end;" "return #a" ) @@ -108,7 +108,7 @@ class TriggerDebugEventBus: Event object if available, None otherwise """ address_id: str = hashlib.sha256(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest() - address: str = f"trigger_debug_inbox:{tenant_id}:{address_id}" + address: str = f"trigger_debug_inbox:{{{tenant_id}}}:{address_id}" try: event_data = redis_client.eval( diff --git a/api/core/trigger/debug/events.py b/api/core/trigger/debug/events.py index 9f7bab5e49..9aec342ed1 100644 --- a/api/core/trigger/debug/events.py +++ b/api/core/trigger/debug/events.py @@ -42,7 +42,7 @@ def build_webhook_pool_key(tenant_id: str, app_id: str, node_id: str) -> str: app_id: App ID node_id: Node ID """ - return f"{TriggerDebugPoolKey.WEBHOOK}:{tenant_id}:{app_id}:{node_id}" + return f"{TriggerDebugPoolKey.WEBHOOK}:{{{tenant_id}}}:{app_id}:{node_id}" class PluginTriggerDebugEvent(BaseDebugEvent): @@ -64,4 +64,4 @@ def build_plugin_pool_key(tenant_id: str, provider_id: str, subscription_id: str provider_id: Provider ID subscription_id: Subscription ID """ - return f"{TriggerDebugPoolKey.PLUGIN}:{tenant_id}:{str(provider_id)}:{subscription_id}:{name}" + return f"{TriggerDebugPoolKey.PLUGIN}:{{{tenant_id}}}:{str(provider_id)}:{subscription_id}:{name}" diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index be70e467a0..e73c38c1d3 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -2,10 +2,12 @@ from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution +from .workflow_start_reason import WorkflowStartReason __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", "WorkflowExecution", "WorkflowNodeExecution", + "WorkflowStartReason", ] diff --git a/api/core/workflow/entities/graph_init_params.py b/api/core/workflow/entities/graph_init_params.py index 7bf25b9f43..ff224a28d1 100644 --- a/api/core/workflow/entities/graph_init_params.py +++ b/api/core/workflow/entities/graph_init_params.py @@ -5,6 +5,16 @@ from pydantic import BaseModel, Field class GraphInitParams(BaseModel): + """GraphInitParams encapsulates the configurations and contextual information + that remain constant throughout a single execution of the graph engine. + + A single execution is defined as follows: as long as the execution has not reached + its conclusion, it is considered one execution. For instance, if a workflow is suspended + and later resumed, it is still regarded as a single execution, not two. + + For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`. + """ + # init params tenant_id: str = Field(..., description="tenant / workspace id") app_id: str = Field(..., description="app id") diff --git a/api/core/workflow/entities/pause_reason.py b/api/core/workflow/entities/pause_reason.py index c6655b7eab..147f56e8be 100644 --- a/api/core/workflow/entities/pause_reason.py +++ b/api/core/workflow/entities/pause_reason.py @@ -1,8 +1,11 @@ +from collections.abc import Mapping from enum import StrEnum, auto -from typing import Annotated, Literal, TypeAlias +from typing import Annotated, Any, Literal, TypeAlias from pydantic import BaseModel, Field +from core.workflow.nodes.human_input.entities import FormInput, UserAction + class PauseReasonType(StrEnum): HUMAN_INPUT_REQUIRED = auto() @@ -11,10 +14,31 @@ class PauseReasonType(StrEnum): class HumanInputRequired(BaseModel): TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED - form_id: str - # The identifier of the human input node causing the pause. + form_content: str + inputs: list[FormInput] = Field(default_factory=list) + actions: list[UserAction] = Field(default_factory=list) + display_in_ui: bool = False node_id: str + node_title: str + + # The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from + # `output_variable_name` to their resolved values. + # + # For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its + # selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable + # `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The + # `resolved_default_values` is `{"name": "John"}`. + # + # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. + resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) + + # The `form_token` is the token used to submit the form via UI surfaces. It corresponds to + # `HumanInputFormRecipient.access_token`. + # + # This field is `None` if webapp delivery is not set and not + # in orchestrating mode. + form_token: str | None = None class SchedulingPause(BaseModel): diff --git a/api/core/workflow/entities/workflow_start_reason.py b/api/core/workflow/entities/workflow_start_reason.py new file mode 100644 index 0000000000..df0f75383b --- /dev/null +++ b/api/core/workflow/entities/workflow_start_reason.py @@ -0,0 +1,8 @@ +from enum import StrEnum + + +class WorkflowStartReason(StrEnum): + """Reason for workflow start events across graph/queue/SSE layers.""" + + INITIAL = "initial" # First start of a workflow run. + RESUMPTION = "resumption" # Start triggered after resuming a paused run. diff --git a/api/core/workflow/graph_engine/_engine_utils.py b/api/core/workflow/graph_engine/_engine_utils.py new file mode 100644 index 0000000000..28898268fe --- /dev/null +++ b/api/core/workflow/graph_engine/_engine_utils.py @@ -0,0 +1,15 @@ +import time + + +def get_timestamp() -> float: + """Retrieve a timestamp as a float point numer representing the number of seconds + since the Unix epoch. + + This function is primarily used to measure the execution time of the workflow engine. + Since workflow execution may be paused and resumed on a different machine, + `time.perf_counter` cannot be used as it is inconsistent across machines. + + To address this, the function uses the wall clock as the time source. + However, it assumes that the clocks of all servers are properly synchronized. + """ + return round(time.time()) diff --git a/api/core/workflow/graph_engine/config.py b/api/core/workflow/graph_engine/config.py index 10dbbd7535..d56a69cee0 100644 --- a/api/core/workflow/graph_engine/config.py +++ b/api/core/workflow/graph_engine/config.py @@ -2,12 +2,14 @@ GraphEngine configuration models. """ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class GraphEngineConfig(BaseModel): """Configuration for GraphEngine worker pool scaling.""" + model_config = ConfigDict(frozen=True) + min_workers: int = 1 max_workers: int = 5 scale_up_threshold: int = 3 diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 5b0f56e59d..98a0702e1c 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -192,9 +192,13 @@ class EventHandler: self._event_collector.collect(edge_event) # Enqueue ready nodes - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) + if self._graph_execution.is_paused: + for node_id in ready_nodes: + self._graph_runtime_state.register_deferred_node(node_id) + else: + for node_id in ready_nodes: + self._state_manager.enqueue_node(node_id) + self._state_manager.start_execution(node_id) # Update execution tracking self._state_manager.finish_execution(event.node_id) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 0b359a2392..ac9e00e29e 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -14,6 +14,7 @@ from collections.abc import Generator from typing import TYPE_CHECKING, cast, final from core.workflow.context import capture_current_context +from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import NodeExecutionType from core.workflow.graph import Graph from core.workflow.graph_events import ( @@ -56,6 +57,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +_DEFAULT_CONFIG = GraphEngineConfig() + + @final class GraphEngine: """ @@ -71,7 +75,7 @@ class GraphEngine: graph: Graph, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel, - config: GraphEngineConfig, + config: GraphEngineConfig = _DEFAULT_CONFIG, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" # stop event @@ -235,7 +239,9 @@ class GraphEngine: self._graph_execution.paused = False self._graph_execution.pause_reasons = [] - start_event = GraphRunStartedEvent() + start_event = GraphRunStartedEvent( + reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL, + ) self._event_manager.notify_layers(start_event) yield start_event @@ -304,15 +310,17 @@ class GraphEngine: for layer in self._layers: try: layer.on_graph_start() - except Exception as e: - logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e) + except Exception: + logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__) def _start_execution(self, *, resume: bool = False) -> None: """Start execution subsystems.""" self._stop_event.clear() paused_nodes: list[str] = [] + deferred_nodes: list[str] = [] if resume: paused_nodes = self._graph_runtime_state.consume_paused_nodes() + deferred_nodes = self._graph_runtime_state.consume_deferred_nodes() # Start worker pool (it calculates initial workers internally) self._worker_pool.start() @@ -328,7 +336,11 @@ class GraphEngine: self._state_manager.enqueue_node(root_node.id) self._state_manager.start_execution(root_node.id) else: - for node_id in paused_nodes: + seen_nodes: set[str] = set() + for node_id in paused_nodes + deferred_nodes: + if node_id in seen_nodes: + continue + seen_nodes.add(node_id) self._state_manager.enqueue_node(node_id) self._state_manager.start_execution(node_id) @@ -346,8 +358,8 @@ class GraphEngine: for layer in self._layers: try: layer.on_graph_end(self._graph_execution.error) - except Exception as e: - logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e) + except Exception: + logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__) # Public property accessors for attributes that need external access @property diff --git a/api/core/workflow/graph_engine/graph_state_manager.py b/api/core/workflow/graph_engine/graph_state_manager.py index 22a3a826fc..d9773645c3 100644 --- a/api/core/workflow/graph_engine/graph_state_manager.py +++ b/api/core/workflow/graph_engine/graph_state_manager.py @@ -224,6 +224,8 @@ class GraphStateManager: Returns: Number of executing nodes """ + # This count is a best-effort snapshot and can change concurrently. + # Only use it for pause-drain checks where scheduling is already frozen. with self._lock: return len(self._executing_nodes) diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 27439a2412..d40d15c545 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -83,12 +83,12 @@ class Dispatcher: """Main dispatcher loop.""" try: self._process_commands() + paused = False while not self._stop_event.is_set(): - if ( - self._execution_coordinator.aborted - or self._execution_coordinator.paused - or self._execution_coordinator.execution_complete - ): + if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete: + break + if self._execution_coordinator.paused: + paused = True break self._execution_coordinator.check_scaling() @@ -101,13 +101,10 @@ class Dispatcher: time.sleep(0.1) self._process_commands() - while True: - try: - event = self._event_queue.get(block=False) - self._event_handler.dispatch(event) - self._event_queue.task_done() - except queue.Empty: - break + if paused: + self._drain_events_until_idle() + else: + self._drain_event_queue() except Exception as e: logger.exception("Dispatcher error") @@ -122,3 +119,24 @@ class Dispatcher: def _process_commands(self, event: GraphNodeEventBase | None = None): if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS): self._execution_coordinator.process_commands() + + def _drain_event_queue(self) -> None: + while True: + try: + event = self._event_queue.get(block=False) + self._event_handler.dispatch(event) + self._event_queue.task_done() + except queue.Empty: + break + + def _drain_events_until_idle(self) -> None: + while not self._stop_event.is_set(): + try: + event = self._event_queue.get(timeout=0.1) + self._event_handler.dispatch(event) + self._event_queue.task_done() + self._process_commands(event) + except queue.Empty: + if not self._execution_coordinator.has_executing_nodes(): + break + self._drain_event_queue() diff --git a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py index e8e8f9f16c..0f8550eb12 100644 --- a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py +++ b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py @@ -94,3 +94,11 @@ class ExecutionCoordinator: self._worker_pool.stop() self._state_manager.clear_executing() + + def has_executing_nodes(self) -> bool: + """Return True if any nodes are currently marked as executing.""" + # This check is only safe once execution has already paused. + # Before pause, executing state can change concurrently, which makes the result unreliable. + if not self._graph_execution.is_paused: + raise AssertionError("has_executing_nodes should only be called after execution is paused") + return self._state_manager.get_executing_count() > 0 diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py index 2b6ee4ec1c..56ea642092 100644 --- a/api/core/workflow/graph_events/__init__.py +++ b/api/core/workflow/graph_events/__init__.py @@ -38,6 +38,8 @@ from .loop import ( from .node import ( NodeRunExceptionEvent, NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, NodeRunPauseRequestedEvent, NodeRunRetrieverResourceEvent, NodeRunRetryEvent, @@ -60,6 +62,8 @@ __all__ = [ "NodeRunAgentLogEvent", "NodeRunExceptionEvent", "NodeRunFailedEvent", + "NodeRunHumanInputFormFilledEvent", + "NodeRunHumanInputFormTimeoutEvent", "NodeRunIterationFailedEvent", "NodeRunIterationNextEvent", "NodeRunIterationStartedEvent", diff --git a/api/core/workflow/graph_events/graph.py b/api/core/workflow/graph_events/graph.py index 5d10a76c15..f46526bcab 100644 --- a/api/core/workflow/graph_events/graph.py +++ b/api/core/workflow/graph_events/graph.py @@ -1,11 +1,16 @@ from pydantic import Field from core.workflow.entities.pause_reason import PauseReason +from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.graph_events import BaseGraphEvent class GraphRunStartedEvent(BaseGraphEvent): - pass + # Reason is emitted for workflow start events and is always set. + reason: WorkflowStartReason = Field( + default=WorkflowStartReason.INITIAL, + description="reason for workflow start", + ) class GraphRunSucceededEvent(BaseGraphEvent): diff --git a/api/core/workflow/graph_events/human_input.py b/api/core/workflow/graph_events/human_input.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index 4d0108e77b..975d72ad1f 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -54,6 +54,22 @@ class NodeRunRetryEvent(NodeRunStartedEvent): retry_index: int = Field(..., description="which retry attempt is about to be performed") +class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase): + """Emitted when a HumanInput form is submitted and before the node finishes.""" + + node_title: str = Field(..., description="HumanInput node title") + rendered_content: str = Field(..., description="Markdown content rendered with user inputs.") + action_id: str = Field(..., description="User action identifier chosen in the form.") + action_text: str = Field(..., description="Display text of the chosen action button.") + + +class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase): + """Emitted when a HumanInput form times out.""" + + node_title: str = Field(..., description="HumanInput node title") + expiration_time: datetime = Field(..., description="Form expiration time") + + class NodeRunPauseRequestedEvent(GraphNodeEventBase): reason: PauseReason = Field(..., description="pause reason") diff --git a/api/core/workflow/node_events/__init__.py b/api/core/workflow/node_events/__init__.py index f14a594c85..a9bef8f9a2 100644 --- a/api/core/workflow/node_events/__init__.py +++ b/api/core/workflow/node_events/__init__.py @@ -13,6 +13,8 @@ from .loop import ( LoopSucceededEvent, ) from .node import ( + HumanInputFormFilledEvent, + HumanInputFormTimeoutEvent, ModelInvokeCompletedEvent, PauseRequestedEvent, RunRetrieverResourceEvent, @@ -23,6 +25,8 @@ from .node import ( __all__ = [ "AgentLogEvent", + "HumanInputFormFilledEvent", + "HumanInputFormTimeoutEvent", "IterationFailedEvent", "IterationNextEvent", "IterationStartedEvent", diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index e4fa52f444..9c76b7d7c2 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -47,3 +47,19 @@ class StreamCompletedEvent(NodeEventBase): class PauseRequestedEvent(NodeEventBase): reason: PauseReason = Field(..., description="pause reason") + + +class HumanInputFormFilledEvent(NodeEventBase): + """Event emitted when a human input form is submitted.""" + + node_title: str + rendered_content: str + action_id: str + action_text: str + + +class HumanInputFormTimeoutEvent(NodeEventBase): + """Event emitted when a human input form times out.""" + + node_title: str + expiration_time: datetime diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 63e0260341..2b773b537c 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -18,6 +18,8 @@ from core.workflow.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, NodeRunIterationFailedEvent, NodeRunIterationNextEvent, NodeRunIterationStartedEvent, @@ -34,6 +36,8 @@ from core.workflow.graph_events import ( ) from core.workflow.node_events import ( AgentLogEvent, + HumanInputFormFilledEvent, + HumanInputFormTimeoutEvent, IterationFailedEvent, IterationNextEvent, IterationStartedEvent, @@ -61,6 +65,15 @@ logger = logging.getLogger(__name__) class Node(Generic[NodeDataT]): + """BaseNode serves as the foundational class for all node implementations. + + Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output` + attribute to track files generated by the LLM). However, these states are not persisted + when the workflow is suspended or resumed. If a node needs its state to be preserved + across workflow suspension and resumption, it should include the relevant state data + in its output. + """ + node_type: ClassVar[NodeType] execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData @@ -251,10 +264,33 @@ class Node(Generic[NodeDataT]): return self._node_execution_id def ensure_execution_id(self) -> str: - if not self._node_execution_id: - self._node_execution_id = str(uuid4()) + if self._node_execution_id: + return self._node_execution_id + + resumed_execution_id = self._restore_execution_id_from_runtime_state() + if resumed_execution_id: + self._node_execution_id = resumed_execution_id + return self._node_execution_id + + self._node_execution_id = str(uuid4()) return self._node_execution_id + def _restore_execution_id_from_runtime_state(self) -> str | None: + graph_execution = self.graph_runtime_state.graph_execution + try: + node_executions = graph_execution.node_executions + except AttributeError: + return None + if not isinstance(node_executions, dict): + return None + node_execution = node_executions.get(self._node_id) + if node_execution is None: + return None + execution_id = node_execution.execution_id + if not execution_id: + return None + return str(execution_id) + def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: return cast(NodeDataT, self._node_data_type.model_validate(data)) @@ -620,6 +656,28 @@ class Node(Generic[NodeDataT]): metadata=event.metadata, ) + @_dispatch.register + def _(self, event: HumanInputFormFilledEvent): + return NodeRunHumanInputFormFilledEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + node_title=event.node_title, + rendered_content=event.rendered_content, + action_id=event.action_id, + action_text=event.action_text, + ) + + @_dispatch.register + def _(self, event: HumanInputFormTimeoutEvent): + return NodeRunHumanInputFormTimeoutEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + node_title=event.node_title, + expiration_time=event.expiration_time, + ) + @_dispatch.register def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: return NodeRunLoopStartedEvent( diff --git a/api/core/workflow/nodes/human_input/__init__.py b/api/core/workflow/nodes/human_input/__init__.py index 379440557c..1789604577 100644 --- a/api/core/workflow/nodes/human_input/__init__.py +++ b/api/core/workflow/nodes/human_input/__init__.py @@ -1,3 +1,3 @@ -from .human_input_node import HumanInputNode - -__all__ = ["HumanInputNode"] +""" +Human Input node implementation. +""" diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/core/workflow/nodes/human_input/entities.py index 02913d93c3..72d4fc675b 100644 --- a/api/core/workflow/nodes/human_input/entities.py +++ b/api/core/workflow/nodes/human_input/entities.py @@ -1,10 +1,350 @@ -from pydantic import Field +""" +Human Input node entities. +""" +import re +import uuid +from collections.abc import Mapping, Sequence +from datetime import datetime, timedelta +from typing import Annotated, Any, ClassVar, Literal, Self + +from pydantic import BaseModel, Field, field_validator, model_validator + +from core.variables.consts import SELECTORS_LENGTH from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser +from core.workflow.runtime import VariablePool + +from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit + +_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") + + +class _WebAppDeliveryConfig(BaseModel): + """Configuration for webapp delivery method.""" + + pass # Empty for webapp delivery + + +class MemberRecipient(BaseModel): + """Member recipient for email delivery.""" + + type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER + user_id: str + + +class ExternalRecipient(BaseModel): + """External recipient for email delivery.""" + + type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL + email: str + + +EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")] + + +class EmailRecipients(BaseModel): + """Email recipients configuration.""" + + # When true, recipients are the union of all workspace members and external items. + # Member items are ignored because they are already covered by the workspace scope. + # De-duplication is applied by email, with member recipients taking precedence. + whole_workspace: bool = False + items: list[EmailRecipient] = Field(default_factory=list) + + +class EmailDeliveryConfig(BaseModel): + """Configuration for email delivery method.""" + + URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" + + recipients: EmailRecipients + + # the subject of email + subject: str + + # Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which + # represent the url to submit the form. + # + # It may also reference the output variable of the previous node with the syntax + # `{{#.#}}`. + body: str + debug_mode: bool = False + + def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig": + if not user_id: + debug_recipients = EmailRecipients(whole_workspace=False, items=[]) + return self.model_copy(update={"recipients": debug_recipients}) + debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) + return self.model_copy(update={"recipients": debug_recipients}) + + @classmethod + def replace_url_placeholder(cls, body: str, url: str | None) -> str: + """Replace the url placeholder with provided value.""" + return body.replace(cls.URL_PLACEHOLDER, url or "") + + @classmethod + def render_body_template( + cls, + *, + body: str, + url: str | None, + variable_pool: VariablePool | None = None, + ) -> str: + """Render email body by replacing placeholders with runtime values.""" + templated_body = cls.replace_url_placeholder(body, url) + if variable_pool is None: + return templated_body + return variable_pool.convert_template(templated_body).text + + +class _DeliveryMethodBase(BaseModel): + """Base delivery method configuration.""" + + enabled: bool = True + id: uuid.UUID = Field(default_factory=uuid.uuid4) + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + return () + + +class WebAppDeliveryMethod(_DeliveryMethodBase): + """Webapp delivery method configuration.""" + + type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP + # The config field is not used currently. + config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig) + + +class EmailDeliveryMethod(_DeliveryMethodBase): + """Email delivery method configuration.""" + + type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL + config: EmailDeliveryConfig + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + variable_template_parser = VariableTemplateParser(template=self.config.body) + selectors: list[Sequence[str]] = [] + for variable_selector in variable_template_parser.extract_variable_selectors(): + value_selector = list(variable_selector.value_selector) + if len(value_selector) < SELECTORS_LENGTH: + continue + selectors.append(value_selector[:SELECTORS_LENGTH]) + return selectors + + +DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] + + +def apply_debug_email_recipient( + method: DeliveryChannelConfig, + *, + enabled: bool, + user_id: str, +) -> DeliveryChannelConfig: + if not enabled: + return method + if not isinstance(method, EmailDeliveryMethod): + return method + if not method.config.debug_mode: + return method + debug_config = method.config.with_debug_recipient(user_id or "") + return method.model_copy(update={"config": debug_config}) + + +class FormInputDefault(BaseModel): + """Default configuration for form inputs.""" + + # NOTE: Ideally, a discriminated union would be used to model + # FormInputDefault. However, the UI requires preserving the previous + # value when switching between `VARIABLE` and `CONSTANT` types. This + # necessitates retaining all fields, making a discriminated union unsuitable. + + type: PlaceholderType + + # The selector of default variable, used when `type` is `VARIABLE`. + selector: Sequence[str] = Field(default_factory=tuple) # + + # The value of the default, used when `type` is `CONSTANT`. + # TODO: How should we express JSON values? + value: str = "" + + @model_validator(mode="after") + def _validate_selector(self) -> Self: + if self.type == PlaceholderType.CONSTANT: + return self + if len(self.selector) < SELECTORS_LENGTH: + raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") + return self + + +class FormInput(BaseModel): + """Form input definition.""" + + type: FormInputType + output_variable_name: str + default: FormInputDefault | None = None + + +_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +class UserAction(BaseModel): + """User action configuration.""" + + # id is the identifier for this action. + # It also serves as the identifiers of output handle. + # + # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) + id: str = Field(max_length=20) + title: str = Field(max_length=20) + button_style: ButtonStyle = ButtonStyle.DEFAULT + + @field_validator("id") + @classmethod + def _validate_id(cls, value: str) -> str: + if not _IDENTIFIER_PATTERN.match(value): + raise ValueError( + f"'{value}' is not a valid identifier. It must start with a letter or underscore, " + f"and contain only letters, numbers, or underscores." + ) + return value class HumanInputNodeData(BaseNodeData): - """Configuration schema for the HumanInput node.""" + """Human Input node data.""" - required_variables: list[str] = Field(default_factory=list) - pause_reason: str | None = Field(default=None) + delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) + form_content: str = "" + inputs: list[FormInput] = Field(default_factory=list) + user_actions: list[UserAction] = Field(default_factory=list) + timeout: int = 36 + timeout_unit: TimeoutUnit = TimeoutUnit.HOUR + + @field_validator("inputs") + @classmethod + def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: + seen_names: set[str] = set() + for form_input in inputs: + name = form_input.output_variable_name + if name in seen_names: + raise ValueError(f"duplicated output_variable_name '{name}' in inputs") + seen_names.add(name) + return inputs + + @field_validator("user_actions") + @classmethod + def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: + seen_ids: set[str] = set() + for action in user_actions: + action_id = action.id + if action_id in seen_ids: + raise ValueError(f"duplicated user action id '{action_id}'") + seen_ids.add(action_id) + return user_actions + + def is_webapp_enabled(self) -> bool: + for dm in self.delivery_methods: + if not dm.enabled: + continue + if dm.type == DeliveryMethodType.WEBAPP: + return True + return False + + def expiration_time(self, start_time: datetime) -> datetime: + if self.timeout_unit == TimeoutUnit.HOUR: + return start_time + timedelta(hours=self.timeout) + elif self.timeout_unit == TimeoutUnit.DAY: + return start_time + timedelta(days=self.timeout) + else: + raise AssertionError("unknown timeout unit.") + + def outputs_field_names(self) -> Sequence[str]: + field_names = [] + for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): + field_names.append(match.group("field_name")) + return field_names + + def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: + variable_mappings: dict[str, Sequence[str]] = {} + + def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: + for selector in selectors: + if len(selector) < SELECTORS_LENGTH: + continue + qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" + variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) + + form_template_parser = VariableTemplateParser(template=self.form_content) + _add_variable_selectors( + [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] + ) + for delivery_method in self.delivery_methods: + if not delivery_method.enabled: + continue + _add_variable_selectors(delivery_method.extract_variable_selectors()) + + for input in self.inputs: + default_value = input.default + if default_value is None: + continue + if default_value.type == PlaceholderType.CONSTANT: + continue + default_value_key = ".".join(default_value.selector) + qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" + variable_mappings[qualified_variable_mapping_key] = default_value.selector + + return variable_mappings + + def find_action_text(self, action_id: str) -> str: + """ + Resolve action display text by id. + """ + for action in self.user_actions: + if action.id == action_id: + return action.title + return action_id + + +class FormDefinition(BaseModel): + form_content: str + inputs: list[FormInput] = Field(default_factory=list) + user_actions: list[UserAction] = Field(default_factory=list) + rendered_content: str + expiration_time: datetime + + # this is used to store the resolved default values + default_values: dict[str, Any] = Field(default_factory=dict) + + # node_title records the title of the HumanInput node. + node_title: str | None = None + + # display_in_ui controls whether the form should be displayed in UI surfaces. + display_in_ui: bool | None = None + + +class HumanInputSubmissionValidationError(ValueError): + pass + + +def validate_human_input_submission( + *, + inputs: Sequence[FormInput], + user_actions: Sequence[UserAction], + selected_action_id: str, + form_data: Mapping[str, Any], +) -> None: + available_actions = {action.id for action in user_actions} + if selected_action_id not in available_actions: + raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") + + provided_inputs = set(form_data.keys()) + missing_inputs = [ + form_input.output_variable_name + for form_input in inputs + if form_input.output_variable_name not in provided_inputs + ] + + if missing_inputs: + missing_list = ", ".join(missing_inputs) + raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/core/workflow/nodes/human_input/enums.py b/api/core/workflow/nodes/human_input/enums.py new file mode 100644 index 0000000000..da85728828 --- /dev/null +++ b/api/core/workflow/nodes/human_input/enums.py @@ -0,0 +1,72 @@ +import enum + + +class HumanInputFormStatus(enum.StrEnum): + """Status of a human input form.""" + + # Awaiting submission from any recipient. Forms stay in this state until + # submitted or a timeout rule applies. + WAITING = enum.auto() + # Global timeout reached. The workflow run is stopped and will not resume. + # This is distinct from node-level timeout. + EXPIRED = enum.auto() + # Submitted by a recipient; form data is available and execution resumes + # along the selected action edge. + SUBMITTED = enum.auto() + # Node-level timeout reached. The human input node should emit a timeout + # event and the workflow should resume along the timeout edge. + TIMEOUT = enum.auto() + + +class HumanInputFormKind(enum.StrEnum): + """Kind of a human input form.""" + + RUNTIME = enum.auto() # Form created during workflow execution. + DELIVERY_TEST = enum.auto() # Form created for delivery tests. + + +class DeliveryMethodType(enum.StrEnum): + """Delivery method types for human input forms.""" + + # WEBAPP controls whether the form is delivered to the web app. It not only controls + # the standalone web app, but also controls the installed apps in the console. + WEBAPP = enum.auto() + + EMAIL = enum.auto() + + +class ButtonStyle(enum.StrEnum): + """Button styles for user actions.""" + + PRIMARY = enum.auto() + DEFAULT = enum.auto() + ACCENT = enum.auto() + GHOST = enum.auto() + + +class TimeoutUnit(enum.StrEnum): + """Timeout unit for form expiration.""" + + HOUR = enum.auto() + DAY = enum.auto() + + +class FormInputType(enum.StrEnum): + """Form input types.""" + + TEXT_INPUT = enum.auto() + PARAGRAPH = enum.auto() + + +class PlaceholderType(enum.StrEnum): + """Default value types for form inputs.""" + + VARIABLE = enum.auto() + CONSTANT = enum.auto() + + +class EmailRecipientType(enum.StrEnum): + """Email recipient types.""" + + MEMBER = enum.auto() + EXTERNAL = enum.auto() diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index 6c8bf36fab..1d7522ea25 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -1,12 +1,42 @@ -from collections.abc import Mapping -from typing import Any +import json +import logging +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any +from core.app.entities.app_invoke_entities import InvokeFrom +from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult, PauseRequestedEvent +from core.workflow.node_events import ( + HumanInputFormFilledEvent, + HumanInputFormTimeoutEvent, + NodeRunResult, + PauseRequestedEvent, +) +from core.workflow.node_events.base import NodeEventBase +from core.workflow.node_events.node import StreamCompletedEvent from core.workflow.nodes.base.node import Node +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now -from .entities import HumanInputNodeData +from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient +from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType + +if TYPE_CHECKING: + from core.workflow.entities.graph_init_params import GraphInitParams + from core.workflow.runtime.graph_runtime_state import GraphRuntimeState + + +_SELECTED_BRANCH_KEY = "selected_branch" + + +logger = logging.getLogger(__name__) class HumanInputNode(Node[HumanInputNodeData]): @@ -17,7 +47,7 @@ class HumanInputNode(Node[HumanInputNodeData]): "edge_source_handle", "edgeSourceHandle", "source_handle", - "selected_branch", + _SELECTED_BRANCH_KEY, "selectedBranch", "branch", "branch_id", @@ -25,43 +55,37 @@ class HumanInputNode(Node[HumanInputNodeData]): "handle", ) + _node_data: HumanInputNodeData + _form_repository: HumanInputFormRepository + _OUTPUT_FIELD_ACTION_ID = "__action_id" + _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" + _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + form_repository: HumanInputFormRepository | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + if form_repository is None: + form_repository = HumanInputFormRepositoryImpl( + session_factory=db.engine, + tenant_id=self.tenant_id, + ) + self._form_repository = form_repository + @classmethod def version(cls) -> str: return "1" - def _run(self): # type: ignore[override] - if self._is_completion_ready(): - branch_handle = self._resolve_branch_selection() - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={}, - edge_source_handle=branch_handle or "source", - ) - - return self._pause_generator() - - def _pause_generator(self): - # TODO(QuantumGhost): yield a real form id. - yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id)) - - def _is_completion_ready(self) -> bool: - """Determine whether all required inputs are satisfied.""" - - if not self.node_data.required_variables: - return False - - variable_pool = self.graph_runtime_state.variable_pool - - for selector_str in self.node_data.required_variables: - parts = selector_str.split(".") - if len(parts) != 2: - return False - segment = variable_pool.get(parts) - if segment is None: - return False - - return True - def _resolve_branch_selection(self) -> str | None: """Determine the branch handle selected by human input if available.""" @@ -108,3 +132,224 @@ class HumanInputNode(Node[HumanInputNodeData]): return candidate return None + + @property + def _workflow_execution_id(self) -> str: + workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id + assert workflow_exec_id is not None + return workflow_exec_id + + def _form_to_pause_event(self, form_entity: HumanInputFormEntity): + required_event = self._human_input_required_event(form_entity) + pause_requested_event = PauseRequestedEvent(reason=required_event) + return pause_requested_event + + def resolve_default_values(self) -> Mapping[str, Any]: + variable_pool = self.graph_runtime_state.variable_pool + resolved_defaults: dict[str, Any] = {} + for input in self._node_data.inputs: + if (default_value := input.default) is None: + continue + if default_value.type == PlaceholderType.CONSTANT: + continue + resolved_value = variable_pool.get(default_value.selector) + if resolved_value is None: + # TODO: How should we handle this? + continue + resolved_defaults[input.output_variable_name] = ( + WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value) + ) + + return resolved_defaults + + def _should_require_console_recipient(self) -> bool: + if self.invoke_from == InvokeFrom.DEBUGGER: + return True + if self.invoke_from == InvokeFrom.EXPLORE: + return self._node_data.is_webapp_enabled() + return False + + def _display_in_ui(self) -> bool: + if self.invoke_from == InvokeFrom.DEBUGGER: + return True + return self._node_data.is_webapp_enabled() + + def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: + enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] + if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}: + enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] + return [ + apply_debug_email_recipient( + method, + enabled=self.invoke_from == InvokeFrom.DEBUGGER, + user_id=self.user_id or "", + ) + for method in enabled_methods + ] + + def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: + node_data = self._node_data + resolved_default_values = self.resolve_default_values() + display_in_ui = self._display_in_ui() + form_token = form_entity.web_app_token + if display_in_ui and form_token is None: + raise AssertionError("Form token should be available for UI execution.") + return HumanInputRequired( + form_id=form_entity.id, + form_content=form_entity.rendered_content, + inputs=node_data.inputs, + actions=node_data.user_actions, + display_in_ui=display_in_ui, + node_id=self.id, + node_title=node_data.title, + form_token=form_token, + resolved_default_values=resolved_default_values, + ) + + def _run(self) -> Generator[NodeEventBase, None, None]: + """ + Execute the human input node. + + This method will: + 1. Generate a unique form ID + 2. Create form content with variable substitution + 3. Create form in database + 4. Send form via configured delivery methods + 5. Suspend workflow execution + 6. Wait for form submission to resume + """ + repo = self._form_repository + form = repo.get_form(self._workflow_execution_id, self.id) + if form is None: + display_in_ui = self._display_in_ui() + params = FormCreateParams( + app_id=self.app_id, + workflow_execution_id=self._workflow_execution_id, + node_id=self.id, + form_config=self._node_data, + rendered_content=self.render_form_content_before_submission(), + delivery_methods=self._effective_delivery_methods(), + display_in_ui=display_in_ui, + resolved_default_values=self.resolve_default_values(), + console_recipient_required=self._should_require_console_recipient(), + console_creator_account_id=( + self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None + ), + backstage_recipient_required=True, + ) + form_entity = self._form_repository.create_form(params) + # Create human input required event + + logger.info( + "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", + self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, + self.id, + form_entity.id, + ) + yield self._form_to_pause_event(form_entity) + return + + if ( + form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED} + or form.expiration_time <= naive_utc_now() + ): + yield HumanInputFormTimeoutEvent( + node_title=self._node_data.title, + expiration_time=form.expiration_time, + ) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={self._OUTPUT_FIELD_ACTION_ID: ""}, + edge_source_handle=self._TIMEOUT_HANDLE, + ) + ) + return + + if not form.submitted: + yield self._form_to_pause_event(form) + return + + selected_action_id = form.selected_action_id + if selected_action_id is None: + raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") + submitted_data = form.submitted_data or {} + outputs: dict[str, Any] = dict(submitted_data) + outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id + rendered_content = self.render_form_content_with_outputs( + form.rendered_content, + outputs, + self._node_data.outputs_field_names(), + ) + outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content + + action_text = self._node_data.find_action_text(selected_action_id) + + yield HumanInputFormFilledEvent( + node_title=self._node_data.title, + rendered_content=rendered_content, + action_id=selected_action_id, + action_text=action_text, + ) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs=outputs, + edge_source_handle=selected_action_id, + ) + ) + + def render_form_content_before_submission(self) -> str: + """ + Process form content by substituting variables. + + This method should: + 1. Parse the form_content markdown + 2. Substitute {{#node_name.var_name#}} with actual values + 3. Keep {{#$output.field_name#}} placeholders for form inputs + """ + rendered_form_content = self.graph_runtime_state.variable_pool.convert_template( + self._node_data.form_content, + ) + return rendered_form_content.markdown + + @staticmethod + def render_form_content_with_outputs( + form_content: str, + outputs: Mapping[str, Any], + field_names: Sequence[str], + ) -> str: + """ + Replace {{#$output.xxx#}} placeholders with submitted values. + """ + rendered_content = form_content + for field_name in field_names: + placeholder = "{{#$output." + field_name + "#}}" + value = outputs.get(field_name) + if value is None: + replacement = "" + elif isinstance(value, (dict, list)): + replacement = json.dumps(value, ensure_ascii=False) + else: + replacement = str(value) + rendered_content = rendered_content.replace(placeholder, replacement) + return rendered_content + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: Mapping[str, Any], + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selectors referenced in form content and input default values. + + This method should parse: + 1. Variables referenced in form_content ({{#node_name.var_name#}}) + 2. Variables referenced in input default values + """ + validated_node_data = HumanInputNodeData.model_validate(node_data) + return validated_node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 3daca90b9b..bfeb9b5b79 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData): type: str = "knowledge-index" chunk_structure: str index_chunk_variable_selector: list[str] + indexing_technique: str | None = None + summary_index_setting: dict | None = None diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 17ca4bef7b..b88c2d510f 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,9 +1,11 @@ +import concurrent.futures import datetime import logging import time from collections.abc import Mapping from typing import Any +from flask import current_app from sqlalchemy import func, select from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.runtime import VariablePool from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary +from services.summary_index_service import SummaryIndexService +from tasks.generate_summary_index_task import generate_summary_index_task from .entities import KnowledgeIndexNodeData from .exc import ( @@ -67,7 +71,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): # index knowledge try: if is_preview: - outputs = self._get_preview_output(node_data.chunk_structure, chunks) + # Preview mode: generate summaries for chunks directly without saving to database + # Format preview and generate summaries on-the-fly + # Get indexing_technique and summary_index_setting from node_data (workflow graph config) + # or fallback to dataset if not available in node_data + indexing_technique = node_data.indexing_technique or dataset.indexing_technique + summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting + + outputs = self._get_preview_output_with_summaries( + node_data.chunk_structure, + chunks, + dataset=dataset, + indexing_technique=indexing_technique, + summary_index_setting=summary_index_setting, + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, @@ -148,6 +165,11 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): ) .scalar() ) + # Update need_summary based on dataset's summary_index_setting + if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True: + document.need_summary = True + else: + document.need_summary = False db.session.add(document) # update document segment status db.session.query(DocumentSegment).where( @@ -163,6 +185,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): db.session.commit() + # Generate summary index if enabled + self._handle_summary_index_generation(dataset, document, variable_pool) + return { "dataset_id": ds_id_value, "dataset_name": dataset_name_value, @@ -173,9 +198,304 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): "display_status": "completed", } - def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]: + def _handle_summary_index_generation( + self, + dataset: Dataset, + document: Document, + variable_pool: VariablePool, + ) -> None: + """ + Handle summary index generation based on mode (debug/preview or production). + + Args: + dataset: Dataset containing the document + document: Document to generate summaries for + variable_pool: Variable pool to check invoke_from + """ + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return + + # Check if summary index is enabled + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + return + + # Skip qa_model documents + if document.doc_form == "qa_model": + return + + # Determine if in preview/debug mode + invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER + + if is_preview: + try: + # Query segments that need summary generation + query = db.session.query(DocumentSegment).filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, + ) + segments = query.all() + + if not segments: + logger.info("No segments found for document %s", document.id) + return + + # Filter segments based on mode + segments_to_process = [] + for segment in segments: + # Skip if summary already exists + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed") + .first() + ) + if existing_summary: + continue + + # For parent-child mode, all segments are parent chunks, so process all + segments_to_process.append(segment) + + if not segments_to_process: + logger.info("No segments need summary generation for document %s", document.id) + return + + # Use ThreadPoolExecutor for concurrent generation + flask_app = current_app._get_current_object() # type: ignore + max_workers = min(10, len(segments_to_process)) # Limit to 10 workers + + def process_segment(segment: DocumentSegment) -> None: + """Process a single segment in a thread with Flask app context.""" + with flask_app.app_context(): + try: + SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting) + except Exception: + logger.exception( + "Failed to generate summary for segment %s", + segment.id, + ) + # Continue processing other segments + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(process_segment, segment) for segment in segments_to_process] + # Wait for all tasks to complete + concurrent.futures.wait(futures) + + logger.info( + "Successfully generated summary index for %s segments in document %s", + len(segments_to_process), + document.id, + ) + except Exception: + logger.exception("Failed to generate summary index for document %s", document.id) + # Don't fail the entire indexing process if summary generation fails + else: + # Production mode: asynchronous generation + logger.info( + "Queuing summary index generation task for document %s (production mode)", + document.id, + ) + try: + generate_summary_index_task.delay(dataset.id, document.id, None) + logger.info("Summary index generation task queued for document %s", document.id) + except Exception: + logger.exception( + "Failed to queue summary index generation task for document %s", + document.id, + ) + # Don't fail the entire indexing process if task queuing fails + + def _get_preview_output_with_summaries( + self, + chunk_structure: str, + chunks: Any, + dataset: Dataset, + indexing_technique: str | None = None, + summary_index_setting: dict | None = None, + ) -> Mapping[str, Any]: + """ + Generate preview output with summaries for chunks in preview mode. + This method generates summaries on-the-fly without saving to database. + + Args: + chunk_structure: Chunk structure type + chunks: Chunks to generate preview for + dataset: Dataset object (for tenant_id) + indexing_technique: Indexing technique from node config or dataset + summary_index_setting: Summary index setting from node config or dataset + """ index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() - return index_processor.format_preview(chunks) + preview_output = index_processor.format_preview(chunks) + + # Check if summary index is enabled + if indexing_technique != "high_quality": + return preview_output + + if not summary_index_setting or not summary_index_setting.get("enable"): + return preview_output + + # Generate summaries for chunks + if "preview" in preview_output and isinstance(preview_output["preview"], list): + chunk_count = len(preview_output["preview"]) + logger.info( + "Generating summaries for %s chunks in preview mode (dataset: %s)", + chunk_count, + dataset.id, + ) + # Use ParagraphIndexProcessor's generate_summary method + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + # Get Flask app for application context in worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def generate_summary_for_chunk(preview_item: dict) -> None: + """Generate summary for a single chunk.""" + if "content" in preview_item: + # Set Flask application context in worker thread + if flask_app: + with flask_app.app_context(): + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=preview_item["content"], + summary_index_setting=summary_index_setting, + ) + if summary: + preview_item["summary"] = summary + else: + # Fallback: try without app context (may fail) + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=preview_item["content"], + summary_index_setting=summary_index_setting, + ) + if summary: + preview_item["summary"] = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_output["preview"])) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor: + futures = [ + executor.submit(generate_summary_for_chunk, preview_item) + for preview_item in preview_output["preview"] + ] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode, if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise KnowledgeIndexNodeError(error_summary) + + completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None) + logger.info( + "Completed summary generation for preview chunks: %s/%s succeeded", + completed_count, + len(preview_output["preview"]), + ) + + return preview_output + + def _get_preview_output( + self, + chunk_structure: str, + chunks: Any, + dataset: Dataset | None = None, + variable_pool: VariablePool | None = None, + ) -> Mapping[str, Any]: + index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() + preview_output = index_processor.format_preview(chunks) + + # If dataset is provided, try to enrich preview with summaries + if dataset and variable_pool: + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter_by(id=document_id.value).first() + if document: + # Query summaries for this document + summaries = ( + db.session.query(DocumentSegmentSummary) + .filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, + ) + .all() + ) + + if summaries: + # Create a map of segment content to summary for matching + # Use content matching as chunks in preview might not be indexed yet + summary_by_content = {} + for summary in summaries: + segment = ( + db.session.query(DocumentSegment) + .filter_by(id=summary.chunk_id, dataset_id=dataset.id) + .first() + ) + if segment: + # Normalize content for matching (strip whitespace) + normalized_content = segment.content.strip() + summary_by_content[normalized_content] = summary.summary_content + + # Enrich preview with summaries by content matching + if "preview" in preview_output and isinstance(preview_output["preview"], list): + matched_count = 0 + for preview_item in preview_output["preview"]: + if "content" in preview_item: + # Normalize content for matching + normalized_chunk_content = preview_item["content"].strip() + if normalized_chunk_content in summary_by_content: + preview_item["summary"] = summary_by_content[normalized_chunk_content] + matched_count += 1 + + if matched_count > 0: + logger.info( + "Enriched preview with %s existing summaries (dataset: %s, document: %s)", + matched_count, + dataset.id, + document.id, + ) + + return preview_output @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 8670a71aa3..3c4850ebac 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -419,6 +419,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" else: source["content"] = segment.get_sign_content() + # Add summary if available + if record.summary: + source["summary"] = record.summary retrieval_resource_list.append(source) if retrieval_resource_list: retrieval_resource_list = sorted( diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index dfb55dcd80..17d82c2118 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -685,6 +685,8 @@ class LLMNode(Node[LLMNodeData]): if "content" not in item: raise InvalidContextStructureError(f"Invalid context structure: {item}") + if item.get("summary"): + context_str += item["summary"] + "\n" context_str += item["content"] + "\n" retriever_resource = self._convert_to_original_retriever_resource(item) @@ -746,6 +748,7 @@ class LLMNode(Node[LLMNodeData]): page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), files=context_dict.get("files"), + summary=context_dict.get("summary"), ) return source diff --git a/api/core/workflow/repositories/human_input_form_repository.py b/api/core/workflow/repositories/human_input_form_repository.py new file mode 100644 index 0000000000..efde59c6fd --- /dev/null +++ b/api/core/workflow/repositories/human_input_form_repository.py @@ -0,0 +1,152 @@ +import abc +import dataclasses +from collections.abc import Mapping, Sequence +from datetime import datetime +from typing import Any, Protocol + +from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData +from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus + + +class HumanInputError(Exception): + pass + + +class FormNotFoundError(HumanInputError): + pass + + +@dataclasses.dataclass +class FormCreateParams: + # app_id is the identifier for the app that the form belongs to. + # It is a string with uuid format. + app_id: str + # None when creating a delivery test form; set for runtime forms. + workflow_execution_id: str | None + + # node_id is the identifier for a specific + # node in the graph. + # + # TODO: for node inside loop / iteration, this would + # cause problems, as a single node may be executed multiple times. + node_id: str + + form_config: HumanInputNodeData + rendered_content: str + # Delivery methods already filtered by runtime context (invoke_from). + delivery_methods: Sequence[DeliveryChannelConfig] + # UI display flag computed by runtime context. + display_in_ui: bool + + # resolved_default_values saves the values for defaults with + # type = VARIABLE. + # + # For type = CONSTANT, the value is not stored inside `resolved_default_values` + resolved_default_values: Mapping[str, Any] + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + + # Force creating a console-only recipient for submission in Console. + console_recipient_required: bool = False + console_creator_account_id: str | None = None + # Force creating a backstage recipient for submission in Console. + backstage_recipient_required: bool = False + + +class HumanInputFormEntity(abc.ABC): + @property + @abc.abstractmethod + def id(self) -> str: + """id returns the identifer of the form.""" + pass + + @property + @abc.abstractmethod + def web_app_token(self) -> str | None: + """web_app_token returns the token for submission inside webapp. + + For console/debug execution, this may point to the console submission token + if the form is configured to require console delivery. + """ + + # TODO: what if the users are allowed to add multiple + # webapp delivery? + pass + + @property + @abc.abstractmethod + def recipients(self) -> list["HumanInputFormRecipientEntity"]: ... + + @property + @abc.abstractmethod + def rendered_content(self) -> str: + """Rendered markdown content associated with the form.""" + ... + + @property + @abc.abstractmethod + def selected_action_id(self) -> str | None: + """Identifier of the selected user action if the form has been submitted.""" + ... + + @property + @abc.abstractmethod + def submitted_data(self) -> Mapping[str, Any] | None: + """Submitted form data if available.""" + ... + + @property + @abc.abstractmethod + def submitted(self) -> bool: + """Whether the form has been submitted.""" + ... + + @property + @abc.abstractmethod + def status(self) -> HumanInputFormStatus: + """Current status of the form.""" + ... + + @property + @abc.abstractmethod + def expiration_time(self) -> datetime: + """When the form expires.""" + ... + + +class HumanInputFormRecipientEntity(abc.ABC): + @property + @abc.abstractmethod + def id(self) -> str: + """id returns the identifer of this recipient.""" + ... + + @property + @abc.abstractmethod + def token(self) -> str: + """token returns a random string used to submit form""" + ... + + +class HumanInputFormRepository(Protocol): + """ + Repository interface for HumanInputForm. + + This interface defines the contract for accessing and manipulating + HumanInputForm data, regardless of the underlying storage mechanism. + + Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), + and other implementation details should be handled at the implementation level, not in + the core interface. This keeps the core domain model clean and independent of specific + application domains or deployment scenarios. + """ + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + """Get the form created for a given human input node in a workflow execution. Returns + `None` if the form has not been created yet.""" + ... + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + """ + Create a human input form from form definition. + """ + ... diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 401cecc162..f79230217c 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -6,14 +6,18 @@ import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol +from pydantic import BaseModel, Field from pydantic.json import pydantic_encoder from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.entities.pause_reason import PauseReason +from core.workflow.enums import NodeState from core.workflow.runtime.variable_pool import VariablePool +if TYPE_CHECKING: + from core.workflow.entities.pause_reason import PauseReason + class ReadyQueueProtocol(Protocol): """Structural interface required from ready queue implementations.""" @@ -60,7 +64,7 @@ class GraphExecutionProtocol(Protocol): aborted: bool error: Exception | None exceptions_count: int - pause_reasons: list[PauseReason] + pause_reasons: Sequence[PauseReason] def start(self) -> None: """Transition execution into the running state.""" @@ -103,14 +107,33 @@ class ResponseStreamCoordinatorProtocol(Protocol): ... +class NodeProtocol(Protocol): + """Structural interface for graph nodes.""" + + id: str + state: NodeState + + +class EdgeProtocol(Protocol): + id: str + state: NodeState + + class GraphProtocol(Protocol): """Structural interface required from graph instances attached to the runtime state.""" - nodes: Mapping[str, object] - edges: Mapping[str, object] - root_node: object + nodes: Mapping[str, NodeProtocol] + edges: Mapping[str, EdgeProtocol] + root_node: NodeProtocol - def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ... + def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... + + +class _GraphStateSnapshot(BaseModel): + """Serializable graph state snapshot for node/edge states.""" + + nodes: dict[str, NodeState] = Field(default_factory=dict) + edges: dict[str, NodeState] = Field(default_factory=dict) @dataclass(slots=True) @@ -128,10 +151,20 @@ class _GraphRuntimeStateSnapshot: graph_execution_dump: str | None response_coordinator_dump: str | None paused_nodes: tuple[str, ...] + deferred_nodes: tuple[str, ...] + graph_node_states: dict[str, NodeState] + graph_edge_states: dict[str, NodeState] class GraphRuntimeState: - """Mutable runtime state shared across graph execution components.""" + """Mutable runtime state shared across graph execution components. + + `GraphRuntimeState` encapsulates the runtime state of workflow execution, + including scheduling details, variable values, and timing information. + + Values that are initialized prior to workflow execution and remain constant + throughout the execution should be part of `GraphInitParams` instead. + """ def __init__( self, @@ -169,6 +202,16 @@ class GraphRuntimeState: self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() + self._deferred_nodes: set[str] = set() + + # Node and edges states needed to be restored into + # graph object. + # + # These two fields are non-None only when resuming from a snapshot. + # Once the graph is attached, these two fields will be set to None. + self._pending_graph_node_states: dict[str, NodeState] | None = None + self._pending_graph_edge_states: dict[str, NodeState] | None = None + self.stop_event: threading.Event = threading.Event() if graph is not None: @@ -190,6 +233,7 @@ class GraphRuntimeState: if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None: self._response_coordinator.loads(self._pending_response_coordinator_dump) self._pending_response_coordinator_dump = None + self._apply_pending_graph_state() def configure(self, *, graph: GraphProtocol | None = None) -> None: """Ensure core collaborators are initialized with the provided context.""" @@ -311,8 +355,13 @@ class GraphRuntimeState: "ready_queue": self.ready_queue.dumps(), "graph_execution": self.graph_execution.dumps(), "paused_nodes": list(self._paused_nodes), + "deferred_nodes": list(self._deferred_nodes), } + graph_state = self._snapshot_graph_state() + if graph_state is not None: + snapshot["graph_state"] = graph_state + if self._response_coordinator is not None and self._graph is not None: snapshot["response_coordinator"] = self._response_coordinator.dumps() @@ -346,6 +395,11 @@ class GraphRuntimeState: self._paused_nodes.add(node_id) + def get_paused_nodes(self) -> list[str]: + """Retrieve the list of paused nodes without mutating internal state.""" + + return list(self._paused_nodes) + def consume_paused_nodes(self) -> list[str]: """Retrieve and clear the list of paused nodes awaiting resume.""" @@ -353,6 +407,23 @@ class GraphRuntimeState: self._paused_nodes.clear() return nodes + def register_deferred_node(self, node_id: str) -> None: + """Record a node that became ready during pause and should resume later.""" + + self._deferred_nodes.add(node_id) + + def get_deferred_nodes(self) -> list[str]: + """Retrieve deferred nodes without mutating internal state.""" + + return list(self._deferred_nodes) + + def consume_deferred_nodes(self) -> list[str]: + """Retrieve and clear deferred nodes awaiting resume.""" + + nodes = list(self._deferred_nodes) + self._deferred_nodes.clear() + return nodes + # ------------------------------------------------------------------ # Builders # ------------------------------------------------------------------ @@ -414,6 +485,10 @@ class GraphRuntimeState: graph_execution_payload = payload.get("graph_execution") response_payload = payload.get("response_coordinator") paused_nodes_payload = payload.get("paused_nodes", []) + deferred_nodes_payload = payload.get("deferred_nodes", []) + graph_state_payload = payload.get("graph_state", {}) or {} + graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes") + graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges") return _GraphRuntimeStateSnapshot( start_at=start_at, @@ -427,6 +502,9 @@ class GraphRuntimeState: graph_execution_dump=graph_execution_payload, response_coordinator_dump=response_payload, paused_nodes=tuple(map(str, paused_nodes_payload)), + deferred_nodes=tuple(map(str, deferred_nodes_payload)), + graph_node_states=graph_node_states, + graph_edge_states=graph_edge_states, ) def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None: @@ -442,6 +520,10 @@ class GraphRuntimeState: self._restore_graph_execution(snapshot.graph_execution_dump) self._restore_response_coordinator(snapshot.response_coordinator_dump) self._paused_nodes = set(snapshot.paused_nodes) + self._deferred_nodes = set(snapshot.deferred_nodes) + self._pending_graph_node_states = snapshot.graph_node_states or None + self._pending_graph_edge_states = snapshot.graph_edge_states or None + self._apply_pending_graph_state() def _restore_ready_queue(self, payload: str | None) -> None: if payload is not None: @@ -478,3 +560,68 @@ class GraphRuntimeState: self._pending_response_coordinator_dump = payload self._response_coordinator = None + + def _snapshot_graph_state(self) -> _GraphStateSnapshot: + graph = self._graph + if graph is None: + if self._pending_graph_node_states is None and self._pending_graph_edge_states is None: + return _GraphStateSnapshot() + return _GraphStateSnapshot( + nodes=self._pending_graph_node_states or {}, + edges=self._pending_graph_edge_states or {}, + ) + + nodes = graph.nodes + edges = graph.edges + if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping): + return _GraphStateSnapshot() + + node_states = {} + for node_id, node in nodes.items(): + if not isinstance(node_id, str): + continue + node_states[node_id] = node.state + + edge_states = {} + for edge_id, edge in edges.items(): + if not isinstance(edge_id, str): + continue + edge_states[edge_id] = edge.state + + return _GraphStateSnapshot(nodes=node_states, edges=edge_states) + + def _apply_pending_graph_state(self) -> None: + if self._graph is None: + return + if self._pending_graph_node_states: + for node_id, state in self._pending_graph_node_states.items(): + node = self._graph.nodes.get(node_id) + if node is None: + continue + node.state = state + if self._pending_graph_edge_states: + for edge_id, state in self._pending_graph_edge_states.items(): + edge = self._graph.edges.get(edge_id) + if edge is None: + continue + edge.state = state + + self._pending_graph_node_states = None + self._pending_graph_edge_states = None + + +def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]: + if not isinstance(payload, Mapping): + return {} + raw_map = payload.get(key, {}) + if not isinstance(raw_map, Mapping): + return {} + result: dict[str, NodeState] = {} + for node_id, raw_state in raw_map.items(): + if not isinstance(node_id, str): + continue + try: + result[node_id] = NodeState(str(raw_state)) + except ValueError: + continue + return result diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index 5456043ccd..f1f549e1f8 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -15,12 +15,14 @@ class WorkflowRuntimeTypeConverter: def to_json_encodable(self, value: None) -> None: ... def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: - result = self._to_json_encodable_recursive(value) + """Convert runtime values to JSON-serializable structures.""" + + result = self.value_to_json_encodable_recursive(value) if isinstance(result, Mapping) or result is None: return result return {} - def _to_json_encodable_recursive(self, value: Any): + def value_to_json_encodable_recursive(self, value: Any): if value is None: return value if isinstance(value, (bool, int, str, float)): @@ -29,7 +31,7 @@ class WorkflowRuntimeTypeConverter: # Convert Decimal to float for JSON serialization return float(value) if isinstance(value, Segment): - return self._to_json_encodable_recursive(value.value) + return self.value_to_json_encodable_recursive(value.value) if isinstance(value, File): return value.to_dict() if isinstance(value, BaseModel): @@ -37,11 +39,11 @@ class WorkflowRuntimeTypeConverter: if isinstance(value, dict): res = {} for k, v in value.items(): - res[k] = self._to_json_encodable_recursive(v) + res[k] = self.value_to_json_encodable_recursive(v) return res if isinstance(value, list): res_list = [] for item in value: - res_list.append(self._to_json_encodable_recursive(item)) + res_list.append(self.value_to_json_encodable_recursive(item)) return res_list return value diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index c0279f893b..03e6cbda68 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" @@ -102,7 +102,7 @@ elif [[ "${MODE}" == "job" ]]; then fi echo "Running Flask job command: flask $*" - + # Temporarily disable exit on error to capture exit code set +e flask "$@" diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 08cf96c1c1..aa9723f375 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -102,6 +102,8 @@ def init_app(app: DifyApp) -> Celery: imports = [ "tasks.async_workflow_tasks", # trigger workers "tasks.trigger_processing_tasks", # async trigger processing + "tasks.generate_summary_index_task", # summary index generation + "tasks.regenerate_summary_index_task", # summary index regeneration ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME @@ -149,6 +151,12 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.queue_monitor_task.queue_monitor_task", "schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30), } + if dify_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK: + imports.append("tasks.human_input_timeout_tasks") + beat_schedule["human_input_form_timeout"] = { + "task": "human_input_form_timeout.check_and_resume", + "schedule": timedelta(minutes=dify_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL), + } if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") imports.append("tasks.process_tenant_plugin_autoupgrade_check_task") diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 5e75bc36b0..0797a3cb98 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -8,12 +8,16 @@ from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union import redis from redis import RedisError from redis.cache import CacheConfig +from redis.client import PubSub from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.sentinel import Sentinel from configs import dify_config from dify_app import DifyApp +from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol +from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel +from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel if TYPE_CHECKING: from redis.lock import Lock @@ -106,6 +110,7 @@ class RedisClientWrapper: def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... def zcard(self, name: str | bytes) -> Any: ... def getdel(self, name: str | bytes) -> Any: ... + def pubsub(self) -> PubSub: ... def __getattr__(self, item: str) -> Any: if self._client is None: @@ -114,6 +119,7 @@ class RedisClientWrapper: redis_client: RedisClientWrapper = RedisClientWrapper() +pubsub_redis_client: RedisClientWrapper = RedisClientWrapper() def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: @@ -226,6 +232,12 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis return client +def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]: + if use_clusters: + return RedisCluster.from_url(pubsub_url) + return redis.Redis.from_url(pubsub_url) + + def init_app(app: DifyApp): """Initialize Redis client and attach it to the app.""" global redis_client @@ -244,6 +256,24 @@ def init_app(app: DifyApp): redis_client.initialize(client) app.extensions["redis"] = redis_client + pubsub_client = client + if dify_config.normalized_pubsub_redis_url: + pubsub_client = _create_pubsub_client( + dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS + ) + pubsub_redis_client.initialize(pubsub_client) + + +def get_pubsub_redis_client() -> RedisClientWrapper: + return pubsub_redis_client + + +def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol: + redis_conn = get_pubsub_redis_client() + if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded": + return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType] + return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType] + P = ParamSpec("P") R = TypeVar("R") diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index f67723630b..817c8b0448 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -13,6 +13,7 @@ from typing import Any from sqlalchemy.orm import sessionmaker +from core.workflow.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value @@ -207,8 +208,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep reverse=True, ) - if deduplicated_results: - return _dict_to_workflow_node_execution_model(deduplicated_results[0]) + for row in deduplicated_results: + model = _dict_to_workflow_node_execution_model(row) + if model.status != WorkflowNodeExecutionStatus.PAUSED: + return model return None @@ -309,6 +312,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep if model and model.id: # Ensure model is valid models.append(model) + models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED] + # Sort by index DESC for trace visualization models.sort(key=lambda x: x.index, reverse=True) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index d8ae0ad8b8..cda46f2339 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -192,6 +192,7 @@ class StatusCount(ResponseModel): success: int failed: int partial_success: int + paused: int class ModelConfig(ResponseModel): diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 1e5ec7d200..ff6578098b 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -39,6 +39,14 @@ dataset_retrieval_model_fields = { "score_threshold_enabled": fields.Boolean, "score_threshold": fields.Float, } + +dataset_summary_index_fields = { + "enable": fields.Boolean, + "model_name": fields.String, + "model_provider_name": fields.String, + "summary_prompt": fields.String, +} + external_retrieval_model_fields = { "top_k": fields.Integer, "score_threshold": fields.Float, @@ -83,6 +91,7 @@ dataset_detail_fields = { "embedding_model_provider": fields.String, "embedding_available": fields.Boolean, "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), + "summary_index_setting": fields.Nested(dataset_summary_index_fields), "tags": fields.List(fields.Nested(tag_fields)), "doc_form": fields.String, "external_knowledge_info": fields.Nested(external_knowledge_info_fields), diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 9be59f7454..35a2a04f3e 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -33,6 +33,11 @@ document_fields = { "hit_count": fields.Integer, "doc_form": fields.String, "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), + # Summary index generation status: + # "SUMMARIZING" (when task is queued and generating) + "summary_index_status": fields.String, + # Whether this document needs summary index generation + "need_summary": fields.Boolean, } document_with_segments_fields = { @@ -60,6 +65,10 @@ document_with_segments_fields = { "completed_segments": fields.Integer, "total_segments": fields.Integer, "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), + # Summary index generation status: + # "SUMMARIZING" (when task is queued and generating) + "summary_index_status": fields.String, + "need_summary": fields.Boolean, # Whether this document needs summary index generation } dataset_and_document_fields = { diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index e70f9fa722..0b54992835 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -58,4 +58,5 @@ hit_testing_record_fields = { "score": fields.Float, "tsne_position": fields.Raw, "files": fields.List(fields.Nested(files_fields)), + "summary": fields.String, # Summary content if retrieved via summary index } diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index c81e482f73..77b26a7423 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -6,6 +6,7 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, field_validator +from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from core.file import File from fields.conversation_fields import AgentThought, JSONValue, MessageFile @@ -36,6 +37,7 @@ class RetrieverResource(ResponseModel): segment_position: int | None = None index_node_hash: str | None = None content: str | None = None + summary: str | None = None created_at: int | None = None @field_validator("created_at", mode="before") @@ -60,6 +62,7 @@ class MessageListItem(ResponseModel): message_files: list[MessageFile] status: str error: str | None = None + extra_contents: list[ExecutionExtraContentDomainModel] @field_validator("inputs", mode="before") @classmethod diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 56d6b68378..2ce9fb154c 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -49,4 +49,5 @@ segment_fields = { "stopped_at": TimestampField, "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "attachments": fields.List(fields.Nested(attachment_fields)), + "summary": fields.String, # Summary content for the segment } diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index 7d4b8e63ca..fa2be421a1 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -162,7 +162,7 @@ class RedisSubscriptionBase(Subscription): self._start_if_needed() return iter(self._message_iterator()) - def receive(self, timeout: float | None = None) -> bytes | None: + def receive(self, timeout: float | None = 0.1) -> bytes | None: """Receive the next message from the subscription.""" if self._closed.is_set(): raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed") diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index d190c51bbc..9e8ab90e8e 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -61,7 +61,14 @@ class _RedisShardedSubscription(RedisSubscriptionBase): def _get_message(self) -> dict | None: assert self._pubsub is not None - return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined] + # NOTE(QuantumGhost): this is an issue in + # upstream code. If Sharded PubSub is used with Cluster, the + # `ClusterPubSub.get_sharded_message` will return `None` regardless of + # message['type']. + # + # Since we have already filtered at the caller's site, we can safely set + # `ignore_subscribe_messages=False`. + return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined] def _get_message_type(self) -> str: return "smessage" diff --git a/api/libs/email_template_renderer.py b/api/libs/email_template_renderer.py new file mode 100644 index 0000000000..98ea30ab46 --- /dev/null +++ b/api/libs/email_template_renderer.py @@ -0,0 +1,49 @@ +""" +Email template rendering helpers with configurable safety modes. +""" + +import time +from collections.abc import Mapping +from typing import Any + +from flask import render_template_string +from jinja2.runtime import Context +from jinja2.sandbox import ImmutableSandboxedEnvironment + +from configs import dify_config +from configs.feature import TemplateMode + + +class SandboxedEnvironment(ImmutableSandboxedEnvironment): + """Sandboxed environment with execution timeout.""" + + def __init__(self, timeout: int, *args: Any, **kwargs: Any): + self._deadline = time.time() + timeout if timeout else None + super().__init__(*args, **kwargs) + + def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any: + if self._deadline is not None and time.time() > self._deadline: + raise TimeoutError("Template rendering timeout") + return super().call(context, obj, *args, **kwargs) + + +def render_email_template(template: str, substitutions: Mapping[str, str]) -> str: + """ + Render email template content according to the configured template mode. + + In unsafe mode, Jinja expressions are evaluated directly. + In sandbox mode, a sandboxed environment with timeout is used. + In disabled mode, the template is returned without rendering. + """ + mode = dify_config.MAIL_TEMPLATING_MODE + timeout = dify_config.MAIL_TEMPLATING_TIMEOUT + + if mode == TemplateMode.UNSAFE: + return render_template_string(template, **substitutions) + if mode == TemplateMode.SANDBOX: + env = SandboxedEnvironment(timeout=timeout) + tmpl = env.from_string(template) + return tmpl.render(substitutions) + if mode == TemplateMode.DISABLED: + return template + raise ValueError(f"Unsupported mail templating mode: {mode}") diff --git a/api/libs/flask_utils.py b/api/libs/flask_utils.py index beade7eb25..e45c8fe319 100644 --- a/api/libs/flask_utils.py +++ b/api/libs/flask_utils.py @@ -1,12 +1,15 @@ import contextvars from collections.abc import Iterator from contextlib import contextmanager -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar from flask import Flask, g T = TypeVar("T") +if TYPE_CHECKING: + from models import Account, EndUser + @contextmanager def preserve_flask_contexts( @@ -64,3 +67,7 @@ def preserve_flask_contexts( finally: # Any cleanup can be added here if needed pass + + +def set_login_user(user: "Account | EndUser"): + g._login_user = user diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 23eb8dca05..ef26699fb3 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -136,7 +136,7 @@ class PKCS1OAepCipher: # Step 3a (OS2IP) em_int = bytes_to_long(em) # Step 3b (RSAEP) - m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) + m_int: int = gmpy2.powmod(em_int, self._key.e, self._key.n) # type: ignore[attr-defined] # Step 3c (I2OSP) c = long_to_bytes(m_int, k) return c @@ -169,7 +169,7 @@ class PKCS1OAepCipher: ct_int = bytes_to_long(ciphertext) # Step 2b (RSADP) # m_int = self._key._decrypt(ct_int) - m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) + m_int: int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # type: ignore[attr-defined] # Complete step 2c (I2OSP) em = long_to_bytes(m_int, k) # Step 3a diff --git a/api/libs/helper.py b/api/libs/helper.py index 07c4823727..fb577b9c99 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -7,10 +7,10 @@ import struct import subprocess import time import uuid -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast from uuid import UUID from zoneinfo import available_timezones @@ -126,6 +126,13 @@ class TimestampField(fields.Raw): return int(value.timestamp()) +class OptionalTimestampField(fields.Raw): + def format(self, value) -> int | None: + if value is None: + return None + return int(value.timestamp()) + + def email(email): # Define a regex pattern for email addresses pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$" @@ -237,6 +244,26 @@ def convert_datetime_to_date(field, target_timezone: str = ":tz"): def generate_string(n): + """ + Generates a cryptographically secure random string of the specified length. + + This function uses a cryptographically secure pseudorandom number generator (CSPRNG) + to create a string composed of ASCII letters (both uppercase and lowercase) and digits. + + Each character in the generated string provides approximately 5.95 bits of entropy + (log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the + length of the string (`n`) should be at least 22 characters. + + Args: + n (int): The length of the random string to generate. For secure usage, + `n` should be 22 or greater. + + Returns: + str: A random string of length `n` composed of ASCII letters and digits. + + Note: + This function is suitable for generating credentials or other secure tokens. + """ letters_digits = string.ascii_letters + string.digits result = "" for _ in range(n): @@ -405,11 +432,35 @@ class TokenManager: return f"{token_type}:account:{account_id}" +class _RateLimiterRedisClient(Protocol): + def zadd(self, name: str | bytes, mapping: dict[str | bytes | int | float, float | int | str | bytes]) -> int: ... + + def zremrangebyscore(self, name: str | bytes, min: str | float, max: str | float) -> int: ... + + def zcard(self, name: str | bytes) -> int: ... + + def expire(self, name: str | bytes, time: int) -> bool: ... + + +def _default_rate_limit_member_factory() -> str: + current_time = int(time.time()) + return f"{current_time}:{secrets.token_urlsafe(nbytes=8)}" + + class RateLimiter: - def __init__(self, prefix: str, max_attempts: int, time_window: int): + def __init__( + self, + prefix: str, + max_attempts: int, + time_window: int, + member_factory: Callable[[], str] = _default_rate_limit_member_factory, + redis_client: _RateLimiterRedisClient = redis_client, + ): self.prefix = prefix self.max_attempts = max_attempts self.time_window = time_window + self._member_factory = member_factory + self._redis_client = redis_client def _get_key(self, email: str) -> str: return f"{self.prefix}:{email}" @@ -419,8 +470,8 @@ class RateLimiter: current_time = int(time.time()) window_start_time = current_time - self.time_window - redis_client.zremrangebyscore(key, "-inf", window_start_time) - attempts = redis_client.zcard(key) + self._redis_client.zremrangebyscore(key, "-inf", window_start_time) + attempts = self._redis_client.zcard(key) if attempts and int(attempts) >= self.max_attempts: return True @@ -428,7 +479,8 @@ class RateLimiter: def increment_rate_limit(self, email: str): key = self._get_key(email) + member = self._member_factory() current_time = int(time.time()) - redis_client.zadd(key, {current_time: current_time}) - redis_client.expire(key, self.time_window * 2) + self._redis_client.zadd(key, {member: current_time}) + self._redis_client.expire(key, self.time_window * 2) diff --git a/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py new file mode 100644 index 0000000000..c6c72859dc --- /dev/null +++ b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py @@ -0,0 +1,107 @@ +"""add summary index feature + +Revision ID: 788d3099ae3a +Revises: 9d77545f524e +Create Date: 2026-01-27 18:15:45.277928 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + +def _is_pg(conn): + return conn.dialect.name == "postgresql" + +# revision identifiers, used by Alembic. +revision = '788d3099ae3a' +down_revision = '9d77545f524e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if _is_pg(conn): + op.create_table('document_segment_summaries', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('chunk_id', models.types.StringUUID(), nullable=False), + sa.Column('summary_content', models.types.LongText(), nullable=True), + sa.Column('summary_index_node_id', sa.String(length=255), nullable=True), + sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False), + sa.Column('error', models.types.LongText(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey') + ) + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False) + batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False) + batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True)) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + else: + # MySQL: Use compatible syntax + op.create_table( + 'document_segment_summaries', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('chunk_id', models.types.StringUUID(), nullable=False), + sa.Column('summary_content', models.types.LongText(), nullable=True), + sa.Column('summary_index_node_id', sa.String(length=255), nullable=True), + sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False), + sa.Column('error', models.types.LongText(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey'), + ) + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False) + batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False) + batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True)) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_column('need_summary') + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('summary_index_setting') + + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.drop_index('document_segment_summaries_status_idx') + batch_op.drop_index('document_segment_summaries_document_id_idx') + batch_op.drop_index('document_segment_summaries_dataset_id_idx') + batch_op.drop_index('document_segment_summaries_chunk_id_idx') + + op.drop_table('document_segment_summaries') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2026_01_29_1415-e8c3b3c46151_add_human_input_related_db_models.py b/api/migrations/versions/2026_01_29_1415-e8c3b3c46151_add_human_input_related_db_models.py new file mode 100644 index 0000000000..a1546ef940 --- /dev/null +++ b/api/migrations/versions/2026_01_29_1415-e8c3b3c46151_add_human_input_related_db_models.py @@ -0,0 +1,99 @@ +"""Add human input related db models + +Revision ID: e8c3b3c46151 +Revises: 788d3099ae3a +Create Date: 2026-01-29 14:15:23.081903 + +""" + +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "e8c3b3c46151" +down_revision = "788d3099ae3a" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "execution_extra_contents", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + + sa.Column("type", sa.String(length=30), nullable=False), + sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False), + sa.Column("message_id", models.types.StringUUID(), nullable=True), + sa.Column("form_id", models.types.StringUUID(), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("execution_extra_contents_pkey")), + ) + with op.batch_alter_table("execution_extra_contents", schema=None) as batch_op: + batch_op.create_index(batch_op.f("execution_extra_contents_message_id_idx"), ["message_id"], unique=False) + batch_op.create_index( + batch_op.f("execution_extra_contents_workflow_run_id_idx"), ["workflow_run_id"], unique=False + ) + + op.create_table( + "human_input_form_deliveries", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + + sa.Column("form_id", models.types.StringUUID(), nullable=False), + sa.Column("delivery_method_type", sa.String(length=20), nullable=False), + sa.Column("delivery_config_id", models.types.StringUUID(), nullable=True), + sa.Column("channel_payload", sa.Text(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_deliveries_pkey")), + ) + + op.create_table( + "human_input_form_recipients", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + + sa.Column("form_id", models.types.StringUUID(), nullable=False), + sa.Column("delivery_id", models.types.StringUUID(), nullable=False), + sa.Column("recipient_type", sa.String(length=20), nullable=False), + sa.Column("recipient_payload", sa.Text(), nullable=False), + sa.Column("access_token", sa.VARCHAR(length=32), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_recipients_pkey")), + ) + with op.batch_alter_table('human_input_form_recipients', schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f('human_input_form_recipients_access_token_key'), ['access_token']) + + op.create_table( + "human_input_forms", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("app_id", models.types.StringUUID(), nullable=False), + sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True), + sa.Column("form_kind", sa.String(length=20), nullable=False), + sa.Column("node_id", sa.String(length=60), nullable=False), + sa.Column("form_definition", sa.Text(), nullable=False), + sa.Column("rendered_content", sa.Text(), nullable=False), + sa.Column("status", sa.String(length=20), nullable=False), + sa.Column("expiration_time", sa.DateTime(), nullable=False), + sa.Column("selected_action_id", sa.String(length=200), nullable=True), + sa.Column("submitted_data", sa.Text(), nullable=True), + sa.Column("submitted_at", sa.DateTime(), nullable=True), + sa.Column("submission_user_id", models.types.StringUUID(), nullable=True), + sa.Column("submission_end_user_id", models.types.StringUUID(), nullable=True), + sa.Column("completed_by_recipient_id", models.types.StringUUID(), nullable=True), + + sa.PrimaryKeyConstraint("id", name=op.f("human_input_forms_pkey")), + ) + + +def downgrade(): + op.drop_table("human_input_forms") + op.drop_table("human_input_form_recipients") + op.drop_table("human_input_form_deliveries") + op.drop_table("execution_extra_contents") diff --git a/api/models/__init__.py b/api/models/__init__.py index 74b33130ef..1d5d604ba7 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -34,6 +34,8 @@ from .enums import ( WorkflowRunTriggeredFrom, WorkflowTriggerStatus, ) +from .execution_extra_content import ExecutionExtraContent, HumanInputContent +from .human_input import HumanInputForm from .model import ( AccountTrialAppRecord, ApiRequest, @@ -155,9 +157,12 @@ __all__ = [ "DocumentSegment", "Embedding", "EndUser", + "ExecutionExtraContent", "ExporleBanner", "ExternalKnowledgeApis", "ExternalKnowledgeBindings", + "HumanInputContent", + "HumanInputForm", "IconType", "InstalledApp", "InvitationCode", diff --git a/api/models/base.py b/api/models/base.py index c8a5e20f25..aa93d31199 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -41,7 +41,7 @@ class DefaultFieldsMixin: ) updated_at: Mapped[datetime] = mapped_column( - __name_pos=DateTime, + DateTime, nullable=False, default=naive_utc_now, server_default=func.current_timestamp(), diff --git a/api/models/dataset.py b/api/models/dataset.py index 62f11b8c72..e7da2961bc 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -72,6 +72,7 @@ class Dataset(Base): keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10")) collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(AdjustedJSON, nullable=True) + summary_index_setting = mapped_column(AdjustedJSON, nullable=True) built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) icon_info = mapped_column(AdjustedJSON, nullable=True) runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'")) @@ -419,6 +420,7 @@ class Document(Base): doc_metadata = mapped_column(AdjustedJSON, nullable=True) doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) doc_language = mapped_column(String(255), nullable=True) + need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -1575,3 +1577,36 @@ class SegmentAttachmentBinding(Base): segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class DocumentSegmentSummary(Base): + __tablename__ = "document_segment_summaries" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"), + sa.Index("document_segment_summaries_dataset_id_idx", "dataset_id"), + sa.Index("document_segment_summaries_document_id_idx", "document_id"), + sa.Index("document_segment_summaries_chunk_id_idx", "chunk_id"), + sa.Index("document_segment_summaries_status_idx", "status"), + ) + + id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # corresponds to DocumentSegment.id or parent chunk id + chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + summary_content: Mapped[str] = mapped_column(LongText, nullable=True) + summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) + summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'")) + error: Mapped[str] = mapped_column(LongText, nullable=True) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + disabled_by = mapped_column(StringUUID, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + def __repr__(self): + return f"" diff --git a/api/models/enums.py b/api/models/enums.py index 8cd3d4cf2a..2bc61120ce 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -36,6 +36,7 @@ class MessageStatus(StrEnum): """ NORMAL = "normal" + PAUSED = "paused" ERROR = "error" diff --git a/api/models/execution_extra_content.py b/api/models/execution_extra_content.py new file mode 100644 index 0000000000..d0bd34efec --- /dev/null +++ b/api/models/execution_extra_content.py @@ -0,0 +1,78 @@ +from enum import StrEnum, auto +from typing import TYPE_CHECKING + +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base import Base, DefaultFieldsMixin +from .types import EnumText, StringUUID + +if TYPE_CHECKING: + from .human_input import HumanInputForm + + +class ExecutionContentType(StrEnum): + HUMAN_INPUT = auto() + + +class ExecutionExtraContent(DefaultFieldsMixin, Base): + """ExecutionExtraContent stores extra contents produced during workflow / chatflow execution.""" + + # The `ExecutionExtraContent` uses single table inheritance to model different + # kinds of contents produced during message generation. + # + # See: https://docs.sqlalchemy.org/en/20/orm/inheritance.html#single-table-inheritance + + __tablename__ = "execution_extra_contents" + __mapper_args__ = { + "polymorphic_abstract": True, + "polymorphic_on": "type", + "with_polymorphic": "*", + } + # type records the type of the content. It serves as the `discriminator` for the + # single table inheritance. + type: Mapped[ExecutionContentType] = mapped_column( + EnumText(ExecutionContentType, length=30), + nullable=False, + ) + + # `workflow_run_id` records the workflow execution which generates this content, correspond to + # `WorkflowRun.id`. + workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) + + # `message_id` records the messages generated by the execution associated with this `ExecutionExtraContent`. + # It references to `Message.id`. + # + # For workflow execution, this field is `None`. + # + # For chatflow execution, `message_id`` is not None, and the following condition holds: + # + # The message referenced by `message_id` has `message.workflow_run_id == execution_extra_content.workflow_run_id` + # + message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, index=True) + + +class HumanInputContent(ExecutionExtraContent): + """HumanInputContent is a concrete class that represents human input content. + It should only be initialized with the `new` class method.""" + + __mapper_args__ = { + "polymorphic_identity": ExecutionContentType.HUMAN_INPUT, + } + + # A relation to HumanInputForm table. + # + # While the form_id column is nullable in database (due to the nature of single table inheritance), + # the form_id field should not be null for a given `HumanInputContent` instance. + form_id: Mapped[str] = mapped_column(StringUUID, nullable=True) + + @classmethod + def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent": + return cls(form_id=form_id, message_id=message_id) + + form: Mapped["HumanInputForm"] = relationship( + "HumanInputForm", + foreign_keys=[form_id], + uselist=False, + lazy="raise", + primaryjoin="foreign(HumanInputContent.form_id) == HumanInputForm.id", + ) diff --git a/api/models/human_input.py b/api/models/human_input.py new file mode 100644 index 0000000000..5208461de1 --- /dev/null +++ b/api/models/human_input.py @@ -0,0 +1,237 @@ +from datetime import datetime +from enum import StrEnum +from typing import Annotated, Literal, Self, final + +import sqlalchemy as sa +from pydantic import BaseModel, Field +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from core.workflow.nodes.human_input.enums import ( + DeliveryMethodType, + HumanInputFormKind, + HumanInputFormStatus, +) +from libs.helper import generate_string + +from .base import Base, DefaultFieldsMixin +from .types import EnumText, StringUUID + +_token_length = 22 +# A 32-character string can store a base64-encoded value with 192 bits of entropy +# or a base62-encoded value with over 180 bits of entropy, providing sufficient +# uniqueness for most use cases. +_token_field_length = 32 +_email_field_length = 330 + + +def _generate_token() -> str: + return generate_string(_token_length) + + +class HumanInputForm(DefaultFieldsMixin, Base): + __tablename__ = "human_input_forms" + + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + form_kind: Mapped[HumanInputFormKind] = mapped_column( + EnumText(HumanInputFormKind), + nullable=False, + default=HumanInputFormKind.RUNTIME, + ) + + # The human input node the current form corresponds to. + node_id: Mapped[str] = mapped_column(sa.String(60), nullable=False) + form_definition: Mapped[str] = mapped_column(sa.Text, nullable=False) + rendered_content: Mapped[str] = mapped_column(sa.Text, nullable=False) + status: Mapped[HumanInputFormStatus] = mapped_column( + EnumText(HumanInputFormStatus), + nullable=False, + default=HumanInputFormStatus.WAITING, + ) + + expiration_time: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + ) + + # Submission-related fields (nullable until a submission happens). + selected_action_id: Mapped[str | None] = mapped_column(sa.String(200), nullable=True) + submitted_data: Mapped[str | None] = mapped_column(sa.Text, nullable=True) + submitted_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True) + submission_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + submission_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + + completed_by_recipient_id: Mapped[str | None] = mapped_column( + StringUUID, + nullable=True, + ) + + deliveries: Mapped[list["HumanInputDelivery"]] = relationship( + "HumanInputDelivery", + primaryjoin="HumanInputForm.id == foreign(HumanInputDelivery.form_id)", + uselist=True, + back_populates="form", + lazy="raise", + ) + completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship( + "HumanInputFormRecipient", + primaryjoin="HumanInputForm.completed_by_recipient_id == foreign(HumanInputFormRecipient.id)", + lazy="raise", + viewonly=True, + ) + + +class HumanInputDelivery(DefaultFieldsMixin, Base): + __tablename__ = "human_input_form_deliveries" + + form_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + delivery_method_type: Mapped[DeliveryMethodType] = mapped_column( + EnumText(DeliveryMethodType), + nullable=False, + ) + delivery_config_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + channel_payload: Mapped[str] = mapped_column(sa.Text, nullable=False) + + form: Mapped[HumanInputForm] = relationship( + "HumanInputForm", + uselist=False, + foreign_keys=[form_id], + primaryjoin="HumanInputDelivery.form_id == HumanInputForm.id", + back_populates="deliveries", + lazy="raise", + ) + + recipients: Mapped[list["HumanInputFormRecipient"]] = relationship( + "HumanInputFormRecipient", + primaryjoin="HumanInputDelivery.id == foreign(HumanInputFormRecipient.delivery_id)", + uselist=True, + back_populates="delivery", + # Require explicit preloading + lazy="raise", + ) + + +class RecipientType(StrEnum): + # EMAIL_MEMBER member means that the + EMAIL_MEMBER = "email_member" + EMAIL_EXTERNAL = "email_external" + # STANDALONE_WEB_APP is used by the standalone web app. + # + # It's not used while running workflows / chatflows containing HumanInput + # node inside console. + STANDALONE_WEB_APP = "standalone_web_app" + # CONSOLE is used while running workflows / chatflows containing HumanInput + # node inside console. (E.G. running installed apps or debugging workflows / chatflows) + CONSOLE = "console" + # BACKSTAGE is used for backstage input inside console. + BACKSTAGE = "backstage" + + +@final +class EmailMemberRecipientPayload(BaseModel): + TYPE: Literal[RecipientType.EMAIL_MEMBER] = RecipientType.EMAIL_MEMBER + user_id: str + + # The `email` field here is only used for mail sending. + email: str + + +@final +class EmailExternalRecipientPayload(BaseModel): + TYPE: Literal[RecipientType.EMAIL_EXTERNAL] = RecipientType.EMAIL_EXTERNAL + email: str + + +@final +class StandaloneWebAppRecipientPayload(BaseModel): + TYPE: Literal[RecipientType.STANDALONE_WEB_APP] = RecipientType.STANDALONE_WEB_APP + + +@final +class ConsoleRecipientPayload(BaseModel): + TYPE: Literal[RecipientType.CONSOLE] = RecipientType.CONSOLE + account_id: str | None = None + + +@final +class BackstageRecipientPayload(BaseModel): + TYPE: Literal[RecipientType.BACKSTAGE] = RecipientType.BACKSTAGE + account_id: str | None = None + + +@final +class ConsoleDeliveryPayload(BaseModel): + type: Literal["console"] = "console" + internal: bool = True + + +RecipientPayload = Annotated[ + EmailMemberRecipientPayload + | EmailExternalRecipientPayload + | StandaloneWebAppRecipientPayload + | ConsoleRecipientPayload + | BackstageRecipientPayload, + Field(discriminator="TYPE"), +] + + +class HumanInputFormRecipient(DefaultFieldsMixin, Base): + __tablename__ = "human_input_form_recipients" + + form_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + delivery_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + recipient_type: Mapped["RecipientType"] = mapped_column(EnumText(RecipientType), nullable=False) + recipient_payload: Mapped[str] = mapped_column(sa.Text, nullable=False) + + # Token primarily used for authenticated resume links (email, etc.). + access_token: Mapped[str | None] = mapped_column( + sa.VARCHAR(_token_field_length), + nullable=False, + default=_generate_token, + unique=True, + ) + + delivery: Mapped[HumanInputDelivery] = relationship( + "HumanInputDelivery", + uselist=False, + foreign_keys=[delivery_id], + back_populates="recipients", + primaryjoin="HumanInputFormRecipient.delivery_id == HumanInputDelivery.id", + # Require explicit preloading + lazy="raise", + ) + + form: Mapped[HumanInputForm] = relationship( + "HumanInputForm", + uselist=False, + foreign_keys=[form_id], + primaryjoin="HumanInputFormRecipient.form_id == HumanInputForm.id", + # Require explicit preloading + lazy="raise", + ) + + @classmethod + def new( + cls, + form_id: str, + delivery_id: str, + payload: RecipientPayload, + ) -> Self: + recipient_model = cls( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=payload.TYPE, + recipient_payload=payload.model_dump_json(), + access_token=_generate_token(), + ) + return recipient_model diff --git a/api/models/model.py b/api/models/model.py index be0cfd58a7..c12362f359 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import re import uuid -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto @@ -657,16 +657,22 @@ class AccountTrialAppRecord(Base): return user -class ExporleBanner(Base): +class ExporleBanner(TypeBase): __tablename__ = "exporle_banners" __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) - id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) - content = mapped_column(sa.JSON, nullable=False) - link = mapped_column(String(255), nullable=False) - sort = mapped_column(sa.Integer, nullable=False) - status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying")) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) + content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) + link: Mapped[str] = mapped_column(String(255), nullable=False) + sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) + status: Mapped[str] = mapped_column( + sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled" + ) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + language: Mapped[str] = mapped_column( + String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US" + ) class OAuthProviderApp(TypeBase): @@ -937,6 +943,7 @@ class Conversation(Base): WorkflowExecutionStatus.FAILED: 0, WorkflowExecutionStatus.STOPPED: 0, WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0, + WorkflowExecutionStatus.PAUSED: 0, } for message in messages: @@ -957,6 +964,7 @@ class Conversation(Base): "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], "failed": status_counts[WorkflowExecutionStatus.FAILED], "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], + "paused": status_counts[WorkflowExecutionStatus.PAUSED], } @property @@ -1339,6 +1347,14 @@ class Message(Base): db.session.commit() return result + # TODO(QuantumGhost): dirty hacks, fix this later. + def set_extra_contents(self, contents: Sequence[dict[str, Any]]) -> None: + self._extra_contents = list(contents) + + @property + def extra_contents(self) -> list[dict[str, Any]]: + return getattr(self, "_extra_contents", []) + @property def workflow_run(self): if self.workflow_run_id: diff --git a/api/models/workflow.py b/api/models/workflow.py index df83228c2a..94e0881bd1 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -20,6 +20,7 @@ from sqlalchemy import ( select, ) from sqlalchemy.orm import Mapped, declared_attr, mapped_column +from typing_extensions import deprecated from core.file.constants import maybe_file_object from core.file.models import File @@ -30,7 +31,7 @@ from core.workflow.constants import ( SYSTEM_VARIABLE_NODE_ID, ) from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from core.workflow.enums import NodeType +from core.workflow.enums import NodeType, WorkflowExecutionStatus from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -405,6 +406,11 @@ class Workflow(Base): # bug return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @property + @deprecated( + "This property is not accurate for determining if a workflow is published as a tool." + "It only checks if there's a WorkflowToolProvider for the app, " + "not if this specific workflow version is the one being used by the tool." + ) def tool_published(self) -> bool: """ DEPRECATED: This property is not accurate for determining if a workflow is published as a tool. @@ -607,13 +613,16 @@ class WorkflowRun(Base): version: Mapped[str] = mapped_column(String(255)) graph: Mapped[str | None] = mapped_column(LongText) inputs: Mapped[str | None] = mapped_column(LongText) - status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded + status: Mapped[WorkflowExecutionStatus] = mapped_column( + EnumText(WorkflowExecutionStatus, length=255), + nullable=False, + ) outputs: Mapped[str | None] = mapped_column(LongText, default="{}") error: Mapped[str | None] = mapped_column(LongText) elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) - created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) finished_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -629,11 +638,13 @@ class WorkflowRun(Base): ) @property + @deprecated("This method is retained for historical reasons; avoid using it if possible.") def created_by_account(self): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property + @deprecated("This method is retained for historical reasons; avoid using it if possible.") def created_by_end_user(self): from .model import EndUser @@ -653,6 +664,7 @@ class WorkflowRun(Base): return json.loads(self.outputs) if self.outputs else {} @property + @deprecated("This method is retained for historical reasons; avoid using it if possible.") def message(self): from .model import Message @@ -661,6 +673,7 @@ class WorkflowRun(Base): ) @property + @deprecated("This method is retained for historical reasons; avoid using it if possible.") def workflow(self): return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() @@ -1861,7 +1874,12 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): def to_entity(self) -> PauseReason: if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: - return HumanInputRequired(form_id=self.form_id, node_id=self.node_id) + return HumanInputRequired( + form_id=self.form_id, + form_content="", + node_id=self.node_id, + node_title="", + ) elif self.type_ == PauseReasonType.SCHEDULED_PAUSE: return SchedulingPause(message=self.message) else: diff --git a/api/pyproject.toml b/api/pyproject.toml index 575c1434c5..16395573f4 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.11.4" +version = "1.12.0" requires-python = ">=3.11,<3.13" dependencies = [ @@ -40,7 +40,7 @@ dependencies = [ "numpy~=1.26.4", "openpyxl~=3.1.5", "opik~=1.8.72", - "litellm==1.77.1", # Pinned to avoid madoka dependency issue + "litellm==1.77.1", # Pinned to avoid madoka dependency issue "opentelemetry-api==1.27.0", "opentelemetry-distro==0.48b0", "opentelemetry-exporter-otlp==1.27.0", @@ -175,6 +175,7 @@ dev = [ # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", + "pytest-xdist>=3.8.0", ] ############################################################ @@ -229,3 +230,23 @@ vdb = [ "mo-vector~=0.1.13", "mysql-connector-python>=9.3.0", ] + +[tool.mypy] + +[[tool.mypy.overrides]] +# targeted ignores for current type-check errors +# TODO(QuantumGhost): suppress type errors in HITL related code. +# fix the type error later +module = [ + "configs.middleware.cache.redis_pubsub_config", + "extensions.ext_redis", + "tasks.workflow_execution_tasks", + "core.workflow.nodes.base.node", + "services.human_input_delivery_test_service", + "core.app.apps.advanced_chat.app_generator", + "controllers.console.human_input_form", + "controllers.console.app.workflow_run", + "repositories.sqlalchemy_api_workflow_node_execution_repository", + "extensions.logstore.repositories.logstore_api_workflow_run_repository", +] +ignore_errors = true diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 5b3f635301..6446eb0d6e 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -10,6 +10,7 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m """ from collections.abc import Sequence +from dataclasses import dataclass from datetime import datetime from typing import Protocol @@ -19,6 +20,27 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload +@dataclass(frozen=True) +class WorkflowNodeExecutionSnapshot: + """ + Minimal snapshot of workflow node execution for stream recovery. + + Only includes fields required by snapshot events. + """ + + execution_id: str # Unique execution identifier (node_execution_id or row id). + node_id: str # Workflow graph node id. + node_type: str # Workflow graph node type (e.g. "human-input"). + title: str # Human-friendly node title. + index: int # Execution order index within the workflow run. + status: str # Execution status (running/succeeded/failed/paused). + elapsed_time: float # Execution elapsed time in seconds. + created_at: datetime # Execution created timestamp. + finished_at: datetime | None # Execution finished timestamp. + iteration_id: str | None = None # Iteration id from execution metadata, if any. + loop_id: str | None = None # Loop id from execution metadata, if any. + + class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): """ Protocol for service-layer operations on WorkflowNodeExecutionModel. @@ -79,6 +101,8 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr Args: tenant_id: The tenant identifier app_id: The application identifier + workflow_id: The workflow identifier + triggered_from: The workflow trigger source workflow_run_id: The workflow run identifier Returns: @@ -86,6 +110,27 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr """ ... + def get_execution_snapshots_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + triggered_from: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionSnapshot]: + """ + Get minimal snapshots for node executions in a workflow run. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionSnapshot ordered by creation time + """ + ... + def get_execution_by_id( self, execution_id: str, diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 1d3954571f..17e01a6e18 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -432,6 +432,13 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): # while creating pause. ... + def get_workflow_pause(self, workflow_run_id: str) -> WorkflowPauseEntity | None: + """Retrieve the current pause for a workflow execution. + + If there is no current pause, this method would return `None`. + """ + ... + def resume_workflow_pause( self, workflow_run_id: str, @@ -627,3 +634,19 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): [{"date": "2024-01-01", "interactions": 2.5}, ...] """ ... + + def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None: + """ + Get a specific workflow run by its id and the associated tenant id. + + This function does not apply application isolation. It should only be used when + the application identifier is not available. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + run_id: Workflow run identifier + + Returns: + WorkflowRun object if found, None otherwise + """ + ... diff --git a/api/repositories/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py index b970f39816..a3c4039aaa 100644 --- a/api/repositories/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -63,6 +63,12 @@ class WorkflowPauseEntity(ABC): """ pass + @property + @abstractmethod + def paused_at(self) -> datetime: + """`paused_at` returns the creation time of the pause.""" + pass + @abstractmethod def get_pause_reasons(self) -> Sequence[PauseReason]: """ @@ -70,7 +76,5 @@ class WorkflowPauseEntity(ABC): Returns a sequence of `PauseReason` objects describing the specific nodes and reasons for which the workflow execution was paused. - This information is related to, but distinct from, the `PauseReason` type - defined in `api/core/workflow/entities/pause_reason.py`. """ ... diff --git a/api/repositories/execution_extra_content_repository.py b/api/repositories/execution_extra_content_repository.py new file mode 100644 index 0000000000..72b5443d2c --- /dev/null +++ b/api/repositories/execution_extra_content_repository.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Protocol + +from core.entities.execution_extra_content import ExecutionExtraContentDomainModel + + +class ExecutionExtraContentRepository(Protocol): + def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: ... + + +__all__ = ["ExecutionExtraContentRepository"] diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index b19cc73bd1..6c696b6478 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -5,6 +5,7 @@ This module provides a concrete implementation of the service repository protoco using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. """ +import json from collections.abc import Sequence from datetime import datetime from typing import cast @@ -13,11 +14,12 @@ from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from models.workflow import ( - WorkflowNodeExecutionModel, - WorkflowNodeExecutionOffload, +from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload +from repositories.api_workflow_node_execution_repository import ( + DifyAPIWorkflowNodeExecutionRepository, + WorkflowNodeExecutionSnapshot, ) -from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): @@ -79,6 +81,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut WorkflowNodeExecutionModel.app_id == app_id, WorkflowNodeExecutionModel.workflow_id == workflow_id, WorkflowNodeExecutionModel.node_id == node_id, + WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED, ) .order_by(desc(WorkflowNodeExecutionModel.created_at)) .limit(1) @@ -117,6 +120,80 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut with self._session_maker() as session: return session.execute(stmt).scalars().all() + def get_execution_snapshots_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + triggered_from: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionSnapshot]: + stmt = ( + select( + WorkflowNodeExecutionModel.id, + WorkflowNodeExecutionModel.node_execution_id, + WorkflowNodeExecutionModel.node_id, + WorkflowNodeExecutionModel.node_type, + WorkflowNodeExecutionModel.title, + WorkflowNodeExecutionModel.index, + WorkflowNodeExecutionModel.status, + WorkflowNodeExecutionModel.elapsed_time, + WorkflowNodeExecutionModel.created_at, + WorkflowNodeExecutionModel.finished_at, + WorkflowNodeExecutionModel.execution_metadata, + ) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_id == workflow_id, + WorkflowNodeExecutionModel.triggered_from == triggered_from, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + .order_by( + asc(WorkflowNodeExecutionModel.created_at), + asc(WorkflowNodeExecutionModel.index), + ) + ) + + with self._session_maker() as session: + rows = session.execute(stmt).all() + + return [self._row_to_snapshot(row) for row in rows] + + @staticmethod + def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot: + metadata: dict[str, object] = {} + execution_metadata = getattr(row, "execution_metadata", None) + if execution_metadata: + try: + metadata = json.loads(execution_metadata) + except json.JSONDecodeError: + metadata = {} + iteration_id = metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID.value) + loop_id = metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID.value) + execution_id = getattr(row, "node_execution_id", None) or row.id + elapsed_time = getattr(row, "elapsed_time", None) + created_at = row.created_at + finished_at = getattr(row, "finished_at", None) + if elapsed_time is None: + if finished_at is not None and created_at is not None: + elapsed_time = (finished_at - created_at).total_seconds() + else: + elapsed_time = 0.0 + return WorkflowNodeExecutionSnapshot( + execution_id=str(execution_id), + node_id=row.node_id, + node_type=row.node_type, + title=row.title, + index=row.index, + status=row.status, + elapsed_time=float(elapsed_time), + created_at=created_at, + finished_at=finished_at, + iteration_id=str(iteration_id) if iteration_id else None, + loop_id=str(loop_id) if loop_id else None, + ) + def get_execution_by_id( self, execution_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index d5214be042..00cb979e17 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -19,6 +19,7 @@ Implementation Notes: - Maintains data consistency with proper transaction handling """ +import json import logging import uuid from collections.abc import Callable, Sequence @@ -27,12 +28,14 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa +from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause +from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause from core.workflow.enums import WorkflowExecutionStatus, WorkflowType +from core.workflow.nodes.human_input.entities import FormDefinition from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date @@ -40,6 +43,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom +from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -57,6 +61,67 @@ class _WorkflowRunError(Exception): pass +def _select_recipient_token( + recipients: Sequence[HumanInputFormRecipient], + recipient_type: RecipientType, +) -> str | None: + for recipient in recipients: + if recipient.recipient_type == recipient_type and recipient.access_token: + return recipient.access_token + return None + + +def _build_human_input_required_reason( + reason_model: WorkflowPauseReason, + form_model: HumanInputForm | None, + recipients: Sequence[HumanInputFormRecipient], +) -> HumanInputRequired: + form_content = "" + inputs = [] + actions = [] + display_in_ui = False + resolved_default_values: dict[str, Any] = {} + node_title = "Human Input" + form_id = reason_model.form_id + node_id = reason_model.node_id + if form_model is not None: + form_id = form_model.id + node_id = form_model.node_id or node_id + try: + definition_payload = json.loads(form_model.form_definition) + if "expiration_time" not in definition_payload: + definition_payload["expiration_time"] = form_model.expiration_time + definition = FormDefinition.model_validate(definition_payload) + except ValidationError: + definition = None + + if definition is not None: + form_content = definition.form_content + inputs = list(definition.inputs) + actions = list(definition.user_actions) + display_in_ui = bool(definition.display_in_ui) + resolved_default_values = dict(definition.default_values) + node_title = definition.node_title or node_title + + form_token = ( + _select_recipient_token(recipients, RecipientType.BACKSTAGE) + or _select_recipient_token(recipients, RecipientType.CONSOLE) + or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP) + ) + + return HumanInputRequired( + form_id=form_id, + form_content=form_content, + inputs=inputs, + actions=actions, + display_in_ui=display_in_ui, + node_id=node_id, + node_title=node_title, + form_token=form_token, + resolved_default_values=resolved_default_values, + ) + + class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): """ SQLAlchemy implementation of APIWorkflowRunRepository. @@ -676,9 +741,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): raise ValueError(f"WorkflowRun not found: {workflow_run_id}") # Check if workflow is in RUNNING status - if workflow_run.status != WorkflowExecutionStatus.RUNNING: + # TODO(QuantumGhost): It seems that the persistence of `WorkflowRun.status` + # happens before the execution of GraphLayer + if workflow_run.status not in {WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.PAUSED}: raise _WorkflowRunError( - f"Only WorkflowRun with RUNNING status can be paused, " + f"Only WorkflowRun with RUNNING or PAUSED status can be paused, " f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}" ) # @@ -729,13 +796,48 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) - return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models) + return _PrivateWorkflowPauseEntity( + pause_model=pause_model, + reason_models=pause_reason_models, + pause_reasons=pause_reasons, + ) def _get_reasons_by_pause_id(self, session: Session, pause_id: str): reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id) pause_reason_models = session.scalars(reason_stmt).all() return pause_reason_models + def _hydrate_pause_reasons( + self, + session: Session, + pause_reason_models: Sequence[WorkflowPauseReason], + ) -> list[PauseReason]: + form_ids = [ + reason.form_id + for reason in pause_reason_models + if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id + ] + form_models: dict[str, HumanInputForm] = {} + recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {} + if form_ids: + form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) + for form in session.scalars(form_stmt).all(): + form_models[form.id] = form + + recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + for recipient in session.scalars(recipient_stmt).all(): + recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient) + + pause_reasons: list[PauseReason] = [] + for reason in pause_reason_models: + if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: + form_model = form_models.get(reason.form_id) + recipients = recipient_models_by_form.get(reason.form_id, []) + pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients)) + else: + pause_reasons.append(reason.to_entity()) + return pause_reasons + def get_workflow_pause( self, workflow_run_id: str, @@ -767,14 +869,12 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if pause_model is None: return None pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id) - - human_input_form: list[Any] = [] - # TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason + pause_reasons = self._hydrate_pause_reasons(session, pause_reason_models) return _PrivateWorkflowPauseEntity( pause_model=pause_model, reason_models=pause_reason_models, - human_input_form=human_input_form, + pause_reasons=pause_reasons, ) def resume_workflow_pause( @@ -828,10 +928,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}") pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id) + hydrated_pause_reasons = self._hydrate_pause_reasons(session, pause_reasons) # Mark as resumed pause_model.resumed_at = naive_utc_now() - workflow_run.pause_id = None # type: ignore workflow_run.status = WorkflowExecutionStatus.RUNNING session.add(pause_model) @@ -839,7 +939,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) - return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons) + return _PrivateWorkflowPauseEntity( + pause_model=pause_model, + reason_models=pause_reasons, + pause_reasons=hydrated_pause_reasons, + ) def delete_workflow_pause( self, @@ -1165,6 +1269,15 @@ GROUP BY return cast(list[AverageInteractionStats], response_data) + def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None: + """Get a specific workflow run by its id and the associated tenant id.""" + with self._session_maker() as session: + stmt = select(WorkflowRun).where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.id == run_id, + ) + return session.scalar(stmt) + class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): """ @@ -1179,10 +1292,12 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): *, pause_model: WorkflowPause, reason_models: Sequence[WorkflowPauseReason], + pause_reasons: Sequence[PauseReason] | None = None, human_input_form: Sequence = (), ) -> None: self._pause_model = pause_model self._reason_models = reason_models + self._pause_reasons = pause_reasons self._cached_state: bytes | None = None self._human_input_form = human_input_form @@ -1219,4 +1334,10 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): return self._pause_model.resumed_at def get_pause_reasons(self) -> Sequence[PauseReason]: + if self._pause_reasons is not None: + return list(self._pause_reasons) return [reason.to_entity() for reason in self._reason_models] + + @property + def paused_at(self) -> datetime: + return self._pause_model.created_at diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py new file mode 100644 index 0000000000..5a2c0ea46f --- /dev/null +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import json +import logging +import re +from collections import defaultdict +from collections.abc import Sequence +from typing import Any + +from sqlalchemy import select +from sqlalchemy.orm import Session, selectinload, sessionmaker + +from core.entities.execution_extra_content import ( + ExecutionExtraContentDomainModel, + HumanInputFormDefinition, + HumanInputFormSubmissionData, +) +from core.entities.execution_extra_content import ( + HumanInputContent as HumanInputContentDomainModel, +) +from core.workflow.nodes.human_input.entities import FormDefinition +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from models.execution_extra_content import ( + ExecutionExtraContent as ExecutionExtraContentModel, +) +from models.execution_extra_content import ( + HumanInputContent as HumanInputContentModel, +) +from models.human_input import HumanInputFormRecipient, RecipientType +from repositories.execution_extra_content_repository import ExecutionExtraContentRepository + +logger = logging.getLogger(__name__) + +_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") + + +def _extract_output_field_names(form_content: str) -> list[str]: + if not form_content: + return [] + return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)] + + +class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository): + def __init__(self, session_maker: sessionmaker[Session]): + self._session_maker = session_maker + + def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: + if not message_ids: + return [] + + grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = { + message_id: [] for message_id in message_ids + } + + stmt = ( + select(ExecutionExtraContentModel) + .where(ExecutionExtraContentModel.message_id.in_(message_ids)) + .options(selectinload(HumanInputContentModel.form)) + .order_by(ExecutionExtraContentModel.created_at.asc()) + ) + + with self._session_maker() as session: + results = session.scalars(stmt).all() + + form_ids = { + content.form_id + for content in results + if isinstance(content, HumanInputContentModel) and content.form_id is not None + } + recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list) + if form_ids: + recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + recipients = session.scalars(recipient_stmt).all() + for recipient in recipients: + recipients_by_form_id[recipient.form_id].append(recipient) + else: + recipients_by_form_id = {} + + for content in results: + message_id = content.message_id + if not message_id or message_id not in grouped_contents: + continue + + domain_model = self._map_model_to_domain(content, recipients_by_form_id) + if domain_model is None: + continue + + grouped_contents[message_id].append(domain_model) + + return [grouped_contents[message_id] for message_id in message_ids] + + def _map_model_to_domain( + self, + model: ExecutionExtraContentModel, + recipients_by_form_id: dict[str, list[HumanInputFormRecipient]], + ) -> ExecutionExtraContentDomainModel | None: + if isinstance(model, HumanInputContentModel): + return self._map_human_input_content(model, recipients_by_form_id) + + logger.debug("Unsupported execution extra content type encountered: %s", model.type) + return None + + def _map_human_input_content( + self, + model: HumanInputContentModel, + recipients_by_form_id: dict[str, list[HumanInputFormRecipient]], + ) -> HumanInputContentDomainModel | None: + form = model.form + if form is None: + logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id) + return None + + try: + definition_payload = json.loads(form.form_definition) + if "expiration_time" not in definition_payload: + definition_payload["expiration_time"] = form.expiration_time + form_definition = FormDefinition.model_validate(definition_payload) + except ValueError: + logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id) + return None + node_title = form_definition.node_title or form.node_id + display_in_ui = bool(form_definition.display_in_ui) + + submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED + if not submitted: + form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, [])) + return HumanInputContentDomainModel( + workflow_run_id=model.workflow_run_id, + submitted=False, + form_definition=HumanInputFormDefinition( + form_id=form.id, + node_id=form.node_id, + node_title=node_title, + form_content=form.rendered_content, + inputs=form_definition.inputs, + actions=form_definition.user_actions, + display_in_ui=display_in_ui, + form_token=form_token, + resolved_default_values=form_definition.default_values, + expiration_time=int(form.expiration_time.timestamp()), + ), + ) + + selected_action_id = form.selected_action_id + if not selected_action_id: + logger.warning("HumanInputContent(id=%s) form has no selected action", model.id) + return None + + action_text = next( + (action.title for action in form_definition.user_actions if action.id == selected_action_id), + selected_action_id, + ) + + submitted_data: dict[str, Any] = {} + if form.submitted_data: + try: + submitted_data = json.loads(form.submitted_data) + except ValueError: + logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id) + return None + + rendered_content = HumanInputNode.render_form_content_with_outputs( + form.rendered_content, + submitted_data, + _extract_output_field_names(form_definition.form_content), + ) + + return HumanInputContentDomainModel( + workflow_run_id=model.workflow_run_id, + submitted=True, + form_submission_data=HumanInputFormSubmissionData( + node_id=form.node_id, + node_title=node_title, + rendered_content=rendered_content, + action_id=selected_action_id, + action_text=action_text, + ), + ) + + @staticmethod + def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None: + console_recipient = next( + (recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE), + None, + ) + if console_recipient and console_recipient.access_token: + return console_recipient.access_token + + web_app_recipient = next( + (recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP), + None, + ) + if web_app_recipient and web_app_recipient.access_token: + return web_app_recipient.access_token + + return None + + +__all__ = ["SQLAlchemyExecutionExtraContentRepository"] diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index f3dc4cd60b..1f6740b066 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -92,6 +92,16 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): return list(self.session.scalars(query).all()) + def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None: + """Get the trigger log associated with a workflow run.""" + query = ( + select(WorkflowTriggerLog) + .where(WorkflowTriggerLog.workflow_run_id == workflow_run_id) + .order_by(WorkflowTriggerLog.created_at.desc()) + .limit(1) + ) + return self.session.scalar(query) + def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: """ Delete trigger logs associated with the given workflow run ids. diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py index b0009e398d..7f9e6b7b68 100644 --- a/api/repositories/workflow_trigger_log_repository.py +++ b/api/repositories/workflow_trigger_log_repository.py @@ -110,6 +110,18 @@ class WorkflowTriggerLogRepository(Protocol): """ ... + def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None: + """ + Retrieve a trigger log associated with a specific workflow run. + + Args: + workflow_run_id: Identifier of the workflow run + + Returns: + The matching WorkflowTriggerLog if present, None otherwise + """ + ... + def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: """ Delete trigger logs for workflow run IDs. diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 0f42c99246..9400362605 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -44,7 +44,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.5.0" +CURRENT_DSL_VERSION = "0.6.0" class ImportMode(StrEnum): diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index ce85f2e914..0c27c403f8 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,7 +1,9 @@ from __future__ import annotations +import logging +import threading import uuid -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping from typing import TYPE_CHECKING, Any, Union from configs import dify_config @@ -9,22 +11,63 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.completion.app_generator import CompletionAppGenerator +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit +from core.app.features.rate_limiting.rate_limit import rate_limit_context +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig +from core.db import session_factory from enums.quota_type import QuotaType, unlimited from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowRun from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService +from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task + +logger = logging.getLogger(__name__) + +SSE_TASK_START_FALLBACK_MS = 200 if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload class AppGenerateService: + @staticmethod + def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]: + started = False + lock = threading.Lock() + + def _try_start() -> bool: + nonlocal started + with lock: + if started: + return True + try: + start_task() + except Exception: + logger.exception("Failed to enqueue streaming task") + return False + started = True + return True + + # XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber. + # The Celery task may publish the first event before the API side actually subscribes, + # causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe, + # but also use a short fallback timer so the task still runs if the client never consumes. + timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start) + timer.daemon = True + timer.start() + + def _on_subscribe() -> None: + if _try_start(): + timer.cancel() + + return _on_subscribe + @classmethod @trace_span(AppGenerateHandler) def generate( @@ -88,15 +131,29 @@ class AppGenerateService: elif app_model.mode == AppMode.ADVANCED_CHAT: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) + with rate_limit_context(rate_limit, request_id): + payload = AppExecutionParams.new( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + call_depth=0, + ) + payload_json = payload.model_dump_json() + + def on_subscribe(): + workflow_based_app_execution_task.delay(payload_json) + + on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) + generator = AdvancedChatAppGenerator() return rate_limit.generate( - AdvancedChatAppGenerator.convert_to_event_stream( - AdvancedChatAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=streaming, + generator.convert_to_event_stream( + generator.retrieve_events( + AppMode.ADVANCED_CHAT, + payload.workflow_run_id, + on_subscribe=on_subscribe, ), ), request_id=request_id, @@ -104,6 +161,40 @@ class AppGenerateService: elif app_model.mode == AppMode.WORKFLOW: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) + if streaming: + with rate_limit_context(rate_limit, request_id): + payload = AppExecutionParams.new( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=True, + call_depth=0, + root_node_id=root_node_id, + workflow_run_id=str(uuid.uuid4()), + ) + payload_json = payload.model_dump_json() + + def on_subscribe(): + workflow_based_app_execution_task.delay(payload_json) + + on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) + return rate_limit.generate( + WorkflowAppGenerator.convert_to_event_stream( + MessageBasedAppGenerator.retrieve_events( + AppMode.WORKFLOW, + payload.workflow_run_id, + on_subscribe=on_subscribe, + ), + ), + request_id, + ) + + pause_config = PauseStateLayerConfig( + session_factory=session_factory.get_session_maker(), + state_owner_user_id=workflow.created_by, + ) return rate_limit.generate( WorkflowAppGenerator.convert_to_event_stream( WorkflowAppGenerator().generate( @@ -112,9 +203,10 @@ class AppGenerateService: user=user, args=args, invoke_from=invoke_from, - streaming=streaming, + streaming=False, root_node_id=root_node_id, call_depth=0, + pause_state_config=pause_config, ), ), request_id, @@ -248,3 +340,19 @@ class AppGenerateService: raise ValueError("Workflow not published") return workflow + + @classmethod + def get_response_generator( + cls, + app_model: App, + workflow_run: WorkflowRun, + ): + if workflow_run.status.is_ended(): + # TODO(QuantumGhost): handled the ended scenario. + pass + + generator = AdvancedChatAppGenerator() + + return generator.convert_to_event_stream( + generator.retrieve_events(AppMode(app_model.mode), workflow_run.id), + ) diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 41ee9c88aa..a95361cebd 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -136,7 +136,7 @@ class AudioService: message = db.session.query(Message).where(Message.id == message_id).first() if message is None: return None - if message.answer == "" and message.status == MessageStatus.NORMAL: + if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}: return None else: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index be9a0e9279..0b3fcbe4ae 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -89,6 +89,7 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t from tasks.document_indexing_update_task import document_indexing_update_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task +from tasks.regenerate_summary_index_task import regenerate_summary_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -211,6 +212,7 @@ class DatasetService: embedding_model_provider: str | None = None, embedding_model_name: str | None = None, retrieval_model: RetrievalModel | None = None, + summary_index_setting: dict | None = None, ): # check if dataset name already exists if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): @@ -253,6 +255,8 @@ class DatasetService: dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.provider = provider + if summary_index_setting is not None: + dataset.summary_index_setting = summary_index_setting db.session.add(dataset) db.session.flush() @@ -476,6 +480,11 @@ class DatasetService: if external_retrieval_model: dataset.retrieval_model = external_retrieval_model + # Update summary index setting if provided + summary_index_setting = data.get("summary_index_setting", None) + if summary_index_setting is not None: + dataset.summary_index_setting = summary_index_setting + # Update basic dataset properties dataset.name = data.get("name", dataset.name) dataset.description = data.get("description", dataset.description) @@ -564,6 +573,9 @@ class DatasetService: # update Retrieval model if data.get("retrieval_model"): filtered_data["retrieval_model"] = data["retrieval_model"] + # update summary index setting + if data.get("summary_index_setting"): + filtered_data["summary_index_setting"] = data.get("summary_index_setting") # update icon info if data.get("icon_info"): filtered_data["icon_info"] = data.get("icon_info") @@ -572,12 +584,27 @@ class DatasetService: db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) db.session.commit() + # Reload dataset to get updated values + db.session.refresh(dataset) + # update pipeline knowledge base node data DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id) # Trigger vector index task if indexing technique changed if action: deal_dataset_vector_index_task.delay(dataset.id, action) + # If embedding_model changed, also regenerate summary vectors + if action == "update": + regenerate_summary_index_task.delay( + dataset.id, + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + + # Note: summary_index_setting changes do not trigger automatic regeneration of existing summaries. + # The new setting will only apply to: + # 1. New documents added after the setting change + # 2. Manual summary generation requests return dataset @@ -616,6 +643,7 @@ class DatasetService: knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue] knowledge_index_node_data["keyword_number"] = dataset.keyword_number + knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting node["data"] = knowledge_index_node_data updated = True except Exception: @@ -854,6 +882,54 @@ class DatasetService: ) filtered_data["collection_binding_id"] = dataset_collection_binding.id + @staticmethod + def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool: + """ + Check if summary_index_setting model (model_name or model_provider_name) has changed. + + Args: + dataset: Current dataset object + data: Update data dictionary + + Returns: + bool: True if summary model changed, False otherwise + """ + # Check if summary_index_setting is being updated + if "summary_index_setting" not in data or data.get("summary_index_setting") is None: + return False + + new_summary_setting = data.get("summary_index_setting") + old_summary_setting = dataset.summary_index_setting + + # If new setting is disabled, no need to regenerate + if not new_summary_setting or not new_summary_setting.get("enable"): + return False + + # If old setting doesn't exist, no need to regenerate (no existing summaries to regenerate) + # Note: This task only regenerates existing summaries, not generates new ones + if not old_summary_setting: + return False + + # Compare model_name and model_provider_name + old_model_name = old_summary_setting.get("model_name") + old_model_provider = old_summary_setting.get("model_provider_name") + new_model_name = new_summary_setting.get("model_name") + new_model_provider = new_summary_setting.get("model_provider_name") + + # Check if model changed + if old_model_name != new_model_name or old_model_provider != new_model_provider: + logger.info( + "Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s", + dataset.id, + old_model_provider, + old_model_name, + new_model_provider, + new_model_name, + ) + return True + + return False + @staticmethod def update_rag_pipeline_dataset_settings( session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False @@ -889,6 +965,9 @@ class DatasetService: else: raise ValueError("Invalid index method") dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting session.add(dataset) else: if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: @@ -994,6 +1073,9 @@ class DatasetService: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting session.add(dataset) session.commit() if action: @@ -1314,6 +1396,50 @@ class DocumentService: upload_file = DocumentService._get_upload_file_for_upload_file_document(document) return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True) + @staticmethod + def enrich_documents_with_summary_index_status( + documents: Sequence[Document], + dataset: Dataset, + tenant_id: str, + ) -> None: + """ + Enrich documents with summary_index_status based on dataset summary index settings. + + This method calculates and sets the summary_index_status for each document that needs summary. + Documents that don't need summary or when summary index is disabled will have status set to None. + + Args: + documents: List of Document instances to enrich + dataset: Dataset instance containing summary_index_setting + tenant_id: Tenant ID for summary status lookup + """ + # Check if dataset has summary index enabled + has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True + + # Filter documents that need summary calculation + documents_need_summary = [doc for doc in documents if doc.need_summary is True] + document_ids_need_summary = [str(doc.id) for doc in documents_need_summary] + + # Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled) + summary_status_map: dict[str, str | None] = {} + if has_summary_index and document_ids_need_summary: + from services.summary_index_service import SummaryIndexService + + summary_status_map = SummaryIndexService.get_documents_summary_index_status( + document_ids=document_ids_need_summary, + dataset_id=dataset.id, + tenant_id=tenant_id, + ) + + # Add summary_index_status to each document + for document in documents: + if has_summary_index and document.need_summary is True: + # Get status from map, default to None (not queued yet) + document.summary_index_status = summary_status_map.get(str(document.id)) # type: ignore[attr-defined] + else: + # Return null if summary index is not enabled or document doesn't need summary + document.summary_index_status = None # type: ignore[attr-defined] + @staticmethod def prepare_document_batch_download_zip( *, @@ -1964,6 +2090,8 @@ class DocumentService: DuplicateDocumentIndexingTaskProxy( dataset.tenant_id, dataset.id, duplicate_document_ids ).delay() + # Note: Summary index generation is triggered in document_indexing_task after indexing completes + # to ensure segments are available. See tasks/document_indexing_task.py except LockNotOwnedError: pass @@ -2268,6 +2396,11 @@ class DocumentService: name: str, batch: str, ): + # Set need_summary based on dataset's summary_index_setting + need_summary = False + if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True: + need_summary = True + document = Document( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -2281,6 +2414,7 @@ class DocumentService: created_by=account.id, doc_form=document_form, doc_language=document_language, + need_summary=need_summary, ) doc_metadata = {} if dataset.built_in_field_enabled: @@ -2505,6 +2639,7 @@ class DocumentService: embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model.model_dump() if retrieval_model else None, + summary_index_setting=knowledge_config.summary_index_setting, is_multimodal=knowledge_config.is_multimodal, ) @@ -2686,6 +2821,14 @@ class DocumentService: if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") + # valid summary index setting + summary_index_setting = args["process_rule"].get("summary_index_setting") + if summary_index_setting and summary_index_setting.get("enable"): + if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]: + raise ValueError("Summary index model name is required") + if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]: + raise ValueError("Summary index model provider name is required") + @staticmethod def batch_update_document_status( dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user @@ -3154,6 +3297,35 @@ class SegmentService: if args.enabled or keyword_changed: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) + # update summary index if summary is provided and has changed + if args.summary is not None: + # When user manually provides summary, allow saving even if summary_index_setting doesn't exist + # summary_index_setting is only needed for LLM generation, not for manual summary vectorization + # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting + if dataset.indexing_technique == "high_quality": + # Query existing summary from database + from models.dataset import DocumentSegmentSummary + + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .first() + ) + + # Check if summary has changed + existing_summary_content = existing_summary.summary_content if existing_summary else None + if existing_summary_content != args.summary: + # Summary has changed, update it + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary) + except Exception: + logger.exception("Failed to update summary for segment %s", segment.id) + # Don't fail the entire update if summary update fails else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -3228,6 +3400,73 @@ class SegmentService: elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) + # Handle summary index when content changed + if dataset.indexing_technique == "high_quality": + from models.dataset import DocumentSegmentSummary + + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .first() + ) + + if args.summary is None: + # User didn't provide summary, auto-regenerate if segment previously had summary + # Auto-regeneration only happens if summary_index_setting exists and enable is True + if ( + existing_summary + and dataset.summary_index_setting + and dataset.summary_index_setting.get("enable") is True + ): + # Segment previously had summary, regenerate it with new content + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, dataset.summary_index_setting + ) + logger.info("Auto-regenerated summary for segment %s after content change", segment.id) + except Exception: + logger.exception("Failed to auto-regenerate summary for segment %s", segment.id) + # Don't fail the entire update if summary regeneration fails + else: + # User provided summary, check if it has changed + # Manual summary updates are allowed even if summary_index_setting doesn't exist + existing_summary_content = existing_summary.summary_content if existing_summary else None + if existing_summary_content != args.summary: + # Summary has changed, use user-provided summary + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary) + logger.info("Updated summary for segment %s with user-provided content", segment.id) + except Exception: + logger.exception("Failed to update summary for segment %s", segment.id) + # Don't fail the entire update if summary update fails + else: + # Summary hasn't changed, regenerate based on new content + # Auto-regeneration only happens if summary_index_setting exists and enable is True + if ( + existing_summary + and dataset.summary_index_setting + and dataset.summary_index_setting.get("enable") is True + ): + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, dataset.summary_index_setting + ) + logger.info( + "Regenerated summary for segment %s after content change (summary unchanged)", + segment.id, + ) + except Exception: + logger.exception("Failed to regenerate summary for segment %s", segment.id) + # Don't fail the entire update if summary regeneration fails # update multimodel vector index VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset) except Exception as e: @@ -3616,6 +3855,39 @@ class SegmentService: ) return result if isinstance(result, DocumentSegment) else None + @classmethod + def get_segments_by_document_and_dataset( + cls, + document_id: str, + dataset_id: str, + status: str | None = None, + enabled: bool | None = None, + ) -> Sequence[DocumentSegment]: + """ + Get segments for a document in a dataset with optional filtering. + + Args: + document_id: Document ID + dataset_id: Dataset ID + status: Optional status filter (e.g., "completed") + enabled: Optional enabled filter (True/False) + + Returns: + Sequence of DocumentSegment instances + """ + query = select(DocumentSegment).where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) + + if status is not None: + query = query.where(DocumentSegment.status == status) + + if enabled is not None: + query = query.where(DocumentSegment.enabled == enabled) + + return db.session.scalars(query).all() + class DatasetCollectionBindingService: @classmethod diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 7959734e89..8dc5b93501 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -119,6 +119,7 @@ class KnowledgeConfig(BaseModel): data_source: DataSource | None = None process_rule: ProcessRule | None = None retrieval_model: RetrievalModel | None = None + summary_index_setting: dict | None = None doc_form: str = "text_model" doc_language: str = "English" embedding_model: str | None = None @@ -141,6 +142,7 @@ class SegmentUpdateArgs(BaseModel): regenerate_child_chunks: bool = False enabled: bool | None = None attachment_ids: list[str] | None = None + summary: str | None = None # Summary content for summary index class ChildChunkUpdateArgs(BaseModel): diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index cbb0efcc2a..041ae4edba 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -116,6 +116,8 @@ class KnowledgeConfiguration(BaseModel): embedding_model: str = "" keyword_number: int | None = 10 retrieval_model: RetrievalSetting + # add summary index setting + summary_index_setting: dict | None = None @field_validator("embedding_model_provider", mode="before") @classmethod diff --git a/api/services/feature_service.py b/api/services/feature_service.py index d94ae49d91..fda3a15144 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -138,6 +138,8 @@ class FeatureModel(BaseModel): is_allow_transfer_workspace: bool = True trigger_event: Quota = Quota(usage=0, limit=3000, reset_date=0) api_rate_limit: Quota = Quota(usage=0, limit=5000, reset_date=0) + # Controls whether email delivery is allowed for HumanInput nodes. + human_input_email_delivery_enabled: bool = False # pydantic configs model_config = ConfigDict(protected_namespaces=()) knowledge_pipeline: KnowledgePipeline = KnowledgePipeline() @@ -191,6 +193,11 @@ class FeatureService: features.knowledge_pipeline.publish_enabled = True cls._fulfill_params_from_workspace_info(features, tenant_id) + features.human_input_email_delivery_enabled = cls._resolve_human_input_email_delivery_enabled( + features=features, + tenant_id=tenant_id, + ) + return features @classmethod @@ -203,6 +210,17 @@ class FeatureService: knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", CloudPlan.SANDBOX) return knowledge_rate_limit + @classmethod + def _resolve_human_input_email_delivery_enabled(cls, *, features: FeatureModel, tenant_id: str | None) -> bool: + if dify_config.ENTERPRISE_ENABLED or not dify_config.BILLING_ENABLED: + return True + if not tenant_id: + return False + return features.billing.enabled and features.billing.subscription.plan in ( + CloudPlan.PROFESSIONAL, + CloudPlan.TEAM, + ) + @classmethod def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py new file mode 100644 index 0000000000..ff37ff098f --- /dev/null +++ b/api/services/human_input_delivery_test_service.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import StrEnum +from typing import Protocol + +from sqlalchemy import Engine, select +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.workflow.nodes.human_input.entities import ( + DeliveryChannelConfig, + EmailDeliveryConfig, + EmailDeliveryMethod, + ExternalRecipient, + MemberRecipient, +) +from core.workflow.runtime import VariablePool +from extensions.ext_database import db +from extensions.ext_mail import mail +from libs.email_template_renderer import render_email_template +from models import Account, TenantAccountJoin +from services.feature_service import FeatureService + + +class DeliveryTestStatus(StrEnum): + OK = "ok" + FAILED = "failed" + + +@dataclass(frozen=True) +class DeliveryTestEmailRecipient: + email: str + form_token: str + + +@dataclass(frozen=True) +class DeliveryTestContext: + tenant_id: str + app_id: str + node_id: str + node_title: str | None + rendered_content: str + template_vars: dict[str, str] = field(default_factory=dict) + recipients: list[DeliveryTestEmailRecipient] = field(default_factory=list) + variable_pool: VariablePool | None = None + + +@dataclass(frozen=True) +class DeliveryTestResult: + status: DeliveryTestStatus + delivered_to: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + + +class DeliveryTestError(Exception): + pass + + +class DeliveryTestUnsupportedError(DeliveryTestError): + pass + + +def _build_form_link(token: str | None) -> str | None: + if not token: + return None + base_url = dify_config.APP_WEB_URL + if not base_url: + return None + return f"{base_url.rstrip('/')}/form/{token}" + + +class DeliveryTestHandler(Protocol): + def supports(self, method: DeliveryChannelConfig) -> bool: ... + + def send_test( + self, + *, + context: DeliveryTestContext, + method: DeliveryChannelConfig, + ) -> DeliveryTestResult: ... + + +class DeliveryTestRegistry: + def __init__(self, handlers: list[DeliveryTestHandler] | None = None) -> None: + self._handlers = list(handlers or []) + + def register(self, handler: DeliveryTestHandler) -> None: + self._handlers.append(handler) + + def dispatch( + self, + *, + context: DeliveryTestContext, + method: DeliveryChannelConfig, + ) -> DeliveryTestResult: + for handler in self._handlers: + if handler.supports(method): + return handler.send_test(context=context, method=method) + raise DeliveryTestUnsupportedError("Delivery method does not support test send.") + + @classmethod + def default(cls) -> DeliveryTestRegistry: + return cls([EmailDeliveryTestHandler()]) + + +class HumanInputDeliveryTestService: + def __init__(self, registry: DeliveryTestRegistry | None = None) -> None: + self._registry = registry or DeliveryTestRegistry.default() + + def send_test( + self, + *, + context: DeliveryTestContext, + method: DeliveryChannelConfig, + ) -> DeliveryTestResult: + return self._registry.dispatch(context=context, method=method) + + +class EmailDeliveryTestHandler: + def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None: + if session_factory is None: + session_factory = sessionmaker(bind=db.engine) + elif isinstance(session_factory, Engine): + session_factory = sessionmaker(bind=session_factory) + self._session_factory = session_factory + + def supports(self, method: DeliveryChannelConfig) -> bool: + return isinstance(method, EmailDeliveryMethod) + + def send_test( + self, + *, + context: DeliveryTestContext, + method: DeliveryChannelConfig, + ) -> DeliveryTestResult: + if not isinstance(method, EmailDeliveryMethod): + raise DeliveryTestUnsupportedError("Delivery method does not support test send.") + features = FeatureService.get_features(context.tenant_id) + if not features.human_input_email_delivery_enabled: + raise DeliveryTestError("Email delivery is not available for current plan.") + if not mail.is_inited(): + raise DeliveryTestError("Mail client is not initialized.") + + recipients = self._resolve_recipients( + tenant_id=context.tenant_id, + method=method, + ) + if not recipients: + raise DeliveryTestError("No recipients configured for delivery method.") + + delivered: list[str] = [] + for recipient_email in recipients: + substitutions = self._build_substitutions( + context=context, + recipient_email=recipient_email, + ) + subject = render_email_template(method.config.subject, substitutions) + templated_body = EmailDeliveryConfig.render_body_template( + body=method.config.body, + url=substitutions.get("form_link"), + variable_pool=context.variable_pool, + ) + body = render_email_template(templated_body, substitutions) + + mail.send( + to=recipient_email, + subject=subject, + html=body, + ) + delivered.append(recipient_email) + + return DeliveryTestResult(status=DeliveryTestStatus.OK, delivered_to=delivered) + + def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]: + recipients = method.config.recipients + emails: list[str] = [] + member_user_ids: list[str] = [] + for recipient in recipients.items: + if isinstance(recipient, MemberRecipient): + member_user_ids.append(recipient.user_id) + elif isinstance(recipient, ExternalRecipient): + if recipient.email: + emails.append(recipient.email) + + if recipients.whole_workspace: + member_user_ids = [] + member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None) + emails.extend(member_emails.values()) + elif member_user_ids: + member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids) + for user_id in member_user_ids: + email = member_emails.get(user_id) + if email: + emails.append(email) + + return list(dict.fromkeys([email for email in emails if email])) + + def _query_workspace_member_emails( + self, + *, + tenant_id: str, + user_ids: list[str] | None, + ) -> dict[str, str]: + if user_ids is None: + unique_ids = None + else: + unique_ids = {user_id for user_id in user_ids if user_id} + if not unique_ids: + return {} + + stmt = ( + select(Account.id, Account.email) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where(TenantAccountJoin.tenant_id == tenant_id) + ) + if unique_ids is not None: + stmt = stmt.where(Account.id.in_(unique_ids)) + + with self._session_factory() as session: + rows = session.execute(stmt).all() + return dict(rows) + + @staticmethod + def _build_substitutions( + *, + context: DeliveryTestContext, + recipient_email: str, + ) -> dict[str, str]: + raw_values: dict[str, str | None] = { + "form_id": "", + "node_title": context.node_title, + "workflow_run_id": "", + "form_token": "", + "form_link": "", + "form_content": context.rendered_content, + "recipient_email": recipient_email, + } + substitutions = {key: value or "" for key, value in raw_values.items()} + if context.template_vars: + substitutions.update({key: value for key, value in context.template_vars.items() if value is not None}) + token = next( + (recipient.form_token for recipient in context.recipients if recipient.email == recipient_email), + None, + ) + if token: + substitutions["form_token"] = token + substitutions["form_link"] = _build_form_link(token) or "" + return substitutions diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py new file mode 100644 index 0000000000..76b6e6e0e6 --- /dev/null +++ b/api/services/human_input_service.py @@ -0,0 +1,250 @@ +import logging +from collections.abc import Mapping +from datetime import datetime, timedelta +from typing import Any + +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormSubmissionRepository, +) +from core.workflow.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from libs.datetime_utils import ensure_naive_utc, naive_utc_now +from libs.exception import BaseHTTPException +from models.human_input import RecipientType +from models.model import App, AppMode +from repositories.factory import DifyAPIRepositoryFactory +from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE, resume_app_execution + + +class Form: + def __init__(self, record: HumanInputFormRecord): + self._record = record + + def get_definition(self) -> FormDefinition: + return self._record.definition + + @property + def submitted(self) -> bool: + return self._record.submitted + + @property + def id(self) -> str: + return self._record.form_id + + @property + def workflow_run_id(self) -> str | None: + """Workflow run id for runtime forms; None for delivery tests.""" + return self._record.workflow_run_id + + @property + def tenant_id(self) -> str: + return self._record.tenant_id + + @property + def app_id(self) -> str: + return self._record.app_id + + @property + def recipient_id(self) -> str | None: + return self._record.recipient_id + + @property + def recipient_type(self) -> RecipientType | None: + return self._record.recipient_type + + @property + def status(self) -> HumanInputFormStatus: + return self._record.status + + @property + def form_kind(self) -> HumanInputFormKind: + return self._record.form_kind + + @property + def created_at(self) -> "datetime": + return self._record.created_at + + @property + def expiration_time(self) -> "datetime": + return self._record.expiration_time + + +class HumanInputError(Exception): + pass + + +class FormSubmittedError(HumanInputError, BaseHTTPException): + error_code = "human_input_form_submitted" + description = "This form has already been submitted by another user, form_id={form_id}" + code = 412 + + def __init__(self, form_id: str): + template = self.description or "This form has already been submitted by another user, form_id={form_id}" + description = template.format(form_id=form_id) + super().__init__(description=description) + + +class FormNotFoundError(HumanInputError, BaseHTTPException): + error_code = "human_input_form_not_found" + code = 404 + + +class InvalidFormDataError(HumanInputError, BaseHTTPException): + error_code = "invalid_form_data" + code = 400 + + def __init__(self, description: str): + super().__init__(description=description) + + +class WebAppDeliveryNotEnabledError(HumanInputError, BaseException): + pass + + +class FormExpiredError(HumanInputError, BaseHTTPException): + error_code = "human_input_form_expired" + code = 412 + + def __init__(self, form_id: str): + super().__init__(description=f"This form has expired, form_id={form_id}") + + +logger = logging.getLogger(__name__) + + +class HumanInputService: + def __init__( + self, + session_factory: sessionmaker[Session] | Engine, + form_repository: HumanInputFormSubmissionRepository | None = None, + ): + if isinstance(session_factory, Engine): + session_factory = sessionmaker(bind=session_factory) + self._session_factory = session_factory + self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory) + + def get_form_by_token(self, form_token: str) -> Form | None: + record = self._form_repository.get_by_token(form_token) + if record is None: + return None + return Form(record) + + def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form | None: + form = self.get_form_by_token(form_token) + if form is None or form.recipient_type != recipient_type: + return None + self._ensure_not_submitted(form) + return form + + def get_form_definition_by_token_for_console(self, form_token: str) -> Form | None: + form = self.get_form_by_token(form_token) + if form is None or form.recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}: + return None + self._ensure_not_submitted(form) + return form + + def submit_form_by_token( + self, + recipient_type: RecipientType, + form_token: str, + selected_action_id: str, + form_data: Mapping[str, Any], + submission_end_user_id: str | None = None, + submission_user_id: str | None = None, + ): + form = self.get_form_by_token(form_token) + if form is None or form.recipient_type != recipient_type: + raise WebAppDeliveryNotEnabledError() + + self.ensure_form_active(form) + self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data) + + result = self._form_repository.mark_submitted( + form_id=form.id, + recipient_id=form.recipient_id, + selected_action_id=selected_action_id, + form_data=form_data, + submission_user_id=submission_user_id, + submission_end_user_id=submission_end_user_id, + ) + + if result.form_kind != HumanInputFormKind.RUNTIME: + return + if result.workflow_run_id is None: + return + self.enqueue_resume(result.workflow_run_id) + + def ensure_form_active(self, form: Form) -> None: + if form.submitted: + raise FormSubmittedError(form.id) + if form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}: + raise FormExpiredError(form.id) + now = naive_utc_now() + if ensure_naive_utc(form.expiration_time) <= now: + raise FormExpiredError(form.id) + if self._is_globally_expired(form, now=now): + raise FormExpiredError(form.id) + + def _ensure_not_submitted(self, form: Form) -> None: + if form.submitted: + raise FormSubmittedError(form.id) + + def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None: + definition = form.get_definition() + try: + validate_human_input_submission( + inputs=definition.inputs, + user_actions=definition.user_actions, + selected_action_id=selected_action_id, + form_data=form_data, + ) + except HumanInputSubmissionValidationError as exc: + raise InvalidFormDataError(str(exc)) from exc + + def enqueue_resume(self, workflow_run_id: str) -> None: + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory) + workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(workflow_run_id) + + if workflow_run is None: + raise AssertionError(f"WorkflowRun not found, id={workflow_run_id}") + with self._session_factory(expire_on_commit=False) as session: + app_query = select(App).where(App.id == workflow_run.app_id) + app = session.execute(app_query).scalar_one_or_none() + if app is None: + logger.error( + "App not found for WorkflowRun, workflow_run_id=%s, app_id=%s", workflow_run_id, workflow_run.app_id + ) + return + + if app.mode in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}: + payload = {"workflow_run_id": workflow_run_id} + try: + resume_app_execution.apply_async( + kwargs={"payload": payload}, + queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, + ) + except Exception: # pragma: no cover + logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id) + return + + logger.warning("App mode %s does not support resume for workflow run %s", app.mode, workflow_run_id) + + def _is_globally_expired(self, form: Form, *, now: datetime | None = None) -> bool: + global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS + if global_timeout_seconds <= 0: + return False + if form.workflow_run_id is None: + return False + current = now or naive_utc_now() + created_at = ensure_naive_utc(form.created_at) + global_deadline = created_at + timedelta(seconds=global_timeout_seconds) + return global_deadline <= current diff --git a/api/services/message_service.py b/api/services/message_service.py index a53ca8b22d..ce699e79d4 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,6 +1,9 @@ import json +from collections.abc import Sequence from typing import Union +from sqlalchemy.orm import sessionmaker + from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -14,6 +17,10 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback +from repositories.execution_extra_content_repository import ExecutionExtraContentRepository +from repositories.sqlalchemy_execution_extra_content_repository import ( + SQLAlchemyExecutionExtraContentRepository, +) from services.conversation_service import ConversationService from services.errors.message import ( FirstMessageNotExistsError, @@ -24,6 +31,23 @@ from services.errors.message import ( from services.workflow_service import WorkflowService +def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + return SQLAlchemyExecutionExtraContentRepository(session_maker=session_maker) + + +def attach_message_extra_contents(messages: Sequence[Message]) -> None: + if not messages: + return + + repository = _create_execution_extra_content_repository() + extra_contents_lists = repository.get_by_message_ids([message.id for message in messages]) + + for index, message in enumerate(messages): + contents = extra_contents_lists[index] if index < len(extra_contents_lists) else [] + message.set_extra_contents([content.model_dump(mode="json", exclude_none=True) for content in contents]) + + class MessageService: @classmethod def pagination_by_first_id( @@ -85,6 +109,8 @@ class MessageService: if order == "asc": history_messages = list(reversed(history_messages)) + attach_message_extra_contents(history_messages) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c1c6e204fb..be1ce834f6 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -343,6 +343,9 @@ class RagPipelineDslService: dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() @@ -477,6 +480,9 @@ class RagPipelineDslService: dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py new file mode 100644 index 0000000000..b8e1f8bc3f --- /dev/null +++ b/api/services/summary_index_service.py @@ -0,0 +1,1432 @@ +"""Summary index service for generating and managing document segment summaries.""" + +import logging +import time +import uuid +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy.orm import Session + +from core.db.session_factory import session_factory +from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.models.document import Document +from libs import helper +from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary +from models.dataset import Document as DatasetDocument + +logger = logging.getLogger(__name__) + + +class SummaryIndexService: + """Service for generating and managing summary indexes.""" + + @staticmethod + def generate_summary_for_segment( + segment: DocumentSegment, + dataset: Dataset, + summary_index_setting: dict, + ) -> tuple[str, LLMUsage]: + """ + Generate summary for a single segment. + + Args: + segment: DocumentSegment to generate summary for + dataset: Dataset containing the segment + summary_index_setting: Summary index configuration + + Returns: + Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object + + Raises: + ValueError: If summary_index_setting is invalid or generation fails + """ + # Reuse the existing generate_summary method from ParagraphIndexProcessor + # Use lazy import to avoid circular import + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + summary_content, usage = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=segment.content, + summary_index_setting=summary_index_setting, + segment_id=segment.id, + ) + + if not summary_content: + raise ValueError("Generated summary is empty") + + return summary_content, usage + + @staticmethod + def create_summary_record( + segment: DocumentSegment, + dataset: Dataset, + summary_content: str, + status: str = "generating", + ) -> DocumentSegmentSummary: + """ + Create or update a DocumentSegmentSummary record. + If a summary record already exists for this segment, it will be updated instead of creating a new one. + + Args: + segment: DocumentSegment to create summary for + dataset: Dataset containing the segment + summary_content: Generated summary content + status: Summary status (default: "generating") + + Returns: + Created or updated DocumentSegmentSummary instance + """ + with session_factory.create_session() as session: + # Check if summary record already exists + existing_summary = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if existing_summary: + # Update existing record + existing_summary.summary_content = summary_content + existing_summary.status = status + existing_summary.error = None # type: ignore[assignment] # Clear any previous errors + # Re-enable if it was disabled + if not existing_summary.enabled: + existing_summary.enabled = True + existing_summary.disabled_at = None + existing_summary.disabled_by = None + session.add(existing_summary) + session.flush() + return existing_summary + else: + # Create new record (enabled by default) + summary_record = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=summary_content, + status=status, + enabled=True, # Explicitly set enabled to True + ) + session.add(summary_record) + session.flush() + return summary_record + + @staticmethod + def vectorize_summary( + summary_record: DocumentSegmentSummary, + segment: DocumentSegment, + dataset: Dataset, + session: Session | None = None, + ) -> None: + """ + Vectorize summary and store in vector database. + + Args: + summary_record: DocumentSegmentSummary record + segment: Original DocumentSegment + dataset: Dataset containing the segment + session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. + If not provided, creates a new session and commits automatically. + """ + if dataset.indexing_technique != "high_quality": + logger.warning( + "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", + dataset.id, + ) + return + + # Get summary_record_id for later session queries + summary_record_id = summary_record.id + # Save the original session parameter for use in error handling + original_session = session + logger.debug( + "Starting vectorization for segment %s, summary_record_id=%s, using_provided_session=%s", + segment.id, + summary_record_id, + original_session is not None, + ) + + # Reuse existing index_node_id if available (like segment does), otherwise generate new one + old_summary_node_id = summary_record.summary_index_node_id + if old_summary_node_id: + # Reuse existing index_node_id (like segment behavior) + summary_index_node_id = old_summary_node_id + logger.debug("Reusing existing index_node_id %s for segment %s", summary_index_node_id, segment.id) + else: + # Generate new index node ID only for new summaries + summary_index_node_id = str(uuid.uuid4()) + logger.debug("Generated new index_node_id %s for segment %s", summary_index_node_id, segment.id) + + # Always regenerate hash (in case summary content changed) + summary_content = summary_record.summary_content + if not summary_content or not summary_content.strip(): + raise ValueError(f"Summary content is empty for segment {segment.id}, cannot vectorize") + summary_hash = helper.generate_text_hash(summary_content) + + # Delete old vector only if we're reusing the same index_node_id (to overwrite) + # If index_node_id changed, the old vector should have been deleted elsewhere + if old_summary_node_id and old_summary_node_id == summary_index_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s. Continuing with new vectorization.", + segment.id, + str(e), + ) + + # Calculate embedding tokens for summary (for logging and statistics) + embedding_tokens = 0 + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content]) + embedding_tokens = tokens_list[0] if tokens_list else 0 + except Exception as e: + logger.warning("Failed to calculate embedding tokens for summary: %s", str(e)) + + # Create document with summary content and metadata + summary_document = Document( + page_content=summary_content, + metadata={ + "doc_id": summary_index_node_id, + "doc_hash": summary_hash, + "dataset_id": dataset.id, + "document_id": segment.document_id, + "original_chunk_id": segment.id, # Key: link to original chunk + "doc_type": DocType.TEXT, + "is_summary": True, # Identifier for summary documents + }, + ) + + # Vectorize and store with retry mechanism for connection errors + max_retries = 3 + retry_delay = 2.0 + + for attempt in range(max_retries): + try: + logger.debug( + "Attempting to vectorize summary for segment %s (attempt %s/%s)", + segment.id, + attempt + 1, + max_retries, + ) + vector = Vector(dataset) + # Use duplicate_check=False to ensure re-vectorization even if old vector still exists + # The old vector should have been deleted above, but if deletion failed, + # we still want to re-vectorize (upsert will overwrite) + vector.add_texts([summary_document], duplicate_check=False) + logger.debug( + "Successfully added summary vector to database for segment %s (attempt %s/%s)", + segment.id, + attempt + 1, + max_retries, + ) + + # Log embedding token usage + if embedding_tokens > 0: + logger.info( + "Summary embedding for segment %s used %s tokens", + segment.id, + embedding_tokens, + ) + + # Success - update summary record with index node info + # Use provided session if available, otherwise create a new one + use_provided_session = session is not None + if not use_provided_session: + logger.debug("Creating new session for vectorization of segment %s", segment.id) + session_context = session_factory.create_session() + session = session_context.__enter__() + else: + logger.debug("Using provided session for vectorization of segment %s", segment.id) + session_context = None # Don't use context manager for provided session + + # At this point, session is guaranteed to be not None + # Type narrowing: session is definitely not None after the if/else above + if session is None: + raise RuntimeError("Session should not be None at this point") + + try: + # Declare summary_record_in_session variable + summary_record_in_session: DocumentSegmentSummary | None + + # If using provided session, merge the summary_record into it + if use_provided_session: + # Merge the summary_record into the provided session + logger.debug( + "Merging summary_record (id=%s) into provided session for segment %s", + summary_record_id, + segment.id, + ) + summary_record_in_session = session.merge(summary_record) + logger.debug( + "Successfully merged summary_record for segment %s, merged_id=%s", + segment.id, + summary_record_in_session.id, + ) + else: + # Query the summary record in the new session + logger.debug( + "Querying summary_record by id=%s for segment %s in new session", + summary_record_id, + segment.id, + ) + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + ) + + if not summary_record_in_session: + # Record not found - try to find by chunk_id and dataset_id instead + logger.debug( + "Summary record not found by id=%s, trying chunk_id=%s and dataset_id=%s " + "for segment %s", + summary_record_id, + segment.id, + dataset.id, + segment.id, + ) + summary_record_in_session = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if not summary_record_in_session: + # Still not found - create a new one using the parameter data + logger.warning( + "Summary record not found in database for segment %s (id=%s), creating new one. " + "This may indicate a session isolation issue.", + segment.id, + summary_record_id, + ) + summary_record_in_session = DocumentSegmentSummary( + id=summary_record_id, # Use the same ID if available + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=summary_content, + summary_index_node_id=summary_index_node_id, + summary_index_node_hash=summary_hash, + tokens=embedding_tokens, + status="completed", + enabled=True, + ) + session.add(summary_record_in_session) + logger.info( + "Created new summary record (id=%s) for segment %s after vectorization", + summary_record_id, + segment.id, + ) + else: + # Found by chunk_id - update it + logger.info( + "Found summary record for segment %s by chunk_id " + "(id mismatch: expected %s, found %s). " + "This may indicate the record was created in a different session.", + segment.id, + summary_record_id, + summary_record_in_session.id, + ) + else: + logger.debug( + "Found summary_record (id=%s) for segment %s in new session", + summary_record_id, + segment.id, + ) + + # At this point, summary_record_in_session is guaranteed to be not None + if summary_record_in_session is None: + raise RuntimeError("summary_record_in_session should not be None at this point") + + # Update all fields including summary_content + # Always use the summary_content from the parameter (which is the latest from outer session) + # rather than relying on what's in the database, in case outer session hasn't committed yet + summary_record_in_session.summary_index_node_id = summary_index_node_id + summary_record_in_session.summary_index_node_hash = summary_hash + summary_record_in_session.tokens = embedding_tokens # Save embedding tokens + summary_record_in_session.status = "completed" + # Ensure summary_content is preserved (use the latest from summary_record parameter) + # This is critical: use the parameter value, not the database value + summary_record_in_session.summary_content = summary_content + # Explicitly update updated_at to ensure it's refreshed even if other fields haven't changed + summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) + session.add(summary_record_in_session) + + # Only commit if we created the session ourselves + if not use_provided_session: + logger.debug("Committing session for segment %s (self-created session)", segment.id) + session.commit() + logger.debug("Successfully committed session for segment %s", segment.id) + else: + # When using provided session, flush to ensure changes are written to database + # This prevents refresh() from overwriting our changes + logger.debug( + "Flushing session for segment %s (using provided session, caller will commit)", + segment.id, + ) + session.flush() + logger.debug("Successfully flushed session for segment %s", segment.id) + # If using provided session, let the caller handle commit + + logger.info( + "Successfully vectorized summary for segment %s, index_node_id=%s, index_node_hash=%s, " + "tokens=%s, summary_record_id=%s, use_provided_session=%s", + segment.id, + summary_index_node_id, + summary_hash, + embedding_tokens, + summary_record_in_session.id, + use_provided_session, + ) + # Update the original object for consistency + summary_record.summary_index_node_id = summary_index_node_id + summary_record.summary_index_node_hash = summary_hash + summary_record.tokens = embedding_tokens + summary_record.status = "completed" + summary_record.summary_content = summary_content + if summary_record_in_session.updated_at: + summary_record.updated_at = summary_record_in_session.updated_at + finally: + # Only close session if we created it ourselves + if not use_provided_session and session_context: + session_context.__exit__(None, None, None) + # Success, exit function + return + + except (ConnectionError, Exception) as e: + error_str = str(e).lower() + # Check if it's a connection-related error that might be transient + is_connection_error = any( + keyword in error_str + for keyword in [ + "connection", + "disconnected", + "timeout", + "network", + "could not connect", + "server disconnected", + "weaviate", + ] + ) + + if is_connection_error and attempt < max_retries - 1: + # Retry for connection errors + wait_time = retry_delay * (2**attempt) # Exponential backoff + logger.warning( + "Vectorization attempt %s/%s failed for segment %s (connection error): %s. " + "Retrying in %.1f seconds...", + attempt + 1, + max_retries, + segment.id, + str(e), + wait_time, + ) + time.sleep(wait_time) + continue + else: + # Final attempt failed or non-connection error - log and update status + logger.error( + "Failed to vectorize summary for segment %s after %s attempts: %s. " + "summary_record_id=%s, index_node_id=%s, use_provided_session=%s", + segment.id, + attempt + 1, + str(e), + summary_record_id, + summary_index_node_id, + session is not None, + exc_info=True, + ) + # Update error status in session + # Use the original_session saved at function start (the function parameter) + logger.debug( + "Updating error status for segment %s, summary_record_id=%s, has_original_session=%s", + segment.id, + summary_record_id, + original_session is not None, + ) + # Always create a new session for error handling to avoid issues with closed sessions + # Even if original_session was provided, we create a new one for safety + with session_factory.create_session() as error_session: + # Try to find the record by id first + # Note: Using assignment only (no type annotation) to avoid redeclaration error + summary_record_in_session = ( + error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + ) + if not summary_record_in_session: + # Try to find by chunk_id and dataset_id + logger.debug( + "Summary record not found by id=%s, trying chunk_id=%s and dataset_id=%s " + "for segment %s", + summary_record_id, + segment.id, + dataset.id, + segment.id, + ) + summary_record_in_session = ( + error_session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if summary_record_in_session: + summary_record_in_session.status = "error" + summary_record_in_session.error = f"Vectorization failed: {str(e)}" + summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) + error_session.add(summary_record_in_session) + error_session.commit() + logger.info( + "Updated error status in new session for segment %s, record_id=%s", + segment.id, + summary_record_in_session.id, + ) + # Update the original object for consistency + summary_record.status = "error" + summary_record.error = summary_record_in_session.error + summary_record.updated_at = summary_record_in_session.updated_at + else: + logger.warning( + "Could not update error status: summary record not found for segment %s (id=%s). " + "This may indicate a session isolation issue.", + segment.id, + summary_record_id, + ) + raise + + @staticmethod + def batch_create_summary_records( + segments: list[DocumentSegment], + dataset: Dataset, + status: str = "not_started", + ) -> None: + """ + Batch create summary records for segments with specified status. + If a record already exists, update its status. + + Args: + segments: List of DocumentSegment instances + dataset: Dataset containing the segments + status: Initial status for the records (default: "not_started") + """ + segment_ids = [segment.id for segment in segments] + if not segment_ids: + return + + with session_factory.create_session() as session: + # Query existing summary records + existing_summaries = ( + session.query(DocumentSegmentSummary) + .filter( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .all() + ) + existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries} + + # Create or update records + for segment in segments: + existing_summary = existing_summary_map.get(segment.id) + if existing_summary: + # Update existing record + existing_summary.status = status + existing_summary.error = None # type: ignore[assignment] # Clear any previous errors + if not existing_summary.enabled: + existing_summary.enabled = True + existing_summary.disabled_at = None + existing_summary.disabled_by = None + session.add(existing_summary) + else: + # Create new record + summary_record = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=None, # Will be filled later + status=status, + enabled=True, + ) + session.add(summary_record) + + @staticmethod + def update_summary_record_error( + segment: DocumentSegment, + dataset: Dataset, + error: str, + ) -> None: + """ + Update summary record with error status. + + Args: + segment: DocumentSegment + dataset: Dataset containing the segment + error: Error message + """ + with session_factory.create_session() as session: + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if summary_record: + summary_record.status = "error" + summary_record.error = error + session.add(summary_record) + session.commit() + else: + logger.warning("Summary record not found for segment %s when updating error", segment.id) + + @staticmethod + def generate_and_vectorize_summary( + segment: DocumentSegment, + dataset: Dataset, + summary_index_setting: dict, + ) -> DocumentSegmentSummary: + """ + Generate summary for a segment and vectorize it. + Assumes summary record already exists (created by batch_create_summary_records). + + Args: + segment: DocumentSegment to generate summary for + dataset: Dataset containing the segment + summary_index_setting: Summary index configuration + + Returns: + Created DocumentSegmentSummary instance + + Raises: + ValueError: If summary generation fails + """ + with session_factory.create_session() as session: + try: + # Get or refresh summary record in this session + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if not summary_record_in_session: + # If not found, create one + logger.warning("Summary record not found for segment %s, creating one", segment.id) + summary_record_in_session = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content="", + status="generating", + enabled=True, + ) + session.add(summary_record_in_session) + session.flush() + + # Update status to "generating" + summary_record_in_session.status = "generating" + summary_record_in_session.error = None # type: ignore[assignment] + session.add(summary_record_in_session) + # Don't flush here - wait until after vectorization succeeds + + # Generate summary (returns summary_content and llm_usage) + summary_content, llm_usage = SummaryIndexService.generate_summary_for_segment( + segment, dataset, summary_index_setting + ) + + # Update summary content + summary_record_in_session.summary_content = summary_content + session.add(summary_record_in_session) + # Flush to ensure summary_content is saved before vectorize_summary queries it + session.flush() + + # Log LLM usage for summary generation + if llm_usage and llm_usage.total_tokens > 0: + logger.info( + "Summary generation for segment %s used %s tokens (prompt: %s, completion: %s)", + segment.id, + llm_usage.total_tokens, + llm_usage.prompt_tokens, + llm_usage.completion_tokens, + ) + + # Vectorize summary (will delete old vector if exists before creating new one) + # Pass the session-managed record to vectorize_summary + # vectorize_summary will update status to "completed" and tokens in its own session + # vectorize_summary will also ensure summary_content is preserved + try: + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary_record_in_session, segment, dataset, session=session) + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary_record_in_session) + # Commit the session + # (summary_record_in_session should have status="completed" and tokens from refresh) + session.commit() + logger.info("Successfully generated and vectorized summary for segment %s", segment.id) + return summary_record_in_session + except Exception as vectorize_error: + # If vectorization fails, update status to error in current session + logger.exception("Failed to vectorize summary for segment %s", segment.id) + summary_record_in_session.status = "error" + summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}" + session.add(summary_record_in_session) + session.commit() + raise + + except Exception as e: + logger.exception("Failed to generate summary for segment %s", segment.id) + # Update summary record with error status + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + if summary_record_in_session: + summary_record_in_session.status = "error" + summary_record_in_session.error = str(e) + session.add(summary_record_in_session) + session.commit() + raise + + @staticmethod + def generate_summaries_for_document( + dataset: Dataset, + document: DatasetDocument, + summary_index_setting: dict, + segment_ids: list[str] | None = None, + only_parent_chunks: bool = False, + ) -> list[DocumentSegmentSummary]: + """ + Generate summaries for all segments in a document including vectorization. + + Args: + dataset: Dataset containing the document + document: DatasetDocument to generate summaries for + summary_index_setting: Summary index configuration + segment_ids: Optional list of specific segment IDs to process + only_parent_chunks: If True, only process parent chunks (for parent-child mode) + + Returns: + List of created DocumentSegmentSummary instances + """ + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", + dataset.id, + dataset.indexing_technique, + ) + return [] + + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info("Summary index is disabled for dataset %s", dataset.id) + return [] + + # Skip qa_model documents + if document.doc_form == "qa_model": + logger.info("Skipping summary generation for qa_model document %s", document.id) + return [] + + logger.info( + "Starting summary generation for document %s in dataset %s, segment_ids: %s, only_parent_chunks: %s", + document.id, + dataset.id, + len(segment_ids) if segment_ids else "all", + only_parent_chunks, + ) + + with session_factory.create_session() as session: + # Query segments (only enabled segments) + query = session.query(DocumentSegment).filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, # Only generate summaries for enabled segments + ) + + if segment_ids: + query = query.filter(DocumentSegment.id.in_(segment_ids)) + + segments = query.all() + + if not segments: + logger.info("No segments found for document %s", document.id) + return [] + + # Batch create summary records with "not_started" status before processing + # This ensures all records exist upfront, allowing status tracking + SummaryIndexService.batch_create_summary_records( + segments=segments, + dataset=dataset, + status="not_started", + ) + session.commit() # Commit initial records + + summary_records = [] + + for segment in segments: + # For parent-child mode, only process parent chunks + # In parent-child mode, all DocumentSegments are parent chunks, + # so we process all of them. Child chunks are stored in ChildChunk table + # and are not DocumentSegments, so they won't be in the segments list. + # This check is mainly for clarity and future-proofing. + if only_parent_chunks: + # In parent-child mode, all segments in the query are parent chunks + # Child chunks are not DocumentSegments, so they won't appear here + # We can process all segments + pass + + try: + summary_record = SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, summary_index_setting + ) + summary_records.append(summary_record) + except Exception as e: + logger.exception("Failed to generate summary for segment %s", segment.id) + # Update summary record with error status + SummaryIndexService.update_summary_record_error( + segment=segment, + dataset=dataset, + error=str(e), + ) + # Continue with other segments + continue + + logger.info( + "Completed summary generation for document %s: %s summaries generated and vectorized", + document.id, + len(summary_records), + ) + return summary_records + + @staticmethod + def disable_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + disabled_by: str | None = None, + ) -> None: + """ + Disable summary records and remove vectors from vector database for segments. + Unlike delete, this preserves the summary records but marks them as disabled. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to disable summaries for. If None, disable all. + disabled_by: User ID who disabled the summaries + """ + from libs.datetime_utils import naive_utc_now + + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by( + dataset_id=dataset.id, + enabled=True, # Only disable enabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + logger.info( + "Disabling %s summary records for dataset %s, segment_ids: %s", + len(summaries), + dataset.id, + len(segment_ids) if segment_ids else "all", + ) + + # Remove from vector database (but keep records) + if dataset.indexing_technique == "high_quality": + summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] + if summary_node_ids: + try: + vector = Vector(dataset) + vector.delete_by_ids(summary_node_ids) + except Exception as e: + logger.warning("Failed to remove summary vectors: %s", str(e)) + + # Disable summary records (don't delete) + now = naive_utc_now() + for summary in summaries: + summary.enabled = False + summary.disabled_at = now + summary.disabled_by = disabled_by + session.add(summary) + + session.commit() + logger.info("Disabled %s summary records for dataset %s", len(summaries), dataset.id) + + @staticmethod + def enable_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + ) -> None: + """ + Enable summary records and re-add vectors to vector database for segments. + + Note: This method enables summaries based on chunk status, not summary_index_setting.enable. + The summary_index_setting.enable flag only controls automatic generation, + not whether existing summaries can be used. + Summary.enabled should always be kept in sync with chunk.enabled. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to enable summaries for. If None, enable all. + """ + # Only enable summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return + + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by( + dataset_id=dataset.id, + enabled=False, # Only enable disabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + logger.info( + "Enabling %s summary records for dataset %s, segment_ids: %s", + len(summaries), + dataset.id, + len(segment_ids) if segment_ids else "all", + ) + + # Re-vectorize and re-add to vector database + enabled_count = 0 + for summary in summaries: + # Get the original segment + segment = ( + session.query(DocumentSegment) + .filter_by( + id=summary.chunk_id, + dataset_id=dataset.id, + ) + .first() + ) + + # Summary.enabled stays in sync with chunk.enabled, + # only enable summary if the associated chunk is enabled. + if not segment or not segment.enabled or segment.status != "completed": + continue + + if not summary.summary_content: + continue + + try: + # Re-vectorize summary (this will update status and tokens in its own session) + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=session) + + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary) + + # Enable summary record + summary.enabled = True + summary.disabled_at = None + summary.disabled_by = None + session.add(summary) + enabled_count += 1 + except Exception: + logger.exception("Failed to re-vectorize summary %s", summary.id) + # Keep it disabled if vectorization fails + continue + + session.commit() + logger.info("Enabled %s summary records for dataset %s", enabled_count, dataset.id) + + @staticmethod + def delete_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + ) -> None: + """ + Delete summary records and vectors for segments (used only for actual deletion scenarios). + For disable/enable operations, use disable_summaries_for_segments/enable_summaries_for_segments. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to delete summaries for. If None, delete all. + """ + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + # Delete from vector database + if dataset.indexing_technique == "high_quality": + summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] + if summary_node_ids: + vector = Vector(dataset) + vector.delete_by_ids(summary_node_ids) + + # Delete summary records + for summary in summaries: + session.delete(summary) + + session.commit() + logger.info("Deleted %s summary records for dataset %s", len(summaries), dataset.id) + + @staticmethod + def update_summary_for_segment( + segment: DocumentSegment, + dataset: Dataset, + summary_content: str, + ) -> DocumentSegmentSummary | None: + """ + Update summary for a segment and re-vectorize it. + + Args: + segment: DocumentSegment to update summary for + dataset: Dataset containing the segment + summary_content: New summary content + + Returns: + Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality + """ + # Only update summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return None + + # When user manually provides summary, allow saving even if summary_index_setting doesn't exist + # summary_index_setting is only needed for LLM generation, not for manual summary vectorization + # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting + + # Skip qa_model documents + if segment.document and segment.document.doc_form == "qa_model": + return None + + with session_factory.create_session() as session: + try: + # Check if summary_content is empty (whitespace-only strings are considered empty) + if not summary_content or not summary_content.strip(): + # If summary is empty, only delete existing summary vector and record + summary_record = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if summary_record: + # Delete old vector if exists + old_summary_node_id = summary_record.summary_index_node_id + if old_summary_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Delete summary record since summary is empty + session.delete(summary_record) + session.commit() + logger.info("Deleted summary for segment %s (empty content provided)", segment.id) + return None + else: + # No existing summary record, nothing to do + logger.info("No summary record found for segment %s, nothing to delete", segment.id) + return None + + # Find existing summary record + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if summary_record: + # Update existing summary + old_summary_node_id = summary_record.summary_index_node_id + + # Update summary content + summary_record.summary_content = summary_content + summary_record.status = "generating" + summary_record.error = None # type: ignore[assignment] # Clear any previous errors + session.add(summary_record) + # Flush to ensure summary_content is saved before vectorize_summary queries it + session.flush() + + # Delete old vector if exists (before vectorization) + if old_summary_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Re-vectorize summary (this will update status to "completed" and tokens in its own session) + # vectorize_summary will also ensure summary_content is preserved + # Note: vectorize_summary may take time due to embedding API calls, but we need to complete it + # to ensure the summary is properly indexed + try: + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary_record, segment, dataset, session=session) + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary_record) + # Now commit the session (summary_record should have status="completed" and tokens from refresh) + session.commit() + logger.info("Successfully updated and re-vectorized summary for segment %s", segment.id) + return summary_record + except Exception as e: + # If vectorization fails, update status to error in current session + # Don't raise the exception - just log it and return the record with error status + # This allows the segment update to complete even if vectorization fails + summary_record.status = "error" + summary_record.error = f"Vectorization failed: {str(e)}" + session.commit() + logger.exception("Failed to vectorize summary for segment %s", segment.id) + # Return the record with error status instead of raising + # The caller can check the status if needed + return summary_record + else: + # Create new summary record if doesn't exist + summary_record = SummaryIndexService.create_summary_record( + segment, dataset, summary_content, status="generating" + ) + # Re-vectorize summary (this will update status to "completed" and tokens in its own session) + # Note: summary_record was created in a different session, + # so we need to merge it into current session + try: + # Merge the record into current session first (since it was created in a different session) + summary_record = session.merge(summary_record) + # Pass the session to vectorize_summary - it will update the merged record + SummaryIndexService.vectorize_summary(summary_record, segment, dataset, session=session) + # Refresh to get updated status and tokens from database + session.refresh(summary_record) + # Commit the session to persist the changes + session.commit() + logger.info("Successfully created and vectorized summary for segment %s", segment.id) + return summary_record + except Exception as e: + # If vectorization fails, update status to error in current session + # Merge the record into current session first + error_record = session.merge(summary_record) + error_record.status = "error" + error_record.error = f"Vectorization failed: {str(e)}" + session.commit() + logger.exception("Failed to vectorize summary for segment %s", segment.id) + # Return the record with error status instead of raising + return error_record + + except Exception as e: + logger.exception("Failed to update summary for segment %s", segment.id) + # Update summary record with error status if it exists + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + if summary_record: + summary_record.status = "error" + summary_record.error = str(e) + session.add(summary_record) + session.commit() + raise + + @staticmethod + def get_segment_summary(segment_id: str, dataset_id: str) -> DocumentSegmentSummary | None: + """ + Get summary for a single segment. + + Args: + segment_id: Segment ID (chunk_id) + dataset_id: Dataset ID + + Returns: + DocumentSegmentSummary instance if found, None otherwise + """ + with session_factory.create_session() as session: + return ( + session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment_id, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + .first() + ) + + @staticmethod + def get_segments_summaries(segment_ids: list[str], dataset_id: str) -> dict[str, DocumentSegmentSummary]: + """ + Get summaries for multiple segments. + + Args: + segment_ids: List of segment IDs (chunk_ids) + dataset_id: Dataset ID + + Returns: + Dictionary mapping segment_id to DocumentSegmentSummary (only enabled summaries) + """ + if not segment_ids: + return {} + + with session_factory.create_session() as session: + summary_records = ( + session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + .all() + ) + + return {summary.chunk_id: summary for summary in summary_records} + + @staticmethod + def get_document_summaries( + document_id: str, dataset_id: str, segment_ids: list[str] | None = None + ) -> list[DocumentSegmentSummary]: + """ + Get all summary records for a document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + segment_ids: Optional list of segment IDs to filter by + + Returns: + List of DocumentSegmentSummary instances (only enabled summaries) + """ + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter( + DocumentSegmentSummary.document_id == document_id, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + return query.all() + + @staticmethod + def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None: + """ + Get summary_index_status for a single document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + tenant_id: Tenant ID + + Returns: + "SUMMARIZING" if there are pending summaries, None otherwise + """ + # Get all segments for this document (excluding qa_model and re_segment) + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment.id) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + .all() + ) + segment_ids = [seg.id for seg in segments] + + if not segment_ids: + return None + + # Get all summary records for these segments + summaries = SummaryIndexService.get_segments_summaries(segment_ids, dataset_id) + summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()} + + # Check if there are any "not_started" or "generating" status summaries + has_pending_summaries = any( + summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) + and summary_status_map[segment_id] in ("not_started", "generating") + for segment_id in segment_ids + ) + + return "SUMMARIZING" if has_pending_summaries else None + + @staticmethod + def get_documents_summary_index_status( + document_ids: list[str], dataset_id: str, tenant_id: str + ) -> dict[str, str | None]: + """ + Get summary_index_status for multiple documents. + + Args: + document_ids: List of document IDs + dataset_id: Dataset ID + tenant_id: Tenant ID + + Returns: + Dictionary mapping document_id to summary_index_status ("SUMMARIZING" or None) + """ + if not document_ids: + return {} + + # Get all segments for these documents (excluding qa_model and re_segment) + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment.id, DocumentSegment.document_id) + .where( + DocumentSegment.document_id.in_(document_ids), + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + .all() + ) + + # Group segments by document_id + document_segments_map: dict[str, list[str]] = {} + for segment in segments: + doc_id = str(segment.document_id) + if doc_id not in document_segments_map: + document_segments_map[doc_id] = [] + document_segments_map[doc_id].append(segment.id) + + # Get all summary records for these segments + all_segment_ids = [seg.id for seg in segments] + summaries = SummaryIndexService.get_segments_summaries(all_segment_ids, dataset_id) + summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()} + + # Calculate summary_index_status for each document + result: dict[str, str | None] = {} + for doc_id in document_ids: + segment_ids = document_segments_map.get(doc_id, []) + if not segment_ids: + # No segments, status is None (not started) + result[doc_id] = None + continue + + # Check if there are any "not_started" or "generating" status summaries + # Only check enabled=True summaries (already filtered in query) + # If segment has no summary record (summary_status_map.get returns None), + # it means the summary is disabled (enabled=False) or not created yet, ignore it + has_pending_summaries = any( + summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) + and summary_status_map[segment_id] in ("not_started", "generating") + for segment_id in segment_ids + ) + + if has_pending_summaries: + # Task is still running (not started or generating) + result[doc_id] = "SUMMARIZING" + else: + # All enabled=True summaries are "completed" or "error", task finished + # Or no enabled=True summaries exist (all disabled) + result[doc_id] = None + + return result + + @staticmethod + def get_document_summary_status_detail( + document_id: str, + dataset_id: str, + ) -> dict[str, Any]: + """ + Get detailed summary status for a document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + + Returns: + Dictionary containing: + - total_segments: Total number of segments in the document + - summary_status: Dictionary with status counts + - completed: Number of summaries completed + - generating: Number of summaries being generated + - error: Number of summaries with errors + - not_started: Number of segments without summary records + - summaries: List of summary records with status and content preview + """ + from services.dataset_service import SegmentService + + # Get all segments for this document + segments = SegmentService.get_segments_by_document_and_dataset( + document_id=document_id, + dataset_id=dataset_id, + status="completed", + enabled=True, + ) + + total_segments = len(segments) + + # Get all summary records for these segments + segment_ids = [segment.id for segment in segments] + summaries = [] + if segment_ids: + summaries = SummaryIndexService.get_document_summaries( + document_id=document_id, + dataset_id=dataset_id, + segment_ids=segment_ids, + ) + + # Create a mapping of chunk_id to summary + summary_map = {summary.chunk_id: summary for summary in summaries} + + # Count statuses + status_counts = { + "completed": 0, + "generating": 0, + "error": 0, + "not_started": 0, + } + + summary_list = [] + for segment in segments: + summary = summary_map.get(segment.id) + if summary: + status = summary.status + status_counts[status] = status_counts.get(status, 0) + 1 + summary_list.append( + { + "segment_id": segment.id, + "segment_position": segment.position, + "status": summary.status, + "summary_preview": ( + summary.summary_content[:100] + "..." + if summary.summary_content and len(summary.summary_content) > 100 + else summary.summary_content + ), + "error": summary.error, + "created_at": int(summary.created_at.timestamp()) if summary.created_at else None, + "updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None, + } + ) + else: + status_counts["not_started"] += 1 + summary_list.append( + { + "segment_id": segment.id, + "segment_position": segment.position, + "status": "not_started", + "summary_preview": None, + "error": None, + "created_at": None, + "updated_at": None, + } + ) + + return { + "total_segments": total_segments, + "summary_status": status_counts, + "summaries": summary_list, + } diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index ab5d5480df..0ae40199ab 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -67,6 +67,8 @@ class WorkflowToolManageService: if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") + WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) + workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -158,6 +160,8 @@ class WorkflowToolManageService: if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) + workflow_tool_provider.name = name workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) diff --git a/api/services/workflow/entities.py b/api/services/workflow/entities.py index 70ec8d6e2a..2af0d1fd90 100644 --- a/api/services/workflow/entities.py +++ b/api/services/workflow/entities.py @@ -98,6 +98,12 @@ class WorkflowTaskData(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) +class WorkflowResumeTaskData(BaseModel): + """Payload for workflow resumption tasks.""" + + workflow_run_id: str + + class AsyncTriggerExecutionResult(BaseModel): """Result from async trigger-based workflow execution""" diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py new file mode 100644 index 0000000000..dd4651f130 --- /dev/null +++ b/api/services/workflow_event_snapshot_service.py @@ -0,0 +1,460 @@ +from __future__ import annotations + +import json +import logging +import queue +import threading +import time +from collections.abc import Generator, Mapping, Sequence +from dataclasses import dataclass +from typing import Any + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session, sessionmaker + +from core.app.apps.message_generator import MessageGenerator +from core.app.entities.task_entities import ( + MessageReplaceStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + StreamEvent, + WorkflowPauseStreamResponse, + WorkflowStartStreamResponse, +) +from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext +from core.workflow.entities import WorkflowStartReason +from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from core.workflow.runtime import GraphRuntimeState +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from models.model import AppMode, Message +from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun +from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot +from repositories.entities.workflow_pause import WorkflowPauseEntity +from repositories.factory import DifyAPIRepositoryFactory + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class MessageContext: + conversation_id: str + message_id: str + created_at: int + answer: str | None = None + + +@dataclass +class BufferState: + queue: queue.Queue[Mapping[str, Any]] + stop_event: threading.Event + done_event: threading.Event + task_id_ready: threading.Event + task_id_hint: str | None = None + + +def build_workflow_event_stream( + *, + app_mode: AppMode, + workflow_run: WorkflowRun, + tenant_id: str, + app_id: str, + session_maker: sessionmaker[Session], + idle_timeout: float = 300, + ping_interval: float = 10.0, +) -> Generator[Mapping[str, Any] | str, None, None]: + topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + message_context = ( + _get_message_context(session_maker, workflow_run.id) if app_mode == AppMode.ADVANCED_CHAT else None + ) + + pause_entity: WorkflowPauseEntity | None = None + if workflow_run.status == WorkflowExecutionStatus.PAUSED: + try: + pause_entity = workflow_run_repo.get_workflow_pause(workflow_run.id) + except Exception: + logger.exception("Failed to load workflow pause for run %s", workflow_run.id) + pause_entity = None + + resumption_context = _load_resumption_context(pause_entity) + node_snapshots = node_execution_repo.get_execution_snapshots_by_workflow_run( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_run.workflow_id, + # NOTE(QuantumGhost): for events resumption, we only care about + # the execution records from `WORKFLOW_RUN`. + # + # Ideally filtering with `workflow_run_id` is enough. However, + # due to the index of `WorkflowNodeExecution` table, we have to + # add a filter condition of `triggered_from` to + # ensure that we can utilize the index. + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=workflow_run.id, + ) + + def _generate() -> Generator[Mapping[str, Any] | str, None, None]: + # send a PING event immediately to prevent the connection staying in pending state for a long time. + # + # This simplify the debugging process as the DevTools in Chrome does not + # provide complete curl command for pending connections. + yield StreamEvent.PING.value + + last_msg_time = time.time() + last_ping_time = last_msg_time + + with topic.subscribe() as sub: + buffer_state = _start_buffering(sub) + try: + task_id = _resolve_task_id(resumption_context, buffer_state, workflow_run.id) + + snapshot_events = _build_snapshot_events( + workflow_run=workflow_run, + node_snapshots=node_snapshots, + task_id=task_id, + message_context=message_context, + pause_entity=pause_entity, + resumption_context=resumption_context, + ) + + for event in snapshot_events: + last_msg_time = time.time() + last_ping_time = last_msg_time + yield event + if _is_terminal_event(event, include_paused=True): + return + + while True: + if buffer_state.done_event.is_set() and buffer_state.queue.empty(): + return + + try: + event = buffer_state.queue.get(timeout=0.1) + except queue.Empty: + current_time = time.time() + if current_time - last_msg_time > idle_timeout: + logger.debug( + "No workflow events received for %s seconds, keeping stream open", + idle_timeout, + ) + last_msg_time = current_time + if current_time - last_ping_time >= ping_interval: + yield StreamEvent.PING.value + last_ping_time = current_time + continue + + last_msg_time = time.time() + last_ping_time = last_msg_time + yield event + if _is_terminal_event(event, include_paused=True): + return + finally: + buffer_state.stop_event.set() + + return _generate() + + +def _get_message_context(session_maker: sessionmaker[Session], workflow_run_id: str) -> MessageContext | None: + with session_maker() as session: + stmt = select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(desc(Message.created_at)) + message = session.scalar(stmt) + if message is None: + return None + created_at = int(message.created_at.timestamp()) if message.created_at else 0 + return MessageContext( + conversation_id=message.conversation_id, + message_id=message.id, + created_at=created_at, + answer=message.answer, + ) + + +def _load_resumption_context(pause_entity: WorkflowPauseEntity | None) -> WorkflowResumptionContext | None: + if pause_entity is None: + return None + try: + raw_state = pause_entity.get_state().decode() + return WorkflowResumptionContext.loads(raw_state) + except Exception: + logger.exception("Failed to load resumption context") + return None + + +def _resolve_task_id( + resumption_context: WorkflowResumptionContext | None, + buffer_state: BufferState | None, + workflow_run_id: str, + wait_timeout: float = 0.2, +) -> str: + if resumption_context is not None: + generate_entity = resumption_context.get_generate_entity() + if generate_entity.task_id: + return generate_entity.task_id + if buffer_state is None: + return workflow_run_id + if buffer_state.task_id_hint is None: + buffer_state.task_id_ready.wait(timeout=wait_timeout) + if buffer_state.task_id_hint: + return buffer_state.task_id_hint + return workflow_run_id + + +def _build_snapshot_events( + *, + workflow_run: WorkflowRun, + node_snapshots: Sequence[WorkflowNodeExecutionSnapshot], + task_id: str, + message_context: MessageContext | None, + pause_entity: WorkflowPauseEntity | None, + resumption_context: WorkflowResumptionContext | None, +) -> list[Mapping[str, Any]]: + events: list[Mapping[str, Any]] = [] + + workflow_started = _build_workflow_started_event( + workflow_run=workflow_run, + task_id=task_id, + ) + _apply_message_context(workflow_started, message_context) + events.append(workflow_started) + + if message_context is not None and message_context.answer is not None: + message_replace = _build_message_replace_event(task_id=task_id, answer=message_context.answer) + _apply_message_context(message_replace, message_context) + events.append(message_replace) + + for snapshot in node_snapshots: + node_started = _build_node_started_event( + workflow_run_id=workflow_run.id, + task_id=task_id, + snapshot=snapshot, + ) + _apply_message_context(node_started, message_context) + events.append(node_started) + + if snapshot.status != WorkflowNodeExecutionStatus.RUNNING.value: + node_finished = _build_node_finished_event( + workflow_run_id=workflow_run.id, + task_id=task_id, + snapshot=snapshot, + ) + _apply_message_context(node_finished, message_context) + events.append(node_finished) + + if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None: + pause_event = _build_pause_event( + workflow_run=workflow_run, + workflow_run_id=workflow_run.id, + task_id=task_id, + pause_entity=pause_entity, + resumption_context=resumption_context, + ) + if pause_event is not None: + _apply_message_context(pause_event, message_context) + events.append(pause_event) + + return events + + +def _build_workflow_started_event( + *, + workflow_run: WorkflowRun, + task_id: str, +) -> dict[str, Any]: + response = WorkflowStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=WorkflowStartStreamResponse.Data( + id=workflow_run.id, + workflow_id=workflow_run.workflow_id, + inputs=workflow_run.inputs_dict or {}, + created_at=int(workflow_run.created_at.timestamp()), + reason=WorkflowStartReason.INITIAL, + ), + ) + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + return payload + + +def _build_message_replace_event(*, task_id: str, answer: str) -> dict[str, Any]: + response = MessageReplaceStreamResponse( + task_id=task_id, + answer=answer, + reason="", + ) + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + return payload + + +def _build_node_started_event( + *, + workflow_run_id: str, + task_id: str, + snapshot: WorkflowNodeExecutionSnapshot, +) -> dict[str, Any]: + created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0 + response = NodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run_id, + data=NodeStartStreamResponse.Data( + id=snapshot.execution_id, + node_id=snapshot.node_id, + node_type=snapshot.node_type, + title=snapshot.title, + index=snapshot.index, + predecessor_node_id=None, + inputs=None, + created_at=created_at, + extras={}, + iteration_id=snapshot.iteration_id, + loop_id=snapshot.loop_id, + ), + ) + return response.to_ignore_detail_dict() + + +def _build_node_finished_event( + *, + workflow_run_id: str, + task_id: str, + snapshot: WorkflowNodeExecutionSnapshot, +) -> dict[str, Any]: + created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0 + finished_at = int(snapshot.finished_at.timestamp()) if snapshot.finished_at else created_at + response = NodeFinishStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run_id, + data=NodeFinishStreamResponse.Data( + id=snapshot.execution_id, + node_id=snapshot.node_id, + node_type=snapshot.node_type, + title=snapshot.title, + index=snapshot.index, + predecessor_node_id=None, + inputs=None, + process_data=None, + outputs=None, + status=snapshot.status, + error=None, + elapsed_time=snapshot.elapsed_time, + execution_metadata=None, + created_at=created_at, + finished_at=finished_at, + files=[], + iteration_id=snapshot.iteration_id, + loop_id=snapshot.loop_id, + ), + ) + return response.to_ignore_detail_dict() + + +def _build_pause_event( + *, + workflow_run: WorkflowRun, + workflow_run_id: str, + task_id: str, + pause_entity: WorkflowPauseEntity, + resumption_context: WorkflowResumptionContext | None, +) -> dict[str, Any] | None: + paused_nodes: list[str] = [] + outputs: dict[str, Any] = {} + if resumption_context is not None: + state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) + paused_nodes = state.get_paused_nodes() + outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {})) + + reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()] + response = WorkflowPauseStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run_id, + data=WorkflowPauseStreamResponse.Data( + workflow_run_id=workflow_run_id, + paused_nodes=paused_nodes, + outputs=outputs, + reasons=reasons, + status=workflow_run.status.value, + created_at=int(workflow_run.created_at.timestamp()), + elapsed_time=float(workflow_run.elapsed_time or 0.0), + total_tokens=int(workflow_run.total_tokens or 0), + total_steps=int(workflow_run.total_steps or 0), + ), + ) + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + return payload + + +def _apply_message_context(payload: dict[str, Any], message_context: MessageContext | None) -> None: + if message_context is None: + return + payload["conversation_id"] = message_context.conversation_id + payload["message_id"] = message_context.message_id + payload["created_at"] = message_context.created_at + + +def _start_buffering(subscription) -> BufferState: + buffer_state = BufferState( + queue=queue.Queue(maxsize=2048), + stop_event=threading.Event(), + done_event=threading.Event(), + task_id_ready=threading.Event(), + ) + + def _worker() -> None: + dropped_count = 0 + try: + while not buffer_state.stop_event.is_set(): + msg = subscription.receive(timeout=0.1) + if msg is None: + continue + event = _parse_event_message(msg) + if event is None: + continue + task_id = event.get("task_id") + if task_id and buffer_state.task_id_hint is None: + buffer_state.task_id_hint = str(task_id) + buffer_state.task_id_ready.set() + try: + buffer_state.queue.put_nowait(event) + except queue.Full: + dropped_count += 1 + try: + buffer_state.queue.get_nowait() + except queue.Empty: + pass + try: + buffer_state.queue.put_nowait(event) + except queue.Full: + continue + logger.warning("Dropped buffered workflow event, total_dropped=%s", dropped_count) + except Exception: + logger.exception("Failed while buffering workflow events") + finally: + buffer_state.done_event.set() + + thread = threading.Thread(target=_worker, name=f"workflow-event-buffer-{id(subscription)}", daemon=True) + thread.start() + return buffer_state + + +def _parse_event_message(message: bytes) -> Mapping[str, Any] | None: + try: + event = json.loads(message) + except json.JSONDecodeError: + logger.warning("Failed to decode workflow event payload") + return None + if not isinstance(event, dict): + return None + return event + + +def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool: + if not isinstance(event, Mapping): + return False + event_type = event.get("event") + if event_type == StreamEvent.WORKFLOW_FINISHED.value: + return True + if include_paused: + return event_type == StreamEvent.WORKFLOW_PAUSED.value + return False diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6404136994..4e1e515de5 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,4 +1,5 @@ import json +import logging import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence @@ -11,21 +12,34 @@ from configs import dify_config from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File from core.repositories import DifyCoreRepositoryFactory +from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.variables import VariableBase from core.variables.variables import Variable -from core.workflow.entities import WorkflowNodeExecution +from core.workflow.entities import GraphInitParams, WorkflowNodeExecution +from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent from core.workflow.node_events import NodeRunResult from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node +from core.workflow.nodes.human_input.entities import ( + DeliveryChannelConfig, + HumanInputNodeData, + apply_debug_email_recipient, + validate_human_input_submission, +) +from core.workflow.nodes.human_input.enums import HumanInputFormKind +from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.runtime import VariablePool +from core.workflow.repositories.human_input_form_repository import FormCreateParams +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from core.workflow.variable_loader import load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated @@ -34,6 +48,8 @@ from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account +from models.enums import UserFrom +from models.human_input import HumanInputFormRecipient, RecipientType from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType @@ -44,6 +60,13 @@ from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededEr from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError +from .human_input_delivery_test_service import ( + DeliveryTestContext, + DeliveryTestEmailRecipient, + DeliveryTestError, + DeliveryTestUnsupportedError, + HumanInputDeliveryTestService, +) from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService @@ -744,6 +767,344 @@ class WorkflowService: return workflow_node_execution + def get_human_input_form_preview( + self, + *, + app_model: App, + account: Account, + node_id: str, + inputs: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any]: + """ + Build a human input form preview for a draft workflow. + + Args: + app_model: Target application model. + account: Current account. + node_id: Human input node ID. + inputs: Values used to fill missing upstream variables referenced in form_content. + """ + draft_workflow = self.get_draft_workflow(app_model=app_model) + if not draft_workflow: + raise ValueError("Workflow not initialized") + + node_config = draft_workflow.get_node_config_by_id(node_id) + node_type = Workflow.get_node_type_from_node_config(node_config) + if node_type is not NodeType.HUMAN_INPUT: + raise ValueError("Node type must be human-input.") + + # inputs: values used to fill missing upstream variables referenced in form_content. + variable_pool = self._build_human_input_variable_pool( + app_model=app_model, + workflow=draft_workflow, + node_config=node_config, + manual_inputs=inputs or {}, + ) + node = self._build_human_input_node( + workflow=draft_workflow, + account=account, + node_config=node_config, + variable_pool=variable_pool, + ) + + rendered_content = node.render_form_content_before_submission() + resolved_default_values = node.resolve_default_values() + node_data = node.node_data + human_input_required = HumanInputRequired( + form_id=node_id, + form_content=rendered_content, + inputs=node_data.inputs, + actions=node_data.user_actions, + node_id=node_id, + node_title=node.title, + resolved_default_values=resolved_default_values, + form_token=None, + ) + return human_input_required.model_dump(mode="json") + + def submit_human_input_form_preview( + self, + *, + app_model: App, + account: Account, + node_id: str, + form_inputs: Mapping[str, Any], + inputs: Mapping[str, Any] | None = None, + action: str, + ) -> Mapping[str, Any]: + """ + Submit a human input form preview for a draft workflow. + + Args: + app_model: Target application model. + account: Current account. + node_id: Human input node ID. + form_inputs: Values the user provides for the form's own fields. + inputs: Values used to fill missing upstream variables referenced in form_content. + action: Selected action ID. + """ + draft_workflow = self.get_draft_workflow(app_model=app_model) + if not draft_workflow: + raise ValueError("Workflow not initialized") + + node_config = draft_workflow.get_node_config_by_id(node_id) + node_type = Workflow.get_node_type_from_node_config(node_config) + if node_type is not NodeType.HUMAN_INPUT: + raise ValueError("Node type must be human-input.") + + # inputs: values used to fill missing upstream variables referenced in form_content. + # form_inputs: values the user provides for the form's own fields. + variable_pool = self._build_human_input_variable_pool( + app_model=app_model, + workflow=draft_workflow, + node_config=node_config, + manual_inputs=inputs or {}, + ) + node = self._build_human_input_node( + workflow=draft_workflow, + account=account, + node_config=node_config, + variable_pool=variable_pool, + ) + node_data = node.node_data + + validate_human_input_submission( + inputs=node_data.inputs, + user_actions=node_data.user_actions, + selected_action_id=action, + form_data=form_inputs, + ) + + rendered_content = node.render_form_content_before_submission() + outputs: dict[str, Any] = dict(form_inputs) + outputs["__action_id"] = action + outputs["__rendered_content"] = node.render_form_content_with_outputs( + rendered_content, outputs, node_data.outputs_field_names() + ) + + enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, + app_id=app_model.id, + node_id=node_id, + node_type=NodeType.HUMAN_INPUT, + node_execution_id=str(uuid.uuid4()), + user=account, + enclosing_node_id=enclosing_node_id, + ) + draft_var_saver.save(outputs=outputs, process_data={}) + session.commit() + + return outputs + + def test_human_input_delivery( + self, + *, + app_model: App, + account: Account, + node_id: str, + delivery_method_id: str, + inputs: Mapping[str, Any] | None = None, + ) -> None: + draft_workflow = self.get_draft_workflow(app_model=app_model) + if not draft_workflow: + raise ValueError("Workflow not initialized") + + node_config = draft_workflow.get_node_config_by_id(node_id) + node_type = Workflow.get_node_type_from_node_config(node_config) + if node_type is not NodeType.HUMAN_INPUT: + raise ValueError("Node type must be human-input.") + + node_data = HumanInputNodeData.model_validate(node_config.get("data", {})) + delivery_method = self._resolve_human_input_delivery_method( + node_data=node_data, + delivery_method_id=delivery_method_id, + ) + if delivery_method is None: + raise ValueError("Delivery method not found.") + delivery_method = apply_debug_email_recipient( + delivery_method, + enabled=True, + user_id=account.id or "", + ) + + variable_pool = self._build_human_input_variable_pool( + app_model=app_model, + workflow=draft_workflow, + node_config=node_config, + manual_inputs=inputs or {}, + ) + node = self._build_human_input_node( + workflow=draft_workflow, + account=account, + node_config=node_config, + variable_pool=variable_pool, + ) + rendered_content = node.render_form_content_before_submission() + resolved_default_values = node.resolve_default_values() + form_id, recipients = self._create_human_input_delivery_test_form( + app_model=app_model, + node_id=node_id, + node_data=node_data, + delivery_method=delivery_method, + rendered_content=rendered_content, + resolved_default_values=resolved_default_values, + ) + test_service = HumanInputDeliveryTestService() + context = DeliveryTestContext( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + node_id=node_id, + node_title=node_data.title, + rendered_content=rendered_content, + template_vars={"form_id": form_id}, + recipients=recipients, + variable_pool=variable_pool, + ) + try: + test_service.send_test(context=context, method=delivery_method) + except DeliveryTestUnsupportedError as exc: + raise ValueError("Delivery method does not support test send.") from exc + except DeliveryTestError as exc: + raise ValueError(str(exc)) from exc + + @staticmethod + def _resolve_human_input_delivery_method( + *, + node_data: HumanInputNodeData, + delivery_method_id: str, + ) -> DeliveryChannelConfig | None: + for method in node_data.delivery_methods: + if str(method.id) == delivery_method_id: + return method + return None + + def _create_human_input_delivery_test_form( + self, + *, + app_model: App, + node_id: str, + node_data: HumanInputNodeData, + delivery_method: DeliveryChannelConfig, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> tuple[str, list[DeliveryTestEmailRecipient]]: + repo = HumanInputFormRepositoryImpl(session_factory=db.engine, tenant_id=app_model.tenant_id) + params = FormCreateParams( + app_id=app_model.id, + workflow_execution_id=None, + node_id=node_id, + form_config=node_data, + rendered_content=rendered_content, + delivery_methods=[delivery_method], + display_in_ui=False, + resolved_default_values=resolved_default_values, + form_kind=HumanInputFormKind.DELIVERY_TEST, + ) + form_entity = repo.create_form(params) + return form_entity.id, self._load_email_recipients(form_entity.id) + + @staticmethod + def _load_email_recipients(form_id: str) -> list[DeliveryTestEmailRecipient]: + logger = logging.getLogger(__name__) + + with Session(bind=db.engine) as session: + recipients = session.scalars( + select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id) + ).all() + recipients_data: list[DeliveryTestEmailRecipient] = [] + for recipient in recipients: + if recipient.recipient_type not in {RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL}: + continue + if not recipient.access_token: + continue + try: + payload = json.loads(recipient.recipient_payload) + except Exception: + logger.exception("Failed to parse human input recipient payload for delivery test.") + continue + email = payload.get("email") + if email: + recipients_data.append(DeliveryTestEmailRecipient(email=email, form_token=recipient.access_token)) + return recipients_data + + def _build_human_input_node( + self, + *, + workflow: Workflow, + account: Account, + node_config: Mapping[str, Any], + variable_pool: VariablePool, + ) -> HumanInputNode: + graph_init_params = GraphInitParams( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + graph_config=workflow.graph_dict, + user_id=account.id, + user_from=UserFrom.ACCOUNT.value, + invoke_from=InvokeFrom.DEBUGGER.value, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + ) + node = HumanInputNode( + id=node_config.get("id", str(uuid.uuid4())), + config=node_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + return node + + def _build_human_input_variable_pool( + self, + *, + app_model: App, + workflow: Workflow, + node_config: Mapping[str, Any], + manual_inputs: Mapping[str, Any], + ) -> VariablePool: + with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): + draft_var_srv = WorkflowDraftVariableService(session) + draft_var_srv.prefill_conversation_variable_default_values(workflow) + + variable_pool = VariablePool( + system_variables=SystemVariable.default(), + user_inputs={}, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) + + variable_loader = DraftVarLoader( + engine=db.engine, + app_id=app_model.id, + tenant_id=app_model.tenant_id, + ) + variable_mapping = HumanInputNode.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, + config=node_config, + ) + normalized_user_inputs: dict[str, Any] = dict(manual_inputs) + + load_into_variable_pool( + variable_loader=variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=normalized_user_inputs, + ) + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=normalized_user_inputs, + variable_pool=variable_pool, + tenant_id=app_model.tenant_id, + ) + + return variable_pool + def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: @@ -945,6 +1306,13 @@ class WorkflowService: if any(nt.is_trigger_node for nt in node_types): raise ValueError("Start node and trigger nodes cannot coexist in the same workflow") + for node in node_configs: + node_data = node.get("data", {}) + node_type = node_data.get("type") + + if node_type == NodeType.HUMAN_INPUT: + self._validate_human_input_node_data(node_data) + def validate_features_structure(self, app_model: App, features: dict): if app_model.mode == AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( @@ -957,6 +1325,23 @@ class WorkflowService: else: raise ValueError(f"Invalid app mode: {app_model.mode}") + def _validate_human_input_node_data(self, node_data: dict) -> None: + """ + Validate HumanInput node data format. + + Args: + node_data: The node data dictionary + + Raises: + ValueError: If the node data format is invalid + """ + from core.workflow.nodes.human_input.entities import HumanInputNodeData + + try: + HumanInputNodeData.model_validate(node_data) + except Exception as e: + raise ValueError(f"Invalid HumanInput node data: {str(e)}") + def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict ) -> Workflow | None: diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 62e6497e9d..2d3d00cd50 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -118,6 +118,19 @@ def add_document_to_index_task(dataset_document_id: str): ) session.commit() + # Enable summary indexes for all segments in this document + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + if segment_ids_list: + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + ) + except Exception as e: + logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e)) + end_at = time.perf_counter() logger.info( click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green") diff --git a/api/tasks/app_generate/__init__.py b/api/tasks/app_generate/__init__.py new file mode 100644 index 0000000000..4aa02ef39f --- /dev/null +++ b/api/tasks/app_generate/__init__.py @@ -0,0 +1,3 @@ +from .workflow_execute_task import AppExecutionParams, resume_app_execution, workflow_based_app_execution_task + +__all__ = ["AppExecutionParams", "resume_app_execution", "workflow_based_app_execution_task"] diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py new file mode 100644 index 0000000000..e58d334f41 --- /dev/null +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -0,0 +1,491 @@ +import contextlib +import logging +import uuid +from collections.abc import Generator, Mapping +from enum import StrEnum +from typing import Annotated, Any, TypeAlias, Union + +from celery import shared_task +from flask import current_app, json +from pydantic import BaseModel, Discriminator, Field, Tag +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session, sessionmaker + +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, + WorkflowAppGenerateEntity, +) +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext +from core.repositories import DifyCoreRepositoryFactory +from core.workflow.runtime import GraphRuntimeState +from extensions.ext_database import db +from libs.flask_utils import set_login_user +from models.account import Account +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.model import App, AppMode, Conversation, EndUser, Message +from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory + +logger = logging.getLogger(__name__) + +WORKFLOW_BASED_APP_EXECUTION_QUEUE = "workflow_based_app_execution" + + +class _UserType(StrEnum): + ACCOUNT = "account" + END_USER = "end_user" + + +class _Account(BaseModel): + TYPE: _UserType = _UserType.ACCOUNT + + user_id: str + + +class _EndUser(BaseModel): + TYPE: _UserType = _UserType.END_USER + end_user_id: str + + +def _get_user_type_descriminator(value: Any): + if isinstance(value, (_Account, _EndUser)): + return value.TYPE + elif isinstance(value, dict): + user_type_str = value.get("TYPE") + if user_type_str is None: + return None + try: + user_type = _UserType(user_type_str) + except ValueError: + return None + return user_type + else: + # return None if the discriminator value isn't found + return None + + +User: TypeAlias = Annotated[ + (Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]), + Discriminator(_get_user_type_descriminator), +] + + +class AppExecutionParams(BaseModel): + app_id: str + workflow_id: str + tenant_id: str + app_mode: AppMode = AppMode.ADVANCED_CHAT + user: User + args: Mapping[str, Any] + + invoke_from: InvokeFrom + streaming: bool = True + call_depth: int = 0 + root_node_id: str | None = None + workflow_run_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + + @classmethod + def new( + cls, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + call_depth: int = 0, + root_node_id: str | None = None, + workflow_run_id: str | None = None, + ): + user_params: _Account | _EndUser + if isinstance(user, Account): + user_params = _Account(user_id=user.id) + elif isinstance(user, EndUser): + user_params = _EndUser(end_user_id=user.id) + else: + raise AssertionError("this statement should be unreachable.") + return cls( + app_id=app_model.id, + workflow_id=workflow.id, + tenant_id=app_model.tenant_id, + app_mode=AppMode.value_of(app_model.mode), + user=user_params, + args=args, + invoke_from=invoke_from, + streaming=streaming, + call_depth=call_depth, + root_node_id=root_node_id, + workflow_run_id=workflow_run_id or str(uuid.uuid4()), + ) + + +class _AppRunner: + def __init__(self, session_factory: sessionmaker | Engine, exec_params: AppExecutionParams): + if isinstance(session_factory, Engine): + session_factory = sessionmaker(bind=session_factory) + self._session_factory = session_factory + self._exec_params = exec_params + + @contextlib.contextmanager + def _session(self): + with self._session_factory(expire_on_commit=False) as session, session.begin(): + yield session + + @contextlib.contextmanager + def _setup_flask_context(self, user: Account | EndUser): + flask_app = current_app._get_current_object() # type: ignore + with flask_app.app_context(): + set_login_user(user) + yield + + def run(self): + exec_params = self._exec_params + with self._session() as session: + workflow = session.get(Workflow, exec_params.workflow_id) + if workflow is None: + logger.warning("Workflow %s not found for execution", exec_params.workflow_id) + return None + app = session.get(App, workflow.app_id) + if app is None: + logger.warning("App %s not found for workflow %s", workflow.app_id, exec_params.workflow_id) + return None + + pause_config = PauseStateLayerConfig( + session_factory=self._session_factory, + state_owner_user_id=workflow.created_by, + ) + + user = self._resolve_user() + + with self._setup_flask_context(user): + response = self._run_app( + app=app, + workflow=workflow, + user=user, + pause_state_config=pause_config, + ) + if not exec_params.streaming: + return response + + assert isinstance(response, Generator) + _publish_streaming_response(response, exec_params.workflow_run_id, exec_params.app_mode) + + def _run_app( + self, + *, + app: App, + workflow: Workflow, + user: Account | EndUser, + pause_state_config: PauseStateLayerConfig, + ): + exec_params = self._exec_params + if exec_params.app_mode == AppMode.ADVANCED_CHAT: + return AdvancedChatAppGenerator().generate( + app_model=app, + workflow=workflow, + user=user, + args=exec_params.args, + invoke_from=exec_params.invoke_from, + streaming=exec_params.streaming, + workflow_run_id=exec_params.workflow_run_id, + pause_state_config=pause_state_config, + ) + if exec_params.app_mode == AppMode.WORKFLOW: + return WorkflowAppGenerator().generate( + app_model=app, + workflow=workflow, + user=user, + args=exec_params.args, + invoke_from=exec_params.invoke_from, + streaming=exec_params.streaming, + call_depth=exec_params.call_depth, + root_node_id=exec_params.root_node_id, + workflow_run_id=exec_params.workflow_run_id, + pause_state_config=pause_state_config, + ) + + logger.error("Unsupported app mode for execution: %s", exec_params.app_mode) + return None + + def _resolve_user(self) -> Account | EndUser: + user_params = self._exec_params.user + + if isinstance(user_params, _EndUser): + with self._session() as session: + return session.get(EndUser, user_params.end_user_id) + elif not isinstance(user_params, _Account): + raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}") + + with self._session() as session: + user: Account = session.get(Account, user_params.user_id) + user.set_tenant_id(self._exec_params.tenant_id) + + return user + + +def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None: + role = CreatorUserRole(workflow_run.created_by_role) + if role == CreatorUserRole.ACCOUNT: + user = session.get(Account, workflow_run.created_by) + if user: + user.set_tenant_id(workflow_run.tenant_id) + return user + + return session.get(EndUser, workflow_run.created_by) + + +def _publish_streaming_response( + response_stream: Generator[str | Mapping[str, Any], None, None], workflow_run_id: str, app_mode: AppMode +) -> None: + topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id) + for event in response_stream: + try: + payload = json.dumps(event) + except TypeError: + logger.exception("error while encoding event") + continue + + topic.publish(payload.encode()) + + +@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE) +def workflow_based_app_execution_task( + payload: str, +) -> Generator[Mapping[str, Any] | str, None, None] | Mapping[str, Any] | None: + exec_params = AppExecutionParams.model_validate_json(payload) + + logger.info("workflow_based_app_execution_task run with params: %s", exec_params) + + runner = _AppRunner(db.engine, exec_params=exec_params) + return runner.run() + + +def _resume_app_execution(payload: dict[str, Any]) -> None: + workflow_run_id = payload["workflow_run_id"] + + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory) + + pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) + if pause_entity is None: + logger.warning("No pause entity found for workflow run %s", workflow_run_id) + return + + try: + resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) + except Exception: + logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id) + return + + generate_entity = resumption_context.get_generate_entity() + + graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) + + conversation = None + message = None + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = session.get(WorkflowRun, workflow_run_id) + if workflow_run is None: + logger.warning("Workflow run %s not found during resume", workflow_run_id) + return + + workflow = session.get(Workflow, workflow_run.workflow_id) + if workflow is None: + logger.warning("Workflow %s not found during resume", workflow_run.workflow_id) + return + + app_model = session.get(App, workflow_run.app_id) + if app_model is None: + logger.warning("App %s not found during resume", workflow_run.app_id) + return + + user = _resolve_user_for_run(session, workflow_run) + if user is None: + logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id) + return + + if isinstance(generate_entity, AdvancedChatAppGenerateEntity): + if generate_entity.conversation_id is None: + logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id) + return + + conversation = session.get(Conversation, generate_entity.conversation_id) + if conversation is None: + logger.warning( + "Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id + ) + return + + message = session.scalar( + select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc()) + ) + if message is None: + logger.warning("Message not found for workflow run %s", workflow_run_id) + return + + if not isinstance(generate_entity, (AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity)): + logger.error( + "Unsupported resumption entity for workflow run %s (found %s)", + workflow_run_id, + type(generate_entity), + ) + return + + workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity) + + pause_config = PauseStateLayerConfig( + session_factory=session_factory, + state_owner_user_id=workflow.created_by, + ) + + if isinstance(generate_entity, AdvancedChatAppGenerateEntity): + assert conversation is not None + assert message is not None + _resume_advanced_chat( + app_model=app_model, + workflow=workflow, + user=user, + conversation=conversation, + message=message, + generate_entity=generate_entity, + graph_runtime_state=graph_runtime_state, + session_factory=session_factory, + pause_state_config=pause_config, + workflow_run_id=workflow_run_id, + workflow_run=workflow_run, + ) + elif isinstance(generate_entity, WorkflowAppGenerateEntity): + _resume_workflow( + app_model=app_model, + workflow=workflow, + user=user, + generate_entity=generate_entity, + graph_runtime_state=graph_runtime_state, + session_factory=session_factory, + pause_state_config=pause_config, + workflow_run_id=workflow_run_id, + workflow_run=workflow_run, + workflow_run_repo=workflow_run_repo, + pause_entity=pause_entity, + ) + + +def _resume_advanced_chat( + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + conversation: Conversation, + message: Message, + generate_entity: AdvancedChatAppGenerateEntity, + graph_runtime_state: GraphRuntimeState, + session_factory: sessionmaker, + pause_state_config: PauseStateLayerConfig, + workflow_run_id: str, + workflow_run: WorkflowRun, +) -> None: + try: + triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) + except ValueError: + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=app_model.id, + triggered_from=triggered_from, + ) + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=app_model.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + generator = AdvancedChatAppGenerator() + + try: + response = generator.resume( + app_model=app_model, + workflow=workflow, + user=user, + conversation=conversation, + message=message, + application_generate_entity=generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + graph_runtime_state=graph_runtime_state, + pause_state_config=pause_state_config, + ) + except Exception: + logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id) + raise + + if generate_entity.stream: + assert isinstance(response, Generator) + _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT) + + +def _resume_workflow( + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + generate_entity: WorkflowAppGenerateEntity, + graph_runtime_state: GraphRuntimeState, + session_factory: sessionmaker, + pause_state_config: PauseStateLayerConfig, + workflow_run_id: str, + workflow_run: WorkflowRun, + workflow_run_repo, + pause_entity, +) -> None: + try: + triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) + except ValueError: + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=app_model.id, + triggered_from=triggered_from, + ) + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=app_model.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + generator = WorkflowAppGenerator() + + try: + response = generator.resume( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=generate_entity, + graph_runtime_state=graph_runtime_state, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + pause_state_config=pause_state_config, + ) + except Exception: + logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id) + raise + + if generate_entity.stream: + assert isinstance(response, Generator) + _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW) + + workflow_run_repo.delete_workflow_pause(pause_entity) + + +@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution") +def resume_app_execution(payload: dict[str, Any]) -> None: + _resume_app_execution(payload) diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index b51884148e..cc96542d4b 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -5,32 +5,42 @@ These tasks handle workflow execution for different subscription tiers with appropriate retry policies and error handling. """ +import logging from datetime import UTC, datetime from typing import Any from celery import shared_task from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext +from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory +from core.repositories import DifyCoreRepositoryFactory +from core.workflow.runtime import GraphRuntimeState +from extensions.ext_database import db from models.account import Account -from models.enums import CreatorUserRole, WorkflowTriggerStatus +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant from models.trigger import WorkflowTriggerLog -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import WorkflowNotFoundError from services.workflow.entities import ( TriggerData, + WorkflowResumeTaskData, WorkflowTaskData, ) from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy +logger = logging.getLogger(__name__) + @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) def execute_workflow_professional(task_data_dict: dict[str, Any]): @@ -141,6 +151,11 @@ def _execute_workflow_common( if trigger_data.workflow_id: args["workflow_id"] = str(trigger_data.workflow_id) + pause_config = PauseStateLayerConfig( + session_factory=session_factory.get_session_maker(), + state_owner_user_id=workflow.created_by, + ) + # Execute the workflow with the trigger type generator.generate( app_model=app_model, @@ -156,6 +171,7 @@ def _execute_workflow_common( # TODO: Re-enable TimeSliceLayer after the HITL release. TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id), ], + pause_state_config=pause_config, ) except Exception as e: @@ -173,21 +189,153 @@ def _execute_workflow_common( session.commit() -def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser: +@shared_task(name="resume_workflow_execution") +def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None: + """Resume a paused workflow run via Celery.""" + task_data = WorkflowResumeTaskData.model_validate(task_data_dict) + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory) + + pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id) + if pause_entity is None: + logger.warning("No pause state for workflow run %s", task_data.workflow_run_id) + return + workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(pause_entity.workflow_execution_id) + if workflow_run is None: + logger.warning("Workflow run not found for pause entity: pause_entity_id=%s", pause_entity.id) + return + + try: + resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) + except Exception as exc: + logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id) + raise exc + + generate_entity = resumption_context.get_generate_entity() + if not isinstance(generate_entity, WorkflowAppGenerateEntity): + logger.error( + "Unsupported resumption entity for workflow run %s: %s", + task_data.workflow_run_id, + type(generate_entity), + ) + return + + graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) + + with session_factory() as session: + workflow = session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) + if workflow is None: + raise WorkflowNotFoundError( + "Workflow not found: workflow_run_id=%s, workflow_id=%s", workflow_run.id, workflow_run.workflow_id + ) + user = _get_user(session, workflow_run) + app_model = session.scalar(select(App).where(App.id == workflow_run.app_id)) + if app_model is None: + raise _AppNotFoundError( + "App not found: app_id=%s, workflow_run_id=%s", workflow_run.app_id, workflow_run.id + ) + + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom(workflow_run.triggered_from), + ) + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + pause_config = PauseStateLayerConfig( + session_factory=session_factory, + state_owner_user_id=workflow.created_by, + ) + + generator = WorkflowAppGenerator() + start_time = datetime.now(UTC) + graph_engine_layers = [] + trigger_log = _query_trigger_log_info(session_factory, task_data.workflow_run_id) + + if trigger_log: + cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity( + queue=AsyncWorkflowQueue(trigger_log.queue_name), + schedule_strategy=AsyncWorkflowSystemStrategy, + granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY, + ) + cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity) + + graph_engine_layers.extend( + [ + TimeSliceLayer(cfs_plan_scheduler), + TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id), + ] + ) + + workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity) + + generator.resume( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=generate_entity, + graph_runtime_state=graph_runtime_state, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + graph_engine_layers=graph_engine_layers, + pause_state_config=pause_config, + ) + workflow_run_repo.delete_workflow_pause(pause_entity) + + +def _get_user(session: Session, workflow_run: WorkflowRun | WorkflowTriggerLog) -> Account | EndUser: """Compose user from trigger log""" - tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id)) + tenant = session.scalar(select(Tenant).where(Tenant.id == workflow_run.tenant_id)) if not tenant: - raise ValueError(f"Tenant not found: {trigger_log.tenant_id}") + raise _TenantNotFoundError( + "Tenant not found for WorkflowRun: tenant_id=%s, workflow_run_id=%s", + workflow_run.tenant_id, + workflow_run.id, + ) # Get user from trigger log - if trigger_log.created_by_role == CreatorUserRole.ACCOUNT: - user = session.scalar(select(Account).where(Account.id == trigger_log.created_by)) + if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: + user = session.scalar(select(Account).where(Account.id == workflow_run.created_by)) if user: user.current_tenant = tenant else: # CreatorUserRole.END_USER - user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by)) + user = session.scalar(select(EndUser).where(EndUser.id == workflow_run.created_by)) if not user: - raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})") + raise _UserNotFoundError( + "User not found: user_id=%s, created_by_role=%s, workflow_run_id=%s", + workflow_run.created_by, + workflow_run.created_by_role, + workflow_run.id, + ) return user + + +def _query_trigger_log_info(session_factory: sessionmaker[Session], workflow_run_id) -> WorkflowTriggerLog | None: + with session_factory() as session, session.begin(): + trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id) + if not trigger_log: + logger.debug("Trigger log not found for workflow_run: workflow_run_id=%s", workflow_run_id) + return None + + return trigger_log + + +class _TenantNotFoundError(Exception): + pass + + +class _UserNotFoundError(Exception): + pass + + +class _AppNotFoundError(Exception): + pass diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 74b939e84d..d388284980 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -50,7 +50,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 86e7cc7160..91ace6be02 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -51,7 +51,9 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index bcca1bf49f..4214f043e0 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -42,7 +42,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): ).all() index_node_ids = [segment.index_node_id for segment in segments] - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) segment_ids = [segment.id for segment in segments] segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) session.execute(segment_delete_stmt) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index bfa709502c..764c635d83 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -47,6 +47,7 @@ def delete_segment_from_index_task( doc_form = dataset_document.doc_form # Proceed with index cleanup using the index_node_ids directly + # For actual deletion, we should delete summaries (not just disable them) index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor.clean( dataset, @@ -54,6 +55,7 @@ def delete_segment_from_index_task( with_keywords=True, delete_child_chunks=True, precomputed_child_node_ids=child_node_ids, + delete_summaries=True, # Actually delete summaries when segment is deleted ) if dataset.is_multimodal: # delete segment attachment binding diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 0ce6429a94..bc45171623 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -60,6 +60,18 @@ def disable_segment_from_index_task(segment_id: str): index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.clean(dataset, [segment.index_node_id]) + # Disable summary index for this segment + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=[segment.id], + disabled_by=segment.disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e)) + end_at = time.perf_counter() logger.info( click.style( diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 03635902d1..3cc267e821 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -68,6 +68,21 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen index_node_ids.extend(attachment_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + # Disable summary indexes for these segments + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + try: + # Get disabled_by from first segment (they should all have the same disabled_by) + disabled_by = segments[0].disabled_by if segments else None + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + disabled_by=disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summaries for segments: %s", str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) except Exception: diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 3bdff60196..34496e9c6f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -14,6 +14,7 @@ from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document from services.feature_service import FeatureService +from tasks.generate_summary_index_task import generate_summary_index_task logger = logging.getLogger(__name__) @@ -99,6 +100,78 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): indexing_runner.run(documents) end_at = time.perf_counter() logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + + # Trigger summary index generation for completed documents if enabled + # Only generate for high_quality indexing technique and when summary_index_setting is enabled + # Re-query dataset to get latest summary_index_setting (in case it was updated) + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.warning("Dataset %s not found after indexing", dataset_id) + return + + if dataset.indexing_technique == "high_quality": + summary_index_setting = dataset.summary_index_setting + if summary_index_setting and summary_index_setting.get("enable"): + # expire all session to get latest document's indexing status + session.expire_all() + # Check each document's indexing status and trigger summary generation if completed + for document_id in document_ids: + # Re-query document to get latest status (IndexingRunner may have updated it) + document = ( + session.query(Document) + .where(Document.id == document_id, Document.dataset_id == dataset_id) + .first() + ) + if document: + logger.info( + "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s", + document_id, + document.indexing_status, + document.doc_form, + document.need_summary, + ) + if ( + document.indexing_status == "completed" + and document.doc_form != "qa_model" + and document.need_summary is True + ): + try: + generate_summary_index_task.delay(dataset.id, document_id, None) + logger.info( + "Queued summary index generation task for document %s in dataset %s " + "after indexing completed", + document_id, + dataset.id, + ) + except Exception: + logger.exception( + "Failed to queue summary index generation task for document %s", + document_id, + ) + # Don't fail the entire indexing process if summary task queuing fails + else: + logger.info( + "Skipping summary generation for document %s: " + "status=%s, doc_form=%s, need_summary=%s", + document_id, + document.indexing_status, + document.doc_form, + document.need_summary, + ) + else: + logger.warning("Document %s not found after indexing", document_id) + else: + logger.info( + "Summary index generation skipped for dataset %s: summary_index_setting.enable=%s", + dataset.id, + summary_index_setting.get("enable") if summary_index_setting else None, + ) + else: + logger.info( + "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')", + dataset.id, + dataset.indexing_technique, + ) except DocumentIsPausedError as ex: logger.info(click.style(str(ex), fg="yellow")) except Exception: diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 1f9f21aa7e..41ebb0b076 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -106,6 +106,17 @@ def enable_segment_to_index_task(segment_id: str): # save vector index index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) + # Enable summary index for this segment + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=[segment.id], + ) + except Exception as e: + logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) except Exception as e: diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 48d3c8e178..d90eb4c39f 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -106,6 +106,18 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i # save vector index index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) + # Enable summary indexes for these segments + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + ) + except Exception as e: + logger.warning("Failed to enable summaries for segments: %s", str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) except Exception as e: diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py new file mode 100644 index 0000000000..e4273e16b5 --- /dev/null +++ b/api/tasks/generate_summary_index_task.py @@ -0,0 +1,119 @@ +"""Async task for generating summary indexes.""" + +import logging +import time + +import click +from celery import shared_task + +from core.db.session_factory import session_factory +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) + + +@shared_task(queue="dataset") +def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None): + """ + Async generate summary index for document segments. + + Args: + dataset_id: Dataset ID + document_id: Document ID + segment_ids: Optional list of specific segment IDs to process. If None, process all segments. + + Usage: + generate_summary_index_task.delay(dataset_id, document_id) + generate_summary_index_task.delay(dataset_id, document_id, segment_ids) + """ + logger.info( + click.style( + f"Start generating summary index for document {document_id} in dataset {dataset_id}", + fg="green", + ) + ) + start_at = time.perf_counter() + + try: + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + + document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + if not document: + logger.error(click.style(f"Document not found: {document_id}", fg="red")) + return + + # Check if document needs summary + if not document.need_summary: + logger.info( + click.style( + f"Skipping summary generation for document {document_id}: need_summary is False", + fg="cyan", + ) + ) + return + + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + click.style( + f"Skipping summary generation for dataset {dataset_id}: " + f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'", + fg="cyan", + ) + ) + return + + # Check if summary index is enabled + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info( + click.style( + f"Summary index is disabled for dataset {dataset_id}", + fg="cyan", + ) + ) + return + + # Determine if only parent chunks should be processed + only_parent_chunks = dataset.chunk_structure == "parent_child_index" + + # Generate summaries + summary_records = SummaryIndexService.generate_summaries_for_document( + dataset=dataset, + document=document, + summary_index_setting=summary_index_setting, + segment_ids=segment_ids, + only_parent_chunks=only_parent_chunks, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Summary index generation completed for document {document_id}: " + f"{len(summary_records)} summaries generated, latency: {end_at - start_at}", + fg="green", + ) + ) + + except Exception as e: + logger.exception("Failed to generate summary index for document %s", document_id) + # Update document segments with error status if needed + if segment_ids: + error_message = f"Summary generation failed: {str(e)}" + with session_factory.create_session() as session: + session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + ).update( + { + DocumentSegment.error: error_message, + }, + synchronize_session=False, + ) + session.commit() diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py new file mode 100644 index 0000000000..5413a33d6a --- /dev/null +++ b/api/tasks/human_input_timeout_tasks.py @@ -0,0 +1,113 @@ +import logging +from datetime import timedelta + +from celery import shared_task +from sqlalchemy import or_, select +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repositories.human_input_repository import HumanInputFormSubmissionRepository +from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from extensions.ext_database import db +from extensions.ext_storage import storage +from libs.datetime_utils import ensure_naive_utc, naive_utc_now +from models.human_input import HumanInputForm +from models.workflow import WorkflowPause, WorkflowRun +from services.human_input_service import HumanInputService + +logger = logging.getLogger(__name__) + + +def _is_global_timeout(form_model: HumanInputForm, global_timeout_seconds: int, *, now) -> bool: + if global_timeout_seconds <= 0: + return False + if form_model.workflow_run_id is None: + return False + created_at = ensure_naive_utc(form_model.created_at) + global_deadline = created_at + timedelta(seconds=global_timeout_seconds) + return global_deadline <= now + + +def _handle_global_timeout(*, form_id: str, workflow_run_id: str, node_id: str, session_factory: sessionmaker) -> None: + now = naive_utc_now() + with session_factory() as session, session.begin(): + workflow_run = session.get(WorkflowRun, workflow_run_id) + if workflow_run is not None: + workflow_run.status = WorkflowExecutionStatus.STOPPED + workflow_run.error = f"Human input global timeout at node {node_id}" + workflow_run.finished_at = now + session.add(workflow_run) + + pause_model = session.scalar(select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id)) + if pause_model is not None: + try: + storage.delete(pause_model.state_object_key) + except Exception: + logger.exception( + "Failed to delete pause state object for workflow_run_id=%s, pause_id=%s", + workflow_run_id, + pause_model.id, + ) + pause_model.resumed_at = now + session.add(pause_model) + + +@shared_task(name="human_input_form_timeout.check_and_resume", queue="schedule_executor") +def check_and_handle_human_input_timeouts(limit: int = 100) -> None: + """Scan for expired human input forms and resume or end workflows.""" + + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + form_repo = HumanInputFormSubmissionRepository(session_factory) + service = HumanInputService(session_factory, form_repository=form_repo) + now = naive_utc_now() + global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS + + with session_factory() as session: + global_deadline = now - timedelta(seconds=global_timeout_seconds) if global_timeout_seconds > 0 else None + timeout_filter = HumanInputForm.expiration_time <= now + if global_deadline is not None: + timeout_filter = or_(timeout_filter, HumanInputForm.created_at <= global_deadline) + stmt = ( + select(HumanInputForm) + .where( + HumanInputForm.status == HumanInputFormStatus.WAITING, + timeout_filter, + ) + .order_by(HumanInputForm.id.asc()) + .limit(limit) + ) + expired_forms = session.scalars(stmt).all() + + for form_model in expired_forms: + try: + if form_model.form_kind == HumanInputFormKind.DELIVERY_TEST: + form_repo.mark_timeout( + form_id=form_model.id, + timeout_status=HumanInputFormStatus.TIMEOUT, + reason="delivery_test_timeout", + ) + continue + + is_global = _is_global_timeout(form_model, global_timeout_seconds, now=now) + record = form_repo.mark_timeout( + form_id=form_model.id, + timeout_status=HumanInputFormStatus.EXPIRED if is_global else HumanInputFormStatus.TIMEOUT, + reason="global_timeout" if is_global else "node_timeout", + ) + assert record.workflow_run_id is not None, "workflow_run_id should not be None for non-test form" + if is_global: + _handle_global_timeout( + form_id=record.form_id, + workflow_run_id=record.workflow_run_id, + node_id=record.node_id, + session_factory=session_factory, + ) + else: + service.enqueue_resume(record.workflow_run_id) + except Exception: + logger.exception( + "Failed to handle timeout for form_id=%s workflow_run_id=%s", + form_model.id, + form_model.workflow_run_id, + ) diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py new file mode 100644 index 0000000000..d1cd0fbadc --- /dev/null +++ b/api/tasks/mail_human_input_delivery_task.py @@ -0,0 +1,190 @@ +import json +import logging +import time +from dataclasses import dataclass +from typing import Any + +import click +from celery import shared_task +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext +from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod +from core.workflow.runtime import GraphRuntimeState, VariablePool +from extensions.ext_database import db +from extensions.ext_mail import mail +from models.human_input import ( + DeliveryMethodType, + HumanInputDelivery, + HumanInputForm, + HumanInputFormRecipient, + RecipientType, +) +from repositories.factory import DifyAPIRepositoryFactory +from services.feature_service import FeatureService + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class _EmailRecipient: + email: str + token: str + + +@dataclass(frozen=True) +class _EmailDeliveryJob: + form_id: str + subject: str + body: str + form_content: str + recipients: list[_EmailRecipient] + + +def _build_form_link(token: str) -> str: + base_url = dify_config.APP_WEB_URL + return f"{base_url.rstrip('/')}/form/{token}" + + +def _parse_recipient_payload(payload: str) -> tuple[str | None, RecipientType | None]: + try: + payload_dict: dict[str, Any] = json.loads(payload) + except Exception: + logger.exception("Failed to parse recipient payload") + return None, None + + return payload_dict.get("email"), payload_dict.get("TYPE") + + +def _load_email_jobs(session: Session, form: HumanInputForm) -> list[_EmailDeliveryJob]: + deliveries = session.scalars( + select(HumanInputDelivery).where( + HumanInputDelivery.form_id == form.id, + HumanInputDelivery.delivery_method_type == DeliveryMethodType.EMAIL, + ) + ).all() + jobs: list[_EmailDeliveryJob] = [] + for delivery in deliveries: + delivery_config = EmailDeliveryMethod.model_validate_json(delivery.channel_payload) + + recipients = session.scalars( + select(HumanInputFormRecipient).where(HumanInputFormRecipient.delivery_id == delivery.id) + ).all() + + recipient_entities: list[_EmailRecipient] = [] + for recipient in recipients: + email, recipient_type = _parse_recipient_payload(recipient.recipient_payload) + if recipient_type not in {RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL}: + continue + if not email: + continue + token = recipient.access_token + if not token: + continue + recipient_entities.append(_EmailRecipient(email=email, token=token)) + + if not recipient_entities: + continue + + jobs.append( + _EmailDeliveryJob( + form_id=form.id, + subject=delivery_config.config.subject, + body=delivery_config.config.body, + form_content=form.rendered_content, + recipients=recipient_entities, + ) + ) + return jobs + + +def _render_body( + body_template: str, + form_link: str, + *, + variable_pool: VariablePool | None, +) -> str: + body = EmailDeliveryConfig.render_body_template( + body=body_template, + url=form_link, + variable_pool=variable_pool, + ) + return body + + +def _load_variable_pool(workflow_run_id: str | None) -> VariablePool | None: + if not workflow_run_id: + return None + + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory) + pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) + if pause_entity is None: + logger.info("No pause state found for workflow run %s", workflow_run_id) + return None + + try: + resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) + except Exception: + logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id) + return None + + graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) + return graph_runtime_state.variable_pool + + +def _open_session(session_factory: sessionmaker | Session | None): + if session_factory is None: + return Session(db.engine) + if isinstance(session_factory, Session): + return session_factory + return session_factory() + + +@shared_task(queue="mail") +def dispatch_human_input_email_task(form_id: str, node_title: str | None = None, session_factory=None): + if not mail.is_inited(): + return + + logger.info(click.style(f"Start human input email delivery for form {form_id}", fg="green")) + start_at = time.perf_counter() + + try: + with _open_session(session_factory) as session: + form = session.get(HumanInputForm, form_id) + if form is None: + logger.warning("Human input form not found, form_id=%s", form_id) + return + features = FeatureService.get_features(form.tenant_id) + if not features.human_input_email_delivery_enabled: + logger.info( + "Human input email delivery is not available for tenant=%s, form_id=%s", + form.tenant_id, + form_id, + ) + return + jobs = _load_email_jobs(session, form) + + variable_pool = _load_variable_pool(form.workflow_run_id) + + for job in jobs: + for recipient in job.recipients: + form_link = _build_form_link(recipient.token) + body = _render_body(job.body, form_link, variable_pool=variable_pool) + + mail.send( + to=recipient.email, + subject=job.subject, + html=body, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Human input email delivery succeeded for form {form_id}: latency: {end_at - start_at}", fg="green" + ) + ) + except Exception: + logger.exception("Send human input email failed, form_id=%s", form_id) diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py new file mode 100644 index 0000000000..cf8988d13e --- /dev/null +++ b/api/tasks/regenerate_summary_index_task.py @@ -0,0 +1,315 @@ +"""Task for regenerating summary indexes when dataset settings change.""" + +import logging +import time +from collections import defaultdict + +import click +from celery import shared_task +from sqlalchemy import or_, select + +from core.db.session_factory import session_factory +from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary +from models.dataset import Document as DatasetDocument +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) + + +@shared_task(queue="dataset") +def regenerate_summary_index_task( + dataset_id: str, + regenerate_reason: str = "summary_model_changed", + regenerate_vectors_only: bool = False, +): + """ + Regenerate summary indexes for all documents in a dataset. + + This task is triggered when: + 1. summary_index_setting model changes (regenerate_reason="summary_model_changed") + - Regenerates summary content and vectors for all existing summaries + 2. embedding_model changes (regenerate_reason="embedding_model_changed") + - Only regenerates vectors for existing summaries (keeps summary content) + + Args: + dataset_id: Dataset ID + regenerate_reason: Reason for regeneration ("summary_model_changed" or "embedding_model_changed") + regenerate_vectors_only: If True, only regenerate vectors without regenerating summary content + """ + logger.info( + click.style( + f"Start regenerate summary index for dataset {dataset_id}, reason: {regenerate_reason}", + fg="green", + ) + ) + start_at = time.perf_counter() + + try: + with session_factory.create_session() as session: + dataset = session.query(Dataset).filter_by(id=dataset_id).first() + if not dataset: + logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + + # Only regenerate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + click.style( + f"Skipping summary regeneration for dataset {dataset_id}: " + f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'", + fg="cyan", + ) + ) + return + + # Check if summary index is enabled (only for summary_model change) + # For embedding_model change, we still re-vectorize existing summaries even if setting is disabled + summary_index_setting = dataset.summary_index_setting + if not regenerate_vectors_only: + # For summary_model change, require summary_index_setting to be enabled + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info( + click.style( + f"Summary index is disabled for dataset {dataset_id}", + fg="cyan", + ) + ) + return + + total_segments_processed = 0 + total_segments_failed = 0 + + if regenerate_vectors_only: + # For embedding_model change: directly query all segments with existing summaries + # Don't require document indexing_status == "completed" + # Include summaries with status "completed" or "error" (if they have content) + segments_with_summaries = ( + session.query(DocumentSegment, DocumentSegmentSummary) + .join( + DocumentSegmentSummary, + DocumentSegment.id == DocumentSegmentSummary.chunk_id, + ) + .join( + DatasetDocument, + DocumentSegment.document_id == DatasetDocument.id, + ) + .where( + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.status == "completed", # Segment must be completed + DocumentSegment.enabled == True, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.summary_content.isnot(None), # Must have summary content + # Include completed summaries or error summaries (with content) + or_( + DocumentSegmentSummary.status == "completed", + DocumentSegmentSummary.status == "error", + ), + DatasetDocument.enabled == True, # Document must be enabled + DatasetDocument.archived == False, # Document must not be archived + DatasetDocument.doc_form != "qa_model", # Skip qa_model documents + ) + .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) + .all() + ) + + if not segments_with_summaries: + logger.info( + click.style( + f"No segments with summaries found for re-vectorization in dataset {dataset_id}", + fg="cyan", + ) + ) + return + + logger.info( + "Found %s segments with summaries for re-vectorization in dataset %s", + len(segments_with_summaries), + dataset_id, + ) + + # Group by document for logging + segments_by_document = defaultdict(list) + for segment, summary_record in segments_with_summaries: + segments_by_document[segment.document_id].append((segment, summary_record)) + + logger.info( + "Segments grouped into %s documents for re-vectorization", + len(segments_by_document), + ) + + for document_id, segment_summary_pairs in segments_by_document.items(): + logger.info( + "Re-vectorizing summaries for %s segments in document %s", + len(segment_summary_pairs), + document_id, + ) + + for segment, summary_record in segment_summary_pairs: + try: + # Delete old vector + if summary_record.summary_index_node_id: + try: + from core.rag.datasource.vdb.vector_factory import Vector + + vector = Vector(dataset) + vector.delete_by_ids([summary_record.summary_index_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Re-vectorize with new embedding model + SummaryIndexService.vectorize_summary(summary_record, segment, dataset) + session.commit() + total_segments_processed += 1 + + except Exception as e: + logger.error( + "Failed to re-vectorize summary for segment %s: %s", + segment.id, + str(e), + exc_info=True, + ) + total_segments_failed += 1 + # Update summary record with error status + summary_record.status = "error" + summary_record.error = f"Re-vectorization failed: {str(e)}" + session.add(summary_record) + session.commit() + continue + + else: + # For summary_model change: require document indexing_status == "completed" + # Get all documents with completed indexing status + dataset_documents = session.scalars( + select(DatasetDocument).where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + ).all() + + if not dataset_documents: + logger.info( + click.style( + f"No documents found for summary regeneration in dataset {dataset_id}", + fg="cyan", + ) + ) + return + + logger.info( + "Found %s documents for summary regeneration in dataset %s", + len(dataset_documents), + dataset_id, + ) + + for dataset_document in dataset_documents: + # Skip qa_model documents + if dataset_document.doc_form == "qa_model": + continue + + try: + # Get all segments with existing summaries + segments = ( + session.query(DocumentSegment) + .join( + DocumentSegmentSummary, + DocumentSegment.id == DocumentSegmentSummary.chunk_id, + ) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegmentSummary.dataset_id == dataset_id, + ) + .order_by(DocumentSegment.position.asc()) + .all() + ) + + if not segments: + continue + + logger.info( + "Regenerating summaries for %s segments in document %s", + len(segments), + dataset_document.id, + ) + + for segment in segments: + summary_record = None + try: + # Get existing summary record + summary_record = ( + session.query(DocumentSegmentSummary) + .filter_by( + chunk_id=segment.id, + dataset_id=dataset_id, + ) + .first() + ) + + if not summary_record: + logger.warning("Summary record not found for segment %s, skipping", segment.id) + continue + + # Regenerate both summary content and vectors (for summary_model change) + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, summary_index_setting + ) + session.commit() + total_segments_processed += 1 + + except Exception as e: + logger.error( + "Failed to regenerate summary for segment %s: %s", + segment.id, + str(e), + exc_info=True, + ) + total_segments_failed += 1 + # Update summary record with error status + if summary_record: + summary_record.status = "error" + summary_record.error = f"Regeneration failed: {str(e)}" + session.add(summary_record) + session.commit() + continue + + except Exception as e: + logger.error( + "Failed to process document %s for summary regeneration: %s", + dataset_document.id, + str(e), + exc_info=True, + ) + continue + + end_at = time.perf_counter() + if regenerate_vectors_only: + logger.info( + click.style( + f"Summary re-vectorization completed for dataset {dataset_id}: " + f"{total_segments_processed} segments processed successfully, " + f"{total_segments_failed} segments failed, " + f"latency: {end_at - start_at:.2f}s", + fg="green", + ) + ) + else: + logger.info( + click.style( + f"Summary index regeneration completed for dataset {dataset_id}: " + f"{total_segments_processed} segments processed successfully, " + f"{total_segments_failed} segments failed, " + f"latency: {end_at - start_at:.2f}s", + fg="green", + ) + ) + + except Exception: + logger.exception("Regenerate summary index failed for dataset %s", dataset_id) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index c3c255fb17..55259ab527 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -46,6 +46,21 @@ def remove_document_from_index_task(document_id: str): index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() + + # Disable summary indexes for all segments in this document + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + if segment_ids_list: + try: + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + disabled_by=document.disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e)) + index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 948cf8b3a0..44adadeaa5 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -1,3 +1,4 @@ +import logging import os import pathlib import random @@ -10,26 +11,34 @@ from flask.testing import FlaskClient from sqlalchemy.orm import Session from app_factory import create_app +from configs.app_config import DifyConfig from extensions.ext_database import db from models import Account, DifySetup, Tenant, TenantAccountJoin from services.account_service import AccountService, RegisterService +_DEFUALT_TEST_ENV = ".env" +_DEFAULT_VDB_TEST_ENV = "vdb.env" + +_logger = logging.getLogger(__name__) + # Loading the .env file if it exists def _load_env(): current_file_path = pathlib.Path(__file__).absolute() # Items later in the list have higher precedence. - files_to_load = [".env", "vdb.env"] + env_file_paths = [ + os.getenv("DIFY_TEST_ENV_FILE", str(current_file_path.parent / _DEFUALT_TEST_ENV)), + os.getenv("DIFY_VDB_TEST_ENV_FILE", str(current_file_path.parent / _DEFAULT_VDB_TEST_ENV)), + ] - env_file_paths = [current_file_path.parent / i for i in files_to_load] - for path in env_file_paths: - if not path.exists(): - continue + for env_path_str in env_file_paths: + if not pathlib.Path(env_path_str).exists(): + _logger.warning("specified configuration file %s not exist", env_path_str) from dotenv import load_dotenv # Set `override=True` to ensure values from `vdb.env` take priority over values from `.env` - load_dotenv(str(path), override=True) + load_dotenv(str(env_path_str), override=True) _load_env() @@ -41,6 +50,12 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs") _CACHED_APP = create_app() +@pytest.fixture(scope="session") +def dify_config() -> DifyConfig: + config = DifyConfig() # type: ignore + return config + + @pytest.fixture def flask_app() -> Flask: return _CACHED_APP diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py new file mode 100644 index 0000000000..e3f0d8a96e --- /dev/null +++ b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py @@ -0,0 +1,36 @@ +""" +Utilities and helpers for Redis broadcast channel integration tests. + +This module provides utility classes and functions for testing +Redis broadcast channel functionality. +""" + +from .test_data import ( + LARGE_MESSAGES, + SMALL_MESSAGES, + SPECIAL_MESSAGES, + BufferTestConfig, + ConcurrencyTestConfig, + ErrorTestConfig, +) +from .test_helpers import ( + ConcurrentPublisher, + SubscriptionMonitor, + assert_message_order, + measure_throughput, + wait_for_condition, +) + +__all__ = [ + "LARGE_MESSAGES", + "SMALL_MESSAGES", + "SPECIAL_MESSAGES", + "BufferTestConfig", + "ConcurrencyTestConfig", + "ConcurrentPublisher", + "ErrorTestConfig", + "SubscriptionMonitor", + "assert_message_order", + "measure_throughput", + "wait_for_condition", +] diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py new file mode 100644 index 0000000000..2cccb08304 --- /dev/null +++ b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py @@ -0,0 +1,315 @@ +""" +Test data and configuration classes for Redis broadcast channel integration tests. + +This module provides dataclasses and constants for test configurations, +message sets, and test scenarios. +""" + +import dataclasses +from typing import Any + +from libs.broadcast_channel.channel import Overflow + + +@dataclasses.dataclass(frozen=True) +class BufferTestConfig: + """Configuration for buffer management tests.""" + + buffer_size: int + overflow_strategy: Overflow + message_count: int + expected_behavior: str + description: str + + +@dataclasses.dataclass(frozen=True) +class ConcurrencyTestConfig: + """Configuration for concurrency tests.""" + + publisher_count: int + subscriber_count: int + messages_per_publisher: int + test_duration: float + description: str + + +@dataclasses.dataclass(frozen=True) +class ErrorTestConfig: + """Configuration for error handling tests.""" + + error_type: str + test_input: Any + expected_exception: type[Exception] + description: str + + +# Test message sets for different scenarios +SMALL_MESSAGES = [ + b"msg_1", + b"msg_2", + b"msg_3", + b"msg_4", + b"msg_5", +] + +MEDIUM_MESSAGES = [ + b"medium_message_1_with_more_content", + b"medium_message_2_with_more_content", + b"medium_message_3_with_more_content", + b"medium_message_4_with_more_content", + b"medium_message_5_with_more_content", +] + +LARGE_MESSAGES = [ + b"large_message_" + b"x" * 1000, + b"large_message_" + b"y" * 1000, + b"large_message_" + b"z" * 1000, +] + +VERY_LARGE_MESSAGES = [ + b"very_large_message_" + b"x" * 10000, # ~10KB + b"very_large_message_" + b"y" * 50000, # ~50KB + b"very_large_message_" + b"z" * 100000, # ~100KB +] + +SPECIAL_MESSAGES = [ + b"", # Empty message + b"\x00\x01\x02", # Binary data with null bytes + "unicode_test_你好".encode(), # Unicode + b"special_chars_!@#$%^&*()_+-=[]{}|;':\",./<>?", # Special characters + b"newlines\n\r\t", # Control characters +] + +BINARY_MESSAGES = [ + bytes(range(256)), # All possible byte values + b"\xff\xfe\xfd\xfc\xfb\xfa\xf9\xf8", # High byte values + b"\x00\x01\x02\x03\x04\x05\x06\x07", # Low byte values +] + +# Buffer test configurations +BUFFER_TEST_CONFIGS = [ + BufferTestConfig( + buffer_size=3, + overflow_strategy=Overflow.DROP_OLDEST, + message_count=5, + expected_behavior="drop_oldest", + description="Drop oldest messages when buffer is full", + ), + BufferTestConfig( + buffer_size=3, + overflow_strategy=Overflow.DROP_NEWEST, + message_count=5, + expected_behavior="drop_newest", + description="Drop newest messages when buffer is full", + ), + BufferTestConfig( + buffer_size=3, + overflow_strategy=Overflow.BLOCK, + message_count=5, + expected_behavior="block", + description="Block when buffer is full", + ), +] + +# Concurrency test configurations +CONCURRENCY_TEST_CONFIGS = [ + ConcurrencyTestConfig( + publisher_count=1, + subscriber_count=1, + messages_per_publisher=10, + test_duration=5.0, + description="Single publisher, single subscriber", + ), + ConcurrencyTestConfig( + publisher_count=3, + subscriber_count=1, + messages_per_publisher=10, + test_duration=5.0, + description="Multiple publishers, single subscriber", + ), + ConcurrencyTestConfig( + publisher_count=1, + subscriber_count=3, + messages_per_publisher=10, + test_duration=5.0, + description="Single publisher, multiple subscribers", + ), + ConcurrencyTestConfig( + publisher_count=3, + subscriber_count=3, + messages_per_publisher=10, + test_duration=5.0, + description="Multiple publishers, multiple subscribers", + ), +] + +# Error test configurations +ERROR_TEST_CONFIGS = [ + ErrorTestConfig( + error_type="invalid_buffer_size", + test_input=0, + expected_exception=ValueError, + description="Zero buffer size should raise ValueError", + ), + ErrorTestConfig( + error_type="invalid_buffer_size", + test_input=-1, + expected_exception=ValueError, + description="Negative buffer size should raise ValueError", + ), + ErrorTestConfig( + error_type="invalid_buffer_size", + test_input=1.5, + expected_exception=TypeError, + description="Float buffer size should raise TypeError", + ), + ErrorTestConfig( + error_type="invalid_buffer_size", + test_input="invalid", + expected_exception=TypeError, + description="String buffer size should raise TypeError", + ), +] + +# Topic name test cases +TOPIC_NAME_TEST_CASES = [ + "simple_topic", + "topic_with_underscores", + "topic-with-dashes", + "topic.with.dots", + "topic_with_numbers_123", + "UPPERCASE_TOPIC", + "mixed_Case_Topic", + "topic_with_symbols_!@#$%", + "very_long_topic_name_" + "x" * 100, + "unicode_topic_你好", + "topic:with:colons", + "topic/with/slashes", + "topic\\with\\backslashes", +] + +# Performance test configurations +PERFORMANCE_TEST_CONFIGS = [ + { + "name": "small_messages_high_frequency", + "message_size": 50, + "message_count": 1000, + "description": "Many small messages", + }, + { + "name": "medium_messages_medium_frequency", + "message_size": 500, + "message_count": 100, + "description": "Medium messages", + }, + { + "name": "large_messages_low_frequency", + "message_size": 5000, + "message_count": 10, + "description": "Large messages", + }, +] + +# Stress test configurations +STRESS_TEST_CONFIGS = [ + { + "name": "high_frequency_publishing", + "publisher_count": 5, + "messages_per_publisher": 100, + "subscriber_count": 3, + "description": "High frequency publishing with multiple publishers", + }, + { + "name": "many_subscribers", + "publisher_count": 1, + "messages_per_publisher": 50, + "subscriber_count": 10, + "description": "Many subscribers to single publisher", + }, + { + "name": "mixed_load", + "publisher_count": 3, + "messages_per_publisher": 100, + "subscriber_count": 5, + "description": "Mixed load with multiple publishers and subscribers", + }, +] + +# Edge case test data +EDGE_CASE_MESSAGES = [ + b"", # Empty message + b"\x00", # Single null byte + b"\xff", # Single max byte value + b"a", # Single ASCII character + "ä".encode(), # Single unicode character (2 bytes) + "𐍈".encode(), # Unicode character outside BMP (4 bytes) + b"\x00" * 1000, # 1000 null bytes + b"\xff" * 1000, # 1000 max byte values +] + +# Message validation test data +MESSAGE_VALIDATION_TEST_CASES = [ + { + "name": "valid_bytes", + "input": b"valid_message", + "should_pass": True, + "description": "Valid bytes message", + }, + { + "name": "empty_bytes", + "input": b"", + "should_pass": True, + "description": "Empty bytes message", + }, + { + "name": "binary_data", + "input": bytes(range(256)), + "should_pass": True, + "description": "Binary data with all byte values", + }, + { + "name": "large_message", + "input": b"x" * 1000000, # 1MB + "should_pass": True, + "description": "Large message (1MB)", + }, +] + +# Redis connection test scenarios +REDIS_CONNECTION_TEST_SCENARIOS = [ + { + "name": "normal_connection", + "should_fail": False, + "description": "Normal Redis connection", + }, + { + "name": "connection_timeout", + "should_fail": True, + "description": "Connection timeout scenario", + }, + { + "name": "connection_refused", + "should_fail": True, + "description": "Connection refused scenario", + }, +] + +# Test constants +DEFAULT_TIMEOUT = 10.0 +SHORT_TIMEOUT = 2.0 +LONG_TIMEOUT = 30.0 + +# Message size limits for testing +MAX_SMALL_MESSAGE_SIZE = 100 +MAX_MEDIUM_MESSAGE_SIZE = 1000 +MAX_LARGE_MESSAGE_SIZE = 10000 + +# Thread counts for concurrency testing +MIN_THREAD_COUNT = 1 +MAX_THREAD_COUNT = 10 +DEFAULT_THREAD_COUNT = 3 + +# Buffer sizes for testing +MIN_BUFFER_SIZE = 1 +MAX_BUFFER_SIZE = 1000 +DEFAULT_BUFFER_SIZE = 10 diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py new file mode 100644 index 0000000000..65f3007b01 --- /dev/null +++ b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py @@ -0,0 +1,396 @@ +""" +Test helper utilities for Redis broadcast channel integration tests. + +This module provides utility classes and functions for testing concurrent +operations, monitoring subscriptions, and measuring performance. +""" + +import logging +import threading +import time +from collections.abc import Callable +from typing import Any + +_logger = logging.getLogger(__name__) + + +class ConcurrentPublisher: + """ + Utility class for publishing messages concurrently from multiple threads. + + This class manages multiple publisher threads that can publish messages + to the same or different topics concurrently, useful for stress testing + and concurrency validation. + """ + + def __init__(self, producer, message_count: int = 10, delay: float = 0.0): + """ + Initialize the concurrent publisher. + + Args: + producer: The producer instance to publish with + message_count: Number of messages to publish per thread + delay: Delay between messages in seconds + """ + self.producer = producer + self.message_count = message_count + self.delay = delay + self.threads: list[threading.Thread] = [] + self.published_messages: list[list[bytes]] = [] + self._lock = threading.Lock() + self._started = False + + def start_publishers(self, thread_count: int = 3) -> None: + """ + Start multiple publisher threads. + + Args: + thread_count: Number of publisher threads to start + """ + if self._started: + raise RuntimeError("Publishers already started") + + self._started = True + + def _publisher(thread_id: int) -> None: + messages: list[bytes] = [] + for i in range(self.message_count): + message = f"thread_{thread_id}_msg_{i}".encode() + try: + self.producer.publish(message) + messages.append(message) + if self.delay > 0: + time.sleep(self.delay) + except Exception: + _logger.exception("Pubmsg=lisher %s", thread_id) + + with self._lock: + self.published_messages.append(messages) + + for thread_id in range(thread_count): + thread = threading.Thread( + target=_publisher, + args=(thread_id,), + name=f"publisher-{thread_id}", + daemon=True, + ) + thread.start() + self.threads.append(thread) + + def wait_for_completion(self, timeout: float = 30.0) -> bool: + """ + Wait for all publisher threads to complete. + + Args: + timeout: Maximum time to wait in seconds + + Returns: + bool: True if all threads completed successfully + """ + for thread in self.threads: + thread.join(timeout) + if thread.is_alive(): + return False + return True + + def get_all_messages(self) -> list[bytes]: + """ + Get all messages published by all threads. + + Returns: + list[bytes]: Flattened list of all published messages + """ + with self._lock: + all_messages = [] + for thread_messages in self.published_messages: + all_messages.extend(thread_messages) + return all_messages + + def get_thread_messages(self, thread_id: int) -> list[bytes]: + """ + Get messages published by a specific thread. + + Args: + thread_id: ID of the thread + + Returns: + list[bytes]: Messages published by the specified thread + """ + with self._lock: + if 0 <= thread_id < len(self.published_messages): + return self.published_messages[thread_id].copy() + return [] + + +class SubscriptionMonitor: + """ + Utility class for monitoring subscription activity in tests. + + This class monitors a subscription and tracks message reception, + errors, and completion status for testing purposes. + """ + + def __init__(self, subscription, timeout: float = 10.0): + """ + Initialize the subscription monitor. + + Args: + subscription: The subscription to monitor + timeout: Default timeout for operations + """ + self.subscription = subscription + self.timeout = timeout + self.messages: list[bytes] = [] + self.errors: list[Exception] = [] + self.completed = False + self._lock = threading.Lock() + self._condition = threading.Condition(self._lock) + self._monitor_thread: threading.Thread | None = None + self._start_time: float | None = None + + def start_monitoring(self) -> None: + """Start monitoring the subscription in a separate thread.""" + if self._monitor_thread is not None: + raise RuntimeError("Monitoring already started") + + self._start_time = time.time() + + def _monitor(): + try: + for message in self.subscription: + with self._lock: + self.messages.append(message) + self._condition.notify_all() + except Exception as e: + with self._lock: + self.errors.append(e) + self._condition.notify_all() + finally: + with self._lock: + self.completed = True + self._condition.notify_all() + + self._monitor_thread = threading.Thread( + target=_monitor, + name="subscription-monitor", + daemon=True, + ) + self._monitor_thread.start() + + def wait_for_messages(self, count: int, timeout: float | None = None) -> bool: + """ + Wait for a specific number of messages. + + Args: + count: Number of messages to wait for + timeout: Timeout in seconds (uses default if None) + + Returns: + bool: True if expected messages were received + """ + if timeout is None: + timeout = self.timeout + + deadline = time.time() + timeout + + with self._condition: + while len(self.messages) < count and not self.completed: + remaining = deadline - time.time() + if remaining <= 0: + return False + self._condition.wait(remaining) + + return len(self.messages) >= count + + def wait_for_completion(self, timeout: float | None = None) -> bool: + """ + Wait for monitoring to complete. + + Args: + timeout: Timeout in seconds (uses default if None) + + Returns: + bool: True if monitoring completed successfully + """ + if timeout is None: + timeout = self.timeout + + deadline = time.time() + timeout + + with self._condition: + while not self.completed: + remaining = deadline - time.time() + if remaining <= 0: + return False + self._condition.wait(remaining) + + return True + + def get_messages(self) -> list[bytes]: + """ + Get all received messages. + + Returns: + list[bytes]: Copy of received messages + """ + with self._lock: + return self.messages.copy() + + def get_error_count(self) -> int: + """ + Get the number of errors encountered. + + Returns: + int: Number of errors + """ + with self._lock: + return len(self.errors) + + def get_elapsed_time(self) -> float: + """ + Get the elapsed monitoring time. + + Returns: + float: Elapsed time in seconds + """ + if self._start_time is None: + return 0.0 + return time.time() - self._start_time + + def stop(self) -> None: + """Stop monitoring and close the subscription.""" + if self._monitor_thread is not None: + self.subscription.close() + self._monitor_thread.join(timeout=1.0) + + +def assert_message_order(received: list[bytes], expected: list[bytes]) -> bool: + """ + Assert that messages were received in the expected order. + + Args: + received: List of received messages + expected: List of expected messages in order + + Returns: + bool: True if order matches expected + """ + if len(received) != len(expected): + return False + + for i, (recv_msg, exp_msg) in enumerate(zip(received, expected)): + if recv_msg != exp_msg: + _logger.error("Message order mismatch at index %s: expected %s, got %s", i, exp_msg, recv_msg) + return False + + return True + + +def measure_throughput( + operation: Callable[[], Any], + duration: float = 1.0, +) -> tuple[float, int]: + """ + Measure the throughput of an operation over a specified duration. + + Args: + operation: The operation to measure + duration: Duration to run the operation in seconds + + Returns: + tuple[float, int]: (operations per second, total operations) + """ + start_time = time.time() + end_time = start_time + duration + count = 0 + + while time.time() < end_time: + try: + operation() + count += 1 + except Exception: + _logger.exception("Operation failed") + break + + elapsed = time.time() - start_time + ops_per_sec = count / elapsed if elapsed > 0 else 0.0 + + return ops_per_sec, count + + +def wait_for_condition( + condition: Callable[[], bool], + timeout: float = 10.0, + interval: float = 0.1, +) -> bool: + """ + Wait for a condition to become true. + + Args: + condition: Function that returns True when condition is met + timeout: Maximum time to wait in seconds + interval: Check interval in seconds + + Returns: + bool: True if condition was met within timeout + """ + deadline = time.time() + timeout + + while time.time() < deadline: + if condition(): + return True + time.sleep(interval) + + return False + + +def create_stress_test_messages( + count: int, + size: int = 100, +) -> list[bytes]: + """ + Create messages for stress testing. + + Args: + count: Number of messages to create + size: Size of each message in bytes + + Returns: + list[bytes]: List of test messages + """ + messages = [] + for i in range(count): + message = f"stress_test_msg_{i:06d}_".ljust(size, "x").encode() + messages.append(message) + return messages + + +def validate_message_integrity( + original_messages: list[bytes], + received_messages: list[bytes], +) -> dict[str, Any]: + """ + Validate the integrity of received messages. + + Args: + original_messages: Messages that were sent + received_messages: Messages that were received + + Returns: + dict[str, Any]: Validation results + """ + original_set = set(original_messages) + received_set = set(received_messages) + + missing_messages = original_set - received_set + extra_messages = received_set - original_set + + return { + "total_sent": len(original_messages), + "total_received": len(received_messages), + "missing_count": len(missing_messages), + "extra_count": len(extra_messages), + "missing_messages": list(missing_messages), + "extra_messages": list(extra_messages), + "integrity_ok": len(missing_messages) == 0 and len(extra_messages) == 0, + } diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py new file mode 100644 index 0000000000..7fad603a6d --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -0,0 +1,166 @@ +"""TestContainers integration tests for ChatConversationApi status_count behavior.""" + +import json +import uuid + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from configs import dify_config +from constants import HEADER_NAME_CSRF_TOKEN +from core.workflow.enums import WorkflowExecutionStatus +from libs.datetime_utils import naive_utc_now +from libs.token import _real_cookie_name, generate_csrf_token +from models import Account, DifySetup, Tenant, TenantAccountJoin +from models.account import AccountStatus, TenantAccountRole +from models.enums import CreatorUserRole +from models.model import App, AppMode, Conversation, Message +from models.workflow import WorkflowRun +from services.account_service import AccountService + + +def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: + account = Account( + email=f"test-{uuid.uuid4()}@example.com", + name="Test User", + interface_language="en-US", + status=AccountStatus.ACTIVE, + ) + account.initialized_at = naive_utc_now() + db_session.add(account) + db_session.commit() + + tenant = Tenant(name="Test Tenant", status="normal") + db_session.add(tenant) + db_session.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session.add(join) + db_session.commit() + + account.set_tenant_id(tenant.id) + account.timezone = "UTC" + db_session.commit() + + dify_setup = DifySetup(version=dify_config.project.version) + db_session.add(dify_setup) + db_session.commit() + + return account, tenant + + +def _create_app(db_session: Session, tenant_id: str, account_id: str) -> App: + app = App( + tenant_id=tenant_id, + name="Test Chat App", + mode=AppMode.CHAT, + enable_site=True, + enable_api=True, + created_by=account_id, + ) + db_session.add(app) + db_session.commit() + return app + + +def _create_conversation(db_session: Session, app_id: str, account_id: str) -> Conversation: + conversation = Conversation( + app_id=app_id, + name="Test Conversation", + inputs={}, + status="normal", + mode=AppMode.CHAT, + from_source=CreatorUserRole.ACCOUNT, + from_account_id=account_id, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_workflow_run(db_session: Session, app_id: str, tenant_id: str, account_id: str) -> WorkflowRun: + workflow_run = WorkflowRun( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=str(uuid.uuid4()), + type="chat", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + inputs=json.dumps({"query": "test"}), + status=WorkflowExecutionStatus.PAUSED, + outputs=json.dumps({}), + elapsed_time=0.0, + total_tokens=0, + total_steps=0, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account_id, + created_at=naive_utc_now(), + ) + db_session.add(workflow_run) + db_session.commit() + return workflow_run + + +def _create_message( + db_session: Session, app_id: str, conversation_id: str, workflow_run_id: str, account_id: str +) -> Message: + message = Message( + app_id=app_id, + conversation_id=conversation_id, + query="Hello", + message={"type": "text", "content": "Hello"}, + answer="Hi there", + message_tokens=1, + answer_tokens=1, + message_unit_price=0.001, + answer_unit_price=0.001, + message_price_unit=0.001, + answer_price_unit=0.001, + currency="USD", + status="normal", + from_source=CreatorUserRole.ACCOUNT, + from_account_id=account_id, + workflow_run_id=workflow_run_id, + inputs={"query": "Hello"}, + ) + db_session.add(message) + db_session.commit() + return message + + +def test_chat_conversation_status_count_includes_paused( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +): + account, tenant = _create_account_and_tenant(db_session_with_containers) + app = _create_app(db_session_with_containers, tenant.id, account.id) + conversation = _create_conversation(db_session_with_containers, app.id, account.id) + conversation_id = conversation.id + workflow_run = _create_workflow_run(db_session_with_containers, app.id, tenant.id, account.id) + _create_message(db_session_with_containers, app.id, conversation.id, workflow_run.id, account.id) + + access_token = AccountService.get_account_jwt_token(account) + csrf_token = generate_csrf_token(account.id) + cookie_name = _real_cookie_name("csrf_token") + + test_client_with_containers.set_cookie(cookie_name, csrf_token, domain="localhost") + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-conversations", + headers={ + "Authorization": f"Bearer {access_token}", + HEADER_NAME_CSRF_TOKEN: csrf_token, + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["total"] == 1 + assert payload["data"][0]["id"] == conversation_id + assert payload["data"][0]["status_count"]["paused"] == 1 diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py new file mode 100644 index 0000000000..079e4934bb --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -0,0 +1,240 @@ +"""TestContainers integration tests for HumanInputFormRepositoryImpl.""" + +from __future__ import annotations + +from uuid import uuid4 + +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session + +from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.workflow.nodes.human_input.entities import ( + DeliveryChannelConfig, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + FormDefinition, + HumanInputNodeData, + MemberRecipient, + UserAction, + WebAppDeliveryMethod, +) +from core.workflow.repositories.human_input_form_repository import FormCreateParams +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.human_input import ( + EmailExternalRecipientPayload, + EmailMemberRecipientPayload, + HumanInputForm, + HumanInputFormRecipient, + RecipientType, +) + + +def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]: + tenant = Tenant(name="Test Tenant", status="normal") + session.add(tenant) + session.flush() + + members: list[Account] = [] + for index, email in enumerate(member_emails): + account = Account( + email=email, + name=f"Member {index}", + interface_language="en-US", + status="active", + ) + session.add(account) + session.flush() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.NORMAL, + current=True, + ) + session.add(tenant_join) + members.append(account) + + session.commit() + return tenant, members + + +def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCreateParams: + form_config = HumanInputNodeData( + title="Human Approval", + delivery_methods=delivery_methods, + form_content="

Approve?

", + user_actions=[UserAction(id="approve", title="Approve")], + ) + return FormCreateParams( + app_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + node_id="human-input-node", + form_config=form_config, + rendered_content="

Approve?

", + delivery_methods=delivery_methods, + display_in_ui=False, + resolved_default_values={}, + ) + + +def _build_email_delivery( + whole_workspace: bool, recipients: list[MemberRecipient | ExternalRecipient] +) -> EmailDeliveryMethod: + return EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), + subject="Approval Needed", + body="Please review", + ) + ) + + +class TestHumanInputFormRepositoryImplWithContainers: + def test_create_form_with_whole_workspace_recipients(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, members = _create_tenant_with_members( + db_session_with_containers, + member_emails=["member1@example.com", "member2@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + params = _build_form_params( + delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + recipients = verification_session.scalars( + select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id) + ).all() + + assert len(recipients) == len(members) + member_payloads = [ + EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload) + for recipient in recipients + if recipient.recipient_type == RecipientType.EMAIL_MEMBER + ] + member_emails = {payload.email for payload in member_payloads} + assert member_emails == {member.email for member in members} + + def test_create_form_with_specific_members_and_external(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, members = _create_tenant_with_members( + db_session_with_containers, + member_emails=["primary@example.com", "secondary@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + params = _build_form_params( + delivery_methods=[ + _build_email_delivery( + whole_workspace=False, + recipients=[ + MemberRecipient(user_id=members[0].id), + ExternalRecipient(email="external@example.com"), + ], + ) + ], + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + recipients = verification_session.scalars( + select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id) + ).all() + + member_recipient_payloads = [ + EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload) + for recipient in recipients + if recipient.recipient_type == RecipientType.EMAIL_MEMBER + ] + assert len(member_recipient_payloads) == 1 + assert member_recipient_payloads[0].user_id == members[0].id + + external_payloads = [ + EmailExternalRecipientPayload.model_validate_json(recipient.recipient_payload) + for recipient in recipients + if recipient.recipient_type == RecipientType.EMAIL_EXTERNAL + ] + assert len(external_payloads) == 1 + assert external_payloads[0].email == "external@example.com" + + def test_create_form_persists_default_values(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, _ = _create_tenant_with_members( + db_session_with_containers, + member_emails=["prefill@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + resolved_values = {"greeting": "Hello!"} + params = FormCreateParams( + app_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + node_id="human-input-node", + form_config=HumanInputNodeData( + title="Human Approval", + form_content="

Approve?

", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + ), + rendered_content="

Approve?

", + delivery_methods=[], + display_in_ui=False, + resolved_default_values=resolved_values, + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + form_model = verification_session.scalars( + select(HumanInputForm).where(HumanInputForm.id == form_entity.id) + ).first() + + assert form_model is not None + definition = FormDefinition.model_validate_json(form_model.form_definition) + assert definition.default_values == resolved_values + + def test_create_form_persists_display_in_ui(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, _ = _create_tenant_with_members( + db_session_with_containers, + member_emails=["ui@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + params = FormCreateParams( + app_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + node_id="human-input-node", + form_config=HumanInputNodeData( + title="Human Approval", + form_content="

Approve?

", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + delivery_methods=[WebAppDeliveryMethod()], + ), + rendered_content="

Approve?

", + delivery_methods=[WebAppDeliveryMethod()], + display_in_ui=True, + resolved_default_values={}, + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + form_model = verification_session.scalars( + select(HumanInputForm).where(HumanInputForm.id == form_entity.id) + ).first() + + assert form_model is not None + definition = FormDefinition.model_validate_json(form_model.form_definition) + assert definition.display_in_ui is True diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py new file mode 100644 index 0000000000..06d55177eb --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -0,0 +1,336 @@ +import time +import uuid +from datetime import timedelta +from unittest.mock import MagicMock + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities import GraphInitParams +from core.workflow.enums import WorkflowType +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now +from models import Account +from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.model import App, AppMode, IconType +from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun + + +def _mock_form_repository_without_submission() -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + form_entity.rendered_content = "rendered" + form_entity.submitted = False + repo.create_form.return_value = form_entity + repo.get_form.return_value = None + return repo + + +def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + form_entity.rendered_content = "rendered" + form_entity.submitted = True + form_entity.selected_action_id = action_id + form_entity.submitted_data = {} + form_entity.status = HumanInputFormStatus.WAITING + form_entity.expiration_time = naive_utc_now() + timedelta(hours=1) + repo.get_form.return_value = form_entity + return repo + + +def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + workflow_execution_id=workflow_execution_id, + app_id=app_id, + workflow_id=workflow_id, + user_id=user_id, + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_graph( + runtime_state: GraphRuntimeState, + tenant_id: str, + app_id: str, + workflow_id: str, + user_id: str, + form_repository: HumanInputFormRepository, +) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + params = GraphInitParams( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + graph_config=graph_config, + user_id=user_id, + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_data = StartNodeData(title="start", variables=[]) + start_node = StartNode( + id="start", + config={"id": "start", "data": start_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + human_data = HumanInputNodeData( + title="human", + form_content="Awaiting human input", + inputs=[], + user_actions=[ + UserAction(id="continue", title="Continue"), + ], + ) + human_node = HumanInputNode( + id="human", + config={"id": "human", "data": human_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + form_repository=form_repository, + ) + + end_data = EndNodeData( + title="end", + outputs=[], + desc=None, + ) + end_node = EndNode( + id="end", + config={"id": "end", "data": end_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + return ( + Graph.new() + .add_root(start_node) + .add_node(human_node) + .add_node(end_node, from_node_id="human", source_handle="continue") + .build() + ) + + +def _build_generate_entity( + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_execution_id: str, + user_id: str, +) -> WorkflowAppGenerateEntity: + app_config = WorkflowUIBasedAppConfig( + tenant_id=tenant_id, + app_id=app_id, + app_mode=AppMode.WORKFLOW, + workflow_id=workflow_id, + ) + return WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user_id, + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_id=workflow_execution_id, + ) + + +class TestHumanInputResumeNodeExecutionIntegration: + @pytest.fixture(autouse=True) + def setup_test_data(self, db_session_with_containers: Session): + tenant = Tenant( + name="Test Tenant", + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + account = Account( + email="test@example.com", + name="Test User", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(tenant_join) + db_session_with_containers.commit() + + account.current_tenant = tenant + + app = App( + tenant_id=tenant.id, + name="Test App", + description="", + mode=AppMode.WORKFLOW.value, + icon_type=IconType.EMOJI.value, + icon="rocket", + icon_background="#4ECDC4", + enable_site=False, + enable_api=False, + api_rpm=0, + api_rph=0, + is_demo=False, + is_public=False, + is_universal=False, + max_active_requests=None, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + + workflow = Workflow( + tenant_id=tenant.id, + app_id=app.id, + type="workflow", + version="draft", + graph='{"nodes": [], "edges": []}', + features='{"file_upload": {"enabled": false}}', + created_by=account.id, + created_at=naive_utc_now(), + ) + db_session_with_containers.add(workflow) + db_session_with_containers.commit() + + self.session = db_session_with_containers + self.tenant = tenant + self.account = account + self.app = app + self.workflow = workflow + + yield + + self.session.execute(delete(WorkflowNodeExecutionModel)) + self.session.execute(delete(WorkflowRun)) + self.session.execute(delete(Workflow).where(Workflow.id == self.workflow.id)) + self.session.execute(delete(App).where(App.id == self.app.id)) + self.session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == self.tenant.id)) + self.session.execute(delete(Account).where(Account.id == self.account.id)) + self.session.execute(delete(Tenant).where(Tenant.id == self.tenant.id)) + self.session.commit() + + def _build_persistence_layer(self, execution_id: str) -> WorkflowPersistenceLayer: + generate_entity = _build_generate_entity( + tenant_id=self.tenant.id, + app_id=self.app.id, + workflow_id=self.workflow.id, + workflow_execution_id=execution_id, + user_id=self.account.id, + ) + execution_repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=self.session.get_bind(), + user=self.account, + app_id=self.app.id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + node_execution_repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=self.session.get_bind(), + user=self.account, + app_id=self.app.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + return WorkflowPersistenceLayer( + application_generate_entity=generate_entity, + workflow_info=PersistenceWorkflowInfo( + workflow_id=self.workflow.id, + workflow_type=WorkflowType.WORKFLOW, + version=self.workflow.version, + graph_data=self.workflow.graph_dict, + ), + workflow_execution_repository=execution_repo, + workflow_node_execution_repository=node_execution_repo, + ) + + def _run_graph(self, graph: Graph, runtime_state: GraphRuntimeState, execution_id: str) -> None: + engine = GraphEngine( + workflow_id=self.workflow.id, + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + engine.layer(self._build_persistence_layer(execution_id)) + for _ in engine.run(): + continue + + def test_resume_human_input_does_not_create_duplicate_node_execution(self): + execution_id = str(uuid.uuid4()) + runtime_state = _build_runtime_state( + workflow_execution_id=execution_id, + app_id=self.app.id, + workflow_id=self.workflow.id, + user_id=self.account.id, + ) + pause_repo = _mock_form_repository_without_submission() + paused_graph = _build_graph( + runtime_state, + self.tenant.id, + self.app.id, + self.workflow.id, + self.account.id, + pause_repo, + ) + self._run_graph(paused_graph, runtime_state, execution_id) + + snapshot = runtime_state.dumps() + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + resume_repo = _mock_form_repository_with_submission(action_id="continue") + resumed_graph = _build_graph( + resumed_state, + self.tenant.id, + self.app.id, + self.workflow.id, + self.account.id, + resume_repo, + ) + self._run_graph(resumed_graph, resumed_state, execution_id) + + stmt = select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.workflow_run_id == execution_id, + WorkflowNodeExecutionModel.node_id == "human", + ) + records = self.session.execute(stmt).scalars().all() + assert len(records) == 1 + assert records[0].status != "paused" + assert records[0].triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + assert records[0].created_by_role == CreatorUserRole.ACCOUNT diff --git a/api/tests/test_containers_integration_tests/helpers/__init__.py b/api/tests/test_containers_integration_tests/helpers/__init__.py new file mode 100644 index 0000000000..40d03889a9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/helpers/__init__.py @@ -0,0 +1 @@ +"""Helper utilities for integration tests.""" diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py new file mode 100644 index 0000000000..19d7772c39 --- /dev/null +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timedelta +from decimal import Decimal +from uuid import uuid4 + +from core.workflow.nodes.human_input.entities import FormDefinition, UserAction +from models.account import Account, Tenant, TenantAccountJoin +from models.execution_extra_content import HumanInputContent +from models.human_input import HumanInputForm, HumanInputFormStatus +from models.model import App, Conversation, Message + + +@dataclass +class HumanInputMessageFixture: + app: App + account: Account + conversation: Conversation + message: Message + form: HumanInputForm + action_id: str + action_text: str + node_title: str + + +def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session.add(tenant) + db_session.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"human_input_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session.add(account) + db_session.flush() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session.add(tenant_join) + db_session.flush() + + app = App( + tenant_id=tenant.id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session.add(app) + db_session.flush() + + conversation = Conversation( + app_id=app.id, + mode="chat", + name="Test Conversation", + summary="", + introduction="", + system_instruction="", + status="normal", + invoke_from="console", + from_source="console", + from_account_id=account.id, + from_end_user_id=None, + ) + conversation.inputs = {} + db_session.add(conversation) + db_session.flush() + + workflow_run_id = str(uuid4()) + message = Message( + app_id=app.id, + conversation_id=conversation.id, + inputs={}, + query="Human input query", + message={"messages": []}, + answer="Human input answer", + message_tokens=50, + message_unit_price=Decimal("0.001"), + answer_tokens=80, + answer_unit_price=Decimal("0.001"), + provider_response_latency=0.5, + currency="USD", + from_source="console", + from_account_id=account.id, + workflow_run_id=workflow_run_id, + ) + db_session.add(message) + db_session.flush() + + action_id = "approve" + action_text = "Approve request" + node_title = "Approval" + form_definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id=action_id, title=action_text)], + rendered_content="Rendered block", + expiration_time=datetime.utcnow() + timedelta(days=1), + node_title=node_title, + display_in_ui=True, + ) + form = HumanInputForm( + tenant_id=tenant.id, + app_id=app.id, + workflow_run_id=workflow_run_id, + node_id="node-id", + form_definition=form_definition.model_dump_json(), + rendered_content="Rendered block", + status=HumanInputFormStatus.SUBMITTED, + expiration_time=datetime.utcnow() + timedelta(days=1), + selected_action_id=action_id, + ) + db_session.add(form) + db_session.flush() + + content = HumanInputContent( + workflow_run_id=workflow_run_id, + message_id=message.id, + form_id=form.id, + ) + db_session.add(content) + db_session.commit() + + return HumanInputMessageFixture( + app=app, + account=account, + conversation=conversation, + message=message, + form=form, + action_id=action_id, + action_text=action_text, + node_title=node_title, + ) diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py index d612e70910..43915a204d 100644 --- a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py @@ -16,6 +16,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest import redis +from redis.cluster import RedisCluster from testcontainers.redis import RedisContainer from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic @@ -332,3 +333,95 @@ class TestShardedRedisBroadcastChannelIntegration: # Verify subscriptions are cleaned up topic_subscribers_after = self._get_sharded_numsub(redis_client, topic_name) assert topic_subscribers_after == 0 + + +class TestShardedRedisBroadcastChannelClusterIntegration: + """Integration tests for sharded pub/sub with RedisCluster client.""" + + @pytest.fixture(scope="class") + def redis_cluster_container(self) -> Iterator[RedisContainer]: + """Create a Redis 7 container with cluster mode enabled.""" + command = ( + "redis-server --port 6379 " + "--cluster-enabled yes " + "--cluster-config-file nodes.conf " + "--cluster-node-timeout 5000 " + "--appendonly no " + "--protected-mode no" + ) + with RedisContainer(image="redis:7-alpine").with_command(command) as container: + yield container + + @classmethod + def _get_test_topic_name(cls) -> str: + return f"test_sharded_cluster_topic_{uuid.uuid4()}" + + @staticmethod + def _ensure_single_node_cluster(host: str, port: int) -> None: + client = redis.Redis(host=host, port=port, decode_responses=False) + client.config_set("cluster-announce-ip", host) + client.config_set("cluster-announce-port", port) + slots = client.execute_command("CLUSTER", "SLOTS") + if not slots: + client.execute_command("CLUSTER", "ADDSLOTSRANGE", 0, 16383) + + deadline = time.time() + 5.0 + while time.time() < deadline: + info = client.execute_command("CLUSTER", "INFO") + info_text = info.decode("utf-8") if isinstance(info, (bytes, bytearray)) else str(info) + if "cluster_state:ok" in info_text: + return + time.sleep(0.05) + raise RuntimeError("Redis cluster did not become ready in time") + + @pytest.fixture(scope="class") + def redis_cluster_client(self, redis_cluster_container: RedisContainer) -> RedisCluster: + host = redis_cluster_container.get_container_host_ip() + port = int(redis_cluster_container.get_exposed_port(6379)) + self._ensure_single_node_cluster(host, port) + return RedisCluster(host=host, port=port, decode_responses=False) + + @pytest.fixture + def broadcast_channel(self, redis_cluster_client: RedisCluster) -> BroadcastChannel: + return ShardedRedisBroadcastChannel(redis_cluster_client) + + def test_cluster_sharded_pubsub_delivers_message(self, broadcast_channel: BroadcastChannel): + """Ensure sharded subscription receives messages when using RedisCluster client.""" + topic_name = self._get_test_topic_name() + message = b"cluster sharded message" + + topic = broadcast_channel.topic(topic_name) + producer = topic.as_producer() + subscription = topic.subscribe() + ready_event = threading.Event() + + def consumer_thread() -> list[bytes]: + received = [] + try: + _ = subscription.receive(0.01) + except SubscriptionClosedError: + return received + ready_event.set() + deadline = time.time() + 5.0 + while time.time() < deadline: + msg = subscription.receive(timeout=0.1) + if msg is None: + continue + received.append(msg) + break + subscription.close() + return received + + def producer_thread(): + if not ready_event.wait(timeout=2.0): + pytest.fail("subscriber did not become ready before publish") + producer.publish(message) + + with ThreadPoolExecutor(max_workers=2) as executor: + consumer_future = executor.submit(consumer_thread) + producer_future = executor.submit(producer_thread) + + producer_future.result(timeout=5.0) + received_messages = consumer_future.result(timeout=5.0) + + assert received_messages == [message] diff --git a/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py b/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py new file mode 100644 index 0000000000..178fc2e4fb --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py @@ -0,0 +1,25 @@ +""" +Integration tests for RateLimiter using testcontainers Redis. +""" + +import uuid + +import pytest + +from extensions.ext_redis import redis_client +from libs import helper as helper_module + + +@pytest.mark.usefixtures("flask_app_with_containers") +def test_rate_limiter_counts_multiple_attempts_in_same_second(monkeypatch): + prefix = f"test_rate_limit:{uuid.uuid4().hex}" + limiter = helper_module.RateLimiter(prefix=prefix, max_attempts=2, time_window=60) + key = limiter._get_key("203.0.113.10") + + redis_client.delete(key) + monkeypatch.setattr(helper_module.time, "time", lambda: 1_700_000_000) + + limiter.increment_rate_limit("203.0.113.10") + limiter.increment_rate_limit("203.0.113.10") + + assert limiter.is_rate_limited("203.0.113.10") is True diff --git a/api/tests/test_containers_integration_tests/models/test_account.py b/api/tests/test_containers_integration_tests/models/test_account.py new file mode 100644 index 0000000000..078dc0e8de --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_account.py @@ -0,0 +1,79 @@ +# import secrets + +# import pytest +# from sqlalchemy import select +# from sqlalchemy.orm import Session +# from sqlalchemy.orm.exc import DetachedInstanceError + +# from libs.datetime_utils import naive_utc_now +# from models.account import Account, Tenant, TenantAccountJoin + + +# @pytest.fixture +# def session(db_session_with_containers): +# with Session(db_session_with_containers.get_bind()) as session: +# yield session + + +# @pytest.fixture +# def account(session): +# account = Account( +# name="test account", +# email=f"test_{secrets.token_hex(8)}@example.com", +# ) +# session.add(account) +# session.commit() +# return account + + +# @pytest.fixture +# def tenant(session): +# tenant = Tenant(name="test tenant") +# session.add(tenant) +# session.commit() +# return tenant + + +# @pytest.fixture +# def tenant_account_join(session, account, tenant): +# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id) +# session.add(tenant_join) +# session.commit() +# yield tenant_join +# session.delete(tenant_join) +# session.commit() + + +# class TestAccountTenant: +# def test_set_current_tenant_should_reload_tenant( +# self, +# db_session_with_containers, +# account, +# tenant, +# tenant_account_join, +# ): +# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session: +# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one() +# account.current_tenant = scoped_tenant +# scoped_tenant.created_at = naive_utc_now() +# # session.commit() + +# # Ensure the tenant used in assignment is detached. +# with pytest.raises(DetachedInstanceError): +# _ = scoped_tenant.name + +# assert account._current_tenant.id == tenant.id +# assert account._current_tenant.id == tenant.id + +# def test_set_tenant_id_should_load_tenant_as_not_expire( +# self, +# flask_app_with_containers, +# account, +# tenant, +# tenant_account_join, +# ): +# with flask_app_with_containers.test_request_context(): +# account.set_tenant_id(tenant.id) + +# assert account._current_tenant.id == tenant.id +# assert account._current_tenant.id == tenant.id diff --git a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py new file mode 100644 index 0000000000..c9058626d1 --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from sqlalchemy.orm import sessionmaker + +from extensions.ext_database import db +from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository +from tests.test_containers_integration_tests.helpers.execution_extra_content import ( + create_human_input_message_fixture, +) + + +def test_get_by_message_ids_returns_human_input_content(db_session_with_containers): + fixture = create_human_input_message_fixture(db_session_with_containers) + repository = SQLAlchemyExecutionExtraContentRepository( + session_maker=sessionmaker(bind=db.engine, expire_on_commit=False) + ) + + results = repository.get_by_message_ids([fixture.message.id]) + + assert len(results) == 1 + assert len(results[0]) == 1 + content = results[0][0] + assert content.submitted is True + assert content.form_submission_data is not None + assert content.form_submission_data.action_id == fixture.action_id + assert content.form_submission_data.action_text == fixture.action_text + assert content.form_submission_data.rendered_content == fixture.form.rendered_content diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 4d4e77a802..4b6b5048a1 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -2293,6 +2293,12 @@ class TestRegisterService: mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + from extensions.ext_database import db + from models.model import DifySetup + + db.session.query(DifySetup).delete() + db.session.commit() + # Execute setup RegisterService.setup( email=admin_email, @@ -2303,9 +2309,7 @@ class TestRegisterService: ) # Verify account was created - from extensions.ext_database import db from models import Account - from models.model import DifySetup account = db.session.query(Account).filter_by(email=admin_email).first() assert account is not None diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 476f58585d..81bfa0ea20 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -1,5 +1,5 @@ import uuid -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from faker import Faker @@ -26,6 +26,7 @@ class TestAppGenerateService: patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator, patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator, patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator, + patch("services.app_generate_service.MessageBasedAppGenerator") as mock_message_based_generator, patch("services.account_service.FeatureService") as mock_account_feature_service, patch("services.app_generate_service.dify_config") as mock_dify_config, patch("configs.dify_config") as mock_global_dify_config, @@ -38,9 +39,13 @@ class TestAppGenerateService: # Setup default mock returns for workflow service mock_workflow_service_instance = mock_workflow_service.return_value - mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow) - mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow) - mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow) + mock_published_workflow = MagicMock(spec=Workflow) + mock_published_workflow.id = str(uuid.uuid4()) + mock_workflow_service_instance.get_published_workflow.return_value = mock_published_workflow + mock_draft_workflow = MagicMock(spec=Workflow) + mock_draft_workflow.id = str(uuid.uuid4()) + mock_workflow_service_instance.get_draft_workflow.return_value = mock_draft_workflow + mock_workflow_service_instance.get_published_workflow_by_id.return_value = mock_published_workflow # Setup default mock returns for rate limiting mock_rate_limit_instance = mock_rate_limit.return_value @@ -66,6 +71,8 @@ class TestAppGenerateService: mock_advanced_chat_generator_instance.generate.return_value = ["advanced_chat_response"] mock_advanced_chat_generator_instance.single_iteration_generate.return_value = ["single_iteration_response"] mock_advanced_chat_generator_instance.single_loop_generate.return_value = ["single_loop_response"] + mock_advanced_chat_generator_instance.retrieve_events.return_value = ["advanced_chat_events"] + mock_advanced_chat_generator_instance.convert_to_event_stream.return_value = ["advanced_chat_stream"] mock_advanced_chat_generator.convert_to_event_stream.return_value = ["advanced_chat_stream"] mock_workflow_generator_instance = mock_workflow_generator.return_value @@ -76,6 +83,8 @@ class TestAppGenerateService: mock_workflow_generator_instance.single_loop_generate.return_value = ["workflow_single_loop_response"] mock_workflow_generator.convert_to_event_stream.return_value = ["workflow_stream"] + mock_message_based_generator.retrieve_events.return_value = ["workflow_events"] + # Setup default mock returns for account service mock_account_feature_service.get_system_features.return_value.is_allow_register = True @@ -88,6 +97,7 @@ class TestAppGenerateService: mock_global_dify_config.BILLING_ENABLED = False mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000 + mock_global_dify_config.HOSTED_POOL_CREDITS = 1000 yield { "billing_service": mock_billing_service, @@ -98,6 +108,7 @@ class TestAppGenerateService: "agent_chat_generator": mock_agent_chat_generator, "advanced_chat_generator": mock_advanced_chat_generator, "workflow_generator": mock_workflow_generator, + "message_based_generator": mock_message_based_generator, "account_feature_service": mock_account_feature_service, "dify_config": mock_dify_config, "global_dify_config": mock_global_dify_config, @@ -280,8 +291,10 @@ class TestAppGenerateService: assert result == ["test_response"] # Verify advanced chat generator was called - mock_external_service_dependencies["advanced_chat_generator"].return_value.generate.assert_called_once() - mock_external_service_dependencies["advanced_chat_generator"].convert_to_event_stream.assert_called_once() + mock_external_service_dependencies["advanced_chat_generator"].return_value.retrieve_events.assert_called_once() + mock_external_service_dependencies[ + "advanced_chat_generator" + ].return_value.convert_to_event_stream.assert_called_once() def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -304,7 +317,7 @@ class TestAppGenerateService: assert result == ["test_response"] # Verify workflow generator was called - mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once() + mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once() mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once() def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies): @@ -970,14 +983,27 @@ class TestAppGenerateService: } # Execute the method under test - result = AppGenerateService.generate( - app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True - ) + with patch("services.app_generate_service.AppExecutionParams") as mock_exec_params: + mock_payload = MagicMock() + mock_payload.workflow_run_id = fake.uuid4() + mock_payload.model_dump_json.return_value = "{}" + mock_exec_params.new.return_value = mock_payload + + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) # Verify the result assert result == ["test_response"] - # Verify workflow generator was called with complex args - mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once() - call_args = mock_external_service_dependencies["workflow_generator"].return_value.generate.call_args - assert call_args[1]["args"] == args + # Verify payload was built with complex args + mock_exec_params.new.assert_called_once() + call_kwargs = mock_exec_params.new.call_args.kwargs + assert call_kwargs["args"] == args + + # Verify workflow streaming event retrieval was used + mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once_with( + ANY, + mock_payload.workflow_run_id, + on_subscribe=ANY, + ) diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py new file mode 100644 index 0000000000..9c978f830f --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -0,0 +1,112 @@ +import json +import uuid +from unittest.mock import MagicMock + +import pytest + +from core.workflow.enums import NodeType +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + HumanInputNodeData, +) +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import App, AppMode +from models.workflow import Workflow, WorkflowType +from services.workflow_service import WorkflowService + + +def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) -> tuple[App, Account]: + tenant = Tenant(name="Test Tenant") + account = Account(name="Tester", email="tester@example.com") + session.add_all([tenant, account]) + session.flush() + + session.add( + TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + current=True, + role=TenantAccountRole.OWNER.value, + ) + ) + + app = App( + tenant_id=tenant.id, + name="Test App", + description="", + mode=AppMode.WORKFLOW.value, + icon_type="emoji", + icon="app", + icon_background="#ffffff", + enable_site=True, + enable_api=True, + created_by=account.id, + updated_by=account.id, + ) + session.add(app) + session.flush() + + email_method = EmailDeliveryMethod( + id=delivery_method_id, + enabled=True, + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[ExternalRecipient(email="recipient@example.com")], + ), + subject="Test {{recipient_email}}", + body="Body {{#url#}} {{form_content}}", + ), + ) + node_data = HumanInputNodeData( + title="Human Input", + delivery_methods=[email_method], + form_content="Hello Human Input", + inputs=[], + user_actions=[], + ).model_dump(mode="json") + node_data["type"] = NodeType.HUMAN_INPUT.value + graph = json.dumps({"nodes": [{"id": "human-node", "data": node_data}], "edges": []}) + + workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version=Workflow.VERSION_DRAFT, + graph=graph, + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + session.add(workflow) + session.commit() + + return app, account + + +def test_human_input_delivery_test_sends_email( + db_session_with_containers, + monkeypatch: pytest.MonkeyPatch, +) -> None: + delivery_method_id = uuid.uuid4() + app, account = _create_app_with_draft_workflow(db_session_with_containers, delivery_method_id=delivery_method_id) + + send_mock = MagicMock() + monkeypatch.setattr("services.human_input_delivery_test_service.mail.is_inited", lambda: True) + monkeypatch.setattr("services.human_input_delivery_test_service.mail.send", send_mock) + + service = WorkflowService() + service.test_human_input_delivery( + app_model=app, + account=account, + node_id="human-node", + delivery_method_id=str(delivery_method_id), + ) + + assert send_mock.call_count == 1 + assert send_mock.call_args.kwargs["to"] == "recipient@example.com" diff --git a/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py b/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py new file mode 100644 index 0000000000..44e5a82868 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import pytest + +from services.message_service import MessageService +from tests.test_containers_integration_tests.helpers.execution_extra_content import ( + create_human_input_message_fixture, +) + + +@pytest.mark.usefixtures("flask_req_ctx_with_containers") +def test_pagination_returns_extra_contents(db_session_with_containers): + fixture = create_human_input_message_fixture(db_session_with_containers) + + pagination = MessageService.pagination_by_first_id( + app_model=fixture.app, + user=fixture.account, + conversation_id=fixture.conversation.id, + first_id=None, + limit=10, + ) + + assert pagination.data + message = pagination.data[0] + assert message.extra_contents == [ + { + "type": "human_input", + "workflow_run_id": fixture.message.workflow_run_id, + "submitted": True, + "form_submission_data": { + "node_id": fixture.form.node_id, + "node_title": fixture.node_title, + "rendered_content": fixture.form.rendered_content, + "action_id": fixture.action_id, + "action_text": fixture.action_text, + }, + } + ] diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 23c4eeb82f..3a88081db3 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -465,6 +465,27 @@ class TestWorkflowRunService: db.session.add(node_execution) node_executions.append(node_execution) + paused_node_execution = WorkflowNodeExecutionModel( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow_run.workflow_id, + triggered_from="workflow-run", + workflow_run_id=workflow_run.id, + index=99, + node_id="node_paused", + node_type="human_input", + title="Paused Node", + inputs=json.dumps({"input": "paused"}), + process_data=json.dumps({"process": "paused"}), + status="paused", + elapsed_time=0.5, + execution_metadata=json.dumps({"tokens": 0}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(paused_node_execution) + db.session.commit() # Act: Execute the method under test @@ -473,16 +494,19 @@ class TestWorkflowRunService: # Assert: Verify the expected outcomes assert result is not None - assert len(result) == 3 + assert len(result) == 4 # Verify node execution properties + statuses = [node_execution.status for node_execution in result] + assert "paused" in statuses + assert statuses.count("succeeded") == 3 + assert statuses.count("paused") == 1 + for node_execution in result: assert node_execution.tenant_id == app.tenant_id assert node_execution.app_id == app.id assert node_execution.workflow_run_id == workflow_run.id - assert node_execution.index in [0, 1, 2] # Check that index is one of the expected values - assert node_execution.node_id.startswith("node_") # Check that node_id starts with "node_" - assert node_execution.status == "succeeded" + assert node_execution.node_id.startswith("node_") def test_get_workflow_run_node_executions_empty( self, db_session_with_containers, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 3d46735a1a..acd9d78c91 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from core.tools.errors import WorkflowToolHumanInputNotSupportedError from models.tools import WorkflowToolProvider from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService @@ -507,6 +508,62 @@ class TestWorkflowToolManageService: assert tool_count == 0 + def test_create_workflow_tool_human_input_node_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow tool creation fails when workflow contains human input nodes. + + This test verifies: + - Human input nodes prevent workflow tool publishing + - Correct error message + - No database changes when workflow is invalid + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + workflow.graph = json.dumps( + { + "nodes": [ + { + "id": "human_input_node", + "data": {"type": "human-input"}, + } + ] + } + ) + + tool_parameters = self._create_test_workflow_tool_parameters() + with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=fake.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=tool_parameters, + ) + + assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + + from extensions.ext_database import db + + tool_count = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + ) + .count() + ) + + assert tool_count == 0 + def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): """ Test successful workflow tool update with valid parameters. @@ -593,6 +650,80 @@ class TestWorkflowToolManageService: mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called() mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() + def test_update_workflow_tool_human_input_node_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow tool update fails when workflow contains human input nodes. + + This test verifies: + - Human input nodes prevent workflow tool updates + - Correct error message + - Existing tool data remains unchanged + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create initial workflow tool + initial_tool_name = fake.word() + initial_tool_parameters = self._create_test_workflow_tool_parameters() + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=initial_tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=initial_tool_parameters, + ) + + from extensions.ext_database import db + + created_tool = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + WorkflowToolProvider.app_id == app.id, + ) + .first() + ) + + original_name = created_tool.name + + workflow.graph = json.dumps( + { + "nodes": [ + { + "id": "human_input_node", + "data": {"type": "human-input"}, + } + ] + } + ) + db.session.commit() + + with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: + WorkflowToolManageService.update_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_tool_id=created_tool.id, + name=fake.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "⚙️"}, + description=fake.text(max_nb_chars=200), + parameters=initial_tool_parameters, + ) + + assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + + db.session.refresh(created_tool) + assert created_tool.name == original_name + def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): """ Test workflow tool update fails when tool does not exist. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py new file mode 100644 index 0000000000..5fd6c56f7a --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -0,0 +1,214 @@ +import uuid +from datetime import UTC, datetime +from unittest.mock import patch + +import pytest + +from configs import dify_config +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + HumanInputNodeData, + MemberRecipient, +) +from core.workflow.runtime import GraphRuntimeState, VariablePool +from extensions.ext_storage import storage +from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient +from models.model import AppMode +from models.workflow import WorkflowPause, WorkflowRun, WorkflowType +from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task + + +@pytest.fixture(autouse=True) +def cleanup_database(db_session_with_containers): + db_session_with_containers.query(HumanInputFormRecipient).delete() + db_session_with_containers.query(HumanInputDelivery).delete() + db_session_with_containers.query(HumanInputForm).delete() + db_session_with_containers.query(WorkflowPause).delete() + db_session_with_containers.query(WorkflowRun).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() + + +def _create_workspace_member(db_session_with_containers): + account = Account( + email="owner@example.com", + name="Owner", + password="password", + interface_language="en-US", + status=AccountStatus.ACTIVE, + ) + account.created_at = datetime.now(UTC) + account.updated_at = datetime.now(UTC) + db_session_with_containers.add(account) + db_session_with_containers.commit() + db_session_with_containers.refresh(account) + + tenant = Tenant(name="Test Tenant") + tenant.created_at = datetime.now(UTC) + tenant.updated_at = datetime.now(UTC) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + db_session_with_containers.refresh(tenant) + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + ) + tenant_join.created_at = datetime.now(UTC) + tenant_join.updated_at = datetime.now(UTC) + db_session_with_containers.add(tenant_join) + db_session_with_containers.commit() + + return tenant, account + + +def _build_form(db_session_with_containers, tenant, account, *, app_id: str, workflow_execution_id: str): + delivery_method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(user_id=account.id), + ExternalRecipient(email="external@example.com"), + ], + ), + subject="Action needed {{ node_title }} {{#node1.value#}}", + body="Token {{ form_token }} link {{#url#}} content {{#node1.value#}}", + ) + ) + + node_data = HumanInputNodeData( + title="Review", + form_content="Form content", + delivery_methods=[delivery_method], + ) + + engine = db_session_with_containers.get_bind() + repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + params = FormCreateParams( + app_id=app_id, + workflow_execution_id=workflow_execution_id, + node_id="node-1", + form_config=node_data, + rendered_content="Rendered", + delivery_methods=node_data.delivery_methods, + display_in_ui=False, + resolved_default_values={}, + ) + return repo.create_form(params) + + +def _create_workflow_pause_state( + db_session_with_containers, + *, + workflow_run_id: str, + workflow_id: str, + tenant_id: str, + app_id: str, + account_id: str, + variable_pool: VariablePool, +): + workflow_run = WorkflowRun( + id=workflow_run_id, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + version="1", + graph="{}", + inputs="{}", + status=WorkflowExecutionStatus.PAUSED, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account_id, + created_at=datetime.now(UTC), + ) + db_session_with_containers.add(workflow_run) + + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + resumption_context = WorkflowResumptionContext( + generate_entity={ + "type": AppMode.WORKFLOW, + "entity": WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=WorkflowUIBasedAppConfig( + tenant_id=tenant_id, + app_id=app_id, + app_mode=AppMode.WORKFLOW, + workflow_id=workflow_id, + ), + inputs={}, + files=[], + user_id=account_id, + stream=False, + invoke_from=InvokeFrom.WEB_APP, + workflow_execution_id=workflow_run_id, + ), + }, + serialized_graph_runtime_state=runtime_state.dumps(), + ) + + state_object_key = f"workflow_pause_states/{workflow_run_id}.json" + storage.save(state_object_key, resumption_context.dumps().encode()) + + pause_state = WorkflowPause( + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + state_object_key=state_object_key, + ) + db_session_with_containers.add(pause_state) + db_session_with_containers.commit() + + +def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers): + tenant, account = _create_workspace_member(db_session_with_containers) + workflow_run_id = str(uuid.uuid4()) + workflow_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + variable_pool = VariablePool() + variable_pool.add(["node1", "value"], "OK") + _create_workflow_pause_state( + db_session_with_containers, + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + tenant_id=tenant.id, + app_id=app_id, + account_id=account.id, + variable_pool=variable_pool, + ) + form_entity = _build_form( + db_session_with_containers, + tenant, + account, + app_id=app_id, + workflow_execution_id=workflow_run_id, + ) + + monkeypatch.setattr(dify_config, "APP_WEB_URL", "https://app.example.com") + + with patch("tasks.mail_human_input_delivery_task.mail") as mock_mail: + mock_mail.is_inited.return_value = True + + dispatch_human_input_email_task(form_id=form_entity.id, node_title="Approval") + + assert mock_mail.send.call_count == 2 + send_args = [call.kwargs for call in mock_mail.send.call_args_list] + recipients = {kwargs["to"] for kwargs in send_args} + assert recipients == {"owner@example.com", "external@example.com"} + assert all(kwargs["subject"] == "Action needed {{ node_title }} {{#node1.value#}}" for kwargs in send_args) + assert all("app.example.com/form/" in kwargs["html"] for kwargs in send_args) + assert all("content OK" in kwargs["html"] for kwargs in send_args) + assert all("{{ form_token }}" in kwargs["html"] for kwargs in send_args) diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 889e3d1d83..5f4f28cf4f 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -94,11 +94,6 @@ class PrunePausesTestCase: def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]: """Create test cases for pause workflow failure scenarios.""" return [ - PauseWorkflowFailureCase( - name="pause_already_paused_workflow", - initial_status=WorkflowExecutionStatus.PAUSED, - description="Should fail to pause an already paused workflow", - ), PauseWorkflowFailureCase( name="pause_completed_workflow", initial_status=WorkflowExecutionStatus.SUCCEEDED, diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 6fce7849f9..cf52980e57 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -164,6 +164,62 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): assert "timezone=UTC" in options +def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("REDIS_HOST", "redis.example.com") + monkeypatch.setenv("REDIS_PORT", "6380") + monkeypatch.setenv("REDIS_USERNAME", "user") + monkeypatch.setenv("REDIS_PASSWORD", "pass@word") + monkeypatch.setenv("REDIS_DB", "2") + monkeypatch.setenv("REDIS_USE_SSL", "true") + + config = DifyConfig() + + assert config.normalized_pubsub_redis_url == "rediss://user:pass%40word@redis.example.com:6380/2" + assert config.PUBSUB_REDIS_CHANNEL_TYPE == "pubsub" + + +def test_pubsub_redis_url_override(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("PUBSUB_REDIS_URL", "redis://pubsub-host:6381/5") + + config = DifyConfig() + + assert config.normalized_pubsub_redis_url == "redis://pubsub-host:6381/5" + + +def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("REDIS_HOST", "") + + with pytest.raises(ValueError, match="PUBSUB_REDIS_URL must be set"): + _ = DifyConfig().normalized_pubsub_redis_url + + @pytest.mark.parametrize( ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), [ diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index c5e1576186..da957d3a81 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from sqlalchemy import create_engine # Getting the absolute path of the current file's directory ABS_PATH = os.path.dirname(os.path.abspath(__file__)) @@ -36,6 +37,7 @@ import sys sys.path.insert(0, PROJECT_DIR) +from core.db.session_factory import configure_session_factory, session_factory from extensions import ext_redis @@ -49,6 +51,8 @@ def _patch_redis_clients_on_loaded_modules(): continue if hasattr(module, "redis_client"): module.redis_client = redis_mock + if hasattr(module, "pubsub_redis_client"): + module.pubsub_redis_client = redis_mock @pytest.fixture @@ -66,7 +70,10 @@ def _provide_app_context(app: Flask): def _patch_redis_clients(): """Patch redis_client to MagicMock only for unit test executions.""" - with patch.object(ext_redis, "redis_client", redis_mock): + with ( + patch.object(ext_redis, "redis_client", redis_mock), + patch.object(ext_redis, "pubsub_redis_client", redis_mock), + ): _patch_redis_clients_on_loaded_modules() yield @@ -102,3 +109,18 @@ def reset_secret_key(): yield finally: dify_config.SECRET_KEY = original + + +@pytest.fixture(scope="session") +def _unit_test_engine(): + engine = create_engine("sqlite:///:memory:") + yield engine + engine.dispose() + + +@pytest.fixture(autouse=True) +def _configure_session_factory(_unit_test_engine): + try: + session_factory.get_session_maker() + except RuntimeError: + configure_session_factory(_unit_test_engine, expire_on_commit=False) diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index 40eb59a8f4..2ac3dc037d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -16,11 +16,9 @@ if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] -def _load_app_module(): +@pytest.fixture(scope="module") +def app_module(): module_name = "controllers.console.app.app" - if module_name in sys.modules: - return sys.modules[module_name] - root = Path(__file__).resolve().parents[5] module_path = root / "controllers" / "console" / "app" / "app.py" @@ -31,6 +29,13 @@ def _load_app_module(): def schema_model(self, name, schema): self.models[name] = schema + return schema + + def model(self, name, model_dict=None, **kwargs): + """Register a model with the namespace (flask-restx compatibility).""" + if model_dict is not None: + self.models[name] = model_dict + return model_dict def _decorator(self, obj): return obj @@ -52,8 +57,12 @@ def _load_app_module(): stub_namespace = _StubNamespace() - original_console = sys.modules.get("controllers.console") - original_app_pkg = sys.modules.get("controllers.console.app") + original_modules: dict[str, ModuleType | None] = { + "controllers.console": sys.modules.get("controllers.console"), + "controllers.console.app": sys.modules.get("controllers.console.app"), + "controllers.common.schema": sys.modules.get("controllers.common.schema"), + module_name: sys.modules.get(module_name), + } stubbed_modules: list[tuple[str, ModuleType | None]] = [] console_module = ModuleType("controllers.console") @@ -98,35 +107,35 @@ def _load_app_module(): module = util.module_from_spec(spec) sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + try: - assert spec.loader is not None - spec.loader.exec_module(module) + yield module finally: for name, original in reversed(stubbed_modules): if original is not None: sys.modules[name] = original else: sys.modules.pop(name, None) - if original_console is not None: - sys.modules["controllers.console"] = original_console - else: - sys.modules.pop("controllers.console", None) - if original_app_pkg is not None: - sys.modules["controllers.console.app"] = original_app_pkg - else: - sys.modules.pop("controllers.console.app", None) - - return module + for name, original in original_modules.items(): + if original is not None: + sys.modules[name] = original + else: + sys.modules.pop(name, None) -_app_module = _load_app_module() -AppDetailWithSite = _app_module.AppDetailWithSite -AppPagination = _app_module.AppPagination -AppPartial = _app_module.AppPartial +@pytest.fixture(scope="module") +def app_models(app_module): + return SimpleNamespace( + AppDetailWithSite=app_module.AppDetailWithSite, + AppPagination=app_module.AppPagination, + AppPartial=app_module.AppPartial, + ) @pytest.fixture(autouse=True) -def patch_signed_url(monkeypatch): +def patch_signed_url(monkeypatch, app_module): """Ensure icon URL generation uses a deterministic helper for tests.""" def _fake_signed_url(key: str | None) -> str | None: @@ -134,7 +143,7 @@ def patch_signed_url(monkeypatch): return None return f"signed:{key}" - monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url) def _ts(hour: int = 12) -> datetime: @@ -162,7 +171,8 @@ def _dummy_workflow(): ) -def test_app_partial_serialization_uses_aliases(): +def test_app_partial_serialization_uses_aliases(app_models): + AppPartial = app_models.AppPartial created_at = _ts() app_obj = SimpleNamespace( id="app-1", @@ -197,7 +207,8 @@ def test_app_partial_serialization_uses_aliases(): assert serialized["tags"][0]["name"] == "Utilities" -def test_app_detail_with_site_includes_nested_serialization(): +def test_app_detail_with_site_includes_nested_serialization(app_models): + AppDetailWithSite = app_models.AppDetailWithSite timestamp = _ts(14) site = SimpleNamespace( code="site-code", @@ -246,7 +257,8 @@ def test_app_detail_with_site_includes_nested_serialization(): assert serialized["site"]["created_at"] == int(timestamp.timestamp()) -def test_app_pagination_aliases_per_page_and_has_next(): +def test_app_pagination_aliases_per_page_and_has_next(app_models): + AppPagination = app_models.AppPagination item_one = SimpleNamespace( id="app-10", name="Paginated One", diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py new file mode 100644 index 0000000000..86a3b2bd93 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow as workflow_module +from controllers.console.app import wraps as app_wraps +from libs import login as login_lib +from models.account import Account, AccountStatus, TenantAccountRole +from models.model import AppMode + + +def _make_account() -> Account: + account = Account(name="tester", email="tester@example.com") + account.status = AccountStatus.ACTIVE + account.role = TenantAccountRole.OWNER + account.id = "account-123" # type: ignore[assignment] + account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] + account._get_current_object = lambda: account # type: ignore[attr-defined] + return account + + +def _make_app(mode: AppMode) -> SimpleNamespace: + return SimpleNamespace(id="app-123", tenant_id="tenant-123", mode=mode.value) + + +def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None: + # Skip setup and auth guardrails + monkeypatch.setattr("configs.dify_config.EDITION", "CLOUD") + monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) + monkeypatch.setattr(login_lib, "current_user", account) + monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) + monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") + monkeypatch.delenv("INIT_PASSWORD", raising=False) + + # Avoid hitting the database when resolving the app model + monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) + + +@dataclass +class PreviewCase: + resource_cls: type + path: str + mode: AppMode + + +@pytest.mark.parametrize( + "case", + [ + PreviewCase( + resource_cls=workflow_module.AdvancedChatDraftHumanInputFormPreviewApi, + path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-42/form/preview", + mode=AppMode.ADVANCED_CHAT, + ), + PreviewCase( + resource_cls=workflow_module.WorkflowDraftHumanInputFormPreviewApi, + path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-42/form/preview", + mode=AppMode.WORKFLOW, + ), + ], +) +def test_human_input_preview_delegates_to_service( + app: Flask, monkeypatch: pytest.MonkeyPatch, case: PreviewCase +) -> None: + account = _make_account() + app_model = _make_app(case.mode) + _patch_console_guards(monkeypatch, account, app_model) + + preview_payload = { + "form_id": "node-42", + "form_content": "
example
", + "inputs": [{"name": "topic"}], + "actions": [{"id": "continue"}], + } + service_instance = MagicMock() + service_instance.get_human_input_form_preview.return_value = preview_payload + monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) + + with app.test_request_context(case.path, method="POST", json={"inputs": {"topic": "tech"}}): + response = case.resource_cls().post(app_id=app_model.id, node_id="node-42") + + assert response == preview_payload + service_instance.get_human_input_form_preview.assert_called_once_with( + app_model=app_model, + account=account, + node_id="node-42", + inputs={"topic": "tech"}, + ) + + +@dataclass +class SubmitCase: + resource_cls: type + path: str + mode: AppMode + + +@pytest.mark.parametrize( + "case", + [ + SubmitCase( + resource_cls=workflow_module.AdvancedChatDraftHumanInputFormRunApi, + path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-99/form/run", + mode=AppMode.ADVANCED_CHAT, + ), + SubmitCase( + resource_cls=workflow_module.WorkflowDraftHumanInputFormRunApi, + path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-99/form/run", + mode=AppMode.WORKFLOW, + ), + ], +) +def test_human_input_submit_forwards_payload(app: Flask, monkeypatch: pytest.MonkeyPatch, case: SubmitCase) -> None: + account = _make_account() + app_model = _make_app(case.mode) + _patch_console_guards(monkeypatch, account, app_model) + + result_payload = {"node_id": "node-99", "outputs": {"__rendered_content": "

done

"}, "action": "approve"} + service_instance = MagicMock() + service_instance.submit_human_input_form_preview.return_value = result_payload + monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) + + with app.test_request_context( + case.path, + method="POST", + json={"form_inputs": {"answer": "42"}, "inputs": {"#node-1.result#": "LLM output"}, "action": "approve"}, + ): + response = case.resource_cls().post(app_id=app_model.id, node_id="node-99") + + assert response == result_payload + service_instance.submit_human_input_form_preview.assert_called_once_with( + app_model=app_model, + account=account, + node_id="node-99", + form_inputs={"answer": "42"}, + inputs={"#node-1.result#": "LLM output"}, + action="approve", + ) + + +@dataclass +class DeliveryTestCase: + resource_cls: type + path: str + mode: AppMode + + +@pytest.mark.parametrize( + "case", + [ + DeliveryTestCase( + resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi, + path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test", + mode=AppMode.ADVANCED_CHAT, + ), + DeliveryTestCase( + resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi, + path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test", + mode=AppMode.WORKFLOW, + ), + ], +) +def test_human_input_delivery_test_calls_service( + app: Flask, monkeypatch: pytest.MonkeyPatch, case: DeliveryTestCase +) -> None: + account = _make_account() + app_model = _make_app(case.mode) + _patch_console_guards(monkeypatch, account, app_model) + + service_instance = MagicMock() + monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) + + with app.test_request_context( + case.path, + method="POST", + json={"delivery_method_id": "delivery-123"}, + ): + response = case.resource_cls().post(app_id=app_model.id, node_id="node-7") + + assert response == {} + service_instance.test_human_input_delivery.assert_called_once_with( + app_model=app_model, + account=account, + node_id="node-7", + delivery_method_id="delivery-123", + inputs={}, + ) + + +def test_human_input_delivery_test_maps_validation_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + account = _make_account() + app_model = _make_app(AppMode.ADVANCED_CHAT) + _patch_console_guards(monkeypatch, account, app_model) + + service_instance = MagicMock() + service_instance.test_human_input_delivery.side_effect = ValueError("bad delivery method") + monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) + + with app.test_request_context( + "/console/api/apps/app-123/workflows/draft/human-input/nodes/node-1/delivery-test", + method="POST", + json={"delivery_method_id": "bad"}, + ): + with pytest.raises(ValueError): + workflow_module.WorkflowDraftHumanInputDeliveryTestApi().post(app_id=app_model.id, node_id="node-1") + + +def test_human_input_preview_rejects_non_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + account = _make_account() + app_model = _make_app(AppMode.ADVANCED_CHAT) + _patch_console_guards(monkeypatch, account, app_model) + + with app.test_request_context( + "/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-1/form/preview", + method="POST", + json={"inputs": ["not-a-dict"]}, + ): + with pytest.raises(ValidationError): + workflow_module.AdvancedChatDraftHumanInputFormPreviewApi().post(app_id=app_model.id, node_id="node-1") diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py new file mode 100644 index 0000000000..34d6a2232c --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from flask import Flask + +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_run as workflow_run_module +from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.nodes.human_input.entities import FormInput, UserAction +from core.workflow.nodes.human_input.enums import FormInputType +from libs import login as login_lib +from models.account import Account, AccountStatus, TenantAccountRole +from models.workflow import WorkflowRun + + +def _make_account() -> Account: + account = Account(name="tester", email="tester@example.com") + account.status = AccountStatus.ACTIVE + account.role = TenantAccountRole.OWNER + account.id = "account-123" # type: ignore[assignment] + account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] + account._get_current_object = lambda: account # type: ignore[attr-defined] + return account + + +def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account) -> None: + monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) + monkeypatch.setattr(login_lib, "current_user", account) + monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) + monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(workflow_run_module, "current_user", account) + monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") + + +class _PauseEntity: + def __init__(self, paused_at: datetime, reasons: list[HumanInputRequired]): + self.paused_at = paused_at + self._reasons = reasons + + def get_pause_reasons(self): + return self._reasons + + +def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + account = _make_account() + _patch_console_guards(monkeypatch, account) + monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com") + + workflow_run = Mock(spec=WorkflowRun) + workflow_run.status = WorkflowExecutionStatus.PAUSED + workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0) + fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run)) + monkeypatch.setattr(workflow_run_module, "db", fake_db) + + reason = HumanInputRequired( + form_id="form-1", + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + actions=[UserAction(id="approve", title="Approve")], + node_id="node-1", + node_title="Ask Name", + form_token="backstage-token", + ) + pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason]) + + repo = Mock() + repo.get_workflow_pause.return_value = pause_entity + monkeypatch.setattr( + workflow_run_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_, **__: repo, + ) + + with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): + response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") + + assert status == 200 + assert response["paused_at"] == "2024-01-01T12:00:00Z" + assert response["paused_nodes"][0]["node_id"] == "node-1" + assert response["paused_nodes"][0]["pause_type"]["type"] == "human_input" + assert ( + response["paused_nodes"][0]["pause_type"]["backstage_input_url"] + == "https://web.example.com/form/backstage-token" + ) + assert "pending_human_inputs" not in response diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py new file mode 100644 index 0000000000..fcaa61a871 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -0,0 +1,25 @@ +from types import SimpleNamespace + +from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField +from core.workflow.enums import WorkflowExecutionStatus + + +def test_workflow_run_status_field_with_enum() -> None: + field = WorkflowRunStatusField() + obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED) + + assert field.output("status", obj) == "paused" + + +def test_workflow_run_outputs_field_paused_returns_empty() -> None: + field = WorkflowRunOutputsField() + obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED, outputs_dict={"foo": "bar"}) + + assert field.output("outputs", obj) == {} + + +def test_workflow_run_outputs_field_running_returns_outputs() -> None: + field = WorkflowRunOutputsField() + obj = SimpleNamespace(status=WorkflowExecutionStatus.RUNNING, outputs_dict={"foo": "bar"}) + + assert field.output("outputs", obj) == {"foo": "bar"} diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py new file mode 100644 index 0000000000..4fb735b033 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -0,0 +1,456 @@ +"""Unit tests for controllers.web.human_input_form endpoints.""" + +from __future__ import annotations + +import json +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +import controllers.web.human_input_form as human_input_module +import controllers.web.site as site_module +from controllers.web.error import WebFormRateLimitExceededError +from models.human_input import RecipientType +from services.human_input_service import FormExpiredError + +HumanInputFormApi = human_input_module.HumanInputFormApi +TenantStatus = human_input_module.TenantStatus + + +@pytest.fixture +def app() -> Flask: + """Configure a minimal Flask app for request contexts.""" + + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +class _FakeSession: + """Simple stand-in for db.session that returns pre-seeded objects.""" + + def __init__(self, mapping: dict[str, Any]): + self._mapping = mapping + self._model_name: str | None = None + + def query(self, model): + self._model_name = model.__name__ + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + assert self._model_name is not None + return self._mapping.get(self._model_name) + + +class _FakeDB: + """Minimal db stub exposing engine and session.""" + + def __init__(self, session: _FakeSession): + self.session = session + self.engine = object() + + +def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask): + """GET returns form definition merged with site payload.""" + + expiration_time = datetime(2099, 1, 1, tzinfo=UTC) + + class _FakeDefinition: + def model_dump(self): + return { + "form_content": "Raw content", + "rendered_content": "Rendered {{#$output.name#}}", + "inputs": [{"type": "text", "output_variable_name": "name", "default": None}], + "default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}}, + "user_actions": [{"id": "approve", "title": "Approve", "button_style": "default"}], + } + + class _FakeForm: + def __init__(self, expiration: datetime): + self.workflow_run_id = "workflow-1" + self.app_id = "app-1" + self.tenant_id = "tenant-1" + self.expiration_time = expiration + self.recipient_type = RecipientType.BACKSTAGE + + def get_definition(self): + return _FakeDefinition() + + form = _FakeForm(expiration_time) + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = False + monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + + tenant = SimpleNamespace( + id="tenant-1", + status=TenantStatus.NORMAL, + plan="basic", + custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False}, + ) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + workflow_run = SimpleNamespace(app_id="app-1") + site_model = SimpleNamespace( + title="My Site", + icon_type="emoji", + icon="robot", + icon_background="#fff", + description="desc", + default_language="en", + chat_color_theme="light", + chat_color_theme_inverted=False, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + prompt_public=False, + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + # Patch service to return fake form. + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = form + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + + # Patch db session. + db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model})) + monkeypatch.setattr(human_input_module, "db", db_stub) + + monkeypatch.setattr( + site_module.FeatureService, + "get_features", + lambda tenant_id: SimpleNamespace(can_replace_logo=True), + ) + + with app.test_request_context("/api/form/human_input/token-1", method="GET"): + response = HumanInputFormApi().get("token-1") + + body = json.loads(response.get_data(as_text=True)) + assert set(body.keys()) == { + "site", + "form_content", + "inputs", + "resolved_default_values", + "user_actions", + "expiration_time", + } + assert body["form_content"] == "Rendered {{#$output.name#}}" + assert body["inputs"] == [{"type": "text", "output_variable_name": "name", "default": None}] + assert body["resolved_default_values"] == {"name": "Alice", "age": "30", "meta": '{"k": "v"}'} + assert body["user_actions"] == [{"id": "approve", "title": "Approve", "button_style": "default"}] + assert body["expiration_time"] == int(expiration_time.timestamp()) + assert body["site"] == { + "app_id": "app-1", + "end_user_id": None, + "enable_site": True, + "site": { + "title": "My Site", + "chat_color_theme": "light", + "chat_color_theme_inverted": False, + "icon_type": "emoji", + "icon": "robot", + "icon_background": "#fff", + "icon_url": None, + "description": "desc", + "copyright": None, + "privacy_policy": None, + "custom_disclaimer": None, + "default_language": "en", + "prompt_public": False, + "show_workflow_steps": True, + "use_icon_as_answer_icon": False, + }, + "model_config": None, + "plan": "basic", + "can_replace_logo": True, + "custom_config": { + "remove_webapp_brand": True, + "replace_webapp_logo": None, + }, + } + service_mock.get_form_by_token.assert_called_once_with("token-1") + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") + + +def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask): + """GET returns form payload for backstage token.""" + + expiration_time = datetime(2099, 1, 2, tzinfo=UTC) + + class _FakeDefinition: + def model_dump(self): + return { + "form_content": "Raw content", + "rendered_content": "Rendered", + "inputs": [], + "default_values": {}, + "user_actions": [], + } + + class _FakeForm: + def __init__(self, expiration: datetime): + self.workflow_run_id = "workflow-1" + self.app_id = "app-1" + self.tenant_id = "tenant-1" + self.expiration_time = expiration + + def get_definition(self): + return _FakeDefinition() + + form = _FakeForm(expiration_time) + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = False + monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + tenant = SimpleNamespace( + id="tenant-1", + status=TenantStatus.NORMAL, + plan="basic", + custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False}, + ) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + workflow_run = SimpleNamespace(app_id="app-1") + site_model = SimpleNamespace( + title="My Site", + icon_type="emoji", + icon="robot", + icon_background="#fff", + description="desc", + default_language="en", + chat_color_theme="light", + chat_color_theme_inverted=False, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + prompt_public=False, + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = form + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + + db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model})) + monkeypatch.setattr(human_input_module, "db", db_stub) + + monkeypatch.setattr( + site_module.FeatureService, + "get_features", + lambda tenant_id: SimpleNamespace(can_replace_logo=True), + ) + + with app.test_request_context("/api/form/human_input/token-1", method="GET"): + response = HumanInputFormApi().get("token-1") + + body = json.loads(response.get_data(as_text=True)) + assert set(body.keys()) == { + "site", + "form_content", + "inputs", + "resolved_default_values", + "user_actions", + "expiration_time", + } + assert body["form_content"] == "Rendered" + assert body["inputs"] == [] + assert body["resolved_default_values"] == {} + assert body["user_actions"] == [] + assert body["expiration_time"] == int(expiration_time.timestamp()) + assert body["site"] == { + "app_id": "app-1", + "end_user_id": None, + "enable_site": True, + "site": { + "title": "My Site", + "chat_color_theme": "light", + "chat_color_theme_inverted": False, + "icon_type": "emoji", + "icon": "robot", + "icon_background": "#fff", + "icon_url": None, + "description": "desc", + "copyright": None, + "privacy_policy": None, + "custom_disclaimer": None, + "default_language": "en", + "prompt_public": False, + "show_workflow_steps": True, + "use_icon_as_answer_icon": False, + }, + "model_config": None, + "plan": "basic", + "can_replace_logo": True, + "custom_config": { + "remove_webapp_brand": True, + "replace_webapp_logo": None, + }, + } + service_mock.get_form_by_token.assert_called_once_with("token-1") + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") + + +def test_get_form_raises_forbidden_when_site_missing(monkeypatch: pytest.MonkeyPatch, app: Flask): + """GET raises Forbidden if site cannot be resolved.""" + + expiration_time = datetime(2099, 1, 3, tzinfo=UTC) + + class _FakeDefinition: + def model_dump(self): + return { + "form_content": "Raw content", + "rendered_content": "Rendered", + "inputs": [], + "default_values": {}, + "user_actions": [], + } + + class _FakeForm: + def __init__(self, expiration: datetime): + self.workflow_run_id = "workflow-1" + self.app_id = "app-1" + self.tenant_id = "tenant-1" + self.expiration_time = expiration + + def get_definition(self): + return _FakeDefinition() + + form = _FakeForm(expiration_time) + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = False + monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + tenant = SimpleNamespace(status=TenantStatus.NORMAL) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + workflow_run = SimpleNamespace(app_id="app-1") + + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = form + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + + db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": None})) + monkeypatch.setattr(human_input_module, "db", db_stub) + + with app.test_request_context("/api/form/human_input/token-1", method="GET"): + with pytest.raises(Forbidden): + HumanInputFormApi().get("token-1") + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") + + +def test_submit_form_accepts_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask): + """POST forwards backstage submissions to the service.""" + + class _FakeForm: + recipient_type = RecipientType.BACKSTAGE + + form = _FakeForm() + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = False + monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = form + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) + + with app.test_request_context( + "/api/form/human_input/token-1", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + response, status = HumanInputFormApi().post("token-1") + + assert status == 200 + assert response == {} + service_mock.submit_form_by_token.assert_called_once_with( + recipient_type=RecipientType.BACKSTAGE, + form_token="token-1", + selected_action_id="approve", + form_data={"content": "ok"}, + submission_end_user_id=None, + ) + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") + + +def test_submit_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask): + """POST rejects submissions when rate limit is exceeded.""" + + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = True + monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = None + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) + + with app.test_request_context( + "/api/form/human_input/token-1", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + with pytest.raises(WebFormRateLimitExceededError): + HumanInputFormApi().post("token-1") + + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_not_called() + service_mock.get_form_by_token.assert_not_called() + + +def test_get_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask): + """GET rejects requests when rate limit is exceeded.""" + + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = True + monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = None + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) + + with app.test_request_context("/api/form/human_input/token-1", method="GET"): + with pytest.raises(WebFormRateLimitExceededError): + HumanInputFormApi().get("token-1") + + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_not_called() + service_mock.get_form_by_token.assert_not_called() + + +def test_get_form_raises_expired(monkeypatch: pytest.MonkeyPatch, app: Flask): + class _FakeForm: + pass + + form = _FakeForm() + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = False + monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = form + service_mock.ensure_form_active.side_effect = FormExpiredError("form-id") + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) + + with app.test_request_context("/api/form/human_input/token-1", method="GET"): + with pytest.raises(FormExpiredError): + HumanInputFormApi().get("token-1") + + service_mock.ensure_form_active.assert_called_once_with(form) + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") diff --git a/api/tests/unit_tests/controllers/web/test_message_list.py b/api/tests/unit_tests/controllers/web/test_message_list.py index 2835f7ffbf..1c096bfbcf 100644 --- a/api/tests/unit_tests/controllers/web/test_message_list.py +++ b/api/tests/unit_tests/controllers/web/test_message_list.py @@ -3,6 +3,7 @@ from __future__ import annotations import builtins +import uuid from datetime import datetime from types import ModuleType, SimpleNamespace from unittest.mock import patch @@ -12,6 +13,8 @@ import pytest from flask import Flask from flask.views import MethodView +from core.entities.execution_extra_content import HumanInputContent + # Ensure flask_restx.api finds MethodView during import. if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] @@ -137,6 +140,12 @@ def test_message_list_mapping(app: Flask) -> None: status="success", error=None, message_metadata_dict={"meta": "value"}, + extra_contents=[ + HumanInputContent( + workflow_run_id=str(uuid.uuid4()), + submitted=True, + ) + ], ) pagination = SimpleNamespace(limit=20, has_more=False, data=[message]) @@ -169,6 +178,8 @@ def test_message_list_mapping(app: Flask) -> None: assert item["agent_thoughts"][0]["chain_id"] == "chain-1" assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp()) + assert item["extra_contents"][0]["workflow_run_id"] == message.extra_contents[0].workflow_run_id + assert item["extra_contents"][0]["submitted"] == message.extra_contents[0].submitted assert item["message_files"][0]["id"] == "file-dict" assert item["message_files"][1]["id"] == "file-obj" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py new file mode 100644 index 0000000000..a94b5445f7 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime +from types import SimpleNamespace +from unittest import mock + +import pytest + +from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent +from core.workflow.entities.pause_reason import HumanInputRequired +from models.enums import MessageStatus +from models.execution_extra_content import HumanInputContent +from models.model import EndUser + + +def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline: + pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__( + pipeline_module.AdvancedChatAppGenerateTaskPipeline + ) + pipeline._workflow_run_id = "run-1" + pipeline._message_id = "message-1" + pipeline._workflow_tenant_id = "tenant-1" + return pipeline + + +def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None: + pipeline = _build_pipeline() + monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1") + + captured_session: dict[str, mock.Mock] = {} + + @contextmanager + def fake_session(): + session = mock.Mock() + session.scalar.return_value = None + captured_session["session"] = session + yield session + + pipeline._database_session = fake_session # type: ignore[method-assign] + + pipeline._persist_human_input_extra_content(node_id="node-1") + + session = captured_session["session"] + session.add.assert_called_once() + content = session.add.call_args.args[0] + assert isinstance(content, HumanInputContent) + assert content.workflow_run_id == "run-1" + assert content.message_id == "message-1" + assert content.form_id == "form-1" + + +def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: + pipeline = _build_pipeline() + monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None) + + called = {"value": False} + + @contextmanager + def fake_session(): + called["value"] = True + session = mock.Mock() + yield session + + pipeline._database_session = fake_session # type: ignore[method-assign] + + pipeline._persist_human_input_extra_content(node_id="node-1") + + assert called["value"] is False + + +def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None: + pipeline = _build_pipeline() + monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1") + + captured_session: dict[str, mock.Mock] = {} + + @contextmanager + def fake_session(): + session = mock.Mock() + session.scalar.return_value = HumanInputContent( + workflow_run_id="run-1", + message_id="message-1", + form_id="form-1", + ) + captured_session["session"] = session + yield session + + pipeline._database_session = fake_session # type: ignore[method-assign] + + pipeline._persist_human_input_extra_content(node_id="node-1") + + session = captured_session["session"] + session.add.assert_not_called() + + +def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None: + pipeline = _build_pipeline() + pipeline._application_generate_entity = SimpleNamespace(task_id="task-1") + pipeline._workflow_response_converter = mock.Mock() + pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = [] + pipeline._ensure_graph_runtime_initialized = mock.Mock( + return_value=SimpleNamespace( + total_tokens=0, + node_run_steps=0, + ), + ) + pipeline._save_message = mock.Mock() + message = SimpleNamespace(status=MessageStatus.NORMAL) + pipeline._get_message = mock.Mock(return_value=message) + pipeline._persist_human_input_extra_content = mock.Mock() + pipeline._base_task_pipeline = mock.Mock() + pipeline._base_task_pipeline.queue_manager = mock.Mock() + pipeline._message_saved_on_pause = False + + @contextmanager + def fake_session(): + session = mock.Mock() + yield session + + pipeline._database_session = fake_session # type: ignore[method-assign] + + reason = HumanInputRequired( + form_id="form-1", + form_content="content", + inputs=[], + actions=[], + node_id="node-1", + node_title="Approval", + form_token="token-1", + resolved_default_values={}, + ) + event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"]) + + list(pipeline._handle_workflow_paused_event(event)) + + pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1") + assert message.status == MessageStatus.PAUSED + + +def test_resume_appends_chunks_to_paused_answer() -> None: + app_config = SimpleNamespace(app_id="app-1", tenant_id="tenant-1", sensitive_word_avoidance=None) + application_generate_entity = SimpleNamespace( + app_config=app_config, + files=[], + workflow_run_id="run-1", + query="hello", + invoke_from=InvokeFrom.WEB_APP, + inputs={}, + task_id="task-1", + ) + queue_manager = SimpleNamespace(graph_runtime_state=None) + conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat") + message = SimpleNamespace( + id="message-1", + created_at=datetime(2024, 1, 1), + query="hello", + answer="before", + status=MessageStatus.PAUSED, + ) + user = EndUser() + user.id = "user-1" + user.session_id = "session-1" + workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={}) + + pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=True, + dialogue_count=1, + draft_var_saver_factory=SimpleNamespace(), + ) + + pipeline._get_message = mock.Mock(return_value=message) + pipeline._recorded_files = [] + + list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after"))) + pipeline._save_message(session=mock.Mock()) + + assert message.answer == "beforeafter" + assert message.status == MessageStatus.NORMAL diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py new file mode 100644 index 0000000000..1c36b4d12b --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -0,0 +1,87 @@ +from datetime import UTC, datetime +from types import SimpleNamespace + +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + + +def _build_converter(): + system_variables = SystemVariable( + files=[], + user_id="user-1", + app_id="app-1", + workflow_id="wf-1", + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + app_entity = SimpleNamespace( + task_id="task-1", + app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"), + invoke_from=InvokeFrom.EXPLORE, + files=[], + inputs={}, + workflow_execution_id="run-1", + call_depth=0, + ) + account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com") + return WorkflowResponseConverter( + application_generate_entity=app_entity, + user=account, + system_variables=system_variables, + ) + + +def test_human_input_form_filled_stream_response_contains_rendered_content(): + converter = _build_converter() + converter.workflow_start_to_stream_response( + task_id="task-1", + workflow_run_id="run-1", + workflow_id="wf-1", + reason=WorkflowStartReason.INITIAL, + ) + + queue_event = QueueHumanInputFormFilledEvent( + node_execution_id="exec-1", + node_id="node-1", + node_type="human-input", + node_title="Human Input", + rendered_content="# Title\nvalue", + action_id="Approve", + action_text="Approve", + ) + + resp = converter.human_input_form_filled_to_stream_response(event=queue_event, task_id="task-1") + + assert resp.workflow_run_id == "run-1" + assert resp.data.node_id == "node-1" + assert resp.data.node_title == "Human Input" + assert resp.data.rendered_content.startswith("# Title") + assert resp.data.action_id == "Approve" + + +def test_human_input_form_timeout_stream_response_contains_timeout_metadata(): + converter = _build_converter() + converter.workflow_start_to_stream_response( + task_id="task-1", + workflow_run_id="run-1", + workflow_id="wf-1", + reason=WorkflowStartReason.INITIAL, + ) + + queue_event = QueueHumanInputFormTimeoutEvent( + node_id="node-1", + node_type="human-input", + node_title="Human Input", + expiration_time=datetime(2025, 1, 1, tzinfo=UTC), + ) + + resp = converter.human_input_form_timeout_to_stream_response(event=queue_event, task_id="task-1") + + assert resp.workflow_run_id == "run-1" + assert resp.data.node_id == "node-1" + assert resp.data.node_title == "Human Input" + assert resp.data.expiration_time == 1735689600 diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py new file mode 100644 index 0000000000..0a9794e41c --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -0,0 +1,56 @@ +from types import SimpleNamespace + +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + + +def _build_converter() -> WorkflowResponseConverter: + """Construct a minimal WorkflowResponseConverter for testing.""" + system_variables = SystemVariable( + files=[], + user_id="user-1", + app_id="app-1", + workflow_id="wf-1", + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + app_entity = SimpleNamespace( + task_id="task-1", + app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"), + invoke_from=InvokeFrom.EXPLORE, + files=[], + inputs={}, + workflow_execution_id="run-1", + call_depth=0, + ) + account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com") + return WorkflowResponseConverter( + application_generate_entity=app_entity, + user=account, + system_variables=system_variables, + ) + + +def test_workflow_start_stream_response_carries_resumption_reason(): + converter = _build_converter() + resp = converter.workflow_start_to_stream_response( + task_id="task-1", + workflow_run_id="run-1", + workflow_id="wf-1", + reason=WorkflowStartReason.RESUMPTION, + ) + assert resp.data.reason is WorkflowStartReason.RESUMPTION + + +def test_workflow_start_stream_response_carries_initial_reason(): + converter = _build_converter() + resp = converter.workflow_start_to_stream_response( + task_id="task-1", + workflow_run_id="run-1", + workflow_id="wf-1", + reason=WorkflowStartReason.INITIAL, + ) + assert resp.data.reason is WorkflowStartReason.INITIAL diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 6b40bf462b..d25bff92dc 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -23,6 +23,7 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) +from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import NodeType from core.workflow.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now @@ -124,7 +125,12 @@ class TestWorkflowResponseConverter: original_data = {"large_field": "x" * 10000, "metadata": "info"} truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} - converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -160,7 +166,12 @@ class TestWorkflowResponseConverter: original_data = {"small": "data"} - converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -191,7 +202,12 @@ class TestWorkflowResponseConverter: """Test node finish response when process_data is None.""" converter = self.create_workflow_response_converter() - converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -225,7 +241,12 @@ class TestWorkflowResponseConverter: original_data = {"large_field": "x" * 10000, "metadata": "info"} truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} - converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -261,7 +282,12 @@ class TestWorkflowResponseConverter: original_data = {"small": "data"} - converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -400,6 +426,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: task_id="test-task-id", workflow_run_id="test-workflow-run-id", workflow_id="test-workflow-id", + reason=WorkflowStartReason.INITIAL, ) return converter diff --git a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py new file mode 100644 index 0000000000..f0d9afc0db --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps import message_based_app_generator +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.task_pipeline import message_cycle_manager +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from models.model import AppMode, Conversation, Message + + +def _make_app_config() -> WorkflowUIBasedAppConfig: + return WorkflowUIBasedAppConfig( + tenant_id="tenant-id", + app_id="app-id", + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-id", + additional_features=AppAdditionalFeatures(), + variables=[], + ) + + +def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity: + return AdvancedChatAppGenerateEntity( + task_id="task-id", + app_config=app_config, + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user-id", + stream=True, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + workflow_run_id="workflow-run-id", + ) + + +@pytest.fixture(autouse=True) +def _mock_db_session(monkeypatch): + session = MagicMock() + + def refresh_side_effect(obj): + if isinstance(obj, Conversation) and obj.id is None: + obj.id = "generated-conversation-id" + if isinstance(obj, Message) and obj.id is None: + obj.id = "generated-message-id" + + session.refresh.side_effect = refresh_side_effect + session.add.return_value = None + session.commit.return_value = None + + monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session)) + return session + + +def test_init_generate_records_sets_conversation_metadata(): + app_config = _make_app_config() + entity = _make_generate_entity(app_config) + + generator = AdvancedChatAppGenerator() + + conversation, _ = generator._init_generate_records(entity, conversation=None) + + assert entity.conversation_id == "generated-conversation-id" + assert conversation.id == "generated-conversation-id" + assert entity.is_new_conversation is True + + +def test_init_generate_records_marks_existing_conversation(): + app_config = _make_app_config() + entity = _make_generate_entity(app_config) + + existing_conversation = Conversation( + app_id=app_config.app_id, + app_model_config_id=None, + model_provider=None, + override_model_configs=None, + model_id=None, + mode=app_config.app_mode.value, + name="existing", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.WEB_APP.value, + from_source="api", + from_end_user_id="user-id", + from_account_id=None, + ) + existing_conversation.id = "existing-conversation-id" + + generator = AdvancedChatAppGenerator() + + conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation) + + assert entity.conversation_id == "existing-conversation-id" + assert conversation is existing_conversation + assert entity.is_new_conversation is False + + +def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch): + app_config = _make_app_config() + entity = _make_generate_entity(app_config) + entity.conversation_id = "existing-conversation-id" + entity.is_new_conversation = True + entity.extras = {"auto_generate_conversation_name": True} + + captured = {} + + class DummyThread: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.started = False + + def start(self): + self.started = True + + def fake_thread(**kwargs): + thread = DummyThread(**kwargs) + captured["thread"] = thread + return thread + + monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread) + + manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock()) + thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello") + + assert thread is captured["thread"] + assert thread.started is True + assert entity.is_new_conversation is False diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py new file mode 100644 index 0000000000..87b8dc51e7 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import ( + AppAdditionalFeatures, + EasyUIBasedAppConfig, + EasyUIBasedAppModelConfigFrom, + ModelConfigEntity, + PromptTemplateEntity, +) +from core.app.apps import message_based_app_generator +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom +from models.model import AppMode, Conversation, Message + + +class DummyModelConf: + def __init__(self, provider: str = "mock-provider", model: str = "mock-model") -> None: + self.provider = provider + self.model = model + + +class DummyCompletionGenerateEntity: + __slots__ = ("app_config", "invoke_from", "user_id", "query", "inputs", "files", "model_conf") + app_config: EasyUIBasedAppConfig + invoke_from: InvokeFrom + user_id: str + query: str + inputs: dict + files: list + model_conf: DummyModelConf + + def __init__(self, app_config: EasyUIBasedAppConfig) -> None: + self.app_config = app_config + self.invoke_from = InvokeFrom.WEB_APP + self.user_id = "user-id" + self.query = "hello" + self.inputs = {} + self.files = [] + self.model_conf = DummyModelConf() + + +def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig: + return EasyUIBasedAppConfig( + tenant_id="tenant-id", + app_id="app-id", + app_mode=app_mode, + app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG, + app_model_config_id="model-config-id", + app_model_config_dict={}, + model=ModelConfigEntity(provider="mock-provider", model="mock-model", mode="chat"), + prompt_template=PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="Hello", + ), + additional_features=AppAdditionalFeatures(), + variables=[], + ) + + +def _make_chat_generate_entity(app_config: EasyUIBasedAppConfig) -> ChatAppGenerateEntity: + return ChatAppGenerateEntity.model_construct( + task_id="task-id", + app_config=app_config, + model_conf=DummyModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user-id", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + +@pytest.fixture(autouse=True) +def _mock_db_session(monkeypatch): + session = MagicMock() + + def refresh_side_effect(obj): + if isinstance(obj, Conversation) and obj.id is None: + obj.id = "generated-conversation-id" + if isinstance(obj, Message) and obj.id is None: + obj.id = "generated-message-id" + + session.refresh.side_effect = refresh_side_effect + session.add.return_value = None + session.commit.return_value = None + + monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session)) + return session + + +def test_init_generate_records_skips_conversation_fields_for_non_conversation_entity(): + app_config = _make_app_config(AppMode.COMPLETION) + entity = DummyCompletionGenerateEntity(app_config=app_config) + + generator = MessageBasedAppGenerator() + + conversation, message = generator._init_generate_records(entity, conversation=None) + + assert conversation.id == "generated-conversation-id" + assert message.id == "generated-message-id" + assert hasattr(entity, "conversation_id") is False + assert hasattr(entity, "is_new_conversation") is False + + +def test_init_generate_records_sets_conversation_fields_for_chat_entity(): + app_config = _make_app_config(AppMode.CHAT) + entity = _make_chat_generate_entity(app_config) + + generator = MessageBasedAppGenerator() + + conversation, _ = generator._init_generate_records(entity, conversation=None) + + assert entity.conversation_id == "generated-conversation-id" + assert entity.is_new_conversation is True + assert conversation.id == "generated-conversation-id" diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py new file mode 100644 index 0000000000..97c993928e --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -0,0 +1,287 @@ +import sys +import time +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +API_DIR = str(Path(__file__).resolve().parents[5]) +if API_DIR not in sys.path: + sys.path.insert(0, API_DIR) + +import core.workflow.nodes.human_input.entities # noqa: F401 +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory +from core.workflow.entities import GraphInitParams +from core.workflow.entities.pause_reason import SchedulingPause +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph import Graph +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_events import ( + GraphEngineEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunSucceededEvent, +) +from core.workflow.node_events import NodeRunResult, PauseRequestedEvent +from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + +if "core.ops.ops_trace_manager" not in sys.modules: + ops_stub = ModuleType("core.ops.ops_trace_manager") + + class _StubTraceQueueManager: + def __init__(self, *_, **__): + pass + + ops_stub.TraceQueueManager = _StubTraceQueueManager + sys.modules["core.ops.ops_trace_manager"] = ops_stub + + +class _StubToolNodeData(BaseNodeData): + pause_on: bool = False + + +class _StubToolNode(Node[_StubToolNodeData]): + node_type = NodeType.TOOL + + @classmethod + def version(cls) -> str: + return "1" + + def init_node_data(self, data): + self._node_data = _StubToolNodeData.model_validate(data) + + def _get_error_strategy(self): + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self): + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + + def _run(self): + if self.node_data.pause_on: + yield PauseRequestedEvent(reason=SchedulingPause(message="test pause")) + return + + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"value": f"{self.id}-done"}, + ) + yield self._convert_node_run_result_to_graph_node_event(result) + + +def _patch_tool_node(mocker): + original_create_node = DifyNodeFactory.create_node + + def _patched_create_node(self, node_config: dict[str, object]) -> Node: + node_data = node_config.get("data", {}) + if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value: + return _StubToolNode( + id=str(node_config["id"]), + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + ) + return original_create_node(self, node_config) + + mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) + + +def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: + node_data = data.model_dump() + node_data["type"] = node_type.value + return node_data + + +def _build_graph_config(*, pause_on: str | None) -> dict[str, object]: + start_data = StartNodeData(title="start", variables=[]) + tool_data_a = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_a") + tool_data_b = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_b") + tool_data_c = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_c") + end_data = EndNodeData( + title="end", + outputs=[OutputVariableEntity(variable="result", value_selector=["tool_c", "value"])], + desc=None, + ) + + nodes = [ + {"id": "start", "data": _node_data(NodeType.START, start_data)}, + {"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)}, + {"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)}, + {"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)}, + {"id": "end", "data": _node_data(NodeType.END, end_data)}, + ] + edges = [ + {"source": "start", "target": "tool_a"}, + {"source": "tool_a", "target": "tool_b"}, + {"source": "tool_b", "target": "tool_c"}, + {"source": "tool_c", "target": "end"}, + ] + return {"nodes": nodes, "edges": edges} + + +def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph: + graph_config = _build_graph_config(pause_on=pause_on) + params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="service-api", + call_depth=0, + ) + + node_factory = DifyNodeFactory( + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + return Graph.init(graph_config=graph_config, node_factory=node_factory) + + +def _build_runtime_state(run_id: str) -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + variable_pool.system_variables.workflow_execution_id = run_id + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]: + command_channel = InMemoryChannel() + graph = _build_graph(runtime_state, pause_on=pause_on) + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=command_channel, + ) + + events: list[GraphEngineEvent] = [] + for event in engine.run(): + events.append(event) + return events + + +def _node_successes(events: list[GraphEngineEvent]) -> list[str]: + return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)] + + +def test_workflow_app_pause_resume_matches_baseline(mocker): + _patch_tool_node(mocker) + + baseline_state = _build_runtime_state("baseline") + baseline_events = _run_with_optional_pause(baseline_state, pause_on=None) + assert isinstance(baseline_events[-1], GraphRunSucceededEvent) + baseline_nodes = _node_successes(baseline_events) + baseline_outputs = baseline_state.outputs + + paused_state = _build_runtime_state("paused-run") + paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") + assert isinstance(paused_events[-1], GraphRunPausedEvent) + paused_nodes = _node_successes(paused_events) + snapshot = paused_state.dumps() + + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + + generator = wf_app_gen_module.WorkflowAppGenerator() + + def _fake_generate(**kwargs): + state: GraphRuntimeState = kwargs["graph_runtime_state"] + events = _run_with_optional_pause(state, pause_on=None) + return _node_successes(events) + + mocker.patch.object(generator, "_generate", side_effect=_fake_generate) + + resumed_nodes = generator.resume( + app_model=SimpleNamespace(mode="workflow"), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API), + graph_runtime_state=resumed_state, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + ) + + assert paused_nodes + resumed_nodes == baseline_nodes + assert resumed_state.outputs == baseline_outputs + + +def test_advanced_chat_pause_resume_matches_baseline(mocker): + _patch_tool_node(mocker) + + baseline_state = _build_runtime_state("adv-baseline") + baseline_events = _run_with_optional_pause(baseline_state, pause_on=None) + assert isinstance(baseline_events[-1], GraphRunSucceededEvent) + baseline_nodes = _node_successes(baseline_events) + baseline_outputs = baseline_state.outputs + + paused_state = _build_runtime_state("adv-paused") + paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") + assert isinstance(paused_events[-1], GraphRunPausedEvent) + paused_nodes = _node_successes(paused_events) + snapshot = paused_state.dumps() + + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + + generator = adv_app_gen_module.AdvancedChatAppGenerator() + + def _fake_generate(**kwargs): + state: GraphRuntimeState = kwargs["graph_runtime_state"] + events = _run_with_optional_pause(state, pause_on=None) + return _node_successes(events) + + mocker.patch.object(generator, "_generate", side_effect=_fake_generate) + + resumed_nodes = generator.resume( + app_model=SimpleNamespace(mode="workflow"), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + conversation=SimpleNamespace(id="conv"), + message=SimpleNamespace(id="msg"), + application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_runtime_state=resumed_state, + ) + + assert paused_nodes + resumed_nodes == baseline_nodes + assert resumed_state.outputs == baseline_outputs + + +def test_resume_emits_resumption_start_reason(mocker) -> None: + _patch_tool_node(mocker) + + paused_state = _build_runtime_state("resume-reason") + paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") + initial_start = next(event for event in paused_events if isinstance(event, GraphRunStartedEvent)) + assert initial_start.reason == WorkflowStartReason.INITIAL + + resumed_state = GraphRuntimeState.from_snapshot(paused_state.dumps()) + resumed_events = _run_with_optional_pause(resumed_state, pause_on=None) + resume_start = next(event for event in resumed_events if isinstance(event, GraphRunStartedEvent)) + assert resume_start.reason == WorkflowStartReason.RESUMPTION diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py new file mode 100644 index 0000000000..7b5447c01e --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import json +import queue + +import pytest + +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.task_entities import StreamEvent +from models.model import AppMode + + +class FakeSubscription: + def __init__(self, message_queue: queue.Queue[bytes], state: dict[str, bool]) -> None: + self._queue = message_queue + self._state = state + self._closed = False + + def __enter__(self): + self._state["subscribed"] = True + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def close(self) -> None: + self._closed = True + + def receive(self, timeout: float | None = 0.1) -> bytes | None: + if self._closed: + return None + try: + if timeout is None: + return self._queue.get() + return self._queue.get(timeout=timeout) + except queue.Empty: + return None + + +class FakeTopic: + def __init__(self) -> None: + self._queue: queue.Queue[bytes] = queue.Queue() + self._state = {"subscribed": False} + + def subscribe(self) -> FakeSubscription: + return FakeSubscription(self._queue, self._state) + + def publish(self, payload: bytes) -> None: + self._queue.put(payload) + + @property + def subscribed(self) -> bool: + return self._state["subscribed"] + + +def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch): + topic = FakeTopic() + + def fake_get_response_topic(cls, app_mode, workflow_run_id): + return topic + + monkeypatch.setattr(MessageBasedAppGenerator, "get_response_topic", classmethod(fake_get_response_topic)) + + def on_subscribe() -> None: + assert topic.subscribed is True + event = {"event": StreamEvent.WORKFLOW_FINISHED.value} + topic.publish(json.dumps(event).encode()) + + generator = MessageBasedAppGenerator.retrieve_events( + AppMode.WORKFLOW, + "workflow-run-id", + idle_timeout=0.5, + on_subscribe=on_subscribe, + ) + + assert next(generator) == StreamEvent.PING.value + event = next(generator) + assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value + with pytest.raises(StopIteration): + next(generator) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py index 83ac3a5591..7e8367c6c4 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py @@ -1,3 +1,6 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator @@ -17,3 +20,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false(): args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False} assert WorkflowAppGenerator()._should_prepare_user_inputs(args) + + +def test_resume_delegates_to_generate(mocker): + generator = WorkflowAppGenerator() + mock_generate = mocker.patch.object(generator, "_generate", return_value="ok") + + application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger") + runtime_state = MagicMock(name="runtime-state") + pause_config = MagicMock(name="pause-config") + + result = generator.resume( + app_model=MagicMock(), + workflow=MagicMock(), + user=MagicMock(), + application_generate_entity=application_generate_entity, + graph_runtime_state=runtime_state, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + graph_engine_layers=("layer",), + pause_state_config=pause_config, + variable_loader=MagicMock(), + ) + + assert result == "ok" + mock_generate.assert_called_once() + kwargs = mock_generate.call_args.kwargs + assert kwargs["graph_runtime_state"] is runtime_state + assert kwargs["pause_state_config"] is pause_config + assert kwargs["streaming"] is False + assert kwargs["invoke_from"] == "debugger" + + +def test_generate_appends_pause_layer_and_forwards_state(mocker): + generator = WorkflowAppGenerator() + + mock_queue_manager = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager) + + fake_current_app = MagicMock() + fake_current_app._get_current_object.return_value = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app) + + mocker.patch( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert", + return_value="converted", + ) + mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response") + mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock()) + + pause_layer = MagicMock(name="pause-layer") + mocker.patch( + "core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", + return_value=pause_layer, + ) + + dummy_session = MagicMock() + dummy_session.close = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session) + + worker_kwargs: dict[str, object] = {} + + class DummyThread: + def __init__(self, target, kwargs): + worker_kwargs["target"] = target + worker_kwargs["kwargs"] = kwargs + + def start(self): + return None + + mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread) + + app_model = SimpleNamespace(mode="workflow") + app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf") + application_generate_entity = SimpleNamespace( + task_id="task", + user_id="user", + invoke_from="service-api", + app_config=app_config, + files=[], + stream=True, + workflow_execution_id="run", + ) + + graph_runtime_state = MagicMock() + + result = generator._generate( + app_model=app_model, + workflow=MagicMock(), + user=MagicMock(), + application_generate_entity=application_generate_entity, + invoke_from="service-api", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + streaming=True, + graph_engine_layers=("base-layer",), + graph_runtime_state=graph_runtime_state, + pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"), + ) + + assert result == "converted" + assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer) + assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state + + +def test_resume_path_runs_worker_with_runtime_state(mocker): + generator = WorkflowAppGenerator() + runtime_state = MagicMock(name="runtime-state") + + pause_layer = MagicMock(name="pause-layer") + mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer) + + queue_manager = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager) + + mocker.patch.object(generator, "_handle_response", return_value="raw-response") + mocker.patch( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert", + side_effect=lambda response, invoke_from: response, + ) + + fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock()) + mocker.patch("core.app.apps.workflow.app_generator.db", fake_db) + + workflow = SimpleNamespace( + id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1" + ) + end_user = SimpleNamespace(session_id="end-user-session") + app_record = SimpleNamespace(id="app") + + session = MagicMock() + session.__enter__.return_value = session + session.__exit__.return_value = False + session.scalar.side_effect = [workflow, end_user, app_record] + mocker.patch("core.app.apps.workflow.app_generator.session_factory", return_value=session) + + runner_instance = MagicMock() + + def runner_ctor(**kwargs): + assert kwargs["graph_runtime_state"] is runtime_state + return runner_instance + + mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor) + + class ImmediateThread: + def __init__(self, target, kwargs): + target(**kwargs) + + def start(self): + return None + + mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread) + + mocker.patch( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner") + + app_model = SimpleNamespace(mode="workflow") + app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow") + application_generate_entity = SimpleNamespace( + task_id="task", + user_id="user", + invoke_from="service-api", + app_config=app_config, + files=[], + stream=True, + workflow_execution_id="run", + trace_manager=MagicMock(), + ) + + result = generator.resume( + app_model=app_model, + workflow=workflow, + user=MagicMock(), + application_generate_entity=application_generate_entity, + graph_runtime_state=runtime_state, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + pause_state_config=pause_config, + ) + + assert result == "raw-response" + runner_instance.run.assert_called_once() + queue_manager.graph_runtime_state = runtime_state diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py new file mode 100644 index 0000000000..f4efb240c0 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -0,0 +1,59 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.graph_events.graph import GraphRunPausedEvent + + +class _DummyQueueManager: + def __init__(self): + self.published = [] + + def publish(self, event, _from): + self.published.append(event) + + +class _DummyRuntimeState: + def get_paused_nodes(self): + return ["node-1"] + + +class _DummyGraphEngine: + def __init__(self): + self.graph_runtime_state = _DummyRuntimeState() + + +class _DummyWorkflowEntry: + def __init__(self): + self.graph_engine = _DummyGraphEngine() + + +def test_handle_pause_event_enqueues_email_task(monkeypatch: pytest.MonkeyPatch): + queue_manager = _DummyQueueManager() + runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app-id") + workflow_entry = _DummyWorkflowEntry() + + reason = HumanInputRequired( + form_id="form-123", + form_content="content", + inputs=[], + actions=[], + node_id="node-1", + node_title="Review", + ) + event = GraphRunPausedEvent(reasons=[reason], outputs={}) + + email_task = MagicMock() + monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", email_task) + + runner._handle_event(workflow_entry, event) + + email_task.apply_async.assert_called_once() + kwargs = email_task.apply_async.call_args.kwargs["kwargs"] + assert kwargs["form_id"] == "form-123" + assert kwargs["node_title"] == "Review" + + assert any(isinstance(evt, QueueWorkflowPausedEvent) for evt in queue_manager.published) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py new file mode 100644 index 0000000000..c30b925d88 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -0,0 +1,183 @@ +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.apps.common import workflow_response_converter +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse +from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.graph_events.graph import GraphRunPausedEvent +from core.workflow.nodes.human_input.entities import FormInput, UserAction +from core.workflow.nodes.human_input.enums import FormInputType +from core.workflow.system_variable import SystemVariable +from models.account import Account + + +class _RecordingWorkflowAppRunner(WorkflowAppRunner): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.published_events = [] + + def _publish_event(self, event): + self.published_events.append(event) + + +class _FakeRuntimeState: + def get_paused_nodes(self): + return ["node-pause-1"] + + +def _build_runner(): + app_entity = SimpleNamespace( + app_config=SimpleNamespace(app_id="app-id"), + inputs={}, + files=[], + invoke_from=InvokeFrom.SERVICE_API, + single_iteration_run=None, + single_loop_run=None, + workflow_execution_id="run-id", + user_id="user-id", + ) + workflow = SimpleNamespace( + graph_dict={}, + tenant_id="tenant-id", + environment_variables={}, + id="workflow-id", + ) + queue_manager = SimpleNamespace(publish=lambda event, pub_from: None) + return _RecordingWorkflowAppRunner( + application_generate_entity=app_entity, + queue_manager=queue_manager, + variable_loader=MagicMock(), + workflow=workflow, + system_user_id="sys-user", + root_node_id=None, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + +def test_graph_run_paused_event_emits_queue_pause_event(): + runner = _build_runner() + reason = HumanInputRequired( + form_id="form-1", + form_content="content", + inputs=[], + actions=[], + node_id="node-human", + node_title="Human Step", + form_token="tok", + ) + event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"}) + workflow_entry = SimpleNamespace( + graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()), + ) + + runner._handle_event(workflow_entry, event) + + assert len(runner.published_events) == 1 + queue_event = runner.published_events[0] + assert isinstance(queue_event, QueueWorkflowPausedEvent) + assert queue_event.reasons == [reason] + assert queue_event.outputs == {"foo": "bar"} + assert queue_event.paused_nodes == ["node-pause-1"] + + +def _build_converter(): + application_generate_entity = SimpleNamespace( + inputs={}, + files=[], + invoke_from=InvokeFrom.SERVICE_API, + app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), + ) + system_variables = SystemVariable( + user_id="user", + app_id="app-id", + workflow_id="workflow-id", + workflow_execution_id="run-id", + ) + user = MagicMock(spec=Account) + user.id = "account-id" + user.name = "Tester" + user.email = "tester@example.com" + return WorkflowResponseConverter( + application_generate_entity=application_generate_entity, + user=user, + system_variables=system_variables, + ) + + +def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.MonkeyPatch): + converter = _build_converter() + converter.workflow_start_to_stream_response( + task_id="task", + workflow_run_id="run-id", + workflow_id="workflow-id", + reason=WorkflowStartReason.INITIAL, + ) + + expiration_time = datetime(2024, 1, 1, tzinfo=UTC) + + class _FakeSession: + def execute(self, _stmt): + return [("form-1", expiration_time)] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession()) + monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object())) + + reason = HumanInputRequired( + form_id="form-1", + form_content="Rendered", + inputs=[ + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), + ], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + node_id="node-id", + node_title="Human Step", + form_token="token", + ) + queue_event = QueueWorkflowPausedEvent( + reasons=[reason], + outputs={"answer": "value"}, + paused_nodes=["node-id"], + ) + + runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0) + responses = converter.workflow_pause_to_stream_response( + event=queue_event, + task_id="task", + graph_runtime_state=runtime_state, + ) + + assert isinstance(responses[-1], WorkflowPauseStreamResponse) + pause_resp = responses[-1] + assert pause_resp.workflow_run_id == "run-id" + assert pause_resp.data.paused_nodes == ["node-id"] + assert pause_resp.data.outputs == {} + assert pause_resp.data.reasons[0]["form_id"] == "form-1" + assert pause_resp.data.reasons[0]["display_in_ui"] is True + + assert isinstance(responses[0], HumanInputRequiredResponse) + hi_resp = responses[0] + assert hi_resp.data.form_id == "form-1" + assert hi_resp.data.node_id == "node-id" + assert hi_resp.data.node_title == "Human Step" + assert hi_resp.data.inputs[0].output_variable_name == "field" + assert hi_resp.data.actions[0].id == "approve" + assert hi_resp.data.display_in_ui is True + assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py new file mode 100644 index 0000000000..32cb1ed47c --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -0,0 +1,96 @@ +import time +from contextlib import contextmanager +from unittest.mock import MagicMock + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import QueueWorkflowStartedEvent +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from models.account import Account +from models.model import AppMode + + +def _build_workflow_app_config() -> WorkflowUIBasedAppConfig: + return WorkflowUIBasedAppConfig( + tenant_id="tenant-id", + app_id="app-id", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-id", + ) + + +def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity: + return WorkflowAppGenerateEntity( + task_id="task-id", + app_config=_build_workflow_app_config(), + inputs={}, + files=[], + user_id="user-id", + stream=False, + invoke_from=InvokeFrom.SERVICE_API, + workflow_execution_id=run_id, + ) + + +def _build_runtime_state(run_id: str) -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable(workflow_execution_id=run_id), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +@contextmanager +def _noop_session(): + yield MagicMock() + + +def _build_pipeline(run_id: str) -> WorkflowAppGenerateTaskPipeline: + queue_manager = MagicMock(spec=AppQueueManager) + queue_manager.invoke_from = InvokeFrom.SERVICE_API + queue_manager.graph_runtime_state = _build_runtime_state(run_id) + workflow = MagicMock() + workflow.id = "workflow-id" + workflow.features_dict = {} + user = Account(name="user", email="user@example.com") + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=_build_generate_entity(run_id), + workflow=workflow, + queue_manager=queue_manager, + user=user, + stream=False, + draft_var_saver_factory=MagicMock(), + ) + pipeline._database_session = _noop_session + return pipeline + + +def test_workflow_app_log_saved_only_on_initial_start() -> None: + run_id = "run-initial" + pipeline = _build_pipeline(run_id) + pipeline._save_workflow_app_log = MagicMock() + + event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.INITIAL) + list(pipeline._handle_workflow_started_event(event)) + + pipeline._save_workflow_app_log.assert_called_once() + _, kwargs = pipeline._save_workflow_app_log.call_args + assert kwargs["workflow_run_id"] == run_id + assert pipeline._workflow_execution_id == run_id + + +def test_workflow_app_log_skipped_on_resumption_start() -> None: + run_id = "run-resume" + pipeline = _build_pipeline(run_id) + pipeline._save_workflow_app_log = MagicMock() + + event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.RESUMPTION) + list(pipeline._handle_workflow_started_event(event)) + + pipeline._save_workflow_app_log.assert_not_called() + assert pipeline._workflow_execution_id == run_id diff --git a/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py b/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py new file mode 100644 index 0000000000..86c80985c4 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py @@ -0,0 +1,143 @@ +import json +from collections.abc import Callable +from dataclasses import dataclass + +import pytest + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, + WorkflowAppGenerateEntity, +) +from core.app.layers.pause_state_persist_layer import ( + WorkflowResumptionContext, + _AdvancedChatAppGenerateEntityWrapper, + _WorkflowGenerateEntityWrapper, +) +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TraceQueueManagerStub(TraceQueueManager): + """Minimal TraceQueueManager stub that avoids Flask dependencies.""" + + def __init__(self): + # Skip parent initialization to avoid starting timers or accessing Flask globals. + pass + + +def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig: + return WorkflowUIBasedAppConfig( + tenant_id="tenant-id", + app_id="app-id", + app_mode=app_mode, + workflow_id=f"{app_mode.value}-workflow-id", + ) + + +def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity: + return WorkflowAppGenerateEntity( + task_id="workflow-task", + app_config=_build_workflow_app_config(AppMode.WORKFLOW), + inputs={"topic": "serialization"}, + files=[], + user_id="user-workflow", + stream=True, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=1, + trace_manager=trace_manager, + workflow_execution_id="workflow-exec-id", + extras={"external_trace_id": "trace-id"}, + ) + + +def _create_advanced_chat_generate_entity( + trace_manager: TraceQueueManager | None = None, +) -> AdvancedChatAppGenerateEntity: + return AdvancedChatAppGenerateEntity( + task_id="advanced-task", + app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT), + conversation_id="conversation-id", + inputs={"topic": "roundtrip"}, + files=[], + user_id="user-advanced", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + query="Explain serialization", + extras={"auto_generate_conversation_name": True}, + trace_manager=trace_manager, + workflow_run_id="workflow-run-id", + ) + + +def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager(): + entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub()) + + serialized = entity.model_dump_json() + payload = json.loads(serialized) + + assert "trace_manager" not in payload + + restored = WorkflowAppGenerateEntity.model_validate_json(serialized) + + assert restored.model_dump() == entity.model_dump() + assert restored.trace_manager is None + + +def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager(): + entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub()) + + serialized = entity.model_dump_json() + payload = json.loads(serialized) + + assert "trace_manager" not in payload + + restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized) + + assert restored.model_dump() == entity.model_dump() + assert restored.trace_manager is None + + +@dataclass(frozen=True) +class ResumptionContextCase: + name: str + context_factory: Callable[[], tuple[WorkflowResumptionContext, type]] + + +def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]: + entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub()) + context = WorkflowResumptionContext( + serialized_graph_runtime_state=json.dumps({"state": "workflow"}), + generate_entity=_WorkflowGenerateEntityWrapper(entity=entity), + ) + return context, WorkflowAppGenerateEntity + + +def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]: + entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub()) + context = WorkflowResumptionContext( + serialized_graph_runtime_state=json.dumps({"state": "advanced"}), + generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity), + ) + return context, AdvancedChatAppGenerateEntity + + +@pytest.mark.parametrize( + "case", + [ + pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"), + pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"), + ], +) +def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase): + context, expected_type = case.context_factory() + + serialized = context.dumps() + restored = WorkflowResumptionContext.loads(serialized) + + assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state + entity = restored.get_generate_entity() + assert isinstance(entity, expected_type) + assert entity.model_dump() == context.get_generate_entity().model_dump() + assert entity.trace_manager is None diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py index 91352b2a5f..cfdeef6a8d 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py +++ b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py @@ -101,3 +101,26 @@ def test__normalize_non_stream_plugin_result__empty_iterator_defaults(): assert result.message.tool_calls == [] assert result.usage == LLMUsage.empty_usage() assert result.system_fingerprint is None + + +def test__normalize_non_stream_plugin_result__closes_chunk_iterator(): + prompt_messages = [UserPromptMessage(content="hi")] + + chunk = _make_chunk(content="hello", usage=LLMUsage.empty_usage()) + closed: list[bool] = [] + + def _chunk_iter(): + try: + yield chunk + yield _make_chunk(content="ignored", usage=LLMUsage.empty_usage()) + finally: + closed.append(True) + + result = _normalize_non_stream_plugin_result( + model="test-model", + prompt_messages=prompt_messages, + result=_chunk_iter(), + ) + + assert result.message.content == "hello" + assert closed == [True] diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py new file mode 100644 index 0000000000..a380149554 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -0,0 +1,72 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig +from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation +from models.model import AppMode + + +def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker): + workflow = MagicMock() + workflow.created_by = "owner-id" + + app = MagicMock() + app.mode = AppMode.ADVANCED_CHAT + app.workflow = workflow + + mocker.patch( + "core.plugin.backwards_invocation.app.db", + SimpleNamespace(engine=MagicMock()), + ) + generator_spy = mocker.patch( + "core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate", + return_value={"result": "ok"}, + ) + + result = PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={"k": "v"}, + files=[], + ) + + assert result == {"result": "ok"} + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert isinstance(pause_state_config, PauseStateLayerConfig) + assert pause_state_config.state_owner_user_id == "owner-id" + + +def test_invoke_workflow_app_injects_pause_state_config(mocker): + workflow = MagicMock() + workflow.created_by = "owner-id" + + app = MagicMock() + app.mode = AppMode.WORKFLOW + app.workflow = workflow + + mocker.patch( + "core.plugin.backwards_invocation.app.db", + SimpleNamespace(engine=MagicMock()), + ) + generator_spy = mocker.patch( + "core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate", + return_value={"result": "ok"}, + ) + + result = PluginAppBackwardsInvocation.invoke_workflow_app( + app=app, + user=MagicMock(), + stream=False, + inputs={"k": "v"}, + files=[], + ) + + assert result == {"result": "ok"} + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert isinstance(pause_state_config, PauseStateLayerConfig) + assert pause_state_config.state_owner_user_id == "owner-id" diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index f9e59a5f05..0792ada194 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,7 +1,9 @@ """Primarily used for testing merged cell scenarios""" +import io import os import tempfile +from pathlib import Path from types import SimpleNamespace from docx import Document @@ -56,6 +58,42 @@ def test_parse_row(): assert extractor._parse_row(row, {}, 3) == gt[idx] +def test_init_downloads_via_ssrf_proxy(monkeypatch): + doc = Document() + doc.add_paragraph("hello") + buf = io.BytesIO() + doc.save(buf) + docx_bytes = buf.getvalue() + + calls: list[tuple[str, object]] = [] + + class FakeResponse: + status_code = 200 + content = docx_bytes + + def close(self) -> None: + calls.append(("close", None)) + + def fake_get(url: str, **kwargs): + calls.append(("get", (url, kwargs))) + return FakeResponse() + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + + extractor = WordExtractor("https://example.com/test.docx", "tenant_id", "user_id") + try: + assert calls + assert calls[0][0] == "get" + url, kwargs = calls[0][1] + assert url == "https://example.com/test.docx" + assert kwargs.get("timeout") is None + assert extractor.web_path == "https://example.com/test.docx" + assert extractor.file_path != extractor.web_path + assert Path(extractor.file_path).read_bytes() == docx_bytes + finally: + extractor.temp_file.close() + + def test_extract_images_from_docx(monkeypatch): external_bytes = b"ext-bytes" internal_bytes = b"int-bytes" diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py new file mode 100644 index 0000000000..811ed2143b --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -0,0 +1,574 @@ +"""Unit tests for HumanInputFormRepositoryImpl private helpers.""" + +from __future__ import annotations + +import dataclasses +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormRepositoryImpl, + HumanInputFormSubmissionRepository, + _WorkspaceMemberInfo, +) +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + FormDefinition, + MemberRecipient, + UserAction, +) +from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from libs.datetime_utils import naive_utc_now +from models.human_input import ( + EmailExternalRecipientPayload, + EmailMemberRecipientPayload, + HumanInputFormRecipient, + RecipientType, +) + + +def _build_repository() -> HumanInputFormRepositoryImpl: + return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id") + + +def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]: + created: list[SimpleNamespace] = [] + + def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def] + recipient = SimpleNamespace( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=payload.TYPE, + recipient_payload=payload.model_dump_json(), + ) + created.append(recipient) + return recipient + + monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new)) + return created + + +@pytest.fixture(autouse=True) +def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None: + """Avoid SQLAlchemy mapper configuration in tests using fake sessions.""" + + class _FakeSelect: + def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def] + return self + + def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def] + return self + + monkeypatch.setattr( + "core.repositories.human_input_repository.selectinload", lambda *args, **kwargs: "_loader_option" + ) + monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *args, **kwargs: _FakeSelect()) + + +class TestHumanInputFormRepositoryImplHelpers: + def test_build_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert restrict_to_user_ids == ["member-1"] + return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) + + recipients = repo._build_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(user_id="member-1"), + ExternalRecipient(email="external@example.com"), + ], + ), + ) + + assert len(recipients) == 2 + member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER) + external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL) + + member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload) + assert member_payload.user_id == "member-1" + assert member_payload.email == "member@example.com" + + external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload) + assert external_payload.email == "external@example.com" + + def test_build_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + created = _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert restrict_to_user_ids == ["missing-member"] + return [] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) + + recipients = repo._build_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(user_id="missing-member"), + ExternalRecipient(email="external@example.com"), + ], + ), + ) + + assert len(recipients) == 1 + assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL + assert len(created) == 1 # only external recipient created via factory + + def test_build_email_recipients_whole_workspace_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + _patch_recipient_factory(monkeypatch) + + def fake_query(self, session): # type: ignore[no-untyped-def] + assert session is session_stub + return [ + _WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"), + _WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"), + ] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query) + + recipients = repo._build_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients_config=EmailRecipients( + whole_workspace=True, + items=[], + ), + ) + + assert len(recipients) == 2 + emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients} + assert emails == {"member1@example.com", "member2@example.com"} + + def test_build_email_recipients_dedupes_external_by_email(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + created = _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert restrict_to_user_ids == [] + return [] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) + + recipients = repo._build_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[ + ExternalRecipient(email="external@example.com"), + ExternalRecipient(email="external@example.com"), + ], + ), + ) + + assert len(recipients) == 1 + assert len(created) == 1 + + def test_build_email_recipients_prefers_member_over_external_by_email( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + repo = _build_repository() + session_stub = object() + _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert restrict_to_user_ids == ["member-1"] + return [_WorkspaceMemberInfo(user_id="member-1", email="shared@example.com")] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) + + recipients = repo._build_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(user_id="member-1"), + ExternalRecipient(email="shared@example.com"), + ], + ), + ) + + assert len(recipients) == 1 + assert recipients[0].recipient_type == RecipientType.EMAIL_MEMBER + + def test_delivery_method_to_model_includes_external_recipients_with_whole_workspace( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + repo = _build_repository() + session_stub = object() + _patch_recipient_factory(monkeypatch) + + def fake_query(self, session): # type: ignore[no-untyped-def] + assert session is session_stub + return [ + _WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"), + _WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"), + ] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query) + + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=True, + items=[ExternalRecipient(email="external@example.com")], + ), + subject="subject", + body="body", + ) + ) + + result = repo._delivery_method_to_model(session=session_stub, form_id="form-id", delivery_method=method) + + assert len(result.recipients) == 3 + member_emails = { + EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email + for r in result.recipients + if r.recipient_type == RecipientType.EMAIL_MEMBER + } + assert member_emails == {"member1@example.com", "member2@example.com"} + external_payload = EmailExternalRecipientPayload.model_validate_json( + next(r for r in result.recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL).recipient_payload + ) + assert external_payload.email == "external@example.com" + + +def _make_form_definition() -> str: + return FormDefinition( + form_content="hello", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + rendered_content="

hello

", + expiration_time=datetime.utcnow(), + ).model_dump_json() + + +@dataclasses.dataclass +class _DummyForm: + id: str + workflow_run_id: str + node_id: str + tenant_id: str + app_id: str + form_definition: str + rendered_content: str + expiration_time: datetime + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + created_at: datetime = dataclasses.field(default_factory=naive_utc_now) + selected_action_id: str | None = None + submitted_data: str | None = None + submitted_at: datetime | None = None + submission_user_id: str | None = None + submission_end_user_id: str | None = None + completed_by_recipient_id: str | None = None + status: HumanInputFormStatus = HumanInputFormStatus.WAITING + + +@dataclasses.dataclass +class _DummyRecipient: + id: str + form_id: str + recipient_type: RecipientType + access_token: str + form: _DummyForm | None = None + + +class _FakeScalarResult: + def __init__(self, obj): + self._obj = obj + + def first(self): + if isinstance(self._obj, list): + return self._obj[0] if self._obj else None + return self._obj + + def all(self): + if isinstance(self._obj, list): + return list(self._obj) + if self._obj is None: + return [] + return [self._obj] + + +class _FakeSession: + def __init__( + self, + *, + scalars_result=None, + scalars_results: list[object] | None = None, + forms: dict[str, _DummyForm] | None = None, + recipients: dict[str, _DummyRecipient] | None = None, + ): + if scalars_results is not None: + self._scalars_queue = list(scalars_results) + elif scalars_result is not None: + self._scalars_queue = [scalars_result] + else: + self._scalars_queue = [] + self.forms = forms or {} + self.recipients = recipients or {} + + def scalars(self, _query): + if self._scalars_queue: + result = self._scalars_queue.pop(0) + else: + result = None + return _FakeScalarResult(result) + + def get(self, model_cls, obj_id): # type: ignore[no-untyped-def] + if getattr(model_cls, "__name__", None) == "HumanInputForm": + return self.forms.get(obj_id) + if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient": + return self.recipients.get(obj_id) + return None + + def add(self, _obj): + return None + + def flush(self): + return None + + def refresh(self, _obj): + return None + + def begin(self): + return self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +def _session_factory(session: _FakeSession): + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return None + + def _factory(*_args, **_kwargs): + return _SessionContext() + + return _factory + + +class TestHumanInputFormRepositoryImplPublicMethods: + def test_get_form_returns_entity_and_recipients(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-id", + app_id="app-id", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="token-123", + ) + session = _FakeSession(scalars_results=[form, [recipient]]) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + entity = repo.get_form(form.workflow_run_id, form.node_id) + + assert entity is not None + assert entity.id == form.id + assert entity.web_app_token == "token-123" + assert len(entity.recipients) == 1 + assert entity.recipients[0].token == "token-123" + + def test_get_form_returns_none_when_missing(self): + session = _FakeSession(scalars_results=[None]) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + assert repo.get_form("run-1", "node-1") is None + + def test_get_form_returns_unsubmitted_state(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-id", + app_id="app-id", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=naive_utc_now(), + ) + session = _FakeSession(scalars_results=[form, []]) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + entity = repo.get_form(form.workflow_run_id, form.node_id) + + assert entity is not None + assert entity.submitted is False + assert entity.selected_action_id is None + assert entity.submitted_data is None + + def test_get_form_returns_submission_when_completed(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-id", + app_id="app-id", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=naive_utc_now(), + selected_action_id="approve", + submitted_data='{"field": "value"}', + submitted_at=naive_utc_now(), + ) + session = _FakeSession(scalars_results=[form, []]) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + entity = repo.get_form(form.workflow_run_id, form.node_id) + + assert entity is not None + assert entity.submitted is True + assert entity.selected_action_id == "approve" + assert entity.submitted_data == {"field": "value"} + + +class TestHumanInputFormSubmissionRepository: + def test_get_by_token_returns_record(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-1", + app_id="app-1", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="token-123", + form=form, + ) + session = _FakeSession(scalars_result=recipient) + repo = HumanInputFormSubmissionRepository(_session_factory(session)) + + record = repo.get_by_token("token-123") + + assert record is not None + assert record.form_id == form.id + assert record.recipient_type == RecipientType.STANDALONE_WEB_APP + assert record.submitted is False + + def test_get_by_form_id_and_recipient_type_uses_recipient(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-1", + app_id="app-1", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="token-123", + form=form, + ) + session = _FakeSession(scalars_result=recipient) + repo = HumanInputFormSubmissionRepository(_session_factory(session)) + + record = repo.get_by_form_id_and_recipient_type( + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + ) + + assert record is not None + assert record.recipient_id == recipient.id + assert record.access_token == recipient.access_token + + def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch): + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) + + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-1", + app_id="app-1", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=fixed_now, + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id="form-1", + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="token-123", + ) + session = _FakeSession( + forms={form.id: form}, + recipients={recipient.id: recipient}, + ) + repo = HumanInputFormSubmissionRepository(_session_factory(session)) + + record: HumanInputFormRecord = repo.mark_submitted( + form_id=form.id, + recipient_id=recipient.id, + selected_action_id="approve", + form_data={"field": "value"}, + submission_user_id="user-1", + submission_end_user_id="end-user-1", + ) + + assert form.selected_action_id == "approve" + assert form.completed_by_recipient_id == recipient.id + assert form.submission_user_id == "user-1" + assert form.submission_end_user_id == "end-user-1" + assert form.submitted_at == fixed_now + assert record.submitted is True + assert record.selected_action_id == "approve" + assert record.submitted_data == {"field": "value"} diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py new file mode 100644 index 0000000000..c46e31d90f --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -0,0 +1,33 @@ +import pytest + +from core.tools.errors import WorkflowToolHumanInputNotSupportedError +from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils + + +def test_ensure_no_human_input_nodes_passes_for_non_human_input(): + graph = { + "nodes": [ + { + "id": "start_node", + "data": {"type": "start"}, + } + ] + } + + WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) + + +def test_ensure_no_human_input_nodes_raises_for_human_input(): + graph = { + "nodes": [ + { + "id": "human_input_node", + "data": {"type": "human-input"}, + } + ] + } + + with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: + WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) + + assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index cd45292488..bbedfdb6ae 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -55,6 +55,43 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel assert exc_info.value.args == ("oops",) +def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch): + entity = ToolEntity( + identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), + parameters=[], + description=None, + has_runtime_parameters=False, + ) + runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) + tool = WorkflowTool( + workflow_app_id="", + workflow_as_tool_id="", + version="1", + workflow_entities={}, + workflow_call_depth=1, + entity=entity, + runtime=runtime, + ) + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + from unittest.mock import MagicMock, Mock + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list(tool.invoke("test_user", {})) + + call_kwargs = generate_mock.call_args.kwargs + assert "pause_state_config" in call_kwargs + assert call_kwargs["pause_state_config"] is None + + def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should generate variable messages when there are outputs""" entity = ToolEntity( diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index deff06fc5d..1b6d03e36a 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -118,7 +118,6 @@ class TestGraphRuntimeState: from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue assert isinstance(queue, InMemoryReadyQueue) - assert state.ready_queue is queue def test_graph_execution_lazy_instantiation(self): state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py new file mode 100644 index 0000000000..6144df06e0 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py @@ -0,0 +1,88 @@ +""" +Tests for PauseReason discriminated union serialization/deserialization. +""" + +import pytest +from pydantic import BaseModel, ValidationError + +from core.workflow.entities.pause_reason import ( + HumanInputRequired, + PauseReason, + SchedulingPause, +) + + +class _Holder(BaseModel): + """Helper model that embeds PauseReason for union tests.""" + + reason: PauseReason + + +class TestPauseReasonDiscriminator: + """Test suite for PauseReason union discriminator.""" + + @pytest.mark.parametrize( + ("dict_value", "expected"), + [ + pytest.param( + { + "reason": { + "TYPE": "human_input_required", + "form_id": "form_id", + "form_content": "form_content", + "node_id": "node_id", + "node_title": "node_title", + }, + }, + HumanInputRequired( + form_id="form_id", + form_content="form_content", + node_id="node_id", + node_title="node_title", + ), + id="HumanInputRequired", + ), + pytest.param( + { + "reason": { + "TYPE": "scheduled_pause", + "message": "Hold on", + } + }, + SchedulingPause(message="Hold on"), + id="SchedulingPause", + ), + ], + ) + def test_model_validate(self, dict_value, expected): + """Ensure scheduled pause payloads with lowercase TYPE deserialize.""" + holder = _Holder.model_validate(dict_value) + + assert type(holder.reason) == type(expected) + assert holder.reason == expected + + @pytest.mark.parametrize( + "reason", + [ + HumanInputRequired( + form_id="form_id", + form_content="form_content", + node_id="node_id", + node_title="node_title", + ), + SchedulingPause(message="Hold on"), + ], + ids=lambda x: type(x).__name__, + ) + def test_model_construct(self, reason): + holder = _Holder(reason=reason) + assert holder.reason == reason + + def test_model_construct_with_invalid_type(self): + with pytest.raises(ValidationError): + holder = _Holder(reason=object()) # type: ignore + + def test_unknown_type_fails_validation(self): + """Unknown TYPE values should raise a validation error.""" + with pytest.raises(ValidationError): + _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py new file mode 100644 index 0000000000..2ef23c7f0f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py @@ -0,0 +1,131 @@ +"""Utilities for testing HumanInputNode without database dependencies.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any + +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) +from libs.datetime_utils import naive_utc_now + + +class _InMemoryFormRecipient(HumanInputFormRecipientEntity): + """Minimal recipient entity required by the repository interface.""" + + def __init__(self, recipient_id: str, token: str) -> None: + self._id = recipient_id + self._token = token + + @property + def id(self) -> str: + return self._id + + @property + def token(self) -> str: + return self._token + + +@dataclass +class _InMemoryFormEntity(HumanInputFormEntity): + form_id: str + rendered: str + token: str | None = None + action_id: str | None = None + data: Mapping[str, Any] | None = None + is_submitted: bool = False + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = naive_utc_now() + + @property + def id(self) -> str: + return self.form_id + + @property + def web_app_token(self) -> str | None: + return self.token + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class InMemoryHumanInputFormRepository(HumanInputFormRepository): + """Pure in-memory repository used by workflow graph engine tests.""" + + def __init__(self) -> None: + self._form_counter = 0 + self.created_params: list[FormCreateParams] = [] + self.created_forms: list[_InMemoryFormEntity] = [] + self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + self.created_params.append(params) + self._form_counter += 1 + form_id = f"form-{self._form_counter}" + token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}" + entity = _InMemoryFormEntity( + form_id=form_id, + rendered=params.rendered_content, + token=token, + ) + self.created_forms.append(entity) + self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity + return entity + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_key.get((workflow_execution_id, node_id)) + + # Convenience helpers for tests ------------------------------------- + + def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: + """Simulate a human submission for the next repository lookup.""" + + if not self.created_forms: + raise AssertionError("no form has been created to attach submission data") + entity = self.created_forms[-1] + entity.action_id = action_id + entity.data = form_data or {} + entity.is_submitted = True + entity.status_value = HumanInputFormStatus.SUBMITTED + entity.expiration = naive_utc_now() + timedelta(days=1) + + def clear_submission(self) -> None: + if not self.created_forms: + return + for form in self.created_forms: + form.action_id = None + form.data = None + form.is_submitted = False + form.status_value = HumanInputFormStatus.WAITING diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py new file mode 100644 index 0000000000..6038a15211 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py @@ -0,0 +1,74 @@ +import queue +import threading +from datetime import datetime + +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher +from core.workflow.graph_events import NodeRunSucceededEvent +from core.workflow.node_events import NodeRunResult + + +class StubExecutionCoordinator: + def __init__(self, paused: bool) -> None: + self._paused = paused + self.mark_complete_called = False + self.failed_error: Exception | None = None + + @property + def aborted(self) -> bool: + return False + + @property + def paused(self) -> bool: + return self._paused + + @property + def execution_complete(self) -> bool: + return False + + def check_scaling(self) -> None: + return None + + def process_commands(self) -> None: + return None + + def mark_complete(self) -> None: + self.mark_complete_called = True + + def mark_failed(self, error: Exception) -> None: + self.failed_error = error + + +class StubEventHandler: + def __init__(self) -> None: + self.events: list[object] = [] + + def dispatch(self, event: object) -> None: + self.events.append(event) + + +def test_dispatcher_drains_events_when_paused() -> None: + event_queue: queue.Queue = queue.Queue() + event = NodeRunSucceededEvent( + id="exec-1", + node_id="node-1", + node_type=NodeType.START, + start_at=datetime.utcnow(), + node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), + ) + event_queue.put(event) + + handler = StubEventHandler() + coordinator = StubExecutionCoordinator(paused=True) + dispatcher = Dispatcher( + event_queue=event_queue, + event_handler=handler, + execution_coordinator=coordinator, + event_emitter=None, + stop_event=threading.Event(), + ) + + dispatcher._dispatcher_loop() + + assert handler.events == [event] + assert coordinator.mark_complete_called is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py index 0d67a76169..53de8908a8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock +import pytest + from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor from core.workflow.graph_engine.domain.graph_execution import GraphExecution from core.workflow.graph_engine.graph_state_manager import GraphStateManager @@ -48,3 +50,13 @@ def test_handle_pause_noop_when_execution_running() -> None: worker_pool.stop.assert_not_called() state_manager.clear_executing.assert_not_called() + + +def test_has_executing_nodes_requires_pause() -> None: + graph_execution = GraphExecution(workflow_id="workflow") + graph_execution.start() + + coordinator, _, _ = _build_coordinator(graph_execution) + + with pytest.raises(AssertionError): + coordinator.has_executing_nodes() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py new file mode 100644 index 0000000000..65d34c2009 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -0,0 +1,189 @@ +import time +from collections.abc import Mapping + +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.workflow.entities import GraphInitParams +from core.workflow.enums import NodeState +from core.workflow.graph import Graph +from core.workflow.graph_engine.graph_state_manager import GraphStateManager +from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + +from .test_mock_config import MockConfig +from .test_mock_nodes import MockLLMNode + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_llm_node( + *, + node_id: str, + runtime_state: GraphRuntimeState, + graph_init_params: GraphInitParams, + mock_config: MockConfig, +) -> MockLLMNode: + llm_data = LLMNodeData( + title=f"LLM {node_id}", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text=f"Prompt {node_id}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + ) + llm_config = {"id": node_id, "data": llm_data.model_dump()} + return MockLLMNode( + id=llm_config["id"], + config=llm_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + mock_config=mock_config, + ) + + +def _build_graph(runtime_state: GraphRuntimeState) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + mock_config = MockConfig() + llm_a = _build_llm_node( + node_id="llm_a", + runtime_state=runtime_state, + graph_init_params=graph_init_params, + mock_config=mock_config, + ) + llm_b = _build_llm_node( + node_id="llm_b", + runtime_state=runtime_state, + graph_init_params=graph_init_params, + mock_config=mock_config, + ) + + end_data = EndNodeData(title="End", outputs=[], desc=None) + end_config = {"id": "end", "data": end_data.model_dump()} + end_node = EndNode( + id=end_config["id"], + config=end_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + builder = ( + Graph.new() + .add_root(start_node) + .add_node(llm_a, from_node_id="start") + .add_node(llm_b, from_node_id="start") + .add_node(end_node, from_node_id="llm_a") + ) + return builder.connect(tail="llm_b", head="end").build() + + +def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]: + return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()} + + +def test_runtime_state_snapshot_restores_graph_states() -> None: + runtime_state = _build_runtime_state() + graph = _build_graph(runtime_state) + runtime_state.attach_graph(graph) + + graph.nodes["llm_a"].state = NodeState.TAKEN + graph.nodes["llm_b"].state = NodeState.SKIPPED + + for edge in graph.edges.values(): + if edge.tail == "start" and edge.head == "llm_a": + edge.state = NodeState.TAKEN + elif edge.tail == "start" and edge.head == "llm_b": + edge.state = NodeState.SKIPPED + elif edge.head == "end" and edge.tail == "llm_a": + edge.state = NodeState.TAKEN + elif edge.head == "end" and edge.tail == "llm_b": + edge.state = NodeState.SKIPPED + + snapshot = runtime_state.dumps() + + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + resumed_graph = _build_graph(resumed_state) + resumed_state.attach_graph(resumed_graph) + + assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN + assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED + assert _edge_state_map(resumed_graph) == _edge_state_map(graph) + + +def test_join_readiness_uses_restored_edge_states() -> None: + runtime_state = _build_runtime_state() + graph = _build_graph(runtime_state) + runtime_state.attach_graph(graph) + + ready_queue = InMemoryReadyQueue() + state_manager = GraphStateManager(graph, ready_queue) + + for edge in graph.get_incoming_edges("end"): + if edge.tail == "llm_a": + edge.state = NodeState.TAKEN + if edge.tail == "llm_b": + edge.state = NodeState.UNKNOWN + + assert state_manager.is_node_ready("end") is False + + for edge in graph.get_incoming_edges("end"): + if edge.tail == "llm_b": + edge.state = NodeState.TAKEN + + assert state_manager.is_node_ready("end") is True + + snapshot = runtime_state.dumps() + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + resumed_graph = _build_graph(resumed_state) + resumed_state.attach_graph(resumed_graph) + + resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue()) + assert resumed_state_manager.is_node_ready("end") is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index c398e4e8c1..194d009288 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -1,5 +1,7 @@ +import datetime import time from collections.abc import Iterable +from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -14,11 +16,12 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) +from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input import HumanInputNode -from core.workflow.nodes.human_input.entities import HumanInputNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, @@ -28,15 +31,21 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode from .test_table_runner import TableTestRunner, WorkflowTestCase -def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: +def _build_branching_graph( + mock_config: MockConfig, + form_repository: HumanInputFormRepository, + graph_runtime_state: GraphRuntimeState | None = None, +) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} graph_init_params = GraphInitParams( tenant_id="tenant", @@ -49,12 +58,18 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime call_depth=0, ) - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + if graph_runtime_state is None: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="test-execution-id", + ), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( @@ -93,15 +108,21 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime human_data = HumanInputNodeData( title="Human Input", - required_variables=["human.input_ready"], - pause_reason="Awaiting human input", + form_content="Human input required", + inputs=[], + user_actions=[ + UserAction(id="primary", title="Primary"), + UserAction(id="secondary", title="Secondary"), + ], ) + human_config = {"id": "human", "data": human_data.model_dump()} human_node = HumanInputNode( id=human_config["id"], config=human_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + form_repository=form_repository, ) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") @@ -219,8 +240,18 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: for scenario in branch_scenarios: runner = TableTestRunner() - def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_branching_graph(mock_config) + mock_create_repo = MagicMock(spec=HumanInputFormRepository) + mock_create_repo.get_form.return_value = None + mock_form_entity = MagicMock(spec=HumanInputFormEntity) + mock_form_entity.id = "test_form_id" + mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.recipients = [] + mock_form_entity.rendered_content = "rendered" + mock_form_entity.submitted = False + mock_create_repo.create_form.return_value = mock_form_entity + + def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: + return _build_branching_graph(mock_config, mock_create_repo) initial_case = WorkflowTestCase( description="HumanInput pause before branching decision", @@ -242,23 +273,16 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: assert initial_result.success, initial_result.event_mismatch_details assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) - graph_runtime_state = initial_result.graph_runtime_state - graph = initial_result.graph - assert graph_runtime_state is not None - assert graph is not None - - graph_runtime_state.variable_pool.add(("human", "input_ready"), True) - graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"]) - graph_runtime_state.graph_execution.pause_reason = None - pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) + expected_pre_chunk_events_in_resumption = [ + GraphRunStartedEvent, + NodeRunStartedEvent, + NodeRunHumanInputFormFilledEvent, + ] expected_resume_sequence: list[type] = ( - [ - GraphRunStartedEvent, - NodeRunStartedEvent, - ] + expected_pre_chunk_events_in_resumption + [NodeRunStreamChunkEvent] * pre_chunk_count + [ NodeRunSucceededEvent, @@ -273,11 +297,25 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: ] ) + mock_get_repo = MagicMock(spec=HumanInputFormRepository) + submitted_form = MagicMock(spec=HumanInputFormEntity) + submitted_form.id = mock_form_entity.id + submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.recipients = [] + submitted_form.rendered_content = mock_form_entity.rendered_content + submitted_form.submitted = True + submitted_form.selected_action_id = scenario["handle"] + submitted_form.submitted_data = {} + submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) + mock_get_repo.get_form.return_value = submitted_form + def resume_graph_factory( - graph_snapshot: Graph = graph, - state_snapshot: GraphRuntimeState = graph_runtime_state, + initial_result=initial_result, mock_get_repo=mock_get_repo ) -> tuple[Graph, GraphRuntimeState]: - return graph_snapshot, state_snapshot + assert initial_result.graph_runtime_state is not None + serialized_runtime_state = initial_result.graph_runtime_state.dumps() + resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) + return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state) resume_case = WorkflowTestCase( description=f"HumanInput resumes via {scenario['handle']} branch", @@ -321,7 +359,8 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: for index, event in enumerate(resume_events) if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index ] - assert pre_indices == list(range(2, 2 + pre_chunk_count)) + expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption) + assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index)) resume_chunk_indices = [ index diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index ece69b080b..d8f229205b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -1,4 +1,6 @@ +import datetime import time +from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -13,11 +15,12 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) +from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input import HumanInputNode -from core.workflow.nodes.human_input.entities import HumanInputNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, @@ -27,15 +30,21 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode from .test_table_runner import TableTestRunner, WorkflowTestCase -def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: +def _build_llm_human_llm_graph( + mock_config: MockConfig, + form_repository: HumanInputFormRepository, + graph_runtime_state: GraphRuntimeState | None = None, +) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} graph_init_params = GraphInitParams( tenant_id="tenant", @@ -48,12 +57,15 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun call_depth=0, ) - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + if graph_runtime_state is None: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," + ), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( @@ -92,15 +104,21 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun human_data = HumanInputNodeData( title="Human Input", - required_variables=["human.input_ready"], - pause_reason="Awaiting human input", + form_content="Human input required", + inputs=[], + user_actions=[ + UserAction(id="accept", title="Accept"), + UserAction(id="reject", title="Reject"), + ], ) + human_config = {"id": "human", "data": human_data.model_dump()} human_node = HumanInputNode( id=human_config["id"], config=human_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + form_repository=form_repository, ) llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") @@ -130,7 +148,7 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun .add_root(start_node) .add_node(llm_first) .add_node(human_node) - .add_node(llm_second) + .add_node(llm_second, source_handle="accept") .add_node(end_node) .build() ) @@ -167,8 +185,18 @@ def test_human_input_llm_streaming_order_across_pause() -> None: GraphRunPausedEvent, # graph run pauses awaiting resume ] + mock_create_repo = MagicMock(spec=HumanInputFormRepository) + mock_create_repo.get_form.return_value = None + mock_form_entity = MagicMock(spec=HumanInputFormEntity) + mock_form_entity.id = "test_form_id" + mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.recipients = [] + mock_form_entity.rendered_content = "rendered" + mock_form_entity.submitted = False + mock_create_repo.create_form.return_value = mock_form_entity + def graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_llm_human_llm_graph(mock_config) + return _build_llm_human_llm_graph(mock_config, mock_create_repo) initial_case = WorkflowTestCase( description="HumanInput pause preserves LLM streaming order", @@ -210,6 +238,8 @@ def test_human_input_llm_streaming_order_across_pause() -> None: expected_resume_sequence: list[type] = [ GraphRunStartedEvent, # resumed graph run begins NodeRunStartedEvent, # human node restarts + # Form Filled should be generated first, then the node execution ends and stream chunk is generated. + NodeRunHumanInputFormFilledEvent, NodeRunStreamChunkEvent, # cached llm_initial chunk 1 NodeRunStreamChunkEvent, # cached llm_initial chunk 2 NodeRunStreamChunkEvent, # cached llm_initial final chunk @@ -225,12 +255,27 @@ def test_human_input_llm_streaming_order_across_pause() -> None: GraphRunSucceededEvent, # graph run succeeds after resume ] + mock_get_repo = MagicMock(spec=HumanInputFormRepository) + submitted_form = MagicMock(spec=HumanInputFormEntity) + submitted_form.id = mock_form_entity.id + submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.recipients = [] + submitted_form.rendered_content = mock_form_entity.rendered_content + submitted_form.submitted = True + submitted_form.selected_action_id = "accept" + submitted_form.submitted_data = {} + submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) + mock_get_repo.get_form.return_value = submitted_form + def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: - assert graph_runtime_state is not None - assert graph is not None - graph_runtime_state.variable_pool.add(("human", "input_ready"), True) - graph_runtime_state.graph_execution.pause_reason = None - return graph, graph_runtime_state + # restruct the graph runtime state + serialized_runtime_state = initial_result.graph_runtime_state.dumps() + resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) + return _build_llm_human_llm_graph( + mock_config, + mock_get_repo, + resume_runtime_state, + ) resume_case = WorkflowTestCase( description="HumanInput resume continues LLM streaming order", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py new file mode 100644 index 0000000000..a6aab81f6c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -0,0 +1,270 @@ +import time +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any, Protocol + +from core.workflow.entities import GraphInitParams +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.config import GraphEngineConfig +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunSucceededEvent, +) +from core.workflow.nodes.base.entities import OutputVariableEntity +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now + + +class PauseStateStore(Protocol): + def save(self, runtime_state: GraphRuntimeState) -> None: ... + + def load(self) -> GraphRuntimeState: ... + + +class InMemoryPauseStore: + def __init__(self) -> None: + self._snapshot: str | None = None + + def save(self, runtime_state: GraphRuntimeState) -> None: + self._snapshot = runtime_state.dumps() + + def load(self) -> GraphRuntimeState: + assert self._snapshot is not None + return GraphRuntimeState.from_snapshot(self._snapshot) + + +@dataclass +class StaticForm(HumanInputFormEntity): + form_id: str + rendered: str + is_submitted: bool + action_id: str | None = None + data: Mapping[str, Any] | None = None + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = naive_utc_now() + timedelta(days=1) + + @property + def id(self) -> str: + return self.form_id + + @property + def web_app_token(self) -> str | None: + return "token" + + @property + def recipients(self) -> list: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class StaticRepo(HumanInputFormRepository): + def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: + self._forms_by_node_id = dict(forms_by_node_id) + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + raise AssertionError("create_form should not be called in resume scenario") + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + human_data = HumanInputNodeData( + title="Human Input", + form_content="Human input required", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + ) + + human_a_config = {"id": "human_a", "data": human_data.model_dump()} + human_a = HumanInputNode( + id=human_a_config["id"], + config=human_a_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=repo, + ) + + human_b_config = {"id": "human_b", "data": human_data.model_dump()} + human_b = HumanInputNode( + id=human_b_config["id"], + config=human_b_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=repo, + ) + + end_data = EndNodeData( + title="End", + outputs=[ + OutputVariableEntity(variable="res_a", value_selector=["human_a", "__action_id"]), + OutputVariableEntity(variable="res_b", value_selector=["human_b", "__action_id"]), + ], + desc=None, + ) + end_config = {"id": "end", "data": end_data.model_dump()} + end_node = EndNode( + id=end_config["id"], + config=end_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + builder = ( + Graph.new() + .add_root(start_node) + .add_node(human_a, from_node_id="start") + .add_node(human_b, from_node_id="start") + .add_node(end_node, from_node_id="human_a", source_handle="approve") + ) + return builder.connect(tail="human_b", head="end", source_handle="approve").build() + + +def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[object]: + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig( + min_workers=2, + max_workers=2, + scale_up_threshold=1, + scale_down_idle_time=30.0, + ), + ) + return list(engine.run()) + + +def _form(submitted: bool, action_id: str | None) -> StaticForm: + return StaticForm( + form_id="form", + rendered="rendered", + is_submitted=submitted, + action_id=action_id, + data={}, + status_value=HumanInputFormStatus.SUBMITTED if submitted else HumanInputFormStatus.WAITING, + ) + + +def test_parallel_human_input_join_completes_after_second_resume() -> None: + pause_store: PauseStateStore = InMemoryPauseStore() + + initial_state = _build_runtime_state() + initial_repo = StaticRepo( + { + "human_a": _form(submitted=False, action_id=None), + "human_b": _form(submitted=False, action_id=None), + } + ) + initial_graph = _build_graph(initial_state, initial_repo) + initial_events = _run_graph(initial_graph, initial_state) + + assert isinstance(initial_events[-1], GraphRunPausedEvent) + pause_store.save(initial_state) + + first_resume_state = pause_store.load() + first_resume_repo = StaticRepo( + { + "human_a": _form(submitted=True, action_id="approve"), + "human_b": _form(submitted=False, action_id=None), + } + ) + first_resume_graph = _build_graph(first_resume_state, first_resume_repo) + first_resume_events = _run_graph(first_resume_graph, first_resume_state) + + assert isinstance(first_resume_events[0], GraphRunStartedEvent) + assert first_resume_events[0].reason is WorkflowStartReason.RESUMPTION + assert isinstance(first_resume_events[-1], GraphRunPausedEvent) + pause_store.save(first_resume_state) + + second_resume_state = pause_store.load() + second_resume_repo = StaticRepo( + { + "human_a": _form(submitted=True, action_id="approve"), + "human_b": _form(submitted=True, action_id="approve"), + } + ) + second_resume_graph = _build_graph(second_resume_state, second_resume_repo) + second_resume_events = _run_graph(second_resume_graph, second_resume_state) + + assert isinstance(second_resume_events[0], GraphRunStartedEvent) + assert second_resume_events[0].reason is WorkflowStartReason.RESUMPTION + assert isinstance(second_resume_events[-1], GraphRunSucceededEvent) + assert any(isinstance(event, NodeRunSucceededEvent) and event.node_id == "end" for event in second_resume_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py new file mode 100644 index 0000000000..62aa56fc57 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -0,0 +1,333 @@ +import time +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any + +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.workflow.entities import GraphInitParams +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.config import GraphEngineConfig +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + NodeRunPauseRequestedEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now + +from .test_mock_config import MockConfig, NodeMockConfig +from .test_mock_nodes import MockLLMNode + + +@dataclass +class StaticForm(HumanInputFormEntity): + form_id: str + rendered: str + is_submitted: bool + action_id: str | None = None + data: Mapping[str, Any] | None = None + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = naive_utc_now() + timedelta(days=1) + + @property + def id(self) -> str: + return self.form_id + + @property + def web_app_token(self) -> str | None: + return "token" + + @property + def recipients(self) -> list: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class StaticRepo(HumanInputFormRepository): + def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: + self._forms_by_node_id = dict(forms_by_node_id) + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + raise AssertionError("create_form should not be called in resume scenario") + + +class DelayedHumanInputNode(HumanInputNode): + def __init__(self, delay_seconds: float, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._delay_seconds = delay_seconds + + def _run(self): + if self._delay_seconds > 0: + time.sleep(self._delay_seconds) + yield from super()._run() + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + human_data = HumanInputNodeData( + title="Human Input", + form_content="Human input required", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + ) + + human_a_config = {"id": "human_a", "data": human_data.model_dump()} + human_a = HumanInputNode( + id=human_a_config["id"], + config=human_a_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=repo, + ) + + human_b_config = {"id": "human_b", "data": human_data.model_dump()} + human_b = DelayedHumanInputNode( + id=human_b_config["id"], + config=human_b_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=repo, + delay_seconds=0.2, + ) + + llm_data = LLMNodeData( + title="LLM A", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Prompt A", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + structured_output_enabled=False, + ) + llm_config = {"id": "llm_a", "data": llm_data.model_dump()} + llm_a = MockLLMNode( + id=llm_config["id"], + config=llm_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + mock_config=mock_config, + ) + + return ( + Graph.new() + .add_root(start_node) + .add_node(human_a, from_node_id="start") + .add_node(human_b, from_node_id="start") + .add_node(llm_a, from_node_id="human_a", source_handle="approve") + .build() + ) + + +def test_parallel_human_input_pause_preserves_node_finished() -> None: + runtime_state = _build_runtime_state() + + runtime_state.graph_execution.start() + runtime_state.register_paused_node("human_a") + runtime_state.register_paused_node("human_b") + + submitted = StaticForm( + form_id="form-a", + rendered="rendered", + is_submitted=True, + action_id="approve", + data={}, + status_value=HumanInputFormStatus.SUBMITTED, + ) + pending = StaticForm( + form_id="form-b", + rendered="rendered", + is_submitted=False, + action_id=None, + data=None, + status_value=HumanInputFormStatus.WAITING, + ) + repo = StaticRepo({"human_a": submitted, "human_b": pending}) + + mock_config = MockConfig() + mock_config.simulate_delays = True + mock_config.set_node_config( + "llm_a", + NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), + ) + + graph = _build_graph(runtime_state, repo, mock_config) + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig( + min_workers=2, + max_workers=2, + scale_up_threshold=1, + scale_down_idle_time=30.0, + ), + ) + + events = list(engine.run()) + + llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) + llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) + human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) + graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) + graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events) + + assert graph_started + assert graph_paused + assert human_b_pause + assert llm_started + assert llm_succeeded + + +def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None: + base_state = _build_runtime_state() + base_state.graph_execution.start() + base_state.register_paused_node("human_a") + base_state.register_paused_node("human_b") + snapshot = base_state.dumps() + + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + + submitted = StaticForm( + form_id="form-a", + rendered="rendered", + is_submitted=True, + action_id="approve", + data={}, + status_value=HumanInputFormStatus.SUBMITTED, + ) + pending = StaticForm( + form_id="form-b", + rendered="rendered", + is_submitted=False, + action_id=None, + data=None, + status_value=HumanInputFormStatus.WAITING, + ) + repo = StaticRepo({"human_a": submitted, "human_b": pending}) + + mock_config = MockConfig() + mock_config.simulate_delays = True + mock_config.set_node_config( + "llm_a", + NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), + ) + + graph = _build_graph(resumed_state, repo, mock_config) + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=resumed_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig( + min_workers=2, + max_workers=2, + scale_up_threshold=1, + scale_down_idle_time=30.0, + ), + ) + + events = list(engine.run()) + + start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent)) + assert start_event.reason is WorkflowStartReason.RESUMPTION + + llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) + llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) + human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) + graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) + + assert graph_paused + assert human_b_pause + assert llm_started + assert llm_succeeded diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py new file mode 100644 index 0000000000..156cfefcd6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -0,0 +1,309 @@ +import time +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any + +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.workflow.entities import GraphInitParams +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.config import GraphEngineConfig +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now + +from .test_mock_config import MockConfig, NodeMockConfig +from .test_mock_nodes import MockLLMNode + + +@dataclass +class StaticForm(HumanInputFormEntity): + form_id: str + rendered: str + is_submitted: bool + action_id: str | None = None + data: Mapping[str, Any] | None = None + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = naive_utc_now() + timedelta(days=1) + + @property + def id(self) -> str: + return self.form_id + + @property + def web_app_token(self) -> str | None: + return "token" + + @property + def recipients(self) -> list: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class StaticRepo(HumanInputFormRepository): + def __init__(self, form: HumanInputFormEntity) -> None: + self._form = form + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + if node_id != "human_pause": + return None + return self._form + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + raise AssertionError("create_form should not be called in this test") + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + llm_a_data = LLMNodeData( + title="LLM A", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Prompt A", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + structured_output_enabled=False, + ) + llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()} + llm_a = MockLLMNode( + id=llm_a_config["id"], + config=llm_a_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + mock_config=mock_config, + ) + + llm_b_data = LLMNodeData( + title="LLM B", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Prompt B", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + structured_output_enabled=False, + ) + llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()} + llm_b = MockLLMNode( + id=llm_b_config["id"], + config=llm_b_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + mock_config=mock_config, + ) + + human_data = HumanInputNodeData( + title="Human Input", + form_content="Pause here", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + ) + human_config = {"id": "human_pause", "data": human_data.model_dump()} + human_node = HumanInputNode( + id=human_config["id"], + config=human_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=repo, + ) + + end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) + end_human_config = {"id": "end_human", "data": end_human_data.model_dump()} + end_human = EndNode( + id=end_human_config["id"], + config=end_human_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + return ( + Graph.new() + .add_root(start_node) + .add_node(llm_a, from_node_id="start") + .add_node(human_node, from_node_id="start") + .add_node(llm_b, from_node_id="llm_a") + .add_node(end_human, from_node_id="human_pause", source_handle="approve") + .build() + ) + + +def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None: + for event in events: + if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: + return event + return None + + +def test_pause_defers_ready_nodes_until_resume() -> None: + runtime_state = _build_runtime_state() + + paused_form = StaticForm( + form_id="form-pause", + rendered="rendered", + is_submitted=False, + status_value=HumanInputFormStatus.WAITING, + ) + pause_repo = StaticRepo(paused_form) + + mock_config = MockConfig() + mock_config.simulate_delays = True + mock_config.set_node_config( + "llm_a", + NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), + ) + mock_config.set_node_config( + "llm_b", + NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0), + ) + + graph = _build_graph(runtime_state, pause_repo, mock_config) + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig( + min_workers=2, + max_workers=2, + scale_up_threshold=1, + scale_down_idle_time=30.0, + ), + ) + + paused_events = list(engine.run()) + + assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events) + assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events) + assert _get_node_started_event(paused_events, "llm_b") is None + + snapshot = runtime_state.dumps() + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + + submitted_form = StaticForm( + form_id="form-pause", + rendered="rendered", + is_submitted=True, + action_id="approve", + data={}, + status_value=HumanInputFormStatus.SUBMITTED, + ) + resume_repo = StaticRepo(submitted_form) + + resumed_graph = _build_graph(resumed_state, resume_repo, mock_config) + resumed_engine = GraphEngine( + workflow_id="workflow", + graph=resumed_graph, + graph_runtime_state=resumed_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig( + min_workers=2, + max_workers=2, + scale_up_threshold=1, + scale_down_idle_time=30.0, + ), + ) + + resumed_events = list(resumed_engine.run()) + + start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent)) + assert start_event.reason is WorkflowStartReason.RESUMPTION + + llm_b_started = _get_node_started_event(resumed_events, "llm_b") + assert llm_b_started is not None + assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py new file mode 100644 index 0000000000..700b3f4b8b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -0,0 +1,217 @@ +import datetime +import time +from typing import Any +from unittest.mock import MagicMock + +from core.workflow.entities import GraphInitParams +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_events import ( + GraphEngineEvent, + GraphRunPausedEvent, + GraphRunSucceededEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_events.graph import GraphRunStartedEvent +from core.workflow.nodes.base.entities import OutputVariableEntity +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="test-execution-id", + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + form_entity.rendered_content = "rendered" + form_entity.submitted = True + form_entity.selected_action_id = action_id + form_entity.submitted_data = {} + form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1) + repo.get_form.return_value = form_entity + return repo + + +def _mock_form_repository_without_submission() -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + form_entity.rendered_content = "rendered" + form_entity.submitted = False + repo.create_form.return_value = form_entity + repo.get_form.return_value = None + return repo + + +def _build_human_input_graph( + runtime_state: GraphRuntimeState, + form_repository: HumanInputFormRepository, +) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="service-api", + call_depth=0, + ) + + start_data = StartNodeData(title="start", variables=[]) + start_node = StartNode( + id="start", + config={"id": "start", "data": start_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + human_data = HumanInputNodeData( + title="human", + form_content="Awaiting human input", + inputs=[], + user_actions=[ + UserAction(id="continue", title="Continue"), + ], + ) + human_node = HumanInputNode( + id="human", + config={"id": "human", "data": human_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + form_repository=form_repository, + ) + + end_data = EndNodeData( + title="end", + outputs=[ + OutputVariableEntity(variable="result", value_selector=["human", "action_id"]), + ], + desc=None, + ) + end_node = EndNode( + id="end", + config={"id": "end", "data": end_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + return ( + Graph.new() + .add_root(start_node) + .add_node(human_node) + .add_node(end_node, from_node_id="human", source_handle="continue") + .build() + ) + + +def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + return list(engine.run()) + + +def _node_successes(events: list[GraphEngineEvent]) -> list[str]: + return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)] + + +def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None: + for event in events: + if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: + return event + return None + + +def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any: + segment = variable_pool.get(selector) + assert segment is not None + return getattr(segment, "value", segment) + + +def test_engine_resume_restores_state_and_completion(): + # Baseline run without pausing + baseline_state = _build_runtime_state() + baseline_repo = _mock_form_repository_with_submission(action_id="continue") + baseline_graph = _build_human_input_graph(baseline_state, baseline_repo) + baseline_events = _run_graph(baseline_graph, baseline_state) + assert baseline_events + first_paused_event = baseline_events[0] + assert isinstance(first_paused_event, GraphRunStartedEvent) + assert first_paused_event.reason is WorkflowStartReason.INITIAL + assert isinstance(baseline_events[-1], GraphRunSucceededEvent) + baseline_success_nodes = _node_successes(baseline_events) + + # Run with pause + paused_state = _build_runtime_state() + pause_repo = _mock_form_repository_without_submission() + paused_graph = _build_human_input_graph(paused_state, pause_repo) + paused_events = _run_graph(paused_graph, paused_state) + assert paused_events + first_paused_event = paused_events[0] + assert isinstance(first_paused_event, GraphRunStartedEvent) + assert first_paused_event.reason is WorkflowStartReason.INITIAL + assert isinstance(paused_events[-1], GraphRunPausedEvent) + snapshot = paused_state.dumps() + + # Resume from snapshot + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + resume_repo = _mock_form_repository_with_submission(action_id="continue") + resumed_graph = _build_human_input_graph(resumed_state, resume_repo) + resumed_events = _run_graph(resumed_graph, resumed_state) + assert resumed_events + first_resumed_event = resumed_events[0] + assert isinstance(first_resumed_event, GraphRunStartedEvent) + assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION + assert isinstance(resumed_events[-1], GraphRunSucceededEvent) + + combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events) + assert combined_success_nodes == baseline_success_nodes + + paused_human_started = _node_start_event(paused_events, "human") + resumed_human_started = _node_start_event(resumed_events, "human") + assert paused_human_started is not None + assert resumed_human_started is not None + assert paused_human_started.id == resumed_human_started.id + + assert baseline_state.outputs == resumed_state.outputs + assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value( + resumed_state.variable_pool, ("human", "__action_id") + ) + assert baseline_state.graph_execution.completed + assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 488b47761b..21a642c2f8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -7,6 +7,7 @@ from core.workflow.nodes.base.node import Node # Ensures that all node classes are imported. from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed. _ = NODE_TYPE_CLASSES_MAPPING @@ -45,7 +46,9 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined assert isinstance(cls.node_type, NodeType) assert isinstance(node_version, str) node_type_and_version = (node_type, node_version) - assert node_type_and_version not in type_version_set + assert node_type_and_version not in type_version_set, ( + f"Duplicate node type and version for class: {cls=} {node_type_and_version=}" + ) type_version_set.add(node_type_and_version) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py b/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py new file mode 100644 index 0000000000..20807e9ef9 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py @@ -0,0 +1 @@ +# Unit tests for human input node diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py new file mode 100644 index 0000000000..ca4a887d20 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -0,0 +1,16 @@ +from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients +from core.workflow.runtime import VariablePool + + +def test_render_body_template_replaces_variable_values(): + config = EmailDeliveryConfig( + recipients=EmailRecipients(), + subject="Subject", + body="Hello {{#node1.value#}} {{#url#}}", + ) + variable_pool = VariablePool() + variable_pool.add(["node1", "value"], "World") + + result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool) + + assert result == "Hello World https://example.com" diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py new file mode 100644 index 0000000000..bfe7b03c13 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -0,0 +1,597 @@ +""" +Unit tests for human input node entities. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pydantic import ValidationError + +from core.workflow.entities import GraphInitParams +from core.workflow.node_events import PauseRequestedEvent +from core.workflow.node_events.node import StreamCompletedEvent +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + FormInput, + FormInputDefault, + HumanInputNodeData, + MemberRecipient, + UserAction, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, +) +from core.workflow.nodes.human_input.enums import ( + ButtonStyle, + DeliveryMethodType, + EmailRecipientType, + FormInputType, + PlaceholderType, + TimeoutUnit, +) +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository + + +class TestDeliveryMethod: + """Test DeliveryMethod entity.""" + + def test_webapp_delivery_method(self): + """Test webapp delivery method creation.""" + delivery_method = WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()) + + assert delivery_method.type == DeliveryMethodType.WEBAPP + assert delivery_method.enabled is True + assert isinstance(delivery_method.config, _WebAppDeliveryConfig) + + def test_email_delivery_method(self): + """Test email delivery method creation.""" + recipients = EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), + ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), + ], + ) + + config = EmailDeliveryConfig( + recipients=recipients, subject="Test Subject", body="Test body with {{#url#}} placeholder" + ) + + delivery_method = EmailDeliveryMethod(enabled=True, config=config) + + assert delivery_method.type == DeliveryMethodType.EMAIL + assert delivery_method.enabled is True + assert isinstance(delivery_method.config, EmailDeliveryConfig) + assert delivery_method.config.subject == "Test Subject" + assert len(delivery_method.config.recipients.items) == 2 + + +class TestFormInput: + """Test FormInput entity.""" + + def test_text_input_with_constant_default(self): + """Test text input with constant default value.""" + default = FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter your response here...") + + form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default) + + assert form_input.type == FormInputType.TEXT_INPUT + assert form_input.output_variable_name == "user_input" + assert form_input.default.type == PlaceholderType.CONSTANT + assert form_input.default.value == "Enter your response here..." + + def test_text_input_with_variable_default(self): + """Test text input with variable default value.""" + default = FormInputDefault(type=PlaceholderType.VARIABLE, selector=["node_123", "output_var"]) + + form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default) + + assert form_input.default.type == PlaceholderType.VARIABLE + assert form_input.default.selector == ["node_123", "output_var"] + + def test_form_input_without_default(self): + """Test form input without default value.""" + form_input = FormInput(type=FormInputType.PARAGRAPH, output_variable_name="description") + + assert form_input.type == FormInputType.PARAGRAPH + assert form_input.output_variable_name == "description" + assert form_input.default is None + + +class TestUserAction: + """Test UserAction entity.""" + + def test_user_action_creation(self): + """Test user action creation.""" + action = UserAction(id="approve", title="Approve", button_style=ButtonStyle.PRIMARY) + + assert action.id == "approve" + assert action.title == "Approve" + assert action.button_style == ButtonStyle.PRIMARY + + def test_user_action_default_button_style(self): + """Test user action with default button style.""" + action = UserAction(id="cancel", title="Cancel") + + assert action.button_style == ButtonStyle.DEFAULT + + def test_user_action_length_boundaries(self): + """Test user action id and title length boundaries.""" + action = UserAction(id="a" * 20, title="b" * 20) + + assert action.id == "a" * 20 + assert action.title == "b" * 20 + + @pytest.mark.parametrize( + ("field_name", "value"), + [ + ("id", "a" * 21), + ("title", "b" * 21), + ], + ) + def test_user_action_length_limits(self, field_name: str, value: str): + """User action fields should enforce max length.""" + data = {"id": "approve", "title": "Approve"} + data[field_name] = value + + with pytest.raises(ValidationError) as exc_info: + UserAction(**data) + + errors = exc_info.value.errors() + assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors) + + +class TestHumanInputNodeData: + """Test HumanInputNodeData entity.""" + + def test_valid_node_data_creation(self): + """Test creating valid human input node data.""" + delivery_methods = [WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())] + + inputs = [ + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="content", + default=FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter content..."), + ) + ] + + user_actions = [UserAction(id="submit", title="Submit", button_style=ButtonStyle.PRIMARY)] + + node_data = HumanInputNodeData( + title="Human Input Test", + desc="Test node description", + delivery_methods=delivery_methods, + form_content="# Test Form\n\nPlease provide input:\n\n{{#$output.content#}}", + inputs=inputs, + user_actions=user_actions, + timeout=24, + timeout_unit=TimeoutUnit.HOUR, + ) + + assert node_data.title == "Human Input Test" + assert node_data.desc == "Test node description" + assert len(node_data.delivery_methods) == 1 + assert node_data.form_content.startswith("# Test Form") + assert len(node_data.inputs) == 1 + assert len(node_data.user_actions) == 1 + assert node_data.timeout == 24 + assert node_data.timeout_unit == TimeoutUnit.HOUR + + def test_node_data_with_multiple_delivery_methods(self): + """Test node data with multiple delivery methods.""" + delivery_methods = [ + WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()), + EmailDeliveryMethod( + enabled=False, # Disabled method should be fine + config=EmailDeliveryConfig( + subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) + ), + ), + ] + + node_data = HumanInputNodeData( + title="Test Node", delivery_methods=delivery_methods, timeout=1, timeout_unit=TimeoutUnit.DAY + ) + + assert len(node_data.delivery_methods) == 2 + assert node_data.timeout == 1 + assert node_data.timeout_unit == TimeoutUnit.DAY + + def test_node_data_defaults(self): + """Test node data with default values.""" + node_data = HumanInputNodeData(title="Test Node") + + assert node_data.title == "Test Node" + assert node_data.desc is None + assert node_data.delivery_methods == [] + assert node_data.form_content == "" + assert node_data.inputs == [] + assert node_data.user_actions == [] + assert node_data.timeout == 36 + assert node_data.timeout_unit == TimeoutUnit.HOUR + + def test_duplicate_input_output_variable_name_raises_validation_error(self): + """Duplicate form input output_variable_name should raise validation error.""" + duplicate_inputs = [ + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"), + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"), + ] + + with pytest.raises(ValidationError, match="duplicated output_variable_name 'content'"): + HumanInputNodeData(title="Test Node", inputs=duplicate_inputs) + + def test_duplicate_user_action_ids_raise_validation_error(self): + """Duplicate user action ids should raise validation error.""" + duplicate_actions = [ + UserAction(id="submit", title="Submit"), + UserAction(id="submit", title="Submit Again"), + ] + + with pytest.raises(ValidationError, match="duplicated user action id 'submit'"): + HumanInputNodeData(title="Test Node", user_actions=duplicate_actions) + + def test_extract_outputs_field_names(self): + content = r"""This is titile {{#start.title#}} + + A content is required: + + {{#$output.content#}} + + A ending is required: + + {{#$output.ending#}} + """ + + node_data = HumanInputNodeData(title="Human Input", form_content=content) + field_names = node_data.outputs_field_names() + assert field_names == ["content", "ending"] + + +class TestRecipients: + """Test email recipient entities.""" + + def test_member_recipient(self): + """Test member recipient creation.""" + recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + + assert recipient.type == EmailRecipientType.MEMBER + assert recipient.user_id == "user-123" + + def test_external_recipient(self): + """Test external recipient creation.""" + recipient = ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com") + + assert recipient.type == EmailRecipientType.EXTERNAL + assert recipient.email == "test@example.com" + + def test_email_recipients_whole_workspace(self): + """Test email recipients with whole workspace enabled.""" + recipients = EmailRecipients( + whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] + ) + + assert recipients.whole_workspace is True + assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True + + def test_email_recipients_specific_users(self): + """Test email recipients with specific users.""" + recipients = EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), + ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), + ], + ) + + assert recipients.whole_workspace is False + assert len(recipients.items) == 2 + assert recipients.items[0].user_id == "user-123" + assert recipients.items[1].email == "external@example.com" + + +class TestHumanInputNodeVariableResolution: + """Tests for resolving variable-based defaults in HumanInputNode.""" + + def test_resolves_variable_defaults(self): + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + variable_pool.add(("start", "name"), "Jane Doe") + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + node_data = HumanInputNodeData( + title="Human Input", + form_content="Provide your name", + inputs=[ + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="user_name", + default=FormInputDefault(type=PlaceholderType.VARIABLE, selector=["start", "name"]), + ), + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="user_email", + default=FormInputDefault(type=PlaceholderType.CONSTANT, value="foo@example.com"), + ), + ], + user_actions=[UserAction(id="submit", title="Submit")], + ) + config = {"id": "human", "data": node_data.model_dump()} + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-1", + rendered_content="Provide your name", + web_app_token="token", + recipients=[], + submitted=False, + ) + + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=mock_repo, + ) + + run_result = node._run() + pause_event = next(run_result) + + assert isinstance(pause_event, PauseRequestedEvent) + expected_values = {"user_name": "Jane Doe"} + assert pause_event.reason.resolved_default_values == expected_values + + params = mock_repo.create_form.call_args.args[0] + assert params.resolved_default_values == expected_values + + def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-2", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + node_data = HumanInputNodeData( + title="Human Input", + form_content="Provide your name", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + ) + config = {"id": "human", "data": node_data.model_dump()} + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-2", + rendered_content="Provide your name", + web_app_token="console-token", + recipients=[SimpleNamespace(token="recipient-token")], + submitted=False, + ) + + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=mock_repo, + ) + + run_result = node._run() + pause_event = next(run_result) + + assert isinstance(pause_event, PauseRequestedEvent) + assert pause_event.reason.form_token == "console-token" + + def test_debugger_debug_mode_overrides_email_recipients(self): + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user-123", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-3", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user-123", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + node_data = HumanInputNodeData( + title="Human Input", + form_content="Provide your name", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + delivery_methods=[ + EmailDeliveryMethod( + enabled=True, + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")], + ), + subject="Subject", + body="Body", + debug_mode=True, + ), + ) + ], + ) + config = {"id": "human", "data": node_data.model_dump()} + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-3", + rendered_content="Provide your name", + web_app_token="token", + recipients=[], + submitted=False, + ) + + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=mock_repo, + ) + + run_result = node._run() + pause_event = next(run_result) + assert isinstance(pause_event, PauseRequestedEvent) + + params = mock_repo.create_form.call_args.args[0] + assert len(params.delivery_methods) == 1 + method = params.delivery_methods[0] + assert isinstance(method, EmailDeliveryMethod) + assert method.config.debug_mode is True + assert method.config.recipients.whole_workspace is False + assert len(method.config.recipients.items) == 1 + recipient = method.config.recipients.items[0] + assert isinstance(recipient, MemberRecipient) + assert recipient.user_id == "user-123" + + +class TestValidation: + """Test validation scenarios.""" + + def test_invalid_form_input_type(self): + """Test validation with invalid form input type.""" + with pytest.raises(ValidationError): + FormInput( + type="invalid-type", # Invalid type + output_variable_name="test", + ) + + def test_invalid_button_style(self): + """Test validation with invalid button style.""" + with pytest.raises(ValidationError): + UserAction( + id="test", + title="Test", + button_style="invalid-style", # Invalid style + ) + + def test_invalid_timeout_unit(self): + """Test validation with invalid timeout unit.""" + with pytest.raises(ValidationError): + HumanInputNodeData( + title="Test", + timeout_unit="invalid-unit", # Invalid unit + ) + + +class TestHumanInputNodeRenderedContent: + """Tests for rendering submitted content.""" + + def test_replaces_outputs_placeholders_after_submission(self): + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + node_data = HumanInputNodeData( + title="Human Input", + form_content="Name: {{#$output.name#}}", + inputs=[ + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="name", + ) + ], + user_actions=[UserAction(id="approve", title="Approve")], + ) + config = {"id": "human", "data": node_data.model_dump()} + + form_repository = InMemoryHumanInputFormRepository() + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=form_repository, + ) + + pause_gen = node._run() + pause_event = next(pause_gen) + assert isinstance(pause_event, PauseRequestedEvent) + with pytest.raises(StopIteration): + next(pause_gen) + + form_repository.set_submission(action_id="approve", form_data={"name": "Alice"}) + + events = list(node._run()) + last_event = events[-1] + assert isinstance(last_event, StreamCompletedEvent) + node_run_result = last_event.node_run_result + assert node_run_result.outputs["__rendered_content"] == "Name: Alice" diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py new file mode 100644 index 0000000000..a19ee4dee3 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -0,0 +1,172 @@ +import datetime +from types import SimpleNamespace + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.enums import NodeType +from core.workflow.graph_events import ( + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunStartedEvent, +) +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now +from models.enums import UserFrom + + +class _FakeFormRepository: + def __init__(self, form): + self._form = form + + def get_form(self, *_args, **_kwargs): + return self._form + + +def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: + system_variables = SystemVariable.default() + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), + start_at=0.0, + ) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + config = { + "id": "node-1", + "type": NodeType.HUMAN_INPUT.value, + "data": { + "title": "Human Input", + "form_content": form_content, + "inputs": [ + { + "type": "text_input", + "output_variable_name": "name", + "default": {"type": "constant", "value": ""}, + } + ], + "user_actions": [ + { + "id": "Accept", + "title": "Approve", + "button_style": "default", + } + ], + }, + } + + fake_form = SimpleNamespace( + id="form-1", + rendered_content=form_content, + submitted=True, + selected_action_id="Accept", + submitted_data={"name": "Alice"}, + status=HumanInputFormStatus.SUBMITTED, + expiration_time=naive_utc_now() + datetime.timedelta(days=1), + ) + + repo = _FakeFormRepository(fake_form) + return HumanInputNode( + id="node-1", + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + form_repository=repo, + ) + + +def _build_timeout_node() -> HumanInputNode: + system_variables = SystemVariable.default() + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), + start_at=0.0, + ) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + config = { + "id": "node-1", + "type": NodeType.HUMAN_INPUT.value, + "data": { + "title": "Human Input", + "form_content": "Please enter your name:\n\n{{#$output.name#}}", + "inputs": [ + { + "type": "text_input", + "output_variable_name": "name", + "default": {"type": "constant", "value": ""}, + } + ], + "user_actions": [ + { + "id": "Accept", + "title": "Approve", + "button_style": "default", + } + ], + }, + } + + fake_form = SimpleNamespace( + id="form-1", + rendered_content="content", + submitted=False, + selected_action_id=None, + submitted_data=None, + status=HumanInputFormStatus.TIMEOUT, + expiration_time=naive_utc_now() - datetime.timedelta(minutes=1), + ) + + repo = _FakeFormRepository(fake_form) + return HumanInputNode( + id="node-1", + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + form_repository=repo, + ) + + +def test_human_input_node_emits_form_filled_event_before_succeeded(): + node = _build_node() + + events = list(node.run()) + + assert isinstance(events[0], NodeRunStartedEvent) + assert isinstance(events[1], NodeRunHumanInputFormFilledEvent) + + filled_event = events[1] + assert filled_event.node_title == "Human Input" + assert filled_event.rendered_content.endswith("Alice") + assert filled_event.action_id == "Accept" + assert filled_event.action_text == "Approve" + + +def test_human_input_node_emits_timeout_event_before_succeeded(): + node = _build_timeout_node() + + events = list(node.run()) + + assert isinstance(events[0], NodeRunStartedEvent) + assert isinstance(events[1], NodeRunHumanInputFormTimeoutEvent) + + timeout_event = events[1] + assert timeout_event.node_title == "Human Input" diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool_conver.py b/api/tests/unit_tests/core/workflow/test_variable_pool_conver.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index d3a4d69f07..38477409bb 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -104,6 +104,7 @@ class TestCelerySSLConfiguration: def test_celery_init_applies_ssl_to_broker_and_backend(self): """Test that SSL options are applied to both broker and backend when using Redis.""" mock_config = MagicMock() + mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1 mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.CELERY_BACKEND = "redis" mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" diff --git a/api/tests/unit_tests/extensions/test_pubsub_channel.py b/api/tests/unit_tests/extensions/test_pubsub_channel.py new file mode 100644 index 0000000000..a5b41a7266 --- /dev/null +++ b/api/tests/unit_tests/extensions/test_pubsub_channel.py @@ -0,0 +1,20 @@ +from configs import dify_config +from extensions import ext_redis +from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel +from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel + + +def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch): + monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + + channel = ext_redis.get_pubsub_broadcast_channel() + + assert isinstance(channel, RedisBroadcastChannel) + + +def test_get_pubsub_broadcast_channel_sharded(monkeypatch): + monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded") + + channel = ext_redis.get_pubsub_broadcast_channel() + + assert isinstance(channel, ShardedRedisBroadcastChannel) diff --git a/api/tests/unit_tests/libs/_human_input/__init__.py b/api/tests/unit_tests/libs/_human_input/__init__.py new file mode 100644 index 0000000000..66714e72f8 --- /dev/null +++ b/api/tests/unit_tests/libs/_human_input/__init__.py @@ -0,0 +1 @@ +# Treat this directory as a package so support modules can be imported relatively. diff --git a/api/tests/unit_tests/libs/_human_input/support.py b/api/tests/unit_tests/libs/_human_input/support.py new file mode 100644 index 0000000000..bd86c13a2c --- /dev/null +++ b/api/tests/unit_tests/libs/_human_input/support.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any + +from core.workflow.nodes.human_input.entities import FormInput +from core.workflow.nodes.human_input.enums import TimeoutUnit + + +# Exceptions +class HumanInputError(Exception): + error_code: str = "unknown" + + def __init__(self, message: str = "", error_code: str | None = None): + super().__init__(message) + self.message = message or self.__class__.__name__ + if error_code: + self.error_code = error_code + + +class FormNotFoundError(HumanInputError): + error_code = "form_not_found" + + +class FormExpiredError(HumanInputError): + error_code = "human_input_form_expired" + + +class FormAlreadySubmittedError(HumanInputError): + error_code = "human_input_form_submitted" + + +class InvalidFormDataError(HumanInputError): + error_code = "invalid_form_data" + + +# Models +@dataclass +class HumanInputForm: + form_id: str + workflow_run_id: str + node_id: str + tenant_id: str + app_id: str | None + form_content: str + inputs: list[FormInput] + user_actions: list[dict[str, Any]] + timeout: int + timeout_unit: TimeoutUnit + form_token: str | None = None + created_at: datetime = field(default_factory=datetime.utcnow) + expires_at: datetime | None = None + submitted_at: datetime | None = None + submitted_data: dict[str, Any] | None = None + submitted_action: str | None = None + + def __post_init__(self) -> None: + if self.expires_at is None: + self.calculate_expiration() + + @property + def is_expired(self) -> bool: + return self.expires_at is not None and datetime.utcnow() > self.expires_at + + @property + def is_submitted(self) -> bool: + return self.submitted_at is not None + + def mark_submitted(self, inputs: dict[str, Any], action: str) -> None: + self.submitted_data = inputs + self.submitted_action = action + self.submitted_at = datetime.utcnow() + + def submit(self, inputs: dict[str, Any], action: str) -> None: + self.mark_submitted(inputs, action) + + def calculate_expiration(self) -> None: + start = self.created_at + if self.timeout_unit == TimeoutUnit.HOUR: + self.expires_at = start + timedelta(hours=self.timeout) + elif self.timeout_unit == TimeoutUnit.DAY: + self.expires_at = start + timedelta(days=self.timeout) + else: + raise ValueError(f"Unsupported timeout unit {self.timeout_unit}") + + def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]: + inputs_response = [ + { + "type": form_input.type.name.lower().replace("_", "-"), + "output_variable_name": form_input.output_variable_name, + } + for form_input in self.inputs + ] + response = { + "form_content": self.form_content, + "inputs": inputs_response, + "user_actions": self.user_actions, + } + if include_site_info: + response["site"] = {"app_id": self.app_id, "title": "Workflow Form"} + return response + + +@dataclass +class FormSubmissionData: + form_id: str + inputs: dict[str, Any] + action: str + submitted_at: datetime = field(default_factory=datetime.utcnow) + + @classmethod + def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore + return cls(form_id=form_id, inputs=request.inputs, action=request.action) + + +@dataclass +class FormSubmissionRequest: + inputs: dict[str, Any] + action: str + + +# Repository +class InMemoryFormRepository: + """ + Simple in-memory repository used by unit tests. + """ + + def __init__(self): + self._forms: dict[str, HumanInputForm] = {} + + @property + def forms(self) -> dict[str, HumanInputForm]: + return self._forms + + def save(self, form: HumanInputForm) -> None: + self._forms[form.form_id] = form + + def get_by_id(self, form_id: str) -> HumanInputForm | None: + return self._forms.get(form_id) + + def get_by_token(self, token: str) -> HumanInputForm | None: + for form in self._forms.values(): + if form.form_token == token: + return form + return None + + def delete(self, form_id: str) -> None: + self._forms.pop(form_id, None) + + +# Service +class FormService: + """Service layer for managing human input forms in tests.""" + + def __init__(self, repository: InMemoryFormRepository): + self.repository = repository + + def create_form( + self, + *, + form_id: str, + workflow_run_id: str, + node_id: str, + tenant_id: str, + app_id: str | None, + form_content: str, + inputs, + user_actions, + timeout: int, + timeout_unit: TimeoutUnit, + form_token: str | None = None, + ) -> HumanInputForm: + form = HumanInputForm( + form_id=form_id, + workflow_run_id=workflow_run_id, + node_id=node_id, + tenant_id=tenant_id, + app_id=app_id, + form_content=form_content, + inputs=list(inputs), + user_actions=[{"id": action.id, "title": action.title} for action in user_actions], + timeout=timeout, + timeout_unit=timeout_unit, + form_token=form_token, + ) + form.calculate_expiration() + self.repository.save(form) + return form + + def get_form_by_id(self, form_id: str) -> HumanInputForm: + form = self.repository.get_by_id(form_id) + if form is None: + raise FormNotFoundError() + return form + + def get_form_by_token(self, token: str) -> HumanInputForm: + form = self.repository.get_by_token(token) + if form is None: + raise FormNotFoundError() + return form + + def get_form_definition(self, form_id: str, *, is_token: bool) -> dict: + form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id) + if form.is_expired: + raise FormExpiredError() + if form.is_submitted: + raise FormAlreadySubmittedError() + + definition = { + "form_content": form.form_content, + "inputs": form.inputs, + "user_actions": form.user_actions, + } + if is_token: + definition["site"] = {"title": "Workflow Form"} + return definition + + def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None: + form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id) + if form.is_expired: + raise FormExpiredError() + if form.is_submitted: + raise FormAlreadySubmittedError() + + self._validate_submission(form=form, submission_data=submission_data) + form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action) + self.repository.save(form) + + def cleanup_expired_forms(self) -> int: + expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired] + for form_id in expired_ids: + self.repository.delete(form_id) + return len(expired_ids) + + def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None: + defined_actions = {action["id"] for action in form.user_actions} + if submission_data.action not in defined_actions: + raise InvalidFormDataError(f"Invalid action: {submission_data.action}") + + missing_inputs = [] + for form_input in form.inputs: + if form_input.output_variable_name not in submission_data.inputs: + missing_inputs.append(form_input.output_variable_name) + + if missing_inputs: + raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}") + + # Extra inputs are allowed; no further validation required. diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py new file mode 100644 index 0000000000..15e7d41e85 --- /dev/null +++ b/api/tests/unit_tests/libs/_human_input/test_form_service.py @@ -0,0 +1,326 @@ +""" +Unit tests for FormService. +""" + +from datetime import datetime, timedelta + +import pytest + +from core.workflow.nodes.human_input.entities import ( + FormInput, + UserAction, +) +from core.workflow.nodes.human_input.enums import ( + FormInputType, + TimeoutUnit, +) +from libs.datetime_utils import naive_utc_now + +from .support import ( + FormAlreadySubmittedError, + FormExpiredError, + FormNotFoundError, + FormService, + FormSubmissionData, + InMemoryFormRepository, + InvalidFormDataError, +) + + +class TestFormService: + """Test FormService functionality.""" + + @pytest.fixture + def repository(self): + """Create in-memory repository for testing.""" + return InMemoryFormRepository() + + @pytest.fixture + def form_service(self, repository): + """Create FormService with in-memory repository.""" + return FormService(repository) + + @pytest.fixture + def sample_form_data(self): + """Create sample form data.""" + return { + "form_id": "form-123", + "workflow_run_id": "run-456", + "node_id": "node-789", + "tenant_id": "tenant-abc", + "app_id": "app-def", + "form_content": "# Test Form\n\nInput: {{#$output.input#}}", + "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)], + "user_actions": [UserAction(id="submit", title="Submit")], + "timeout": 1, + "timeout_unit": TimeoutUnit.HOUR, + "form_token": "token-xyz", + } + + def test_create_form(self, form_service, sample_form_data): + """Test form creation.""" + form = form_service.create_form(**sample_form_data) + + assert form.form_id == "form-123" + assert form.workflow_run_id == "run-456" + assert form.node_id == "node-789" + assert form.tenant_id == "tenant-abc" + assert form.app_id == "app-def" + assert form.form_token == "token-xyz" + assert form.timeout == 1 + assert form.timeout_unit == TimeoutUnit.HOUR + assert form.expires_at is not None + assert not form.is_expired + assert not form.is_submitted + + def test_get_form_by_id(self, form_service, sample_form_data): + """Test getting form by ID.""" + # Create form first + created_form = form_service.create_form(**sample_form_data) + + # Retrieve form + retrieved_form = form_service.get_form_by_id("form-123") + + assert retrieved_form.form_id == created_form.form_id + assert retrieved_form.workflow_run_id == created_form.workflow_run_id + + def test_get_form_by_id_not_found(self, form_service): + """Test getting non-existent form by ID.""" + with pytest.raises(FormNotFoundError) as exc_info: + form_service.get_form_by_id("non-existent-form") + + assert exc_info.value.error_code == "form_not_found" + + def test_get_form_by_token(self, form_service, sample_form_data): + """Test getting form by token.""" + # Create form first + created_form = form_service.create_form(**sample_form_data) + + # Retrieve form by token + retrieved_form = form_service.get_form_by_token("token-xyz") + + assert retrieved_form.form_id == created_form.form_id + assert retrieved_form.form_token == "token-xyz" + + def test_get_form_by_token_not_found(self, form_service): + """Test getting non-existent form by token.""" + with pytest.raises(FormNotFoundError) as exc_info: + form_service.get_form_by_token("non-existent-token") + + assert exc_info.value.error_code == "form_not_found" + + def test_get_form_definition_by_id(self, form_service, sample_form_data): + """Test getting form definition by ID.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Get form definition + definition = form_service.get_form_definition("form-123", is_token=False) + + assert "form_content" in definition + assert "inputs" in definition + assert definition["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}" + assert len(definition["inputs"]) == 1 + assert "site" not in definition # Should not include site info for ID-based access + + def test_get_form_definition_by_token(self, form_service, sample_form_data): + """Test getting form definition by token.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Get form definition + definition = form_service.get_form_definition("token-xyz", is_token=True) + + assert "form_content" in definition + assert "inputs" in definition + assert "site" in definition # Should include site info for token-based access + + def test_get_form_definition_expired_form(self, form_service, sample_form_data): + """Test getting definition for expired form.""" + # Create form with past expiry + form_service.create_form(**sample_form_data) + + # Manually expire the form by modifying expiry time + form = form_service.get_form_by_id("form-123") + form.expires_at = datetime.utcnow() - timedelta(hours=1) + form_service.repository.save(form) + + # Should raise FormExpiredError + with pytest.raises(FormExpiredError) as exc_info: + form_service.get_form_definition("form-123", is_token=False) + + assert exc_info.value.error_code == "human_input_form_expired" + + def test_get_form_definition_submitted_form(self, form_service, sample_form_data): + """Test getting definition for already submitted form.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Submit the form + submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") + form_service.submit_form("form-123", submission_data, is_token=False) + + # Should raise FormAlreadySubmittedError + with pytest.raises(FormAlreadySubmittedError) as exc_info: + form_service.get_form_definition("form-123", is_token=False) + + assert exc_info.value.error_code == "human_input_form_submitted" + + def test_submit_form_success(self, form_service, sample_form_data): + """Test successful form submission.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Submit form + submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") + + # Should not raise any exception + form_service.submit_form("form-123", submission_data, is_token=False) + + # Verify form is marked as submitted + form = form_service.get_form_by_id("form-123") + assert form.is_submitted + assert form.submitted_data == {"input": "test value"} + assert form.submitted_action == "submit" + assert form.submitted_at is not None + + def test_submit_form_missing_inputs(self, form_service, sample_form_data): + """Test form submission with missing inputs.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Submit form with missing required input + submission_data = FormSubmissionData( + form_id="form-123", + inputs={}, # Missing required "input" field + action="submit", + ) + + with pytest.raises(InvalidFormDataError) as exc_info: + form_service.submit_form("form-123", submission_data, is_token=False) + + assert "Missing required inputs" in exc_info.value.message + assert "input" in exc_info.value.message + + def test_submit_form_invalid_action(self, form_service, sample_form_data): + """Test form submission with invalid action.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Submit form with invalid action + submission_data = FormSubmissionData( + form_id="form-123", + inputs={"input": "test value"}, + action="invalid_action", # Not in the allowed actions + ) + + with pytest.raises(InvalidFormDataError) as exc_info: + form_service.submit_form("form-123", submission_data, is_token=False) + + assert "Invalid action" in exc_info.value.message + assert "invalid_action" in exc_info.value.message + + def test_submit_form_expired(self, form_service, sample_form_data): + """Test submitting expired form.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Manually expire the form + form = form_service.get_form_by_id("form-123") + form.expires_at = datetime.utcnow() - timedelta(hours=1) + form_service.repository.save(form) + + # Try to submit expired form + submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") + + with pytest.raises(FormExpiredError) as exc_info: + form_service.submit_form("form-123", submission_data, is_token=False) + + assert exc_info.value.error_code == "human_input_form_expired" + + def test_submit_form_already_submitted(self, form_service, sample_form_data): + """Test submitting form that's already submitted.""" + # Create and submit form first + form_service.create_form(**sample_form_data) + + submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "first submission"}, action="submit") + form_service.submit_form("form-123", submission_data, is_token=False) + + # Try to submit again + second_submission = FormSubmissionData( + form_id="form-123", inputs={"input": "second submission"}, action="submit" + ) + + with pytest.raises(FormAlreadySubmittedError) as exc_info: + form_service.submit_form("form-123", second_submission, is_token=False) + + assert exc_info.value.error_code == "human_input_form_submitted" + + def test_cleanup_expired_forms(self, form_service, sample_form_data): + """Test cleanup of expired forms.""" + # Create multiple forms + for i in range(3): + data = sample_form_data.copy() + data["form_id"] = f"form-{i}" + data["form_token"] = f"token-{i}" + form_service.create_form(**data) + + # Manually expire some forms + for i in range(2): # Expire first 2 forms + form = form_service.get_form_by_id(f"form-{i}") + form.expires_at = naive_utc_now() - timedelta(hours=1) + form_service.repository.save(form) + + # Clean up expired forms + cleaned_count = form_service.cleanup_expired_forms() + + assert cleaned_count == 2 + + # Verify expired forms are gone + with pytest.raises(FormNotFoundError): + form_service.get_form_by_id("form-0") + + with pytest.raises(FormNotFoundError): + form_service.get_form_by_id("form-1") + + # Verify non-expired form still exists + form = form_service.get_form_by_id("form-2") + assert form.form_id == "form-2" + + +class TestFormValidation: + """Test form validation logic.""" + + def test_validate_submission_with_extra_inputs(self): + """Test validation allows extra inputs that aren't defined in form.""" + repository = InMemoryFormRepository() + form_service = FormService(repository) + + # Create form with one input + form_data = { + "form_id": "form-123", + "workflow_run_id": "run-456", + "node_id": "node-789", + "tenant_id": "tenant-abc", + "app_id": "app-def", + "form_content": "Test form", + "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="required_input", default=None)], + "user_actions": [UserAction(id="submit", title="Submit")], + "timeout": 1, + "timeout_unit": TimeoutUnit.HOUR, + } + + form_service.create_form(**form_data) + + # Submit with extra input (should be allowed) + submission_data = FormSubmissionData( + form_id="form-123", + inputs={ + "required_input": "value1", + "extra_input": "value2", # Extra input not defined in form + }, + action="submit", + ) + + # Should not raise any exception + form_service.submit_form("form-123", submission_data, is_token=False) diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py new file mode 100644 index 0000000000..962eeb9e11 --- /dev/null +++ b/api/tests/unit_tests/libs/_human_input/test_models.py @@ -0,0 +1,232 @@ +""" +Unit tests for human input form models. +""" + +from datetime import datetime, timedelta + +import pytest + +from core.workflow.nodes.human_input.entities import ( + FormInput, + UserAction, +) +from core.workflow.nodes.human_input.enums import ( + FormInputType, + TimeoutUnit, +) + +from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm + + +class TestHumanInputForm: + """Test HumanInputForm model.""" + + @pytest.fixture + def sample_form_data(self): + """Create sample form data.""" + return { + "form_id": "form-123", + "workflow_run_id": "run-456", + "node_id": "node-789", + "tenant_id": "tenant-abc", + "app_id": "app-def", + "form_content": "# Test Form\n\nInput: {{#$output.input#}}", + "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)], + "user_actions": [UserAction(id="submit", title="Submit")], + "timeout": 2, + "timeout_unit": TimeoutUnit.HOUR, + "form_token": "token-xyz", + } + + def test_form_creation(self, sample_form_data): + """Test form creation.""" + form = HumanInputForm(**sample_form_data) + + assert form.form_id == "form-123" + assert form.workflow_run_id == "run-456" + assert form.node_id == "node-789" + assert form.tenant_id == "tenant-abc" + assert form.app_id == "app-def" + assert form.form_token == "token-xyz" + assert form.timeout == 2 + assert form.timeout_unit == TimeoutUnit.HOUR + assert form.created_at is not None + assert form.expires_at is not None + assert form.submitted_at is None + assert form.submitted_data is None + assert form.submitted_action is None + + def test_form_expiry_calculation_hours(self, sample_form_data): + """Test form expiry calculation for hours.""" + form = HumanInputForm(**sample_form_data) + + # Should expire 2 hours after creation + expected_expiry = form.created_at + timedelta(hours=2) + assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second + + def test_form_expiry_calculation_days(self, sample_form_data): + """Test form expiry calculation for days.""" + sample_form_data["timeout"] = 3 + sample_form_data["timeout_unit"] = TimeoutUnit.DAY + + form = HumanInputForm(**sample_form_data) + + # Should expire 3 days after creation + expected_expiry = form.created_at + timedelta(days=3) + assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second + + def test_form_expiry_property_not_expired(self, sample_form_data): + """Test is_expired property for non-expired form.""" + form = HumanInputForm(**sample_form_data) + assert not form.is_expired + + def test_form_expiry_property_expired(self, sample_form_data): + """Test is_expired property for expired form.""" + # Create form with past expiry + past_time = datetime.utcnow() - timedelta(hours=1) + sample_form_data["created_at"] = past_time + + form = HumanInputForm(**sample_form_data) + # Manually set expiry to past time + form.expires_at = past_time + + assert form.is_expired + + def test_form_submission_property_not_submitted(self, sample_form_data): + """Test is_submitted property for non-submitted form.""" + form = HumanInputForm(**sample_form_data) + assert not form.is_submitted + + def test_form_submission_property_submitted(self, sample_form_data): + """Test is_submitted property for submitted form.""" + form = HumanInputForm(**sample_form_data) + form.submit({"input": "test value"}, "submit") + + assert form.is_submitted + assert form.submitted_at is not None + assert form.submitted_data == {"input": "test value"} + assert form.submitted_action == "submit" + + def test_form_submit_method(self, sample_form_data): + """Test form submit method.""" + form = HumanInputForm(**sample_form_data) + + submission_time_before = datetime.utcnow() + form.submit({"input": "test value"}, "submit") + submission_time_after = datetime.utcnow() + + assert form.is_submitted + assert form.submitted_data == {"input": "test value"} + assert form.submitted_action == "submit" + assert submission_time_before <= form.submitted_at <= submission_time_after + + def test_form_to_response_dict_without_site_info(self, sample_form_data): + """Test converting form to response dict without site info.""" + form = HumanInputForm(**sample_form_data) + + response = form.to_response_dict(include_site_info=False) + + assert "form_content" in response + assert "inputs" in response + assert "site" not in response + assert response["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}" + assert len(response["inputs"]) == 1 + assert response["inputs"][0]["type"] == "text-input" + assert response["inputs"][0]["output_variable_name"] == "input" + + def test_form_to_response_dict_with_site_info(self, sample_form_data): + """Test converting form to response dict with site info.""" + form = HumanInputForm(**sample_form_data) + + response = form.to_response_dict(include_site_info=True) + + assert "form_content" in response + assert "inputs" in response + assert "site" in response + assert response["site"]["app_id"] == "app-def" + assert response["site"]["title"] == "Workflow Form" + + def test_form_without_web_app_token(self, sample_form_data): + """Test form creation without web app token.""" + sample_form_data["form_token"] = None + + form = HumanInputForm(**sample_form_data) + + assert form.form_token is None + assert form.form_id == "form-123" # Other fields should still work + + def test_form_with_explicit_timestamps(self): + """Test form creation with explicit timestamps.""" + created_time = datetime(2024, 1, 15, 10, 30, 0) + expires_time = datetime(2024, 1, 15, 12, 30, 0) + + form = HumanInputForm( + form_id="form-123", + workflow_run_id="run-456", + node_id="node-789", + tenant_id="tenant-abc", + app_id="app-def", + form_content="Test content", + inputs=[], + user_actions=[], + timeout=2, + timeout_unit=TimeoutUnit.HOUR, + created_at=created_time, + expires_at=expires_time, + ) + + assert form.created_at == created_time + assert form.expires_at == expires_time + + +class TestFormSubmissionData: + """Test FormSubmissionData model.""" + + def test_submission_data_creation(self): + """Test submission data creation.""" + submission_data = FormSubmissionData( + form_id="form-123", inputs={"field1": "value1", "field2": "value2"}, action="submit" + ) + + assert submission_data.form_id == "form-123" + assert submission_data.inputs == {"field1": "value1", "field2": "value2"} + assert submission_data.action == "submit" + assert submission_data.submitted_at is not None + + def test_submission_data_from_request(self): + """Test creating submission data from API request.""" + request = FormSubmissionRequest(inputs={"input": "test value"}, action="confirm") + + submission_data = FormSubmissionData.from_request("form-456", request) + + assert submission_data.form_id == "form-456" + assert submission_data.inputs == {"input": "test value"} + assert submission_data.action == "confirm" + assert submission_data.submitted_at is not None + + def test_submission_data_with_empty_inputs(self): + """Test submission data with empty inputs.""" + submission_data = FormSubmissionData(form_id="form-123", inputs={}, action="cancel") + + assert submission_data.inputs == {} + assert submission_data.action == "cancel" + + def test_submission_data_timestamps(self): + """Test submission data timestamp handling.""" + before_time = datetime.utcnow() + + submission_data = FormSubmissionData(form_id="form-123", inputs={"test": "value"}, action="submit") + + after_time = datetime.utcnow() + + assert before_time <= submission_data.submitted_at <= after_time + + def test_submission_data_with_explicit_timestamp(self): + """Test submission data with explicit timestamp.""" + specific_time = datetime(2024, 1, 15, 14, 30, 0) + + submission_data = FormSubmissionData( + form_id="form-123", inputs={"test": "value"}, action="submit", submitted_at=specific_time + ) + + assert submission_data.submitted_at == specific_time diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index de74eff82f..1a93dbbca1 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -1,6 +1,8 @@ +from datetime import datetime + import pytest -from libs.helper import escape_like_pattern, extract_tenant_id +from libs.helper import OptionalTimestampField, escape_like_pattern, extract_tenant_id from models.account import Account from models.model import EndUser @@ -65,6 +67,19 @@ class TestExtractTenantId: extract_tenant_id(dict_user) +class TestOptionalTimestampField: + def test_format_returns_none_for_none(self): + field = OptionalTimestampField() + + assert field.format(None) is None + + def test_format_returns_unix_timestamp_for_datetime(self): + field = OptionalTimestampField() + value = datetime(2024, 1, 2, 3, 4, 5) + + assert field.format(value) == int(value.timestamp()) + + class TestEscapeLikePattern: """Test cases for the escape_like_pattern utility function.""" diff --git a/api/tests/unit_tests/libs/test_rate_limiter.py b/api/tests/unit_tests/libs/test_rate_limiter.py new file mode 100644 index 0000000000..9d44b07b5e --- /dev/null +++ b/api/tests/unit_tests/libs/test_rate_limiter.py @@ -0,0 +1,68 @@ +from unittest.mock import MagicMock + +from libs import helper as helper_module + + +class _FakeRedis: + def __init__(self) -> None: + self._zsets: dict[str, dict[str, float]] = {} + self._expiry: dict[str, int] = {} + + def zadd(self, key: str, mapping: dict[str, float]) -> int: + zset = self._zsets.setdefault(key, {}) + for member, score in mapping.items(): + zset[str(member)] = float(score) + return len(mapping) + + def zremrangebyscore(self, key: str, min_score: str | float, max_score: str | float) -> int: + zset = self._zsets.get(key, {}) + min_value = float("-inf") if min_score == "-inf" else float(min_score) + max_value = float("inf") if max_score == "+inf" else float(max_score) + to_delete = [member for member, score in zset.items() if min_value <= score <= max_value] + for member in to_delete: + del zset[member] + return len(to_delete) + + def zcard(self, key: str) -> int: + return len(self._zsets.get(key, {})) + + def expire(self, key: str, ttl: int) -> bool: + self._expiry[key] = ttl + return True + + +def test_rate_limiter_counts_attempts_within_same_second(monkeypatch): + fake_redis = _FakeRedis() + monkeypatch.setattr(helper_module.time, "time", lambda: 1000) + + limiter = helper_module.RateLimiter( + prefix="test_rate_limit", + max_attempts=2, + time_window=60, + redis_client=fake_redis, + ) + + limiter.increment_rate_limit("203.0.113.10") + limiter.increment_rate_limit("203.0.113.10") + + assert limiter.is_rate_limited("203.0.113.10") is True + + +def test_rate_limiter_uses_injected_redis(monkeypatch): + redis_client = MagicMock() + redis_client.zcard.return_value = 1 + monkeypatch.setattr(helper_module.time, "time", lambda: 1000) + + limiter = helper_module.RateLimiter( + prefix="test_rate_limit", + max_attempts=1, + time_window=60, + redis_client=redis_client, + ) + + limiter.increment_rate_limit("203.0.113.10") + limiter.is_rate_limited("203.0.113.10") + + assert redis_client.zadd.called is True + assert redis_client.zremrangebyscore.called is True + assert redis_client.zcard.called is True diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 8be2eea121..c6dfd41803 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -1296,6 +1296,7 @@ class TestConversationStatusCount: assert result["success"] == 1 # One SUCCEEDED assert result["failed"] == 1 # One FAILED assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED + assert result["paused"] == 0 def test_status_count_app_id_filtering(self): """Test that status_count filters workflow runs by app_id for security.""" @@ -1350,6 +1351,7 @@ class TestConversationStatusCount: assert result["success"] == 0 assert result["failed"] == 0 assert result["partial_success"] == 0 + assert result["paused"] == 0 def test_status_count_handles_invalid_workflow_status(self): """Test that status_count gracefully handles invalid workflow status values.""" @@ -1404,3 +1406,57 @@ class TestConversationStatusCount: assert result["success"] == 0 assert result["failed"] == 0 assert result["partial_success"] == 0 + assert result["paused"] == 0 + + def test_status_count_paused(self): + """Test status_count includes paused workflow runs.""" + # Arrange + from core.workflow.enums import WorkflowExecutionStatus + + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_run_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id, + ), + ] + + mock_workflow_runs = [ + MagicMock( + id=workflow_run_id, + status=WorkflowExecutionStatus.PAUSED.value, + app_id=app_id, + ), + ] + + with patch("models.model.db.session.scalars") as mock_scalars: + + def mock_scalars_side_effect(query): + mock_result = MagicMock() + if "messages" in str(query): + mock_result.all.return_value = mock_messages + elif "workflow_runs" in str(query): + mock_result.all.return_value = mock_workflow_runs + else: + mock_result.all.return_value = [] + return mock_result + + mock_scalars.side_effect = mock_scalars_side_effect + + # Act + result = conversation.status_count + + # Assert + assert result["paused"] == 1 diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..ceb1406a4b --- /dev/null +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -0,0 +1,40 @@ +"""Unit tests for DifyAPISQLAlchemyWorkflowNodeExecutionRepository implementation.""" + +from unittest.mock import Mock + +from sqlalchemy.orm import Session, sessionmaker + +from repositories.sqlalchemy_api_workflow_node_execution_repository import ( + DifyAPISQLAlchemyWorkflowNodeExecutionRepository, +) + + +class TestDifyAPISQLAlchemyWorkflowNodeExecutionRepository: + def test_get_executions_by_workflow_run_keeps_paused_records(self): + mock_session = Mock(spec=Session) + execute_result = Mock() + execute_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = execute_result + + session_maker = Mock(spec=sessionmaker) + context_manager = Mock() + context_manager.__enter__ = Mock(return_value=mock_session) + context_manager.__exit__ = Mock(return_value=None) + session_maker.return_value = context_manager + + repository = DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker) + + repository.get_executions_by_workflow_run( + tenant_id="tenant-123", + app_id="app-123", + workflow_run_id="workflow-run-123", + ) + + stmt = mock_session.execute.call_args[0][0] + where_clauses = list(getattr(stmt, "_where_criteria", []) or []) + where_strs = [str(clause).lower() for clause in where_clauses] + + assert any("tenant_id" in clause for clause in where_strs) + assert any("app_id" in clause for clause in where_strs) + assert any("workflow_run_id" in clause for clause in where_strs) + assert not any("paused" in clause for clause in where_strs) diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index d443c4c9a5..4caaa056ff 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -1,5 +1,6 @@ """Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation.""" +import secrets from datetime import UTC, datetime from unittest.mock import Mock, patch @@ -7,12 +8,17 @@ import pytest from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session, sessionmaker +from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus +from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowRun +from models.workflow import WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, + _build_human_input_required_reason, _PrivateWorkflowPauseEntity, _WorkflowRunError, ) @@ -205,11 +211,11 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): ): """Test workflow pause creation when workflow not in RUNNING status.""" # Arrange - sample_workflow_run.status = WorkflowExecutionStatus.PAUSED + sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED mock_session.get.return_value = sample_workflow_run # Act & Assert - with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"): + with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"): repository.create_workflow_pause( workflow_run_id="workflow-run-123", state_owner_user_id="user-123", @@ -295,6 +301,7 @@ class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): sample_workflow_pause.resumed_at = None mock_session.scalar.return_value = sample_workflow_run + mock_session.scalars.return_value.all.return_value = [] with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now: mock_now.return_value = datetime.now(UTC) @@ -455,3 +462,53 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository) assert result1 == expected_state assert result2 == expected_state mock_storage.load.assert_called_once() # Only called once due to caching + + +class TestBuildHumanInputRequiredReason: + def test_prefers_backstage_token_when_available(self): + expiration_time = datetime.now(UTC) + form_definition = FormDefinition( + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "Alice"}, + node_title="Ask Name", + display_in_ui=True, + ) + form_model = HumanInputForm( + id="form-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_run_id="run-1", + node_id="node-1", + form_definition=form_definition.model_dump_json(), + rendered_content="rendered", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + reason_model = WorkflowPauseReason( + pause_id="pause-1", + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id="form-1", + node_id="node-1", + message="", + ) + access_token = secrets.token_urlsafe(8) + backstage_recipient = HumanInputFormRecipient( + form_id="form-1", + delivery_id="delivery-1", + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=access_token, + ) + + reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient]) + + assert isinstance(reason, HumanInputRequired) + assert reason.form_token == access_token + assert reason.node_title == "Ask Name" + assert reason.form_content == "content" + assert reason.inputs[0].output_variable_name == "name" + assert reason.actions[0].id == "approve" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py new file mode 100644 index 0000000000..f5428b46ff --- /dev/null +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta + +from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain +from core.entities.execution_extra_content import HumanInputFormSubmissionData +from core.workflow.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from models.execution_extra_content import HumanInputContent as HumanInputContentModel +from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType +from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository + + +class _FakeScalarResult: + def __init__(self, values: Sequence[HumanInputContentModel]): + self._values = list(values) + + def all(self) -> list[HumanInputContentModel]: + return list(self._values) + + +class _FakeSession: + def __init__(self, values: Sequence[Sequence[object]]): + self._values = list(values) + + def scalars(self, _stmt): + if not self._values: + return _FakeScalarResult([]) + return _FakeScalarResult(self._values.pop(0)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +@dataclass +class _FakeSessionMaker: + session: _FakeSession + + def __call__(self) -> _FakeSession: + return self.session + + +def _build_form(action_id: str, action_title: str, rendered_content: str) -> HumanInputForm: + expiration_time = datetime.now(UTC) + timedelta(days=1) + definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id=action_id, title=action_title)], + rendered_content="rendered", + expiration_time=expiration_time, + node_title="Approval", + display_in_ui=True, + ) + form = HumanInputForm( + id=f"form-{action_id}", + tenant_id="tenant-id", + app_id="app-id", + workflow_run_id="workflow-run", + node_id="node-id", + form_definition=definition.model_dump_json(), + rendered_content=rendered_content, + status=HumanInputFormStatus.SUBMITTED, + expiration_time=expiration_time, + ) + form.selected_action_id = action_id + return form + + +def _build_content(message_id: str, action_id: str, action_title: str) -> HumanInputContentModel: + form = _build_form( + action_id=action_id, + action_title=action_title, + rendered_content=f"Rendered {action_title}", + ) + content = HumanInputContentModel( + id=f"content-{message_id}", + form_id=form.id, + message_id=message_id, + workflow_run_id=form.workflow_run_id, + ) + content.form = form + return content + + +def test_get_by_message_ids_groups_contents_by_message() -> None: + message_ids = ["msg-1", "msg-2"] + contents = [_build_content("msg-1", "approve", "Approve")] + repository = SQLAlchemyExecutionExtraContentRepository( + session_maker=_FakeSessionMaker(session=_FakeSession(values=[contents, []])) + ) + + result = repository.get_by_message_ids(message_ids) + + assert len(result) == 2 + assert [content.model_dump(mode="json", exclude_none=True) for content in result[0]] == [ + HumanInputContentDomain( + workflow_run_id="workflow-run", + submitted=True, + form_submission_data=HumanInputFormSubmissionData( + node_id="node-id", + node_title="Approval", + rendered_content="Rendered Approve", + action_id="approve", + action_text="Approve", + ), + ).model_dump(mode="json", exclude_none=True) + ] + assert result[1] == [] + + +def test_get_by_message_ids_returns_unsubmitted_form_definition() -> None: + expiration_time = datetime.now(UTC) + timedelta(days=1) + definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "John"}, + node_title="Approval", + display_in_ui=True, + ) + form = HumanInputForm( + id="form-1", + tenant_id="tenant-id", + app_id="app-id", + workflow_run_id="workflow-run", + node_id="node-id", + form_definition=definition.model_dump_json(), + rendered_content="Rendered block", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + content = HumanInputContentModel( + id="content-msg-1", + form_id=form.id, + message_id="msg-1", + workflow_run_id=form.workflow_run_id, + ) + content.form = form + + recipient = HumanInputFormRecipient( + form_id=form.id, + delivery_id="delivery-1", + recipient_type=RecipientType.CONSOLE, + recipient_payload=ConsoleRecipientPayload(account_id=None).model_dump_json(), + access_token="token-1", + ) + + repository = SQLAlchemyExecutionExtraContentRepository( + session_maker=_FakeSessionMaker(session=_FakeSession(values=[[content], [recipient]])) + ) + + result = repository.get_by_message_ids(["msg-1"]) + + assert len(result) == 1 + assert len(result[0]) == 1 + domain_content = result[0][0] + assert domain_content.submitted is False + assert domain_content.workflow_run_id == "workflow-run" + assert domain_content.form_definition is not None + assert domain_content.form_definition.expiration_time == int(form.expiration_time.timestamp()) + assert domain_content.form_definition is not None + form_definition = domain_content.form_definition + assert form_definition.form_id == "form-1" + assert form_definition.node_id == "node-id" + assert form_definition.node_title == "Approval" + assert form_definition.form_content == "Rendered block" + assert form_definition.display_in_ui is True + assert form_definition.form_token == "token-1" + assert form_definition.resolved_default_values == {"name": "John"} + assert form_definition.expiration_time == int(form.expiration_time.timestamp()) diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py new file mode 100644 index 0000000000..71134464e6 --- /dev/null +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -0,0 +1,65 @@ +from unittest.mock import MagicMock + +import services.app_generate_service as app_generate_service_module +from models.model import AppMode +from services.app_generate_service import AppGenerateService + + +class _DummyRateLimit: + def __init__(self, client_id: str, max_active_requests: int) -> None: + self.client_id = client_id + self.max_active_requests = max_active_requests + + @staticmethod + def gen_request_key() -> str: + return "dummy-request-id" + + def enter(self, request_id: str | None = None) -> str: + return request_id or "dummy-request-id" + + def exit(self, request_id: str) -> None: + return None + + def generate(self, generator, request_id: str): + return generator + + +def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch): + monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False) + mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) + + workflow = MagicMock() + workflow.id = "workflow-id" + workflow.created_by = "owner-id" + + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + + generator_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.generate", + return_value={"result": "ok"}, + ) + + app_model = MagicMock() + app_model.mode = AppMode.WORKFLOW + app_model.id = "app-id" + app_model.tenant_id = "tenant-id" + app_model.max_active_requests = 0 + app_model.is_agent = False + + user = MagicMock() + user.id = "user-id" + + result = AppGenerateService.generate( + app_model=app_model, + user=user, + args={"inputs": {"k": "v"}}, + invoke_from=MagicMock(), + streaming=False, + ) + + assert result == {"result": "ok"} + + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert pause_state_config is not None + assert pause_state_config.state_owner_user_id == "owner-id" diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 81135dbbdf..eca1d44d23 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation: within conversations. """ + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session): + def test_pagination_by_first_id_without_first_id( + self, mock_get_conversation, mock_db_session, mock_create_extra_repo + ): """ Test message pagination without specifying first_id. @@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages + mock_repository = MagicMock() + mock_repository.get_by_message_ids.return_value = [[] for _ in messages] + mock_create_extra_repo.return_value = mock_repository # Act - Call the pagination method without first_id result = MessageService.pagination_by_first_id( @@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation: # Verify conversation was looked up with correct parameters mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id) + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session): + def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): """ Test message pagination with first_id specified. @@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation: mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.first.return_value = first_message # First message returned mock_query.all.return_value = messages # Remaining messages returned + mock_repository = MagicMock() + mock_repository.get_by_message_ids.return_value = [[] for _ in messages] + mock_create_extra_repo.return_value = mock_repository # Act - Call the pagination method with first_id result = MessageService.pagination_by_first_id( @@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation: assert result.data == [] assert result.has_more is False + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session): + def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): """ Test that has_more flag is correctly set when there are more messages. @@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages + mock_repository = MagicMock() + mock_repository.get_by_message_ids.return_value = [[] for _ in messages] + mock_create_extra_repo.return_value = mock_repository # Act result = MessageService.pagination_by_first_id( @@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation: assert len(result.data) == limit # Extra message should be removed assert result.has_more is True # Flag should be set + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session): + def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): """ Test message pagination with ascending order. @@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages + mock_repository = MagicMock() + mock_repository.get_by_message_ids.return_value = [[] for _ in messages] + mock_create_extra_repo.return_value = mock_repository # Act result = MessageService.pagination_by_first_id( diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py index 0aabe2fc30..08818945e3 100644 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -138,6 +138,7 @@ class TestDatasetServiceUpdateDataset: "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, + patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task, patch( "services.dataset_service.current_user", create_autospec(Account, instance=True) ) as mock_current_user, @@ -147,6 +148,7 @@ class TestDatasetServiceUpdateDataset: "model_manager": mock_model_manager, "get_binding": mock_get_binding, "task": mock_task, + "regenerate_task": mock_regenerate_task, "current_user": mock_current_user, } @@ -549,6 +551,13 @@ class TestDatasetServiceUpdateDataset: # Verify vector index task was triggered mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update") + # Verify regenerate summary index task was triggered (when embedding_model changes) + mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with( + "dataset-123", + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + # Verify return value assert result == dataset diff --git a/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py b/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py new file mode 100644 index 0000000000..ab141a7b2d --- /dev/null +++ b/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py @@ -0,0 +1,104 @@ +from dataclasses import dataclass + +import pytest + +from enums.cloud_plan import CloudPlan +from services import feature_service as feature_service_module +from services.feature_service import FeatureModel, FeatureService + + +@dataclass(frozen=True) +class HumanInputEmailDeliveryCase: + name: str + enterprise_enabled: bool + billing_enabled: bool + tenant_id: str | None + billing_feature_enabled: bool + plan: str + expected: bool + + +CASES = [ + HumanInputEmailDeliveryCase( + name="enterprise_enabled", + enterprise_enabled=True, + billing_enabled=True, + tenant_id=None, + billing_feature_enabled=False, + plan=CloudPlan.SANDBOX, + expected=True, + ), + HumanInputEmailDeliveryCase( + name="billing_disabled", + enterprise_enabled=False, + billing_enabled=False, + tenant_id=None, + billing_feature_enabled=False, + plan=CloudPlan.SANDBOX, + expected=True, + ), + HumanInputEmailDeliveryCase( + name="billing_enabled_requires_tenant", + enterprise_enabled=False, + billing_enabled=True, + tenant_id=None, + billing_feature_enabled=True, + plan=CloudPlan.PROFESSIONAL, + expected=False, + ), + HumanInputEmailDeliveryCase( + name="billing_feature_off", + enterprise_enabled=False, + billing_enabled=True, + tenant_id="tenant-1", + billing_feature_enabled=False, + plan=CloudPlan.PROFESSIONAL, + expected=False, + ), + HumanInputEmailDeliveryCase( + name="professional_plan", + enterprise_enabled=False, + billing_enabled=True, + tenant_id="tenant-1", + billing_feature_enabled=True, + plan=CloudPlan.PROFESSIONAL, + expected=True, + ), + HumanInputEmailDeliveryCase( + name="team_plan", + enterprise_enabled=False, + billing_enabled=True, + tenant_id="tenant-1", + billing_feature_enabled=True, + plan=CloudPlan.TEAM, + expected=True, + ), + HumanInputEmailDeliveryCase( + name="sandbox_plan", + enterprise_enabled=False, + billing_enabled=True, + tenant_id="tenant-1", + billing_feature_enabled=True, + plan=CloudPlan.SANDBOX, + expected=False, + ), +] + + +@pytest.mark.parametrize("case", CASES, ids=lambda case: case.name) +def test_resolve_human_input_email_delivery_enabled_matrix( + monkeypatch: pytest.MonkeyPatch, + case: HumanInputEmailDeliveryCase, +): + monkeypatch.setattr(feature_service_module.dify_config, "ENTERPRISE_ENABLED", case.enterprise_enabled) + monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", case.billing_enabled) + features = FeatureModel() + features.billing.enabled = case.billing_feature_enabled + features.billing.subscription.plan = case.plan + + result = FeatureService._resolve_human_input_email_delivery_enabled( + features=features, + tenant_id=case.tenant_id, + ) + + assert result is case.expected diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py new file mode 100644 index 0000000000..e0d6ad1b39 --- /dev/null +++ b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py @@ -0,0 +1,97 @@ +from types import SimpleNamespace + +import pytest + +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, +) +from core.workflow.runtime import VariablePool +from services import human_input_delivery_test_service as service_module +from services.human_input_delivery_test_service import ( + DeliveryTestContext, + DeliveryTestError, + EmailDeliveryTestHandler, +) + + +def _make_email_method() -> EmailDeliveryMethod: + return EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[ExternalRecipient(email="tester@example.com")], + ), + subject="Test subject", + body="Test body", + ) + ) + + +def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), + ) + + handler = EmailDeliveryTestHandler(session_factory=object()) + context = DeliveryTestContext( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + node_title="Human Input", + rendered_content="content", + ) + method = _make_email_method() + + with pytest.raises(DeliveryTestError, match="Email delivery is not available"): + handler.send_test(context=context, method=method) + + +def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): + class DummyMail: + def __init__(self): + self.sent: list[dict[str, str]] = [] + + def is_inited(self) -> bool: + return True + + def send(self, *, to: str, subject: str, html: str): + self.sent.append({"to": to, "subject": subject, "html": html}) + + mail = DummyMail() + monkeypatch.setattr(service_module, "mail", mail) + monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template) + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + + handler = EmailDeliveryTestHandler(session_factory=object()) + handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment] + + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]), + subject="Subject", + body="Value {{#node1.value#}}", + ) + ) + variable_pool = VariablePool() + variable_pool.add(["node1", "value"], "OK") + context = DeliveryTestContext( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + node_title="Human Input", + rendered_content="content", + variable_pool=variable_pool, + ) + + handler.send_test(context=context, method=method) + + assert mail.sent[0]["html"] == "Value OK" diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py new file mode 100644 index 0000000000..d2cf74daf3 --- /dev/null +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -0,0 +1,290 @@ +import dataclasses +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +import pytest + +import services.human_input_service as human_input_service_module +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormSubmissionRepository, +) +from core.workflow.nodes.human_input.entities import ( + FormDefinition, + FormInput, + UserAction, +) +from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus +from models.human_input import RecipientType +from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError +from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE + + +@pytest.fixture +def mock_session_factory(): + session = MagicMock() + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = None + + factory = MagicMock() + factory.return_value = session_cm + return factory, session + + +@pytest.fixture +def sample_form_record(): + return HumanInputFormRecord( + form_id="form-id", + workflow_run_id="workflow-run-id", + node_id="node-id", + tenant_id="tenant-id", + app_id="app-id", + form_kind=HumanInputFormKind.RUNTIME, + definition=FormDefinition( + form_content="hello", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + rendered_content="

hello

", + expiration_time=datetime.utcnow() + timedelta(hours=1), + ), + rendered_content="

hello

", + created_at=datetime.utcnow(), + expiration_time=datetime.utcnow() + timedelta(hours=1), + status=HumanInputFormStatus.WAITING, + selected_action_id=None, + submitted_data=None, + submitted_at=None, + submission_user_id=None, + submission_end_user_id=None, + completed_by_recipient_id=None, + recipient_id="recipient-id", + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="token", + ) + + +def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run = MagicMock() + workflow_run.app_id = "app-id" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + app = MagicMock() + app.mode = "workflow" + session.execute.return_value.scalar_one_or_none.return_value = app + + resume_task = mocker.patch("services.human_input_service.resume_app_execution") + + service.enqueue_resume("workflow-run-id") + + resume_task.apply_async.assert_called_once() + call_kwargs = resume_task.apply_async.call_args.kwargs + assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE + assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id" + + +def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + expired_record = dataclasses.replace( + sample_form_record, + created_at=datetime.utcnow() - timedelta(hours=2), + expiration_time=datetime.utcnow() + timedelta(hours=2), + ) + monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600) + + with pytest.raises(FormExpiredError): + service.ensure_form_active(Form(expired_record)) + + +def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run = MagicMock() + workflow_run.app_id = "app-id" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + app = MagicMock() + app.mode = "advanced-chat" + session.execute.return_value.scalar_one_or_none.return_value = app + + resume_task = mocker.patch("services.human_input_service.resume_app_execution") + + service.enqueue_resume("workflow-run-id") + + resume_task.apply_async.assert_called_once() + call_kwargs = resume_task.apply_async.call_args.kwargs + assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE + assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id" + + +def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run = MagicMock() + workflow_run.app_id = "app-id" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + app = MagicMock() + app.mode = "completion" + session.execute.return_value.scalar_one_or_none.return_value = app + + resume_task = mocker.patch("services.human_input_service.resume_app_execution") + + service.enqueue_resume("workflow-run-id") + + resume_task.apply_async.assert_not_called() + + +def test_get_form_definition_by_token_for_console_uses_repository(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + console_record = dataclasses.replace(sample_form_record, recipient_type=RecipientType.CONSOLE) + repo.get_by_token.return_value = console_record + + service = HumanInputService(session_factory, form_repository=repo) + form = service.get_form_definition_by_token_for_console("token") + + repo.get_by_token.assert_called_once_with("token") + assert form is not None + assert form.get_definition() == console_record.definition + + +def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + repo.mark_submitted.return_value = sample_form_record + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "enqueue_resume") + + service.submit_form_by_token( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token", + selected_action_id="submit", + form_data={"field": "value"}, + submission_end_user_id="end-user-id", + ) + + repo.get_by_token.assert_called_once_with("token") + repo.mark_submitted.assert_called_once() + call_kwargs = repo.mark_submitted.call_args.kwargs + assert call_kwargs["form_id"] == sample_form_record.form_id + assert call_kwargs["recipient_id"] == sample_form_record.recipient_id + assert call_kwargs["selected_action_id"] == "submit" + assert call_kwargs["form_data"] == {"field": "value"} + assert call_kwargs["submission_end_user_id"] == "end-user-id" + enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) + + +def test_submit_form_by_token_skips_enqueue_for_delivery_test(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + test_record = dataclasses.replace( + sample_form_record, + form_kind=HumanInputFormKind.DELIVERY_TEST, + workflow_run_id=None, + ) + repo.get_by_token.return_value = test_record + repo.mark_submitted.return_value = test_record + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "enqueue_resume") + + service.submit_form_by_token( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token", + selected_action_id="submit", + form_data={"field": "value"}, + ) + + enqueue_spy.assert_not_called() + + +def test_submit_form_by_token_passes_submission_user_id(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + repo.mark_submitted.return_value = sample_form_record + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "enqueue_resume") + + service.submit_form_by_token( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token", + selected_action_id="submit", + form_data={"field": "value"}, + submission_user_id="account-id", + ) + + call_kwargs = repo.mark_submitted.call_args.kwargs + assert call_kwargs["submission_user_id"] == "account-id" + assert call_kwargs["submission_end_user_id"] is None + enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) + + +def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = dataclasses.replace(sample_form_record) + service = HumanInputService(session_factory, form_repository=repo) + + with pytest.raises(InvalidFormDataError) as exc_info: + service.submit_form_by_token( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token", + selected_action_id="invalid", + form_data={}, + ) + + assert "Invalid action" in str(exc_info.value) + repo.mark_submitted.assert_not_called() + + +def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + + definition_with_input = FormDefinition( + form_content="hello", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content")], + user_actions=sample_form_record.definition.user_actions, + rendered_content="

hello

", + expiration_time=sample_form_record.expiration_time, + ) + form_with_input = dataclasses.replace(sample_form_record, definition=definition_with_input) + repo.get_by_token.return_value = form_with_input + service = HumanInputService(session_factory, form_repository=repo) + + with pytest.raises(InvalidFormDataError) as exc_info: + service.submit_form_by_token( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token", + selected_action_id="submit", + form_data={}, + ) + + assert "Missing required inputs" in str(exc_info.value) + repo.mark_submitted.assert_not_called() diff --git a/api/tests/unit_tests/services/test_message_service_extra_contents.py b/api/tests/unit_tests/services/test_message_service_extra_contents.py new file mode 100644 index 0000000000..3c8e301caa --- /dev/null +++ b/api/tests/unit_tests/services/test_message_service_extra_contents.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import pytest + +from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData +from services import message_service + + +class _FakeMessage: + def __init__(self, message_id: str): + self.id = message_id + self.extra_contents = None + + def set_extra_contents(self, contents): + self.extra_contents = contents + + +def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None: + messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")] + repo = type( + "Repo", + (), + { + "get_by_message_ids": lambda _self, message_ids: [ + [ + HumanInputContent( + workflow_run_id="workflow-run-1", + submitted=True, + form_submission_data=HumanInputFormSubmissionData( + node_id="node-1", + node_title="Approval", + rendered_content="Rendered", + action_id="approve", + action_text="Approve", + ), + ) + ], + [], + ] + }, + )() + + monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo) + + message_service.attach_message_extra_contents(messages) + + assert messages[0].extra_contents == [ + { + "type": "human_input", + "workflow_run_id": "workflow-run-1", + "submitted": True, + "form_submission_data": { + "node_id": "node-1", + "node_title": "Approval", + "rendered_content": "Rendered", + "action_id": "approve", + "action_text": "Approve", + }, + } + ] + assert messages[1].extra_contents == [] diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index f45a72927e..ded141f01a 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -35,7 +35,6 @@ class TestDataFactory: app_id: str = "app-789", workflow_id: str = "workflow-101", status: str | WorkflowExecutionStatus = "paused", - pause_id: str | None = None, **kwargs, ) -> MagicMock: """Create a mock WorkflowRun object.""" @@ -45,7 +44,6 @@ class TestDataFactory: mock_run.app_id = app_id mock_run.workflow_id = workflow_id mock_run.status = status - mock_run.pause_id = pause_id for key, value in kwargs.items(): setattr(mock_run, key, value) diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py new file mode 100644 index 0000000000..d6c92f1013 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py @@ -0,0 +1,158 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.tools.errors import WorkflowToolHumanInputNotSupportedError +from models.model import App +from models.tools import WorkflowToolProvider +from services.tools import workflow_tools_manage_service + + +class DummyWorkflow: + def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: + self._graph_dict = graph_dict + self.version = version + + @property + def graph_dict(self) -> dict: + return self._graph_dict + + +class FakeQuery: + def __init__(self, result): + self._result = result + + def where(self, *args, **kwargs): + return self + + def first(self): + return self._result + + +class DummySession: + def __init__(self) -> None: + self.added: list[object] = [] + + def __enter__(self) -> "DummySession": + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def add(self, obj) -> None: + self.added.append(obj) + + def begin(self): + return DummyBegin(self) + + +class DummyBegin: + def __init__(self, session: DummySession) -> None: + self._session = session + + def __enter__(self) -> DummySession: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +class DummySessionContext: + def __init__(self, session: DummySession) -> None: + self._session = session + + def __enter__(self) -> DummySession: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +class DummySessionFactory: + def __init__(self, session: DummySession) -> None: + self._session = session + + def create_session(self) -> DummySessionContext: + return DummySessionContext(self._session) + + +def _build_fake_session(app) -> SimpleNamespace: + def query(model): + if model is WorkflowToolProvider: + return FakeQuery(None) + if model is App: + return FakeQuery(app) + return FakeQuery(None) + + return SimpleNamespace(query=query) + + +def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch): + workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]}) + app = SimpleNamespace(workflow=workflow) + + fake_session = _build_fake_session(app) + monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) + + mock_from_db = MagicMock() + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) + mock_invalidate = MagicMock() + + parameters = [{"name": "input", "description": "input", "form": "form"}] + + with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: + workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( + user_id="user-id", + tenant_id="tenant-id", + workflow_app_id="app-id", + name="tool_name", + label="Tool", + icon={"type": "emoji", "emoji": "tool"}, + description="desc", + parameters=parameters, + ) + + assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + mock_from_db.assert_not_called() + mock_invalidate.assert_not_called() + + +def test_create_workflow_tool_success(monkeypatch): + workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]}) + app = SimpleNamespace(workflow=workflow) + + fake_db = MagicMock() + fake_session = _build_fake_session(app) + fake_db.session = fake_session + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + + dummy_session = DummySession() + monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) + + mock_from_db = MagicMock() + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) + + parameters = [{"name": "input", "description": "input", "form": "form"}] + icon = {"type": "emoji", "emoji": "tool"} + + result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( + user_id="user-id", + tenant_id="tenant-id", + workflow_app_id="app-id", + name="tool_name", + label="Tool", + icon=icon, + description="desc", + parameters=parameters, + ) + + assert result == {"result": "success"} + assert len(dummy_session.added) == 1 + created_provider = dummy_session.added[0] + assert created_provider.name == "tool_name" + assert created_provider.label == "Tool" + assert created_provider.icon == json.dumps(icon) + assert created_provider.version == workflow.version + mock_from_db.assert_called_once() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py new file mode 100644 index 0000000000..844dab8976 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import json +import queue +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import UTC, datetime +from threading import Event + +import pytest + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper +from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from core.workflow.runtime import GraphRuntimeState, VariablePool +from models.enums import CreatorUserRole +from models.model import AppMode +from models.workflow import WorkflowRun +from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot +from repositories.entities.workflow_pause import WorkflowPauseEntity +from services.workflow_event_snapshot_service import ( + BufferState, + MessageContext, + _build_snapshot_events, + _resolve_task_id, +) + + +@dataclass(frozen=True) +class _FakePauseEntity(WorkflowPauseEntity): + pause_id: str + workflow_run_id: str + paused_at_value: datetime + pause_reasons: Sequence[HumanInputRequired] + + @property + def id(self) -> str: + return self.pause_id + + @property + def workflow_execution_id(self) -> str: + return self.workflow_run_id + + def get_state(self) -> bytes: + raise AssertionError("state is not required for snapshot tests") + + @property + def resumed_at(self) -> datetime | None: + return None + + @property + def paused_at(self) -> datetime: + return self.paused_at_value + + def get_pause_reasons(self) -> Sequence[HumanInputRequired]: + return self.pause_reasons + + +def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun: + return WorkflowRun( + id="run-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + type="workflow", + triggered_from="app-run", + version="v1", + graph=None, + inputs=json.dumps({"input": "value"}), + status=status, + outputs=json.dumps({}), + error=None, + elapsed_time=0.0, + total_tokens=0, + total_steps=0, + created_by_role=CreatorUserRole.END_USER, + created_by="user-1", + created_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + + +def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot: + created_at = datetime(2024, 1, 1, tzinfo=UTC) + finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC) + return WorkflowNodeExecutionSnapshot( + execution_id="exec-1", + node_id="node-1", + node_type="human-input", + title="Human Input", + index=1, + status=status.value, + elapsed_time=0.5, + created_at=created_at, + finished_at=finished_at, + iteration_id=None, + loop_id=None, + ) + + +def _build_resumption_context(task_id: str) -> WorkflowResumptionContext: + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant-1", + app_id="app-1", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-1", + ) + generate_entity = WorkflowAppGenerateEntity( + task_id=task_id, + app_config=app_config, + inputs={}, + files=[], + user_id="user-1", + stream=True, + invoke_from=InvokeFrom.EXPLORE, + call_depth=0, + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + runtime_state.register_paused_node("node-1") + runtime_state.outputs = {"result": "value"} + wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity) + return WorkflowResumptionContext( + generate_entity=wrapper, + serialized_graph_runtime_state=runtime_state.dumps(), + ) + + +def test_build_snapshot_events_includes_pause_event() -> None: + workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED) + snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED) + resumption_context = _build_resumption_context("task-ctx") + pause_entity = _FakePauseEntity( + pause_id="pause-1", + workflow_run_id="run-1", + paused_at_value=datetime(2024, 1, 1, tzinfo=UTC), + pause_reasons=[ + HumanInputRequired( + form_id="form-1", + form_content="content", + node_id="node-1", + node_title="Human Input", + ) + ], + ) + + events = _build_snapshot_events( + workflow_run=workflow_run, + node_snapshots=[snapshot], + task_id="task-ctx", + message_context=None, + pause_entity=pause_entity, + resumption_context=resumption_context, + ) + + assert [event["event"] for event in events] == [ + "workflow_started", + "node_started", + "node_finished", + "workflow_paused", + ] + assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value + pause_data = events[-1]["data"] + assert pause_data["paused_nodes"] == ["node-1"] + assert pause_data["outputs"] == {"result": "value"} + assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value + assert pause_data["created_at"] == int(workflow_run.created_at.timestamp()) + assert pause_data["elapsed_time"] == workflow_run.elapsed_time + assert pause_data["total_tokens"] == workflow_run.total_tokens + assert pause_data["total_steps"] == workflow_run.total_steps + + +def test_build_snapshot_events_applies_message_context() -> None: + workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING) + snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED) + message_context = MessageContext( + conversation_id="conv-1", + message_id="msg-1", + created_at=1700000000, + answer="snapshot message", + ) + + events = _build_snapshot_events( + workflow_run=workflow_run, + node_snapshots=[snapshot], + task_id="task-1", + message_context=message_context, + pause_entity=None, + resumption_context=None, + ) + + assert [event["event"] for event in events] == [ + "workflow_started", + "message_replace", + "node_started", + "node_finished", + ] + assert events[1]["answer"] == "snapshot message" + for event in events: + assert event["conversation_id"] == "conv-1" + assert event["message_id"] == "msg-1" + assert event["created_at"] == 1700000000 + + +@pytest.mark.parametrize( + ("context_task_id", "buffered_task_id", "expected"), + [ + ("task-ctx", "task-buffer", "task-ctx"), + (None, "task-buffer", "task-buffer"), + (None, None, "run-1"), + ], +) +def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None: + resumption_context = _build_resumption_context(context_task_id) if context_task_id else None + buffer_state = BufferState( + queue=queue.Queue(), + stop_event=Event(), + done_event=Event(), + task_id_ready=Event(), + task_id_hint=buffered_task_id, + ) + if buffered_task_id: + buffer_state.task_id_ready.set() + task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0) + assert task_id == expected diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py new file mode 100644 index 0000000000..5ac5ac8ad2 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -0,0 +1,184 @@ +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.orm import sessionmaker + +from core.workflow.enums import NodeType +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + HumanInputNodeData, + MemberRecipient, +) +from services import workflow_service as workflow_service_module +from services.workflow_service import WorkflowService + + +def _make_service() -> WorkflowService: + return WorkflowService(session_maker=sessionmaker()) + + +def _build_node_config(delivery_methods): + node_data = HumanInputNodeData( + title="Human Input", + delivery_methods=delivery_methods, + form_content="Test content", + inputs=[], + user_actions=[], + ).model_dump(mode="json") + node_data["type"] = NodeType.HUMAN_INPUT.value + return {"id": "node-1", "data": node_data} + + +def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod: + return EmailDeliveryMethod( + id=uuid.uuid4(), + enabled=enabled, + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[ExternalRecipient(email="tester@example.com")], + ), + subject="Test subject", + body="Test body", + debug_mode=debug_mode, + ), + ) + + +def test_human_input_delivery_requires_draft_workflow(): + service = _make_service() + service.get_draft_workflow = MagicMock(return_value=None) # type: ignore[method-assign] + app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") + account = SimpleNamespace(id="account-1") + + with pytest.raises(ValueError, match="Workflow not initialized"): + service.test_human_input_delivery( + app_model=app_model, + account=account, + node_id="node-1", + delivery_method_id="delivery-1", + ) + + +def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch): + service = _make_service() + delivery_method = _make_email_method(enabled=False) + node_config = _build_node_config([delivery_method]) + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = node_config + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] + node_stub = MagicMock() + node_stub._render_form_content_before_submission.return_value = "rendered" + node_stub._resolve_default_values.return_value = {} + service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] + service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] + return_value=("form-1", {}) + ) + + test_service_instance = MagicMock() + monkeypatch.setattr( + workflow_service_module, + "HumanInputDeliveryTestService", + MagicMock(return_value=test_service_instance), + ) + + app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") + account = SimpleNamespace(id="account-1") + + service.test_human_input_delivery( + app_model=app_model, + account=account, + node_id="node-1", + delivery_method_id=str(delivery_method.id), + ) + + test_service_instance.send_test.assert_called_once() + + +def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch): + service = _make_service() + delivery_method = _make_email_method(enabled=True) + node_config = _build_node_config([delivery_method]) + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = node_config + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] + node_stub = MagicMock() + node_stub._render_form_content_before_submission.return_value = "rendered" + node_stub._resolve_default_values.return_value = {} + service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] + service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] + return_value=("form-1", {}) + ) + + test_service_instance = MagicMock() + monkeypatch.setattr( + workflow_service_module, + "HumanInputDeliveryTestService", + MagicMock(return_value=test_service_instance), + ) + + app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") + account = SimpleNamespace(id="account-1") + + service.test_human_input_delivery( + app_model=app_model, + account=account, + node_id="node-1", + delivery_method_id=str(delivery_method.id), + inputs={"#node-1.output#": "value"}, + ) + + pool_args = service._build_human_input_variable_pool.call_args.kwargs + assert pool_args["manual_inputs"] == {"#node-1.output#": "value"} + test_service_instance.send_test.assert_called_once() + + +def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch): + service = _make_service() + delivery_method = _make_email_method(enabled=True, debug_mode=True) + node_config = _build_node_config([delivery_method]) + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = node_config + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] + node_stub = MagicMock() + node_stub._render_form_content_before_submission.return_value = "rendered" + node_stub._resolve_default_values.return_value = {} + service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] + service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] + return_value=("form-1", {}) + ) + + test_service_instance = MagicMock() + monkeypatch.setattr( + workflow_service_module, + "HumanInputDeliveryTestService", + MagicMock(return_value=test_service_instance), + ) + + app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") + account = SimpleNamespace(id="account-1") + + service.test_human_input_delivery( + app_model=app_model, + account=account, + node_id="node-1", + delivery_method_id=str(delivery_method.id), + ) + + test_service_instance.send_test.assert_called_once() + sent_method = test_service_instance.send_test.call_args.kwargs["method"] + assert isinstance(sent_method, EmailDeliveryMethod) + assert sent_method.config.debug_mode is True + assert sent_method.config.recipients.whole_workspace is False + assert len(sent_method.config.recipients.items) == 1 + recipient = sent_method.config.recipients.items[0] + assert isinstance(recipient, MemberRecipient) + assert recipient.user_id == account.id diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py index 32d2f8b7e0..70d7bde870 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.workflow.enums import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel from repositories.sqlalchemy_api_workflow_node_execution_repository import ( DifyAPISQLAlchemyWorkflowNodeExecutionRepository, @@ -52,6 +53,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: call_args = mock_session.scalar.call_args[0][0] assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + compiled = call_args.compile() + assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values() + def test_get_node_last_execution_not_found(self, repository): """Test getting the last execution for a node when it doesn't exist.""" # Arrange @@ -71,28 +75,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result is None mock_session.scalar.assert_called_once() - def test_get_executions_by_workflow_run(self, repository, mock_execution): - """Test getting all executions for a workflow run.""" - # Arrange - mock_session = MagicMock(spec=Session) - repository._session_maker.return_value.__enter__.return_value = mock_session - executions = [mock_execution] - mock_session.execute.return_value.scalars.return_value.all.return_value = executions - - # Act - result = repository.get_executions_by_workflow_run( - tenant_id="tenant-123", - app_id="app-456", - workflow_run_id="run-101", - ) - - # Assert - assert result == executions - mock_session.execute.assert_called_once() - # Verify the query was constructed correctly - call_args = mock_session.execute.call_args[0][0] - assert hasattr(call_args, "compile") # It's a SQLAlchemy statement - def test_get_executions_by_workflow_run_empty(self, repository): """Test getting executions for a workflow run when none exist.""" # Arrange diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 9700cbaf0e..015dac257e 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -1,9 +1,15 @@ +from contextlib import nullcontext +from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from core.workflow.enums import NodeType +from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import FormInputType from models.model import App from models.workflow import Workflow +from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService @@ -161,3 +167,120 @@ class TestWorkflowService: assert workflows == [] assert has_more is False mock_session.scalars.assert_called_once() + + def test_submit_human_input_form_preview_uses_rendered_content( + self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch + ) -> None: + service = workflow_service + node_data = HumanInputNodeData( + title="Human Input", + form_content="

{{#$output.name#}}

", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + ) + node = MagicMock() + node.node_data = node_data + node.render_form_content_before_submission.return_value = "

preview

" + node.render_form_content_with_outputs.return_value = "

rendered

" + + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] + + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + workflow.get_enclosing_node_type_and_id.return_value = None + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + + saved_outputs: dict[str, object] = {} + + class DummySession: + def __init__(self, *args, **kwargs): + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def begin(self): + return nullcontext() + + class DummySaver: + def __init__(self, *args, **kwargs): + pass + + def save(self, outputs, process_data): + saved_outputs.update(outputs) + + monkeypatch.setattr(workflow_service_module, "Session", DummySession) + monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) + monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) + + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + account = SimpleNamespace(id="account-1") + + result = service.submit_human_input_form_preview( + app_model=app_model, + account=account, + node_id="node-1", + form_inputs={"name": "Ada", "extra": "ignored"}, + inputs={"#node-0.result#": "LLM output"}, + action="approve", + ) + + service._build_human_input_variable_pool.assert_called_once_with( + app_model=app_model, + workflow=workflow, + node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}, + manual_inputs={"#node-0.result#": "LLM output"}, + ) + + node.render_form_content_with_outputs.assert_called_once() + called_args = node.render_form_content_with_outputs.call_args.args + assert called_args[0] == "

preview

" + assert called_args[2] == node_data.outputs_field_names() + rendered_outputs = called_args[1] + assert rendered_outputs["name"] == "Ada" + assert rendered_outputs["extra"] == "ignored" + assert "extra" in saved_outputs + assert "extra" in result + assert saved_outputs["name"] == "Ada" + assert result["name"] == "Ada" + assert result["__action_id"] == "approve" + assert "__rendered_content" in result + + def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None: + service = workflow_service + node_data = HumanInputNodeData( + title="Human Input", + form_content="

{{#$output.name#}}

", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + ) + node = MagicMock() + node.node_data = node_data + node._render_form_content_before_submission.return_value = "

preview

" + node._render_form_content_with_outputs.return_value = "

rendered

" + + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] + + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + account = SimpleNamespace(id="account-1") + + with pytest.raises(ValueError) as exc_info: + service.submit_human_input_form_preview( + app_model=app_model, + account=account, + node_id="node-1", + form_inputs={}, + inputs={}, + action="approve", + ) + + assert "Missing required inputs" in str(exc_info.value) diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py new file mode 100644 index 0000000000..ee0699ba2d --- /dev/null +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from types import SimpleNamespace +from typing import Any + +import pytest + +from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from tasks import human_input_timeout_tasks as task_module + + +class _FakeScalarResult: + def __init__(self, items: list[Any]): + self._items = items + + def all(self) -> list[Any]: + return self._items + + +class _FakeSession: + def __init__(self, items: list[Any], capture: dict[str, Any]): + self._items = items + self._capture = capture + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, stmt): + self._capture["stmt"] = stmt + return _FakeScalarResult(self._items) + + +class _FakeSessionFactory: + def __init__(self, items: list[Any], capture: dict[str, Any]): + self._items = items + self._capture = capture + self._capture["session_factory"] = self + + def __call__(self): + session = _FakeSession(self._items, self._capture) + self._capture["session"] = session + return session + + +class _FakeFormRepo: + def __init__(self, _session_factory, form_map: dict[str, Any] | None = None): + self.calls: list[dict[str, Any]] = [] + self._form_map = form_map or {} + + def mark_timeout(self, *, form_id: str, timeout_status: HumanInputFormStatus, reason: str | None = None): + self.calls.append( + { + "form_id": form_id, + "timeout_status": timeout_status, + "reason": reason, + } + ) + form = self._form_map.get(form_id) + return SimpleNamespace( + form_id=form_id, + workflow_run_id=getattr(form, "workflow_run_id", None), + node_id=getattr(form, "node_id", None), + ) + + +class _FakeService: + def __init__(self, _session_factory, form_repository=None): + self.enqueued: list[str] = [] + + def enqueue_resume(self, workflow_run_id: str | None) -> None: + if workflow_run_id is not None: + self.enqueued.append(workflow_run_id) + + +def _build_form( + *, + form_id: str, + form_kind: HumanInputFormKind, + created_at: datetime, + expiration_time: datetime, + workflow_run_id: str | None, + node_id: str, +) -> SimpleNamespace: + return SimpleNamespace( + id=form_id, + form_kind=form_kind, + created_at=created_at, + expiration_time=expiration_time, + workflow_run_id=workflow_run_id, + node_id=node_id, + status=HumanInputFormStatus.WAITING, + ) + + +def test_is_global_timeout_uses_created_at(): + now = datetime(2025, 1, 1, 12, 0, 0) + form = SimpleNamespace(created_at=now - timedelta(seconds=61), workflow_run_id="run-1") + + assert task_module._is_global_timeout(form, 60, now=now) is True + + form.workflow_run_id = None + assert task_module._is_global_timeout(form, 60, now=now) is False + + form.workflow_run_id = "run-1" + form.created_at = now - timedelta(seconds=59) + assert task_module._is_global_timeout(form, 60, now=now) is False + + assert task_module._is_global_timeout(form, 0, now=now) is False + + +def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pytest.MonkeyPatch): + now = datetime(2025, 1, 1, 12, 0, 0) + monkeypatch.setattr(task_module, "naive_utc_now", lambda: now) + monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600) + monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object())) + + forms = [ + _build_form( + form_id="form-global", + form_kind=HumanInputFormKind.RUNTIME, + created_at=now - timedelta(hours=2), + expiration_time=now + timedelta(hours=1), + workflow_run_id="run-global", + node_id="node-global", + ), + _build_form( + form_id="form-node", + form_kind=HumanInputFormKind.RUNTIME, + created_at=now - timedelta(minutes=5), + expiration_time=now - timedelta(seconds=1), + workflow_run_id="run-node", + node_id="node-node", + ), + _build_form( + form_id="form-delivery", + form_kind=HumanInputFormKind.DELIVERY_TEST, + created_at=now - timedelta(minutes=1), + expiration_time=now - timedelta(seconds=1), + workflow_run_id=None, + node_id="node-delivery", + ), + ] + + capture: dict[str, Any] = {} + monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture)) + + form_map = {form.id: form for form in forms} + repo = _FakeFormRepo(None, form_map=form_map) + + def _repo_factory(_session_factory): + return repo + + service = _FakeService(None) + + def _service_factory(_session_factory, form_repository=None): + return service + + global_calls: list[dict[str, Any]] = [] + + monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _repo_factory) + monkeypatch.setattr(task_module, "HumanInputService", _service_factory) + monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **kwargs: global_calls.append(kwargs)) + + task_module.check_and_handle_human_input_timeouts(limit=100) + + assert {(call["form_id"], call["timeout_status"], call["reason"]) for call in repo.calls} == { + ("form-global", HumanInputFormStatus.EXPIRED, "global_timeout"), + ("form-node", HumanInputFormStatus.TIMEOUT, "node_timeout"), + ("form-delivery", HumanInputFormStatus.TIMEOUT, "delivery_test_timeout"), + } + assert service.enqueued == ["run-node"] + assert global_calls == [ + { + "form_id": "form-global", + "workflow_run_id": "run-global", + "node_id": "node-global", + "session_factory": capture.get("session_factory"), + } + ] + + stmt = capture.get("stmt") + assert stmt is not None + stmt_text = str(stmt) + assert "created_at <=" in stmt_text + assert "expiration_time <=" in stmt_text + assert "ORDER BY human_input_forms.id" in stmt_text + + +def test_check_and_handle_human_input_timeouts_omits_global_filter_when_disabled(monkeypatch: pytest.MonkeyPatch): + now = datetime(2025, 1, 1, 12, 0, 0) + monkeypatch.setattr(task_module, "naive_utc_now", lambda: now) + monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0) + monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object())) + + capture: dict[str, Any] = {} + monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory([], capture)) + monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _FakeFormRepo) + monkeypatch.setattr(task_module, "HumanInputService", _FakeService) + monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **_kwargs: None) + + task_module.check_and_handle_human_input_timeouts(limit=1) + + stmt = capture.get("stmt") + assert stmt is not None + stmt_text = str(stmt) + assert "created_at <=" not in stmt_text diff --git a/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py new file mode 100644 index 0000000000..20cb7a211e --- /dev/null +++ b/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py @@ -0,0 +1,123 @@ +from collections.abc import Sequence +from types import SimpleNamespace + +import pytest + +from tasks import mail_human_input_delivery_task as task_module + + +class _DummyMail: + def __init__(self): + self.sent: list[dict[str, str]] = [] + self._inited = True + + def is_inited(self) -> bool: + return self._inited + + def send(self, *, to: str, subject: str, html: str): + self.sent.append({"to": to, "subject": subject, "html": html}) + + +class _DummySession: + def __init__(self, form): + self._form = form + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def get(self, _model, _form_id): + return self._form + + +def _build_job(recipient_count: int = 1) -> task_module._EmailDeliveryJob: + recipients: list[task_module._EmailRecipient] = [] + for idx in range(recipient_count): + recipients.append(task_module._EmailRecipient(email=f"user{idx}@example.com", token=f"token-{idx}")) + + return task_module._EmailDeliveryJob( + form_id="form-1", + subject="Subject", + body="Body for {{#url}}", + form_content="content", + recipients=recipients, + ) + + +def test_dispatch_human_input_email_task_sends_to_each_recipient(monkeypatch: pytest.MonkeyPatch): + mail = _DummyMail() + form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None) + + monkeypatch.setattr(task_module, "mail", mail) + monkeypatch.setattr( + task_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + jobs: Sequence[task_module._EmailDeliveryJob] = [_build_job(recipient_count=2)] + monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: jobs) + + task_module.dispatch_human_input_email_task( + form_id="form-1", + node_title="Approve", + session_factory=lambda: _DummySession(form), + ) + + assert len(mail.sent) == 2 + assert all(payload["subject"] == "Subject" for payload in mail.sent) + assert all("Body for" in payload["html"] for payload in mail.sent) + + +def test_dispatch_human_input_email_task_skips_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): + mail = _DummyMail() + form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None) + + monkeypatch.setattr(task_module, "mail", mail) + monkeypatch.setattr( + task_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), + ) + monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: []) + + task_module.dispatch_human_input_email_task( + form_id="form-1", + node_title="Approve", + session_factory=lambda: _DummySession(form), + ) + + assert mail.sent == [] + + +def test_dispatch_human_input_email_task_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): + mail = _DummyMail() + form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id="run-1") + job = task_module._EmailDeliveryJob( + form_id="form-1", + subject="Subject", + body="Body {{#node1.value#}}", + form_content="content", + recipients=[task_module._EmailRecipient(email="user@example.com", token="token-1")], + ) + + variable_pool = task_module.VariablePool() + variable_pool.add(["node1", "value"], "OK") + + monkeypatch.setattr(task_module, "mail", mail) + monkeypatch.setattr( + task_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job]) + monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: variable_pool) + + task_module.dispatch_human_input_email_task( + form_id="form-1", + node_title="Approve", + session_factory=lambda: _DummySession(form), + ) + + assert mail.sent[0]["html"] == "Body OK" diff --git a/api/tests/unit_tests/tasks/test_workflow_execute_task.py b/api/tests/unit_tests/tasks/test_workflow_execute_task.py new file mode 100644 index 0000000000..161151305d --- /dev/null +++ b/api/tests/unit_tests/tasks/test_workflow_execute_task.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import json +import uuid +from unittest.mock import MagicMock + +import pytest + +from models.model import AppMode +from tasks.app_generate.workflow_execute_task import _publish_streaming_response + + +@pytest.fixture +def mock_topic(mocker) -> MagicMock: + topic = MagicMock() + mocker.patch( + "tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic", + return_value=topic, + ) + return topic + + +def test_publish_streaming_response_with_uuid(mock_topic: MagicMock): + workflow_run_id = uuid.uuid4() + response_stream = iter([{"event": "foo"}, "ping"]) + + _publish_streaming_response(response_stream, workflow_run_id, app_mode=AppMode.ADVANCED_CHAT) + + payloads = [call.args[0] for call in mock_topic.publish.call_args_list] + assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()] + + +def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock): + workflow_run_id = uuid.uuid4() + response_stream = iter([{"event": "bar"}]) + + _publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT) + + mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode()) diff --git a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py new file mode 100644 index 0000000000..fd5f0713a4 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py @@ -0,0 +1,488 @@ +# """ +# Unit tests for workflow node execution Celery tasks. + +# These tests verify the asynchronous storage functionality for workflow node execution data, +# including truncation and offloading logic. +# """ + +# import json +# from unittest.mock import MagicMock, Mock, patch +# from uuid import uuid4 + +# import pytest + +# from core.workflow.entities.workflow_node_execution import ( +# WorkflowNodeExecution, +# WorkflowNodeExecutionStatus, +# ) +# from core.workflow.enums import NodeType +# from libs.datetime_utils import naive_utc_now +# from models import WorkflowNodeExecutionModel +# from models.enums import ExecutionOffLoadType +# from models.model import UploadFile +# from models.workflow import WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom +# from tasks.workflow_node_execution_tasks import ( +# _create_truncator, +# _json_encode, +# _replace_or_append_offload, +# _truncate_and_upload_async, +# save_workflow_node_execution_data_task, +# save_workflow_node_execution_task, +# ) + + +# @pytest.fixture +# def sample_execution_data(): +# """Sample execution data for testing.""" +# execution = WorkflowNodeExecution( +# id=str(uuid4()), +# node_execution_id=str(uuid4()), +# workflow_id=str(uuid4()), +# workflow_execution_id=str(uuid4()), +# index=1, +# node_id="test_node", +# node_type=NodeType.LLM, +# title="Test Node", +# inputs={"input_key": "input_value"}, +# outputs={"output_key": "output_value"}, +# process_data={"process_key": "process_value"}, +# status=WorkflowNodeExecutionStatus.RUNNING, +# created_at=naive_utc_now(), +# ) +# return execution.model_dump() + + +# @pytest.fixture +# def mock_db_model(): +# """Mock database model for testing.""" +# db_model = Mock(spec=WorkflowNodeExecutionModel) +# db_model.id = "test-execution-id" +# db_model.offload_data = [] +# return db_model + + +# @pytest.fixture +# def mock_file_service(): +# """Mock file service for testing.""" +# file_service = Mock() +# mock_upload_file = Mock(spec=UploadFile) +# mock_upload_file.id = "mock-file-id" +# file_service.upload_file.return_value = mock_upload_file +# return file_service + + +# class TestSaveWorkflowNodeExecutionDataTask: +# """Test cases for save_workflow_node_execution_data_task.""" + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# @patch("tasks.workflow_node_execution_tasks.select") +# def test_save_execution_data_task_success( +# self, mock_select, mock_sessionmaker, sample_execution_data, mock_db_model +# ): +# """Test successful execution of save_workflow_node_execution_data_task.""" +# # Setup mocks +# mock_session = MagicMock() +# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session +# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model + +# # Execute task +# result = save_workflow_node_execution_data_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# user_data={"user_id": "test-user-id", "user_type": "account"}, +# ) + +# # Verify success +# assert result is True +# mock_session.merge.assert_called_once_with(mock_db_model) +# mock_session.commit.assert_called_once() + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# @patch("tasks.workflow_node_execution_tasks.select") +# def test_save_execution_data_task_execution_not_found(self, mock_select, mock_sessionmaker, +# sample_execution_data): +# """Test task when execution is not found in database.""" +# # Setup mocks +# mock_session = MagicMock() +# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session +# mock_session.execute.return_value.scalars.return_value.first.return_value = None + +# # Execute task +# result = save_workflow_node_execution_data_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# user_data={"user_id": "test-user-id", "user_type": "account"}, +# ) + +# # Verify failure +# assert result is False +# mock_session.merge.assert_not_called() +# mock_session.commit.assert_not_called() + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# @patch("tasks.workflow_node_execution_tasks.select") +# def test_save_execution_data_task_with_truncation(self, mock_select, mock_sessionmaker, mock_db_model): +# """Test task with data that requires truncation.""" +# # Create execution with large data +# large_data = {"large_field": "x" * 10000} +# execution = WorkflowNodeExecution( +# id=str(uuid4()), +# node_execution_id=str(uuid4()), +# workflow_id=str(uuid4()), +# workflow_execution_id=str(uuid4()), +# index=1, +# node_id="test_node", +# node_type=NodeType.LLM, +# title="Test Node", +# inputs=large_data, +# outputs=large_data, +# process_data=large_data, +# status=WorkflowNodeExecutionStatus.RUNNING, +# created_at=naive_utc_now(), +# ) +# execution_data = execution.model_dump() + +# # Setup mocks +# mock_session = MagicMock() +# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session +# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model + +# # Create mock upload file +# mock_upload_file = Mock(spec=UploadFile) +# mock_upload_file.id = "mock-file-id" + +# # Execute task +# with patch("tasks.workflow_node_execution_tasks._truncate_and_upload_async") as mock_truncate: +# # Mock truncation results +# mock_truncate.return_value = { +# "truncated_value": {"large_field": "[TRUNCATED]"}, +# "file": mock_upload_file, +# "offload": WorkflowNodeExecutionOffload( +# id=str(uuid4()), +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# node_execution_id=execution.id, +# type_=ExecutionOffLoadType.INPUTS, +# file_id=mock_upload_file.id, +# ), +# } + +# result = save_workflow_node_execution_data_task( +# execution_data=execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# user_data={"user_id": "test-user-id", "user_type": "account"}, +# ) + +# # Verify success and truncation was called +# assert result is True +# assert mock_truncate.call_count == 3 # inputs, outputs, process_data +# mock_session.merge.assert_called_once_with(mock_db_model) +# mock_session.commit.assert_called_once() + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# def test_save_execution_data_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data): +# """Test task retry mechanism on exception.""" +# # Setup mock to raise exception +# mock_sessionmaker.side_effect = Exception("Database error") + +# # Create a mock task instance with proper retry behavior +# with patch.object(save_workflow_node_execution_data_task, "retry") as mock_retry: +# mock_retry.side_effect = Exception("Retry called") + +# # Execute task and expect retry +# with pytest.raises(Exception, match="Retry called"): +# save_workflow_node_execution_data_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# user_data={"user_id": "test-user-id", "user_type": "account"}, +# ) + +# # Verify retry was called +# mock_retry.assert_called_once() + + +# class TestTruncateAndUploadAsync: +# """Test cases for _truncate_and_upload_async function.""" + +# def test_truncate_and_upload_with_none_values(self, mock_file_service): +# """Test _truncate_and_upload_async with None values.""" +# # The function handles None values internally, so we test with empty dict instead +# result = _truncate_and_upload_async( +# values={}, +# execution_id="test-id", +# type_=ExecutionOffLoadType.INPUTS, +# tenant_id="test-tenant", +# app_id="test-app", +# user_data={"user_id": "test-user", "user_type": "account"}, +# file_service=mock_file_service, +# ) + +# # Empty dict should not require truncation +# assert result is None +# mock_file_service.upload_file.assert_not_called() + +# @patch("tasks.workflow_node_execution_tasks._create_truncator") +# def test_truncate_and_upload_no_truncation_needed(self, mock_create_truncator, mock_file_service): +# """Test _truncate_and_upload_async when no truncation is needed.""" +# # Mock truncator to return no truncation +# mock_truncator = Mock() +# mock_truncator.truncate_variable_mapping.return_value = ({"small": "data"}, False) +# mock_create_truncator.return_value = mock_truncator + +# small_values = {"small": "data"} +# result = _truncate_and_upload_async( +# values=small_values, +# execution_id="test-id", +# type_=ExecutionOffLoadType.INPUTS, +# tenant_id="test-tenant", +# app_id="test-app", +# user_data={"user_id": "test-user", "user_type": "account"}, +# file_service=mock_file_service, +# ) + +# assert result is None +# mock_file_service.upload_file.assert_not_called() + +# @patch("tasks.workflow_node_execution_tasks._create_truncator") +# @patch("models.Account") +# @patch("models.Tenant") +# def test_truncate_and_upload_with_account_user( +# self, mock_tenant_class, mock_account_class, mock_create_truncator, mock_file_service +# ): +# """Test _truncate_and_upload_async with account user.""" +# # Mock truncator to return truncation needed +# mock_truncator = Mock() +# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True) +# mock_create_truncator.return_value = mock_truncator + +# # Mock user and tenant creation +# mock_account = Mock() +# mock_account.id = "test-user" +# mock_account_class.return_value = mock_account + +# mock_tenant = Mock() +# mock_tenant.id = "test-tenant" +# mock_tenant_class.return_value = mock_tenant + +# large_values = {"large": "x" * 10000} +# result = _truncate_and_upload_async( +# values=large_values, +# execution_id="test-id", +# type_=ExecutionOffLoadType.INPUTS, +# tenant_id="test-tenant", +# app_id="test-app", +# user_data={"user_id": "test-user", "user_type": "account"}, +# file_service=mock_file_service, +# ) + +# # Verify result structure +# assert result is not None +# assert "truncated_value" in result +# assert "file" in result +# assert "offload" in result +# assert result["truncated_value"] == {"truncated": "data"} + +# # Verify file upload was called +# mock_file_service.upload_file.assert_called_once() +# upload_call = mock_file_service.upload_file.call_args +# assert upload_call[1]["filename"] == "node_execution_test-id_inputs.json" +# assert upload_call[1]["mimetype"] == "application/json" +# assert upload_call[1]["user"] == mock_account + +# @patch("tasks.workflow_node_execution_tasks._create_truncator") +# @patch("models.EndUser") +# def test_truncate_and_upload_with_end_user(self, mock_end_user_class, mock_create_truncator, mock_file_service): +# """Test _truncate_and_upload_async with end user.""" +# # Mock truncator to return truncation needed +# mock_truncator = Mock() +# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True) +# mock_create_truncator.return_value = mock_truncator + +# # Mock end user creation +# mock_end_user = Mock() +# mock_end_user.id = "test-user" +# mock_end_user.tenant_id = "test-tenant" +# mock_end_user_class.return_value = mock_end_user + +# large_values = {"large": "x" * 10000} +# result = _truncate_and_upload_async( +# values=large_values, +# execution_id="test-id", +# type_=ExecutionOffLoadType.OUTPUTS, +# tenant_id="test-tenant", +# app_id="test-app", +# user_data={"user_id": "test-user", "user_type": "end_user"}, +# file_service=mock_file_service, +# ) + +# # Verify result structure +# assert result is not None +# assert result["truncated_value"] == {"truncated": "data"} + +# # Verify file upload was called with end user +# mock_file_service.upload_file.assert_called_once() +# upload_call = mock_file_service.upload_file.call_args +# assert upload_call[1]["filename"] == "node_execution_test-id_outputs.json" +# assert upload_call[1]["user"] == mock_end_user + + +# class TestHelperFunctions: +# """Test cases for helper functions.""" + +# @patch("tasks.workflow_node_execution_tasks.dify_config") +# def test_create_truncator(self, mock_config): +# """Test _create_truncator function.""" +# mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000 +# mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100 +# mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500 + +# truncator = _create_truncator() + +# # Verify truncator was created with correct config +# assert truncator is not None + +# def test_json_encode(self): +# """Test _json_encode function.""" +# test_data = {"key": "value", "number": 42} +# result = _json_encode(test_data) + +# assert isinstance(result, str) +# decoded = json.loads(result) +# assert decoded == test_data + +# def test_replace_or_append_offload_replace_existing(self): +# """Test _replace_or_append_offload replaces existing offload of same type.""" +# existing_offload = WorkflowNodeExecutionOffload( +# id=str(uuid4()), +# tenant_id="test-tenant", +# app_id="test-app", +# node_execution_id="test-execution", +# type_=ExecutionOffLoadType.INPUTS, +# file_id="old-file-id", +# ) + +# new_offload = WorkflowNodeExecutionOffload( +# id=str(uuid4()), +# tenant_id="test-tenant", +# app_id="test-app", +# node_execution_id="test-execution", +# type_=ExecutionOffLoadType.INPUTS, +# file_id="new-file-id", +# ) + +# result = _replace_or_append_offload([existing_offload], new_offload) + +# assert len(result) == 1 +# assert result[0].file_id == "new-file-id" + +# def test_replace_or_append_offload_append_new_type(self): +# """Test _replace_or_append_offload appends new offload of different type.""" +# existing_offload = WorkflowNodeExecutionOffload( +# id=str(uuid4()), +# tenant_id="test-tenant", +# app_id="test-app", +# node_execution_id="test-execution", +# type_=ExecutionOffLoadType.INPUTS, +# file_id="inputs-file-id", +# ) + +# new_offload = WorkflowNodeExecutionOffload( +# id=str(uuid4()), +# tenant_id="test-tenant", +# app_id="test-app", +# node_execution_id="test-execution", +# type_=ExecutionOffLoadType.OUTPUTS, +# file_id="outputs-file-id", +# ) + +# result = _replace_or_append_offload([existing_offload], new_offload) + +# assert len(result) == 2 +# file_ids = [offload.file_id for offload in result] +# assert "inputs-file-id" in file_ids +# assert "outputs-file-id" in file_ids + + +# class TestSaveWorkflowNodeExecutionTask: +# """Test cases for save_workflow_node_execution_task.""" + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# @patch("tasks.workflow_node_execution_tasks.select") +# def test_save_workflow_node_execution_task_create_new(self, mock_select, mock_sessionmaker, +# sample_execution_data): +# """Test creating a new workflow node execution.""" +# # Setup mocks +# mock_session = MagicMock() +# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session +# mock_session.scalar.return_value = None # No existing execution + +# # Execute task +# result = save_workflow_node_execution_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, +# creator_user_id="test-user-id", +# creator_user_role="account", +# ) + +# # Verify success +# assert result is True +# mock_session.add.assert_called_once() +# mock_session.commit.assert_called_once() + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# @patch("tasks.workflow_node_execution_tasks.select") +# def test_save_workflow_node_execution_task_update_existing( +# self, mock_select, mock_sessionmaker, sample_execution_data +# ): +# """Test updating an existing workflow node execution.""" +# # Setup mocks +# mock_session = MagicMock() +# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session + +# existing_execution = Mock(spec=WorkflowNodeExecutionModel) +# mock_session.scalar.return_value = existing_execution + +# # Execute task +# result = save_workflow_node_execution_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, +# creator_user_id="test-user-id", +# creator_user_role="account", +# ) + +# # Verify success +# assert result is True +# mock_session.add.assert_not_called() # Should not add new, just update existing +# mock_session.commit.assert_called_once() + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# def test_save_workflow_node_execution_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data): +# """Test task retry mechanism on exception.""" +# # Setup mock to raise exception +# mock_sessionmaker.side_effect = Exception("Database error") + +# # Create a mock task instance with proper retry behavior +# with patch.object(save_workflow_node_execution_task, "retry") as mock_retry: +# mock_retry.side_effect = Exception("Retry called") + +# # Execute task and expect retry +# with pytest.raises(Exception, match="Retry called"): +# save_workflow_node_execution_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, +# creator_user_id="test-user-id", +# creator_user_role="account", +# ) + +# # Verify retry was called +# mock_retry.assert_called_once() diff --git a/api/ty.toml b/api/ty.toml index bb4ff5bbcf..6869ca98c4 100644 --- a/api/ty.toml +++ b/api/ty.toml @@ -1,16 +1,45 @@ [src] exclude = [ - # TODO: enable when violations fixed + # deps groups (A1/A2/B/C/D/E) + # A2: workflow engine/nodes + "core/workflow", + "core/app/workflow", + "core/helper/code_executor", + # B: app runner + prompt + "core/prompt", + "core/app/apps/base_app_runner.py", "core/app/apps/workflow_app_runner.py", + # C: services/controllers/fields/libs + "services", "controllers/console/app", "controllers/console/explore", "controllers/console/datasets", "controllers/console/workspace", + "controllers/service_api/wraps.py", + "fields/conversation_fields.py", + "libs/external_api.py", + # D: observability + integrations + "core/ops", + "extensions", + # E: vector DB integrations + "core/rag/datasource/vdb", # non-producition or generated code "migrations", "tests", + # targeted ignores for current type-check errors + # TODO(QuantumGhost): suppress type errors in HITL related code. + # fix the type error later + "configs/middleware/cache/redis_pubsub_config.py", + "extensions/ext_redis.py", + "models/execution_extra_content.py", + "tasks/workflow_execution_tasks.py", + "core/workflow/nodes/base/node.py", + "services/human_input_delivery_test_service.py", + "core/app/apps/advanced_chat/app_generator.py", + "controllers/console/human_input_form.py", + "controllers/console/app/workflow_run.py", + "repositories/sqlalchemy_api_workflow_node_execution_repository.py", + "extensions/logstore/repositories/logstore_api_workflow_run_repository.py", + "controllers/web/workflow_events.py", + "tasks/app_generate/workflow_execute_task.py", ] - -[rules] -missing-argument = "ignore" # TODO: restore when **args for constructor is supported properly -possibly-unbound-attribute = "ignore" diff --git a/api/uv.lock b/api/uv.lock index 7808c16a8c..7bb43fbb12 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1368,7 +1368,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.11.4" +version = "1.12.0" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -1479,6 +1479,7 @@ dev = [ { name = "pytest-env" }, { name = "pytest-mock" }, { name = "pytest-timeout" }, + { name = "pytest-xdist" }, { name = "ruff" }, { name = "scipy-stubs" }, { name = "sseclient-py" }, @@ -1678,6 +1679,7 @@ dev = [ { name = "pytest-env", specifier = "~=1.1.3" }, { name = "pytest-mock", specifier = "~=3.14.0" }, { name = "pytest-timeout", specifier = ">=2.4.0" }, + { name = "pytest-xdist", specifier = ">=3.8.0" }, { name = "ruff", specifier = "~=0.14.0" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, @@ -1896,6 +1898,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/d8/2a1c638d9e0aa7e269269a1a1bf423ddd94267f1a01bbe3ad03432b67dd4/eval_type_backport-0.3.0-py3-none-any.whl", hash = "sha256:975a10a0fe333c8b6260d7fdb637698c9a16c3a9e3b6eb943fee6a6f67a37fe8", size = 6061, upload-time = "2025-11-13T20:56:49.499Z" }, ] +[[package]] +name = "execnet" +version = "2.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/89/780e11f9588d9e7128a3f87788354c7946a9cbb1401ad38a48c4db9a4f07/execnet-2.1.2.tar.gz", hash = "sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd", size = 166622, upload-time = "2025-11-12T09:56:37.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, +] + [[package]] name = "faker" version = "38.2.0" @@ -5141,6 +5152,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, ] +[[package]] +name = "pytest-xdist" +version = "3.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069, upload-time = "2025-07-01T13:30:59.346Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, +] + [[package]] name = "python-calamine" version = "0.5.4" diff --git a/dev/pytest/pytest_unit_tests.sh b/dev/pytest/pytest_unit_tests.sh index 496cb40952..7c39a48bf4 100755 --- a/dev/pytest/pytest_unit_tests.sh +++ b/dev/pytest/pytest_unit_tests.sh @@ -5,6 +5,12 @@ SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-20}" +PYTEST_XDIST_ARGS="${PYTEST_XDIST_ARGS:--n auto}" -# libs -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/unit_tests +# Run most tests in parallel (excluding controllers which have import conflicts with xdist) +# Controller tests have module-level side effects (Flask route registration) that cause +# race conditions when imported concurrently by multiple pytest-xdist workers. +pytest --timeout "${PYTEST_TIMEOUT}" ${PYTEST_XDIST_ARGS} api/tests/unit_tests --ignore=api/tests/unit_tests/controllers + +# Run controller tests sequentially to avoid import race conditions +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/unit_tests/controllers diff --git a/docker/.env.example b/docker/.env.example index b6c04fdb77..93099347bd 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1375,6 +1375,7 @@ PLUGIN_DAEMON_PORT=5002 PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi PLUGIN_DAEMON_URL=http://plugin_daemon:5002 PLUGIN_MAX_PACKAGE_SIZE=52428800 +PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 PLUGIN_PPROF_ENABLED=false PLUGIN_DEBUGGING_HOST=0.0.0.0 @@ -1398,9 +1399,9 @@ PLUGIN_STDIO_BUFFER_SIZE=1024 PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880 PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 -# Plugin Daemon side timeout (configure to match the API side below) +# Plugin Daemon side timeout (configure to match the API side below) PLUGIN_MAX_EXECUTION_TIMEOUT=600 -# API side timeout (configure to match the Plugin Daemon side above) +# API side timeout (configure to match the Plugin Daemon side above) PLUGIN_DAEMON_TIMEOUT=600.0 # PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple PIP_MIRROR_URL= @@ -1518,4 +1519,31 @@ AMPLITUDE_API_KEY= SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 + + +# Redis URL used for PubSub between API and +# celery worker +# defaults to url constructed from `REDIS_*` +# configurations +PUBSUB_REDIS_URL= +# Pub/sub channel type for streaming events. +# valid options are: +# +# - pubsub: for normal Pub/Sub +# - sharded: for sharded Pub/Sub +# +# It's highly recommended to use sharded Pub/Sub AND redis cluster +# for large deployments. +PUBSUB_REDIS_CHANNEL_TYPE=pubsub +# Whether to use Redis cluster mode while running +# PubSub. +# It's highly recommended to enable this for large deployments. +PUBSUB_REDIS_USE_CLUSTERS=false + +# Whether to Enable human input timeout check task +ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true +# Human input timeout check interval in minutes +HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1 + + SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 9659990383..e27b51bcc0 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.0 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.0 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.0 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.4 + image: langgenius/dify-web:1.12.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -270,7 +270,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.2-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 81c34fc6a2..4a739bbbe0 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -123,7 +123,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.2-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always env_file: - ./middleware.env diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 902ca3103c..a5518ceee9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -589,6 +589,7 @@ x-shared-env: &shared-api-worker-env PLUGIN_DAEMON_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} PLUGIN_DAEMON_URL: ${PLUGIN_DAEMON_URL:-http://plugin_daemon:5002} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_MODEL_SCHEMA_CACHE_TTL: ${PLUGIN_MODEL_SCHEMA_CACHE_TTL:-3600} PLUGIN_PPROF_ENABLED: ${PLUGIN_PPROF_ENABLED:-false} PLUGIN_DEBUGGING_HOST: ${PLUGIN_DEBUGGING_HOST:-0.0.0.0} PLUGIN_DEBUGGING_PORT: ${PLUGIN_DEBUGGING_PORT:-5003} @@ -682,6 +683,11 @@ x-shared-env: &shared-api-worker-env SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21} SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000} SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30} + PUBSUB_REDIS_URL: ${PUBSUB_REDIS_URL:-} + PUBSUB_REDIS_CHANNEL_TYPE: ${PUBSUB_REDIS_CHANNEL_TYPE:-pubsub} + PUBSUB_REDIS_USE_CLUSTERS: ${PUBSUB_REDIS_USE_CLUSTERS:-false} + ENABLE_HUMAN_INPUT_TIMEOUT_TASK: ${ENABLE_HUMAN_INPUT_TIMEOUT_TASK:-true} + HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: ${HUMAN_INPUT_TIMEOUT_TASK_INTERVAL:-1} SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: ${SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL:-90000} services: @@ -706,7 +712,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.0 restart: always environment: # Use the shared environment variables. @@ -748,7 +754,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.0 restart: always environment: # Use the shared environment variables. @@ -787,7 +793,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.4 + image: langgenius/dify-api:1.12.0 restart: always environment: # Use the shared environment variables. @@ -817,7 +823,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.4 + image: langgenius/dify-web:1.12.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -955,7 +961,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.2-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. diff --git a/web/AGENTS.md b/web/AGENTS.md index 7362cd51db..5dd41b8a3c 100644 --- a/web/AGENTS.md +++ b/web/AGENTS.md @@ -1,5 +1,9 @@ +## Frontend Workflow + +- Refer to the `./docs/test.md` and `./docs/lint.md` for detailed frontend workflow instructions. + ## Automated Test Generation -- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests. +- Use `./docs/test.md` as the canonical instruction set for generating frontend automated tests. - When proposing or saving tests, re-read that document and follow every requirement. - All frontend tests MUST also comply with the `frontend-testing` skill. Treat the skill as a mandatory constraint, not optional guidance. diff --git a/web/README.md b/web/README.md index 9c731a081a..64039709dc 100644 --- a/web/README.md +++ b/web/README.md @@ -107,6 +107,8 @@ Open [http://localhost:6006](http://localhost:6006) with your browser to see the If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscode/settings.json` for lint code setting. +Then follow the [Lint Documentation](./docs/lint.md) to lint the code. + ## Test We use [Vitest](https://vitest.dev/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing. diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts index 373c2f86d3..d3296bacd0 100644 --- a/web/__mocks__/provider-context.ts +++ b/web/__mocks__/provider-context.ts @@ -35,6 +35,7 @@ export const baseProviderContextValue: ProviderContextState = { refreshLicenseLimit: noop, isAllowTransferWorkspace: false, isAllowPublishAsCustomKnowledgePipelineTemplate: false, + humanInputEmailDeliveryEnabled: false, } export const createMockProviderContextValue = (overrides: Partial = {}): ProviderContextState => { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx index fc27f84c60..fffc1ff2a5 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -8,7 +8,6 @@ describe('SVG Attribute Error Reproduction', () => { // Capture console errors const originalError = console.error let errorMessages: string[] = [] - beforeEach(() => { errorMessages = [] console.error = vi.fn((message) => { diff --git a/web/app/(humanInputLayout)/form/[token]/form.tsx b/web/app/(humanInputLayout)/form/[token]/form.tsx new file mode 100644 index 0000000000..d027ef8b7d --- /dev/null +++ b/web/app/(humanInputLayout)/form/[token]/form.tsx @@ -0,0 +1,289 @@ +'use client' +import type { ButtonProps } from '@/app/components/base/button' +import type { FormInputItem, UserAction } from '@/app/components/workflow/nodes/human-input/types' +import type { SiteInfo } from '@/models/share' +import type { HumanInputFormError } from '@/service/use-share' +import { + RiCheckboxCircleFill, + RiErrorWarningFill, + RiInformation2Fill, +} from '@remixicon/react' +import { produce } from 'immer' +import { useParams } from 'next/navigation' +import * as React from 'react' +import { useEffect, useMemo, useState } from 'react' +import { useTranslation } from 'react-i18next' +import AppIcon from '@/app/components/base/app-icon' +import Button from '@/app/components/base/button' +import ContentItem from '@/app/components/base/chat/chat/answer/human-input-content/content-item' +import ExpirationTime from '@/app/components/base/chat/chat/answer/human-input-content/expiration-time' +import { getButtonStyle } from '@/app/components/base/chat/chat/answer/human-input-content/utils' +import Loading from '@/app/components/base/loading' +import DifyLogo from '@/app/components/base/logo/dify-logo' +import useDocumentTitle from '@/hooks/use-document-title' +import { useGetHumanInputForm, useSubmitHumanInputForm } from '@/service/use-share' +import { cn } from '@/utils/classnames' + +export type FormData = { + site: { site: SiteInfo } + form_content: string + inputs: FormInputItem[] + resolved_default_values: Record + user_actions: UserAction[] + expiration_time: number +} + +const FormContent = () => { + const { t } = useTranslation() + + const { token } = useParams<{ token: string }>() + useDocumentTitle('') + + const [inputs, setInputs] = useState>({}) + const [success, setSuccess] = useState(false) + + const { mutate: submitForm, isPending: isSubmitting } = useSubmitHumanInputForm() + + const { data: formData, isLoading, error } = useGetHumanInputForm(token) + + const expired = (error as HumanInputFormError | null)?.code === 'human_input_form_expired' + const submitted = (error as HumanInputFormError | null)?.code === 'human_input_form_submitted' + const rateLimitExceeded = (error as HumanInputFormError | null)?.code === 'web_form_rate_limit_exceeded' + + const splitByOutputVar = (content: string): string[] => { + const outputVarRegex = /(\{\{#\$output\.[^#]+#\}\})/g + const parts = content.split(outputVarRegex) + return parts.filter(part => part.length > 0) + } + + const contentList = useMemo(() => { + if (!formData?.form_content) + return [] + return splitByOutputVar(formData.form_content) + }, [formData?.form_content]) + + useEffect(() => { + if (!formData?.inputs) + return + const initialInputs: Record = {} + formData.inputs.forEach((item) => { + initialInputs[item.output_variable_name] = item.default.type === 'variable' ? formData.resolved_default_values[item.output_variable_name] || '' : item.default.value + }) + setInputs(initialInputs) + }, [formData?.inputs, formData?.resolved_default_values]) + + // use immer + const handleInputsChange = (name: string, value: string) => { + const newInputs = produce(inputs, (draft) => { + draft[name] = value + }) + setInputs(newInputs) + } + + const submit = (actionID: string) => { + submitForm( + { token, data: { inputs, action: actionID } }, + { + onSuccess: () => { + setSuccess(true) + }, + }, + ) + } + + if (isLoading) { + return ( + + ) + } + + if (success) { + return ( +
+
+
+
+ +
+
+
{t('humanInput.thanks', { ns: 'share' })}
+
{t('humanInput.recorded', { ns: 'share' })}
+
+
{t('humanInput.submissionID', { id: token, ns: 'share' })}
+
+
+
+
{t('chat.poweredBy', { ns: 'share' })}
+ +
+
+
+
+ ) + } + + if (expired) { + return ( +
+
+
+
+ +
+
+
{t('humanInput.sorry', { ns: 'share' })}
+
{t('humanInput.expired', { ns: 'share' })}
+
+
{t('humanInput.submissionID', { id: token, ns: 'share' })}
+
+
+
+
{t('chat.poweredBy', { ns: 'share' })}
+ +
+
+
+
+ ) + } + + if (submitted) { + return ( +
+
+
+
+ +
+
+
{t('humanInput.sorry', { ns: 'share' })}
+
{t('humanInput.completed', { ns: 'share' })}
+
+
{t('humanInput.submissionID', { id: token, ns: 'share' })}
+
+
+
+
{t('chat.poweredBy', { ns: 'share' })}
+ +
+
+
+
+ ) + } + + if (rateLimitExceeded) { + return ( +
+
+
+
+ +
+
+
{t('humanInput.rateLimitExceeded', { ns: 'share' })}
+
+
+
+
+
{t('chat.poweredBy', { ns: 'share' })}
+ +
+
+
+
+ ) + } + + if (!formData) { + return ( +
+
+
+
+ +
+
+
{t('humanInput.formNotFound', { ns: 'share' })}
+
+
+
+
+
{t('chat.poweredBy', { ns: 'share' })}
+ +
+
+
+
+ ) + } + + const site = formData.site.site + + return ( +
+
+ +
{site.title}
+
+
+
+ {contentList.map((content, index) => ( + + ))} +
+ {formData.user_actions.map((action: UserAction) => ( + + ))} +
+ +
+
+
+
{t('chat.poweredBy', { ns: 'share' })}
+ +
+
+
+
+ ) +} + +export default React.memo(FormContent) diff --git a/web/app/(humanInputLayout)/form/[token]/page.tsx b/web/app/(humanInputLayout)/form/[token]/page.tsx new file mode 100644 index 0000000000..a7e2305b2b --- /dev/null +++ b/web/app/(humanInputLayout)/form/[token]/page.tsx @@ -0,0 +1,13 @@ +'use client' +import * as React from 'react' +import FormContent from './form' + +const FormPage = () => { + return ( +
+ +
+ ) +} + +export default React.memo(FormPage) diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index 113f3b5680..c874990448 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -47,7 +47,7 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => { await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router, webAppLogout, shareCode]) + }, [getSigninUrl, router, shareCode]) if (appInfoError) { return ( diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index 9f89a03993..a2b847f74f 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -31,7 +31,7 @@ const Splash: FC = ({ children }) => { await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router, webAppLogout, shareCode]) + }, [getSigninUrl, router, shareCode]) const [isLoading, setIsLoading] = useState(true) useEffect(() => { diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 255feaccdf..aa31f0201f 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -31,6 +31,7 @@ import { fetchWorkflowDraft } from '@/service/workflow' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import AppIcon from '../base/app-icon' import AppOperations from './app-operations' @@ -145,13 +146,8 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx appID: appDetail.id, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${appDetail.name}.yml` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${appDetail.name}.yml` }) } catch { notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index 4d7c832e04..96127c4210 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -11,6 +11,7 @@ import { datasetDetailQueryKeyPrefix, useInvalidDatasetList } from '@/service/kn import { useInvalid } from '@/service/use-base' import { useExportPipelineDSL } from '@/service/use-pipeline' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import ActionButton from '../../base/action-button' import Confirm from '../../base/confirm' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' @@ -64,13 +65,8 @@ const DropDown = ({ pipelineId: pipeline_id, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${name}.pipeline` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${name}.pipeline` }) } catch { Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/app-sidebar/toggle-button.tsx b/web/app/components/app-sidebar/toggle-button.tsx index a6bdee4f78..cbfbeee452 100644 --- a/web/app/components/app-sidebar/toggle-button.tsx +++ b/web/app/components/app-sidebar/toggle-button.tsx @@ -4,7 +4,7 @@ import { useTranslation } from 'react-i18next' import { cn } from '@/utils/classnames' import Button from '../base/button' import Tooltip from '../base/tooltip' -import { getKeyboardKeyNameBySystem } from '../workflow/utils' +import ShortcutsName from '../workflow/shortcuts-name' type TooltipContentProps = { expand: boolean @@ -20,18 +20,7 @@ const TooltipContent = ({ return (
{expand ? t('sidebar.collapseSidebar', { ns: 'layout' }) : t('sidebar.expandSidebar', { ns: 'layout' })} -
- { - TOGGLE_SHORTCUT.map(key => ( - - {getKeyboardKeyNameBySystem(key)} - - )) - } -
+
) } diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index 5add1aed32..4fc1e26007 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -21,6 +21,7 @@ import { LanguagesSupported } from '@/i18n-config/language' import { clearAllAnnotations, fetchExportAnnotationList } from '@/service/annotation' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import Button from '../../../base/button' import AddAnnotationModal from '../add-annotation-modal' import BatchAddModal from '../batch-add-annotation-modal' @@ -56,28 +57,23 @@ const HeaderOptions: FC = ({ ) const JSONLOutput = () => { - const a = document.createElement('a') const content = listTransformer(list).join('\n') const file = new Blob([content], { type: 'application/jsonl' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `annotations-${locale}.jsonl` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `annotations-${locale}.jsonl` }) } - const fetchList = async () => { + const fetchList = React.useCallback(async () => { const { data }: any = await fetchExportAnnotationList(appId) setList(data as AnnotationItemBasic[]) - } + }, [appId]) useEffect(() => { fetchList() - }, []) + }, [fetchList]) useEffect(() => { if (controlUpdateList) fetchList() - }, [controlUpdateList]) + }, [controlUpdateList, fetchList]) const [showBulkImportModal, setShowBulkImportModal] = useState(false) const [showClearConfirm, setShowClearConfirm] = useState(false) diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 0a026a680b..1348e3111f 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -49,7 +49,8 @@ import Divider from '../../base/divider' import Loading from '../../base/loading' import Toast from '../../base/toast' import Tooltip from '../../base/tooltip' -import { getKeyboardKeyCodeBySystem, getKeyboardKeyNameBySystem } from '../../workflow/utils' +import ShortcutsName from '../../workflow/shortcuts-name' +import { getKeyboardKeyCodeBySystem } from '../../workflow/utils' import AccessControl from '../app-access-control' import PublishWithMultipleModel from './publish-with-multiple-model' import SuggestedAction from './suggested-action' @@ -114,6 +115,7 @@ export type AppPublisherProps = { missingStartNode?: boolean hasTriggerNode?: boolean // Whether workflow currently contains any trigger nodes (used to hide missing-start CTA when triggers exist). startNodeLimitExceeded?: boolean + hasHumanInputNode?: boolean } const PUBLISH_SHORTCUT = ['ctrl', '⇧', 'P'] @@ -137,13 +139,14 @@ const AppPublisher = ({ missingStartNode = false, hasTriggerNode = false, startNodeLimitExceeded = false, + hasHumanInputNode = false, }: AppPublisherProps) => { const { t } = useTranslation() const [published, setPublished] = useState(false) const [open, setOpen] = useState(false) const [showAppAccessControl, setShowAppAccessControl] = useState(false) - const [isAppAccessSet, setIsAppAccessSet] = useState(true) + const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false) const appDetail = useAppStore(state => state.appDetail) @@ -160,6 +163,13 @@ const AppPublisher = ({ const { data: appAccessSubjects, isLoading: isGettingAppWhiteListSubjects } = useAppWhiteListSubjects(appDetail?.id, open && systemFeatures.webapp_auth.enabled && appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS) const openAsyncWindow = useAsyncWindowOpen() + const isAppAccessSet = useMemo(() => { + if (appDetail && appAccessSubjects) { + return !(appDetail.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && appAccessSubjects.groups?.length === 0 && appAccessSubjects.members?.length === 0) + } + return true + }, [appAccessSubjects, appDetail]) + const noAccessPermission = useMemo(() => systemFeatures.webapp_auth.enabled && appDetail && appDetail.access_mode !== AccessMode.EXTERNAL_MEMBERS && !userCanAccessApp?.result, [systemFeatures, appDetail, userCanAccessApp]) const disabledFunctionButton = useMemo(() => (!publishedAt || missingStartNode || noAccessPermission), [publishedAt, missingStartNode, noAccessPermission]) @@ -170,25 +180,13 @@ const AppPublisher = ({ return t('noUserInputNode', { ns: 'app' }) if (noAccessPermission) return t('noAccessPermission', { ns: 'app' }) - }, [missingStartNode, noAccessPermission, publishedAt]) + }, [missingStartNode, noAccessPermission, publishedAt, t]) useEffect(() => { if (systemFeatures.webapp_auth.enabled && open && appDetail) refetch() }, [open, appDetail, refetch, systemFeatures]) - useEffect(() => { - if (appDetail && appAccessSubjects) { - if (appDetail.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && appAccessSubjects.groups?.length === 0 && appAccessSubjects.members?.length === 0) - setIsAppAccessSet(false) - else - setIsAppAccessSet(true) - } - else { - setIsAppAccessSet(true) - } - }, [appAccessSubjects, appDetail]) - const handlePublish = useCallback(async (params?: ModelAndParameter | PublishWorkflowParams) => { try { await onPublish?.(params) @@ -345,13 +343,7 @@ const AppPublisher = ({ : (
{t('common.publishUpdate', { ns: 'workflow' })} -
- {PUBLISH_SHORTCUT.map(key => ( - - {getKeyboardKeyNameBySystem(key)} - - ))} -
+
) } @@ -466,7 +458,7 @@ const AppPublisher = ({ {t('common.accessAPIReference', { ns: 'workflow' })} - {appDetail?.mode === AppModeEnum.WORKFLOW && ( + {appDetail?.mode === AppModeEnum.WORKFLOW && !hasHumanInputNode && ( { const saveButton = await screen.findByRole('button', { name: 'common.operation.save' }) fireEvent.click(saveButton) - expect(onPromptVariablesChange).toHaveBeenCalledTimes(1) + await waitFor(() => { + expect(onPromptVariablesChange).toHaveBeenCalledTimes(1) + }) }) it('should show error when variable key is duplicated', async () => { diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index e2b50cf030..66c7bce80c 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { AppIconSelection } from '../../base/app-icon-picker' -import { RiArrowRightLine, RiArrowRightSLine, RiCommandLine, RiCornerDownLeftLine, RiExchange2Fill } from '@remixicon/react' +import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon/react' import { useDebounceFn, useKeyPress } from 'ahooks' import Image from 'next/image' @@ -29,6 +29,7 @@ import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' import { basePath } from '@/utils/var' import AppIconPicker from '../../base/app-icon-picker' +import ShortcutsName from '../../workflow/shortcuts-name' type CreateAppProps = { onSuccess: () => void @@ -269,10 +270,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 838e9cc03f..04d8b1e754 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { MouseEventHandler } from 'react' -import { RiCloseLine, RiCommandLine, RiCornerDownLeftLine } from '@remixicon/react' +import { RiCloseLine } from '@remixicon/react' import { useDebounceFn, useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' import { useRouter } from 'next/navigation' @@ -28,6 +28,7 @@ import { } from '@/service/apps' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import ShortcutsName from '../../workflow/shortcuts-name' import Uploader from './uploader' type CreateFromDSLModalProps = { @@ -298,10 +299,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS className="gap-1" > {t('newApp.Create', { ns: 'app' })} -
- - -
+ diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index b13eec2e3d..40519dcb36 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -68,6 +68,7 @@ type IDrawerContext = { } type StatusCount = { + paused: number success: number failed: number partial_success: number @@ -93,7 +94,15 @@ const statusTdRender = (statusCount: StatusCount) => { if (!statusCount) return null - if (statusCount.partial_success + statusCount.failed === 0) { + if (statusCount.paused > 0) { + return ( +
+ + Pending +
+ ) + } + else if (statusCount.partial_success + statusCount.failed === 0) { return (
@@ -296,7 +305,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { if (abortControllerRef.current === controller) abortControllerRef.current = null } - }, [detail.id, hasMore, timezone, t, appDetail, detail?.model_config?.configs?.introduction]) + }, [detail.id, hasMore, timezone, t, appDetail]) // Derive chatItemTree, threadChatItems, and oldestAnswerIdRef from allChatItems useEffect(() => { @@ -411,7 +420,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) return false } - }, [allChatItems, appDetail?.id, t]) + }, [allChatItems, appDetail?.id, notify, t]) const fetchInitiated = useRef(false) @@ -504,7 +513,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { finally { setIsLoading(false) } - }, [detail.id, hasMore, isLoading, timezone, t, appDetail, detail?.model_config?.configs?.introduction]) + }, [detail.id, hasMore, isLoading, timezone, t, appDetail]) const handleScroll = useCallback(() => { const scrollableDiv = document.getElementById('scrollableDiv') diff --git a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx index 17857ec702..54763907df 100644 --- a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx +++ b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx @@ -53,6 +53,7 @@ const defaultProviderContext = { refreshLicenseLimit: noop, isAllowTransferWorkspace: false, isAllowPublishAsCustomKnowledgePipelineTemplate: false, + humanInputEmailDeliveryEnabled: false, } const defaultModalContext: ModalContextState = { diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index c39282a022..22358805a7 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -8,7 +8,7 @@ import { RiClipboardLine, RiFileList3Line, RiPlayList2Line, - RiReplay15Line, + RiResetLeftLine, RiSparklingFill, RiSparklingLine, RiThumbDownLine, @@ -18,10 +18,12 @@ import { useBoolean } from 'ahooks' import copy from 'copy-to-clipboard' import { useParams } from 'next/navigation' import * as React from 'react' -import { useEffect, useState } from 'react' +import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useStore as useAppStore } from '@/app/components/app/store' import ActionButton, { ActionButtonState } from '@/app/components/base/action-button' +import HumanInputFilledFormList from '@/app/components/base/chat/chat/answer/human-input-filled-form-list' +import HumanInputFormList from '@/app/components/base/chat/chat/answer/human-input-form-list' import WorkflowProcessItem from '@/app/components/base/chat/chat/answer/workflow-process' import { useChatContext } from '@/app/components/base/chat/chat/context' import Loading from '@/app/components/base/loading' @@ -29,7 +31,8 @@ import { Markdown } from '@/app/components/base/markdown' import NewAudioButton from '@/app/components/base/new-audio-button' import Toast from '@/app/components/base/toast' import { fetchTextGenerationMessage } from '@/service/debug' -import { AppSourceType, fetchMoreLikeThis, updateFeedback } from '@/service/share' +import { AppSourceType, fetchMoreLikeThis, submitHumanInputForm, updateFeedback } from '@/service/share' +import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow' import { cn } from '@/utils/classnames' import ResultTab from './result-tab' @@ -121,7 +124,7 @@ const GenerationItem: FC = ({ const [isQuerying, { setTrue: startQuerying, setFalse: stopQuerying }] = useBoolean(false) const childProps = { - isInWebApp: true, + isInWebApp, content: completionRes, messageId: childMessageId, depth: depth + 1, @@ -202,16 +205,22 @@ const GenerationItem: FC = ({ } const [currentTab, setCurrentTab] = useState('DETAIL') - const showResultTabs = !!workflowProcessData?.resultText || !!workflowProcessData?.files?.length + const showResultTabs = !!workflowProcessData?.resultText || !!workflowProcessData?.files?.length || (workflowProcessData?.humanInputFormDataList && workflowProcessData?.humanInputFormDataList.length > 0) || (workflowProcessData?.humanInputFilledFormDataList && workflowProcessData?.humanInputFilledFormDataList.length > 0) const switchTab = async (tab: string) => { setCurrentTab(tab) } useEffect(() => { - if (workflowProcessData?.resultText || !!workflowProcessData?.files?.length) + if (workflowProcessData?.resultText || !!workflowProcessData?.files?.length || (workflowProcessData?.humanInputFormDataList && workflowProcessData?.humanInputFormDataList.length > 0) || (workflowProcessData?.humanInputFilledFormDataList && workflowProcessData?.humanInputFilledFormDataList.length > 0)) switchTab('RESULT') else switchTab('DETAIL') - }, [workflowProcessData?.files?.length, workflowProcessData?.resultText]) + }, [workflowProcessData?.files?.length, workflowProcessData?.resultText, workflowProcessData?.humanInputFormDataList, workflowProcessData?.humanInputFilledFormDataList]) + const handleSubmitHumanInputForm = useCallback(async (formToken: string, formData: { inputs: Record, action: string }) => { + if (appSourceType === AppSourceType.installedApp) + await submitHumanInputFormService(formToken, formData) + else + await submitHumanInputForm(formToken, formData) + }, [appSourceType]) return ( <> @@ -275,7 +284,24 @@ const GenerationItem: FC = ({ )}
{!isError && ( - + <> + {currentTab === 'RESULT' && workflowProcessData.humanInputFormDataList && workflowProcessData.humanInputFormDataList.length > 0 && ( +
+ +
+ )} + {currentTab === 'RESULT' && workflowProcessData.humanInputFilledFormDataList && workflowProcessData.humanInputFilledFormDataList.length > 0 && ( +
+ +
+ )} + + )} )} @@ -348,7 +374,7 @@ const GenerationItem: FC = ({ )} {isInWebApp && isError && ( - + )} {isInWebApp && !isWorkflow && !isTryApp && ( diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index b9597c8ea1..262efad781 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -81,6 +81,14 @@ const WorkflowAppLogList: FC = ({ logs, appDetail, onRefresh }) => { ) } + if (status === 'paused') { + return ( +
+ + Pending +
+ ) + } if (status === 'running') { return (
diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index f1eadb9d05..730a39b68d 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -33,6 +33,7 @@ import { fetchWorkflowDraft } from '@/service/workflow' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import { formatTime } from '@/utils/time' import { basePath } from '@/utils/var' @@ -161,13 +162,8 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { appID: app.id, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${app.name}.yml` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${app.name}.yml` }) } catch { notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) @@ -346,7 +342,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { dateFormat: `${t('segment.dateTimeFormat', { ns: 'datasetDocuments' })}`, }) return `${t('segment.editedAt', { ns: 'datasetDocuments' })} ${timeText}` - }, [app.updated_at, app.created_at]) + }, [app.updated_at, app.created_at, t]) return ( <> diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index 255bfbf9c5..3be8492489 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -105,6 +105,7 @@ const Apps = () => { {isShowTryAppPanel && ( { }, [appParams, currentConversationItem?.introduction]) const { chatList, - setTargetMessageId, handleSend, handleStop, + handleSwitchSibling, isResponding: respondingState, suggestedQuestions, } = useChat( @@ -122,8 +125,11 @@ const ChatWrapper = () => { if (fileIsUploading) return true + + if (chatList.some(item => item.isAnswer && item.humanInputFormDataList && item.humanInputFormDataList.length > 0)) + return true return false - }, [inputsFormValue, inputsForms, allInputsHidden]) + }, [allInputsHidden, inputsForms, chatList, inputsFormValue]) useEffect(() => { if (currentChatInstanceRef.current) @@ -134,6 +140,40 @@ const ChatWrapper = () => { setIsResponding(respondingState) }, [respondingState, setIsResponding]) + // Resume paused workflows when chat history is loaded + useEffect(() => { + if (!appPrevChatTree || appPrevChatTree.length === 0) + return + + // Find the last answer item with workflow_run_id that needs resumption (DFS - find deepest first) + let lastPausedNode: ChatItemInTree | undefined + const findLastPausedWorkflow = (nodes: ChatItemInTree[]) => { + nodes.forEach((node) => { + // DFS: recurse to children first + if (node.children && node.children.length > 0) + findLastPausedWorkflow(node.children) + + // Track the last node with humanInputFormDataList + if (node.isAnswer && node.workflow_run_id && node.humanInputFormDataList && node.humanInputFormDataList.length > 0) + lastPausedNode = node + }) + } + + findLastPausedWorkflow(appPrevChatTree) + + // Only resume the last paused workflow + if (lastPausedNode) { + handleSwitchSibling( + lastPausedNode.id, + { + onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), + onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted, + isPublicAPI: appSourceType === AppSourceType.webApp, + }, + ) + } + }, []) + const doSend: OnSend = useCallback((message, files, isRegenerate = false, parentAnswer: ChatItem | null = null) => { const data: any = { query: message, @@ -149,10 +189,10 @@ const ChatWrapper = () => { { onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), onConversationComplete: isHistoryConversation ? undefined : handleNewConversationCompleted, - isPublicAPI: !isInstalledApp, + isPublicAPI: appSourceType === AppSourceType.webApp, }, ) - }, [chatList, handleNewConversationCompleted, handleSend, currentConversationId, currentConversationInputs, newConversationInputs, isInstalledApp, appId]) + }, [inputsForms, currentConversationId, currentConversationInputs, newConversationInputs, chatList, handleSend, appSourceType, appId, isHistoryConversation, handleNewConversationCompleted]) const doRegenerate = useCallback((chatItem: ChatItem, editedQuestion?: { message: string, files?: FileEntity[] }) => { const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)! @@ -160,12 +200,27 @@ const ChatWrapper = () => { doSend(editedQuestion ? editedQuestion.message : question.content, editedQuestion ? editedQuestion.files : question.message_files, true, isValidGeneratedAnswer(parentAnswer) ? parentAnswer : null) }, [chatList, doSend]) + const doSwitchSibling = useCallback((siblingMessageId: string) => { + handleSwitchSibling(siblingMessageId, { + onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), + onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted, + isPublicAPI: appSourceType === AppSourceType.webApp, + }) + }, [handleSwitchSibling, currentConversationId, handleNewConversationCompleted, appSourceType, appId]) + const messageList = useMemo(() => { if (currentConversationId || chatList.length > 1) return chatList // Without messages we are in the welcome screen, so hide the opening statement from chatlist return chatList.filter(item => !item.isOpeningStatement) - }, [chatList]) + }, [chatList, currentConversationId]) + + const handleSubmitHumanInputForm = useCallback(async (formToken: string, formData: any) => { + if (isInstalledApp) + await submitHumanInputFormService(formToken, formData) + else + await submitHumanInputForm(formToken, formData) + }, [isInstalledApp]) const [collapsed, setCollapsed] = useState(!!currentConversationId) @@ -274,6 +329,7 @@ const ChatWrapper = () => { inputsForm={inputsForms} onRegenerate={doRegenerate} onStopResponding={handleStop} + onHumanInputFormSubmit={handleSubmitHumanInputForm} chatNode={( <> {chatNode} @@ -286,7 +342,7 @@ const ChatWrapper = () => { answerIcon={answerIcon} hideProcessDetail themeBuilder={themeBuilder} - switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)} + switchSibling={doSwitchSibling} inputDisabled={inputDisabled} sidebarCollapseState={sidebarCollapseState} questionIcon={ diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index ad1de38d07..da344a9789 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -1,3 +1,4 @@ +import type { ExtraContent } from '../chat/type' import type { Callback, ChatConfig, @@ -9,6 +10,7 @@ import type { AppData, ConversationItem, } from '@/models/share' +import type { HumanInputFilledFormData, HumanInputFormData } from '@/types/workflow' import { useLocalStorageState } from 'ahooks' import { noop } from 'es-toolkit/function' import { produce } from 'immer' @@ -57,6 +59,24 @@ function getFormattedChatList(messages: any[]) { parentMessageId: item.parent_message_id || undefined, }) const answerFiles = item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [] + const humanInputFormDataList: HumanInputFormData[] = [] + const humanInputFilledFormDataList: HumanInputFilledFormData[] = [] + let workflowRunId = '' + if (item.status === 'paused') { + item.extra_contents?.forEach((content: ExtraContent) => { + if (content.type === 'human_input' && !content.submitted) { + humanInputFormDataList.push(content.form_definition) + workflowRunId = content.workflow_run_id + } + }) + } + else if (item.status === 'normal') { + item.extra_contents?.forEach((content: ExtraContent) => { + if (content.type === 'human_input' && content.submitted) { + humanInputFilledFormDataList.push(content.form_submission_data) + } + }) + } newChatList.push({ id: item.id, content: item.answer, @@ -66,6 +86,9 @@ function getFormattedChatList(messages: any[]) { citation: item.retriever_resources, message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id, upload_file_id: item.upload_file_id }))), parentMessageId: `question-${item.id}`, + humanInputFormDataList, + humanInputFilledFormDataList, + workflow_run_id: workflowRunId, }) }) return newChatList diff --git a/web/app/components/base/chat/chat/answer/human-input-content/content-item.tsx b/web/app/components/base/chat/chat/answer/human-input-content/content-item.tsx new file mode 100644 index 0000000000..3ed777d41e --- /dev/null +++ b/web/app/components/base/chat/chat/answer/human-input-content/content-item.tsx @@ -0,0 +1,54 @@ +import type { ContentItemProps } from './type' +import * as React from 'react' +import { useMemo } from 'react' +import { Markdown } from '@/app/components/base/markdown' +import Textarea from '@/app/components/base/textarea' + +const ContentItem = ({ + content, + formInputFields, + inputs, + onInputChange, +}: ContentItemProps) => { + const isInputField = (field: string) => { + const outputVarRegex = /\{\{#\$output\.[^#]+#\}\}/ + return outputVarRegex.test(field) + } + + const extractFieldName = (str: string): string => { + const outputVarRegex = /\{\{#\$output\.([^#]+)#\}\}/ + const match = str.match(outputVarRegex) + return match ? match[1] : '' + } + + const fieldName = useMemo(() => { + return extractFieldName(content) + }, [content]) + + const formInputField = useMemo(() => { + return formInputFields.find(field => field.output_variable_name === fieldName) + }, [formInputFields, fieldName]) + + if (!isInputField(content)) { + return ( + + ) + } + + if (!formInputField) + return null + + return ( +
+ {formInputField.type === 'paragraph' && ( +