diff --git a/api/.env.example b/api/.env.example index 6cfe0266c2..f6f65011ea 100644 --- a/api/.env.example +++ b/api/.env.example @@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y MARKETPLACE_ENABLED=true MARKETPLACE_API_URL=https://marketplace.dify.ai +# Creators Platform configuration +CREATORS_PLATFORM_FEATURES_ENABLED=true +CREATORS_PLATFORM_API_URL=https://creators.dify.ai +CREATORS_PLATFORM_OAUTH_CLIENT_ID= + # Endpoint configuration ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} diff --git a/api/commands/plugin.py b/api/commands/plugin.py index c34391025a..8bd5392d7b 100644 --- a/api/commands/plugin.py +++ b/api/commands/plugin.py @@ -11,7 +11,7 @@ from configs import dify_config from core.helper import encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.plugin import PluginInstaller -from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params +from core.tools.utils.system_encryption import encrypt_system_params from extensions.ext_database import db from models import Tenant from models.oauth import DatasourceOauthParamConfig, DatasourceProvider @@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params): click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) + oauth_client_params = encrypt_system_params(client_params_dict) click.echo(click.style("Client params encrypted successfully.", fg="green")) except Exception as e: click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) @@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params): click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) + oauth_client_params = encrypt_system_params(client_params_dict) click.echo(click.style("Client params encrypted successfully.", fg="green")) except Exception as e: click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index ae49ae47d0..52e33c1789 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings): ) +class CreatorsPlatformConfig(BaseSettings): + """ + Configuration for Creators Platform integration + """ + + CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field( + description="Enable or disable Creators Platform features", + default=True, + ) + + CREATORS_PLATFORM_API_URL: HttpUrl = Field( + description="Creators Platform API URL", + default=HttpUrl("https://creators.dify.ai"), + ) + + CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field( + description="OAuth client ID for Creators Platform integration", + default="", + ) + + class EndpointConfig(BaseSettings): """ Configuration for various application endpoints and URLs @@ -1379,6 +1400,7 @@ class FeatureConfig( AuthConfig, # Changed from OAuthConfig to AuthConfig BillingConfig, CodeExecutionSandboxConfig, + CreatorsPlatformConfig, TriggerConfig, AsyncWorkflowConfig, PluginConfig, diff --git a/api/controllers/common/human_input.py b/api/controllers/common/human_input.py new file mode 100644 index 0000000000..5d6f4efb95 --- /dev/null +++ b/api/controllers/common/human_input.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, JsonValue + + +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict[str, JsonValue] + action: str diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 9102983d86..a736fc8bc8 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -692,6 +692,32 @@ class AppExportApi(Resource): return payload.model_dump(mode="json") +@console_ns.route("/apps//publish-to-creators-platform") +class AppPublishToCreatorsPlatformApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=None) + @edit_permission_required + def post(self, app_model): + """Publish app to Creators Platform""" + from configs import dify_config + from core.helper.creators import get_redirect_url, upload_dsl + + if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED: + return {"error": "Creators Platform features are not enabled"}, 403 + + current_user, _ = current_account_with_tenant() + + dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False) + dsl_bytes = dsl_content.encode("utf-8") + + claim_code = upload_dsl(dsl_bytes) + redirect_url = get_redirect_url(str(current_user.id), claim_code) + + return {"redirect_url": redirect_url} + + @console_ns.route("/apps//name") class AppNameApi(Resource): @console_ns.doc("check_app_name") diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 845af37365..79b3e6cc9f 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -8,10 +8,10 @@ from collections.abc import Generator from flask import Response, jsonify, request from flask_restx import Resource -from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from controllers.common.human_input import HumanInputFormSubmitPayload from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError @@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.apps.message_generator import MessageGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App from models.enums import CreatorUserRole -from models.human_input import RecipientType from models.model import AppMode from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory @@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream logger = logging.getLogger(__name__) -class HumanInputFormSubmitPayload(BaseModel): - inputs: dict - action: str - - def _jsonify_form_definition(form: Form) -> Response: payload = form.get_definition().model_dump() payload["expiration_time"] = int(form.expiration_time.timestamp()) @@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource): if form.tenant_id != current_tenant_id: raise NotFoundError("App not found") + @staticmethod + def _ensure_console_recipient_type(form: Form) -> None: + if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE): + raise NotFoundError("form not found") + @setup_required @login_required @account_initialization_required @@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource): raise NotFoundError(f"form not found, token={form_token}") self._ensure_console_access(form) - + self._ensure_console_recipient_type(form) recipient_type = form.recipient_type - if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}: - raise NotFoundError(f"form not found, token={form_token}") # The type checker is not smart enought to validate the following invariant. # So we need to assert it manually. assert recipient_type is not None, "recipient_type cannot be None here." diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 614bf03ea5..f73e2da54e 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -37,6 +37,11 @@ class TagBindingRemovePayload(BaseModel): type: TagType = Field(description="Tag type") +class TagBindingItemDeletePayload(BaseModel): + target_id: str = Field(description="Target ID to unbind tag from") + type: TagType = Field(description="Tag type") + + class TagListQueryParam(BaseModel): type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter") keyword: str | None = Field(None, description="Search keyword") @@ -70,6 +75,7 @@ register_schema_models( TagBasePayload, TagBindingPayload, TagBindingRemovePayload, + TagBindingItemDeletePayload, TagListQueryParam, TagResponse, ) @@ -152,41 +158,107 @@ class TagUpdateDeleteApi(Resource): return "", 204 -@console_ns.route("/tag-bindings/create") -class TagBindingCreateApi(Resource): +def _require_tag_binding_edit_permission() -> None: + """ + Ensure the current account can edit tag bindings. + + Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant. + """ + current_user, _ = current_account_with_tenant() + # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() + + +def _create_tag_bindings() -> tuple[dict[str, str], int]: + _require_tag_binding_edit_permission() + + payload = TagBindingPayload.model_validate(console_ns.payload or {}) + TagService.save_tag_binding( + TagBindingCreatePayload( + tag_ids=payload.tag_ids, + target_id=payload.target_id, + type=payload.type, + ) + ) + return {"result": "success"}, 200 + + +def _remove_tag_binding() -> tuple[dict[str, str], int]: + _require_tag_binding_edit_permission() + + payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding( + TagBindingDeletePayload( + tag_id=payload.tag_id, + target_id=payload.target_id, + type=payload.type, + ) + ) + return {"result": "success"}, 200 + + +@console_ns.route("/tag-bindings") +class TagBindingCollectionApi(Resource): + """Canonical collection resource for tag binding creation.""" + + @console_ns.doc("create_tag_binding") @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() + return _create_tag_bindings() - payload = TagBindingPayload.model_validate(console_ns.payload or {}) - TagService.save_tag_binding( - TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type) + +@console_ns.route("/tag-bindings/") +class TagBindingItemApi(Resource): + """Canonical item resource for tag binding deletion.""" + + @console_ns.doc("delete_tag_binding") + @console_ns.doc(params={"id": "Tag ID"}) + @console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__]) + @setup_required + @login_required + @account_initialization_required + def delete(self, id): + _require_tag_binding_edit_permission() + payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding( + TagBindingDeletePayload( + tag_id=str(id), + target_id=payload.target_id, + type=payload.type, + ) ) - return {"result": "success"}, 200 +@console_ns.route("/tag-bindings/create") +class DeprecatedTagBindingCreateApi(Resource): + """Deprecated verb-based alias for tag binding creation.""" + + @console_ns.doc("create_tag_binding_deprecated") + @console_ns.doc(deprecated=True) + @console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.") + @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + def post(self): + return _create_tag_bindings() + + @console_ns.route("/tag-bindings/remove") -class TagBindingDeleteApi(Resource): +class DeprecatedTagBindingRemoveApi(Resource): + """Deprecated verb-based alias for tag binding deletion.""" + + @console_ns.doc("delete_tag_binding_deprecated") + @console_ns.doc(deprecated=True) + @console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.") @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() - - payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) - TagService.delete_tag_binding( - TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type) - ) - - return {"result": "success"}, 200 + return _remove_tag_binding() diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 4f7f7d9a98..182631e8f5 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -23,9 +23,11 @@ from .app import ( conversation, file, file_preview, + human_input_form, message, site, workflow, + workflow_events, ) from .dataset import ( dataset, @@ -50,6 +52,7 @@ __all__ = [ "file", "file_preview", "hit_testing", + "human_input_form", "index", "message", "metadata", @@ -58,6 +61,7 @@ __all__ = [ "segment", "site", "workflow", + "workflow_events", ] api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/human_input_form.py b/api/controllers/service_api/app/human_input_form.py new file mode 100644 index 0000000000..8e5003dbbf --- /dev/null +++ b/api/controllers/service_api/app/human_input_form.py @@ -0,0 +1,137 @@ +""" +Service API human input form endpoints. + +This module exposes app-token authenticated APIs for fetching and submitting +paused human input forms in workflow/chatflow runs. +""" + +import json +import logging +from datetime import datetime + +from flask import Response +from flask_restx import Resource +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.common.human_input import HumanInputFormSubmitPayload +from controllers.common.schema import register_schema_models +from controllers.service_api import service_api_ns +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface +from extensions.ext_database import db +from models.model import App, EndUser +from services.human_input_service import Form, FormNotFoundError, HumanInputService + +logger = logging.getLogger(__name__) + + +register_schema_models(service_api_ns, HumanInputFormSubmitPayload) + + +def _stringify_default_values(values: dict[str, object]) -> dict[str, str]: + result: dict[str, str] = {} + for key, value in values.items(): + if value is None: + result[key] = "" + elif isinstance(value, (dict, list)): + result[key] = json.dumps(value, ensure_ascii=False) + else: + result[key] = str(value) + return result + + +def _to_timestamp(value: datetime) -> int: + return int(value.timestamp()) + + +def _jsonify_form_definition(form: Form) -> Response: + definition_payload = form.get_definition().model_dump() + payload = { + "form_content": definition_payload["rendered_content"], + "inputs": definition_payload["inputs"], + "resolved_default_values": _stringify_default_values(definition_payload["default_values"]), + "user_actions": definition_payload["user_actions"], + "expiration_time": _to_timestamp(form.expiration_time), + } + return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json") + + +def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None: + if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id: + raise NotFound("Form not found") + + +def _ensure_form_is_allowed_for_service_api(form: Form) -> None: + # Keep app-token callers scoped to the public web-form surface; internal HITL + # routes must continue to flow through console-only authentication. + if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API): + raise NotFound("Form not found") + + +@service_api_ns.route("/form/human_input/") +class WorkflowHumanInputFormApi(Resource): + @service_api_ns.doc("get_human_input_form") + @service_api_ns.doc(description="Get a paused human input form by token") + @service_api_ns.doc(params={"form_token": "Human input form token"}) + @service_api_ns.doc( + responses={ + 200: "Form retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Form not found", + 412: "Form already submitted or expired", + } + ) + @validate_app_token + def get(self, app_model: App, form_token: str): + service = HumanInputService(db.engine) + form = service.get_form_by_token(form_token) + if form is None: + raise NotFound("Form not found") + + _ensure_form_belongs_to_app(form, app_model) + _ensure_form_is_allowed_for_service_api(form) + service.ensure_form_active(form) + return _jsonify_form_definition(form) + + @service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__]) + @service_api_ns.doc("submit_human_input_form") + @service_api_ns.doc(description="Submit a paused human input form by token") + @service_api_ns.doc(params={"form_token": "Human input form token"}) + @service_api_ns.doc( + responses={ + 200: "Form submitted successfully", + 400: "Bad request - invalid submission data", + 401: "Unauthorized - invalid API token", + 404: "Form not found", + 412: "Form already submitted or expired", + } + ) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser, form_token: str): + payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {}) + + service = HumanInputService(db.engine) + form = service.get_form_by_token(form_token) + if form is None: + raise NotFound("Form not found") + + _ensure_form_belongs_to_app(form, app_model) + _ensure_form_is_allowed_for_service_api(form) + + recipient_type = form.recipient_type + if recipient_type is None: + logger.warning("Recipient type is None for form, form_id=%s", form.id) + raise BadRequest("Form recipient type is invalid") + + try: + service.submit_form_by_token( + recipient_type=recipient_type, + form_token=form_token, + selected_action_id=payload.action, + form_data=payload.inputs, + submission_end_user_id=end_user.id, + ) + except FormNotFoundError: + raise NotFound("Form not found") + + return {}, 200 diff --git a/api/controllers/service_api/app/workflow_events.py b/api/controllers/service_api/app/workflow_events.py new file mode 100644 index 0000000000..b281b271c0 --- /dev/null +++ b/api/controllers/service_api/app/workflow_events.py @@ -0,0 +1,142 @@ +""" +Service API workflow resume event stream endpoints. +""" + +import json +from collections.abc import Generator + +from flask import Response, request +from flask_restx import Resource +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +from controllers.service_api import service_api_ns +from controllers.service_api.app.error import NotWorkflowAppError +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.message_generator import MessageGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.task_entities import StreamEvent +from core.workflow.human_input_policy import HumanInputSurface +from extensions.ext_database import db +from models.enums import CreatorUserRole +from models.model import App, AppMode, EndUser +from repositories.factory import DifyAPIRepositoryFactory +from services.workflow_event_snapshot_service import build_workflow_event_stream + + +@service_api_ns.route("/workflow//events") +class WorkflowEventsApi(Resource): + """Service API for getting workflow execution events after resume.""" + + @service_api_ns.doc("get_workflow_events") + @service_api_ns.doc(description="Get workflow execution events stream after resume") + @service_api_ns.doc( + params={ + "task_id": "Workflow run ID", + "user": "End user identifier (query param)", + "include_state_snapshot": ( + "Whether to replay from persisted state snapshot, " + 'specify `"true"` to include a status snapshot of executed nodes' + ), + "continue_on_pause": ( + "Whether to keep the stream open across workflow_paused events," + 'specify `"true"` to keep the stream open for `workflow_paused` events.' + ), + } + ) + @service_api_ns.doc( + responses={ + 200: "SSE event stream", + 401: "Unauthorized - invalid API token", + 404: "Workflow run not found", + } + ) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) + def get(self, app_model: App, end_user: EndUser, task_id: str): + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}: + raise NotWorkflowAppError() + + session_maker = sessionmaker(db.engine) + repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + workflow_run = repo.get_workflow_run_by_id_and_tenant_id( + tenant_id=app_model.tenant_id, + run_id=task_id, + ) + + if workflow_run is None: + raise NotFound("Workflow run not found") + + if workflow_run.app_id != app_model.id: + raise NotFound("Workflow run not found") + + if workflow_run.created_by_role != CreatorUserRole.END_USER: + raise NotFound("Workflow run not found") + + if workflow_run.created_by != end_user.id: + raise NotFound("Workflow run not found") + + workflow_run_entity = workflow_run + + if workflow_run_entity.finished_at is not None: + response = WorkflowResponseConverter.workflow_run_result_to_finish_response( + task_id=workflow_run_entity.id, + workflow_run=workflow_run_entity, + creator_user=end_user, + ) + + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + + def _generate_finished_events() -> Generator[str, None, None]: + yield f"data: {json.dumps(payload)}\n\n" + + event_generator = _generate_finished_events + else: + msg_generator = MessageGenerator() + generator: BaseAppGenerator + if app_mode == AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + elif app_mode == AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + else: + raise NotWorkflowAppError() + + include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" + continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true" + terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None + + def _generate_stream_events(): + if include_state_snapshot: + return generator.convert_to_event_stream( + build_workflow_event_stream( + app_mode=app_mode, + workflow_run=workflow_run_entity, + tenant_id=app_model.tenant_id, + app_id=app_model.id, + session_maker=session_maker, + human_input_surface=HumanInputSurface.SERVICE_API, + close_on_pause=not continue_on_pause, + ) + ) + return generator.convert_to_event_stream( + msg_generator.retrieve_events( + app_mode, + workflow_run_entity.id, + terminal_events=terminal_events, + ), + ) + + event_generator = _generate_stream_events + + return Response( + event_generator(), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 44876f8303..1ddf2e0717 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict from flask import Response, request from flask_restx import Resource -from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.human_input import HumanInputFormSubmitPayload from controllers.web import web_ns from controllers.web.error import NotFoundError, WebFormRateLimitExceededError from controllers.web.site import serialize_app_site_payload @@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ logger = logging.getLogger(__name__) -class HumanInputFormSubmitPayload(BaseModel): - inputs: dict - action: str - - _FORM_SUBMIT_RATE_LIMITER = RateLimiter( prefix="web_form_submit_rate_limit", max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 9e64b471cb..b79d5514b4 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -34,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse +from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, +) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager @@ -655,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user: Account | EndUser, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, - ) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]: + ) -> ( + ChatbotAppBlockingResponse + | AdvancedChatPausedBlockingResponse + | Generator[ChatbotAppStreamResponse, None, None] + ): """ Handle response. :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index fe2702ed69..7cb0c9a8d3 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -3,7 +3,7 @@ from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( - AppBlockingResponse, + AdvancedChatPausedBlockingResponse, AppStreamResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, @@ -12,22 +12,40 @@ from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeStartStreamResponse, PingStreamResponse, + StreamEvent, ) -class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = ChatbotAppBlockingResponse - +class AdvancedChatAppGenerateResponseConverter( + AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse] +): @classmethod - def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_full_response( + cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking full response. :param blocking_response: blocking response :return: """ - blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) + if isinstance(blocking_response, AdvancedChatPausedBlockingResponse): + paused_data = blocking_response.data.model_dump(mode="json") + return { + "event": StreamEvent.WORKFLOW_PAUSED.value, + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, + "workflow_run_id": blocking_response.data.workflow_run_id, + "data": paused_data, + } + response = { - "event": "message", + "event": StreamEvent.MESSAGE.value, "task_id": blocking_response.task_id, "id": blocking_response.data.id, "message_id": blocking_response.data.message_id, @@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_simple_response( + cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking simple response. :param blocking_response: blocking response @@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) return response diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 78b582bdf5..82dbf5381d 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -53,14 +53,18 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, ErrorStreamResponse, + HumanInputRequiredPauseReasonPayload, + HumanInputRequiredResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, PingStreamResponse, StreamResponse, + WorkflowPauseStreamResponse, WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if message.status == MessageStatus.PAUSED and message.answer: self._task_state.answer = message.answer - def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: + def process( + self, + ) -> Union[ + ChatbotAppBlockingResponse, + AdvancedChatPausedBlockingResponse, + Generator[ChatbotAppStreamResponse, None, None], + ]: """ Process generate task pipeline. :return: @@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]: """ Process blocking response. :return: """ + human_input_responses: list[HumanInputRequiredResponse] = [] for stream_response in generator: if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err + elif isinstance(stream_response, HumanInputRequiredResponse): + human_input_responses.append(stream_response) + elif isinstance(stream_response, WorkflowPauseStreamResponse): + return AdvancedChatPausedBlockingResponse( + task_id=stream_response.task_id, + data=AdvancedChatPausedBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + workflow_run_id=stream_response.data.workflow_run_id, + answer=self._task_state.answer, + metadata=self._message_end_to_stream_response().metadata, + created_at=self._message_created_at, + paused_nodes=stream_response.data.paused_nodes, + reasons=stream_response.data.reasons, + status=stream_response.data.status, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + ), + ) elif isinstance(stream_response, MessageEndStreamResponse): extras = {} if stream_response.metadata: @@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: continue + if human_input_responses: + return self._build_paused_blocking_response_from_human_input(human_input_responses) + raise ValueError("queue listening stopped unexpectedly.") + def _build_paused_blocking_response_from_human_input( + self, human_input_responses: list[HumanInputRequiredResponse] + ) -> AdvancedChatPausedBlockingResponse: + runtime_state = self._resolve_graph_runtime_state() + paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses)) + reasons = [ + HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json") + for response in human_input_responses + ] + + return AdvancedChatPausedBlockingResponse( + task_id=self._application_generate_entity.task_id, + data=AdvancedChatPausedBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + workflow_run_id=human_input_responses[-1].workflow_run_id, + answer=self._task_state.answer, + metadata=self._message_end_to_stream_response().metadata, + created_at=self._message_created_at, + paused_nodes=paused_nodes, + reasons=reasons, + status=WorkflowExecutionStatus.PAUSED, + elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at, + total_tokens=runtime_state.total_tokens, + total_steps=runtime_state.node_run_steps, + ), + ) + def _to_stream_response( self, generator: Generator[StreamResponse, None, None] ) -> Generator[ChatbotAppStreamResponse, Any, None]: diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 731c6ee12e..03bc0a9108 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -1,6 +1,8 @@ from collections.abc import Generator from typing import Any, cast +from pydantic import JsonValue + from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -12,11 +14,9 @@ from core.app.entities.task_entities import ( ) -class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = ChatbotAppBlockingResponse - +class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking simple response. :param blocking_response: blocking response @@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index d5edfaeb25..abcbb2f943 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -1,7 +1,9 @@ import logging from abc import ABC, abstractmethod from collections.abc import Generator, Mapping -from typing import Any, Union +from typing import Any, Union, cast + +from pydantic import JsonValue from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse @@ -11,8 +13,10 @@ from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) -class AppGenerateResponseConverter(ABC): - _blocking_response_type: type[AppBlockingResponse] +class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC): + @classmethod + def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse: + return cast(TBlockingResponse, response) @classmethod def convert( @@ -20,7 +24,7 @@ class AppGenerateResponseConverter(ABC): ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if isinstance(response, AppBlockingResponse): - return cls.convert_blocking_full_response(response) + return cls.convert_blocking_full_response(cls._cast_blocking_response(response)) else: def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]: @@ -29,7 +33,7 @@ class AppGenerateResponseConverter(ABC): return _generate_full_response() else: if isinstance(response, AppBlockingResponse): - return cls.convert_blocking_simple_response(response) + return cls.convert_blocking_simple_response(cls._cast_blocking_response(response)) else: def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]: @@ -39,12 +43,12 @@ class AppGenerateResponseConverter(ABC): @classmethod @abstractmethod - def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]: raise NotImplementedError @classmethod @abstractmethod - def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]: raise NotImplementedError @classmethod @@ -106,13 +110,13 @@ class AppGenerateResponseConverter(ABC): return metadata @classmethod - def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]: + def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]: """ Error to stream response. :param e: exception :return: """ - error_responses: dict[type[Exception], dict[str, Any]] = { + error_responses: dict[type[Exception], dict[str, JsonValue]] = { ValueError: {"code": "invalid_param", "status": 400}, ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, QuotaExceededError: { @@ -126,7 +130,7 @@ class AppGenerateResponseConverter(ABC): } # Determine the response based on the type of exception - data: dict[str, Any] | None = None + data: dict[str, JsonValue] | None = None for k, v in error_responses.items(): if isinstance(e, k): data = v diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 3d0375151d..26efcbfafd 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -1,6 +1,8 @@ from collections.abc import Generator from typing import Any, cast +from pydantic import JsonValue + from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -12,11 +14,9 @@ from core.app.entities.task_entities import ( ) -class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = ChatbotAppBlockingResponse - +class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking simple response. :param blocking_response: blocking response @@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index bd685d5189..7bab3f7bff 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -52,6 +52,7 @@ from core.tools.tool_manager import ToolManager from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -336,7 +337,26 @@ class WorkflowResponseConverter: except (TypeError, json.JSONDecodeError): definition_payload = {} display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) - form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session) + form_token_by_form_id = load_form_tokens_by_form_id( + human_input_form_ids, + session=session, + surface=( + HumanInputSurface.SERVICE_API + if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API + else None + ), + ) + + # Reconnect paths must preserve the same pause-reason contract as live streams; + # otherwise clients see schema drift after resume. + pause_reasons = enrich_human_input_pause_reasons( + pause_reasons, + form_tokens_by_form_id=form_token_by_form_id, + expiration_times_by_form_id={ + form_id: int(expiration_time.timestamp()) + for form_id, expiration_time in expiration_times_by_form_id.items() + }, + ) responses: list[StreamResponse] = [] diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 71886b39ba..ad978f58e0 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -1,6 +1,8 @@ from collections.abc import Generator from typing import Any, cast +from pydantic import JsonValue + from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -12,17 +14,15 @@ from core.app.entities.task_entities import ( ) -class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = CompletionAppBlockingResponse - +class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): """ Convert blocking full response. :param blocking_response: blocking response :return: """ - response = { + response: dict[str, Any] = { "event": "message", "task_id": blocking_response.task_id, "id": blocking_response.data.id, @@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): """ Convert blocking simple response. :param blocking_response: blocking response @@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "message_id": chunk.message_id, "created_at": chunk.created_at, @@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "message_id": chunk.message_id, "created_at": chunk.created_at, diff --git a/api/core/app/apps/message_generator.py b/api/core/app/apps/message_generator.py index 68631bb230..c04f20c796 100644 --- a/api/core/app/apps/message_generator.py +++ b/api/core/app/apps/message_generator.py @@ -1,6 +1,7 @@ -from collections.abc import Callable, Generator, Mapping +from collections.abc import Callable, Generator, Iterable, Mapping from core.app.apps.streaming_utils import stream_topic_events +from core.app.entities.task_entities import StreamEvent from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic from models.model import AppMode @@ -26,6 +27,7 @@ class MessageGenerator: idle_timeout=300, ping_interval: float = 10.0, on_subscribe: Callable[[], None] | None = None, + terminal_events: Iterable[str | StreamEvent] | None = None, ) -> Generator[Mapping | str, None, None]: topic = cls.get_response_topic(app_mode, workflow_run_id) return stream_topic_events( @@ -33,4 +35,5 @@ class MessageGenerator: idle_timeout=idle_timeout, ping_interval=ping_interval, on_subscribe=on_subscribe, + terminal_events=terminal_events, ) diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py index 02b3160b7c..3913657ae8 100644 --- a/api/core/app/apps/pipeline/generate_response_converter.py +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -13,11 +13,9 @@ from core.app.entities.task_entities import ( ) -class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = WorkflowAppBlockingResponse - +class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: """ Convert blocking full response. :param blocking_response: blocking response @@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): return dict(blocking_response.model_dump()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: """ Convert blocking simple response. :param blocking_response: blocking response diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 4b2f17189b..4a76d0809e 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -27,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity -from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.app.entities.task_entities import ( + WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, + WorkflowAppStreamResponse, +) from core.datasource.entities.datasource_entities import ( DatasourceProviderType, OnlineDriveBrowseFilesRequest, @@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator): user: Account | EndUser, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, - ) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]: + ) -> ( + WorkflowAppBlockingResponse + | WorkflowAppPausedBlockingResponse + | Generator[WorkflowAppStreamResponse, None, None] + ): """ Handle response. :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/streaming_utils.py b/api/core/app/apps/streaming_utils.py index af3441aca3..5743bad4b6 100644 --- a/api/core/app/apps/streaming_utils.py +++ b/api/core/app/apps/streaming_utils.py @@ -59,7 +59,7 @@ def stream_topic_events( def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]: - if not terminal_events: + if terminal_events is None: return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value} values: set[str] = set() for item in terminal_events: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6937014a06..e811c2b2e0 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -25,7 +25,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.app.entities.task_entities import ( + WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, + WorkflowAppStreamResponse, +) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args @@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator): user: Account | EndUser, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, - ) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]: + ) -> ( + WorkflowAppBlockingResponse + | WorkflowAppPausedBlockingResponse + | Generator[WorkflowAppStreamResponse, None, None] + ): """ Handle response. :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index c69826cbef..4037388798 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -9,24 +9,29 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, WorkflowAppStreamResponse, ) -class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = WorkflowAppBlockingResponse - +class WorkflowAppGenerateResponseConverter( + AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse] +): @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response( + cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking full response. :param blocking_response: blocking response :return: """ - return blocking_response.model_dump() + return dict(blocking_response.model_dump()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response( + cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking simple response. :param blocking_response: blocking response diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 15645add57..87d9b73078 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -42,12 +42,15 @@ from core.app.entities.queue_entities import ( ) from core.app.entities.task_entities import ( ErrorStreamResponse, + HumanInputRequiredPauseReasonPayload, + HumanInputRequiredResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, PingStreamResponse, StreamResponse, TextChunkStreamResponse, WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, WorkflowPauseStreamResponse, @@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state - def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: + def process( + self, + ) -> Union[ + WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None] + ]: """ Process generate task pipeline. :return: @@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]: """ To blocking response. :return: """ + human_input_responses: list[HumanInputRequiredResponse] = [] for stream_response in generator: if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err + elif isinstance(stream_response, HumanInputRequiredResponse): + human_input_responses.append(stream_response) elif isinstance(stream_response, WorkflowPauseStreamResponse): - response = WorkflowAppBlockingResponse( + return WorkflowAppPausedBlockingResponse( task_id=self._application_generate_entity.task_id, workflow_run_id=stream_response.data.workflow_run_id, - data=WorkflowAppBlockingResponse.Data( + data=WorkflowAppPausedBlockingResponse.Data( id=stream_response.data.workflow_run_id, workflow_id=self._workflow.id, status=stream_response.data.status, @@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): total_steps=stream_response.data.total_steps, created_at=stream_response.data.created_at, finished_at=None, + paused_nodes=stream_response.data.paused_nodes, + reasons=stream_response.data.reasons, ), ) - return response elif isinstance(stream_response, WorkflowFinishStreamResponse): - response = WorkflowAppBlockingResponse( + return WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, workflow_run_id=stream_response.data.id, data=WorkflowAppBlockingResponse.Data( @@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): ), ) - return response else: continue + if human_input_responses: + return self._build_paused_blocking_response_from_human_input(human_input_responses) + raise ValueError("queue listening stopped unexpectedly.") + def _build_paused_blocking_response_from_human_input( + self, human_input_responses: list[HumanInputRequiredResponse] + ) -> WorkflowAppPausedBlockingResponse: + runtime_state = self._resolve_graph_runtime_state() + paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses)) + created_at = int(runtime_state.start_at) + reasons = [ + HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json") + for response in human_input_responses + ] + + return WorkflowAppPausedBlockingResponse( + task_id=self._application_generate_entity.task_id, + workflow_run_id=human_input_responses[-1].workflow_run_id, + data=WorkflowAppPausedBlockingResponse.Data( + id=human_input_responses[-1].workflow_run_id, + workflow_id=self._workflow.id, + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + error=None, + elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at, + total_tokens=runtime_state.total_tokens, + total_steps=runtime_state.node_run_steps, + created_at=created_at, + finished_at=None, + paused_nodes=paused_nodes, + reasons=reasons, + ), + ) + def _to_stream_response( self, generator: Generator[StreamResponse, None, None] ) -> Generator[WorkflowAppStreamResponse, None, None]: diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 6e4ca69cf0..ad05566521 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,12 +1,13 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any +from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, JsonValue from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities import RetrievalSourceMetadata from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReasonType from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage from graphon.nodes.human_input.entities import FormInput, UserAction @@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse): data: Data +class HumanInputRequiredPauseReasonPayload(BaseModel): + """ + Public pause-reason payload used by blocking responses when only + ``human_input_required`` events are available. + """ + + TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED + form_id: str + node_id: str + node_title: str + form_content: str + inputs: Sequence[FormInput] = Field(default_factory=list) + actions: Sequence[UserAction] = Field(default_factory=list) + display_in_ui: bool = False + form_token: str | None = None + resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) + expiration_time: int + + @classmethod + def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload": + return cls( + form_id=data.form_id, + node_id=data.node_id, + node_title=data.node_title, + form_content=data.form_content, + inputs=data.inputs, + actions=data.actions, + display_in_ui=data.display_in_ui, + form_token=data.form_token, + resolved_default_values=data.resolved_default_values, + expiration_time=data.expiration_time, + ) + + class HumanInputFormFilledResponse(StreamResponse): class Data(BaseModel): """ @@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse): workflow_run_id: str data: Data - def to_ignore_detail_dict(self): + def to_ignore_detail_dict(self) -> dict[str, JsonValue]: return { "event": self.event.value, "task_id": self.task_id, @@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse): workflow_run_id: str data: Data - def to_ignore_detail_dict(self): + def to_ignore_detail_dict(self) -> dict[str, JsonValue]: return { "event": self.event.value, "task_id": self.task_id, @@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse): data: Data +class AdvancedChatPausedBlockingResponse(AppBlockingResponse): + """ + ChatbotAppPausedBlockingResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + mode: str + conversation_id: str + message_id: str + workflow_run_id: str + answer: str + metadata: Mapping[str, object] = Field(default_factory=dict) + created_at: int + paused_nodes: Sequence[str] = Field(default_factory=list) + reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]]) + status: WorkflowExecutionStatus + elapsed_time: float + total_tokens: int + total_steps: int + + data: Data + + class CompletionAppBlockingResponse(AppBlockingResponse): """ CompletionAppBlockingResponse entity @@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): data: Data +class WorkflowAppPausedBlockingResponse(AppBlockingResponse): + """ + WorkflowAppPausedBlockingResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + workflow_id: str + status: WorkflowExecutionStatus + outputs: Mapping[str, Any] | None = None + error: str | None = None + elapsed_time: float + total_tokens: int + total_steps: int + created_at: int + finished_at: int | None + paused_nodes: Sequence[str] = Field(default_factory=list) + reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list) + + workflow_run_id: str + data: Data + + class AgentLogStreamResponse(StreamResponse): """ AgentLogStreamResponse entity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index c49c4eb0ac..5631caa1a5 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -1,5 +1,6 @@ from __future__ import annotations +from copy import deepcopy from typing import Any from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity @@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: + """Resolves and returns LLM credentials for a given provider and model. + + Fetched credentials are stored in :attr:`credentials_cache` and reused for + subsequent ``fetch`` calls for the same ``(provider_name, model_name)``. + Because of that cache, a single instance can return stale credentials after + the tenant or provider configuration changes (e.g. API key rotation). + + Do **not** keep one instance for the lifetime of a process or across + unrelated invocations. Create a new provider per request, workflow run, or + other bounded scope where up-to-date credentials matter. + """ + tenant_id: str provider_manager: ProviderManager + credentials_cache: dict[tuple[str, str], dict[str, Any]] def __init__( self, @@ -30,8 +44,12 @@ class DifyCredentialsProvider: user_id=run_context.user_id, ) self.provider_manager = provider_manager + self.credentials_cache = {} def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + if (provider_name, model_name) in self.credentials_cache: + return deepcopy(self.credentials_cache[(provider_name, model_name)]) + provider_configurations = self.provider_manager.get_configurations(self.tenant_id) provider_configuration = provider_configurations.get(provider_name) if not provider_configuration: @@ -46,6 +64,7 @@ class DifyCredentialsProvider: if credentials is None: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials) return credentials @@ -65,7 +84,8 @@ class DifyModelFactory: provider_manager=create_plugin_provider_manager( tenant_id=run_context.tenant_id, user_id=run_context.user_id, - ) + ), + enable_credentials_cache=True, ) self.model_manager = model_manager @@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro tenant_id=run_context.tenant_id, user_id=run_context.user_id, ) - model_manager = ModelManager(provider_manager=provider_manager) + model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True) return ( DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), diff --git a/api/core/helper/creators.py b/api/core/helper/creators.py new file mode 100644 index 0000000000..b01e16f18a --- /dev/null +++ b/api/core/helper/creators.py @@ -0,0 +1,41 @@ +""" +Helper module for Creators Platform integration. + +Provides functionality to upload DSL files to the Creators Platform +and generate redirect URLs with OAuth authorization codes. +""" + +import logging +from urllib.parse import urlencode + +import httpx +from yarl import URL + +from configs import dify_config + +logger = logging.getLogger(__name__) + +creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL)) + + +def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str: + url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload") + response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30) + response.raise_for_status() + data = response.json() + claim_code = data.get("data", {}).get("claim_code") + if not claim_code: + raise ValueError("Creators Platform did not return a valid claim_code") + return claim_code + + +def get_redirect_url(user_account_id: str, claim_code: str) -> str: + base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/") + params: dict[str, str] = {"dsl_claim_code": claim_code} + client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "") + if client_id: + from services.oauth_server import OAuthServerService + + oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id) + params["oauth_code"] = oauth_code + return f"{base_url}?{urlencode(params)}" diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 6454f4f0dc..af2611bb0b 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.prompts import ( CONVERSATION_TITLE_PROMPT, - DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS, - DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE, GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, LLM_MODIFY_CODE_SYSTEM, @@ -217,8 +215,8 @@ class LLMGenerator: else: # Default-model generation keeps the built-in suggested-questions tuning. model_parameters = { - "max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS, - "temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE, + "max_tokens": 2560, + "temperature": 0.0, } stop = [] diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index c030802c79..7ac340926d 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -10,7 +10,14 @@ logger = logging.getLogger(__name__) class SuggestedQuestionsAfterAnswerOutputParser: def __init__(self, instruction_prompt: str | None = None) -> None: - self._instruction_prompt = instruction_prompt or DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT + self._instruction_prompt = self._build_instruction_prompt(instruction_prompt) + + @staticmethod + def _build_instruction_prompt(instruction_prompt: str | None) -> str: + if not instruction_prompt or not instruction_prompt.strip(): + return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT + + return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].' def get_format_instructions(self) -> str: return self._instruction_prompt diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index 855a00c9cd..3c6f8c468a 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -104,9 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( '["question1","question2","question3"]\n' ) -DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256 -DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0 - GENERATOR_QA_PROMPT = ( " The user will send a long text. Generate a Question and Answer pairs only using the knowledge" " in the long text. Please think step by step." diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 86d0e3baaa..457c888e33 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,5 +1,6 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence +from copy import deepcopy from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from configs import dify_config @@ -36,11 +37,13 @@ class ModelInstance: Model instance class. """ - def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): + def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None: self.provider_model_bundle = provider_model_bundle self.model_name = model self.provider = provider_model_bundle.configuration.provider.provider - self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + if credentials is None: + credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + self.credentials = credentials # Runtime LLM invocation fields. self.parameters: Mapping[str, Any] = {} self.stop: Sequence[str] = () @@ -434,8 +437,30 @@ class ModelInstance: class ModelManager: - def __init__(self, provider_manager: ProviderManager): + """Resolves :class:`ModelInstance` objects for a tenant and provider. + + When ``enable_credentials_cache`` is ``True``, resolved credentials for each + ``(tenant_id, provider, model_type, model)`` are stored in + ``_credentials_cache`` and reused. That can return **stale** credentials after + API keys or provider settings change, so a manager constructed with + ``enable_credentials_cache=True`` should not be kept for the lifetime of a + process or shared across unrelated work. Prefer a new manager per request, + workflow run, or similar bounded scope. + + The default is ``enable_credentials_cache=False``; in that mode the internal + credential cache is not populated, and each ``get_model_instance`` call + loads credentials from the current provider configuration. + """ + + def __init__( + self, + provider_manager: ProviderManager, + *, + enable_credentials_cache: bool = False, + ) -> None: self._provider_manager = provider_manager + self._credentials_cache: dict[tuple[str, str, str, str], Any] = {} + self._enable_credentials_cache = enable_credentials_cache @classmethod def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": @@ -463,8 +488,19 @@ class ModelManager: tenant_id=tenant_id, provider=provider, model_type=model_type ) - model_instance = ModelInstance(provider_model_bundle, model) - return model_instance + cred_cache_key = (tenant_id, provider, model_type.value, model) + + if cred_cache_key in self._credentials_cache: + return ModelInstance( + provider_model_bundle, + model, + deepcopy(self._credentials_cache[cred_cache_key]), + ) + + ret = ModelInstance(provider_model_bundle, model) + if self._enable_credentials_cache: + self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials) + return ret def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 242da520c1..392af351b6 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -156,7 +156,8 @@ class Jieba(BaseKeyword): if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return dict(keyword_table_dict["__data__"]["table"]) + data: Any = keyword_table_dict["__data__"] + return dict(data["table"]) else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 1ca6303af6..2af8238cc4 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -109,7 +109,7 @@ class JiebaKeywordTableHandler: """Extract keywords with JIEBA tfidf.""" keywords = self._tfidf.extract_tags( sentence=text, - topK=max_keywords_per_chunk, + topK=max_keywords_per_chunk or 10, ) # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default. keywords = cast(list[str], keywords) diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 426d1b67dc..dd17545c86 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter: result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType] prompt_messages=prompt_messages, tools=dataset_tools, - stream=False, + stream=False, # pyright: ignore[reportArgumentType] model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) usage = result.usage or LLMUsage.empty_usage() diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_encryption.py similarity index 57% rename from api/core/tools/utils/system_oauth_encryption.py rename to api/core/tools/utils/system_encryption.py index 6b7007842d..ca7e6a13fe 100644 --- a/api/core/tools/utils/system_oauth_encryption.py +++ b/api/core/tools/utils/system_encryption.py @@ -14,23 +14,23 @@ from configs import dify_config logger = logging.getLogger(__name__) -class OAuthEncryptionError(Exception): - """OAuth encryption/decryption specific error""" +class EncryptionError(Exception): + """Encryption/decryption specific error""" pass -class SystemOAuthEncrypter: +class SystemEncrypter: """ - A simple OAuth parameters encrypter using AES-CBC encryption. + A simple parameters encrypter using AES-CBC encryption. - This class provides methods to encrypt and decrypt OAuth parameters + This class provides methods to encrypt and decrypt parameters using AES-CBC mode with a key derived from the application's SECRET_KEY. """ def __init__(self, secret_key: str | None = None): """ - Initialize the OAuth encrypter. + Initialize the encrypter. Args: secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY @@ -43,19 +43,19 @@ class SystemOAuthEncrypter: # Generate a fixed 256-bit key using SHA-256 self.key = hashlib.sha256(secret_key.encode()).digest() - def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str: + def encrypt_params(self, params: Mapping[str, Any]) -> str: """ - Encrypt OAuth parameters. + Encrypt parameters. Args: - oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} + params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} Returns: Base64-encoded encrypted string Raises: - OAuthEncryptionError: If encryption fails - ValueError: If oauth_params is invalid + EncryptionError: If encryption fails + ValueError: If params is invalid """ try: @@ -66,7 +66,7 @@ class SystemOAuthEncrypter: cipher = AES.new(self.key, AES.MODE_CBC, iv) # Encrypt data - padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size) + padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size) encrypted_data = cipher.encrypt(padded_data) # Combine IV and encrypted data @@ -76,20 +76,20 @@ class SystemOAuthEncrypter: return base64.b64encode(combined).decode() except Exception as e: - raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e + raise EncryptionError(f"Encryption failed: {str(e)}") from e - def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]: + def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]: """ - Decrypt OAuth parameters. + Decrypt parameters. Args: encrypted_data: Base64-encoded encrypted string Returns: - Decrypted OAuth parameters dictionary + Decrypted parameters dictionary Raises: - OAuthEncryptionError: If decryption fails + EncryptionError: If decryption fails ValueError: If encrypted_data is invalid """ if not isinstance(encrypted_data, str): @@ -118,70 +118,70 @@ class SystemOAuthEncrypter: unpadded_data = unpad(decrypted_data, AES.block_size) # Parse JSON - oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) + params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) - if not isinstance(oauth_params, dict): + if not isinstance(params, dict): raise ValueError("Decrypted data is not a valid dictionary") - return oauth_params + return params except Exception as e: - raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e + raise EncryptionError(f"Decryption failed: {str(e)}") from e # Factory function for creating encrypter instances -def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter: +def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter: """ - Create an OAuth encrypter instance. + Create an encrypter instance. Args: secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY Returns: - SystemOAuthEncrypter instance + SystemEncrypter instance """ - return SystemOAuthEncrypter(secret_key=secret_key) + return SystemEncrypter(secret_key=secret_key) # Global encrypter instance (for backward compatibility) -_oauth_encrypter: SystemOAuthEncrypter | None = None +_encrypter: SystemEncrypter | None = None -def get_system_oauth_encrypter() -> SystemOAuthEncrypter: +def get_system_encrypter() -> SystemEncrypter: """ - Get the global OAuth encrypter instance. + Get the global encrypter instance. Returns: - SystemOAuthEncrypter instance + SystemEncrypter instance """ - global _oauth_encrypter - if _oauth_encrypter is None: - _oauth_encrypter = SystemOAuthEncrypter() - return _oauth_encrypter + global _encrypter + if _encrypter is None: + _encrypter = SystemEncrypter() + return _encrypter # Convenience functions for backward compatibility -def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str: +def encrypt_system_params(params: Mapping[str, Any]) -> str: """ - Encrypt OAuth parameters using the global encrypter. + Encrypt parameters using the global encrypter. Args: - oauth_params: OAuth parameters dictionary + params: Parameters dictionary Returns: Base64-encoded encrypted string """ - return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) + return get_system_encrypter().encrypt_params(params) -def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]: +def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]: """ - Decrypt OAuth parameters using the global encrypter. + Decrypt parameters using the global encrypter. Args: encrypted_data: Base64-encoded encrypted string Returns: - Decrypted OAuth parameters dictionary + Decrypted parameters dictionary """ - return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data) + return get_system_encrypter().decrypt_params(encrypted_data) diff --git a/api/core/workflow/human_input_forms.py b/api/core/workflow/human_input_forms.py index f124b321d4..b02f69ec33 100644 --- a/api/core/workflow/human_input_forms.py +++ b/api/core/workflow/human_input_forms.py @@ -12,20 +12,16 @@ from collections.abc import Sequence from sqlalchemy import select from sqlalchemy.orm import Session +from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token from extensions.ext_database import db from models.human_input import HumanInputFormRecipient, RecipientType -_FORM_TOKEN_PRIORITY = { - RecipientType.BACKSTAGE: 0, - RecipientType.CONSOLE: 1, - RecipientType.STANDALONE_WEB_APP: 2, -} - def load_form_tokens_by_form_id( form_ids: Sequence[str], *, session: Session | None = None, + surface: HumanInputSurface | None = None, ) -> dict[str, str]: """Load the preferred access token for each human input form.""" unique_form_ids = list(dict.fromkeys(form_ids)) @@ -33,23 +29,43 @@ def load_form_tokens_by_form_id( return {} if session is not None: - return _load_form_tokens_by_form_id(session, unique_form_ids) + return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface) with Session(bind=db.engine, expire_on_commit=False) as new_session: - return _load_form_tokens_by_form_id(new_session, unique_form_ids) + return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface) -def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]: - tokens_by_form_id: dict[str, tuple[int, str]] = {} +def _load_form_tokens_by_form_id( + session: Session, + form_ids: Sequence[str], + *, + surface: HumanInputSurface | None = None, +) -> dict[str, str]: + recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {} stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) for recipient in session.scalars(stmt): - priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type) - if priority is None or not recipient.access_token: + if not recipient.access_token: continue + recipients_by_form_id.setdefault(recipient.form_id, []).append( + (recipient.recipient_type, recipient.access_token) + ) - candidate = (priority, recipient.access_token) - current = tokens_by_form_id.get(recipient.form_id) - if current is None or candidate[0] < current[0]: - tokens_by_form_id[recipient.form_id] = candidate + tokens_by_form_id: dict[str, str] = {} + for form_id, recipients in recipients_by_form_id.items(): + token = _get_surface_form_token(recipients, surface=surface) + if token is not None: + tokens_by_form_id[form_id] = token + return tokens_by_form_id - return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()} + +def _get_surface_form_token( + recipients: Sequence[tuple[RecipientType, str]], + *, + surface: HumanInputSurface | None, +) -> str | None: + if surface == HumanInputSurface.SERVICE_API: + for recipient_type, token in recipients: + if recipient_type == RecipientType.STANDALONE_WEB_APP and token: + return token + + return get_preferred_form_token(recipients) diff --git a/api/core/workflow/human_input_policy.py b/api/core/workflow/human_input_policy.py new file mode 100644 index 0000000000..798eb8723f --- /dev/null +++ b/api/core/workflow/human_input_policy.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any + +from graphon.entities.pause_reason import PauseReasonType +from models.human_input import RecipientType + + +class HumanInputSurface(StrEnum): + SERVICE_API = "service_api" + CONSOLE = "console" + + +# Service API is intentionally narrower than other surfaces: app-token callers +# should only be able to act on end-user web forms, not internal console flows. +_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = { + HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}), + HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}), +} + +# A single HITL form can have multiple recipient records; this shared priority +# keeps every API surface consistent about which resume token to expose. +_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = { + RecipientType.BACKSTAGE: 0, + RecipientType.CONSOLE: 1, + RecipientType.STANDALONE_WEB_APP: 2, +} + + +def is_recipient_type_allowed_for_surface( + recipient_type: RecipientType | None, + surface: HumanInputSurface, +) -> bool: + if recipient_type is None: + return False + return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface] + + +def get_preferred_form_token( + recipients: Sequence[tuple[RecipientType, str]], +) -> str | None: + chosen_token: str | None = None + chosen_priority: int | None = None + for recipient_type, token in recipients: + priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type) + if priority is None or not token: + continue + if chosen_priority is None or priority < chosen_priority: + chosen_priority = priority + chosen_token = token + return chosen_token + + +def enrich_human_input_pause_reasons( + reasons: Sequence[Mapping[str, Any]], + *, + form_tokens_by_form_id: Mapping[str, str], + expiration_times_by_form_id: Mapping[str, int], +) -> list[dict[str, Any]]: + enriched: list[dict[str, Any]] = [] + for reason in reasons: + updated = dict(reason) + if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED: + form_id = updated.get("form_id") + if isinstance(form_id, str): + updated["form_token"] = form_tokens_by_form_id.get(form_id) + expiration_time = expiration_times_by_form_id.get(form_id) + if expiration_time is not None: + updated["expiration_time"] = expiration_time + enriched.append(updated) + return enriched diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py index 286dda419c..ac09060e9d 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py @@ -225,8 +225,10 @@ class TestSpanBuilder: span = builder.build_span(span_data) assert isinstance(span, ReadableSpan) assert span.name == "test-span" + assert span.context is not None assert span.context.trace_id == 123 assert span.context.span_id == 456 + assert span.parent is not None assert span.parent.span_id == 789 assert span.resource == resource assert span.attributes == {"attr1": "val1"} diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py index 38d33dd21b..a6808fec0a 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -64,12 +64,13 @@ class TestSpanData: def test_span_data_missing_required_fields(self): with pytest.raises(ValidationError): - SpanData( - trace_id=123, - # span_id missing - name="test_span", - start_time=1000, - end_time=2000, + SpanData.model_validate( + { + "trace_id": 123, + "name": "test_span", + "start_time": 1000, + "end_time": 2000, + } ) def test_span_data_arbitrary_types_allowed(self): diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py index c1b11c9186..fa00829653 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py @@ -2,12 +2,14 @@ from __future__ import annotations from datetime import UTC, datetime from types import SimpleNamespace +from typing import cast from unittest.mock import MagicMock import dify_trace_aliyun.aliyun_trace as aliyun_trace_module import pytest from dify_trace_aliyun.aliyun_trace import AliyunDataTrace from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata from dify_trace_aliyun.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, @@ -44,7 +46,7 @@ class RecordingTraceClient: self.endpoint = endpoint self.added_spans: list[object] = [] - def add_span(self, span) -> None: + def add_span(self, span: object) -> None: self.added_spans.append(span) def api_check(self) -> bool: @@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link: trace_id=trace_id, span_id=span_id, is_remote=False, - trace_flags=TraceFlags.SAMPLED, + trace_flags=TraceFlags(TraceFlags.SAMPLED), ) return Link(context) +def _make_trace_metadata( + trace_id: int = 1, + workflow_span_id: int = 2, + session_id: str = "s", + user_id: str = "u", + links: list[Link] | None = None, +) -> TraceMetadata: + return TraceMetadata( + trace_id=trace_id, + workflow_span_id=workflow_span_id, + session_id=session_id, + user_id=user_id, + links=[] if links is None else links, + ) + + +def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient: + return cast(RecordingTraceClient, trace_instance.trace_client) + + +def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]: + return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans) + + def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: defaults = { "workflow_id": "workflow-id", @@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT trace_instance.workflow_trace(trace_info) add_workflow_span.assert_called_once() - passed_trace_metadata = add_workflow_span.call_args.args[1] + passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1]) assert passed_trace_metadata.trace_id == 111 assert passed_trace_metadata.workflow_span_id == 222 assert passed_trace_metadata.session_id == "c" assert passed_trace_metadata.user_id == "u" assert passed_trace_metadata.links == [] - assert trace_instance.trace_client.added_spans == ["span-1", "span-2"] + assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"] def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_message_trace_info(message_data=None) trace_instance.message_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT ) trace_instance.message_trace(trace_info) - assert len(trace_instance.trace_client.added_spans) == 2 - message_span, llm_span = trace_instance.trace_client.added_spans + spans = _recorded_span_data(trace_instance) + assert len(spans) == 2 + message_span, llm_span = spans assert message_span.name == "message" assert message_span.trace_id == 10 @@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_dataset_retrieval_trace_info(message_data=None) trace_instance.dataset_retrieval_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}]) trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query")) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "dataset_retrieval" assert span.attributes[RETRIEVAL_QUERY] == "query" assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]' @@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_tool_trace_info(message_data=None) trace_instance.tool_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p ) ) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "my-tool" assert span.status == status assert span.attributes[TOOL_NAME] == "my-tool" @@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches( def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm")) @@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type( ): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval")) @@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type( def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool")) @@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task")) @@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors( ): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom"))) node_execution.node_type = BuiltinNodeTypes.CODE @@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch: status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "title" @@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch: status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()]) + trace_metadata = _make_trace_metadata(links=[_make_link()]) node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "my-tool" @@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else [] ) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "retrieval" @@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in") monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out") - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "llm" @@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() # CASE 1: With message_id trace_info = _make_workflow_trace_info( @@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. ) trace_instance.add_workflow_span(trace_info, trace_metadata) - assert len(trace_instance.trace_client.added_spans) == 2 - message_span = trace_instance.trace_client.added_spans[0] - workflow_span = trace_instance.trace_client.added_spans[1] + client = _recording_trace_client(trace_instance) + spans = _recorded_span_data(trace_instance) + assert len(spans) == 2 + message_span = spans[0] + workflow_span = spans[1] assert message_span.name == "message" assert message_span.span_kind == SpanKind.SERVER @@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. assert workflow_span.span_kind == SpanKind.INTERNAL assert workflow_span.parent_span_id == 20 - trace_instance.trace_client.added_spans.clear() + client.added_spans.clear() # CASE 2: Without message_id trace_info_no_msg = _make_workflow_trace_info(message_id=None) trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "workflow" assert span.span_kind == SpanKind.SERVER assert span.parent_span_id is None @@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch: trace_info = _make_suggested_question_trace_info(suggested_question=["how?"]) trace_instance.suggested_question_trace(trace_info) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "suggested_question" assert span.attributes[GEN_AI_COMPLETION] == '["how?"]' diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py index a9e7b80c2a..1b97746dea 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py @@ -1,4 +1,6 @@ import json +from collections.abc import Mapping +from typing import Any, cast from unittest.mock import MagicMock from dify_trace_aliyun.entities.semconv import ( @@ -170,7 +172,7 @@ def test_create_common_span_attributes(): def test_format_retrieval_documents(): # Not a list - assert format_retrieval_documents("not a list") == [] + assert format_retrieval_documents(cast(list[object], "not a list")) == [] # Valid list docs = [ @@ -211,7 +213,7 @@ def test_format_retrieval_documents(): def test_format_input_messages(): # Not a dict - assert format_input_messages(None) == serialize_json_data([]) + assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([]) # No prompts assert format_input_messages({}) == serialize_json_data([]) @@ -244,7 +246,7 @@ def test_format_input_messages(): def test_format_output_messages(): # Not a dict - assert format_output_messages(None) == serialize_json_data([]) + assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([]) # No text assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([]) diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py index 1b24ee7421..8068ee1328 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py @@ -25,13 +25,13 @@ class TestAliyunConfig: def test_missing_required_fields(self): """Test that required fields are enforced""" with pytest.raises(ValidationError): - AliyunConfig() + AliyunConfig.model_validate({}) with pytest.raises(ValidationError): - AliyunConfig(license_key="test_license") + AliyunConfig.model_validate({"license_key": "test_license"}) with pytest.raises(ValidationError): - AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"}) def test_app_name_validation_empty(self): """Test app_name validation with empty value""" diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index b0691a87ea..e9ecc2e083 100644 --- a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -1,4 +1,5 @@ from datetime import UTC, datetime, timedelta +from typing import cast from unittest.mock import MagicMock, patch import pytest @@ -129,7 +130,7 @@ def test_set_span_status(): return "SilentErrorRepr" span.reset_mock() - set_span_status(span, SilentError()) + set_span_status(span, cast(Exception | str | None, SilentError())) assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr" diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py index 103d888eef..0c3c3fc81e 100644 --- a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py @@ -28,13 +28,13 @@ class TestLangfuseConfig: def test_missing_required_fields(self): """Test that required fields are enforced""" with pytest.raises(ValidationError): - LangfuseConfig() + LangfuseConfig.model_validate({}) with pytest.raises(ValidationError): - LangfuseConfig(public_key="public") + LangfuseConfig.model_validate({"public_key": "public"}) with pytest.raises(ValidationError): - LangfuseConfig(secret_key="secret") + LangfuseConfig.model_validate({"secret_key": "secret"}) def test_host_validation_empty(self): """Test host validation with empty value""" diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py index 0340ffb669..82d69b6180 100644 --- a/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py @@ -2,6 +2,7 @@ from datetime import datetime, timedelta from types import SimpleNamespace +from typing import cast from unittest.mock import MagicMock, patch from dify_trace_langfuse.config import LangfuseConfig @@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime: assert trace._get_completion_start_time(start_time, None) is None assert trace._get_completion_start_time(start_time, -1) is None - assert trace._get_completion_start_time(start_time, "invalid") is None + assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py index 37efaf69cf..bd226c9f1a 100644 --- a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py @@ -21,13 +21,13 @@ class TestLangSmithConfig: def test_missing_required_fields(self): """Test that required fields are enforced""" with pytest.raises(ValidationError): - LangSmithConfig() + LangSmithConfig.model_validate({}) with pytest.raises(ValidationError): - LangSmithConfig(api_key="key") + LangSmithConfig.model_validate({"api_key": "key"}) with pytest.raises(ValidationError): - LangSmithConfig(project="project") + LangSmithConfig.model_validate({"project": "project"}) def test_endpoint_validation_https_only(self): """Test endpoint validation only allows HTTPS""" diff --git a/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py index 20211456e3..46c9750a5d 100644 --- a/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py @@ -599,7 +599,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_instance.message_trace(_make_message_trace_info()) mock_tracing["start"].assert_called_once() @@ -609,7 +608,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info(error="something broke") trace_instance.message_trace(trace_info) @@ -620,7 +618,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None monkeypatch.setenv("FILES_URL", "http://files.test") file_data = SimpleNamespace(url="path/to/file.png") @@ -638,7 +635,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info(file_list=None, message_file_data=None) trace_instance.message_trace(trace_info) @@ -651,7 +647,6 @@ class TestMessageTrace: end_user = MagicMock() end_user.session_id = "session-xyz" - mock_db.session.query.return_value.where.return_value.first.return_value = end_user trace_info = _make_message_trace_info( metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"}, @@ -664,7 +659,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info( metadata={"from_account_id": "acc-1"}, diff --git a/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py index fba290f5b8..2e0796c291 100644 --- a/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py @@ -12,6 +12,7 @@ from __future__ import annotations import uuid from datetime import datetime +from typing import cast from unittest.mock import MagicMock, patch from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid @@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace: return instance +def _add_trace_mock(instance: OpikDataTrace) -> MagicMock: + return cast(MagicMock, instance.add_trace) + + +def _add_span_mock(instance: OpikDataTrace) -> MagicMock: + return cast(MagicMock, instance.add_span) + + # --------------------------------------------------------------------------- # _seed_to_uuid4 # --------------------------------------------------------------------------- @@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId: def test_root_span_is_created(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - assert instance.add_span.called + assert _add_span_mock(instance).called def test_root_span_id_matches_expected(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) expected = self._expected_root_span_id(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["id"] == expected def test_root_span_has_no_parent(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["parent_span_id"] is None def test_trace_name_is_workflow_trace(self): @@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId: trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - trace_kwargs = instance.add_trace.call_args_list[0][0][0] + trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0] assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE def test_root_span_name_is_workflow_trace(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE def test_root_span_has_workflow_tag(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert "workflow" in root_span_kwargs["tags"] def test_node_execution_spans_are_parented_to_root(self): @@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId: instance = self._run(trace_info, node_executions=[node_exec]) # call_args_list[0] = root span, [1] = node execution span - assert instance.add_span.call_count == 2 - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + add_span = _add_span_mock(instance) + assert add_span.call_count == 2 + node_span_kwargs = add_span.call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] == expected_root_span_id def test_node_span_not_parented_to_workflow_app_log_id(self): @@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId: instance = self._run(trace_info, node_executions=[node_exec]) old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id) - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] != old_parent_id def test_root_span_id_differs_from_trace_id(self): @@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId: trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) instance = self._run(trace_info) - trace_kwargs = instance.add_trace.call_args_list[0][0][0] + trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0] assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE def test_root_span_uses_workflow_run_id_directly(self): @@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId: instance = self._run(trace_info) expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["id"] == expected_root_span_id def test_root_span_id_differs_from_no_message_id_case(self): @@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId: instance = self._run(trace_info, node_executions=[node_exec]) - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] == expected_root_span_id diff --git a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py index 1e656e2462..3cd918f408 100644 --- a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py @@ -5,6 +5,7 @@ from __future__ import annotations import sys import types from types import SimpleNamespace +from typing import Any, TypedDict, cast from unittest.mock import MagicMock import pytest @@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version from dify_trace_tencent.entities.tencent_trace_entity import SpanData from opentelemetry.sdk.trace import Event -from opentelemetry.trace import Status, StatusCode +from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags metric_reader_instances: list[DummyMetricReader] = [] meter_provider_instances: list[DummyMeterProvider] = [] @@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality: self.kwargs = kwargs +class PatchedCoreComponents(TypedDict): + span_exporter: MagicMock + span_processor: MagicMock + tracer: MagicMock + span: MagicMock + tracer_provider: MagicMock + logger: MagicMock + trace_api: Any + + def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None: """Drop fake metric modules into sys.modules so the client imports resolve.""" @@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.fixture(autouse=True) -def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: +def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents: span_exporter = MagicMock(name="span_exporter") monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter)) @@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: } +def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext: + return SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + + def _build_client() -> TencentTraceClient: return TencentTraceClient( service_name="service", @@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st def test_resolve_grpc_target_handles_errors() -> None: - assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317) + assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317) @pytest.mark.parametrize( @@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None: client.record_trace_duration(0.5) -def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None: +def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() client.hist_llm_duration = MagicMock(name="hist_llm_duration") client.hist_llm_duration.record.side_effect = RuntimeError("boom") @@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, logger.debug.assert_called() -def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + ctx = _make_span_context(span_id=2) + span.get_span_context.return_value = ctx data = SpanData( trace_id=1, @@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, span.add_event.assert_called_once() span.set_status.assert_called_once() span.end.assert_called_once_with(end_time=20) - assert client.span_contexts[2] == "ctx" + assert client.span_contexts[2] == ctx -def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() - client.span_contexts[10] = "existing" + existing_context = _make_span_context(span_id=10) + client.span_contexts[10] = existing_context span = patch_core_components["span"] - span.get_span_context.return_value = "child" + span.get_span_context.return_value = _make_span_context(span_id=11) data = SpanData( trace_id=1, @@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[ client._create_and_export_span(data) trace_api = patch_core_components["trace_api"] - trace_api.NonRecordingSpan.assert_called_once_with("existing") + trace_api.NonRecordingSpan.assert_called_once_with(existing_context) trace_api.set_span_in_context.assert_called_once() -def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + span.get_span_context.return_value = _make_span_context(span_id=2) client.tracer.start_span.side_effect = RuntimeError("boom") client._create_and_export_span( @@ -385,7 +407,7 @@ def test_get_project_url() -> None: assert client.get_project_url() == "https://console.cloud.tencent.com/apm" -def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None: +def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span_processor = patch_core_components["span_processor"] tracer_provider = patch_core_components["tracer_provider"] @@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object metric_reader.shutdown.assert_called_once() -def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None: +def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() meter_provider = meter_provider_instances[-1] meter_provider.shutdown.side_effect = RuntimeError("boom") + assert client.metric_reader is not None client.metric_reader.shutdown.side_effect = RuntimeError("boom") client.shutdown() @@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p assert client.metric_reader is None -def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None: +def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None: client = _build_client() monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom"))) @@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com logger.exception.assert_called_once() -def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + span.get_span_context.return_value = _make_span_context(span_id=2) data = SpanData.model_construct( trace_id=1, @@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None: hist_mock = MagicMock(name="hist_llm_duration") client.hist_llm_duration = hist_mock - client.record_llm_duration(0.3, {"foo": object(), "bar": 2}) + client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2})) _, attrs = hist_mock.record.call_args.args assert isinstance(attrs["foo"], str) assert attrs["bar"] == 2 @@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None: hist_mock = MagicMock(name="hist_trace_duration") client.hist_trace_duration = hist_mock - client.record_trace_duration(1.0, {"meta": object(), "ok": True}) + client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True})) _, attrs = hist_mock.record.call_args.args assert isinstance(attrs["meta"], str) assert attrs["ok"] is True @@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None: ], ) def test_record_methods_handle_exceptions( - method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object] + method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents ) -> None: client = _build_client() hist_mock = MagicMock(name=attr_name) @@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions( def test_metrics_initializes_grpc_metric_exporter() -> None: client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter) + assert isinstance(exporter, DummyGrpcMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317" - assert metric_reader.exporter.kwargs["insecure"] is False - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert exporter.kwargs["endpoint"] == "trace.example.com:4317" + assert exporter.kwargs["insecure"] is False + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyHttpMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) + assert isinstance(exporter, DummyHttpMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert exporter.kwargs["endpoint"] == client.endpoint + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyJsonMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyJsonMetricExporter) + assert isinstance(exporter, DummyJsonMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" - assert "preferred_temporality" in metric_reader.exporter.kwargs + assert exporter.kwargs["endpoint"] == client.endpoint + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" + assert "preferred_temporality" in exporter.kwargs def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None: @@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality) _ = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality) - assert "preferred_temporality" not in metric_reader.exporter.kwargs + assert isinstance(exporter, DummyJsonMetricExporterNoTemporality) + assert "preferred_temporality" not in exporter.kwargs def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py index eeb1fe1d87..377c768198 100644 --- a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py +++ b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py @@ -31,13 +31,13 @@ class TestWeaveConfig: def test_missing_required_fields(self): """Test that required fields are enforced""" with pytest.raises(ValidationError): - WeaveConfig() + WeaveConfig.model_validate({}) with pytest.raises(ValidationError): - WeaveConfig(api_key="key") + WeaveConfig.model_validate({"api_key": "key"}) with pytest.raises(ValidationError): - WeaveConfig(project="project") + WeaveConfig.model_validate({"project": "project"}) def test_endpoint_validation_https_only(self): """Test endpoint validation only allows HTTPS""" diff --git a/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py index 815ac30c0b..bab176e285 100644 --- a/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py +++ b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py @@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector): auth = PasswordAuthenticator(config.user, config.password) options = ClusterOptions(auth) - self._cluster = Cluster(config.connection_string, options) + self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType] self._bucket = self._cluster.bucket(config.bucket_name) self._scope = self._bucket.scope(config.scope_name) self._bucket_name = config.bucket_name @@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) try: - CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) + CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue] search_iter = self._scope.search( self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"]) ) diff --git a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py index 46f3224a95..823b877707 100644 --- a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py +++ b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, TypedDict +from typing import Any, TypedDict, cast from packaging import version from pydantic import BaseModel, model_validator @@ -92,7 +92,7 @@ class MilvusVector(BaseVector): def _load_collection_fields(self, fields: list[str] | None = None): if fields is None: # Load collection fields from remote server - collection_info = self._client.describe_collection(self._collection_name) + collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name)) fields = [field["name"] for field in collection_info["fields"]] # Since primary field is auto-id, no need to track it self._fields = [f for f in fields if f != Field.PRIMARY_KEY] @@ -106,7 +106,8 @@ class MilvusVector(BaseVector): return False try: - milvus_version = self._client.get_server_version() + milvus_version_raw = self._client.get_server_version() + milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw) # Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility if "Zilliz Cloud" in milvus_version: return True diff --git a/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py index 70377c82c8..5d9ab38529 100644 --- a/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py +++ b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py @@ -3,7 +3,7 @@ import json import logging import re import uuid -from typing import Any +from typing import Any, TypedDict import jieba.posseg as pseg # type: ignore import numpy @@ -25,6 +25,18 @@ logger = logging.getLogger(__name__) oracledb.defaults.fetch_lobs = False +class _OraclePoolParams(TypedDict, total=False): + user: str + password: str + dsn: str + min: int + max: int + increment: int + config_dir: str | None + wallet_location: str | None + wallet_password: str | None + + class OracleVectorConfig(BaseModel): user: str password: str @@ -127,22 +139,18 @@ class OracleVector(BaseVector): return connection def _create_connection_pool(self, config: OracleVectorConfig): - pool_params = { - "user": config.user, - "password": config.password, - "dsn": config.dsn, - "min": 1, - "max": 5, - "increment": 1, - } + pool_params = _OraclePoolParams( + user=config.user, + password=config.password, + dsn=config.dsn, + min=1, + max=5, + increment=1, + ) if config.is_autonomous: - pool_params.update( - { - "config_dir": config.config_dir, - "wallet_location": config.wallet_location, - "wallet_password": config.wallet_password, - } - ) + pool_params["config_dir"] = config.config_dir + pool_params["wallet_location"] = config.wallet_location + pool_params["wallet_password"] = config.wallet_password return oracledb.create_pool(**pool_params) def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 474b200fc5..71a2554a60 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -42,7 +42,7 @@ from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm +from models.human_input import HumanInputForm, HumanInputFormRecipient from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -63,6 +63,7 @@ class _WorkflowRunError(Exception): def _build_human_input_required_reason( reason_model: WorkflowPauseReason, form_model: HumanInputForm | None, + recipients: Sequence[HumanInputFormRecipient] = (), ) -> HumanInputRequired: form_content = "" inputs = [] @@ -89,7 +90,7 @@ def _build_human_input_required_reason( resolved_default_values = dict(definition.default_values) node_title = definition.node_title or node_title - return HumanInputRequired( + reason = HumanInputRequired( form_id=form_id, form_content=form_content, inputs=inputs, @@ -98,6 +99,7 @@ def _build_human_input_required_reason( node_title=node_title, resolved_default_values=resolved_default_values, ) + return reason class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): @@ -804,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) for form in session.scalars(form_stmt).all(): form_models[form.id] = form + recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {} + if form_ids: + recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + for recipient in session.scalars(recipient_stmt).all(): + recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient) pause_reasons: list[PauseReason] = [] for reason in pause_reason_models: if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: form_model = form_models.get(reason.form_id) - pause_reasons.append(_build_human_input_required_reason(reason, form_model)) + pause_reasons.append( + _build_human_input_required_reason( + reason, + form_model, + recipients_by_form_id.get(reason.form_id, ()), + ) + ) else: pause_reasons.append(reason.to_entity()) return pause_reasons diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 5e8c7aa337..8ff53d143b 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -162,6 +162,7 @@ class AppGenerateService: invoke_from=invoke_from, streaming=True, call_depth=0, + workflow_run_id=str(uuid.uuid4()), ) payload_json = payload.model_dump_json() @@ -183,6 +184,10 @@ class AppGenerateService: else: # Blocking mode: run synchronously and return JSON instead of SSE # Keep behaviour consistent with WORKFLOW blocking branch. + pause_config = PauseStateLayerConfig( + session_factory=session_factory.get_session_maker(), + state_owner_user_id=workflow.created_by, + ) advanced_generator = AdvancedChatAppGenerator() return rate_limit.generate( advanced_generator.convert_to_event_stream( @@ -194,6 +199,7 @@ class AppGenerateService: invoke_from=invoke_from, workflow_run_id=str(uuid.uuid4()), streaming=False, + pause_state_config=pause_config, ) ), request_id=request_id, diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 5040fcc7e3..bd7758f1c0 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime from typing import TYPE_CHECKING +from cachetools.func import ttl_cache from pydantic import BaseModel, ConfigDict, Field, model_validator from configs import dify_config @@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None: class EnterpriseService: @classmethod + @ttl_cache(ttl=5) def get_info(cls): return EnterpriseRequest.send_request("GET", "/info") diff --git a/api/services/feature_service.py b/api/services/feature_service.py index e18eb096c9..38518378f7 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel): enable_change_email: bool = True plugin_manager: PluginManagerModel = PluginManagerModel() trial_models: list[str] = [] + enable_creators_platform: bool = False enable_trial_app: bool = False enable_explore_banner: bool = False @@ -241,6 +242,9 @@ class FeatureService: if dify_config.MARKETPLACE_ENABLED: system_features.enable_marketplace = True + if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED: + system_features.enable_creators_platform = True + return system_features @classmethod diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 7bd056b8a0..b8242ab3a5 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_provider_encrypter -from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.tools.utils.system_encryption import decrypt_system_params from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider_ids import ToolProviderID @@ -521,7 +521,7 @@ class BuiltinToolManageService: ) if system_client: try: - oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + oauth_params = decrypt_system_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 6e14d996ea..b8a76e4945 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler -from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.tools.utils.system_encryption import decrypt_system_params from core.trigger.entities.api_entities import ( TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity, @@ -635,7 +635,7 @@ class TriggerProviderService: if system_client: try: - oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + oauth_params = decrypt_system_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 5fca444723..94f88f8c49 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker from core.app.apps.message_generator import MessageGenerator from core.app.entities.task_entities import ( + HumanInputRequiredResponse, MessageReplaceStreamResponse, NodeFinishStreamResponse, NodeStartStreamResponse, @@ -22,10 +23,14 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext +from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReasonType from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus from graphon.runtime import GraphRuntimeState from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter +from models.human_input import HumanInputForm from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot @@ -59,8 +64,10 @@ def build_workflow_event_stream( tenant_id: str, app_id: str, session_maker: sessionmaker[Session], + human_input_surface: HumanInputSurface | None = None, idle_timeout: float = 300, ping_interval: float = 10.0, + close_on_pause: bool = True, ) -> Generator[Mapping[str, Any] | str, None, None]: topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id) workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) @@ -115,13 +122,15 @@ def build_workflow_event_stream( message_context=message_context, pause_entity=pause_entity, resumption_context=resumption_context, + session_maker=session_maker, + human_input_surface=human_input_surface, ) for event in snapshot_events: last_msg_time = time.time() last_ping_time = last_msg_time yield event - if _is_terminal_event(event, include_paused=True): + if _is_terminal_event(event, close_on_pause=close_on_pause): return while True: @@ -146,7 +155,7 @@ def build_workflow_event_stream( last_msg_time = time.time() last_ping_time = last_msg_time yield event - if _is_terminal_event(event, include_paused=True): + if _is_terminal_event(event, close_on_pause=close_on_pause): return finally: buffer_state.stop_event.set() @@ -207,6 +216,8 @@ def _build_snapshot_events( message_context: MessageContext | None, pause_entity: WorkflowPauseEntity | None, resumption_context: WorkflowResumptionContext | None, + session_maker: sessionmaker[Session] | None = None, + human_input_surface: HumanInputSurface | None = None, ) -> list[Mapping[str, Any]]: events: list[Mapping[str, Any]] = [] @@ -241,12 +252,24 @@ def _build_snapshot_events( events.append(node_finished) if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None: + for human_input_event in _build_human_input_required_events( + workflow_run_id=workflow_run.id, + task_id=task_id, + pause_entity=pause_entity, + session_maker=session_maker, + human_input_surface=human_input_surface, + ): + _apply_message_context(human_input_event, message_context) + events.append(human_input_event) + pause_event = _build_pause_event( workflow_run=workflow_run, workflow_run_id=workflow_run.id, task_id=task_id, pause_entity=pause_entity, resumption_context=resumption_context, + session_maker=session_maker, + human_input_surface=human_input_surface, ) if pause_event is not None: _apply_message_context(pause_event, message_context) @@ -314,6 +337,97 @@ def _build_node_started_event( return response.to_ignore_detail_dict() +def _build_human_input_required_events( + *, + workflow_run_id: str, + task_id: str, + pause_entity: WorkflowPauseEntity, + session_maker: sessionmaker[Session] | None, + human_input_surface: HumanInputSurface | None, +) -> list[dict[str, Any]]: + reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()] + human_input_form_ids = [ + form_id + for reason in reasons + if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED + for form_id in [reason.get("form_id")] + if isinstance(form_id, str) + ] + + expiration_times_by_form_id: dict[str, int] = {} + display_in_ui_by_form_id: dict[str, bool] = {} + form_tokens_by_form_id: dict[str, str] = {} + if human_input_form_ids and session_maker is not None: + stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where( + HumanInputForm.id.in_(human_input_form_ids) + ) + with session_maker() as session: + for form_id, expiration_time, form_definition in session.execute(stmt): + expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp()) + try: + definition_payload = json.loads(form_definition) if form_definition else {} + except (TypeError, json.JSONDecodeError): + definition_payload = {} + display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) + form_tokens_by_form_id = load_form_tokens_by_form_id( + human_input_form_ids, + session=session, + surface=human_input_surface, + ) + + events: list[dict[str, Any]] = [] + for reason in reasons: + if reason.get("TYPE") != PauseReasonType.HUMAN_INPUT_REQUIRED: + continue + + form_id_raw = reason.get("form_id") + node_id_raw = reason.get("node_id") + node_title_raw = reason.get("node_title") + form_content_raw = reason.get("form_content") + if not isinstance(form_id_raw, str): + continue + if not isinstance(node_id_raw, str): + continue + if not isinstance(node_title_raw, str): + continue + if not isinstance(form_content_raw, str): + continue + form_id = form_id_raw + node_id = node_id_raw + node_title = node_title_raw + form_content = form_content_raw + + inputs = reason.get("inputs") + actions = reason.get("actions") + resolved_default_values = reason.get("resolved_default_values") + + expiration_time = expiration_times_by_form_id.get(form_id) + if expiration_time is None: + continue + + response = HumanInputRequiredResponse( + task_id=task_id, + workflow_run_id=workflow_run_id, + data=HumanInputRequiredResponse.Data( + form_id=form_id, + node_id=node_id, + node_title=node_title, + form_content=form_content, + inputs=inputs if isinstance(inputs, list) else [], + actions=actions if isinstance(actions, list) else [], + display_in_ui=display_in_ui_by_form_id.get(form_id, False), + form_token=form_tokens_by_form_id.get(form_id), + resolved_default_values=(resolved_default_values if isinstance(resolved_default_values, dict) else {}), + expiration_time=expiration_time, + ), + ) + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + events.append(payload) + + return events + + def _build_node_finished_event( *, workflow_run_id: str, @@ -356,6 +470,8 @@ def _build_pause_event( task_id: str, pause_entity: WorkflowPauseEntity, resumption_context: WorkflowResumptionContext | None, + session_maker: sessionmaker[Session] | None, + human_input_surface: HumanInputSurface | None = None, ) -> dict[str, Any] | None: paused_nodes: list[str] = [] outputs: dict[str, Any] = {} @@ -365,6 +481,36 @@ def _build_pause_event( outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {})) reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()] + human_input_form_ids = [ + form_id + for reason in reasons + if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED + for form_id in [reason.get("form_id")] + if isinstance(form_id, str) + ] + form_tokens_by_form_id: dict[str, str] = {} + expiration_times_by_form_id: dict[str, int] = {} + if human_input_form_ids and session_maker is not None: + with session_maker() as session: + form_tokens_by_form_id = load_form_tokens_by_form_id( + human_input_form_ids, + session=session, + surface=human_input_surface, + ) + stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( + HumanInputForm.id.in_(human_input_form_ids) + ) + for row in session.execute(stmt): + form_id, expiration_time, *_rest = row + expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp()) + # Reconnect paths must preserve the same pause-reason contract as live streams; + # otherwise clients see schema drift after resume. + reasons = enrich_human_input_pause_reasons( + reasons, + form_tokens_by_form_id=form_tokens_by_form_id, + expiration_times_by_form_id=expiration_times_by_form_id, + ) + response = WorkflowPauseStreamResponse( task_id=task_id, workflow_run_id=workflow_run_id, @@ -449,12 +595,19 @@ def _parse_event_message(message: bytes) -> Mapping[str, Any] | None: return event -def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool: +def _is_terminal_event( + event: Mapping[str, Any] | str, + close_on_pause: bool = True, + *, + include_paused: bool | None = None, +) -> bool: + if include_paused is not None: + close_on_pause = include_paused if not isinstance(event, Mapping): return False event_type = event.get("event") if event_type == StreamEvent.WORKFLOW_FINISHED.value: return True - if include_paused: + if close_on_pause: return event_type == StreamEvent.WORKFLOW_PAUSED.value return False diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index c22e7e9918..5ceeb302c8 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -399,6 +399,8 @@ def _resume_advanced_chat( workflow_run_id: str, workflow_run: WorkflowRun, ) -> None: + resumed_generate_entity = generate_entity.model_copy(update={"stream": True}) + try: triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) except ValueError: @@ -426,7 +428,7 @@ def _resume_advanced_chat( user=user, conversation=conversation, message=message, - application_generate_entity=generate_entity, + application_generate_entity=resumed_generate_entity, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, graph_runtime_state=graph_runtime_state, @@ -436,9 +438,8 @@ def _resume_advanced_chat( logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id) raise - if generate_entity.stream: - assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT) + assert isinstance(response, Generator) + _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT) def _resume_workflow( @@ -455,6 +456,8 @@ def _resume_workflow( workflow_run_repo, pause_entity, ) -> None: + resumed_generate_entity = generate_entity.model_copy(update={"stream": True}) + try: triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) except ValueError: @@ -480,7 +483,7 @@ def _resume_workflow( app_model=app_model, workflow=workflow, user=user, - application_generate_entity=generate_entity, + application_generate_entity=resumed_generate_entity, graph_runtime_state=graph_runtime_state, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, @@ -490,11 +493,18 @@ def _resume_workflow( logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id) raise - if generate_entity.stream: - assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW) + assert isinstance(response, Generator) + _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW) - workflow_run_repo.delete_workflow_pause(pause_entity) + try: + workflow_run_repo.delete_workflow_pause(pause_entity) + except Exception as exc: + if exc.__class__.__name__ != "_WorkflowRunError" or "WorkflowPause not found" not in str(exc): + raise + logger.info( + "Skipped deleting workflow pause %s after resume because it was already replaced or removed", + pause_entity.id, + ) @shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution") diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index d10e5ed13c..3b5e822b90 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -171,35 +171,13 @@ class TestChatMessageApiPermissions: parent_message_id=None, ) - class MockQuery: - def __init__(self, model): - self.model = model - - def where(self, *args, **kwargs): - return self - - def first(self): - if getattr(self.model, "__name__", "") == "Conversation": - return mock_conversation - return None - - def order_by(self, *args, **kwargs): - return self - - def limit(self, *_): - return self - - def all(self): - if getattr(self.model, "__name__", "") == "Message": - return [mock_message] - return [] - mock_session = mock.Mock() - mock_session.query.side_effect = MockQuery - mock_session.scalar.return_value = False + mock_session.scalar.return_value = mock_conversation + mock_session.scalars.return_value.all.return_value = [mock_message] monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session)) monkeypatch.setattr(message_api, "current_user", mock_account) + monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock()) class DummyPagination: def __init__(self, data, limit, has_more): diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index f14b2c0ae5..635cfee2da 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -24,7 +24,6 @@ def _patch_wraps(): patch("controllers.console.wraps.dify_config", dify_settings), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), ): - mock_db.session.query.return_value.first.return_value = MagicMock() yield diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index aebe87839c..d9828e19c5 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -2,6 +2,7 @@ from __future__ import annotations +import secrets from dataclasses import dataclass, field from datetime import datetime, timedelta from unittest.mock import Mock @@ -11,6 +12,7 @@ import pytest from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker +from core.workflow.human_input_adapter import DeliveryMethodType from extensions.ext_storage import storage from graphon.entities import WorkflowExecution from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType @@ -20,9 +22,11 @@ from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( + BackstageRecipientPayload, HumanInputDelivery, HumanInputForm, HumanInputFormRecipient, + RecipientType, ) from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -628,12 +632,12 @@ class TestPrivateWorkflowPauseEntity: class TestBuildHumanInputRequiredReason: """Integration tests for _build_human_input_required_reason using real DB models.""" - def test_builds_reason_from_form_definition( + def test_prefers_standalone_web_app_token_when_available( self, db_session_with_containers: Session, test_scope: _TestScope, ) -> None: - """Build the graph pause reason from the stored form definition.""" + """Use the public standalone web-app token for service API payloads.""" expiration_time = naive_utc_now() form_definition = FormDefinition( @@ -660,6 +664,40 @@ class TestBuildHumanInputRequiredReason: db_session_with_containers.add(form_model) db_session_with_containers.flush() + delivery = HumanInputDelivery( + form_id=form_model.id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload="{}", + ) + db_session_with_containers.add(delivery) + db_session_with_containers.flush() + + backstage_access_token = secrets.token_urlsafe(8) + backstage_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=backstage_access_token, + ) + console_access_token = secrets.token_urlsafe(8) + console_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.CONSOLE, + recipient_payload="{}", + access_token=console_access_token, + ) + web_app_access_token = secrets.token_urlsafe(8) + web_app_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + recipient_payload="{}", + access_token=web_app_access_token, + ) + db_session_with_containers.add_all([backstage_recipient, console_recipient, web_app_recipient]) + db_session_with_containers.flush() # Create a pause so the reason has a valid pause_id workflow_run = _create_workflow_run( db_session_with_containers, @@ -688,8 +726,15 @@ class TestBuildHumanInputRequiredReason: # Refresh to ensure we have DB-round-tripped objects db_session_with_containers.refresh(form_model) db_session_with_containers.refresh(reason_model) + db_session_with_containers.refresh(backstage_recipient) + db_session_with_containers.refresh(console_recipient) + db_session_with_containers.refresh(web_app_recipient) - reason = _build_human_input_required_reason(reason_model, form_model) + reason = _build_human_input_required_reason( + reason_model, + form_model, + [backstage_recipient, console_recipient, web_app_recipient], + ) assert isinstance(reason, HumanInputRequired) assert reason.node_title == "Ask Name" @@ -697,3 +742,92 @@ class TestBuildHumanInputRequiredReason: assert reason.inputs[0].output_variable_name == "name" assert reason.actions[0].id == "approve" assert reason.resolved_default_values == {"name": "Alice"} + assert not hasattr(reason, "form_token") + + def test_falls_back_to_console_token_when_web_app_token_missing( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Use the console token only when no standalone web-app token exists.""" + + expiration_time = naive_utc_now() + form_definition = FormDefinition( + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "Alice"}, + node_title="Ask Name", + display_in_ui=True, + ) + + form_model = HumanInputForm( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + workflow_run_id=str(uuid4()), + node_id="node-1", + form_definition=form_definition.model_dump_json(), + rendered_content="rendered", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + db_session_with_containers.add(form_model) + db_session_with_containers.flush() + + delivery = HumanInputDelivery( + form_id=form_model.id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload="{}", + ) + db_session_with_containers.add(delivery) + db_session_with_containers.flush() + + backstage_access_token = secrets.token_urlsafe(8) + backstage_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=backstage_access_token, + ) + console_access_token = secrets.token_urlsafe(8) + console_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.CONSOLE, + recipient_payload="{}", + access_token=console_access_token, + ) + db_session_with_containers.add_all([backstage_recipient, console_recipient]) + db_session_with_containers.flush() + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.flush() + test_scope.state_keys.add(pause.state_object_key) + + reason_model = WorkflowPauseReason( + pause_id=pause.id, + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id=form_model.id, + node_id="node-1", + message="", + ) + db_session_with_containers.add(reason_model) + db_session_with_containers.commit() + + reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient, console_recipient]) + + assert isinstance(reason, HumanInputRequired) + assert not hasattr(reason, "form_token") diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index d82933ccb9..3dcd6586e2 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -13,6 +13,12 @@ from models.model import App, Conversation, Message from services.feedback_service import FeedbackService +def _execute_result(rows): + result = mock.Mock() + result.all.return_value = rows + return result + + class TestFeedbackService: """Test FeedbackService methods.""" @@ -81,25 +87,17 @@ class TestFeedbackService: def test_export_feedbacks_csv_format(self, mock_db_session, sample_data): """Test exporting feedback data in CSV format.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ) + ] + ) # Test CSV export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -120,25 +118,17 @@ class TestFeedbackService: def test_export_feedbacks_json_format(self, mock_db_session, sample_data): """Test exporting feedback data in JSON format.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ) + ] + ) # Test JSON export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -157,25 +147,17 @@ class TestFeedbackService: def test_export_feedbacks_with_filters(self, mock_db_session, sample_data): """Test exporting feedback with various filters.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ) + ] + ) # Test with filters result = FeedbackService.export_feedbacks( @@ -193,17 +175,7 @@ class TestFeedbackService: def test_export_feedbacks_no_data(self, mock_db_session, sample_data): """Test exporting feedback when no data exists.""" - - # Setup mock query result with no data - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result([]) result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -251,24 +223,17 @@ class TestFeedbackService: created_at=datetime(2024, 1, 1, 10, 0, 0), ) - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - long_message, - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + long_message, + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ) + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -309,24 +274,17 @@ class TestFeedbackService: created_at=datetime(2024, 1, 1, 10, 0, 0), ) - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - chinese_feedback, - chinese_message, - sample_data["conversation"], - sample_data["app"], - None, # No account for user feedback - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + chinese_feedback, + chinese_message, + sample_data["conversation"], + sample_data["app"], + None, + ) + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -339,32 +297,24 @@ class TestFeedbackService: def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data): """Test that rating emojis are properly formatted in export.""" - - # Setup mock query result with both like and dislike feedback - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ), - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ), - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ), + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ), + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 55873b06a8..7174530e97 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine): configure_session_factory(_unit_test_engine, expire_on_commit=False) -def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account): +def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner): """ - Helper to set up the mock DB execute chain for tenant/account authentication. + Helper to stub the tenant-owner execute result for service API app authentication. - This configures the mock to return (tenant, account) for the - db.session.execute(select(...).join().join().where()).one_or_none() - query used by validate_app_token decorator. + The validate_app_token decorator currently resolves the active tenant owner + via db.session.execute(select(Tenant, Account)...).one_or_none(). Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return - mock_account: Mock account object to return + mock_owner: Mock owner object to return from the execute result """ - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner) -def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): +def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join): """ - Helper to set up the mock DB execute chain for dataset tenant authentication. + Helper to stub the tenant-owner execute result for dataset token authentication. - This configures the mock to return (tenant, tenant_account) for the - db.session.execute(select(...).where().where().where().where()).one_or_none() - query used by validate_dataset_token decorator. + The validate_dataset_token decorator currently resolves the owner mapping via + db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and + then loads the Account separately via db.session.get(...). Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return - mock_ta: Mock tenant account object to return + mock_tenant_account_join: Mock tenant-account join object to return """ - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join) diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py index 9f1ff9b40f..bfa4048191 100644 --- a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") @@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") @@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation: csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with ( patch("services.annotation_service.current_account_with_tenant") as mock_auth, patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")), @@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index cb4fe40944..17bee94c52 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -43,7 +43,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = True # Act @@ -76,7 +75,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists # Act with self.app.test_request_context( @@ -109,7 +107,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = False # Act @@ -135,7 +132,6 @@ class TestAuthenticationSecurity: def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db): """Test that reset password returns success with token for existing accounts.""" # Mock the setup check - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists # Test with existing account mock_get_user.return_value = MagicMock(email="existing@example.com") diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index 9929a71120..b7bc73da5f 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi: - IP rate limiting is checked """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = mock_account mock_send_email.return_value = "email_token_123" @@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi: - Registration is allowed by system features """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = None mock_get_features.return_value.is_allow_register = True @@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi: - Registration is blocked by system features """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = None mock_get_features.return_value.is_allow_register = False @@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi: - Prevents spam and abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi: - AccountInFreezeError is raised for frozen accounts """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.side_effect = AccountRegisterError("Account frozen") @@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi: - Defaults to en-US when not specified """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = mock_account mock_send_email.return_value = "token" @@ -286,7 +280,6 @@ class TestEmailCodeLoginApi: - User is logged in with token pair """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -335,7 +328,6 @@ class TestEmailCodeLoginApi: - User is logged in after account creation """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"} mock_get_user.return_value = None mock_create_account.return_value = mock_account @@ -369,7 +361,6 @@ class TestEmailCodeLoginApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -392,7 +383,6 @@ class TestEmailCodeLoginApi: - InvalidEmailError is raised when email doesn't match token """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} # Act & Assert @@ -415,7 +405,6 @@ class TestEmailCodeLoginApi: - EmailCodeError is raised for wrong verification code """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} # Act & Assert @@ -453,7 +442,6 @@ class TestEmailCodeLoginApi: - User is added as owner of new workspace """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] @@ -496,7 +484,6 @@ class TestEmailCodeLoginApi: - WorkspacesLimitExceeded is raised when limit reached """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] @@ -538,7 +525,6 @@ class TestEmailCodeLoginApi: - NotAllowedCreateWorkspace is raised when creation disabled """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 0cf97da878..d089be8905 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -110,7 +110,6 @@ class TestLoginApi: - Rate limit is reset after successful login """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.return_value = mock_account @@ -162,7 +161,6 @@ class TestLoginApi: - Authentication proceeds with invitation token """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = {"data": {"email": "test@example.com"}} mock_authenticate.return_value = mock_account @@ -199,7 +197,6 @@ class TestLoginApi: - EmailPasswordLoginLimitError is raised when limit exceeded """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True mock_get_invitation.return_value = None @@ -228,7 +225,6 @@ class TestLoginApi: - AccountInFreezeError is raised for frozen accounts """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_frozen.return_value = True # Act & Assert @@ -268,7 +264,6 @@ class TestLoginApi: - Generic error message prevents user enumeration """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = AccountPasswordError("Invalid password") @@ -305,7 +300,6 @@ class TestLoginApi: - Login is prevented even with valid credentials """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = AccountLoginError("Account is banned") @@ -351,7 +345,6 @@ class TestLoginApi: - User cannot login without an assigned workspace """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.return_value = mock_account @@ -383,7 +376,6 @@ class TestLoginApi: - Security check prevents invitation token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}} @@ -425,7 +417,6 @@ class TestLoginApi: mock_token_pair, ): """Test that login retries with lowercase email when uppercase lookup fails.""" - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account] @@ -459,7 +450,6 @@ class TestLoginApi: mock_db, app, ): - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} mock_get_account.side_effect = Unauthorized("Account is banned.") @@ -513,7 +503,6 @@ class TestLogoutApi: - Success response is returned """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_current_account.return_value = (mock_account, MagicMock()) # Act @@ -539,7 +528,6 @@ class TestLogoutApi: - Success response is returned """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() # Create a mock anonymous user that will pass isinstance check anonymous_user = MagicMock() mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {}) diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py index c80758c857..810f1b94fc 100644 --- a/api/tests/unit_tests/controllers/console/billing/test_billing.py +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -46,7 +46,6 @@ class TestPartnerTenants: patch("libs.login.dify_config.LOGIN_DISABLED", False), patch("libs.login.check_csrf_token") as mock_csrf, ): - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_csrf.return_value = None yield {"db": mock_db, "csrf": mock_csrf} diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index 2be5a21f28..6405558bb4 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -8,8 +8,10 @@ from werkzeug.exceptions import Forbidden import controllers.console.tag.tags as module from controllers.console import console_ns from controllers.console.tag.tags import ( - TagBindingCreateApi, - TagBindingDeleteApi, + DeprecatedTagBindingCreateApi, + DeprecatedTagBindingRemoveApi, + TagBindingCollectionApi, + TagBindingItemApi, TagListApi, TagUpdateDeleteApi, ) @@ -205,9 +207,9 @@ class TestTagUpdateDeleteApi: assert status == 204 -class TestTagBindingCreateApi: +class TestTagBindingCollectionApi: def test_create_success(self, app, admin_user, payload_patch): - api = TagBindingCreateApi() + api = TagBindingCollectionApi() method = unwrap(api.post) payload = { @@ -232,7 +234,7 @@ class TestTagBindingCreateApi: assert result["result"] == "success" def test_create_forbidden(self, app, readonly_user, payload_patch): - api = TagBindingCreateApi() + api = TagBindingCollectionApi() method = unwrap(api.post) with app.test_request_context("/", json={}): @@ -247,9 +249,78 @@ class TestTagBindingCreateApi: method(api) -class TestTagBindingDeleteApi: +class TestDeprecatedTagBindingCreateApi: + def test_create_success(self, app, admin_user, payload_patch): + api = DeprecatedTagBindingCreateApi() + method = unwrap(api.post) + + payload = { + "tag_ids": ["tag-1"], + "target_id": "target-1", + "type": "knowledge", + } + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock, + ): + result, status = method(api) + + save_mock.assert_called_once() + assert status == 200 + assert result["result"] == "success" + + +class TestTagBindingItemApi: + def test_delete_success(self, app, admin_user, payload_patch): + api = TagBindingItemApi() + method = unwrap(api.delete) + + payload = { + "target_id": "target-1", + "type": "knowledge", + } + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock, + ): + result, status = method(api, "tag-1") + + delete_mock.assert_called_once() + delete_payload = delete_mock.call_args.args[0] + assert delete_payload.tag_id == "tag-1" + assert delete_payload.target_id == "target-1" + assert delete_payload.type == TagType.KNOWLEDGE + assert status == 200 + assert result["result"] == "success" + + def test_delete_forbidden(self, app, readonly_user): + api = TagBindingItemApi() + method = unwrap(api.delete) + + with app.test_request_context("/"): + with patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(readonly_user, None), + ): + with pytest.raises(Forbidden): + method(api, "tag-1") + + +class TestDeprecatedTagBindingRemoveApi: def test_remove_success(self, app, admin_user, payload_patch): - api = TagBindingDeleteApi() + api = DeprecatedTagBindingRemoveApi() method = unwrap(api.post) payload = { @@ -274,7 +345,7 @@ class TestTagBindingDeleteApi: assert result["result"] == "success" def test_remove_forbidden(self, app, readonly_user, payload_patch): - api = TagBindingDeleteApi() + api = DeprecatedTagBindingRemoveApi() method = unwrap(api.post) with app.test_request_context("/", json={}): @@ -297,3 +368,35 @@ class TestTagResponseModel: assert payload["type"] == "knowledge" assert payload["binding_count"] == "1" + + +class TestTagBindingRouteMetadata: + def test_legacy_write_routes_are_marked_deprecated(self): + assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True + assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True + assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True + assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True + + def test_write_routes_have_stable_operation_ids(self): + assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding" + assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding" + assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated" + assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated" + + def test_canonical_and_legacy_write_routes_are_registered(self): + route_map = { + resource.__name__: urls + for resource, urls, _route_doc, _kwargs in console_ns.resources + if resource.__name__ + in { + "TagBindingCollectionApi", + "TagBindingItemApi", + "DeprecatedTagBindingCreateApi", + "DeprecatedTagBindingRemoveApi", + } + } + + assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",) + assert route_map["TagBindingItemApi"] == ("/tag-bindings/",) + assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",) + assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",) diff --git a/api/tests/unit_tests/controllers/console/test_human_input_form.py b/api/tests/unit_tests/controllers/console/test_human_input_form.py index 232b6eee79..ebf803cac9 100644 --- a/api/tests/unit_tests/controllers/console/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/console/test_human_input_form.py @@ -122,6 +122,35 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) handler(api, form_token="token") +def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP) + + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_by_token(self, _token): + return form + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="user-1"), "tenant-1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/console/api/form/human_input/token", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + with pytest.raises(NotFoundError): + handler(api, form_token="token") + + def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None: submit_mock = Mock() form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE) diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 26ff264f18..0b1a32581a 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -24,10 +24,6 @@ def app(): return app -def _mock_wraps_db(mock_db): - mock_db.session.query.return_value.first.return_value = MagicMock() - - def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account: tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id") account = Account(name=account_id, email=email) @@ -64,7 +60,6 @@ class TestChangeEmailSend: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) @@ -117,7 +112,6 @@ class TestChangeEmailSend: """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" from controllers.console.auth.error import InvalidTokenError - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) @@ -163,7 +157,6 @@ class TestChangeEmailValidity: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("user@example.com", "acc2") mock_current_account.return_value = (mock_account, None) @@ -223,7 +216,6 @@ class TestChangeEmailValidity: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) mock_is_rate_limit.return_value = False @@ -280,7 +272,6 @@ class TestChangeEmailValidity: """A token whose phase marker is a string but not a known transition must be rejected.""" from controllers.console.auth.error import InvalidTokenError - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) mock_is_rate_limit.return_value = False @@ -330,7 +321,6 @@ class TestChangeEmailValidity: """A token minted without a phase marker (e.g. a hand-crafted token) must not validate.""" from controllers.console.auth.error import InvalidTokenError - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) mock_is_rate_limit.return_value = False @@ -378,7 +368,6 @@ class TestChangeEmailReset: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") mock_current_account.return_value = (current_user, None) @@ -434,7 +423,6 @@ class TestChangeEmailReset: """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" from controllers.console.auth.error import InvalidTokenError - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") mock_current_account.return_value = (current_user, None) @@ -488,7 +476,6 @@ class TestChangeEmailReset: """A verified token for address A must not be replayed to change to address B.""" from controllers.console.auth.error import InvalidTokenError - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") mock_current_account.return_value = (current_user, None) @@ -561,7 +548,6 @@ class TestAccountDeletionFeedback: @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback") def test_should_normalize_feedback_email(self, mock_update, mock_db, app): - _mock_wraps_db(mock_db) with app.test_request_context( "/account/delete/feedback", method="POST", @@ -578,7 +564,6 @@ class TestCheckEmailUnique: @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app): - _mock_wraps_db(mock_db) mock_is_freeze.return_value = False mock_check_unique.return_value = True diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py index 239fec8430..811bf5b1e7 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_members.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask import Flask, g @@ -16,10 +16,6 @@ def app(): return flask_app -def _mock_wraps_db(mock_db): - mock_db.session.query.return_value.first.return_value = MagicMock() - - def _build_feature_flags(): placeholder_quota = SimpleNamespace(limit=0, size=0) workspace_members = SimpleNamespace(is_available=lambda count: True) @@ -49,7 +45,6 @@ class TestMemberInviteEmailApi: mock_get_features, app, ): - _mock_wraps_db(mock_db) mock_get_features.return_value = _build_feature_flags() mock_invite_member.return_value = "token-abc" diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index f6e096a97b..aa4973851a 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -310,7 +310,6 @@ class TestSystemSetup: def test_should_allow_when_setup_complete(self, mock_db): """Test that requests are allowed when setup is complete""" # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists @setup_required def admin_view(): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py index 44feacf2ad..1422f29849 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None @contextmanager def _mock_db(): - mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True)) + mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True) with patch("extensions.ext_database.db.session", mock_session): yield diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index f48ace427d..f5d93b5ac3 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter from controllers.service_api.app.error import AppUnavailableError from models.account import TenantStatus from models.model import App, AppMode -from tests.unit_tests.conftest import setup_mock_tenant_account_query +from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result class TestAppParameterApi: @@ -74,7 +74,7 @@ class TestAppParameterApi: # Mock tenant owner info for login mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -120,7 +120,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -161,7 +161,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act & Assert with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -200,7 +200,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act & Assert with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -263,7 +263,7 @@ class TestAppMetaApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -331,7 +331,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -388,7 +388,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -434,7 +434,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -486,7 +486,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): diff --git a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py new file mode 100644 index 0000000000..846d5368f3 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py @@ -0,0 +1,707 @@ +"""Dedicated tests for HITL behavior exposed through the Service API.""" + +from __future__ import annotations + +import json +import sys +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import ANY, MagicMock, Mock + +import pytest + +import services.app_generate_service as ags_module +from controllers.service_api.app.workflow_events import WorkflowEventsApi +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.common import workflow_response_converter +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, + HumanInputRequiredResponse, + WorkflowAppPausedBlockingResponse, + WorkflowPauseStreamResponse, +) +from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper +from core.workflow.human_input_policy import HumanInputSurface +from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType +from graphon.runtime import GraphRuntimeState, VariablePool +from models.account import Account +from models.enums import CreatorUserRole +from models.model import AppMode +from models.workflow import WorkflowRun +from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot +from repositories.entities.workflow_pause import WorkflowPauseEntity +from services.app_generate_service import AppGenerateService +from services.workflow_event_snapshot_service import _build_snapshot_events +from tests.unit_tests.controllers.service_api.conftest import _unwrap + + +class _DummyRateLimit: + @staticmethod + def gen_request_key() -> str: + return "dummy-request-id" + + def __init__(self, client_id: str, max_active_requests: int) -> None: + self.client_id = client_id + self.max_active_requests = max_active_requests + + def enter(self, request_id: str | None = None) -> str: + return request_id or "dummy-request-id" + + def exit(self, request_id: str) -> None: + return None + + def generate(self, generator, request_id: str): + return generator + + +def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run): + workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"] + repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run) + monkeypatch.setattr( + workflow_events_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: repo, + ) + monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object())) + return workflow_events_module + + +def _build_service_api_pause_converter() -> WorkflowResponseConverter: + application_generate_entity = SimpleNamespace( + inputs={}, + files=[], + invoke_from=InvokeFrom.SERVICE_API, + app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), + ) + system_variables = build_system_variables( + user_id="user", + app_id="app-id", + workflow_id="workflow-id", + workflow_execution_id="run-id", + ) + user = MagicMock(spec=Account) + user.id = "account-id" + user.name = "Tester" + user.email = "tester@example.com" + return WorkflowResponseConverter( + application_generate_entity=application_generate_entity, + user=user, + system_variables=system_variables, + ) + + +def _build_advanced_chat_paused_blocking_response() -> AdvancedChatPausedBlockingResponse: + data = AdvancedChatPausedBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + workflow_run_id="run-1", + answer="partial", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + paused_nodes=["node-1"], + reasons=[ + { + "type": PauseReasonType.HUMAN_INPUT_REQUIRED, + "form_id": "form-1", + "expiration_time": 100, + } + ], + status=WorkflowExecutionStatus.PAUSED, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ) + return AdvancedChatPausedBlockingResponse(task_id="t1", data=data) + + +def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse: + return WorkflowAppPausedBlockingResponse( + task_id="t1", + workflow_run_id="r1", + data=WorkflowAppPausedBlockingResponse.Data( + id="r1", + workflow_id="wf-1", + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + error=None, + elapsed_time=0.5, + total_tokens=0, + total_steps=2, + created_at=1, + finished_at=None, + paused_nodes=["node-1"], + reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}], + ), + ) + + +@dataclass(frozen=True) +class _FakePauseEntity(WorkflowPauseEntity): + pause_id: str + workflow_run_id: str + paused_at_value: datetime + pause_reasons: Sequence[HumanInputRequired] + + @property + def id(self) -> str: + return self.pause_id + + @property + def workflow_execution_id(self) -> str: + return self.workflow_run_id + + def get_state(self) -> bytes: + raise AssertionError("state is not required for snapshot tests") + + @property + def resumed_at(self) -> datetime | None: + return None + + @property + def paused_at(self) -> datetime: + return self.paused_at_value + + def get_pause_reasons(self) -> Sequence[HumanInputRequired]: + return self.pause_reasons + + +def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun: + return WorkflowRun( + id="run-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + type="workflow", + triggered_from="app-run", + version="v1", + graph=None, + inputs=json.dumps({"input": "value"}), + status=status, + outputs=json.dumps({}), + error=None, + elapsed_time=0.0, + total_tokens=0, + total_steps=0, + created_by_role=CreatorUserRole.END_USER, + created_by="user-1", + created_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + + +def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot: + created_at = datetime(2024, 1, 1, tzinfo=UTC) + finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC) + return WorkflowNodeExecutionSnapshot( + execution_id="exec-1", + node_id="node-1", + node_type="human-input", + title="Human Input", + index=1, + status=status.value, + elapsed_time=0.5, + created_at=created_at, + finished_at=finished_at, + iteration_id=None, + loop_id=None, + ) + + +def _build_resumption_context(task_id: str) -> WorkflowResumptionContext: + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant-1", + app_id="app-1", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-1", + ) + generate_entity = WorkflowAppGenerateEntity( + task_id=task_id, + app_config=app_config, + inputs={}, + files=[], + user_id="user-1", + stream=True, + invoke_from=InvokeFrom.EXPLORE, + call_depth=0, + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + runtime_state.register_paused_node("node-1") + runtime_state.outputs = {"result": "value"} + wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity) + return WorkflowResumptionContext( + generate_entity=wrapper, + serialized_graph_runtime_state=runtime_state.dumps(), + ) + + +class TestHitlServiceApi: + # Service API event-stream continuation + def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + msg_generator.retrieve_events.return_value = ["raw-event"] + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: streamed\n\n" + msg_generator.retrieve_events.assert_called_once_with( + AppMode.WORKFLOW, + "run-1", + terminal_events=[], + ) + workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"]) + + def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open( + self, app, monkeypatch: pytest.MonkeyPatch + ) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"]) + snapshot_builder = Mock(return_value=["snapshot-events"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context( + "/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true", + method="GET", + ): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: snapshot\n\n" + msg_generator.retrieve_events.assert_not_called() + snapshot_builder.assert_called_once_with( + app_mode=AppMode.WORKFLOW, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=ANY, + human_input_surface=HumanInputSurface.SERVICE_API, + close_on_pause=False, + ) + workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"]) + + def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False) + monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit) + + workflow = MagicMock() + workflow.created_by = "owner-id" + monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow) + monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker") + + generator_instance = MagicMock() + generator_instance.generate.return_value = {"result": "advanced-blocking"} + generator_instance.convert_to_event_stream.side_effect = lambda payload: payload + monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance) + + app_model = MagicMock() + app_model.mode = AppMode.ADVANCED_CHAT + app_model.id = "app-id" + app_model.tenant_id = "tenant-id" + app_model.max_active_requests = 0 + app_model.is_agent = False + + user = MagicMock() + user.id = "user-id" + + result = AppGenerateService.generate( + app_model=app_model, + user=user, + args={"workflow_id": None, "query": "hi", "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + assert result == {"result": "advanced-blocking"} + call_kwargs = generator_instance.generate.call_args.kwargs + assert call_kwargs["streaming"] is False + assert call_kwargs["pause_state_config"] is not None + assert call_kwargs["pause_state_config"].session_factory == "session-maker" + assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id" + + # Blocking payload contract + def test_advanced_chat_blocking_pause_payload_contract(self) -> None: + from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter + + response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response( + _build_advanced_chat_paused_blocking_response() + ) + + assert response["event"] == "workflow_paused" + assert response["workflow_run_id"] == "run-1" + assert response["answer"] == "partial" + assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED + assert response["data"]["reasons"][0]["expiration_time"] == 100 + assert "human_input_forms" not in response["data"] + + def test_workflow_blocking_pause_payload_contract(self) -> None: + from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter + + response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response( + _build_workflow_paused_blocking_response() + ) + + assert response["workflow_run_id"] == "r1" + assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED + assert response["data"]["paused_nodes"] == ["node-1"] + assert response["data"]["reasons"] == [ + {"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100} + ] + assert "human_input_forms" not in response["data"] + + def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None: + from core.app.app_config.entities import AppAdditionalFeatures + from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline + from models.enums import MessageStatus + from models.model import EndUser + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}), + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT), + message=SimpleNamespace( + id="message-id", + query="hello", + created_at=datetime.utcnow(), + status=MessageStatus.NORMAL, + answer="", + ), + user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"), + stream=False, + dialogue_count=1, + draft_var_saver_factory=lambda **kwargs: None, + ) + pipeline._task_state.answer = "partial answer" + pipeline._workflow_run_id = "run-id" + + def _gen(): + yield HumanInputRequiredResponse( + task_id="task", + workflow_run_id="run-id", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Approval", + form_content="Need approval", + inputs=[], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + form_token="token-1", + resolved_default_values={}, + expiration_time=123, + ), + ) + yield WorkflowPauseStreamResponse( + task_id="task", + workflow_run_id="run-id", + data=WorkflowPauseStreamResponse.Data( + workflow_run_id="run-id", + paused_nodes=["node-1"], + outputs={}, + reasons=[ + { + "type": PauseReasonType.HUMAN_INPUT_REQUIRED, + "form_id": "form-1", + "node_id": "node-1", + "expiration_time": 123, + }, + ], + status="paused", + created_at=1, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert isinstance(response, AdvancedChatPausedBlockingResponse) + assert response.data.answer == "partial answer" + assert response.data.workflow_run_id == "run-id" + assert response.data.reasons[0]["form_id"] == "form-1" + assert response.data.reasons[0]["expiration_time"] == 123 + + def test_workflow_blocking_pipeline_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None: + from core.app.apps.workflow import generate_task_pipeline as workflow_pipeline_module + from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + trace_manager=None, + workflow_execution_id="run-id", + extras={}, + call_depth=0, + ) + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}), + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + user=SimpleNamespace(id="user", session_id="session"), + stream=False, + draft_var_saver_factory=lambda **kwargs: None, + ) + monkeypatch.setattr(workflow_pipeline_module.time, "time", lambda: 1700000000) + + def _gen(): + yield HumanInputRequiredResponse( + task_id="task", + workflow_run_id="run", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Human Input", + form_content="content", + expiration_time=1, + ), + ) + yield WorkflowPauseStreamResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowPauseStreamResponse.Data( + workflow_run_id="run", + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + paused_nodes=["node-1"], + reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}], + created_at=1, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert isinstance(response, WorkflowAppPausedBlockingResponse) + assert response.data.status == WorkflowExecutionStatus.PAUSED + assert response.data.paused_nodes == ["node-1"] + assert response.data.reasons == [{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}] + + def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None: + converter = _build_service_api_pause_converter() + converter.workflow_start_to_stream_response( + task_id="task", + workflow_run_id="run-id", + workflow_id="workflow-id", + reason=WorkflowStartReason.INITIAL, + ) + + expiration_time = datetime(2024, 1, 1, tzinfo=UTC) + + class _FakeSession: + def execute(self, _stmt): + return [("form-1", expiration_time, '{"display_in_ui": true}')] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession()) + monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + workflow_response_converter, + "load_form_tokens_by_form_id", + lambda form_ids, session=None, surface=None: {"form-1": "token"}, + ) + + reason = HumanInputRequired( + form_id="form-1", + form_content="Rendered", + inputs=[ + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), + ], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + node_id="node-id", + node_title="Human Step", + form_token="token", + ) + queue_event = QueueWorkflowPausedEvent( + reasons=[reason], + outputs={"answer": "value"}, + paused_nodes=["node-id"], + ) + + runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0) + responses = converter.workflow_pause_to_stream_response( + event=queue_event, + task_id="task", + graph_runtime_state=runtime_state, + ) + + assert isinstance(responses[-1], WorkflowPauseStreamResponse) + pause_resp = responses[-1] + assert pause_resp.workflow_run_id == "run-id" + assert pause_resp.data.paused_nodes == ["node-id"] + assert pause_resp.data.outputs == {} + assert pause_resp.data.reasons[0]["TYPE"] == "human_input_required" + assert pause_resp.data.reasons[0]["form_id"] == "form-1" + assert pause_resp.data.reasons[0]["form_token"] == "token" + assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp()) + + assert isinstance(responses[0], HumanInputRequiredResponse) + hi_resp = responses[0] + assert hi_resp.data.form_id == "form-1" + assert hi_resp.data.node_id == "node-id" + assert hi_resp.data.node_title == "Human Step" + assert hi_resp.data.inputs[0].output_variable_name == "field" + assert hi_resp.data.actions[0].id == "approve" + assert hi_resp.data.display_in_ui is True + assert hi_resp.data.form_token == "token" + assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) + + # Snapshot payload contract + def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED) + snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED) + resumption_context = _build_resumption_context("task-ctx") + monkeypatch.setattr( + "services.workflow_event_snapshot_service.load_form_tokens_by_form_id", + lambda form_ids, session=None, surface=None: {"form-1": "wtok"}, + ) + + class _SessionContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return False + + def session_maker() -> _SessionContext: + return _SessionContext( + SimpleNamespace( + execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')], + ) + ) + + pause_entity = _FakePauseEntity( + pause_id="pause-1", + workflow_run_id="run-1", + paused_at_value=datetime(2024, 1, 1, tzinfo=UTC), + pause_reasons=[ + HumanInputRequired( + form_id="form-1", + form_content="content", + node_id="node-1", + node_title="Human Input", + form_token="wtok", + ) + ], + ) + + events = _build_snapshot_events( + workflow_run=workflow_run, + node_snapshots=[snapshot], + task_id="task-ctx", + message_context=None, + pause_entity=pause_entity, + resumption_context=resumption_context, + session_maker=session_maker, + ) + + assert [event["event"] for event in events] == [ + "workflow_started", + "node_started", + "node_finished", + "human_input_required", + "workflow_paused", + ] + assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value + assert events[3]["data"]["form_token"] == "wtok" + assert events[3]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp()) + pause_data = events[-1]["data"] + assert pause_data["paused_nodes"] == ["node-1"] + assert pause_data["outputs"] == {"result": "value"} + assert pause_data["reasons"][0]["TYPE"] == "human_input_required" + assert pause_data["reasons"][0]["form_token"] == "wtok" + assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp()) + assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value + assert pause_data["created_at"] == int(workflow_run.created_at.timestamp()) + assert pause_data["elapsed_time"] == workflow_run.elapsed_time + assert pause_data["total_tokens"] == workflow_run.total_tokens + assert pause_data["total_steps"] == workflow_run.total_steps diff --git a/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py b/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py new file mode 100644 index 0000000000..531f722ceb --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py @@ -0,0 +1,184 @@ +"""Unit tests for Service API human input form endpoints.""" + +from __future__ import annotations + +import json +import sys +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from werkzeug.exceptions import NotFound + +from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi +from models.human_input import RecipientType +from tests.unit_tests.controllers.service_api.conftest import _unwrap + + +class TestWorkflowHumanInputFormApi: + def test_get_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + definition = SimpleNamespace( + model_dump=lambda: { + "rendered_content": "Rendered form content", + "inputs": [{"output_variable_name": "name"}], + "default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}}, + "user_actions": [{"id": "approve", "title": "Approve"}], + } + ) + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=RecipientType.STANDALONE_WEB_APP, + expiration_time=datetime(2099, 1, 1, tzinfo=UTC), + get_definition=lambda: definition, + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + with app.test_request_context("/form/human_input/token-1", method="GET"): + response = handler(api, app_model=app_model, form_token="token-1") + + payload = json.loads(response.get_data(as_text=True)) + assert payload == { + "form_content": "Rendered form content", + "inputs": [{"output_variable_name": "name"}], + "resolved_default_values": {"name": "Alice", "age": "30", "meta": '{"k": "v"}'}, + "user_actions": [{"id": "approve", "title": "Approve"}], + "expiration_time": int(form.expiration_time.timestamp()), + } + service_mock.get_form_by_token.assert_called_once_with("token-1") + service_mock.ensure_form_active.assert_called_once_with(form) + + def test_get_form_not_in_app(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace( + app_id="another-app", + tenant_id="tenant-1", + expiration_time=datetime(2099, 1, 1, tzinfo=UTC), + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + with app.test_request_context("/form/human_input/token-1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, form_token="token-1") + + @pytest.mark.parametrize( + "recipient_type", + [ + RecipientType.CONSOLE, + RecipientType.BACKSTAGE, + RecipientType.EMAIL_MEMBER, + RecipientType.EMAIL_EXTERNAL, + ], + ) + def test_get_rejects_non_service_api_recipient_types( + self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType + ) -> None: + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=recipient_type, + expiration_time=datetime(2099, 1, 1, tzinfo=UTC), + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + with app.test_request_context("/form/human_input/token-1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, form_token="token-1") + + service_mock.ensure_form_active.assert_not_called() + + def test_post_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=RecipientType.STANDALONE_WEB_APP, + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context( + "/form/human_input/token-1", + method="POST", + json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"}, + ): + response, status = handler(api, app_model=app_model, end_user=end_user, form_token="token-1") + + assert response == {} + assert status == 200 + service_mock.submit_form_by_token.assert_called_once_with( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token-1", + selected_action_id="approve", + form_data={"name": "Alice"}, + submission_end_user_id="end-user-1", + ) + + @pytest.mark.parametrize( + "recipient_type", + [ + RecipientType.CONSOLE, + RecipientType.BACKSTAGE, + RecipientType.EMAIL_MEMBER, + RecipientType.EMAIL_EXTERNAL, + ], + ) + def test_post_rejects_non_service_api_recipient_types( + self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType + ) -> None: + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=recipient_type, + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context( + "/form/human_input/token-1", + method="POST", + json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"}, + ): + with pytest.raises(NotFound): + handler(api, app_model=app_model, end_user=end_user, form_token="token-1") + + service_mock.submit_form_by_token.assert_not_called() diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py new file mode 100644 index 0000000000..f45a7f9632 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py @@ -0,0 +1,166 @@ +"""Unit tests for Service API workflow event stream endpoints.""" + +from __future__ import annotations + +import json +import sys +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from werkzeug.exceptions import NotFound + +from controllers.service_api.app.error import NotWorkflowAppError +from controllers.service_api.app.workflow_events import WorkflowEventsApi +from models.enums import CreatorUserRole +from models.model import AppMode +from tests.unit_tests.controllers.service_api.conftest import _unwrap + + +def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run): + workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"] + repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run) + monkeypatch.setattr( + workflow_events_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: repo, + ) + monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object())) + return workflow_events_module + + +class TestWorkflowEventsApi: + def test_wrong_app_mode(self, app) -> None: + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + with pytest.raises(NotWorkflowAppError): + handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + _mock_repo_for_run(monkeypatch, workflow_run=None) + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="another-user", + finished_at=None, + ) + _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=datetime(2099, 1, 1, tzinfo=UTC), + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + monkeypatch.setattr( + workflow_events_module.WorkflowResponseConverter, + "workflow_run_result_to_finish_response", + lambda **_kwargs: SimpleNamespace( + model_dump=lambda mode="json": {"task_id": "run-1", "status": "succeeded"}, + event=SimpleNamespace(value="workflow_finished"), + ), + ) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.mimetype == "text/event-stream" + body = response.get_data(as_text=True).strip() + assert body.startswith("data: ") + payload = json.loads(body[len("data: ") :]) + assert payload["task_id"] == "run-1" + assert payload["event"] == "workflow_finished" + + def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + msg_generator.retrieve_events.return_value = ["raw-event"] + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: streamed\n\n" + msg_generator.retrieve_events.assert_called_once_with( + AppMode.WORKFLOW, + "run-1", + terminal_events=None, + ) + workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"]) + + def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"]) + snapshot_builder = Mock(return_value=["snapshot-events"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: snapshot\n\n" + msg_generator.retrieve_events.assert_not_called() + snapshot_builder.assert_called_once() + workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"]) diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index eddba5a517..8c89812cb4 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -15,7 +15,10 @@ from flask import Flask from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import TenantStatus from models.model import App, AppMode, EndUser -from tests.unit_tests.conftest import setup_mock_tenant_account_query +from tests.unit_tests.conftest import ( + setup_mock_dataset_owner_execute_result, + setup_mock_tenant_owner_execute_result, +) @pytest.fixture @@ -123,9 +126,7 @@ class AuthenticationMocker: mock_db.session.get.side_effect = [mock_app, mock_tenant] if mock_account: - mock_ta = Mock() - mock_ta.account_id = mock_account.id - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) @staticmethod def setup_dataset_auth(mock_db, mock_tenant, mock_account): @@ -133,8 +134,7 @@ class AuthenticationMocker: mock_ta = Mock() mock_ta.account_id = mock_account.id - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) - + setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta) mock_db.session.get.return_value = mock_account diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 288659b192..1b391e67ec 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -701,8 +701,8 @@ class TestDocumentApiDelete: ``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which internally calls ``validate_and_get_api_token``. To bypass the decorator we call the original function via ``__wrapped__`` (preserved by - ``functools.wraps``). ``delete`` queries the dataset via - ``db.session.query(Dataset)`` directly, so we patch ``db`` at the + ``functools.wraps``). ``delete`` loads the dataset via + ``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the controller module. """ diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index a2008e024b..6dfbdcf98e 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan from models.account import TenantStatus from models.model import ApiToken from tests.unit_tests.conftest import ( - setup_mock_dataset_tenant_query, - setup_mock_tenant_account_query, + setup_mock_dataset_owner_execute_result, + setup_mock_tenant_owner_execute_result, ) @@ -141,14 +141,11 @@ class TestValidateAppToken: mock_account = Mock() mock_account.id = str(uuid.uuid4()) - mock_ta = Mock() - mock_ta.account_id = mock_account.id - # Use side_effect to return app first, then tenant via session.get() mock_db.session.get.side_effect = [mock_app, mock_tenant] - # Mock the tenant owner query (execute(select(...)).one_or_none()) - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) + # Mock the tenant owner execute result (execute(select(...)).one_or_none()) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) @validate_app_token def protected_view(app_model): @@ -471,7 +468,7 @@ class TestValidateDatasetToken: mock_account.current_tenant = mock_tenant # Mock the tenant account join query (execute(select(...)).one_or_none()) - setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta) + setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta) # Mock the account lookup via session.get() mock_db.session.get.return_value = mock_account diff --git a/api/tests/unit_tests/controllers/web/conftest.py b/api/tests/unit_tests/controllers/web/conftest.py index 274d78c9cf..b7f3244c6c 100644 --- a/api/tests/unit_tests/controllers/web/conftest.py +++ b/api/tests/unit_tests/controllers/web/conftest.py @@ -22,18 +22,16 @@ class FakeSession: def __init__(self, mapping: dict[str, Any] | None = None): self._mapping: dict[str, Any] = mapping or {} - self._model_name: str | None = None - def query(self, model: type) -> FakeSession: - self._model_name = model.__name__ - return self + def get(self, model: type, _ident: object) -> Any: + return self._mapping.get(model.__name__) - def where(self, *_args: object, **_kwargs: object) -> FakeSession: - return self - - def first(self) -> Any: - assert self._model_name is not None - return self._mapping.get(self._model_name) + def scalar(self, stmt: Any) -> Any: + try: + model = stmt.column_descriptions[0]["entity"] + except (AttributeError, IndexError, KeyError, TypeError): + return None + return self._mapping.get(model.__name__) class FakeDB: diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py index a1dbc80b20..5f2dc19aab 100644 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -36,18 +36,6 @@ class _FakeSession: def __init__(self, mapping: dict[str, Any]): self._mapping = mapping - self._model_name: str | None = None - - def query(self, model): - self._model_name = model.__name__ - return self - - def where(self, *args, **kwargs): - return self - - def first(self): - assert self._model_name is not None - return self._mapping.get(self._model_name) def get(self, model, ident): return self._mapping.get(model.__name__) diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index a01587d64a..13b953c04d 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -34,7 +34,6 @@ def _patch_wraps(): patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), patch("controllers.web.login.dify_config", web_dify), ): - mock_db.session.query.return_value.first.return_value = MagicMock() yield diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 45d4b0e321..370f7abb8b 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -154,7 +154,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock GraphRuntimeState to accept the variable pool @@ -301,7 +300,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock ConversationVariable.from_variable to return mock objects @@ -453,7 +451,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock GraphRuntimeState to accept the variable pool diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index f2df35d7d0..6debeb4fdd 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -1,7 +1,10 @@ from collections.abc import Generator +import pytest + from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, ErrorStreamResponse, @@ -10,7 +13,8 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) -from graphon.enums import WorkflowNodeExecutionStatus +from graphon.entities.pause_reason import PauseReasonType +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: @@ -28,6 +32,37 @@ class TestAdvancedChatGenerateResponseConverter: response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) assert "usage" not in response["metadata"] + def test_blocking_full_response_derives_pause_data_from_model_dump(self, monkeypatch: pytest.MonkeyPatch): + data = AdvancedChatPausedBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + workflow_run_id="run-1", + answer="partial", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + paused_nodes=["node-1"], + reasons=[{"type": PauseReasonType.HUMAN_INPUT_REQUIRED, "form_id": "form-1"}], + status=WorkflowExecutionStatus.PAUSED, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ) + original_model_dump = type(data).model_dump + + def _model_dump_with_future_field(self, *args, **kwargs): + payload = original_model_dump(self, *args, **kwargs) + payload["future_field"] = "future-value" + return payload + + monkeypatch.setattr(type(data), "model_dump", _model_dump_with_future_field) + blocking = AdvancedChatPausedBlockingResponse(task_id="t1", data=data) + + response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(blocking) + + assert response["data"]["future_field"] == "future-value" + def test_stream_simple_response_includes_node_events(self): node_start = NodeStartStreamResponse( task_id="t1", diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 29fd63c063..64bcfa9a18 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -39,15 +39,19 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, AnnotationReply, AnnotationReplyAccount, + HumanInputRequiredResponse, MessageAudioStreamResponse, MessageEndStreamResponse, PingStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables +from graphon.entities.pause_reason import PauseReasonType from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import UserAction from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus @@ -123,6 +127,57 @@ class TestAdvancedChatGenerateTaskPipeline: assert response.data.answer == "done" assert response.data.metadata == {"k": "v"} + def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self): + pipeline = _make_pipeline() + pipeline._task_state.answer = "partial answer" + pipeline._workflow_run_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + start_at=0.0, + total_tokens=7, + node_run_steps=3, + ) + + def _gen(): + yield HumanInputRequiredResponse( + task_id="task", + workflow_run_id="run-id", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Approval", + form_content="Need approval", + inputs=[], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + form_token="token-1", + resolved_default_values={}, + expiration_time=123, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert isinstance(response, AdvancedChatPausedBlockingResponse) + assert response.data.workflow_run_id == "run-id" + assert response.data.status == "paused" + assert response.data.paused_nodes == ["node-1"] + assert response.data.reasons == [ + { + "TYPE": PauseReasonType.HUMAN_INPUT_REQUIRED, + "form_id": "form-1", + "node_id": "node-1", + "node_title": "Approval", + "form_content": "Need approval", + "inputs": [], + "actions": [{"id": "approve", "title": "Approve", "button_style": "default"}], + "display_in_ui": True, + "form_token": "token-1", + "resolved_default_values": {}, + "expiration_time": 123, + } + ] + def test_handle_text_chunk_event_updates_state(self): pipeline = _make_pipeline() pipeline._message_cycle_manager = SimpleNamespace( diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py index 9a2dc38f74..c36edf48fc 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py @@ -375,7 +375,7 @@ def test_generate_success_returns_converted(generator, mocker): workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={}) session = MagicMock() - session.query.return_value.where.return_value.first.return_value = workflow + session.get.return_value = workflow mocker.patch.object(module.db, "session", session) queue_manager = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index 618c8fd76f..603062a51c 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -132,11 +132,8 @@ def test_run_pipeline_not_found(mocker): app_generate_entity.single_iteration_run = None app_generate_entity.single_loop_run = None - query = MagicMock() - query.where.return_value.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.get.side_effect = [None, None] mocker.patch.object(module.db, "session", session) runner = PipelineRunner( @@ -157,11 +154,9 @@ def test_run_workflow_not_initialized(mocker): app_generate_entity = _build_app_generate_entity() pipeline = MagicMock(id="pipe") - query_pipeline = MagicMock() - query_pipeline.where.return_value.first.return_value = pipeline session = MagicMock() - session.query.return_value = query_pipeline + session.get.side_effect = [None, pipeline] mocker.patch.object(module.db, "session", session) runner = PipelineRunner( diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py new file mode 100644 index 0000000000..560652f8cb --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from collections.abc import Generator + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.task_entities import ( + AppStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) +from graphon.enums import WorkflowExecutionStatus + + +class _DummyConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]): + blocking_full_calls: list[WorkflowAppBlockingResponse] = [] + blocking_simple_calls: list[WorkflowAppBlockingResponse] = [] + stream_full_calls: list[Generator[AppStreamResponse, None, None]] = [] + stream_simple_calls: list[Generator[AppStreamResponse, None, None]] = [] + + @classmethod + def reset(cls) -> None: + cls.blocking_full_calls = [] + cls.blocking_simple_calls = [] + cls.stream_full_calls = [] + cls.stream_simple_calls = [] + + @classmethod + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: + cls.blocking_full_calls.append(blocking_response) + return {"kind": "blocking-full", "task_id": blocking_response.task_id} + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: + cls.blocking_simple_calls.append(blocking_response) + return {"kind": "blocking-simple", "task_id": blocking_response.task_id} + + @classmethod + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + cls.stream_full_calls.append(stream_response) + yield {"kind": "stream-full"} + + @classmethod + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + cls.stream_simple_calls.append(stream_response) + yield {"kind": "stream-simple"} + + +def _build_blocking_response() -> WorkflowAppBlockingResponse: + return WorkflowAppBlockingResponse( + task_id="task-1", + workflow_run_id="run-1", + data=WorkflowAppBlockingResponse.Data( + id="run-1", + workflow_id="workflow-1", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"ok": True}, + error=None, + elapsed_time=0.1, + total_tokens=0, + total_steps=1, + created_at=1, + finished_at=2, + ), + ) + + +def _build_stream_response() -> Generator[AppStreamResponse, None, None]: + yield WorkflowAppStreamResponse( + workflow_run_id="run-1", + stream_response=PingStreamResponse(task_id="task-1"), + ) + + +def test_convert_routes_blocking_response_by_invoke_from() -> None: + _DummyConverter.reset() + blocking_response = _build_blocking_response() + + full_result = _DummyConverter.convert(blocking_response, InvokeFrom.SERVICE_API) + simple_result = _DummyConverter.convert(blocking_response, InvokeFrom.WEB_APP) + + assert full_result == {"kind": "blocking-full", "task_id": "task-1"} + assert simple_result == {"kind": "blocking-simple", "task_id": "task-1"} + assert _DummyConverter.blocking_full_calls == [blocking_response] + assert _DummyConverter.blocking_simple_calls == [blocking_response] + + +def test_convert_routes_stream_response_by_invoke_from() -> None: + _DummyConverter.reset() + + full_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.SERVICE_API)) + simple_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.WEB_APP)) + + assert full_result == [{"kind": "stream-full"}] + assert simple_result == [{"kind": "stream-simple"}] + assert len(_DummyConverter.stream_full_calls) == 1 + assert len(_DummyConverter.stream_simple_calls) == 1 diff --git a/api/tests/unit_tests/core/app/apps/test_message_generator.py b/api/tests/unit_tests/core/app/apps/test_message_generator.py index 25377e633e..90c9abf35c 100644 --- a/api/tests/unit_tests/core/app/apps/test_message_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_message_generator.py @@ -1,6 +1,7 @@ from unittest.mock import Mock, patch from core.app.apps.message_generator import MessageGenerator +from core.app.entities.task_entities import StreamEvent from models.model import AppMode @@ -23,7 +24,21 @@ class TestMessageGenerator: "core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}]) ) as mock_stream, ): - events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2)) + events = list( + MessageGenerator.retrieve_events( + AppMode.WORKFLOW, + "run-1", + idle_timeout=1, + ping_interval=2, + terminal_events=[StreamEvent.WORKFLOW_FINISHED.value], + ) + ) assert events == [{"event": "ping"}] - mock_stream.assert_called_once() + mock_stream.assert_called_once_with( + topic="topic", + idle_timeout=1, + ping_interval=2, + on_subscribe=None, + terminal_events=[StreamEvent.WORKFLOW_FINISHED.value], + ) diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py index a7714c56ce..58f0e47a4b 100644 --- a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py +++ b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py @@ -88,6 +88,10 @@ def test_normalize_terminal_events_defaults(): } +def test_normalize_terminal_events_empty_values(): + assert _normalize_terminal_events([]) == set({}) + + def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch): topic = FakeTopic() times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0] @@ -106,3 +110,21 @@ def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch): assert next(generator) == StreamEvent.PING.value # next receive yields None -> ping interval triggers assert next(generator) == StreamEvent.PING.value + + +def test_stream_topic_events_can_continue_past_pause(): + topic = FakeTopic() + topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_PAUSED.value}).encode()) + topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_FINISHED.value}).encode()) + + generator = stream_topic_events( + topic=topic, + idle_timeout=1.0, + terminal_events=[StreamEvent.WORKFLOW_FINISHED.value], + ) + + assert next(generator) == StreamEvent.PING.value + assert next(generator)["event"] == StreamEvent.WORKFLOW_PAUSED.value + assert next(generator)["event"] == StreamEvent.WORKFLOW_FINISHED.value + with pytest.raises(StopIteration): + next(generator) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index 99433478d3..0bcc1029b0 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -36,11 +36,12 @@ from core.app.entities.queue_entities import ( ) from core.app.entities.task_entities import ( ErrorStreamResponse, + HumanInputRequiredResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, PingStreamResponse, + WorkflowAppPausedBlockingResponse, WorkflowFinishStreamResponse, - WorkflowPauseStreamResponse, WorkflowStartStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk @@ -91,27 +92,50 @@ def _make_pipeline(): class TestWorkflowGenerateTaskPipeline: - def test_to_blocking_response_handles_pause(self): + def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self): pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + start_at=0.0, + total_tokens=5, + node_run_steps=2, + ) def _gen(): - yield WorkflowPauseStreamResponse( + yield HumanInputRequiredResponse( task_id="task", - workflow_run_id="run", - data=WorkflowPauseStreamResponse.Data( - workflow_run_id="run", - status=WorkflowExecutionStatus.PAUSED, - outputs={}, - created_at=1, - elapsed_time=0.1, - total_tokens=0, - total_steps=0, + workflow_run_id="run-id", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Human Input", + form_content="content", + expiration_time=1, ), ) response = pipeline._to_blocking_response(_gen()) + assert isinstance(response, WorkflowAppPausedBlockingResponse) + assert response.workflow_run_id == "run-id" assert response.data.status == WorkflowExecutionStatus.PAUSED + assert response.data.created_at == 0 + assert response.data.paused_nodes == ["node-1"] + assert response.data.reasons == [ + { + "TYPE": "human_input_required", + "form_id": "form-1", + "node_id": "node-1", + "node_title": "Human Input", + "form_content": "content", + "inputs": [], + "actions": [], + "display_in_ui": False, + "form_token": None, + "resolved_default_values": {}, + "expiration_time": 1, + } + ] def test_to_blocking_response_handles_finish(self): pipeline = _make_pipeline() diff --git a/api/tests/unit_tests/core/datasource/test_notion_provider.py b/api/tests/unit_tests/core/datasource/test_notion_provider.py index e4bd7d3bdf..d21b9e471b 100644 --- a/api/tests/unit_tests/core/datasource/test_notion_provider.py +++ b/api/tests/unit_tests/core/datasource/test_notion_provider.py @@ -775,9 +775,6 @@ class TestNotionExtractorLastEditedTime: "last_edited_time": "2024-11-27T18:00:00.000Z", } mock_request.return_value = mock_response - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query # Act extractor_page.update_last_edited_time(mock_document_model) @@ -863,9 +860,6 @@ class TestNotionExtractorIntegration: } mock_request.side_effect = [last_edited_response, block_response] - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query # Act documents = extractor.extract() @@ -919,10 +913,6 @@ class TestNotionExtractorIntegration: } mock_post.return_value = database_response - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query - # Act documents = extractor.extract() diff --git a/api/tests/unit_tests/core/helper/test_creators.py b/api/tests/unit_tests/core/helper/test_creators.py new file mode 100644 index 0000000000..df67d3f513 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_creators.py @@ -0,0 +1,106 @@ +"""Tests for the Creators Platform helper module.""" + +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from yarl import URL + + +@pytest.fixture(autouse=True) +def _patch_creators_url(monkeypatch): + """Patch the module-level creators_platform_api_url for all tests.""" + monkeypatch.setattr( + "core.helper.creators.creators_platform_api_url", + URL("https://creators.example.com"), + ) + + +class TestUploadDSL: + @patch("core.helper.creators.httpx.post") + def test_returns_claim_code(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = {"data": {"claim_code": "abc123"}} + mock_response.raise_for_status = MagicMock() + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + result = upload_dsl(b"app: demo", "demo.yaml") + + assert result == "abc123" + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + assert "anonymous-upload" in call_kwargs.args[0] + assert call_kwargs.kwargs["timeout"] == 30 + + @patch("core.helper.creators.httpx.post") + def test_raises_on_missing_claim_code(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = {"data": {}} + mock_response.raise_for_status = MagicMock() + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + with pytest.raises(ValueError, match="claim_code"): + upload_dsl(b"app: demo") + + @patch("core.helper.creators.httpx.post") + def test_raises_on_http_error(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server Error", + request=MagicMock(), + response=MagicMock(), + ) + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + with pytest.raises(httpx.HTTPStatusError): + upload_dsl(b"app: demo") + + +class TestGetRedirectUrl: + @patch("core.helper.creators.dify_config") + def test_without_oauth_client_id(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "" + + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + assert "dsl_claim_code=claim-abc" in url + assert "oauth_code" not in url + assert url.startswith("https://creators.example.com") + + @patch("core.helper.creators.dify_config") + def test_with_oauth_client_id(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "client-xyz" + + with patch( + "services.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="oauth-code-123", + ) as mock_sign: + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + mock_sign.assert_called_once_with("client-xyz", "user-1") + assert "dsl_claim_code=claim-abc" in url + assert "oauth_code=oauth-code-123" in url + + @patch("core.helper.creators.dify_config") + def test_strips_trailing_slash(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com/" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "" + + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + assert url.startswith("https://creators.example.com?") + assert "creators.example.com/?" not in url diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py index f3ef7fccd0..73e081a570 100644 --- a/api/tests/unit_tests/core/helper/test_encrypter.py +++ b/api/tests/unit_tests/core/helper/test_encrypter.py @@ -40,11 +40,11 @@ class TestObfuscatedToken: class TestEncryptToken: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_successful_encryption(self, mock_encrypt, mock_query): + def test_successful_encryption(self, mock_encrypt, mock_get): """Test successful token encryption""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_data" result = encrypt_token("tenant-123", "test_token") @@ -53,9 +53,9 @@ class TestEncryptToken: mock_encrypt.assert_called_with("test_token", "mock_public_key") @patch("extensions.ext_database.db.session.get") - def test_tenant_not_found(self, mock_query): + def test_tenant_not_found(self, mock_get): """Test error when tenant doesn't exist""" - mock_query.return_value = None + mock_get.return_value = None with pytest.raises(ValueError) as exc_info: encrypt_token("invalid-tenant", "test_token") @@ -122,12 +122,12 @@ class TestEncryptDecryptIntegration: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") @patch("libs.rsa.decrypt") - def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query): + def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_get): """Test that encryption and decryption are consistent""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # Setup mock encryption/decryption original_token = "test_token_123" @@ -148,12 +148,12 @@ class TestSecurity: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_cross_tenant_isolation(self, mock_encrypt, mock_query): + def test_cross_tenant_isolation(self, mock_encrypt, mock_get): """Ensure tokens encrypted for one tenant cannot be used by another""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "tenant1_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_for_tenant1" # Encrypt token for tenant1 @@ -183,10 +183,10 @@ class TestSecurity: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_encryption_randomness(self, mock_encrypt, mock_query): + def test_encryption_randomness(self, mock_encrypt, mock_get): """Ensure same plaintext produces different ciphertext""" mock_tenant = MagicMock(encrypt_public_key="key") - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # Different outputs for same input mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"] @@ -207,11 +207,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query): + def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_get): """Test encryption of empty token""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_empty" result = encrypt_token("tenant-123", "") @@ -221,11 +221,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query): + def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_get): """Test tokens containing special/unicode characters""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_special" # Test various special characters @@ -244,11 +244,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query): + def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_get): """Test behavior when token exceeds RSA encryption limits""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # RSA 2048-bit can only encrypt ~245 bytes # The actual limit depends on padding scheme diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 3b64ce6b5c..c4e610d5b0 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -6,10 +6,6 @@ import pytest from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from core.llm_generator.prompts import ( - DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS, - DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE, -) from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -102,8 +98,8 @@ class TestLLMGenerator: assert len(questions) == 2 assert questions[0] == "Question 1?" assert mock_model_instance.invoke_llm.call_args.kwargs["model_parameters"] == { - "max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS, - "temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE, + "max_tokens": 2560, + "temperature": 0.0, } def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): @@ -181,8 +177,8 @@ class TestLLMGenerator: model_type=ModelType.LLM, ) assert default_model_instance.invoke_llm.call_args.kwargs["model_parameters"] == { - "max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS, - "temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE, + "max_tokens": 2560, + "temperature": 0.0, } assert default_model_instance.invoke_llm.call_args.kwargs["stop"] == [] @@ -495,7 +491,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}} @@ -521,7 +517,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() # Cause exception in node_type logic workflow.graph_dict = {"graph": {"nodes": []}} @@ -548,7 +544,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} @@ -636,7 +632,7 @@ class TestLLMGenerator: instance.invoke_llm.return_value = mock_response with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}} diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py index 136ac0c72a..1e91c2dd88 100644 --- a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -29,15 +29,6 @@ class _Field: return ("in", self._name, tuple(values)) -class _FakeQuery: - def __init__(self): - self.where_calls: list[tuple] = [] - - def where(self, *conditions): - self.where_calls.append(conditions) - return self - - class _FakeExecuteResult: def __init__(self, segments: list[SimpleNamespace]): self._segments = segments diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index 0baf85c314..b0ecad4d0c 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -109,17 +109,6 @@ class _FakeExecuteResult: return _FakeExecuteScalarResult(self._data) -class _FakeSummaryQuery: - def __init__(self, summaries: list) -> None: - self._summaries = summaries - - def filter(self, *args, **kwargs): - return self - - def all(self) -> list: - return self._summaries - - class _FakeScalarsResult: def __init__(self, data: list) -> None: self._data = data diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py index dc21d378a2..9de04c80ba 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -372,19 +372,11 @@ def test_vector_delegation_methods(vector_factory_module): def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch): - class _Field: - def __eq__(self, value): - return value - - upload_query = MagicMock() - upload_query.where.return_value = upload_query - vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) vector._embeddings = MagicMock() vector._vector_processor = MagicMock() mock_session = SimpleNamespace(get=lambda _model, _id: None) - monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session)) assert vector.search_by_file("file-1") == [] diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 7c4defc180..b4bb343533 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -1484,11 +1484,8 @@ class TestIndexingRunnerProcessChunk: mock_dependencies["redis"].get.return_value = None - # Mock database query for segment updates - mock_query = MagicMock() - mock_dependencies["db"].session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.update.return_value = None + # Mock database update for segment status + mock_dependencies["db"].session.execute.return_value = None # Create a proper context manager mock mock_context = MagicMock() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 89830f7517..fd607210f1 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -2417,12 +2417,11 @@ class TestDatasetRetrievalKnowledgeRetrieval: mock_document.data_source_type = "upload_file" mock_document.doc_metadata = {} - mock_session.query.return_value.filter.return_value.all.return_value = [ - mock_dataset_from_db - ] - mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( - [mock_dataset_from_db, mock_document] - ) + mock_datasets = MagicMock() + mock_datasets.all.return_value = [mock_dataset_from_db] + mock_documents = MagicMock() + mock_documents.all.return_value = [mock_document] + mock_session.scalars.side_effect = [mock_datasets, mock_documents] # Act result = dataset_retrieval.knowledge_retrieval(request) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py index 90feb4cf01..aace419d15 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py @@ -451,12 +451,11 @@ class TestDatasetRetrievalKnowledgeRetrieval: mock_document.data_source_type = "upload_file" mock_document.doc_metadata = {} - mock_session.query.return_value.filter.return_value.all.return_value = [ - mock_dataset_from_db - ] - mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( - [mock_dataset_from_db, mock_document] - ) + mock_datasets = MagicMock() + mock_datasets.all.return_value = [mock_dataset_from_db] + mock_documents = MagicMock() + mock_documents.all.return_value = [mock_document] + mock_session.scalars.side_effect = [mock_datasets, mock_documents] # Act result = dataset_retrieval.knowledge_retrieval(request) diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index afea9144c0..5a7e7e30a5 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -5,7 +5,7 @@ import redis from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration -from core.model_manager import LBModelManager +from core.model_manager import LBModelManager, ModelManager from extensions.ext_redis import redis_client from graphon.model_runtime.entities.model_entities import ModelType @@ -40,6 +40,29 @@ def lb_model_manager(): return lb_model_manager +def test_model_manager_with_cache_enabled_reuses_stored_credentials(): + """With ``enable_credentials_cache=True``, later calls for the same key return cached creds.""" + provider_manager = MagicMock() + bundle = MagicMock() + bundle.configuration.provider.provider = "openai" + bundle.configuration.tenant_id = "tenant-1" + bundle.configuration.model_settings = [] + bundle.model_type_instance.model_type = ModelType.LLM + get_creds = MagicMock(return_value={"api_key": "first"}) + bundle.configuration.get_current_credentials = get_creds + provider_manager.get_provider_model_bundle.return_value = bundle + + manager = ModelManager(provider_manager, enable_credentials_cache=True) + first = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4") + assert first.credentials == {"api_key": "first"} + get_creds.assert_called_once() + + get_creds.return_value = {"api_key": "second"} + second = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4") + assert second.credentials == {"api_key": "first"} + get_creds.assert_called_once() + + def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager): # initialize redis client redis_client.initialize(redis.Redis()) diff --git a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py index 5691f33e65..6bb86ebe78 100644 --- a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py +++ b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py @@ -2,50 +2,50 @@ from __future__ import annotations import pytest -from core.tools.utils import system_oauth_encryption as oauth_encryption -from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter +from core.tools.utils import system_encryption as encryption +from core.tools.utils.system_encryption import EncryptionError, SystemEncrypter -def test_system_oauth_encrypter_roundtrip(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_roundtrip(): + encrypter = SystemEncrypter(secret_key="test-secret") payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"} - encrypted = encrypter.encrypt_oauth_params(payload) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(payload) + decrypted = encrypter.decrypt_params(encrypted) assert encrypted assert dict(decrypted) == payload -def test_system_oauth_encrypter_decrypt_validates_input(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_decrypt_validates_input(): + encrypter = SystemEncrypter(secret_key="test-secret") with pytest.raises(ValueError, match="must be a string"): - encrypter.decrypt_oauth_params(123) # type: ignore[arg-type] + encrypter.decrypt_params(123) # type: ignore[arg-type] with pytest.raises(ValueError, match="cannot be empty"): - encrypter.decrypt_oauth_params("") + encrypter.decrypt_params("") -def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_raises_error_for_invalid_ciphertext(): + encrypter = SystemEncrypter(secret_key="test-secret") - with pytest.raises(OAuthEncryptionError, match="Decryption failed"): - encrypter.decrypt_oauth_params("not-base64") + with pytest.raises(EncryptionError, match="Decryption failed"): + encrypter.decrypt_params("not-base64") -def test_system_oauth_helpers_use_global_cached_instance(monkeypatch): - monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None) - monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret") +def test_system_helpers_use_global_cached_instance(monkeypatch): + monkeypatch.setattr(encryption, "_encrypter", None) + monkeypatch.setattr("core.tools.utils.system_encryption.dify_config.SECRET_KEY", "global-secret") - first = oauth_encryption.get_system_oauth_encrypter() - second = oauth_encryption.get_system_oauth_encrypter() + first = encryption.get_system_encrypter() + second = encryption.get_system_encrypter() assert first is second - encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"}) - assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"} + encrypted = encryption.encrypt_system_params({"k": "v"}) + assert encryption.decrypt_system_params(encrypted) == {"k": "v"} -def test_create_system_oauth_encrypter_factory(): - encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret") - assert isinstance(encrypter, SystemOAuthEncrypter) +def test_create_system_encrypter_factory(): + encrypter = encryption.create_system_encrypter(secret_key="factory-secret") + assert isinstance(encrypter, SystemEncrypter) diff --git a/api/tests/unit_tests/core/workflow/test_human_input_forms.py b/api/tests/unit_tests/core/workflow/test_human_input_forms.py index 6071a95a57..e508815b35 100644 --- a/api/tests/unit_tests/core/workflow/test_human_input_forms.py +++ b/api/tests/unit_tests/core/workflow/test_human_input_forms.py @@ -1,6 +1,7 @@ from types import SimpleNamespace -from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.human_input_forms import _load_form_tokens_by_form_id, load_form_tokens_by_form_id +from core.workflow.human_input_policy import HumanInputSurface from models.human_input import RecipientType @@ -53,3 +54,50 @@ def test_load_form_tokens_by_form_id_ignores_unsupported_recipients() -> None: ) assert load_form_tokens_by_form_id(["form-1"], session=session) == {} + + +def test_load_form_tokens_by_form_id_uses_shared_priority() -> None: + session = _FakeSession( + recipients=[ + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="web-token", + ), + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.CONSOLE, + access_token="console-token", + ), + ] + ) + + assert _load_form_tokens_by_form_id(session, ["form-1"]) == {"form-1": "console-token"} + + +def test_load_form_tokens_by_form_id_uses_web_token_for_service_api_surface() -> None: + session = _FakeSession( + recipients=[ + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="web-token", + ), + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.CONSOLE, + access_token="console-token", + ), + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.BACKSTAGE, + access_token="backstage-token", + ), + ] + ) + + assert load_form_tokens_by_form_id( + ["form-1"], + session=session, + surface=HumanInputSurface.SERVICE_API, + ) == {"form-1": "web-token"} diff --git a/api/tests/unit_tests/core/workflow/test_human_input_policy.py b/api/tests/unit_tests/core/workflow/test_human_input_policy.py new file mode 100644 index 0000000000..e6d0366af5 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_policy.py @@ -0,0 +1,50 @@ +from core.workflow.human_input_policy import ( + HumanInputSurface, + get_preferred_form_token, + is_recipient_type_allowed_for_surface, +) +from models.human_input import RecipientType + + +def test_service_api_only_allows_public_webapp_forms() -> None: + assert is_recipient_type_allowed_for_surface( + RecipientType.STANDALONE_WEB_APP, + HumanInputSurface.SERVICE_API, + ) + assert not is_recipient_type_allowed_for_surface( + RecipientType.CONSOLE, + HumanInputSurface.SERVICE_API, + ) + assert not is_recipient_type_allowed_for_surface( + RecipientType.BACKSTAGE, + HumanInputSurface.SERVICE_API, + ) + assert not is_recipient_type_allowed_for_surface( + RecipientType.EMAIL_MEMBER, + HumanInputSurface.SERVICE_API, + ) + + +def test_console_only_allows_internal_console_surfaces() -> None: + assert is_recipient_type_allowed_for_surface( + RecipientType.CONSOLE, + HumanInputSurface.CONSOLE, + ) + assert is_recipient_type_allowed_for_surface( + RecipientType.BACKSTAGE, + HumanInputSurface.CONSOLE, + ) + assert not is_recipient_type_allowed_for_surface( + RecipientType.STANDALONE_WEB_APP, + HumanInputSurface.CONSOLE, + ) + + +def test_preferred_form_token_uses_shared_priority_order() -> None: + recipients = [ + (RecipientType.STANDALONE_WEB_APP, "web-token"), + (RecipientType.CONSOLE, "console-token"), + (RecipientType.BACKSTAGE, "backstage-token"), + ] + + assert get_preferred_form_token(recipients) == "backstage-token" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py new file mode 100644 index 0000000000..ac4b087b91 --- /dev/null +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType +from models.human_input import RecipientType +from repositories.sqlalchemy_api_workflow_run_repository import _build_human_input_required_reason + + +def _build_form_model() -> SimpleNamespace: + expiration_time = datetime(2024, 1, 1, tzinfo=UTC) + definition = FormDefinition( + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "Alice"}, + node_title="Ask Name", + display_in_ui=True, + ) + return SimpleNamespace( + id="form-1", + node_id="node-1", + form_definition=definition.model_dump_json(), + expiration_time=expiration_time, + ) + + +def _build_reason_model() -> SimpleNamespace: + return SimpleNamespace(form_id="form-1", node_id="node-1") + + +def test_build_human_input_required_reason_prefers_standalone_web_app_token() -> None: + reason = _build_human_input_required_reason( + _build_reason_model(), + _build_form_model(), + [ + SimpleNamespace(recipient_type=RecipientType.BACKSTAGE, access_token="btok"), + SimpleNamespace(recipient_type=RecipientType.CONSOLE, access_token="ctok"), + SimpleNamespace(recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"), + ], + ) + + assert reason.node_title == "Ask Name" + assert reason.resolved_default_values == {"name": "Alice"} + assert not hasattr(reason, "form_token") + + +def test_build_human_input_required_reason_falls_back_to_console_token() -> None: + reason = _build_human_input_required_reason( + _build_reason_model(), + _build_form_model(), + [ + SimpleNamespace(recipient_type=RecipientType.BACKSTAGE, access_token="btok"), + SimpleNamespace(recipient_type=RecipientType.CONSOLE, access_token="ctok"), + ], + ) + + assert reason.node_id == "node-1" + assert reason.actions[0].id == "approve" + assert not hasattr(reason, "form_token") diff --git a/api/tests/unit_tests/services/document_indexing_task_proxy.py b/api/tests/unit_tests/services/document_indexing_task_proxy.py deleted file mode 100644 index ff243b8dc3..0000000000 --- a/api/tests/unit_tests/services/document_indexing_task_proxy.py +++ /dev/null @@ -1,1291 +0,0 @@ -""" -Comprehensive unit tests for DocumentIndexingTaskProxy service. - -This module contains extensive unit tests for the DocumentIndexingTaskProxy class, -which is responsible for routing document indexing tasks to appropriate Celery queues -based on tenant billing configuration and managing tenant-isolated task queues. - -The DocumentIndexingTaskProxy handles: -- Task scheduling and queuing (direct vs tenant-isolated queues) -- Priority vs normal task routing based on billing plans -- Tenant isolation using TenantIsolatedTaskQueue -- Batch indexing operations with multiple document IDs -- Error handling and retry logic through queue management - -This test suite ensures: -- Correct task routing based on billing configuration -- Proper tenant isolation queue management -- Accurate batch operation handling -- Comprehensive error condition coverage -- Edge cases are properly handled - -================================================================================ -ARCHITECTURE OVERVIEW -================================================================================ - -The DocumentIndexingTaskProxy is a critical component in the document indexing -workflow. It acts as a proxy/router that determines which Celery queue to use -for document indexing tasks based on tenant billing configuration. - -1. Task Queue Routing: - - Direct Queue: Bypasses tenant isolation, used for self-hosted/enterprise - - Tenant Queue: Uses tenant isolation, queues tasks when another task is running - - Default Queue: Normal priority with tenant isolation (SANDBOX plan) - - Priority Queue: High priority with tenant isolation (TEAM/PRO plans) - - Priority Direct Queue: High priority without tenant isolation (billing disabled) - -2. Tenant Isolation: - - Uses TenantIsolatedTaskQueue to ensure only one indexing task runs per tenant - - When a task is running, new tasks are queued in Redis - - When a task completes, it pulls the next task from the queue - - Prevents resource contention and ensures fair task distribution - -3. Billing Configuration: - - SANDBOX plan: Uses default tenant queue (normal priority, tenant isolated) - - TEAM/PRO plans: Uses priority tenant queue (high priority, tenant isolated) - - Billing disabled: Uses priority direct queue (high priority, no isolation) - -4. Batch Operations: - - Supports indexing multiple documents in a single task - - DocumentTask entity serializes task information - - Tasks are queued with all document IDs for batch processing - -================================================================================ -TESTING STRATEGY -================================================================================ - -This test suite follows a comprehensive testing strategy that covers: - -1. Initialization and Configuration: - - Proxy initialization with various parameters - - TenantIsolatedTaskQueue initialization - - Features property caching - - Edge cases (empty document_ids, single document, large batches) - -2. Task Queue Routing: - - Direct queue routing (bypasses tenant isolation) - - Tenant queue routing with existing task key (pushes to waiting queue) - - Tenant queue routing without task key (sets flag and executes immediately) - - DocumentTask serialization and deserialization - - Task function delay() call with correct parameters - -3. Queue Type Selection: - - Default tenant queue routing (normal_document_indexing_task) - - Priority tenant queue routing (priority_document_indexing_task with isolation) - - Priority direct queue routing (priority_document_indexing_task without isolation) - -4. Dispatch Logic: - - Billing enabled + SANDBOX plan → default tenant queue - - Billing enabled + non-SANDBOX plan (TEAM, PRO, etc.) → priority tenant queue - - Billing disabled (self-hosted/enterprise) → priority direct queue - - All CloudPlan enum values handling - - Edge cases: None plan, empty plan string - -5. Tenant Isolation and Queue Management: - - Task key existence checking (get_task_key) - - Task waiting time setting (set_task_waiting_time) - - Task pushing to queue (push_tasks) - - Queue state transitions (idle → active → idle) - - Multiple concurrent task handling - -6. Batch Operations: - - Single document indexing - - Multiple document batch indexing - - Large batch handling - - Empty batch handling (edge case) - -7. Error Handling and Retry Logic: - - Task function delay() failure handling - - Queue operation failures (Redis errors) - - Feature service failures - - Invalid task data handling - - Retry mechanism through queue pull operations - -8. Integration Points: - - FeatureService integration (billing features, subscription plans) - - TenantIsolatedTaskQueue integration (Redis operations) - - Celery task integration (normal_document_indexing_task, priority_document_indexing_task) - - DocumentTask entity serialization - -================================================================================ -""" - -from unittest.mock import Mock, patch - -import pytest - -from core.entities.document_task import DocumentTask -from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from enums.cloud_plan import CloudPlan -from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy - -# ============================================================================ -# Test Data Factory -# ============================================================================ - - -class DocumentIndexingTaskProxyTestDataFactory: - """ - Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests. - - This factory provides static methods to create mock objects for: - - FeatureService features with billing configuration - - TenantIsolatedTaskQueue mocks with various states - - DocumentIndexingTaskProxy instances with different configurations - - DocumentTask entities for testing serialization - - The factory methods help maintain consistency across tests and reduce - code duplication when setting up test scenarios. - """ - - @staticmethod - def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: - """ - Create mock features with billing configuration. - - This method creates a mock FeatureService features object with - billing configuration that can be used to test different billing - scenarios in the DocumentIndexingTaskProxy. - - Args: - billing_enabled: Whether billing is enabled for the tenant - plan: The CloudPlan enum value for the subscription plan - - Returns: - Mock object configured as FeatureService features with billing info - """ - features = Mock() - - features.billing = Mock() - - features.billing.enabled = billing_enabled - - features.billing.subscription = Mock() - - features.billing.subscription.plan = plan - - return features - - @staticmethod - def create_mock_tenant_queue(has_task_key: bool = False) -> Mock: - """ - Create mock TenantIsolatedTaskQueue. - - This method creates a mock TenantIsolatedTaskQueue that can simulate - different queue states for testing tenant isolation logic. - - Args: - has_task_key: Whether the queue has an active task key (task running) - - Returns: - Mock object configured as TenantIsolatedTaskQueue - """ - queue = Mock(spec=TenantIsolatedTaskQueue) - - queue.get_task_key.return_value = "task_key" if has_task_key else None - - queue.push_tasks = Mock() - - queue.set_task_waiting_time = Mock() - - queue.delete_task_key = Mock() - - return queue - - @staticmethod - def create_document_task_proxy( - tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None - ) -> DocumentIndexingTaskProxy: - """ - Create DocumentIndexingTaskProxy instance for testing. - - This method creates a DocumentIndexingTaskProxy instance with default - or specified parameters for use in test cases. - - Args: - tenant_id: Tenant identifier for the proxy - dataset_id: Dataset identifier for the proxy - document_ids: List of document IDs to index (defaults to 3 documents) - - Returns: - DocumentIndexingTaskProxy instance configured for testing - """ - if document_ids is None: - document_ids = ["doc-1", "doc-2", "doc-3"] - - return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - - @staticmethod - def create_document_task( - tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None - ) -> DocumentTask: - """ - Create DocumentTask entity for testing. - - This method creates a DocumentTask entity that can be used to test - task serialization and deserialization logic. - - Args: - tenant_id: Tenant identifier for the task - dataset_id: Dataset identifier for the task - document_ids: List of document IDs to index (defaults to 3 documents) - - Returns: - DocumentTask entity configured for testing - """ - if document_ids is None: - document_ids = ["doc-1", "doc-2", "doc-3"] - - return DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) - - -# ============================================================================ -# Test Classes -# ============================================================================ - - -class TestDocumentIndexingTaskProxy: - """ - Comprehensive unit tests for DocumentIndexingTaskProxy class. - - This test class covers all methods and scenarios of the DocumentIndexingTaskProxy, - including initialization, task routing, queue management, dispatch logic, and - error handling. - """ - - # ======================================================================== - # Initialization Tests - # ======================================================================== - - def test_initialization(self): - """ - Test DocumentIndexingTaskProxy initialization. - - This test verifies that the proxy is correctly initialized with - the provided tenant_id, dataset_id, and document_ids, and that - the TenantIsolatedTaskQueue is properly configured. - """ - # Arrange - tenant_id = "tenant-123" - - dataset_id = "dataset-456" - - document_ids = ["doc-1", "doc-2", "doc-3"] - - # Act - proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - - # Assert - assert proxy._tenant_id == tenant_id - - assert proxy._dataset_id == dataset_id - - assert proxy._document_ids == document_ids - - assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue) - - assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id - - assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing" - - def test_initialization_with_empty_document_ids(self): - """ - Test initialization with empty document_ids list. - - This test verifies that the proxy can be initialized with an empty - document_ids list, which may occur in edge cases or error scenarios. - """ - # Arrange - tenant_id = "tenant-123" - - dataset_id = "dataset-456" - - document_ids = [] - - # Act - proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - - # Assert - assert proxy._tenant_id == tenant_id - - assert proxy._dataset_id == dataset_id - - assert proxy._document_ids == document_ids - - assert len(proxy._document_ids) == 0 - - def test_initialization_with_single_document_id(self): - """ - Test initialization with single document_id. - - This test verifies that the proxy can be initialized with a single - document ID, which is a common use case for single document indexing. - """ - # Arrange - tenant_id = "tenant-123" - - dataset_id = "dataset-456" - - document_ids = ["doc-1"] - - # Act - proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - - # Assert - assert proxy._tenant_id == tenant_id - - assert proxy._dataset_id == dataset_id - - assert proxy._document_ids == document_ids - - assert len(proxy._document_ids) == 1 - - def test_initialization_with_large_batch(self): - """ - Test initialization with large batch of document IDs. - - This test verifies that the proxy can handle large batches of - document IDs, which may occur in bulk indexing scenarios. - """ - # Arrange - tenant_id = "tenant-123" - - dataset_id = "dataset-456" - - document_ids = [f"doc-{i}" for i in range(100)] - - # Act - proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - - # Assert - assert proxy._tenant_id == tenant_id - - assert proxy._dataset_id == dataset_id - - assert proxy._document_ids == document_ids - - assert len(proxy._document_ids) == 100 - - # ======================================================================== - # Features Property Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_features_property(self, mock_feature_service): - """ - Test cached_property features. - - This test verifies that the features property is correctly cached - and that FeatureService.get_features is called only once, even when - the property is accessed multiple times. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features() - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - # Act - features1 = proxy.features - - features2 = proxy.features # Second call should use cached property - - # Assert - assert features1 == mock_features - - assert features2 == mock_features - - assert features1 is features2 # Should be the same instance due to caching - - mock_feature_service.get_features.assert_called_once_with("tenant-123") - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_features_property_with_different_tenants(self, mock_feature_service): - """ - Test features property with different tenant IDs. - - This test verifies that the features property correctly calls - FeatureService.get_features with the correct tenant_id for each - proxy instance. - """ - # Arrange - mock_features1 = DocumentIndexingTaskProxyTestDataFactory.create_mock_features() - - mock_features2 = DocumentIndexingTaskProxyTestDataFactory.create_mock_features() - - mock_feature_service.get_features.side_effect = [mock_features1, mock_features2] - - proxy1 = DocumentIndexingTaskProxy("tenant-1", "dataset-1", ["doc-1"]) - - proxy2 = DocumentIndexingTaskProxy("tenant-2", "dataset-2", ["doc-2"]) - - # Act - features1 = proxy1.features - - features2 = proxy2.features - - # Assert - assert features1 == mock_features1 - - assert features2 == mock_features2 - - mock_feature_service.get_features.assert_any_call("tenant-1") - - mock_feature_service.get_features.assert_any_call("tenant-2") - - # ======================================================================== - # Direct Queue Routing Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_direct_queue(self, mock_task): - """ - Test _send_to_direct_queue method. - - This test verifies that _send_to_direct_queue correctly calls - task_func.delay() with the correct parameters, bypassing tenant - isolation queue management. - """ - # Arrange - tenant_id = "tenant-direct-queue" - dataset_id = "dataset-direct-queue" - document_ids = ["doc-direct-1", "doc-direct-2"] - proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_direct_queue_with_priority_task(self, mock_task): - """ - Test _send_to_direct_queue with priority task function. - - This test verifies that _send_to_direct_queue works correctly - with priority_document_indexing_task as the task function. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_direct_queue_with_single_document(self, mock_task): - """ - Test _send_to_direct_queue with single document ID. - - This test verifies that _send_to_direct_queue correctly handles - a single document ID in the document_ids list. - """ - # Arrange - proxy = DocumentIndexingTaskProxy("tenant-123", "dataset-456", ["doc-1"]) - - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"] - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_direct_queue_with_empty_documents(self, mock_task): - """ - Test _send_to_direct_queue with empty document_ids list. - - This test verifies that _send_to_direct_queue correctly handles - an empty document_ids list, which may occur in edge cases. - """ - # Arrange - proxy = DocumentIndexingTaskProxy("tenant-123", "dataset-456", []) - - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with(tenant_id="tenant-123", dataset_id="dataset-456", document_ids=[]) - - # ======================================================================== - # Tenant Queue Routing Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): - """ - Test _send_to_tenant_queue when task key exists. - - This test verifies that when a task key exists (indicating another - task is running), the new task is pushed to the waiting queue instead - of being executed immediately. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=True - ) - - mock_task.delay = Mock() - - # Act - proxy._send_to_tenant_queue(mock_task) - - # Assert - proxy._tenant_isolated_task_queue.push_tasks.assert_called_once() - - pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] - - assert len(pushed_tasks) == 1 - - expected_task_data = { - "tenant_id": "tenant-123", - "dataset_id": "dataset-456", - "document_ids": ["doc-1", "doc-2", "doc-3"], - } - assert pushed_tasks[0] == expected_task_data - - assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] - - mock_task.delay.assert_not_called() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_tenant_queue_without_task_key(self, mock_task): - """ - Test _send_to_tenant_queue when no task key exists. - - This test verifies that when no task key exists (indicating no task - is currently running), the task is executed immediately and the - task waiting time flag is set. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=False - ) - - mock_task.delay = Mock() - - # Act - proxy._send_to_tenant_queue(mock_task) - - # Assert - proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() - - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_tenant_queue_with_priority_task(self, mock_task): - """ - Test _send_to_tenant_queue with priority task function. - - This test verifies that _send_to_tenant_queue works correctly - with priority_document_indexing_task as the task function. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=False - ) - - mock_task.delay = Mock() - - # Act - proxy._send_to_tenant_queue(mock_task) - - # Assert - proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() - - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_tenant_queue_document_task_serialization(self, mock_task): - """ - Test DocumentTask serialization in _send_to_tenant_queue. - - This test verifies that DocumentTask entities are correctly - serialized to dictionaries when pushing to the waiting queue. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=True - ) - - mock_task.delay = Mock() - - # Act - proxy._send_to_tenant_queue(mock_task) - - # Assert - pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] - - task_dict = pushed_tasks[0] - - # Verify the task can be deserialized back to DocumentTask - document_task = DocumentTask(**task_dict) - - assert document_task.tenant_id == "tenant-123" - - assert document_task.dataset_id == "dataset-456" - - assert document_task.document_ids == ["doc-1", "doc-2", "doc-3"] - - # ======================================================================== - # Queue Type Selection Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_default_tenant_queue(self, mock_task): - """ - Test _send_to_default_tenant_queue method. - - This test verifies that _send_to_default_tenant_queue correctly - calls _send_to_tenant_queue with normal_document_indexing_task. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_tenant_queue = Mock() - - # Act - proxy._send_to_default_tenant_queue() - - # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_tenant_queue(self, mock_task): - """ - Test _send_to_priority_tenant_queue method. - - This test verifies that _send_to_priority_tenant_queue correctly - calls _send_to_tenant_queue with priority_document_indexing_task. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_tenant_queue = Mock() - - # Act - proxy._send_to_priority_tenant_queue() - - # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_direct_queue(self, mock_task): - """ - Test _send_to_priority_direct_queue method. - - This test verifies that _send_to_priority_direct_queue correctly - calls _send_to_direct_queue with priority_document_indexing_task. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_direct_queue = Mock() - - # Act - proxy._send_to_priority_direct_queue() - - # Assert - proxy._send_to_direct_queue.assert_called_once_with(mock_task) - - # ======================================================================== - # Dispatch Logic Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): - """ - Test _dispatch method when billing is enabled with SANDBOX plan. - - This test verifies that when billing is enabled and the subscription - plan is SANDBOX, the dispatch method routes to the default tenant queue. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.SANDBOX - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_default_tenant_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_default_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service): - """ - Test _dispatch method when billing is enabled with TEAM plan. - - This test verifies that when billing is enabled and the subscription - plan is TEAM, the dispatch method routes to the priority tenant queue. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.TEAM - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_tenant_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_priority_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service): - """ - Test _dispatch method when billing is enabled with PROFESSIONAL plan. - - This test verifies that when billing is enabled and the subscription - plan is PROFESSIONAL, the dispatch method routes to the priority tenant queue. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.PROFESSIONAL - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_tenant_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_priority_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_with_billing_disabled(self, mock_feature_service): - """ - Test _dispatch method when billing is disabled. - - This test verifies that when billing is disabled (e.g., self-hosted - or enterprise), the dispatch method routes to the priority direct queue. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_direct_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_priority_direct_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_edge_case_empty_plan(self, mock_feature_service): - """ - Test _dispatch method with empty plan string. - - This test verifies that when billing is enabled but the plan is an - empty string, the dispatch method routes to the priority tenant queue - (treats it as a non-SANDBOX plan). - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="") - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_tenant_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_priority_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_edge_case_none_plan(self, mock_feature_service): - """ - Test _dispatch method with None plan. - - This test verifies that when billing is enabled but the plan is None, - the dispatch method routes to the priority tenant queue (treats it as - a non-SANDBOX plan). - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_tenant_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_priority_tenant_queue.assert_called_once() - - # ======================================================================== - # Delay Method Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_delay_method(self, mock_feature_service): - """ - Test delay method integration. - - This test verifies that the delay method correctly calls _dispatch, - which is the public interface for scheduling document indexing tasks. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.SANDBOX - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_default_tenant_queue = Mock() - - # Act - proxy.delay() - - # Assert - proxy._send_to_default_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_delay_method_with_team_plan(self, mock_feature_service): - """ - Test delay method with TEAM plan. - - This test verifies that the delay method correctly routes to the - priority tenant queue when the subscription plan is TEAM. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.TEAM - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_tenant_queue = Mock() - - # Act - proxy.delay() - - # Assert - proxy._send_to_priority_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_delay_method_with_billing_disabled(self, mock_feature_service): - """ - Test delay method with billing disabled. - - This test verifies that the delay method correctly routes to the - priority direct queue when billing is disabled. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_direct_queue = Mock() - - # Act - proxy.delay() - - # Assert - proxy._send_to_priority_direct_queue.assert_called_once() - - # ======================================================================== - # DocumentTask Entity Tests - # ======================================================================== - - def test_document_task_dataclass(self): - """ - Test DocumentTask dataclass. - - This test verifies that DocumentTask entities can be created and - accessed correctly, which is important for task serialization. - """ - # Arrange - tenant_id = "tenant-123" - - dataset_id = "dataset-456" - - document_ids = ["doc-1", "doc-2"] - - # Act - task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) - - # Assert - assert task.tenant_id == tenant_id - - assert task.dataset_id == dataset_id - - assert task.document_ids == document_ids - - def test_document_task_serialization(self): - """ - Test DocumentTask serialization to dictionary. - - This test verifies that DocumentTask entities can be correctly - serialized to dictionaries using asdict() for queue storage. - """ - # Arrange - from dataclasses import asdict - - task = DocumentIndexingTaskProxyTestDataFactory.create_document_task() - - # Act - task_dict = asdict(task) - - # Assert - assert task_dict["tenant_id"] == "tenant-123" - - assert task_dict["dataset_id"] == "dataset-456" - - assert task_dict["document_ids"] == ["doc-1", "doc-2", "doc-3"] - - def test_document_task_deserialization(self): - """ - Test DocumentTask deserialization from dictionary. - - This test verifies that DocumentTask entities can be correctly - deserialized from dictionaries when pulled from the queue. - """ - # Arrange - task_dict = { - "tenant_id": "tenant-123", - "dataset_id": "dataset-456", - "document_ids": ["doc-1", "doc-2", "doc-3"], - } - - # Act - task = DocumentTask(**task_dict) - - # Assert - assert task.tenant_id == "tenant-123" - - assert task.dataset_id == "dataset-456" - - assert task.document_ids == ["doc-1", "doc-2", "doc-3"] - - # ======================================================================== - # Batch Operations Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_batch_operation_with_multiple_documents(self, mock_task): - """ - Test batch operation with multiple documents. - - This test verifies that the proxy correctly handles batch operations - with multiple document IDs in a single task. - """ - # Arrange - document_ids = [f"doc-{i}" for i in range(10)] - - proxy = DocumentIndexingTaskProxy("tenant-123", "dataset-456", document_ids) - - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_batch_operation_with_large_batch(self, mock_task): - """ - Test batch operation with large batch of documents. - - This test verifies that the proxy correctly handles large batches - of document IDs, which may occur in bulk indexing scenarios. - """ - # Arrange - document_ids = [f"doc-{i}" for i in range(100)] - - proxy = DocumentIndexingTaskProxy("tenant-123", "dataset-456", document_ids) - - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids - ) - - assert len(mock_task.delay.call_args[1]["document_ids"]) == 100 - - # ======================================================================== - # Error Handling Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_direct_queue_task_delay_failure(self, mock_task): - """ - Test _send_to_direct_queue when task.delay() raises an exception. - - This test verifies that exceptions raised by task.delay() are - propagated correctly and not swallowed. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - mock_task.delay.side_effect = Exception("Task delay failed") - - # Act & Assert - with pytest.raises(Exception, match="Task delay failed"): - proxy._send_to_direct_queue(mock_task) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_tenant_queue_push_tasks_failure(self, mock_task): - """ - Test _send_to_tenant_queue when push_tasks raises an exception. - - This test verifies that exceptions raised by push_tasks are - propagated correctly when a task key exists. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - mock_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(has_task_key=True) - - mock_queue.push_tasks.side_effect = Exception("Push tasks failed") - - proxy._tenant_isolated_task_queue = mock_queue - - # Act & Assert - with pytest.raises(Exception, match="Push tasks failed"): - proxy._send_to_tenant_queue(mock_task) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task): - """ - Test _send_to_tenant_queue when set_task_waiting_time raises an exception. - - This test verifies that exceptions raised by set_task_waiting_time are - propagated correctly when no task key exists. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - mock_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(has_task_key=False) - - mock_queue.set_task_waiting_time.side_effect = Exception("Set waiting time failed") - - proxy._tenant_isolated_task_queue = mock_queue - - # Act & Assert - with pytest.raises(Exception, match="Set waiting time failed"): - proxy._send_to_tenant_queue(mock_task) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_feature_service_failure(self, mock_feature_service): - """ - Test _dispatch when FeatureService.get_features raises an exception. - - This test verifies that exceptions raised by FeatureService.get_features - are propagated correctly during dispatch. - """ - # Arrange - mock_feature_service.get_features.side_effect = Exception("Feature service failed") - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - # Act & Assert - with pytest.raises(Exception, match="Feature service failed"): - proxy._dispatch() - - # ======================================================================== - # Integration Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service): - """ - Test full flow for SANDBOX plan with tenant queue. - - This test verifies the complete flow from delay() call to task - scheduling for a SANDBOX plan tenant, including tenant isolation. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.SANDBOX - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=False - ) - - mock_task.delay = Mock() - - # Act - proxy.delay() - - # Assert - proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() - - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_full_flow_team_plan(self, mock_task, mock_feature_service): - """ - Test full flow for TEAM plan with priority tenant queue. - - This test verifies the complete flow from delay() call to task - scheduling for a TEAM plan tenant, including priority routing. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.TEAM - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=False - ) - - mock_task.delay = Mock() - - # Act - proxy.delay() - - # Assert - proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() - - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_full_flow_billing_disabled(self, mock_task, mock_feature_service): - """ - Test full flow for billing disabled (self-hosted/enterprise). - - This test verifies the complete flow from delay() call to task - scheduling when billing is disabled, using priority direct queue. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - mock_task.delay = Mock() - - # Act - proxy.delay() - - # Assert - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") - def test_full_flow_with_existing_task_key(self, mock_task, mock_feature_service): - """ - Test full flow when task key exists (task queuing). - - This test verifies the complete flow when another task is already - running, ensuring the new task is queued correctly. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.SANDBOX - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=True - ) - - mock_task.delay = Mock() - - # Act - proxy.delay() - - # Assert - proxy._tenant_isolated_task_queue.push_tasks.assert_called_once() - - pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] - - expected_task_data = { - "tenant_id": "tenant-123", - "dataset_id": "dataset-456", - "document_ids": ["doc-1", "doc-2", "doc-3"], - } - assert pushed_tasks[0] == expected_task_data - - assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] - - mock_task.delay.assert_not_called() diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py deleted file mode 100644 index 83bae370eb..0000000000 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ /dev/null @@ -1,925 +0,0 @@ -""" -Extensive unit tests for ``ExternalDatasetService``. - -This module focuses on the *external dataset service* surface area, which is responsible -for integrating with **external knowledge APIs** and wiring them into Dify datasets. - -The goal of this test suite is twofold: - -- Provide **high‑confidence regression coverage** for all public helpers on - ``ExternalDatasetService``. -- Serve as **executable documentation** for how external API integration is expected - to behave in different scenarios (happy paths, validation failures, and error codes). - -The file intentionally contains **rich comments and generous spacing** in order to make -each scenario easy to scan during reviews. -""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock, Mock, patch - -import httpx -import pytest - -from constants import HIDDEN_VALUE -from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings -from services.entities.external_knowledge_entities.external_knowledge_entities import ( - Authorization, - AuthorizationConfig, - ExternalKnowledgeApiSetting, -) -from services.errors.dataset import DatasetNameDuplicateError -from services.external_knowledge_service import ExternalDatasetService - - -class ExternalDatasetTestDataFactory: - """ - Factory helpers for building *lightweight* mocks for external knowledge tests. - - These helpers are intentionally small and explicit: - - - They avoid pulling in unnecessary fixtures. - - They reflect the minimal contract that the service under test cares about. - """ - - @staticmethod - def create_external_api( - api_id: str = "api-123", - tenant_id: str = "tenant-1", - name: str = "Test API", - description: str = "Description", - settings: dict[str, Any] | None = None, - ) -> ExternalKnowledgeApis: - """ - Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields. - - Using the real SQLAlchemy model (instead of a pure Mock) makes it easier to - exercise ``settings_dict`` and other convenience properties if needed. - """ - - instance = ExternalKnowledgeApis( - tenant_id=tenant_id, - name=name, - description=description, - settings=None if settings is None else cast(str, pytest.approx), # type: ignore[assignment] - ) - - # Overwrite generated id for determinism in assertions. - instance.id = api_id - return instance - - @staticmethod - def create_dataset( - dataset_id: str = "ds-1", - tenant_id: str = "tenant-1", - name: str = "External Dataset", - provider: str = "external", - ) -> Dataset: - """ - Build a small ``Dataset`` instance representing an external dataset. - """ - - dataset = Dataset( - tenant_id=tenant_id, - name=name, - description="", - provider=provider, - created_by="user-1", - ) - dataset.id = dataset_id - return dataset - - @staticmethod - def create_external_binding( - tenant_id: str = "tenant-1", - dataset_id: str = "ds-1", - api_id: str = "api-1", - external_knowledge_id: str = "knowledge-1", - ) -> ExternalKnowledgeBindings: - """ - Small helper for a binding between dataset and external knowledge API. - """ - - binding = ExternalKnowledgeBindings( - tenant_id=tenant_id, - dataset_id=dataset_id, - external_knowledge_api_id=api_id, - external_knowledge_id=external_knowledge_id, - created_by="user-1", - ) - return binding - - -# --------------------------------------------------------------------------- -# get_external_knowledge_apis -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceGetExternalKnowledgeApis: - """ - Tests for ``ExternalDatasetService.get_external_knowledge_apis``. - - These tests focus on: - - - Basic pagination wiring via ``db.paginate``. - - Optional search keyword behaviour. - """ - - @pytest.fixture - def mock_db_paginate(self): - """ - Patch ``db.paginate`` so we do not touch the real database layer. - """ - - with ( - patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate, - patch("services.external_knowledge_service.select", autospec=True), - ): - yield mock_paginate - - def test_get_external_knowledge_apis_basic_pagination(self, mock_db_paginate: MagicMock): - """ - It should return ``items`` and ``total`` coming from the paginate object. - """ - - # Arrange - tenant_id = "tenant-1" - page = 1 - per_page = 20 - - mock_items = [Mock(spec=ExternalKnowledgeApis), Mock(spec=ExternalKnowledgeApis)] - mock_pagination = SimpleNamespace(items=mock_items, total=42) - mock_db_paginate.return_value = mock_pagination - - # Act - items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id) - - # Assert - assert items is mock_items - assert total == 42 - - mock_db_paginate.assert_called_once() - call_kwargs = mock_db_paginate.call_args.kwargs - assert call_kwargs["page"] == page - assert call_kwargs["per_page"] == per_page - assert call_kwargs["max_per_page"] == 100 - assert call_kwargs["error_out"] is False - - def test_get_external_knowledge_apis_with_search_keyword(self, mock_db_paginate: MagicMock): - """ - When a search keyword is provided, the query should be adjusted - (we simply assert that paginate is still called and does not explode). - """ - - # Arrange - tenant_id = "tenant-1" - page = 2 - per_page = 10 - search = "foo" - - mock_pagination = SimpleNamespace(items=[], total=0) - mock_db_paginate.return_value = mock_pagination - - # Act - items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id, search=search) - - # Assert - assert items == [] - assert total == 0 - mock_db_paginate.assert_called_once() - - -# --------------------------------------------------------------------------- -# validate_api_list -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceValidateApiList: - """ - Lightweight validation tests for ``validate_api_list``. - """ - - def test_validate_api_list_success(self): - """ - A minimal valid configuration (endpoint + api_key) should pass. - """ - - config = {"endpoint": "https://example.com", "api_key": "secret"} - - # Act & Assert – no exception expected - ExternalDatasetService.validate_api_list(config) - - @pytest.mark.parametrize( - ("config", "expected_message"), - [ - ({}, "api list is empty"), - ({"api_key": "k"}, "endpoint is required"), - ({"endpoint": "https://example.com"}, "api_key is required"), - ], - ) - def test_validate_api_list_failures(self, config: dict[str, Any], expected_message: str): - """ - Invalid configs should raise ``ValueError`` with a clear message. - """ - - with pytest.raises(ValueError, match=expected_message): - ExternalDatasetService.validate_api_list(config) - - -# --------------------------------------------------------------------------- -# create_external_knowledge_api & get/update/delete -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceCrudExternalKnowledgeApi: - """ - CRUD tests for external knowledge API templates. - """ - - @pytest.fixture - def mock_db_session(self): - """ - Patch ``db.session`` for all CRUD tests in this class. - """ - - with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: - yield mock_session - - def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock): - """ - ``create_external_knowledge_api`` should persist a new record - when settings are present and valid. - """ - - tenant_id = "tenant-1" - user_id = "user-1" - args = { - "name": "API", - "description": "desc", - "settings": {"endpoint": "https://api.example.com", "api_key": "secret"}, - } - - # We do not want to actually call the remote endpoint here, so we patch the validator. - with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check: - result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args) - - assert isinstance(result, ExternalKnowledgeApis) - mock_check.assert_called_once_with(args["settings"]) - mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() - - def test_create_external_knowledge_api_missing_settings_raises(self, mock_db_session: MagicMock): - """ - Missing ``settings`` should result in a ``ValueError``. - """ - - tenant_id = "tenant-1" - user_id = "user-1" - args = {"name": "API", "description": "desc"} - - with pytest.raises(ValueError, match="settings is required"): - ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args) - - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - def test_get_external_knowledge_api_found(self, mock_db_session: MagicMock): - """ - ``get_external_knowledge_api`` should return the first matching record. - """ - - api = Mock(spec=ExternalKnowledgeApis) - mock_db_session.scalar.return_value = api - - result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id") - assert result is api - - def test_get_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock): - """ - When the record is absent, a ``ValueError`` is raised. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="api template not found"): - ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id") - - def test_update_external_knowledge_api_success_with_hidden_api_key(self, mock_db_session: MagicMock): - """ - Updating an API should keep the existing API key when the special hidden - value placeholder is sent from the UI. - """ - - tenant_id = "tenant-1" - user_id = "user-1" - api_id = "api-1" - - existing_api = Mock(spec=ExternalKnowledgeApis) - existing_api.settings_dict = {"api_key": "stored-key"} - existing_api.settings = '{"api_key":"stored-key"}' - mock_db_session.scalar.return_value = existing_api - - args = { - "name": "New Name", - "description": "New Desc", - "settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE}, - } - - result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args) - - assert result is existing_api - # The placeholder should be replaced with stored key. - assert args["settings"]["api_key"] == "stored-key" - mock_db_session.commit.assert_called_once() - - def test_update_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock): - """ - Updating a non‑existent API template should raise ``ValueError``. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="api template not found"): - ExternalDatasetService.update_external_knowledge_api( - tenant_id="tenant-1", - user_id="user-1", - external_knowledge_api_id="missing-id", - args={"name": "n", "description": "d", "settings": {}}, - ) - - def test_delete_external_knowledge_api_success(self, mock_db_session: MagicMock): - """ - ``delete_external_knowledge_api`` should delete and commit when found. - """ - - api = Mock(spec=ExternalKnowledgeApis) - mock_db_session.scalar.return_value = api - - ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1") - - mock_db_session.delete.assert_called_once_with(api) - mock_db_session.commit.assert_called_once() - - def test_delete_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock): - """ - Deletion of a missing template should raise ``ValueError``. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="api template not found"): - ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing") - - -# --------------------------------------------------------------------------- -# external_knowledge_api_use_check & binding lookups -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceUsageAndBindings: - """ - Tests for usage checks and dataset binding retrieval. - """ - - @pytest.fixture - def mock_db_session(self): - with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: - yield mock_session - - def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock): - """ - When there are bindings, ``external_knowledge_api_use_check`` returns True and count. - """ - - mock_db_session.scalar.return_value = 3 - - in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1") - - assert in_use is True - assert count == 3 - assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0]) - - def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock): - """ - Zero bindings should return ``(False, 0)``. - """ - - mock_db_session.scalar.return_value = 0 - - in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1") - - assert in_use is False - assert count == 0 - - def test_get_external_knowledge_binding_with_dataset_id_found(self, mock_db_session: MagicMock): - """ - Binding lookup should return the first record when present. - """ - - binding = Mock(spec=ExternalKnowledgeBindings) - mock_db_session.scalar.return_value = binding - - result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1") - assert result is binding - - def test_get_external_knowledge_binding_with_dataset_id_not_found_raises(self, mock_db_session: MagicMock): - """ - Missing binding should result in a ``ValueError``. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="external knowledge binding not found"): - ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1") - - -# --------------------------------------------------------------------------- -# document_create_args_validate -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceDocumentCreateArgsValidate: - """ - Tests for ``document_create_args_validate``. - """ - - @pytest.fixture - def mock_db_session(self): - with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: - yield mock_session - - def test_document_create_args_validate_success(self, mock_db_session: MagicMock): - """ - All required custom parameters present – validation should pass. - """ - - external_api = Mock(spec=ExternalKnowledgeApis) - external_api.settings = json_settings = ( - '[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]' - ) - # Raw string; the service itself calls json.loads on it - mock_db_session.scalar.return_value = external_api - - process_parameter = {"foo": "value", "bar": "optional"} - - # Act & Assert – no exception - ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter) - - assert json_settings in external_api.settings # simple sanity check on our test data - - def test_document_create_args_validate_missing_template_raises(self, mock_db_session: MagicMock): - """ - When the referenced API template is missing, a ``ValueError`` is raised. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="api template not found"): - ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {}) - - def test_document_create_args_validate_missing_required_parameter_raises(self, mock_db_session: MagicMock): - """ - Required document process parameters must be supplied. - """ - - external_api = Mock(spec=ExternalKnowledgeApis) - external_api.settings = ( - '[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]' - ) - mock_db_session.scalar.return_value = external_api - - process_parameter = {"bar": "present"} # missing "foo" - - with pytest.raises(ValueError, match="foo is required"): - ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter) - - -# --------------------------------------------------------------------------- -# process_external_api -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceProcessExternalApi: - """ - Tests focused on the HTTP request assembly and method mapping behaviour. - """ - - def test_process_external_api_valid_method_post(self): - """ - For a supported HTTP verb we should delegate to the correct ``ssrf_proxy`` function. - """ - - settings = ExternalKnowledgeApiSetting( - url="https://example.com/path", - request_method="POST", - headers={"X-Test": "1"}, - params={"foo": "bar"}, - ) - - fake_response = httpx.Response(200) - - with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post: - mock_post.return_value = fake_response - - result = ExternalDatasetService.process_external_api(settings, files=None) - - assert result is fake_response - mock_post.assert_called_once() - kwargs = mock_post.call_args.kwargs - assert kwargs["url"] == settings.url - assert kwargs["headers"] == settings.headers - assert kwargs["follow_redirects"] is True - assert "data" in kwargs - - def test_process_external_api_invalid_method_raises(self): - """ - An unsupported HTTP verb should raise ``InvalidHttpMethodError``. - """ - - settings = ExternalKnowledgeApiSetting( - url="https://example.com", - request_method="INVALID", - headers=None, - params={}, - ) - - from graphon.nodes.http_request.exc import InvalidHttpMethodError - - with pytest.raises(InvalidHttpMethodError): - ExternalDatasetService.process_external_api(settings, files=None) - - -# --------------------------------------------------------------------------- -# assembling_headers -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceAssemblingHeaders: - """ - Tests for header assembly based on different authentication flavours. - """ - - def test_assembling_headers_bearer_token(self): - """ - For bearer auth we expect ``Authorization: Bearer `` by default. - """ - - auth = Authorization( - type="api-key", - config=AuthorizationConfig(type="bearer", api_key="secret", header=None), - ) - - headers = ExternalDatasetService.assembling_headers(auth) - - assert headers["Authorization"] == "Bearer secret" - - def test_assembling_headers_basic_token_with_custom_header(self): - """ - For basic auth we honour the configured header name. - """ - - auth = Authorization( - type="api-key", - config=AuthorizationConfig(type="basic", api_key="abc123", header="X-Auth"), - ) - - headers = ExternalDatasetService.assembling_headers(auth, headers={"Existing": "1"}) - - assert headers["Existing"] == "1" - assert headers["X-Auth"] == "Basic abc123" - - def test_assembling_headers_custom_type(self): - """ - Custom auth type should inject the raw API key. - """ - - auth = Authorization( - type="api-key", - config=AuthorizationConfig(type="custom", api_key="raw-key", header="X-API-KEY"), - ) - - headers = ExternalDatasetService.assembling_headers(auth, headers=None) - - assert headers["X-API-KEY"] == "raw-key" - - def test_assembling_headers_missing_config_raises(self): - """ - Missing config object should be rejected. - """ - - auth = Authorization(type="api-key", config=None) - - with pytest.raises(ValueError, match="authorization config is required"): - ExternalDatasetService.assembling_headers(auth) - - def test_assembling_headers_missing_api_key_raises(self): - """ - ``api_key`` is required when type is ``api-key``. - """ - - auth = Authorization( - type="api-key", - config=AuthorizationConfig(type="bearer", api_key=None, header="Authorization"), - ) - - with pytest.raises(ValueError, match="api_key is required"): - ExternalDatasetService.assembling_headers(auth) - - def test_assembling_headers_no_auth_type_leaves_headers_unchanged(self): - """ - For ``no-auth`` we should not modify the headers mapping. - """ - - auth = Authorization(type="no-auth", config=None) - - base_headers = {"X": "1"} - result = ExternalDatasetService.assembling_headers(auth, headers=base_headers) - - # A copy is returned, original is not mutated. - assert result == base_headers - assert result is not base_headers - - -# --------------------------------------------------------------------------- -# get_external_knowledge_api_settings -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceGetExternalKnowledgeApiSettings: - """ - Simple shape test for ``get_external_knowledge_api_settings``. - """ - - def test_get_external_knowledge_api_settings(self): - settings_dict: dict[str, Any] = { - "url": "https://example.com/retrieval", - "request_method": "post", - "headers": {"Content-Type": "application/json"}, - "params": {"foo": "bar"}, - } - - result = ExternalDatasetService.get_external_knowledge_api_settings(settings_dict) - - assert isinstance(result, ExternalKnowledgeApiSetting) - assert result.url == settings_dict["url"] - assert result.request_method == settings_dict["request_method"] - assert result.headers == settings_dict["headers"] - assert result.params == settings_dict["params"] - - -# --------------------------------------------------------------------------- -# create_external_dataset -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceCreateExternalDataset: - """ - Tests around creating the external dataset and its binding row. - """ - - @pytest.fixture - def mock_db_session(self): - with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: - yield mock_session - - def test_create_external_dataset_success(self, mock_db_session: MagicMock): - """ - A brand new dataset name with valid external knowledge references - should create both the dataset and its binding. - """ - - tenant_id = "tenant-1" - user_id = "user-1" - - args = { - "name": "My Dataset", - "description": "desc", - "external_knowledge_api_id": "api-1", - "external_knowledge_id": "knowledge-1", - "external_retrieval_model": {"top_k": 3}, - } - - # No existing dataset with same name. - mock_db_session.scalar.side_effect = [ - None, # duplicate‑name check - Mock(spec=ExternalKnowledgeApis), # external knowledge api - ] - - dataset = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args) - - assert isinstance(dataset, Dataset) - assert dataset.provider == "external" - assert dataset.retrieval_model == args["external_retrieval_model"] - - assert mock_db_session.add.call_count >= 2 # dataset + binding - mock_db_session.flush.assert_called_once() - mock_db_session.commit.assert_called_once() - - def test_create_external_dataset_duplicate_name_raises(self, mock_db_session: MagicMock): - """ - When a dataset with the same name already exists, - ``DatasetNameDuplicateError`` is raised. - """ - - existing_dataset = Mock(spec=Dataset) - mock_db_session.scalar.return_value = existing_dataset - - args = { - "name": "Existing", - "external_knowledge_api_id": "api-1", - "external_knowledge_id": "knowledge-1", - } - - with pytest.raises(DatasetNameDuplicateError): - ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args) - - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - def test_create_external_dataset_missing_api_template_raises(self, mock_db_session: MagicMock): - """ - If the referenced external knowledge API does not exist, a ``ValueError`` is raised. - """ - - # First call: duplicate name check – not found. - mock_db_session.scalar.side_effect = [ - None, - None, # external knowledge api lookup - ] - - args = { - "name": "Dataset", - "external_knowledge_api_id": "missing", - "external_knowledge_id": "knowledge-1", - } - - with pytest.raises(ValueError, match="api template not found"): - ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args) - - def test_create_external_dataset_missing_required_ids_raise(self, mock_db_session: MagicMock): - """ - ``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory. - """ - - # duplicate name check — two calls to create_external_dataset, each does 2 scalar calls - mock_db_session.scalar.side_effect = [ - None, - Mock(spec=ExternalKnowledgeApis), - None, - Mock(spec=ExternalKnowledgeApis), - ] - - args_missing_knowledge_id = { - "name": "Dataset", - "external_knowledge_api_id": "api-1", - "external_knowledge_id": None, - } - - with pytest.raises(ValueError, match="external_knowledge_id is required"): - ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_knowledge_id) - - args_missing_api_id = { - "name": "Dataset", - "external_knowledge_api_id": None, - "external_knowledge_id": "k-1", - } - - with pytest.raises(ValueError, match="external_knowledge_api_id is required"): - ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_api_id) - - -# --------------------------------------------------------------------------- -# fetch_external_knowledge_retrieval -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval: - """ - Tests for ``fetch_external_knowledge_retrieval`` which orchestrates - external retrieval requests and normalises the response payload. - """ - - @pytest.fixture - def mock_db_session(self): - with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: - yield mock_session - - def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock): - """ - With a valid binding and API template, records from the external - service should be returned when the HTTP response is 200. - """ - - tenant_id = "tenant-1" - dataset_id = "ds-1" - query = "test query" - external_retrieval_parameters = {"top_k": 3, "score_threshold_enabled": True, "score_threshold": 0.5} - - binding = ExternalDatasetTestDataFactory.create_external_binding( - tenant_id=tenant_id, - dataset_id=dataset_id, - api_id="api-1", - external_knowledge_id="knowledge-1", - ) - - api = Mock(spec=ExternalKnowledgeApis) - api.settings = '{"endpoint":"https://example.com","api_key":"secret"}' - - # First query: binding; second query: api. - mock_db_session.scalar.side_effect = [ - binding, - api, - ] - - fake_records = [{"content": "doc", "score": 0.9}] - fake_response = Mock(spec=httpx.Response) - fake_response.status_code = 200 - fake_response.json.return_value = {"records": fake_records} - - metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"}) - - with patch.object( - ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True - ) as mock_process: - result = ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id=tenant_id, - dataset_id=dataset_id, - query=query, - external_retrieval_parameters=external_retrieval_parameters, - metadata_condition=metadata_condition, - ) - - assert result == fake_records - - mock_process.assert_called_once() - setting_arg = mock_process.call_args.args[0] - assert isinstance(setting_arg, ExternalKnowledgeApiSetting) - assert setting_arg.url.endswith("/retrieval") - - def test_fetch_external_knowledge_retrieval_binding_not_found_raises(self, mock_db_session: MagicMock): - """ - Missing binding should raise ``ValueError``. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="external knowledge binding not found"): - ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id="tenant-1", - dataset_id="missing", - query="q", - external_retrieval_parameters={}, - metadata_condition=None, - ) - - def test_fetch_external_knowledge_retrieval_missing_api_template_raises(self, mock_db_session: MagicMock): - """ - When the API template is missing or has no settings, a ``ValueError`` is raised. - """ - - binding = ExternalDatasetTestDataFactory.create_external_binding() - mock_db_session.scalar.side_effect = [ - binding, - None, - ] - - with pytest.raises(ValueError, match="external api template not found"): - ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id="tenant-1", - dataset_id="ds-1", - query="q", - external_retrieval_parameters={}, - metadata_condition=None, - ) - - def test_fetch_external_knowledge_retrieval_non_200_status_returns_empty_list(self, mock_db_session: MagicMock): - """ - Non‑200 responses should be treated as an empty result set. - """ - - binding = ExternalDatasetTestDataFactory.create_external_binding() - api = Mock(spec=ExternalKnowledgeApis) - api.settings = '{"endpoint":"https://example.com","api_key":"secret"}' - - mock_db_session.scalar.side_effect = [ - binding, - api, - ] - - fake_response = Mock(spec=httpx.Response) - fake_response.status_code = 500 - fake_response.json.return_value = {} - - with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True): - result = ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id="tenant-1", - dataset_id="ds-1", - query="q", - external_retrieval_parameters={}, - metadata_condition=None, - ) - - assert result == [] diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py index 327281d07f..efb79aadde 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py @@ -374,24 +374,14 @@ def test_publish_workflow_success(mocker, rag_pipeline_service) -> None: mock_db = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db", mock_db) mock_dataset_service_class = mocker.patch("services.dataset_service.DatasetService") - mock_dataset_service = mock_dataset_service_class.return_value - # 6. Mock session and its scalar/query methods + # 6. Mock session and dataset lookup mock_session = mocker.Mock() mock_session.scalar.return_value = draft_wf - # Mock dataset update query (needed even if service is mocked, as rag_pipeline fetches it first) dataset = mocker.Mock() dataset.retrieval_model_dict = {} - dataset_query = mocker.Mock() - dataset_query.where.return_value.first.return_value = dataset - - # Mock node execution copy - node_exec_query = mocker.Mock() - node_exec_query.where.return_value.all.return_value = [] - - # Mocked session query side effects - mock_session.query.side_effect = [node_exec_query, dataset_query] + pipeline.retrieve_dataset.return_value = dataset # 7. Run test result = rag_pipeline_service.publish_workflow(session=mock_session, pipeline=pipeline, account=account) @@ -1524,7 +1514,6 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker ) document = SimpleNamespace(indexing_status="waiting", error=None) - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document) add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add") commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") @@ -1595,7 +1584,6 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None: - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None) with pytest.raises(ValueError, match="Dataset not found"): @@ -1604,7 +1592,6 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None: dataset = SimpleNamespace(pipeline_id="p1") - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None]) with pytest.raises(ValueError, match="Pipeline not found"): @@ -1644,7 +1631,6 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None: def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None: template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template) commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) @@ -1871,7 +1857,6 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None: pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1") - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None]) with pytest.raises(ValueError, match="Workflow not found"): @@ -1910,7 +1895,6 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None: exec_log = SimpleNamespace(pipeline_id="p1") - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log) mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None) @@ -1923,7 +1907,6 @@ def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None: exec_log = SimpleNamespace(pipeline_id="p1") pipeline = SimpleNamespace(id="p1") - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log) mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None) @@ -1940,7 +1923,6 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r workflow = SimpleNamespace( graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[] ) - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) @@ -2103,7 +2085,6 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published( graph_dict={"nodes": [{"id": "n1", "data": {"type": "datasource", "datasource_parameters": {}}}]}, rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}], ) - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow) mocker.patch( @@ -2143,7 +2124,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag {"variable": "v3", "belong_to_node_id": "shared"}, ], ) - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) mocker.patch( @@ -2161,7 +2141,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service) -> None: dataset = SimpleNamespace(pipeline_id="p1") pipeline = SimpleNamespace(id="p1") - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) result = rag_pipeline_service.get_pipeline("t1", "d1") diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py deleted file mode 100644 index f0a66a00d4..0000000000 --- a/api/tests/unit_tests/services/segment_service.py +++ /dev/null @@ -1,1115 +0,0 @@ -from unittest.mock import MagicMock, Mock, patch - -import pytest - -from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models.account import Account -from models.dataset import ChildChunk, Dataset, Document, DocumentSegment -from models.enums import SegmentType -from services.dataset_service import SegmentService -from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs -from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError - - -class SegmentTestDataFactory: - """Factory class for creating test data and mock objects for segment service tests.""" - - @staticmethod - def create_segment_mock( - segment_id: str = "segment-123", - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - content: str = "Test segment content", - position: int = 1, - enabled: bool = True, - status: str = "completed", - word_count: int = 3, - tokens: int = 5, - **kwargs, - ) -> Mock: - """Create a mock segment with specified attributes.""" - segment = Mock(spec=DocumentSegment) - segment.id = segment_id - segment.document_id = document_id - segment.dataset_id = dataset_id - segment.tenant_id = tenant_id - segment.content = content - segment.position = position - segment.enabled = enabled - segment.status = status - segment.word_count = word_count - segment.tokens = tokens - segment.index_node_id = f"node-{segment_id}" - segment.index_node_hash = "hash-123" - segment.keywords = [] - segment.answer = None - segment.disabled_at = None - segment.disabled_by = None - segment.updated_by = None - segment.updated_at = None - segment.indexing_at = None - segment.completed_at = None - segment.error = None - for key, value in kwargs.items(): - setattr(segment, key, value) - return segment - - @staticmethod - def create_child_chunk_mock( - chunk_id: str = "chunk-123", - segment_id: str = "segment-123", - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - content: str = "Test child chunk content", - position: int = 1, - word_count: int = 3, - **kwargs, - ) -> Mock: - """Create a mock child chunk with specified attributes.""" - chunk = Mock(spec=ChildChunk) - chunk.id = chunk_id - chunk.segment_id = segment_id - chunk.document_id = document_id - chunk.dataset_id = dataset_id - chunk.tenant_id = tenant_id - chunk.content = content - chunk.position = position - chunk.word_count = word_count - chunk.index_node_id = f"node-{chunk_id}" - chunk.index_node_hash = "hash-123" - chunk.type = SegmentType.AUTOMATIC - chunk.created_by = "user-123" - chunk.updated_by = None - chunk.updated_at = None - for key, value in kwargs.items(): - setattr(chunk, key, value) - return chunk - - @staticmethod - def create_document_mock( - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - doc_form: str = IndexStructureType.PARAGRAPH_INDEX, - word_count: int = 100, - **kwargs, - ) -> Mock: - """Create a mock document with specified attributes.""" - document = Mock(spec=Document) - document.id = document_id - document.dataset_id = dataset_id - document.tenant_id = tenant_id - document.doc_form = doc_form - document.word_count = word_count - for key, value in kwargs.items(): - setattr(document, key, value) - return document - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, - embedding_model: str = "text-embedding-ada-002", - embedding_model_provider: str = "openai", - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.indexing_technique = indexing_technique - dataset.embedding_model = embedding_model - dataset.embedding_model_provider = embedding_model_provider - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-789", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """Create a mock user with specified attributes.""" - user = Mock(spec=Account) - user.id = user_id - user.current_tenant_id = tenant_id - user.name = "Test User" - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - -class TestSegmentServiceCreateSegment: - """Tests for SegmentService.create_segment method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_create_segment_success(self, mock_db_session, mock_current_user): - """Test successful creation of a segment.""" - # Arrange - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - args = {"content": "New segment content", "keywords": ["test", "segment"]} - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None # No existing segments - mock_db_session.query.return_value = mock_query - - mock_segment = SegmentTestDataFactory.create_segment_mock() - mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_segments_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_hash.return_value = "hash-123" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.create_segment(args, document, dataset) - - # Assert - assert mock_db_session.add.call_count == 2 - - created_segment = mock_db_session.add.call_args_list[0].args[0] - assert isinstance(created_segment, DocumentSegment) - assert created_segment.content == args["content"] - assert created_segment.word_count == len(args["content"]) - - mock_db_session.commit.assert_called_once() - - mock_vector_service.assert_called_once() - vector_call_args = mock_vector_service.call_args[0] - assert vector_call_args[0] == [args["keywords"]] - assert vector_call_args[1][0] == created_segment - assert vector_call_args[2] == dataset - assert vector_call_args[3] == document.doc_form - - assert result == mock_segment - - def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user): - """Test creation of segment with QA model (requires answer).""" - # Arrange - document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]} - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None - mock_db_session.query.return_value = mock_query - - mock_segment = SegmentTestDataFactory.create_segment_mock() - mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_segments_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_hash.return_value = "hash-123" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.create_segment(args, document, dataset) - - # Assert - assert result == mock_segment - mock_db_session.add.assert_called() - mock_db_session.commit.assert_called() - - def test_create_segment_with_high_quality_indexing(self, mock_db_session, mock_current_user): - """Test creation of segment with high quality indexing technique.""" - # Arrange - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - args = {"content": "New segment content", "keywords": ["test"]} - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None - mock_db_session.query.return_value = mock_query - - mock_embedding_model = MagicMock() - mock_embedding_model.get_text_embedding_num_tokens.return_value = [10] - mock_model_manager = MagicMock() - mock_model_manager.get_model_instance.return_value = mock_embedding_model - - mock_segment = SegmentTestDataFactory.create_segment_mock() - mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_segments_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.ModelManager.for_tenant", autospec=True) as mock_model_manager_class, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_model_manager_class.return_value = mock_model_manager - mock_hash.return_value = "hash-123" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.create_segment(args, document, dataset) - - # Assert - assert result == mock_segment - mock_model_manager.get_model_instance.assert_called_once() - mock_embedding_model.get_text_embedding_num_tokens.assert_called_once() - - def test_create_segment_vector_index_failure(self, mock_db_session, mock_current_user): - """Test segment creation when vector indexing fails.""" - # Arrange - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - args = {"content": "New segment content", "keywords": ["test"]} - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None - mock_db_session.query.return_value = mock_query - - mock_segment = SegmentTestDataFactory.create_segment_mock(enabled=False, status="error") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_segments_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_vector_service.side_effect = Exception("Vector indexing failed") - mock_hash.return_value = "hash-123" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.create_segment(args, document, dataset) - - # Assert - assert result == mock_segment - assert mock_db_session.commit.call_count == 2 # Once for creation, once for error update - - -class TestSegmentServiceUpdateSegment: - """Tests for SegmentService.update_segment method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_update_segment_content_success(self, mock_db_session, mock_current_user): - """Test successful update of segment content.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - args = SegmentUpdateArgs(content="Updated content", keywords=["updated"]) - - mock_db_session.query.return_value.where.return_value.first.return_value = segment - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_redis_get.return_value = None # Not indexing - mock_hash.return_value = "new-hash" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.update_segment(args, segment, document, dataset) - - # Assert - assert result == segment - assert segment.content == "Updated content" - assert segment.keywords == ["updated"] - assert segment.word_count == len("Updated content") - assert document.word_count == 100 + (len("Updated content") - 10) - mock_db_session.add.assert_called() - mock_db_session.commit.assert_called() - - def test_update_segment_disable(self, mock_db_session, mock_current_user): - """Test disabling a segment.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True) - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - args = SegmentUpdateArgs(enabled=False) - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex, - patch("services.dataset_service.disable_segment_from_index_task", autospec=True) as mock_task, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_redis_get.return_value = None - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.update_segment(args, segment, document, dataset) - - # Assert - assert result == segment - assert segment.enabled is False - mock_db_session.add.assert_called() - mock_db_session.commit.assert_called() - mock_task.delay.assert_called_once() - - def test_update_segment_indexing_in_progress(self, mock_db_session, mock_current_user): - """Test update fails when segment is currently indexing.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True) - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - args = SegmentUpdateArgs(content="Updated content") - - with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get: - mock_redis_get.return_value = "1" # Indexing in progress - - # Act & Assert - with pytest.raises(ValueError, match="Segment is indexing"): - SegmentService.update_segment(args, segment, document, dataset) - - def test_update_segment_disabled_segment(self, mock_db_session, mock_current_user): - """Test update fails when segment is disabled.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=False) - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - args = SegmentUpdateArgs(content="Updated content") - - with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get: - mock_redis_get.return_value = None - - # Act & Assert - with pytest.raises(ValueError, match="Can't update disabled segment"): - SegmentService.update_segment(args, segment, document, dataset) - - def test_update_segment_with_qa_model(self, mock_db_session, mock_current_user): - """Test update segment with QA model (includes answer).""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) - document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"]) - - mock_db_session.query.return_value.where.return_value.first.return_value = segment - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_redis_get.return_value = None - mock_hash.return_value = "new-hash" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.update_segment(args, segment, document, dataset) - - # Assert - assert result == segment - assert segment.content == "Updated question" - assert segment.answer == "Updated answer" - assert segment.keywords == ["qa"] - new_word_count = len("Updated question") + len("Updated answer") - assert segment.word_count == new_word_count - assert document.word_count == 100 + (new_word_count - 10) - mock_db_session.commit.assert_called() - - -class TestSegmentServiceDeleteSegment: - """Tests for SegmentService.delete_segment method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - def test_delete_segment_success(self, mock_db_session): - """Test successful deletion of a segment.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=50) - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock() - - mock_scalars = MagicMock() - mock_scalars.all.return_value = [] - mock_db_session.scalars.return_value = mock_scalars - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex, - patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task, - patch("services.dataset_service.select", autospec=True) as mock_select, - ): - mock_redis_get.return_value = None - mock_select.return_value.where.return_value = mock_select - - # Act - SegmentService.delete_segment(segment, document, dataset) - - # Assert - mock_db_session.delete.assert_called_once_with(segment) - mock_db_session.commit.assert_called_once() - mock_task.delay.assert_called_once() - - def test_delete_segment_disabled(self, mock_db_session): - """Test deletion of disabled segment (no index deletion).""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=False, word_count=50) - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock() - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task, - ): - mock_redis_get.return_value = None - - # Act - SegmentService.delete_segment(segment, document, dataset) - - # Assert - mock_db_session.delete.assert_called_once_with(segment) - mock_db_session.commit.assert_called_once() - mock_task.delay.assert_not_called() - - def test_delete_segment_indexing_in_progress(self, mock_db_session): - """Test deletion fails when segment is currently being deleted.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True) - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get: - mock_redis_get.return_value = "1" # Deletion in progress - - # Act & Assert - with pytest.raises(ValueError, match="Segment is deleting"): - SegmentService.delete_segment(segment, document, dataset) - - -class TestSegmentServiceDeleteSegments: - """Tests for SegmentService.delete_segments method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_delete_segments_success(self, mock_db_session, mock_current_user): - """Test successful deletion of multiple segments.""" - # Arrange - segment_ids = ["segment-1", "segment-2"] - document = SegmentTestDataFactory.create_document_mock(word_count=200) - dataset = SegmentTestDataFactory.create_dataset_mock() - - segments_info = [ - ("node-1", "segment-1", 50), - ("node-2", "segment-2", 30), - ] - - mock_query = MagicMock() - mock_query.with_entities.return_value.where.return_value.all.return_value = segments_info - mock_db_session.query.return_value = mock_query - - mock_scalars = MagicMock() - mock_scalars.all.return_value = [] - mock_select = MagicMock() - mock_select.where.return_value = mock_select - mock_db_session.scalars.return_value = mock_scalars - - with ( - patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task, - patch("services.dataset_service.select", autospec=True) as mock_select_func, - ): - mock_select_func.return_value = mock_select - - # Act - SegmentService.delete_segments(segment_ids, document, dataset) - - # Assert - mock_db_session.query.return_value.where.return_value.delete.assert_called_once() - mock_db_session.commit.assert_called_once() - mock_task.delay.assert_called_once() - - def test_delete_segments_empty_list(self, mock_db_session, mock_current_user): - """Test deletion with empty list (should return early).""" - # Arrange - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - # Act - SegmentService.delete_segments([], document, dataset) - - # Assert - mock_db_session.query.assert_not_called() - - -class TestSegmentServiceUpdateSegmentsStatus: - """Tests for SegmentService.update_segments_status method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_update_segments_status_enable(self, mock_db_session, mock_current_user): - """Test enabling multiple segments.""" - # Arrange - segment_ids = ["segment-1", "segment-2"] - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - segments = [ - SegmentTestDataFactory.create_segment_mock(segment_id="segment-1", enabled=False), - SegmentTestDataFactory.create_segment_mock(segment_id="segment-2", enabled=False), - ] - - mock_scalars = MagicMock() - mock_scalars.all.return_value = segments - mock_select = MagicMock() - mock_select.where.return_value = mock_select - mock_db_session.scalars.return_value = mock_scalars - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.enable_segments_to_index_task", autospec=True) as mock_task, - patch("services.dataset_service.select", autospec=True) as mock_select_func, - ): - mock_redis_get.return_value = None - mock_select_func.return_value = mock_select - - # Act - SegmentService.update_segments_status(segment_ids, "enable", dataset, document) - - # Assert - assert all(seg.enabled is True for seg in segments) - mock_db_session.commit.assert_called_once() - mock_task.delay.assert_called_once() - - def test_update_segments_status_disable(self, mock_db_session, mock_current_user): - """Test disabling multiple segments.""" - # Arrange - segment_ids = ["segment-1", "segment-2"] - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - segments = [ - SegmentTestDataFactory.create_segment_mock(segment_id="segment-1", enabled=True), - SegmentTestDataFactory.create_segment_mock(segment_id="segment-2", enabled=True), - ] - - mock_scalars = MagicMock() - mock_scalars.all.return_value = segments - mock_select = MagicMock() - mock_select.where.return_value = mock_select - mock_db_session.scalars.return_value = mock_scalars - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.disable_segments_from_index_task", autospec=True) as mock_task, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - patch("services.dataset_service.select", autospec=True) as mock_select_func, - ): - mock_redis_get.return_value = None - mock_now.return_value = "2024-01-01T00:00:00" - mock_select_func.return_value = mock_select - - # Act - SegmentService.update_segments_status(segment_ids, "disable", dataset, document) - - # Assert - assert all(seg.enabled is False for seg in segments) - mock_db_session.commit.assert_called_once() - mock_task.delay.assert_called_once() - - def test_update_segments_status_empty_list(self, mock_db_session, mock_current_user): - """Test update with empty list (should return early).""" - # Arrange - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - # Act - SegmentService.update_segments_status([], "enable", dataset, document) - - # Assert - mock_db_session.scalars.assert_not_called() - - -class TestSegmentServiceGetSegments: - """Tests for SegmentService.get_segments method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_get_segments_success(self, mock_db_session, mock_current_user): - """Test successful retrieval of segments.""" - # Arrange - document_id = "doc-123" - tenant_id = "tenant-123" - segments = [ - SegmentTestDataFactory.create_segment_mock(segment_id="segment-1"), - SegmentTestDataFactory.create_segment_mock(segment_id="segment-2"), - ] - - mock_paginate = MagicMock() - mock_paginate.items = segments - mock_paginate.total = 2 - mock_db_session.paginate.return_value = mock_paginate - - # Act - items, total = SegmentService.get_segments(document_id, tenant_id) - - # Assert - assert len(items) == 2 - assert total == 2 - mock_db_session.paginate.assert_called_once() - - def test_get_segments_with_status_filter(self, mock_db_session, mock_current_user): - """Test retrieval with status filter.""" - # Arrange - document_id = "doc-123" - tenant_id = "tenant-123" - status_list = ["completed", "error"] - - mock_paginate = MagicMock() - mock_paginate.items = [] - mock_paginate.total = 0 - mock_db_session.paginate.return_value = mock_paginate - - # Act - items, total = SegmentService.get_segments(document_id, tenant_id, status_list=status_list) - - # Assert - assert len(items) == 0 - assert total == 0 - - def test_get_segments_with_keyword(self, mock_db_session, mock_current_user): - """Test retrieval with keyword search.""" - # Arrange - document_id = "doc-123" - tenant_id = "tenant-123" - keyword = "test" - - mock_paginate = MagicMock() - mock_paginate.items = [SegmentTestDataFactory.create_segment_mock()] - mock_paginate.total = 1 - mock_db_session.paginate.return_value = mock_paginate - - # Act - items, total = SegmentService.get_segments(document_id, tenant_id, keyword=keyword) - - # Assert - assert len(items) == 1 - assert total == 1 - - -class TestSegmentServiceGetSegmentById: - """Tests for SegmentService.get_segment_by_id method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - def test_get_segment_by_id_success(self, mock_db_session): - """Test successful retrieval of segment by ID.""" - # Arrange - segment_id = "segment-123" - tenant_id = "tenant-123" - segment = SegmentTestDataFactory.create_segment_mock(segment_id=segment_id) - - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = segment - mock_db_session.query.return_value = mock_query - - # Act - result = SegmentService.get_segment_by_id(segment_id, tenant_id) - - # Assert - assert result == segment - - def test_get_segment_by_id_not_found(self, mock_db_session): - """Test retrieval when segment is not found.""" - # Arrange - segment_id = "non-existent" - tenant_id = "tenant-123" - - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - result = SegmentService.get_segment_by_id(segment_id, tenant_id) - - # Assert - assert result is None - - -class TestSegmentServiceGetChildChunks: - """Tests for SegmentService.get_child_chunks method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_get_child_chunks_success(self, mock_db_session, mock_current_user): - """Test successful retrieval of child chunks.""" - # Arrange - segment_id = "segment-123" - document_id = "doc-123" - dataset_id = "dataset-123" - page = 1 - limit = 20 - - mock_paginate = MagicMock() - mock_paginate.items = [ - SegmentTestDataFactory.create_child_chunk_mock(chunk_id="chunk-1"), - SegmentTestDataFactory.create_child_chunk_mock(chunk_id="chunk-2"), - ] - mock_paginate.total = 2 - mock_db_session.paginate.return_value = mock_paginate - - # Act - result = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit) - - # Assert - assert result == mock_paginate - mock_db_session.paginate.assert_called_once() - - def test_get_child_chunks_with_keyword(self, mock_db_session, mock_current_user): - """Test retrieval with keyword search.""" - # Arrange - segment_id = "segment-123" - document_id = "doc-123" - dataset_id = "dataset-123" - page = 1 - limit = 20 - keyword = "test" - - mock_paginate = MagicMock() - mock_paginate.items = [] - mock_paginate.total = 0 - mock_db_session.paginate.return_value = mock_paginate - - # Act - result = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword=keyword) - - # Assert - assert result == mock_paginate - - -class TestSegmentServiceGetChildChunkById: - """Tests for SegmentService.get_child_chunk_by_id method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - def test_get_child_chunk_by_id_success(self, mock_db_session): - """Test successful retrieval of child chunk by ID.""" - # Arrange - chunk_id = "chunk-123" - tenant_id = "tenant-123" - chunk = SegmentTestDataFactory.create_child_chunk_mock(chunk_id=chunk_id) - - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = chunk - mock_db_session.query.return_value = mock_query - - # Act - result = SegmentService.get_child_chunk_by_id(chunk_id, tenant_id) - - # Assert - assert result == chunk - - def test_get_child_chunk_by_id_not_found(self, mock_db_session): - """Test retrieval when child chunk is not found.""" - # Arrange - chunk_id = "non-existent" - tenant_id = "tenant-123" - - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - result = SegmentService.get_child_chunk_by_id(chunk_id, tenant_id) - - # Assert - assert result is None - - -class TestSegmentServiceCreateChildChunk: - """Tests for SegmentService.create_child_chunk method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_create_child_chunk_success(self, mock_db_session, mock_current_user): - """Test successful creation of a child chunk.""" - # Arrange - content = "New child chunk content" - segment = SegmentTestDataFactory.create_segment_mock() - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None - mock_db_session.query.return_value = mock_query - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_child_chunk_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_hash.return_value = "hash-123" - - # Act - result = SegmentService.create_child_chunk(content, segment, document, dataset) - - # Assert - assert result is not None - mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() - mock_vector_service.assert_called_once() - - def test_create_child_chunk_vector_index_failure(self, mock_db_session, mock_current_user): - """Test child chunk creation when vector indexing fails.""" - # Arrange - content = "New child chunk content" - segment = SegmentTestDataFactory.create_segment_mock() - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None - mock_db_session.query.return_value = mock_query - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_child_chunk_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_vector_service.side_effect = Exception("Vector indexing failed") - mock_hash.return_value = "hash-123" - - # Act & Assert - with pytest.raises(ChildChunkIndexingError): - SegmentService.create_child_chunk(content, segment, document, dataset) - - mock_db_session.rollback.assert_called_once() - - -class TestSegmentServiceUpdateChildChunk: - """Tests for SegmentService.update_child_chunk method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_update_child_chunk_success(self, mock_db_session, mock_current_user): - """Test successful update of a child chunk.""" - # Arrange - content = "Updated child chunk content" - chunk = SegmentTestDataFactory.create_child_chunk_mock() - segment = SegmentTestDataFactory.create_segment_mock() - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - with ( - patch( - "services.dataset_service.VectorService.update_child_chunk_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.update_child_chunk(content, chunk, segment, document, dataset) - - # Assert - assert result == chunk - assert chunk.content == content - assert chunk.word_count == len(content) - mock_db_session.add.assert_called_once_with(chunk) - mock_db_session.commit.assert_called_once() - mock_vector_service.assert_called_once() - - def test_update_child_chunk_vector_index_failure(self, mock_db_session, mock_current_user): - """Test child chunk update when vector indexing fails.""" - # Arrange - content = "Updated content" - chunk = SegmentTestDataFactory.create_child_chunk_mock() - segment = SegmentTestDataFactory.create_segment_mock() - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - with ( - patch( - "services.dataset_service.VectorService.update_child_chunk_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_vector_service.side_effect = Exception("Vector indexing failed") - mock_now.return_value = "2024-01-01T00:00:00" - - # Act & Assert - with pytest.raises(ChildChunkIndexingError): - SegmentService.update_child_chunk(content, chunk, segment, document, dataset) - - mock_db_session.rollback.assert_called_once() - - -class TestSegmentServiceDeleteChildChunk: - """Tests for SegmentService.delete_child_chunk method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - def test_delete_child_chunk_success(self, mock_db_session): - """Test successful deletion of a child chunk.""" - # Arrange - chunk = SegmentTestDataFactory.create_child_chunk_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - with patch( - "services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True - ) as mock_vector_service: - # Act - SegmentService.delete_child_chunk(chunk, dataset) - - # Assert - mock_db_session.delete.assert_called_once_with(chunk) - mock_db_session.commit.assert_called_once() - mock_vector_service.assert_called_once_with(chunk, dataset) - - def test_delete_child_chunk_vector_index_failure(self, mock_db_session): - """Test child chunk deletion when vector indexing fails.""" - # Arrange - chunk = SegmentTestDataFactory.create_child_chunk_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - with patch( - "services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True - ) as mock_vector_service: - mock_vector_service.side_effect = Exception("Vector deletion failed") - - # Act & Assert - with pytest.raises(ChildChunkDeleteIndexError): - SegmentService.delete_child_chunk(chunk, dataset) - - mock_db_session.rollback.assert_called_once() diff --git a/api/tests/unit_tests/services/services_test_help.py b/api/tests/unit_tests/services/services_test_help.py deleted file mode 100644 index c6b962f7fc..0000000000 --- a/api/tests/unit_tests/services/services_test_help.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import MagicMock - - -class ServiceDbTestHelper: - """ - Helper class for service database query tests. - """ - - @staticmethod - def setup_db_query_filter_by_mock(mock_db, query_results): - """ - Smart database query mock that responds based on model type and query parameters. - - Args: - mock_db: Mock database session - query_results: Dict mapping (model_name, filter_key, filter_value) to return value - Example: {('Account', 'email', 'test@example.com'): mock_account} - """ - - def query_side_effect(model): - mock_query = MagicMock() - - def filter_by_side_effect(**kwargs): - mock_filter_result = MagicMock() - - def first_side_effect(): - # Find matching result based on model and filter parameters - for (model_name, filter_key, filter_value), result in query_results.items(): - if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value: - return result - return None - - mock_filter_result.first.side_effect = first_side_effect - - # Handle order_by calls for complex queries - def order_by_side_effect(*args, **kwargs): - mock_order_result = MagicMock() - - def order_first_side_effect(): - # Look for order_by results in the same query_results dict - for (model_name, filter_key, filter_value), result in query_results.items(): - if ( - model.__name__ == model_name - and filter_key == "order_by" - and filter_value == "first_available" - ): - return result - return None - - mock_order_result.first.side_effect = order_first_side_effect - return mock_order_result - - mock_filter_result.order_by.side_effect = order_by_side_effect - return mock_filter_result - - mock_query.filter_by.side_effect = filter_by_side_effect - return mock_query - - mock_db.session.query.side_effect = query_side_effect diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index c4f5f57153..e9d2f1481e 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -14,7 +14,6 @@ from services.errors.account import ( AccountRegisterError, CurrentPasswordIncorrectError, ) -from tests.unit_tests.services.services_test_help import ServiceDbTestHelper class TestAccountAssociatedDataFactory: @@ -149,7 +148,6 @@ class TestAccountService: # Setup basic session methods mock_session.add = MagicMock() mock_session.commit = MagicMock() - mock_session.query = MagicMock() yield mock_db @@ -1572,15 +1570,9 @@ class TestRegisterService: account_id="existing-user-456", email="existing@example.com", status="active" ) - # Mock database queries - query_results = { - ( - "TenantAccountJoin", - "tenant_id", - "tenant-456", - ): TestAccountAssociatedDataFactory.create_tenant_join_mock(), - } - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies[ + "db" + ].session.scalar.return_value = TestAccountAssociatedDataFactory.create_tenant_join_mock() # Mock TenantService methods with ( diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index c2b430c551..119a7adc45 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -327,7 +327,8 @@ class TestGenerate: streaming=False, ) assert result == {"result": "advanced-blocking"} - assert gen_spy.call_args.kwargs.get("streaming") is False + call_kwargs = gen_spy.call_args.kwargs + assert call_kwargs.get("streaming") is False retrieve_spy.assert_not_called() # -- ADVANCED_CHAT streaming -------------------------------------------- diff --git a/api/tests/unit_tests/services/test_dataset_service_segment.py b/api/tests/unit_tests/services/test_dataset_service_segment.py index d6c104708c..5cfef76719 100644 --- a/api/tests/unit_tests/services/test_dataset_service_segment.py +++ b/api/tests/unit_tests/services/test_dataset_service_segment.py @@ -714,7 +714,6 @@ class TestSegmentServiceMutations: patch("services.dataset_service.db") as mock_db, patch("services.dataset_service.delete_segment_from_index_task") as delete_task, ): - segments_query = MagicMock() # execute().all() for segments_info (multi-column) execute_result = MagicMock() execute_result.all.return_value = [ diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index d304e0ec44..c389c4a635 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -36,9 +36,7 @@ class TestDatasourceProviderService: @pytest.fixture def mock_db_session(self): """ - Robust, chainable query mock. - q returns itself for .filter_by(), .order_by(), .where() so any - SQLAlchemy chaining pattern works without multiple brittle sub-mocks. + Mock session with scalar/scalars defaults for current SQLAlchemy access paths. """ with ( patch("services.datasource_provider_service.Session") as mock_cls, @@ -46,20 +44,6 @@ class TestDatasourceProviderService: ): sess = MagicMock(spec=Session) - q = MagicMock() - sess.query.return_value = q - - # Self-returning chain — any method called on q returns q - q.filter_by.return_value = q - q.order_by.return_value = q - q.where.return_value = q - - # Default terminal values (tests override per-case) - q.first.return_value = None - q.all.return_value = [] - q.count.return_value = 0 - q.delete.return_value = 1 - # Default values for select()-style calls (tests override per-case) sess.scalar.return_value = None sess.scalars.return_value.all.return_value = [] diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py index ebf1b36610..6eba60e5f1 100644 --- a/api/tests/unit_tests/services/test_trigger_provider_service.py +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -694,7 +694,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified( _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) mocker.patch( - "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + "services.trigger.trigger_provider_service.decrypt_system_params", return_value={"client_id": "system"}, ) @@ -716,7 +716,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails( _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) mocker.patch( - "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + "services.trigger.trigger_provider_service.decrypt_system_params", side_effect=RuntimeError("bad data"), ) diff --git a/api/tests/unit_tests/services/test_webhook_service_additional.py b/api/tests/unit_tests/services/test_webhook_service_additional.py index 776cb5dc3f..491dd94842 100644 --- a/api/tests/unit_tests/services/test_webhook_service_additional.py +++ b/api/tests/unit_tests/services/test_webhook_service_additional.py @@ -17,23 +17,6 @@ from services.trigger import webhook_service as service_module from services.trigger.webhook_service import WebhookService -class _FakeQuery: - def __init__(self, result: Any) -> None: - self._result = result - - def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery": - return self - - def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery": - return self - - def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery": - return self - - def first(self) -> Any: - return self._result - - @pytest.fixture def flask_app() -> Flask: return Flask(__name__) diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 0015e8b908..feafada59a 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -1649,8 +1649,6 @@ class TestWorkflowServiceCredentialValidation: """Missing BuiltinToolProvider → plugin requires no credentials → no error.""" # Arrange with patch("services.workflow_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None - # Act + Assert (should NOT raise) service._check_default_tool_credential("tenant-1", "some-provider") @@ -1662,10 +1660,6 @@ class TestWorkflowServiceCredentialValidation: patch("services.workflow_service.db") as mock_db, patch("core.helper.credential_utils.check_credential_policy_compliance", side_effect=Exception("denied")), ): - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( - mock_provider - ) - # Act + Assert with pytest.raises(ValueError, match="Failed to validate default credential"): service._check_default_tool_credential("tenant-1", "some-provider") diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py index 79a2d30f57..ce0d94398d 100644 --- a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py +++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py @@ -280,7 +280,7 @@ class TestGetOauthClient: assert result == {"client_id": "id", "client_secret": "secret"} - @patch(f"{MODULE}.decrypt_system_oauth_params", return_value={"sys_key": "sys_val"}) + @patch(f"{MODULE}.decrypt_system_params", return_value={"sys_key": "sys_val"}) @patch(f"{MODULE}.PluginService") @patch(f"{MODULE}.create_provider_encrypter") @patch(f"{MODULE}.ToolManager") diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py deleted file mode 100644 index ad80beb4e3..0000000000 --- a/api/tests/unit_tests/services/vector_service.py +++ /dev/null @@ -1,1793 +0,0 @@ -""" -Comprehensive unit tests for VectorService and Vector classes. - -This module contains extensive unit tests for the VectorService and Vector -classes, which are critical components in the RAG (Retrieval-Augmented Generation) -pipeline that handle vector database operations, collection management, embedding -storage and retrieval, and metadata filtering. - -The VectorService provides methods for: -- Creating vector embeddings for document segments -- Updating segment vector embeddings -- Generating child chunks for hierarchical indexing -- Managing child chunk vectors (create, update, delete) - -The Vector class provides methods for: -- Vector database operations (create, add, delete, search) -- Collection creation and management with Redis locking -- Embedding storage and retrieval -- Vector index operations (HNSW, L2 distance, etc.) -- Metadata filtering in vector space -- Support for multiple vector database backends - -This test suite ensures: -- Correct vector database operations -- Proper collection creation and management -- Accurate embedding storage and retrieval -- Comprehensive vector search functionality -- Metadata filtering and querying -- Error conditions are handled correctly -- Edge cases are properly validated - -================================================================================ -ARCHITECTURE OVERVIEW -================================================================================ - -The Vector service system is a critical component that bridges document -segments and vector databases, enabling semantic search and retrieval. - -1. VectorService: - - High-level service for managing vector operations on document segments - - Handles both regular segments and hierarchical (parent-child) indexing - - Integrates with IndexProcessor for document transformation - - Manages embedding model instances via ModelManager - -2. Vector Class: - - Wrapper around BaseVector implementations - - Handles embedding generation via ModelManager - - Supports multiple vector database backends (Chroma, Milvus, Qdrant, etc.) - - Manages collection creation with Redis locking for concurrency control - - Provides batch processing for large document sets - -3. BaseVector Abstract Class: - - Defines interface for vector database operations - - Implemented by various vector database backends - - Provides methods for CRUD operations on vectors - - Supports both vector similarity search and full-text search - -4. Collection Management: - - Uses Redis locks to prevent concurrent collection creation - - Caches collection existence status in Redis - - Supports collection deletion with cache invalidation - -5. Embedding Generation: - - Uses ModelManager to get embedding model instances - - Supports cached embeddings for performance - - Handles batch processing for large document sets - - Generates embeddings for both documents and queries - -================================================================================ -TESTING STRATEGY -================================================================================ - -This test suite follows a comprehensive testing strategy that covers: - -1. VectorService Methods: - - create_segments_vector: Regular and hierarchical indexing - - update_segment_vector: Vector and keyword index updates - - generate_child_chunks: Child chunk generation with full doc mode - - create_child_chunk_vector: Child chunk vector creation - - update_child_chunk_vector: Batch child chunk updates - - delete_child_chunk_vector: Child chunk deletion - -2. Vector Class Methods: - - Initialization with dataset and attributes - - Collection creation with Redis locking - - Embedding generation and batch processing - - Vector operations (create, add_texts, delete_by_ids, etc.) - - Search operations (by vector, by full text) - - Metadata filtering and querying - - Duplicate checking logic - - Vector factory selection - -3. Integration Points: - - ModelManager integration for embedding models - - IndexProcessor integration for document transformation - - Redis integration for locking and caching - - Database session management - - Vector database backend abstraction - -4. Error Handling: - - Invalid vector store configuration - - Missing embedding models - - Collection creation failures - - Search operation errors - - Metadata filtering errors - -5. Edge Cases: - - Empty document lists - - Missing metadata fields - - Duplicate document IDs - - Large batch processing - - Concurrent collection creation - -================================================================================ -""" - -from typing import Any -from unittest.mock import Mock, patch - -import pytest - -from core.rag.datasource.vdb.vector_base import BaseVector -from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.datasource.vdb.vector_type import VectorType -from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from core.rag.models.document import Document -from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment -from services.vector_service import VectorService - -# ============================================================================ -# Test Data Factory -# ============================================================================ - - -class VectorServiceTestDataFactory: - """ - Factory class for creating test data and mock objects for Vector service tests. - - This factory provides static methods to create mock objects for: - - Dataset instances with various configurations - - DocumentSegment instances - - ChildChunk instances - - Document instances (RAG documents) - - Embedding model instances - - Vector processor mocks - - Index processor mocks - - The factory methods help maintain consistency across tests and reduce - code duplication when setting up test scenarios. - """ - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - doc_form: str = IndexStructureType.PARAGRAPH_INDEX, - indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, - embedding_model_provider: str = "openai", - embedding_model: str = "text-embedding-ada-002", - index_struct_dict: dict[str, Any] | None = None, - **kwargs, - ) -> Mock: - """ - Create a mock Dataset with specified attributes. - - Args: - dataset_id: Unique identifier for the dataset - tenant_id: Tenant identifier - doc_form: Document form type - indexing_technique: Indexing technique (high_quality or economy) - embedding_model_provider: Embedding model provider - embedding_model: Embedding model name - index_struct_dict: Index structure dictionary - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a Dataset instance - """ - dataset = Mock(spec=Dataset) - - dataset.id = dataset_id - - dataset.tenant_id = tenant_id - - dataset.doc_form = doc_form - - dataset.indexing_technique = indexing_technique - - dataset.embedding_model_provider = embedding_model_provider - - dataset.embedding_model = embedding_model - - dataset.index_struct_dict = index_struct_dict - - for key, value in kwargs.items(): - setattr(dataset, key, value) - - return dataset - - @staticmethod - def create_document_segment_mock( - segment_id: str = "segment-123", - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - content: str = "Test segment content", - index_node_id: str = "node-123", - index_node_hash: str = "hash-123", - **kwargs, - ) -> Mock: - """ - Create a mock DocumentSegment with specified attributes. - - Args: - segment_id: Unique identifier for the segment - document_id: Parent document identifier - dataset_id: Dataset identifier - content: Segment content text - index_node_id: Index node identifier - index_node_hash: Index node hash - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a DocumentSegment instance - """ - segment = Mock(spec=DocumentSegment) - - segment.id = segment_id - - segment.document_id = document_id - - segment.dataset_id = dataset_id - - segment.content = content - - segment.index_node_id = index_node_id - - segment.index_node_hash = index_node_hash - - for key, value in kwargs.items(): - setattr(segment, key, value) - - return segment - - @staticmethod - def create_child_chunk_mock( - chunk_id: str = "chunk-123", - segment_id: str = "segment-123", - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - content: str = "Test child chunk content", - index_node_id: str = "node-chunk-123", - index_node_hash: str = "hash-chunk-123", - position: int = 1, - **kwargs, - ) -> Mock: - """ - Create a mock ChildChunk with specified attributes. - - Args: - chunk_id: Unique identifier for the child chunk - segment_id: Parent segment identifier - document_id: Parent document identifier - dataset_id: Dataset identifier - tenant_id: Tenant identifier - content: Child chunk content text - index_node_id: Index node identifier - index_node_hash: Index node hash - position: Position in parent segment - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a ChildChunk instance - """ - chunk = Mock(spec=ChildChunk) - - chunk.id = chunk_id - - chunk.segment_id = segment_id - - chunk.document_id = document_id - - chunk.dataset_id = dataset_id - - chunk.tenant_id = tenant_id - - chunk.content = content - - chunk.index_node_id = index_node_id - - chunk.index_node_hash = index_node_hash - - chunk.position = position - - for key, value in kwargs.items(): - setattr(chunk, key, value) - - return chunk - - @staticmethod - def create_dataset_document_mock( - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - dataset_process_rule_id: str = "rule-123", - doc_language: str = "en", - created_by: str = "user-123", - **kwargs, - ) -> Mock: - """ - Create a mock DatasetDocument with specified attributes. - - Args: - document_id: Unique identifier for the document - dataset_id: Dataset identifier - tenant_id: Tenant identifier - dataset_process_rule_id: Process rule identifier - doc_language: Document language - created_by: Creator user ID - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a DatasetDocument instance - """ - document = Mock(spec=DatasetDocument) - - document.id = document_id - - document.dataset_id = dataset_id - - document.tenant_id = tenant_id - - document.dataset_process_rule_id = dataset_process_rule_id - - document.doc_language = doc_language - - document.created_by = created_by - - for key, value in kwargs.items(): - setattr(document, key, value) - - return document - - @staticmethod - def create_dataset_process_rule_mock( - rule_id: str = "rule-123", - **kwargs, - ) -> Mock: - """ - Create a mock DatasetProcessRule with specified attributes. - - Args: - rule_id: Unique identifier for the process rule - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a DatasetProcessRule instance - """ - rule = Mock(spec=DatasetProcessRule) - - rule.id = rule_id - - rule.to_dict = Mock(return_value={"rules": {"parent_mode": "chunk"}}) - - for key, value in kwargs.items(): - setattr(rule, key, value) - - return rule - - @staticmethod - def create_rag_document_mock( - page_content: str = "Test document content", - doc_id: str = "doc-123", - doc_hash: str = "hash-123", - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - **kwargs, - ) -> Document: - """ - Create a RAG Document with specified attributes. - - Args: - page_content: Document content text - doc_id: Document identifier in metadata - doc_hash: Document hash in metadata - document_id: Parent document ID in metadata - dataset_id: Dataset ID in metadata - **kwargs: Additional metadata fields - - Returns: - Document instance configured for testing - """ - metadata = { - "doc_id": doc_id, - "doc_hash": doc_hash, - "document_id": document_id, - "dataset_id": dataset_id, - } - - metadata.update(kwargs) - - return Document(page_content=page_content, metadata=metadata) - - @staticmethod - def create_embedding_model_instance_mock() -> Mock: - """ - Create a mock embedding model instance. - - Returns: - Mock object configured as an embedding model instance - """ - model_instance = Mock() - - model_instance.embed_documents = Mock(return_value=[[0.1] * 1536]) - - model_instance.embed_query = Mock(return_value=[0.1] * 1536) - - return model_instance - - @staticmethod - def create_vector_processor_mock() -> Mock: - """ - Create a mock vector processor (BaseVector implementation). - - Returns: - Mock object configured as a BaseVector instance - """ - processor = Mock(spec=BaseVector) - - processor.collection_name = "test_collection" - - processor.create = Mock() - - processor.add_texts = Mock() - - processor.text_exists = Mock(return_value=False) - - processor.delete_by_ids = Mock() - - processor.delete_by_metadata_field = Mock() - - processor.search_by_vector = Mock(return_value=[]) - - processor.search_by_full_text = Mock(return_value=[]) - - processor.delete = Mock() - - return processor - - @staticmethod - def create_index_processor_mock() -> Mock: - """ - Create a mock index processor. - - Returns: - Mock object configured as an index processor instance - """ - processor = Mock() - - processor.load = Mock() - - processor.clean = Mock() - - processor.transform = Mock(return_value=[]) - - return processor - - -# ============================================================================ -# Tests for VectorService -# ============================================================================ - - -class TestVectorService: - """ - Comprehensive unit tests for VectorService class. - - This test class covers all methods of the VectorService class, including - segment vector operations, child chunk operations, and integration with - various components like IndexProcessor and ModelManager. - """ - - # ======================================================================== - # Tests for create_segments_vector - # ======================================================================== - - @patch("services.vector_service.IndexProcessorFactory") - @patch("services.vector_service.db") - def test_create_segments_vector_regular_indexing(self, mock_db, mock_index_processor_factory): - """ - Test create_segments_vector with regular indexing (non-hierarchical). - - This test verifies that segments are correctly converted to RAG documents - and loaded into the index processor for regular indexing scenarios. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique=IndexTechniqueType.HIGH_QUALITY - ) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - keywords_list = [["keyword1", "keyword2"]] - - mock_index_processor = VectorServiceTestDataFactory.create_index_processor_mock() - - mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor - - # Act - VectorService.create_segments_vector(keywords_list, [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) - - # Assert - mock_index_processor.load.assert_called_once() - - call_args = mock_index_processor.load.call_args - - assert call_args[0][0] == dataset - - assert len(call_args[0][1]) == 1 - - assert call_args[1]["with_keywords"] is True - - assert call_args[1]["keywords_list"] == keywords_list - - @patch("services.vector_service.VectorService.generate_child_chunks") - @patch("services.vector_service.ModelManager.for_tenant") - @patch("services.vector_service.db") - def test_create_segments_vector_parent_child_indexing( - self, mock_db, mock_model_manager, mock_generate_child_chunks - ): - """ - Test create_segments_vector with parent-child indexing. - - This test verifies that for hierarchical indexing, child chunks are - generated instead of regular segment indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY - ) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - processing_rule = VectorServiceTestDataFactory.create_dataset_process_rule_mock() - - mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset_document - - mock_db.session.query.return_value.where.return_value.first.return_value = processing_rule - - mock_embedding_model = VectorServiceTestDataFactory.create_embedding_model_instance_mock() - - mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_model - - # Act - VectorService.create_segments_vector(None, [segment], dataset, "parent_child_model") - - # Assert - mock_generate_child_chunks.assert_called_once() - - @patch("services.vector_service.db") - def test_create_segments_vector_missing_document(self, mock_db): - """ - Test create_segments_vector when document is missing. - - This test verifies that when a document is not found, the segment - is skipped with a warning log. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY - ) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act - VectorService.create_segments_vector(None, [segment], dataset, "parent_child_model") - - # Assert - # Should not raise an error, just skip the segment - - @patch("services.vector_service.db") - def test_create_segments_vector_missing_processing_rule(self, mock_db): - """ - Test create_segments_vector when processing rule is missing. - - This test verifies that when a processing rule is not found, a - ValueError is raised. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY - ) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset_document - - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act & Assert - with pytest.raises(ValueError, match="No processing rule found"): - VectorService.create_segments_vector(None, [segment], dataset, "parent_child_model") - - @patch("services.vector_service.db") - def test_create_segments_vector_economy_indexing_technique(self, mock_db): - """ - Test create_segments_vector with economy indexing technique. - - This test verifies that when indexing_technique is not high_quality, - a ValueError is raised for parent-child indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique=IndexTechniqueType.ECONOMY - ) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - processing_rule = VectorServiceTestDataFactory.create_dataset_process_rule_mock() - - mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset_document - - mock_db.session.query.return_value.where.return_value.first.return_value = processing_rule - - # Act & Assert - with pytest.raises(ValueError, match="The knowledge base index technique is not high quality"): - VectorService.create_segments_vector(None, [segment], dataset, "parent_child_model") - - @patch("services.vector_service.IndexProcessorFactory") - @patch("services.vector_service.db") - def test_create_segments_vector_empty_documents(self, mock_db, mock_index_processor_factory): - """ - Test create_segments_vector with empty documents list. - - This test verifies that when no documents are created, the index - processor is not called. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_index_processor = VectorServiceTestDataFactory.create_index_processor_mock() - - mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor - - # Act - VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) - - # Assert - mock_index_processor.load.assert_not_called() - - # ======================================================================== - # Tests for update_segment_vector - # ======================================================================== - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_update_segment_vector_high_quality(self, mock_db, mock_vector_class): - """ - Test update_segment_vector with high_quality indexing technique. - - This test verifies that segments are correctly updated in the vector - store when using high_quality indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.update_segment_vector(None, segment, dataset) - - # Assert - mock_vector.delete_by_ids.assert_called_once_with([segment.index_node_id]) - - mock_vector.add_texts.assert_called_once() - - @patch("services.vector_service.Keyword") - @patch("services.vector_service.db") - def test_update_segment_vector_economy_with_keywords(self, mock_db, mock_keyword_class): - """ - Test update_segment_vector with economy indexing and keywords. - - This test verifies that segments are correctly updated in the keyword - index when using economy indexing with keywords. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - keywords = ["keyword1", "keyword2"] - - mock_keyword = Mock() - - mock_keyword.delete_by_ids = Mock() - - mock_keyword.add_texts = Mock() - - mock_keyword_class.return_value = mock_keyword - - # Act - VectorService.update_segment_vector(keywords, segment, dataset) - - # Assert - mock_keyword.delete_by_ids.assert_called_once_with([segment.index_node_id]) - - mock_keyword.add_texts.assert_called_once() - - call_args = mock_keyword.add_texts.call_args - - assert call_args[1]["keywords_list"] == [keywords] - - @patch("services.vector_service.Keyword") - @patch("services.vector_service.db") - def test_update_segment_vector_economy_without_keywords(self, mock_db, mock_keyword_class): - """ - Test update_segment_vector with economy indexing without keywords. - - This test verifies that segments are correctly updated in the keyword - index when using economy indexing without keywords. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - mock_keyword = Mock() - - mock_keyword.delete_by_ids = Mock() - - mock_keyword.add_texts = Mock() - - mock_keyword_class.return_value = mock_keyword - - # Act - VectorService.update_segment_vector(None, segment, dataset) - - # Assert - mock_keyword.delete_by_ids.assert_called_once_with([segment.index_node_id]) - - mock_keyword.add_texts.assert_called_once() - - call_args = mock_keyword.add_texts.call_args - - assert "keywords_list" not in call_args[1] or call_args[1].get("keywords_list") is None - - # ======================================================================== - # Tests for generate_child_chunks - # ======================================================================== - - @patch("services.vector_service.IndexProcessorFactory") - @patch("services.vector_service.db") - def test_generate_child_chunks_with_children(self, mock_db, mock_index_processor_factory): - """ - Test generate_child_chunks when children are generated. - - This test verifies that child chunks are correctly generated and - saved to the database when the index processor returns children. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - processing_rule = VectorServiceTestDataFactory.create_dataset_process_rule_mock() - - embedding_model = VectorServiceTestDataFactory.create_embedding_model_instance_mock() - - child_document = VectorServiceTestDataFactory.create_rag_document_mock( - page_content="Child content", doc_id="child-node-123" - ) - - child_document.children = [child_document] - - mock_index_processor = VectorServiceTestDataFactory.create_index_processor_mock() - - mock_index_processor.transform.return_value = [child_document] - - mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor - - # Act - VectorService.generate_child_chunks(segment, dataset_document, dataset, embedding_model, processing_rule, False) - - # Assert - mock_index_processor.transform.assert_called_once() - - mock_index_processor.load.assert_called_once() - - mock_db.session.add.assert_called() - - mock_db.session.commit.assert_called_once() - - @patch("services.vector_service.IndexProcessorFactory") - @patch("services.vector_service.db") - def test_generate_child_chunks_regenerate(self, mock_db, mock_index_processor_factory): - """ - Test generate_child_chunks with regenerate=True. - - This test verifies that when regenerate is True, existing child chunks - are cleaned before generating new ones. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - processing_rule = VectorServiceTestDataFactory.create_dataset_process_rule_mock() - - embedding_model = VectorServiceTestDataFactory.create_embedding_model_instance_mock() - - mock_index_processor = VectorServiceTestDataFactory.create_index_processor_mock() - - mock_index_processor.transform.return_value = [] - - mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor - - # Act - VectorService.generate_child_chunks(segment, dataset_document, dataset, embedding_model, processing_rule, True) - - # Assert - mock_index_processor.clean.assert_called_once() - - call_args = mock_index_processor.clean.call_args - - assert call_args[0][0] == dataset - - assert call_args[0][1] == [segment.index_node_id] - - assert call_args[1]["with_keywords"] is True - - assert call_args[1]["delete_child_chunks"] is True - - @patch("services.vector_service.IndexProcessorFactory") - @patch("services.vector_service.db") - def test_generate_child_chunks_no_children(self, mock_db, mock_index_processor_factory): - """ - Test generate_child_chunks when no children are generated. - - This test verifies that when the index processor returns no children, - no child chunks are saved to the database. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - processing_rule = VectorServiceTestDataFactory.create_dataset_process_rule_mock() - - embedding_model = VectorServiceTestDataFactory.create_embedding_model_instance_mock() - - mock_index_processor = VectorServiceTestDataFactory.create_index_processor_mock() - - mock_index_processor.transform.return_value = [] - - mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor - - # Act - VectorService.generate_child_chunks(segment, dataset_document, dataset, embedding_model, processing_rule, False) - - # Assert - mock_index_processor.transform.assert_called_once() - - mock_index_processor.load.assert_not_called() - - mock_db.session.add.assert_not_called() - - # ======================================================================== - # Tests for create_child_chunk_vector - # ======================================================================== - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_create_child_chunk_vector_high_quality(self, mock_db, mock_vector_class): - """ - Test create_child_chunk_vector with high_quality indexing. - - This test verifies that child chunk vectors are correctly created - when using high_quality indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.create_child_chunk_vector(child_chunk, dataset) - - # Assert - mock_vector.add_texts.assert_called_once() - - call_args = mock_vector.add_texts.call_args - - assert call_args[1]["duplicate_check"] is True - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_create_child_chunk_vector_economy(self, mock_db, mock_vector_class): - """ - Test create_child_chunk_vector with economy indexing. - - This test verifies that child chunk vectors are not created when - using economy indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - - child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.create_child_chunk_vector(child_chunk, dataset) - - # Assert - mock_vector.add_texts.assert_not_called() - - # ======================================================================== - # Tests for update_child_chunk_vector - # ======================================================================== - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_update_child_chunk_vector_with_all_operations(self, mock_db, mock_vector_class): - """ - Test update_child_chunk_vector with new, update, and delete operations. - - This test verifies that child chunk vectors are correctly updated - when there are new chunks, updated chunks, and deleted chunks. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="new-chunk-1") - - update_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="update-chunk-1") - - delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="delete-chunk-1") - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.update_child_chunk_vector([new_chunk], [update_chunk], [delete_chunk], dataset) - - # Assert - mock_vector.delete_by_ids.assert_called_once() - - delete_ids = mock_vector.delete_by_ids.call_args[0][0] - - assert update_chunk.index_node_id in delete_ids - - assert delete_chunk.index_node_id in delete_ids - - mock_vector.add_texts.assert_called_once() - - call_args = mock_vector.add_texts.call_args - - assert len(call_args[0][0]) == 2 # new_chunk + update_chunk - - assert call_args[1]["duplicate_check"] is True - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_update_child_chunk_vector_only_new(self, mock_db, mock_vector_class): - """ - Test update_child_chunk_vector with only new chunks. - - This test verifies that when only new chunks are provided, only - add_texts is called, not delete_by_ids. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.update_child_chunk_vector([new_chunk], [], [], dataset) - - # Assert - mock_vector.delete_by_ids.assert_not_called() - - mock_vector.add_texts.assert_called_once() - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_update_child_chunk_vector_only_delete(self, mock_db, mock_vector_class): - """ - Test update_child_chunk_vector with only deleted chunks. - - This test verifies that when only deleted chunks are provided, only - delete_by_ids is called, not add_texts. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.update_child_chunk_vector([], [], [delete_chunk], dataset) - - # Assert - mock_vector.delete_by_ids.assert_called_once_with([delete_chunk.index_node_id]) - - mock_vector.add_texts.assert_not_called() - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_update_child_chunk_vector_economy(self, mock_db, mock_vector_class): - """ - Test update_child_chunk_vector with economy indexing. - - This test verifies that child chunk vectors are not updated when - using economy indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - - new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.update_child_chunk_vector([new_chunk], [], [], dataset) - - # Assert - mock_vector.delete_by_ids.assert_not_called() - - mock_vector.add_texts.assert_not_called() - - # ======================================================================== - # Tests for delete_child_chunk_vector - # ======================================================================== - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_delete_child_chunk_vector_high_quality(self, mock_db, mock_vector_class): - """ - Test delete_child_chunk_vector with high_quality indexing. - - This test verifies that child chunk vectors are correctly deleted - when using high_quality indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.delete_child_chunk_vector(child_chunk, dataset) - - # Assert - mock_vector.delete_by_ids.assert_called_once_with([child_chunk.index_node_id]) - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_delete_child_chunk_vector_economy(self, mock_db, mock_vector_class): - """ - Test delete_child_chunk_vector with economy indexing. - - This test verifies that child chunk vectors are not deleted when - using economy indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - - child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.delete_child_chunk_vector(child_chunk, dataset) - - # Assert - mock_vector.delete_by_ids.assert_not_called() - - -# ============================================================================ -# Tests for Vector Class -# ============================================================================ - - -class TestVector: - """ - Comprehensive unit tests for Vector class. - - This test class covers all methods of the Vector class, including - initialization, collection management, embedding operations, vector - database operations, and search functionality. - """ - - # ======================================================================== - # Tests for Vector Initialization - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_initialization_default_attributes(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector initialization with default attributes. - - This test verifies that Vector is correctly initialized with default - attributes when none are provided. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - # Act - vector = Vector(dataset=dataset) - - # Assert - assert vector._dataset == dataset - - assert vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash"] - - mock_get_embeddings.assert_called_once() - - mock_init_vector.assert_called_once() - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_initialization_custom_attributes(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector initialization with custom attributes. - - This test verifies that Vector is correctly initialized with custom - attributes when provided. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - custom_attributes = ["custom_attr1", "custom_attr2"] - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - # Act - vector = Vector(dataset=dataset, attributes=custom_attributes) - - # Assert - assert vector._dataset == dataset - - assert vector._attributes == custom_attributes - - # ======================================================================== - # Tests for Vector.create - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_create_with_texts(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.create with texts list. - - This test verifies that documents are correctly embedded and created - in the vector store with batch processing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - documents = [ - VectorServiceTestDataFactory.create_rag_document_mock(page_content=f"Content {i}") for i in range(5) - ] - - mock_embeddings = Mock() - - mock_embeddings.embed_documents = Mock(return_value=[[0.1] * 1536] * 5) - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.create(texts=documents) - - # Assert - mock_embeddings.embed_documents.assert_called() - - mock_vector_processor.create.assert_called() - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_create_empty_texts(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.create with empty texts list. - - This test verifies that when texts is None or empty, no operations - are performed. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.create(texts=None) - - # Assert - mock_embeddings.embed_documents.assert_not_called() - - mock_vector_processor.create.assert_not_called() - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_create_large_batch(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.create with large batch of documents. - - This test verifies that large batches are correctly processed in - chunks of 1000 documents. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - documents = [ - VectorServiceTestDataFactory.create_rag_document_mock(page_content=f"Content {i}") for i in range(2500) - ] - - mock_embeddings = Mock() - - mock_embeddings.embed_documents = Mock(return_value=[[0.1] * 1536] * 1000) - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.create(texts=documents) - - # Assert - # Should be called 3 times (1000, 1000, 500) - assert mock_embeddings.embed_documents.call_count == 3 - - assert mock_vector_processor.create.call_count == 3 - - # ======================================================================== - # Tests for Vector.add_texts - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_add_texts_without_duplicate_check(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.add_texts without duplicate check. - - This test verifies that documents are added without checking for - duplicates when duplicate_check is False. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - documents = [VectorServiceTestDataFactory.create_rag_document_mock()] - - mock_embeddings = Mock() - - mock_embeddings.embed_documents = Mock(return_value=[[0.1] * 1536]) - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.add_texts(documents, duplicate_check=False) - - # Assert - mock_embeddings.embed_documents.assert_called_once() - - mock_vector_processor.create.assert_called_once() - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_add_texts_with_duplicate_check(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.add_texts with duplicate check. - - This test verifies that duplicate documents are filtered out when - duplicate_check is True. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - documents = [VectorServiceTestDataFactory.create_rag_document_mock(doc_id="doc-123")] - - mock_embeddings = Mock() - - mock_embeddings.embed_documents = Mock(return_value=[[0.1] * 1536]) - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.text_exists = Mock(return_value=True) # Document exists - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.add_texts(documents, duplicate_check=True) - - # Assert - mock_vector_processor.text_exists.assert_called_once_with("doc-123") - - mock_embeddings.embed_documents.assert_not_called() - - mock_vector_processor.create.assert_not_called() - - # ======================================================================== - # Tests for Vector.text_exists - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_text_exists_true(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.text_exists when text exists. - - This test verifies that text_exists correctly returns True when - a document exists in the vector store. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.text_exists = Mock(return_value=True) - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - result = vector.text_exists("doc-123") - - # Assert - assert result is True - - mock_vector_processor.text_exists.assert_called_once_with("doc-123") - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_text_exists_false(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.text_exists when text does not exist. - - This test verifies that text_exists correctly returns False when - a document does not exist in the vector store. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.text_exists = Mock(return_value=False) - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - result = vector.text_exists("doc-123") - - # Assert - assert result is False - - mock_vector_processor.text_exists.assert_called_once_with("doc-123") - - # ======================================================================== - # Tests for Vector.delete_by_ids - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_delete_by_ids(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.delete_by_ids. - - This test verifies that documents are correctly deleted by their IDs. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - ids = ["doc-1", "doc-2", "doc-3"] - - # Act - vector.delete_by_ids(ids) - - # Assert - mock_vector_processor.delete_by_ids.assert_called_once_with(ids) - - # ======================================================================== - # Tests for Vector.delete_by_metadata_field - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_delete_by_metadata_field(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.delete_by_metadata_field. - - This test verifies that documents are correctly deleted by metadata - field value. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.delete_by_metadata_field("dataset_id", "dataset-123") - - # Assert - mock_vector_processor.delete_by_metadata_field.assert_called_once_with("dataset_id", "dataset-123") - - # ======================================================================== - # Tests for Vector.search_by_vector - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_search_by_vector(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.search_by_vector. - - This test verifies that vector search correctly embeds the query - and searches the vector store. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - query = "test query" - - query_vector = [0.1] * 1536 - - mock_embeddings = Mock() - - mock_embeddings.embed_query = Mock(return_value=query_vector) - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.search_by_vector = Mock(return_value=[]) - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - result = vector.search_by_vector(query) - - # Assert - mock_embeddings.embed_query.assert_called_once_with(query) - - mock_vector_processor.search_by_vector.assert_called_once_with(query_vector) - - assert result == [] - - # ======================================================================== - # Tests for Vector.search_by_full_text - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_search_by_full_text(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.search_by_full_text. - - This test verifies that full-text search correctly searches the - vector store without embedding the query. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - query = "test query" - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.search_by_full_text = Mock(return_value=[]) - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - result = vector.search_by_full_text(query) - - # Assert - mock_vector_processor.search_by_full_text.assert_called_once_with(query) - - assert result == [] - - # ======================================================================== - # Tests for Vector.delete - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.redis_client") - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_delete(self, mock_get_embeddings, mock_init_vector, mock_redis_client): - """ - Test Vector.delete. - - This test verifies that the collection is deleted and Redis cache - is cleared. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.collection_name = "test_collection" - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.delete() - - # Assert - mock_vector_processor.delete.assert_called_once() - - mock_redis_client.delete.assert_called_once_with("vector_indexing_test_collection") - - # ======================================================================== - # Tests for Vector.get_vector_factory - # ======================================================================== - - def test_vector_get_vector_factory_chroma(self): - """ - Test Vector.get_vector_factory for Chroma. - - This test verifies that the correct factory class is returned for - Chroma vector type. - """ - # Act - factory_class = Vector.get_vector_factory(VectorType.CHROMA) - - # Assert - assert factory_class is not None - - # Verify it's the correct factory by checking the module name - assert "chroma" in factory_class.__module__.lower() - - def test_vector_get_vector_factory_milvus(self): - """ - Test Vector.get_vector_factory for Milvus. - - This test verifies that the correct factory class is returned for - Milvus vector type. - """ - # Act - factory_class = Vector.get_vector_factory(VectorType.MILVUS) - - # Assert - assert factory_class is not None - - assert "milvus" in factory_class.__module__.lower() - - def test_vector_get_vector_factory_invalid_type(self): - """ - Test Vector.get_vector_factory with invalid vector type. - - This test verifies that a ValueError is raised when an invalid - vector type is provided. - """ - # Act & Assert - with pytest.raises(ValueError, match="Vector store .* is not supported"): - Vector.get_vector_factory("invalid_type") - - # ======================================================================== - # Tests for Vector._filter_duplicate_texts - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_filter_duplicate_texts(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector._filter_duplicate_texts. - - This test verifies that duplicate documents are correctly filtered - based on doc_id in metadata. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.text_exists = Mock(side_effect=[True, False]) # First exists, second doesn't - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - doc1 = VectorServiceTestDataFactory.create_rag_document_mock(doc_id="doc-1") - - doc2 = VectorServiceTestDataFactory.create_rag_document_mock(doc_id="doc-2") - - documents = [doc1, doc2] - - # Act - filtered = vector._filter_duplicate_texts(documents) - - # Assert - assert len(filtered) == 1 - - assert filtered[0].metadata["doc_id"] == "doc-2" - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_filter_duplicate_texts_no_metadata(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector._filter_duplicate_texts with documents without metadata. - - This test verifies that documents without metadata are not filtered. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - doc1 = Document(page_content="Content 1", metadata=None) - - doc2 = Document(page_content="Content 2", metadata={}) - - documents = [doc1, doc2] - - # Act - filtered = vector._filter_duplicate_texts(documents) - - # Assert - assert len(filtered) == 2 - - # ======================================================================== - # Tests for Vector._get_embeddings - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.CacheEmbedding") - @patch("core.rag.datasource.vdb.vector_factory.ModelManager.for_tenant") - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - def test_vector_get_embeddings(self, mock_init_vector, mock_model_manager, mock_cache_embedding): - """ - Test Vector._get_embeddings. - - This test verifies that embeddings are correctly retrieved from - ModelManager and wrapped in CacheEmbedding. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - embedding_model_provider="openai", embedding_model="text-embedding-ada-002" - ) - - mock_embedding_model = VectorServiceTestDataFactory.create_embedding_model_instance_mock() - - mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_model - - mock_cache_embedding_instance = Mock() - - mock_cache_embedding.return_value = mock_cache_embedding_instance - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - # Act - vector = Vector(dataset=dataset) - - # Assert - mock_model_manager.return_value.get_model_instance.assert_called_once() - - mock_cache_embedding.assert_called_once_with(mock_embedding_model) - - assert vector._embeddings == mock_cache_embedding_instance diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index d570dce107..dfdbd9acd6 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -1,14 +1,20 @@ import json import queue -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import UTC, datetime +from itertools import cycle from threading import Event +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock import pytest +from sqlalchemy.orm import Session, sessionmaker from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.task_entities import StreamEvent from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus @@ -18,11 +24,14 @@ from models.model import AppMode from models.workflow import WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot from repositories.entities.workflow_pause import WorkflowPauseEntity +from services import workflow_event_snapshot_service as service_module from services.workflow_event_snapshot_service import ( BufferState, MessageContext, _build_snapshot_events, + _is_terminal_event, _resolve_task_id, + build_workflow_event_stream, ) @@ -125,50 +134,6 @@ def _build_resumption_context(task_id: str) -> WorkflowResumptionContext: ) -def test_build_snapshot_events_includes_pause_event() -> None: - workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED) - snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED) - resumption_context = _build_resumption_context("task-ctx") - pause_entity = _FakePauseEntity( - pause_id="pause-1", - workflow_run_id="run-1", - paused_at_value=datetime(2024, 1, 1, tzinfo=UTC), - pause_reasons=[ - HumanInputRequired( - form_id="form-1", - form_content="content", - node_id="node-1", - node_title="Human Input", - ) - ], - ) - - events = _build_snapshot_events( - workflow_run=workflow_run, - node_snapshots=[snapshot], - task_id="task-ctx", - message_context=None, - pause_entity=pause_entity, - resumption_context=resumption_context, - ) - - assert [event["event"] for event in events] == [ - "workflow_started", - "node_started", - "node_finished", - "workflow_paused", - ] - assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value - pause_data = events[-1]["data"] - assert pause_data["paused_nodes"] == ["node-1"] - assert pause_data["outputs"] == {"result": "value"} - assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value - assert pause_data["created_at"] == int(workflow_run.created_at.timestamp()) - assert pause_data["elapsed_time"] == workflow_run.elapsed_time - assert pause_data["total_tokens"] == workflow_run.total_tokens - assert pause_data["total_steps"] == workflow_run.total_steps - - def test_build_snapshot_events_applies_message_context() -> None: workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING) snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED) @@ -222,3 +187,656 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) - buffer_state.task_id_ready.set() task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0) assert task_id == expected + + +def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun: + return WorkflowRun( + id="run-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + type="workflow", + triggered_from="app-run", + version="v1", + graph=None, + inputs=json.dumps({"query": "hello"}), + status=status, + outputs=json.dumps({}), + error=None, + elapsed_time=1.2, + total_tokens=5, + total_steps=2, + created_by_role=CreatorUserRole.END_USER, + created_by="user-1", + created_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + + +def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext: + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant-1", + app_id="app-1", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-1", + ) + generate_entity = WorkflowAppGenerateEntity( + task_id=task_id, + app_config=app_config, + inputs={}, + files=[], + user_id="user-1", + stream=True, + invoke_from=InvokeFrom.EXPLORE, + call_depth=0, + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + runtime_state.outputs = {"answer": "ok"} + wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity) + return WorkflowResumptionContext( + generate_entity=wrapper, + serialized_graph_runtime_state=runtime_state.dumps(), + ) + + +class _SessionContext: + def __init__(self, session: Any) -> None: + self._session = session + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +class _SessionMaker: + def __init__(self, session: Any) -> None: + self._session = session + + def __call__(self) -> _SessionContext: + return _SessionContext(self._session) + + +class _SubscriptionContext: + def __init__(self, subscription: Any) -> None: + self._subscription = subscription + + def __enter__(self) -> Any: + return self._subscription + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +class _Topic: + def __init__(self, subscription: Any) -> None: + self._subscription = subscription + + def subscribe(self) -> _SubscriptionContext: + return _SubscriptionContext(self._subscription) + + +class _StaticSubscription: + def receive(self, timeout: int = 1) -> None: + return None + + +@dataclass(frozen=True) +class _PauseEntity(WorkflowPauseEntity): + state: bytes + + @property + def id(self) -> str: + return "pause-1" + + @property + def workflow_execution_id(self) -> str: + return "run-1" + + @property + def resumed_at(self) -> datetime | None: + return None + + @property + def paused_at(self) -> datetime: + return datetime(2024, 1, 1, tzinfo=UTC) + + def get_state(self) -> bytes: + return self.state + + def get_pause_reasons(self) -> list[Any]: + return [] + + +def test_get_message_context_should_return_none_when_no_message() -> None: + # Arrange + session = SimpleNamespace(scalar=MagicMock(return_value=None)) + session_maker = _SessionMaker(session) + + # Act + result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1") + + # Assert + assert result is None + + +def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None: + # Arrange + message = SimpleNamespace( + id="msg-1", + conversation_id="conv-1", + created_at=None, + answer="answer", + ) + session = SimpleNamespace(scalar=MagicMock(return_value=message)) + session_maker = _SessionMaker(session) + + # Act + result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1") + + # Assert + assert result is not None + assert result.created_at == 0 + assert result.message_id == "msg-1" + assert result.conversation_id == "conv-1" + assert result.answer == "answer" + + +def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None: + # Arrange + + # Act + result = service_module._load_resumption_context(None) + + # Assert + assert result is None + + +def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None: + # Arrange + pause_entity = _PauseEntity(state=b"not-a-valid-state") + + # Act + result = service_module._load_resumption_context(pause_entity) + + # Assert + assert result is None + + +def test_load_resumption_context_should_parse_valid_state_into_context() -> None: + # Arrange + context = _build_resumption_context_additional(task_id="task-ctx") + pause_entity = _PauseEntity(state=context.dumps().encode()) + + # Act + result = service_module._load_resumption_context(pause_entity) + + # Assert + assert result is not None + assert result.get_generate_entity().task_id == "task-ctx" + + +def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None: + # Arrange + + # Act + result = service_module._resolve_task_id( + resumption_context=None, + buffer_state=None, + workflow_run_id="run-1", + ) + + # Assert + assert result == "run-1" + + +@pytest.mark.parametrize( + ("payload", "expected"), + [ + (b'{"event":"node_started"}', {"event": "node_started"}), + (b"invalid-json", None), + (b"[]", None), + ], +) +def test_parse_event_message_should_parse_only_json_object( + payload: bytes, + expected: dict[str, Any] | None, +) -> None: + # Arrange + + # Act + result = service_module._parse_event_message(payload) + + # Assert + assert result == expected + + +def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None: + # Arrange + finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value} + paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value} + + # Act + is_finished = service_module._is_terminal_event(finished_event, close_on_pause=False) + paused_without_flag = service_module._is_terminal_event(paused_event, close_on_pause=False) + paused_with_flag = service_module._is_terminal_event(paused_event, close_on_pause=True) + + # Assert + assert is_finished is True + assert paused_without_flag is False + assert paused_with_flag is True + assert service_module._is_terminal_event(StreamEvent.PING.value, close_on_pause=True) is False + + +def test_apply_message_context_should_update_payload_when_context_exists() -> None: + # Arrange + payload: dict[str, Any] = {"event": "workflow_started"} + context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000) + + # Act + service_module._apply_message_context(payload, context) + + # Assert + assert payload["conversation_id"] == "conv-1" + assert payload["message_id"] == "msg-1" + assert payload["created_at"] == 1700000000 + + +def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None: + # Arrange + class Subscription: + def __init__(self) -> None: + self._calls = 0 + + def receive(self, timeout: int = 1) -> bytes | None: + self._calls += 1 + if self._calls == 1: + return b'{"event":"node_started","task_id":"task-1"}' + return None + + subscription = Subscription() + + # Act + buffer_state = service_module._start_buffering(subscription) + ready = buffer_state.task_id_ready.wait(timeout=1) + event = buffer_state.queue.get(timeout=1) + buffer_state.stop_event.set() + finished = buffer_state.done_event.wait(timeout=1) + + # Assert + assert ready is True + assert finished is True + assert buffer_state.task_id_hint == "task-1" + assert event["event"] == "node_started" + + +def test_start_buffering_should_drop_old_event_when_queue_is_full( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + class QueueWithSingleFull: + def __init__(self) -> None: + self._first_put = True + self.items: list[dict[str, Any]] = [{"event": "old"}] + + def put_nowait(self, item: dict[str, Any]) -> None: + if self._first_put: + self._first_put = False + raise queue.Full + self.items.append(item) + + def get_nowait(self) -> dict[str, Any]: + if not self.items: + raise queue.Empty + return self.items.pop(0) + + def empty(self) -> bool: + return len(self.items) == 0 + + fake_queue = QueueWithSingleFull() + monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue) + + class Subscription: + def __init__(self) -> None: + self._calls = 0 + + def receive(self, timeout: int = 1) -> bytes | None: + self._calls += 1 + if self._calls == 1: + return b'{"event":"node_started","task_id":"task-2"}' + return None + + subscription = Subscription() + + # Act + buffer_state = service_module._start_buffering(subscription) + ready = buffer_state.task_id_ready.wait(timeout=1) + buffer_state.stop_event.set() + finished = buffer_state.done_event.wait(timeout=1) + + # Assert + assert ready is True + assert finished is True + assert fake_queue.items[-1]["task_id"] == "task-2" + + +def test_start_buffering_should_set_done_event_when_subscription_raises() -> None: + # Arrange + class Subscription: + def receive(self, timeout: int = 1) -> bytes | None: + raise RuntimeError("subscription failure") + + subscription = Subscription() + + # Act + buffer_state = service_module._start_buffering(subscription) + finished = buffer_state.done_event.wait(timeout=1) + + # Assert + assert finished is True + + +def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING) + topic = _Topic(_StaticSubscription()) + workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock()) + node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[])) + factory = SimpleNamespace( + create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo), + create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo), + ) + monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory) + monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic)) + monkeypatch.setattr( + service_module, + "_get_message_context", + MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)), + ) + monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None)) + buffer_state = BufferState( + queue=queue.Queue(), + stop_event=Event(), + done_event=Event(), + task_id_ready=Event(), + task_id_hint="task-1", + ) + monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state)) + monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1")) + monkeypatch.setattr( + service_module, + "_build_snapshot_events", + MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]), + ) + + # Act + events = list( + build_workflow_event_stream( + app_mode=AppMode.ADVANCED_CHAT, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=MagicMock(), + ) + ) + + # Assert + assert events[0] == StreamEvent.PING.value + finished_event = cast(Mapping[str, Any], events[1]) + assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value + assert buffer_state.stop_event.is_set() is True + node_repo.get_execution_snapshots_by_workflow_run.assert_called_once() + called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs + assert called_kwargs["workflow_run_id"] == "run-1" + + +def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING) + topic = _Topic(_StaticSubscription()) + workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock()) + node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[])) + factory = SimpleNamespace( + create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo), + create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo), + ) + monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory) + monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic)) + monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None)) + monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[])) + monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1")) + + class AlwaysEmptyQueue: + def empty(self) -> bool: + return False + + def get(self, timeout: int = 1) -> None: + raise queue.Empty + + buffer_state = BufferState( + queue=AlwaysEmptyQueue(), # type: ignore[arg-type] + stop_event=Event(), + done_event=Event(), + task_id_ready=Event(), + task_id_hint="task-1", + ) + monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state)) + time_values = cycle([0.0, 6.0, 21.0, 26.0]) + monkeypatch.setattr(service_module.time, "time", lambda: next(time_values)) + + # Act + events = list( + build_workflow_event_stream( + app_mode=AppMode.WORKFLOW, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=MagicMock(), + idle_timeout=20.0, + ping_interval=5.0, + ) + ) + + # Assert + assert events == [StreamEvent.PING.value, StreamEvent.PING.value] + assert buffer_state.stop_event.is_set() is True + + +def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING) + topic = _Topic(_StaticSubscription()) + workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock()) + node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[])) + factory = SimpleNamespace( + create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo), + create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo), + ) + monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory) + monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic)) + monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None)) + monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[])) + monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1")) + buffer_state = BufferState( + queue=queue.Queue(), + stop_event=Event(), + done_event=Event(), + task_id_ready=Event(), + task_id_hint="task-1", + ) + buffer_state.done_event.set() + monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state)) + + # Act + events = list( + build_workflow_event_stream( + app_mode=AppMode.WORKFLOW, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=MagicMock(), + ) + ) + + # Assert + assert events == [StreamEvent.PING.value] + assert buffer_state.stop_event.is_set() is True + + +def test_build_workflow_event_stream_should_continue_when_pause_loading_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED) + topic = _Topic(_StaticSubscription()) + workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom"))) + node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[])) + factory = SimpleNamespace( + create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo), + create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo), + ) + monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory) + monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic)) + monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None)) + monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1")) + snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}]) + monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder) + buffer_state = BufferState( + queue=queue.Queue(), + stop_event=Event(), + done_event=Event(), + task_id_ready=Event(), + task_id_hint="task-1", + ) + monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state)) + + # Act + events = list( + build_workflow_event_stream( + app_mode=AppMode.WORKFLOW, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=MagicMock(), + ) + ) + + # Assert + assert events[0] == StreamEvent.PING.value + assert snapshot_builder.call_args.kwargs["pause_entity"] is None + + +def test_is_terminal_event_respects_close_on_pause_flag() -> None: + pause_event = {"event": "workflow_paused"} + finish_event = {"event": "workflow_finished"} + + assert _is_terminal_event(pause_event, close_on_pause=True) is True + assert _is_terminal_event(pause_event, close_on_pause=False) is False + assert _is_terminal_event(finish_event, close_on_pause=False) is True + + +def test_build_snapshot_events_preserves_public_form_token(monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED) + snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED) + resumption_context = _build_resumption_context("task-ctx") + monkeypatch.setattr( + service_module, "load_form_tokens_by_form_id", lambda form_ids, session=None, surface=None: {"form-1": "wtok"} + ) + session_maker = _SessionMaker( + SimpleNamespace( + execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')], + ) + ) + pause_entity = _FakePauseEntity( + pause_id="pause-1", + workflow_run_id="run-1", + paused_at_value=datetime(2024, 1, 1, tzinfo=UTC), + pause_reasons=[ + HumanInputRequired( + form_id="form-1", + form_content="content", + node_id="node-1", + node_title="Human Input", + form_token="wtok", + ) + ], + ) + + events = _build_snapshot_events( + workflow_run=workflow_run, + node_snapshots=[snapshot], + task_id="task-ctx", + message_context=None, + pause_entity=pause_entity, + resumption_context=resumption_context, + session_maker=cast(sessionmaker[Session], session_maker), + ) + + assert events[-2]["event"] == StreamEvent.HUMAN_INPUT_REQUIRED.value + assert events[-2]["data"]["form_token"] == "wtok" + assert events[-2]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp()) + pause_data = events[-1]["data"] + assert pause_data["reasons"][0]["form_token"] == "wtok" + assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp()) + + +def test_build_workflow_event_stream_loads_pause_tokens_without_flask_app_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED) + topic = _Topic(_StaticSubscription()) + pause_entity = _FakePauseEntity( + pause_id="pause-1", + workflow_run_id="run-1", + paused_at_value=datetime(2024, 1, 1, tzinfo=UTC), + pause_reasons=[ + HumanInputRequired( + form_id="form-1", + form_content="content", + node_id="node-1", + node_title="Human Input", + ) + ], + ) + workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(return_value=pause_entity)) + node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[])) + factory = SimpleNamespace( + create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo), + create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo), + ) + monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory) + monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic)) + monkeypatch.setattr( + service_module, "_load_resumption_context", MagicMock(return_value=_build_resumption_context("task-1")) + ) + monkeypatch.setattr( + service_module, "load_form_tokens_by_form_id", lambda form_ids, session=None, surface=None: {"form-1": "wtok"} + ) + + session = SimpleNamespace( + scalar=MagicMock(return_value=None), + execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')], + ) + session_maker = _SessionMaker(session) + + events = list( + build_workflow_event_stream( + app_mode=AppMode.WORKFLOW, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=cast(sessionmaker[Session], session_maker), + ) + ) + + pause_event = cast(Mapping[str, Any], events[-1]) + assert pause_event["event"] == StreamEvent.WORKFLOW_PAUSED.value + assert pause_event["data"]["reasons"][0]["form_token"] == "wtok" + assert pause_event["data"]["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp()) diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 5dad58b8f1..b74079bd69 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -89,9 +89,6 @@ def mock_db_session(): session = MagicMock() session._shared_data = {"dataset": None, "documents": []} - # Keep a pointer so repeated Document.first() calls iterate across provided docs - session._doc_first_idx = 0 - def _get_entity(stmt) -> type | None: """Extract the mapped entity class from a SQLAlchemy select statement.""" try: @@ -1591,18 +1588,7 @@ class TestDocumentIndexingTaskSummaryFlow: need_summary=True, ) - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - phase1_docs = [SimpleNamespace(id="doc-1"), SimpleNamespace(id="doc-2"), SimpleNamespace(id="doc-3")] - phase1_document_query = MagicMock() - phase1_document_query.where.return_value = phase1_document_query - phase1_document_query.all.return_value = phase1_docs - - summary_document_query = MagicMock() - summary_document_query.where.return_value = summary_document_query - summary_document_query.all.return_value = [doc_eligible, doc_skip_form, doc_skip_status] session1 = MagicMock() session2 = MagicMock() @@ -1657,18 +1643,6 @@ class TestDocumentIndexingTaskSummaryFlow: need_summary=True, ) - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - phase1_query = MagicMock() - phase1_query.where.return_value = phase1_query - phase1_query.all.return_value = [SimpleNamespace(id="doc-1")] - - summary_query = MagicMock() - summary_query.where.return_value = summary_query - summary_query.all.return_value = [doc_eligible] - session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() diff --git a/api/tests/unit_tests/tasks/test_workflow_execute_task.py b/api/tests/unit_tests/tasks/test_workflow_execute_task.py index d3cf632b47..72508bef52 100644 --- a/api/tests/unit_tests/tasks/test_workflow_execute_task.py +++ b/api/tests/unit_tests/tasks/test_workflow_execute_task.py @@ -7,11 +7,17 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from models.enums import CreatorUserRole from models.model import App, AppMode, Conversation from models.workflow import Workflow, WorkflowRun -from tasks.app_generate.workflow_execute_task import _publish_streaming_response, _resume_app_execution +from repositories.sqlalchemy_api_workflow_run_repository import _WorkflowRunError +from tasks.app_generate.workflow_execute_task import ( + _publish_streaming_response, + _resume_advanced_chat, + _resume_app_execution, + _resume_workflow, +) class _FakeSessionContext: @@ -38,12 +44,28 @@ def _build_advanced_chat_generate_entity(conversation_id: str | None) -> Advance ) +def _build_workflow_generate_entity(stream: bool) -> WorkflowAppGenerateEntity: + return WorkflowAppGenerateEntity( + task_id="task-id", + inputs={}, + files=[], + user_id="user-id", + stream=stream, + invoke_from=InvokeFrom.WEB_APP, + workflow_execution_id="workflow-run-id", + ) + + +def _single_event_generator(payload): + yield payload + + @pytest.fixture -def mock_topic(mocker) -> MagicMock: +def mock_topic(monkeypatch: pytest.MonkeyPatch) -> MagicMock: topic = MagicMock() - mocker.patch( + monkeypatch.setattr( "tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic", - return_value=topic, + lambda *_args, **_kwargs: topic, ) return topic @@ -67,31 +89,35 @@ def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock): mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode()) -def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(mocker): +def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(monkeypatch: pytest.MonkeyPatch): workflow_run_id = "run-id" conversation_id = "conversation-id" message = MagicMock() - mocker.patch("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object())) + monkeypatch.setattr("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object())) pause_entity = MagicMock() pause_entity.get_state.return_value = b"state" workflow_run_repo = MagicMock() workflow_run_repo.get_workflow_pause.return_value = pause_entity - mocker.patch( + monkeypatch.setattr( "tasks.app_generate.workflow_execute_task.DifyAPIRepositoryFactory.create_api_workflow_run_repository", - return_value=workflow_run_repo, + lambda *_args, **_kwargs: workflow_run_repo, ) generate_entity = _build_advanced_chat_generate_entity(conversation_id) resumption_context = MagicMock() resumption_context.serialized_graph_runtime_state = "{}" resumption_context.get_generate_entity.return_value = generate_entity - mocker.patch( - "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads", return_value=resumption_context + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads", + lambda *_args, **_kwargs: resumption_context, + ) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot", + lambda *_args, **_kwargs: MagicMock(), ) - mocker.patch("tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot", return_value=MagicMock()) workflow_run = SimpleNamespace( workflow_id="wf-id", @@ -120,10 +146,15 @@ def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(m session.get.side_effect = _session_get session.scalar.return_value = message - mocker.patch("tasks.app_generate.workflow_execute_task.Session", return_value=_FakeSessionContext(session)) - mocker.patch("tasks.app_generate.workflow_execute_task._resolve_user_for_run", return_value=MagicMock()) - resume_advanced_chat = mocker.patch("tasks.app_generate.workflow_execute_task._resume_advanced_chat") - mocker.patch("tasks.app_generate.workflow_execute_task._resume_workflow") + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.Session", lambda *_args, **_kwargs: _FakeSessionContext(session) + ) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task._resolve_user_for_run", lambda *_args, **_kwargs: MagicMock() + ) + resume_advanced_chat = MagicMock() + monkeypatch.setattr("tasks.app_generate.workflow_execute_task._resume_advanced_chat", resume_advanced_chat) + monkeypatch.setattr("tasks.app_generate.workflow_execute_task._resume_workflow", MagicMock()) _resume_app_execution({"workflow_run_id": workflow_run_id}) @@ -144,29 +175,35 @@ def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(m assert resume_advanced_chat.call_args.kwargs["message"] is message -def test_resume_app_execution_returns_early_when_advanced_chat_missing_conversation_id(mocker): +def test_resume_app_execution_returns_early_when_advanced_chat_missing_conversation_id( + monkeypatch: pytest.MonkeyPatch, +): workflow_run_id = "run-id" - mocker.patch("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object())) + monkeypatch.setattr("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object())) pause_entity = MagicMock() pause_entity.get_state.return_value = b"state" workflow_run_repo = MagicMock() workflow_run_repo.get_workflow_pause.return_value = pause_entity - mocker.patch( + monkeypatch.setattr( "tasks.app_generate.workflow_execute_task.DifyAPIRepositoryFactory.create_api_workflow_run_repository", - return_value=workflow_run_repo, + lambda *_args, **_kwargs: workflow_run_repo, ) generate_entity = _build_advanced_chat_generate_entity(conversation_id=None) resumption_context = MagicMock() resumption_context.serialized_graph_runtime_state = "{}" resumption_context.get_generate_entity.return_value = generate_entity - mocker.patch( - "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads", return_value=resumption_context + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads", + lambda *_args, **_kwargs: resumption_context, + ) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot", + lambda *_args, **_kwargs: MagicMock(), ) - mocker.patch("tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot", return_value=MagicMock()) workflow_run = SimpleNamespace( workflow_id="wf-id", @@ -191,12 +228,152 @@ def test_resume_app_execution_returns_early_when_advanced_chat_missing_conversat session.get.side_effect = _session_get - mocker.patch("tasks.app_generate.workflow_execute_task.Session", return_value=_FakeSessionContext(session)) - mocker.patch("tasks.app_generate.workflow_execute_task._resolve_user_for_run", return_value=MagicMock()) - resume_advanced_chat = mocker.patch("tasks.app_generate.workflow_execute_task._resume_advanced_chat") + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.Session", lambda *_args, **_kwargs: _FakeSessionContext(session) + ) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task._resolve_user_for_run", lambda *_args, **_kwargs: MagicMock() + ) + resume_advanced_chat = MagicMock() + monkeypatch.setattr("tasks.app_generate.workflow_execute_task._resume_advanced_chat", resume_advanced_chat) _resume_app_execution({"workflow_run_id": workflow_run_id}) session.scalar.assert_not_called() workflow_run_repo.resume_workflow_pause.assert_not_called() resume_advanced_chat.assert_not_called() + + +def test_resume_advanced_chat_publishes_events_for_originally_blocking_runs(monkeypatch: pytest.MonkeyPatch): + generate_entity = _build_advanced_chat_generate_entity(conversation_id="conversation-id") + generate_entity.stream = False + + generator_instance = MagicMock() + response_stream = _single_event_generator({"event": "message"}) + generator_instance.resume.return_value = response_stream + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.AdvancedChatAppGenerator", + lambda: generator_instance, + ) + + publish_streaming_response = MagicMock() + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task._publish_streaming_response", publish_streaming_response + ) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: MagicMock(), + ) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: MagicMock(), + ) + + _resume_advanced_chat( + app_model=SimpleNamespace(id="app-id"), + workflow=SimpleNamespace(created_by="workflow-owner"), + user=MagicMock(), + conversation=SimpleNamespace(id="conversation-id"), + message=MagicMock(), + generate_entity=generate_entity, + graph_runtime_state=MagicMock(), + session_factory=MagicMock(), + pause_state_config=MagicMock(), + workflow_run_id="workflow-run-id", + workflow_run=SimpleNamespace(triggered_from="app_run"), + ) + + resumed_entity = generator_instance.resume.call_args.kwargs["application_generate_entity"] + assert resumed_entity.stream is True + publish_streaming_response.assert_called_once_with(response_stream, "workflow-run-id", AppMode.ADVANCED_CHAT) + + +def test_resume_workflow_publishes_events_for_originally_blocking_runs(monkeypatch: pytest.MonkeyPatch): + generate_entity = _build_workflow_generate_entity(stream=False) + + generator_instance = MagicMock() + response_stream = _single_event_generator({"event": "workflow_finished"}) + generator_instance.resume.return_value = response_stream + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.WorkflowAppGenerator", + lambda: generator_instance, + ) + + publish_streaming_response = MagicMock() + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task._publish_streaming_response", publish_streaming_response + ) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: MagicMock(), + ) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: MagicMock(), + ) + workflow_run_repo = MagicMock() + pause_entity = MagicMock() + + _resume_workflow( + app_model=SimpleNamespace(id="app-id"), + workflow=SimpleNamespace(created_by="workflow-owner"), + user=MagicMock(), + generate_entity=generate_entity, + graph_runtime_state=MagicMock(), + session_factory=MagicMock(), + pause_state_config=MagicMock(), + workflow_run_id="workflow-run-id", + workflow_run=SimpleNamespace(triggered_from="app_run"), + workflow_run_repo=workflow_run_repo, + pause_entity=pause_entity, + ) + + resumed_entity = generator_instance.resume.call_args.kwargs["application_generate_entity"] + assert resumed_entity.stream is True + publish_streaming_response.assert_called_once_with(response_stream, "workflow-run-id", AppMode.WORKFLOW) + workflow_run_repo.delete_workflow_pause.assert_called_once_with(pause_entity) + + +def test_resume_workflow_ignores_missing_old_pause_after_repause(monkeypatch: pytest.MonkeyPatch): + generate_entity = _build_workflow_generate_entity(stream=False) + + generator_instance = MagicMock() + response_stream = _single_event_generator({"event": "workflow_paused"}) + generator_instance.resume.return_value = response_stream + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.WorkflowAppGenerator", + lambda: generator_instance, + ) + + publish_streaming_response = MagicMock() + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task._publish_streaming_response", publish_streaming_response + ) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: MagicMock(), + ) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: MagicMock(), + ) + workflow_run_repo = MagicMock() + workflow_run_repo.delete_workflow_pause.side_effect = _WorkflowRunError("WorkflowPause not found: old-pause") + pause_entity = MagicMock() + + _resume_workflow( + app_model=SimpleNamespace(id="app-id"), + workflow=SimpleNamespace(created_by="workflow-owner"), + user=MagicMock(), + generate_entity=generate_entity, + graph_runtime_state=MagicMock(), + session_factory=MagicMock(), + pause_state_config=MagicMock(), + workflow_run_id="workflow-run-id", + workflow_run=SimpleNamespace(triggered_from="app_run"), + workflow_run_repo=workflow_run_repo, + pause_entity=pause_entity, + ) + + publish_streaming_response.assert_called_once_with(response_stream, "workflow-run-id", AppMode.WORKFLOW) + workflow_run_repo.delete_workflow_pause.assert_called_once_with(pause_entity) diff --git a/api/tests/unit_tests/utils/encryption/test_system_encryption.py b/api/tests/unit_tests/utils/encryption/test_system_encryption.py new file mode 100644 index 0000000000..0435facfdb --- /dev/null +++ b/api/tests/unit_tests/utils/encryption/test_system_encryption.py @@ -0,0 +1,619 @@ +import base64 +import hashlib +from unittest.mock import patch + +import pytest +from Crypto.Cipher import AES +from Crypto.Random import get_random_bytes +from Crypto.Util.Padding import pad + +from core.tools.utils.system_encryption import ( + EncryptionError, + SystemEncrypter, + create_system_encrypter, + decrypt_system_params, + encrypt_system_params, + get_system_encrypter, +) + + +class TestSystemEncrypter: + """Test cases for SystemEncrypter class""" + + def test_init_with_secret_key(self): + """Test initialization with provided secret key""" + secret_key = "test_secret_key" + encrypter = SystemEncrypter(secret_key=secret_key) + expected_key = hashlib.sha256(secret_key.encode()).digest() + assert encrypter.key == expected_key + + def test_init_with_none_secret_key(self): + """Test initialization with None secret key falls back to config""" + with patch("core.tools.utils.system_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "config_secret" + encrypter = SystemEncrypter(secret_key=None) + expected_key = hashlib.sha256(b"config_secret").digest() + assert encrypter.key == expected_key + + def test_init_with_empty_secret_key(self): + """Test initialization with empty secret key""" + encrypter = SystemEncrypter(secret_key="") + expected_key = hashlib.sha256(b"").digest() + assert encrypter.key == expected_key + + def test_init_without_secret_key_uses_config(self): + """Test initialization without secret key uses config""" + with patch("core.tools.utils.system_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "default_secret" + encrypter = SystemEncrypter() + expected_key = hashlib.sha256(b"default_secret").digest() + assert encrypter.key == expected_key + + def test_encrypt_params_basic(self): + """Test basic parameters encryption""" + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypter.encrypt_params(params) + + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + # Should be valid base64 + try: + base64.b64decode(encrypted) + except Exception: + pytest.fail("Encrypted result is not valid base64") + + def test_encrypt_params_empty_dict(self): + """Test encryption with empty dictionary""" + encrypter = SystemEncrypter("test_secret") + params = {} + + encrypted = encrypter.encrypt_params(params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_params_complex_data(self): + """Test encryption with complex data structures""" + encrypter = SystemEncrypter("test_secret") + params = { + "client_id": "test_id", + "client_secret": "test_secret", + "scopes": ["read", "write", "admin"], + "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True}, + "numeric_value": 42, + "boolean_value": False, + "null_value": None, + } + + encrypted = encrypter.encrypt_params(params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_params_unicode_data(self): + """Test encryption with unicode data""" + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"} + + encrypted = encrypter.encrypt_params(params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_params_large_data(self): + """Test encryption with large data""" + encrypter = SystemEncrypter("test_secret") + params = { + "client_id": "test_id", + "large_data": "x" * 10000, # 10KB of data + } + + encrypted = encrypter.encrypt_params(params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_params_invalid_input(self): + """Test encryption with invalid input types""" + encrypter = SystemEncrypter("test_secret") + + with pytest.raises(Exception): # noqa: B017 + encrypter.encrypt_params(None) + + with pytest.raises(Exception): # noqa: B017 + encrypter.encrypt_params("not_a_dict") + + def test_decrypt_params_basic(self): + """Test basic parameters decryption""" + encrypter = SystemEncrypter("test_secret") + original_params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_params_empty_dict(self): + """Test decryption of empty dictionary""" + encrypter = SystemEncrypter("test_secret") + original_params = {} + + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_params_complex_data(self): + """Test decryption with complex data structures""" + encrypter = SystemEncrypter("test_secret") + original_params = { + "client_id": "test_id", + "client_secret": "test_secret", + "scopes": ["read", "write", "admin"], + "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True}, + "numeric_value": 42, + "boolean_value": False, + "null_value": None, + } + + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_params_unicode_data(self): + """Test decryption with unicode data""" + encrypter = SystemEncrypter("test_secret") + original_params = { + "client_id": "test_id", + "client_secret": "test_secret", + "description": "This is a test case 🚀", + } + + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_params_large_data(self): + """Test decryption with large data""" + encrypter = SystemEncrypter("test_secret") + original_params = { + "client_id": "test_id", + "large_data": "x" * 10000, # 10KB of data + } + + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_params_invalid_base64(self): + """Test decryption with invalid base64 data""" + encrypter = SystemEncrypter("test_secret") + + with pytest.raises(EncryptionError): + encrypter.decrypt_params("invalid_base64!") + + def test_decrypt_params_empty_string(self): + """Test decryption with empty string""" + encrypter = SystemEncrypter("test_secret") + + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_params("") + + assert "encrypted_data cannot be empty" in str(exc_info.value) + + def test_decrypt_params_non_string_input(self): + """Test decryption with non-string input""" + encrypter = SystemEncrypter("test_secret") + + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_params(123) + + assert "encrypted_data must be a string" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_params(None) + + assert "encrypted_data must be a string" in str(exc_info.value) + + def test_decrypt_params_too_short_data(self): + """Test decryption with too short encrypted data""" + encrypter = SystemEncrypter("test_secret") + + # Create data that's too short (less than 32 bytes) + short_data = base64.b64encode(b"short").decode() + + with pytest.raises(EncryptionError) as exc_info: + encrypter.decrypt_params(short_data) + + assert "Invalid encrypted data format" in str(exc_info.value) + + def test_decrypt_params_corrupted_data(self): + """Test decryption with corrupted data""" + encrypter = SystemEncrypter("test_secret") + + # Create corrupted data (valid base64 but invalid encrypted content) + corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage + + with pytest.raises(EncryptionError): + encrypter.decrypt_params(corrupted_data) + + def test_decrypt_params_wrong_key(self): + """Test decryption with wrong key""" + encrypter1 = SystemEncrypter("secret1") + encrypter2 = SystemEncrypter("secret2") + + original_params = {"client_id": "test_id", "client_secret": "test_secret"} + encrypted = encrypter1.encrypt_params(original_params) + + with pytest.raises(EncryptionError): + encrypter2.decrypt_params(encrypted) + + def test_encryption_decryption_consistency(self): + """Test that encryption and decryption are consistent""" + encrypter = SystemEncrypter("test_secret") + + test_cases = [ + {}, + {"simple": "value"}, + {"client_id": "id", "client_secret": "secret"}, + {"complex": {"nested": {"deep": "value"}}}, + {"unicode": "test 🚀"}, + {"numbers": 42, "boolean": True, "null": None}, + {"array": [1, 2, 3, "four", {"five": 5}]}, + ] + + for original_params in test_cases: + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == original_params, f"Failed for case: {original_params}" + + def test_encryption_randomness(self): + """Test that encryption produces different results for same input""" + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted1 = encrypter.encrypt_params(params) + encrypted2 = encrypter.encrypt_params(params) + + # Should be different due to random IV + assert encrypted1 != encrypted2 + + # But should decrypt to same result + decrypted1 = encrypter.decrypt_params(encrypted1) + decrypted2 = encrypter.decrypt_params(encrypted2) + assert decrypted1 == decrypted2 == params + + def test_different_secret_keys_produce_different_results(self): + """Test that different secret keys produce different encrypted results""" + encrypter1 = SystemEncrypter("secret1") + encrypter2 = SystemEncrypter("secret2") + + params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted1 = encrypter1.encrypt_params(params) + encrypted2 = encrypter2.encrypt_params(params) + + # Should produce different encrypted results + assert encrypted1 != encrypted2 + + # But each should decrypt correctly with its own key + decrypted1 = encrypter1.decrypt_params(encrypted1) + decrypted2 = encrypter2.decrypt_params(encrypted2) + assert decrypted1 == decrypted2 == params + + @patch("core.tools.utils.system_encryption.get_random_bytes") + def test_encrypt_params_crypto_error(self, mock_get_random_bytes): + """Test encryption when crypto operation fails""" + mock_get_random_bytes.side_effect = Exception("Crypto error") + + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id"} + + with pytest.raises(EncryptionError) as exc_info: + encrypter.encrypt_params(params) + + assert "Encryption failed" in str(exc_info.value) + + @patch("core.tools.utils.system_encryption.TypeAdapter") + def test_encrypt_params_serialization_error(self, mock_type_adapter): + """Test encryption when JSON serialization fails""" + mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error") + + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id"} + + with pytest.raises(EncryptionError) as exc_info: + encrypter.encrypt_params(params) + + assert "Encryption failed" in str(exc_info.value) + + def test_decrypt_params_invalid_json(self): + """Test decryption with invalid JSON data""" + encrypter = SystemEncrypter("test_secret") + + # Create valid encrypted data but with invalid JSON content + iv = get_random_bytes(16) + cipher = AES.new(encrypter.key, AES.MODE_CBC, iv) + invalid_json = b"invalid json content" + padded_data = pad(invalid_json, AES.block_size) + encrypted_data = cipher.encrypt(padded_data) + combined = iv + encrypted_data + encoded = base64.b64encode(combined).decode() + + with pytest.raises(EncryptionError): + encrypter.decrypt_params(encoded) + + def test_key_derivation_consistency(self): + """Test that key derivation is consistent""" + secret_key = "test_secret" + encrypter1 = SystemEncrypter(secret_key) + encrypter2 = SystemEncrypter(secret_key) + + assert encrypter1.key == encrypter2.key + + # Keys should be 32 bytes (256 bits) + assert len(encrypter1.key) == 32 + + +class TestFactoryFunctions: + """Test cases for factory functions""" + + def test_create_system_encrypter_with_secret(self): + """Test factory function with secret key""" + secret_key = "test_secret" + encrypter = create_system_encrypter(secret_key) + + assert isinstance(encrypter, SystemEncrypter) + expected_key = hashlib.sha256(secret_key.encode()).digest() + assert encrypter.key == expected_key + + def test_create_system_encrypter_without_secret(self): + """Test factory function without secret key""" + with patch("core.tools.utils.system_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "config_secret" + encrypter = create_system_encrypter() + + assert isinstance(encrypter, SystemEncrypter) + expected_key = hashlib.sha256(b"config_secret").digest() + assert encrypter.key == expected_key + + def test_create_system_encrypter_with_none_secret(self): + """Test factory function with None secret key""" + with patch("core.tools.utils.system_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "config_secret" + encrypter = create_system_encrypter(None) + + assert isinstance(encrypter, SystemEncrypter) + expected_key = hashlib.sha256(b"config_secret").digest() + assert encrypter.key == expected_key + + +class TestGlobalEncrypterInstance: + """Test cases for global encrypter instance""" + + def test_get_system_encrypter_singleton(self): + """Test that get_system_encrypter returns singleton instance""" + # Clear the global instance first + import core.tools.utils.system_encryption + + core.tools.utils.system_encryption._encrypter = None + + encrypter1 = get_system_encrypter() + encrypter2 = get_system_encrypter() + + assert encrypter1 is encrypter2 + assert isinstance(encrypter1, SystemEncrypter) + + def test_get_system_encrypter_uses_config(self): + """Test that global encrypter uses config""" + # Clear the global instance first + import core.tools.utils.system_encryption + + core.tools.utils.system_encryption._encrypter = None + + with patch("core.tools.utils.system_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "global_secret" + encrypter = get_system_encrypter() + + expected_key = hashlib.sha256(b"global_secret").digest() + assert encrypter.key == expected_key + + +class TestConvenienceFunctions: + """Test cases for convenience functions""" + + def test_encrypt_system_params(self): + """Test encrypt_system_params convenience function""" + params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypt_system_params(params) + + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_decrypt_system_params(self): + """Test decrypt_system_params convenience function""" + params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypt_system_params(params) + decrypted = decrypt_system_params(encrypted) + + assert decrypted == params + + def test_convenience_functions_consistency(self): + """Test that convenience functions work consistently""" + test_cases = [ + {}, + {"simple": "value"}, + {"client_id": "id", "client_secret": "secret"}, + {"complex": {"nested": {"deep": "value"}}}, + {"unicode": "test 🚀"}, + {"numbers": 42, "boolean": True, "null": None}, + ] + + for original_params in test_cases: + encrypted = encrypt_system_params(original_params) + decrypted = decrypt_system_params(encrypted) + assert decrypted == original_params, f"Failed for case: {original_params}" + + def test_convenience_functions_with_errors(self): + """Test convenience functions with error conditions""" + # Test encryption with invalid input + with pytest.raises(Exception): # noqa: B017 + encrypt_system_params(None) + + # Test decryption with invalid input + with pytest.raises(ValueError): + decrypt_system_params("") + + with pytest.raises(ValueError): + decrypt_system_params(None) + + +class TestErrorHandling: + """Test cases for error handling""" + + def test_encryption_error_inheritance(self): + """Test that EncryptionError is a proper exception""" + error = EncryptionError("Test error") + assert isinstance(error, Exception) + assert str(error) == "Test error" + + def test_encryption_error_with_cause(self): + """Test EncryptionError with cause""" + original_error = ValueError("Original error") + error = EncryptionError("Wrapper error") + error.__cause__ = original_error + + assert isinstance(error, Exception) + assert str(error) == "Wrapper error" + assert error.__cause__ is original_error + + def test_error_messages_are_informative(self): + """Test that error messages are informative""" + encrypter = SystemEncrypter("test_secret") + + # Test empty string error + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_params("") + assert "encrypted_data cannot be empty" in str(exc_info.value) + + # Test non-string error + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_params(123) + assert "encrypted_data must be a string" in str(exc_info.value) + + # Test invalid format error + short_data = base64.b64encode(b"short").decode() + with pytest.raises(EncryptionError) as exc_info: + encrypter.decrypt_params(short_data) + assert "Invalid encrypted data format" in str(exc_info.value) + + +class TestEdgeCases: + """Test cases for edge cases and boundary conditions""" + + def test_very_long_secret_key(self): + """Test with very long secret key""" + long_secret = "x" * 10000 + encrypter = SystemEncrypter(long_secret) + + # Key should still be 32 bytes due to SHA-256 + assert len(encrypter.key) == 32 + + # Should still work normally + params = {"client_id": "test_id"} + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_special_characters_in_secret_key(self): + """Test with special characters in secret key""" + special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀" + encrypter = SystemEncrypter(special_secret) + + params = {"client_id": "test_id"} + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_empty_values_in_params(self): + """Test with empty values in params""" + params = { + "client_id": "", + "client_secret": "", + "empty_dict": {}, + "empty_list": [], + "empty_string": "", + "zero": 0, + "false": False, + "none": None, + } + + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_deeply_nested_params(self): + """Test with deeply nested params""" + params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}} + + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_params_with_all_json_types(self): + """Test with all JSON-supported data types""" + params = { + "string": "test_string", + "integer": 42, + "float": 3.14159, + "boolean_true": True, + "boolean_false": False, + "null_value": None, + "empty_string": "", + "array": [1, "two", 3.0, True, False, None], + "object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True}, + } + + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + +class TestPerformance: + """Test cases for performance considerations""" + + def test_large_params(self): + """Test with large params""" + large_value = "x" * 100000 # 100KB + params = {"client_id": "test_id", "large_data": large_value} + + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_many_fields_params(self): + """Test with many fields in params""" + params = {f"field_{i}": f"value_{i}" for i in range(1000)} + + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_repeated_encryption_decryption(self): + """Test repeated encryption and decryption operations""" + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id", "client_secret": "test_secret"} + + # Test multiple rounds of encryption/decryption + for i in range(100): + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params diff --git a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py deleted file mode 100644 index e2607f0fb1..0000000000 --- a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py +++ /dev/null @@ -1,619 +0,0 @@ -import base64 -import hashlib -from unittest.mock import patch - -import pytest -from Crypto.Cipher import AES -from Crypto.Random import get_random_bytes -from Crypto.Util.Padding import pad - -from core.tools.utils.system_oauth_encryption import ( - OAuthEncryptionError, - SystemOAuthEncrypter, - create_system_oauth_encrypter, - decrypt_system_oauth_params, - encrypt_system_oauth_params, - get_system_oauth_encrypter, -) - - -class TestSystemOAuthEncrypter: - """Test cases for SystemOAuthEncrypter class""" - - def test_init_with_secret_key(self): - """Test initialization with provided secret key""" - secret_key = "test_secret_key" - encrypter = SystemOAuthEncrypter(secret_key=secret_key) - expected_key = hashlib.sha256(secret_key.encode()).digest() - assert encrypter.key == expected_key - - def test_init_with_none_secret_key(self): - """Test initialization with None secret key falls back to config""" - with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: - mock_config.SECRET_KEY = "config_secret" - encrypter = SystemOAuthEncrypter(secret_key=None) - expected_key = hashlib.sha256(b"config_secret").digest() - assert encrypter.key == expected_key - - def test_init_with_empty_secret_key(self): - """Test initialization with empty secret key""" - encrypter = SystemOAuthEncrypter(secret_key="") - expected_key = hashlib.sha256(b"").digest() - assert encrypter.key == expected_key - - def test_init_without_secret_key_uses_config(self): - """Test initialization without secret key uses config""" - with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: - mock_config.SECRET_KEY = "default_secret" - encrypter = SystemOAuthEncrypter() - expected_key = hashlib.sha256(b"default_secret").digest() - assert encrypter.key == expected_key - - def test_encrypt_oauth_params_basic(self): - """Test basic OAuth parameters encryption""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted = encrypter.encrypt_oauth_params(oauth_params) - - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - # Should be valid base64 - try: - base64.b64decode(encrypted) - except Exception: - pytest.fail("Encrypted result is not valid base64") - - def test_encrypt_oauth_params_empty_dict(self): - """Test encryption with empty dictionary""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {} - - encrypted = encrypter.encrypt_oauth_params(oauth_params) - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - - def test_encrypt_oauth_params_complex_data(self): - """Test encryption with complex data structures""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = { - "client_id": "test_id", - "client_secret": "test_secret", - "scopes": ["read", "write", "admin"], - "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True}, - "numeric_value": 42, - "boolean_value": False, - "null_value": None, - } - - encrypted = encrypter.encrypt_oauth_params(oauth_params) - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - - def test_encrypt_oauth_params_unicode_data(self): - """Test encryption with unicode data""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"} - - encrypted = encrypter.encrypt_oauth_params(oauth_params) - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - - def test_encrypt_oauth_params_large_data(self): - """Test encryption with large data""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = { - "client_id": "test_id", - "large_data": "x" * 10000, # 10KB of data - } - - encrypted = encrypter.encrypt_oauth_params(oauth_params) - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - - def test_encrypt_oauth_params_invalid_input(self): - """Test encryption with invalid input types""" - encrypter = SystemOAuthEncrypter("test_secret") - - with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params(None) - - with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params("not_a_dict") - - def test_decrypt_oauth_params_basic(self): - """Test basic OAuth parameters decryption""" - encrypter = SystemOAuthEncrypter("test_secret") - original_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - - assert decrypted == original_params - - def test_decrypt_oauth_params_empty_dict(self): - """Test decryption of empty dictionary""" - encrypter = SystemOAuthEncrypter("test_secret") - original_params = {} - - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - - assert decrypted == original_params - - def test_decrypt_oauth_params_complex_data(self): - """Test decryption with complex data structures""" - encrypter = SystemOAuthEncrypter("test_secret") - original_params = { - "client_id": "test_id", - "client_secret": "test_secret", - "scopes": ["read", "write", "admin"], - "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True}, - "numeric_value": 42, - "boolean_value": False, - "null_value": None, - } - - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - - assert decrypted == original_params - - def test_decrypt_oauth_params_unicode_data(self): - """Test decryption with unicode data""" - encrypter = SystemOAuthEncrypter("test_secret") - original_params = { - "client_id": "test_id", - "client_secret": "test_secret", - "description": "This is a test case 🚀", - } - - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - - assert decrypted == original_params - - def test_decrypt_oauth_params_large_data(self): - """Test decryption with large data""" - encrypter = SystemOAuthEncrypter("test_secret") - original_params = { - "client_id": "test_id", - "large_data": "x" * 10000, # 10KB of data - } - - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - - assert decrypted == original_params - - def test_decrypt_oauth_params_invalid_base64(self): - """Test decryption with invalid base64 data""" - encrypter = SystemOAuthEncrypter("test_secret") - - with pytest.raises(OAuthEncryptionError): - encrypter.decrypt_oauth_params("invalid_base64!") - - def test_decrypt_oauth_params_empty_string(self): - """Test decryption with empty string""" - encrypter = SystemOAuthEncrypter("test_secret") - - with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params("") - - assert "encrypted_data cannot be empty" in str(exc_info.value) - - def test_decrypt_oauth_params_non_string_input(self): - """Test decryption with non-string input""" - encrypter = SystemOAuthEncrypter("test_secret") - - with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) - - assert "encrypted_data must be a string" in str(exc_info.value) - - with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(None) - - assert "encrypted_data must be a string" in str(exc_info.value) - - def test_decrypt_oauth_params_too_short_data(self): - """Test decryption with too short encrypted data""" - encrypter = SystemOAuthEncrypter("test_secret") - - # Create data that's too short (less than 32 bytes) - short_data = base64.b64encode(b"short").decode() - - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.decrypt_oauth_params(short_data) - - assert "Invalid encrypted data format" in str(exc_info.value) - - def test_decrypt_oauth_params_corrupted_data(self): - """Test decryption with corrupted data""" - encrypter = SystemOAuthEncrypter("test_secret") - - # Create corrupted data (valid base64 but invalid encrypted content) - corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage - - with pytest.raises(OAuthEncryptionError): - encrypter.decrypt_oauth_params(corrupted_data) - - def test_decrypt_oauth_params_wrong_key(self): - """Test decryption with wrong key""" - encrypter1 = SystemOAuthEncrypter("secret1") - encrypter2 = SystemOAuthEncrypter("secret2") - - original_params = {"client_id": "test_id", "client_secret": "test_secret"} - encrypted = encrypter1.encrypt_oauth_params(original_params) - - with pytest.raises(OAuthEncryptionError): - encrypter2.decrypt_oauth_params(encrypted) - - def test_encryption_decryption_consistency(self): - """Test that encryption and decryption are consistent""" - encrypter = SystemOAuthEncrypter("test_secret") - - test_cases = [ - {}, - {"simple": "value"}, - {"client_id": "id", "client_secret": "secret"}, - {"complex": {"nested": {"deep": "value"}}}, - {"unicode": "test 🚀"}, - {"numbers": 42, "boolean": True, "null": None}, - {"array": [1, 2, 3, "four", {"five": 5}]}, - ] - - for original_params in test_cases: - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == original_params, f"Failed for case: {original_params}" - - def test_encryption_randomness(self): - """Test that encryption produces different results for same input""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted1 = encrypter.encrypt_oauth_params(oauth_params) - encrypted2 = encrypter.encrypt_oauth_params(oauth_params) - - # Should be different due to random IV - assert encrypted1 != encrypted2 - - # But should decrypt to same result - decrypted1 = encrypter.decrypt_oauth_params(encrypted1) - decrypted2 = encrypter.decrypt_oauth_params(encrypted2) - assert decrypted1 == decrypted2 == oauth_params - - def test_different_secret_keys_produce_different_results(self): - """Test that different secret keys produce different encrypted results""" - encrypter1 = SystemOAuthEncrypter("secret1") - encrypter2 = SystemOAuthEncrypter("secret2") - - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted1 = encrypter1.encrypt_oauth_params(oauth_params) - encrypted2 = encrypter2.encrypt_oauth_params(oauth_params) - - # Should produce different encrypted results - assert encrypted1 != encrypted2 - - # But each should decrypt correctly with its own key - decrypted1 = encrypter1.decrypt_oauth_params(encrypted1) - decrypted2 = encrypter2.decrypt_oauth_params(encrypted2) - assert decrypted1 == decrypted2 == oauth_params - - @patch("core.tools.utils.system_oauth_encryption.get_random_bytes") - def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes): - """Test encryption when crypto operation fails""" - mock_get_random_bytes.side_effect = Exception("Crypto error") - - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id"} - - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.encrypt_oauth_params(oauth_params) - - assert "Encryption failed" in str(exc_info.value) - - @patch("core.tools.utils.system_oauth_encryption.TypeAdapter") - def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter): - """Test encryption when JSON serialization fails""" - mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error") - - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id"} - - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.encrypt_oauth_params(oauth_params) - - assert "Encryption failed" in str(exc_info.value) - - def test_decrypt_oauth_params_invalid_json(self): - """Test decryption with invalid JSON data""" - encrypter = SystemOAuthEncrypter("test_secret") - - # Create valid encrypted data but with invalid JSON content - iv = get_random_bytes(16) - cipher = AES.new(encrypter.key, AES.MODE_CBC, iv) - invalid_json = b"invalid json content" - padded_data = pad(invalid_json, AES.block_size) - encrypted_data = cipher.encrypt(padded_data) - combined = iv + encrypted_data - encoded = base64.b64encode(combined).decode() - - with pytest.raises(OAuthEncryptionError): - encrypter.decrypt_oauth_params(encoded) - - def test_key_derivation_consistency(self): - """Test that key derivation is consistent""" - secret_key = "test_secret" - encrypter1 = SystemOAuthEncrypter(secret_key) - encrypter2 = SystemOAuthEncrypter(secret_key) - - assert encrypter1.key == encrypter2.key - - # Keys should be 32 bytes (256 bits) - assert len(encrypter1.key) == 32 - - -class TestFactoryFunctions: - """Test cases for factory functions""" - - def test_create_system_oauth_encrypter_with_secret(self): - """Test factory function with secret key""" - secret_key = "test_secret" - encrypter = create_system_oauth_encrypter(secret_key) - - assert isinstance(encrypter, SystemOAuthEncrypter) - expected_key = hashlib.sha256(secret_key.encode()).digest() - assert encrypter.key == expected_key - - def test_create_system_oauth_encrypter_without_secret(self): - """Test factory function without secret key""" - with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: - mock_config.SECRET_KEY = "config_secret" - encrypter = create_system_oauth_encrypter() - - assert isinstance(encrypter, SystemOAuthEncrypter) - expected_key = hashlib.sha256(b"config_secret").digest() - assert encrypter.key == expected_key - - def test_create_system_oauth_encrypter_with_none_secret(self): - """Test factory function with None secret key""" - with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: - mock_config.SECRET_KEY = "config_secret" - encrypter = create_system_oauth_encrypter(None) - - assert isinstance(encrypter, SystemOAuthEncrypter) - expected_key = hashlib.sha256(b"config_secret").digest() - assert encrypter.key == expected_key - - -class TestGlobalEncrypterInstance: - """Test cases for global encrypter instance""" - - def test_get_system_oauth_encrypter_singleton(self): - """Test that get_system_oauth_encrypter returns singleton instance""" - # Clear the global instance first - import core.tools.utils.system_oauth_encryption - - core.tools.utils.system_oauth_encryption._oauth_encrypter = None - - encrypter1 = get_system_oauth_encrypter() - encrypter2 = get_system_oauth_encrypter() - - assert encrypter1 is encrypter2 - assert isinstance(encrypter1, SystemOAuthEncrypter) - - def test_get_system_oauth_encrypter_uses_config(self): - """Test that global encrypter uses config""" - # Clear the global instance first - import core.tools.utils.system_oauth_encryption - - core.tools.utils.system_oauth_encryption._oauth_encrypter = None - - with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: - mock_config.SECRET_KEY = "global_secret" - encrypter = get_system_oauth_encrypter() - - expected_key = hashlib.sha256(b"global_secret").digest() - assert encrypter.key == expected_key - - -class TestConvenienceFunctions: - """Test cases for convenience functions""" - - def test_encrypt_system_oauth_params(self): - """Test encrypt_system_oauth_params convenience function""" - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted = encrypt_system_oauth_params(oauth_params) - - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - - def test_decrypt_system_oauth_params(self): - """Test decrypt_system_oauth_params convenience function""" - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted = encrypt_system_oauth_params(oauth_params) - decrypted = decrypt_system_oauth_params(encrypted) - - assert decrypted == oauth_params - - def test_convenience_functions_consistency(self): - """Test that convenience functions work consistently""" - test_cases = [ - {}, - {"simple": "value"}, - {"client_id": "id", "client_secret": "secret"}, - {"complex": {"nested": {"deep": "value"}}}, - {"unicode": "test 🚀"}, - {"numbers": 42, "boolean": True, "null": None}, - ] - - for original_params in test_cases: - encrypted = encrypt_system_oauth_params(original_params) - decrypted = decrypt_system_oauth_params(encrypted) - assert decrypted == original_params, f"Failed for case: {original_params}" - - def test_convenience_functions_with_errors(self): - """Test convenience functions with error conditions""" - # Test encryption with invalid input - with pytest.raises(Exception): # noqa: B017 - encrypt_system_oauth_params(None) - - # Test decryption with invalid input - with pytest.raises(ValueError): - decrypt_system_oauth_params("") - - with pytest.raises(ValueError): - decrypt_system_oauth_params(None) - - -class TestErrorHandling: - """Test cases for error handling""" - - def test_oauth_encryption_error_inheritance(self): - """Test that OAuthEncryptionError is a proper exception""" - error = OAuthEncryptionError("Test error") - assert isinstance(error, Exception) - assert str(error) == "Test error" - - def test_oauth_encryption_error_with_cause(self): - """Test OAuthEncryptionError with cause""" - original_error = ValueError("Original error") - error = OAuthEncryptionError("Wrapper error") - error.__cause__ = original_error - - assert isinstance(error, Exception) - assert str(error) == "Wrapper error" - assert error.__cause__ is original_error - - def test_error_messages_are_informative(self): - """Test that error messages are informative""" - encrypter = SystemOAuthEncrypter("test_secret") - - # Test empty string error - with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params("") - assert "encrypted_data cannot be empty" in str(exc_info.value) - - # Test non-string error - with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) - assert "encrypted_data must be a string" in str(exc_info.value) - - # Test invalid format error - short_data = base64.b64encode(b"short").decode() - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.decrypt_oauth_params(short_data) - assert "Invalid encrypted data format" in str(exc_info.value) - - -class TestEdgeCases: - """Test cases for edge cases and boundary conditions""" - - def test_very_long_secret_key(self): - """Test with very long secret key""" - long_secret = "x" * 10000 - encrypter = SystemOAuthEncrypter(long_secret) - - # Key should still be 32 bytes due to SHA-256 - assert len(encrypter.key) == 32 - - # Should still work normally - oauth_params = {"client_id": "test_id"} - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_special_characters_in_secret_key(self): - """Test with special characters in secret key""" - special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀" - encrypter = SystemOAuthEncrypter(special_secret) - - oauth_params = {"client_id": "test_id"} - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_empty_values_in_oauth_params(self): - """Test with empty values in oauth params""" - oauth_params = { - "client_id": "", - "client_secret": "", - "empty_dict": {}, - "empty_list": [], - "empty_string": "", - "zero": 0, - "false": False, - "none": None, - } - - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_deeply_nested_oauth_params(self): - """Test with deeply nested oauth params""" - oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}} - - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_oauth_params_with_all_json_types(self): - """Test with all JSON-supported data types""" - oauth_params = { - "string": "test_string", - "integer": 42, - "float": 3.14159, - "boolean_true": True, - "boolean_false": False, - "null_value": None, - "empty_string": "", - "array": [1, "two", 3.0, True, False, None], - "object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True}, - } - - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - -class TestPerformance: - """Test cases for performance considerations""" - - def test_large_oauth_params(self): - """Test with large oauth params""" - large_value = "x" * 100000 # 100KB - oauth_params = {"client_id": "test_id", "large_data": large_value} - - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_many_fields_oauth_params(self): - """Test with many fields in oauth params""" - oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)} - - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_repeated_encryption_decryption(self): - """Test repeated encryption and decryption operations""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - # Test multiple rounds of encryption/decryption - for i in range(100): - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params diff --git a/docker/.env.example b/docker/.env.example index ec7d572057..29741474fa 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1467,6 +1467,11 @@ ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id} MARKETPLACE_ENABLED=true MARKETPLACE_API_URL=https://marketplace.dify.ai +# Creators Platform configuration +CREATORS_PLATFORM_FEATURES_ENABLED=true +CREATORS_PLATFORM_API_URL=https://creators.dify.ai +CREATORS_PLATFORM_OAUTH_CLIENT_ID= + FORCE_VERIFYING_SIGNATURE=true ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES=true diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index aaf099453a..60ba510f44 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -629,6 +629,9 @@ x-shared-env: &shared-api-worker-env ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}} MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} + CREATORS_PLATFORM_FEATURES_ENABLED: ${CREATORS_PLATFORM_FEATURES_ENABLED:-true} + CREATORS_PLATFORM_API_URL: ${CREATORS_PLATFORM_API_URL:-https://creators.dify.ai} + CREATORS_PLATFORM_OAUTH_CLIENT_ID: ${CREATORS_PLATFORM_OAUTH_CLIENT_ID:-} FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES: ${ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES:-true} PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024} diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 405ce77400..1bff82ac17 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -124,11 +124,6 @@ "count": 1 } }, - "web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx": { "ts/no-explicit-any": { "count": 1 @@ -1080,11 +1075,6 @@ "count": 1 } }, - "web/app/components/base/emoji-picker/Inner.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/base/emoji-picker/index.tsx": { "no-restricted-imports": { "count": 1 @@ -3253,14 +3243,6 @@ "count": 2 } }, - "web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 2 - } - }, "web/app/components/plugins/plugin-auth/authorize/index.tsx": { "no-restricted-imports": { "count": 1 @@ -4270,11 +4252,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx": { "no-restricted-imports": { "count": 1 @@ -4293,16 +4270,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/_base/components/form-input-type-switch.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/_base/components/help-link.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/_base/components/input-support-select-var.tsx": { "no-restricted-imports": { "count": 1 @@ -4502,22 +4469,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/agent/components/model-bar.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "ts/no-empty-object-type": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/agent/components/tool-icon.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "react/unsupported-syntax": { - "count": 1 - } - }, "web/app/components/workflow/nodes/agent/default.ts": { "ts/no-explicit-any": { "count": 3 @@ -4859,11 +4810,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/type.ts": { "ts/no-explicit-any": { "count": 2 @@ -4966,14 +4912,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/llm/components/json-schema-config-modal/code-editor.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 4 - } - }, "web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx": { "no-restricted-imports": { "count": 1 @@ -5009,11 +4947,6 @@ "count": 2 } }, - "web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/actions.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/auto-width-input.tsx": { "react/set-state-in-effect": { "count": 1 @@ -5235,11 +5168,6 @@ "count": 5 } }, - "web/app/components/workflow/nodes/tool/components/copy-id.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/tool/components/input-var-list.tsx": { "ts/no-explicit-any": { "count": 7 @@ -5405,11 +5333,6 @@ "count": 1 } }, - "web/app/components/workflow/note-node/note-editor/toolbar/command.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/note-node/note-editor/utils.ts": { "regexp/no-useless-quantifier": { "count": 1 diff --git a/packages/dify-ui/README.md b/packages/dify-ui/README.md index e9c762073d..cd9485c400 100644 --- a/packages/dify-ui/README.md +++ b/packages/dify-ui/README.md @@ -90,6 +90,22 @@ See `[web/docs/overlay-migration.md](../../web/docs/overlay-migration.md)` for t - `pnpm -C packages/dify-ui storybook` — Storybook on the default port. Each primitive has `index.stories.tsx`. - `pnpm -C packages/dify-ui type-check` — `tsgo --noEmit` for this package only. +### Disabling Animations In Tests + +Base UI can wait for `element.getAnimations()` to finish before it unmounts overlays, panels, and transition-driven components. Browser-based test runners can make that timing unstable, especially when tests assert final DOM state rather than animation behavior. + +Set the Base UI test flag in a Vitest setup file to skip those waits: + +```ts +( + globalThis as typeof globalThis & { + BASE_UI_ANIMATIONS_DISABLED: boolean + } +).BASE_UI_ANIMATIONS_DISABLED = true +``` + +`packages/dify-ui/vitest.setup.ts` already applies this for primitive tests. + See `[AGENTS.md](./AGENTS.md)` for: - Component authoring rules (one-component-per-folder, `cva` + `cn`, relative imports inside the package, subpath imports from consumers). diff --git a/packages/dify-ui/src/toast/__tests__/index.spec.tsx b/packages/dify-ui/src/toast/__tests__/index.spec.tsx index edbdacd203..51fccf70d8 100644 --- a/packages/dify-ui/src/toast/__tests__/index.spec.tsx +++ b/packages/dify-ui/src/toast/__tests__/index.spec.tsx @@ -3,19 +3,20 @@ import { toast, ToastHost } from '../index' const asHTMLElement = (element: HTMLElement | SVGElement) => element as HTMLElement -declare global { - // eslint-disable-next-line vars-on-top - var BASE_UI_ANIMATIONS_DISABLED: boolean | undefined +const dispatchToastMouseOver = (element: HTMLElement | SVGElement) => { + element.dispatchEvent(new MouseEvent('mouseover', { + bubbles: true, + })) +} + +const dispatchToastMouseOut = (element: HTMLElement | SVGElement) => { + element.dispatchEvent(new MouseEvent('mouseout', { + bubbles: true, + relatedTarget: document.body, + })) } describe('@langgenius/dify-ui/toast', () => { - beforeAll(() => { - // Base UI waits for `requestAnimationFrame` + `getAnimations().finished` - // before unmounting a toast. Fake timers can't reliably drive that path, - // so short-circuit it to keep auto-dismiss assertions deterministic in CI. - globalThis.BASE_UI_ANIMATIONS_DISABLED = true - }) - beforeEach(() => { vi.clearAllMocks() vi.useFakeTimers() @@ -28,10 +29,6 @@ describe('@langgenius/dify-ui/toast', () => { vi.useRealTimers() }) - afterAll(() => { - globalThis.BASE_UI_ANIMATIONS_DISABLED = undefined - }) - it('should render a success toast when called through the typed shortcut', async () => { const screen = await render() @@ -62,13 +59,13 @@ describe('@langgenius/dify-ui/toast', () => { expect(document.body.querySelectorAll('[role="dialog"]')).toHaveLength(3) expect(document.body.querySelectorAll('button[aria-label="Close notification"][aria-hidden="true"]')).toHaveLength(3) - screen.getByRole('region', { name: 'Notifications' }).element().dispatchEvent(new MouseEvent('mouseover', { - bubbles: true, - })) + const viewport = screen.getByRole('region', { name: 'Notifications' }).element() + dispatchToastMouseOver(viewport) await vi.waitFor(() => { expect(document.body.querySelector('button[aria-label="Close notification"][aria-hidden="true"]')).not.toBeInTheDocument() }) + dispatchToastMouseOut(viewport) }) it('should render a neutral toast when called directly', async () => { @@ -115,11 +112,11 @@ describe('@langgenius/dify-ui/toast', () => { onClose, }) - screen.getByRole('region', { name: 'Notifications' }).element().dispatchEvent(new MouseEvent('mouseover', { - bubbles: true, - })) + const viewport = screen.getByRole('region', { name: 'Notifications' }).element() + dispatchToastMouseOver(viewport) await expect.element(screen.getByRole('button', { name: 'Close notification' })).toBeInTheDocument() + dispatchToastMouseOut(viewport) asHTMLElement(screen.getByRole('button', { name: 'Close notification' }).element()).click() await vi.waitFor(() => { @@ -128,21 +125,6 @@ describe('@langgenius/dify-ui/toast', () => { expect(onClose).toHaveBeenCalledTimes(1) }) - it('should auto dismiss toasts with the Base UI default timeout', async () => { - const screen = await render() - - toast('Default timeout') - await expect.element(screen.getByText('Default timeout')).toBeInTheDocument() - - await vi.advanceTimersByTimeAsync(4999) - expect(document.body).toHaveTextContent('Default timeout') - - await vi.advanceTimersByTimeAsync(1) - await vi.waitFor(() => { - expect(document.body).not.toHaveTextContent('Default timeout') - }) - }) - it('should respect the host timeout configuration', async () => { const screen = await render() diff --git a/packages/dify-ui/vite.config.ts b/packages/dify-ui/vite.config.ts index 5f3533c706..f2a2d24e57 100644 --- a/packages/dify-ui/vite.config.ts +++ b/packages/dify-ui/vite.config.ts @@ -11,6 +11,7 @@ export default defineConfig({ }, test: { globals: true, + setupFiles: ['./vitest.setup.ts'], browser: { enabled: true, provider: playwright(), diff --git a/packages/dify-ui/vitest.setup.ts b/packages/dify-ui/vitest.setup.ts new file mode 100644 index 0000000000..285d6e7760 --- /dev/null +++ b/packages/dify-ui/vitest.setup.ts @@ -0,0 +1,5 @@ +( + globalThis as typeof globalThis & { + BASE_UI_ANIMATIONS_DISABLED: boolean + } +).BASE_UI_ANIMATIONS_DISABLED = true diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/provider-config-modal.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/provider-config-modal.spec.tsx new file mode 100644 index 0000000000..f9e5ea28ee --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/provider-config-modal.spec.tsx @@ -0,0 +1,346 @@ +import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from '../type' +import { toast } from '@langgenius/dify-ui/toast' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { addTracingConfig, removeTracingConfig, updateTracingConfig } from '@/service/apps' +import ConfigBtn from '../config-button' +import ProviderConfigModal from '../provider-config-modal' +import { TracingProvider } from '../type' + +vi.mock('@/service/apps', () => ({ + addTracingConfig: vi.fn(), + removeTracingConfig: vi.fn(), + updateTracingConfig: vi.fn(), +})) + +vi.mock('@langgenius/dify-ui/toast', () => ({ + toast: vi.fn(), +})) + +type ProviderPayload = AliyunConfig | ArizeConfig | DatabricksConfig | LangFuseConfig | LangSmithConfig | MLflowConfig | OpikConfig | PhoenixConfig | TencentConfig | WeaveConfig + +const validConfigs = { + [TracingProvider.arize]: { + api_key: 'arize-api-key', + space_id: 'space-id', + project: 'arize-project', + endpoint: 'https://otlp.arize.com', + }, + [TracingProvider.phoenix]: { + api_key: 'phoenix-api-key', + project: 'phoenix-project', + endpoint: 'https://app.phoenix.arize.com', + }, + [TracingProvider.langSmith]: { + api_key: 'langsmith-api-key', + project: 'langsmith-project', + endpoint: 'https://api.smith.langchain.com', + }, + [TracingProvider.langfuse]: { + public_key: 'public-key', + secret_key: 'secret-key', + host: 'https://cloud.langfuse.com', + }, + [TracingProvider.opik]: { + api_key: 'opik-api-key', + project: 'opik-project', + workspace: 'default', + url: 'https://www.comet.com/opik/api/', + }, + [TracingProvider.weave]: { + api_key: 'weave-api-key', + entity: 'wandb-entity', + project: 'weave-project', + endpoint: 'https://trace.wandb.ai/', + host: 'https://api.wandb.ai', + }, + [TracingProvider.aliyun]: { + app_name: 'aliyun-app', + license_key: 'license-key', + endpoint: 'https://tracing.arms.aliyuncs.com', + }, + [TracingProvider.mlflow]: { + tracking_uri: 'http://localhost:5000', + experiment_id: 'experiment-id', + username: 'mlflow-user', + password: 'mlflow-password', + }, + [TracingProvider.databricks]: { + experiment_id: 'experiment-id', + host: 'https://workspace.cloud.databricks.com', + client_id: 'client-id', + client_secret: 'client-secret', + personal_access_token: 'personal-access-token', + }, + [TracingProvider.tencent]: { + token: 'tencent-token', + endpoint: 'https://your-region.cls.tencentcs.com', + service_name: 'dify_app', + }, +} satisfies Record + +const providerFieldLabels = [ + [TracingProvider.arize, ['API Key', 'Space ID', 'app.tracing.configProvider.project', 'Endpoint']], + [TracingProvider.phoenix, ['API Key', 'app.tracing.configProvider.project', 'Endpoint']], + [TracingProvider.langSmith, ['API Key', 'app.tracing.configProvider.project', 'Endpoint']], + [TracingProvider.langfuse, ['app.tracing.configProvider.secretKey', 'app.tracing.configProvider.publicKey', 'Host']], + [TracingProvider.opik, ['API Key', 'app.tracing.configProvider.project', 'Workspace', 'Url']], + [TracingProvider.weave, ['API Key', 'app.tracing.configProvider.project', 'Entity', 'Endpoint', 'Host']], + [TracingProvider.aliyun, ['License Key', 'Endpoint', 'App Name']], + [TracingProvider.mlflow, ['app.tracing.configProvider.trackingUri', 'app.tracing.configProvider.experimentId', 'app.tracing.configProvider.username', 'app.tracing.configProvider.password']], + [TracingProvider.databricks, ['app.tracing.configProvider.experimentId', 'app.tracing.configProvider.databricksHost', 'app.tracing.configProvider.clientId', 'app.tracing.configProvider.clientSecret', 'app.tracing.configProvider.personalAccessToken']], + [TracingProvider.tencent, ['Token', 'Endpoint', 'Service Name']], +] as const + +const invalidConfigCases: Array<{ + provider: TracingProvider + payload: ProviderPayload + missingField: string +}> = [ + { provider: TracingProvider.arize, payload: { ...validConfigs[TracingProvider.arize], api_key: '' }, missingField: 'API Key' }, + { provider: TracingProvider.arize, payload: { ...validConfigs[TracingProvider.arize], space_id: '' }, missingField: 'Space ID' }, + { provider: TracingProvider.arize, payload: { ...validConfigs[TracingProvider.arize], project: '' }, missingField: 'app.tracing.configProvider.project' }, + { provider: TracingProvider.phoenix, payload: { ...validConfigs[TracingProvider.phoenix], api_key: '' }, missingField: 'API Key' }, + { provider: TracingProvider.phoenix, payload: { ...validConfigs[TracingProvider.phoenix], project: '' }, missingField: 'app.tracing.configProvider.project' }, + { provider: TracingProvider.langSmith, payload: { ...validConfigs[TracingProvider.langSmith], api_key: '' }, missingField: 'API Key' }, + { provider: TracingProvider.langSmith, payload: { ...validConfigs[TracingProvider.langSmith], project: '' }, missingField: 'app.tracing.configProvider.project' }, + { provider: TracingProvider.langfuse, payload: { ...validConfigs[TracingProvider.langfuse], secret_key: '' }, missingField: 'app.tracing.configProvider.secretKey' }, + { provider: TracingProvider.langfuse, payload: { ...validConfigs[TracingProvider.langfuse], public_key: '' }, missingField: 'app.tracing.configProvider.publicKey' }, + { provider: TracingProvider.langfuse, payload: { ...validConfigs[TracingProvider.langfuse], host: '' }, missingField: 'Host' }, + { provider: TracingProvider.weave, payload: { ...validConfigs[TracingProvider.weave], api_key: '' }, missingField: 'API Key' }, + { provider: TracingProvider.weave, payload: { ...validConfigs[TracingProvider.weave], project: '' }, missingField: 'app.tracing.configProvider.project' }, + { provider: TracingProvider.aliyun, payload: { ...validConfigs[TracingProvider.aliyun], app_name: '' }, missingField: 'App Name' }, + { provider: TracingProvider.aliyun, payload: { ...validConfigs[TracingProvider.aliyun], license_key: '' }, missingField: 'License Key' }, + { provider: TracingProvider.aliyun, payload: { ...validConfigs[TracingProvider.aliyun], endpoint: '' }, missingField: 'Endpoint' }, + { provider: TracingProvider.mlflow, payload: { ...validConfigs[TracingProvider.mlflow], tracking_uri: '' }, missingField: 'Tracking URI' }, + { provider: TracingProvider.databricks, payload: { ...validConfigs[TracingProvider.databricks], experiment_id: '' }, missingField: 'Experiment ID' }, + { provider: TracingProvider.databricks, payload: { ...validConfigs[TracingProvider.databricks], host: '' }, missingField: 'Host' }, + { provider: TracingProvider.tencent, payload: { ...validConfigs[TracingProvider.tencent], token: '' }, missingField: 'Token' }, + { provider: TracingProvider.tencent, payload: { ...validConfigs[TracingProvider.tencent], endpoint: '' }, missingField: 'Endpoint' }, + { provider: TracingProvider.tencent, payload: { ...validConfigs[TracingProvider.tencent], service_name: '' }, missingField: 'Service Name' }, +] + +const renderConfigButton = () => { + return render( + + + , + ) +} + +const renderProviderConfigModal = ({ + type = TracingProvider.langfuse, + payload, +}: { + type?: TracingProvider + payload?: ProviderPayload | null +} = {}) => { + const callbacks = { + onCancel: vi.fn(), + onSaved: vi.fn(), + onChosen: vi.fn(), + onRemoved: vi.fn(), + } + + render( + , + ) + + return callbacks +} + +describe('ProviderConfigModal', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.mocked(addTracingConfig).mockResolvedValue({ result: 'success' }) + vi.mocked(updateTracingConfig).mockResolvedValue({ result: 'success' }) + vi.mocked(removeTracingConfig).mockResolvedValue({ result: 'success' }) + }) + + describe('Nested Overlay Behavior', () => { + it('should keep the provider config modal open when clicking inside it', async () => { + const user = userEvent.setup() + renderConfigButton() + + await user.click(screen.getByRole('button', { name: 'Open tracing' })) + await waitFor(() => { + expect(screen.getByText('app.tracing.tracing')).toBeInTheDocument() + }) + + const configActions = screen.getAllByText('app.tracing.config') + expect(configActions.length).toBeGreaterThan(0) + await user.click(configActions[0]!) + await waitFor(() => { + expect(screen.getByText('app.tracing.configProvider.titleapp.tracing.langfuse.title')).toBeInTheDocument() + }) + expect(screen.getByRole('dialog')).toBeInTheDocument() + + await user.click(screen.getByPlaceholderText('https://cloud.langfuse.com')) + + expect(screen.getByText('app.tracing.tracing')).toBeInTheDocument() + expect(screen.getByText('app.tracing.configProvider.titleapp.tracing.langfuse.title')).toBeInTheDocument() + }) + }) + + describe('Rendering', () => { + it.each(providerFieldLabels)('should render %s fields when adding a provider', (provider, expectedLabels) => { + renderProviderConfigModal({ type: provider }) + + expect(screen.getByText(`app.tracing.configProvider.titleapp.tracing.${provider}.title`)).toBeInTheDocument() + expectedLabels.forEach((label) => { + expect(screen.getByText(label)).toBeInTheDocument() + }) + expect(screen.getByRole('button', { name: 'common.operation.saveAndEnable' })).toBeInTheDocument() + }) + }) + + describe('Saving', () => { + it('should add and choose the provider when saving a new config', async () => { + const user = userEvent.setup() + const callbacks = renderProviderConfigModal({ type: TracingProvider.langfuse }) + const textboxes = screen.getAllByRole('textbox') + + await user.type(textboxes[0]!, 'secret-key') + await user.type(textboxes[1]!, 'public-key') + await user.type(textboxes[2]!, 'https://cloud.langfuse.com') + await user.click(screen.getByRole('button', { name: 'common.operation.saveAndEnable' })) + + await waitFor(() => { + expect(addTracingConfig).toHaveBeenCalledWith({ + appId: 'app-id', + body: { + tracing_provider: TracingProvider.langfuse, + tracing_config: validConfigs[TracingProvider.langfuse], + }, + }) + }) + expect(callbacks.onSaved).toHaveBeenCalledWith(validConfigs[TracingProvider.langfuse]) + expect(callbacks.onChosen).toHaveBeenCalledWith(TracingProvider.langfuse) + expect(toast).toHaveBeenCalledWith('common.api.success', { type: 'success' }) + }) + + it.each(Object.values(TracingProvider))('should update valid %s config in edit mode', async (provider) => { + const user = userEvent.setup() + const callbacks = renderProviderConfigModal({ + type: provider, + payload: validConfigs[provider], + }) + + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + await waitFor(() => { + expect(updateTracingConfig).toHaveBeenCalledWith({ + appId: 'app-id', + body: { + tracing_provider: provider, + tracing_config: validConfigs[provider], + }, + }) + }) + expect(callbacks.onSaved).toHaveBeenCalledWith(validConfigs[provider]) + expect(callbacks.onChosen).not.toHaveBeenCalled() + }) + + it.each(invalidConfigCases)('should reject $provider config when $missingField is missing', async ({ provider, payload, missingField }) => { + const user = userEvent.setup() + renderProviderConfigModal({ + type: provider, + payload, + }) + + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + expect(updateTracingConfig).not.toHaveBeenCalled() + expect(toast).toHaveBeenCalledWith( + expect.stringContaining(missingField), + { type: 'error' }, + ) + }) + }) + + describe('Closing And Removing', () => { + it('should cancel when the cancel button is clicked', async () => { + const user = userEvent.setup() + const callbacks = renderProviderConfigModal() + + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + expect(callbacks.onCancel).toHaveBeenCalledTimes(1) + }) + + it('should cancel when the dialog is closed with Escape', async () => { + const user = userEvent.setup() + const callbacks = renderProviderConfigModal() + + await user.keyboard('{Escape}') + + await waitFor(() => { + expect(callbacks.onCancel).toHaveBeenCalledTimes(1) + }) + }) + + it('should remove an existing provider after confirmation', async () => { + const user = userEvent.setup() + const callbacks = renderProviderConfigModal({ + type: TracingProvider.langfuse, + payload: validConfigs[TracingProvider.langfuse], + }) + + await user.click(screen.getByRole('button', { name: 'common.operation.remove' })) + expect(screen.getByText('app.tracing.configProvider.removeConfirmTitle:{"key":"app.tracing.langfuse.title"}')).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'common.operation.confirm' })) + + await waitFor(() => { + expect(removeTracingConfig).toHaveBeenCalledWith({ + appId: 'app-id', + provider: TracingProvider.langfuse, + }) + }) + expect(callbacks.onRemoved).toHaveBeenCalledTimes(1) + expect(toast).toHaveBeenCalledWith('common.api.remove', { type: 'success' }) + }) + + it('should return to the edit dialog when remove confirmation is canceled', async () => { + const user = userEvent.setup() + renderProviderConfigModal({ + type: TracingProvider.langfuse, + payload: validConfigs[TracingProvider.langfuse], + }) + + await user.click(screen.getByRole('button', { name: 'common.operation.remove' })) + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + expect(removeTracingConfig).not.toHaveBeenCalled() + expect(screen.getByText('app.tracing.configProvider.titleapp.tracing.langfuse.title')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx index 4f2497ad71..734b39bd41 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx @@ -11,6 +11,10 @@ import { AlertDialogTitle, } from '@langgenius/dify-ui/alert-dialog' import { Button } from '@langgenius/dify-ui/button' +import { + Dialog, + DialogContent, +} from '@langgenius/dify-ui/dialog' import { toast } from '@langgenius/dify-ui/toast' import { useBoolean } from 'ahooks' import * as React from 'react' @@ -19,10 +23,6 @@ import { useTranslation } from 'react-i18next' import Divider from '@/app/components/base/divider' import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import { - PortalToFollowElem, - PortalToFollowElemContent, -} from '@/app/components/base/portal-to-follow-elem' import { addTracingConfig, removeTracingConfig, updateTracingConfig } from '@/service/apps' import { docURL } from './config' import Field from './field' @@ -153,7 +153,11 @@ const ProviderConfigModal: FC = ({ return weaveConfigTemplate })()) - const [isShowRemoveConfirm, { + const [isConfigDialogOpen, { + set: setIsConfigDialogOpen, + }] = useBoolean(true) + const [isRemoveDialogOpen, { + set: setIsRemoveDialogOpen, setTrue: showRemoveConfirm, setFalse: hideRemoveConfirm, }] = useBoolean(false) @@ -291,13 +295,24 @@ const ProviderConfigModal: FC = ({ } }, [appId, checkValid, config, isAdd, isEdit, isSaving, onChosen, onSaved, t, type]) + // Defer onCancel to onOpenChangeComplete so the dialog's exit animation + // (scale/opacity transition) can finish before the parent unmounts this modal. + const handleConfigDialogOpenChangeComplete = useCallback((open: boolean) => { + if (!open) + onCancel() + }, [onCancel]) + return ( <> - {!isShowRemoveConfirm + {!isRemoveDialogOpen ? ( - - -
+ + +
@@ -650,7 +665,7 @@ const ProviderConfigModal: FC = ({ )} @@ -683,11 +698,11 @@ const ProviderConfigModal: FC = ({
- - + +
) : ( - !open && hideRemoveConfirm()}> +
diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index dd95dc04ba..55666db193 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -16,9 +16,9 @@ import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' -import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' +import { setOAuthPendingRedirect } from '@/app/signin/utils/post-login-redirect' import { useRouter, useSearchParams } from '@/next/navigation' -import { isLegacyBase401, userProfileQueryOptions } from '@/service/use-common' +import { isLegacyBase401, useLogout, userProfileQueryOptions } from '@/service/use-common' import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth' function buildReturnUrl(pathname: string, search: string) { @@ -73,14 +73,17 @@ export default function OAuthAuthorize() { const userProfile = userProfileResp?.profile const { data: authAppInfo, isLoading: isOAuthLoading, isError } = useOAuthAppInfo(client_id, redirect_uri) const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp() + const { mutateAsync: logout } = useLogout() const hasNotifiedRef = useRef(false) const isLoading = isOAuthLoading || isProfileLoading - const onLoginSwitchClick = () => { + const onLoginSwitchClick = async () => { try { - const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`) - setPostLoginRedirect(returnUrl) - router.push('/signin') + const returnUrl = buildReturnUrl('/account/oauth/authorize', `?${searchParams.toString()}`) + setOAuthPendingRedirect(returnUrl) + if (isLoggedIn) + await logout() + router.push(`/signin?redirect_url=${encodeURIComponent(returnUrl)}`) } catch { router.push('/signin') diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index 2c50312590..3d2af1ce61 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -85,7 +85,7 @@ export const AppInitializer = ({ return } - const redirectUrl = resolvePostLoginRedirect() + const redirectUrl = resolvePostLoginRedirect(searchParams) if (redirectUrl) { location.replace(redirectUrl) return diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index 7b69c94b47..e0942a9706 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -17,6 +17,15 @@ import DatasetSidebarDropdown from './dataset-sidebar-dropdown' import NavLink from './nav-link' import ToggleButton from './toggle-button' +const isShortcutFromInputArea = (target: EventTarget | null) => { + if (!(target instanceof HTMLElement)) + return false + + return target.tagName === 'INPUT' + || target.tagName === 'TEXTAREA' + || target.isContentEditable +} + type IAppDetailNavProps = { iconType?: 'app' | 'dataset' navigation: Array<{ @@ -74,6 +83,9 @@ const AppDetailNav = ({ }, [appSidebarExpand, setAppSidebarExpand]) useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.b`, (e) => { + if (isShortcutFromInputArea(e.target)) + return + e.preventDefault() handleToggle() }, { exactMatch: true, useCapture: true }) diff --git a/web/app/components/app/app-publisher/__tests__/index.spec.tsx b/web/app/components/app/app-publisher/__tests__/index.spec.tsx index 9615acb309..cae41f7005 100644 --- a/web/app/components/app/app-publisher/__tests__/index.spec.tsx +++ b/web/app/components/app/app-publisher/__tests__/index.spec.tsx @@ -91,8 +91,11 @@ vi.mock('@/service/explore', () => ({ fetchInstalledAppList: (...args: unknown[]) => mockFetchInstalledAppList(...args), })) +const mockPublishToCreatorsPlatform = vi.fn() + vi.mock('@/service/apps', () => ({ fetchAppDetailDirect: (...args: unknown[]) => mockFetchAppDetailDirect(...args), + publishToCreatorsPlatform: (...args: unknown[]) => mockPublishToCreatorsPlatform(...args), })) vi.mock('@/service/use-apps', () => ({ @@ -469,6 +472,76 @@ describe('AppPublisher', () => { }) }) + it('should show marketplace button and open redirect URL on success', async () => { + mockPublishToCreatorsPlatform.mockResolvedValue({ redirect_url: 'https://marketplace.example.com/publish?code=abc' }) + const windowOpenSpy = vi.spyOn(window, 'open').mockImplementation(() => null) + + renderWithSystemFeatures( + , + { systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } }, + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('common.publishToMarketplace')) + + await waitFor(() => { + expect(mockPublishToCreatorsPlatform).toHaveBeenCalledWith({ appID: 'app-1' }) + expect(windowOpenSpy).toHaveBeenCalledWith('https://marketplace.example.com/publish?code=abc', '_blank') + }) + + windowOpenSpy.mockRestore() + }) + + it('should show toast error when publish to marketplace fails', async () => { + mockPublishToCreatorsPlatform.mockRejectedValue(new Error('network error')) + + renderWithSystemFeatures( + , + { systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } }, + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('common.publishToMarketplace')) + + await waitFor(() => { + expect(mockToastError).toHaveBeenCalledWith('common.publishToMarketplaceFailed') + }) + }) + + it('should disable marketplace button when not yet published', () => { + renderWithSystemFeatures( + , + { systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } }, + ) + + fireEvent.click(screen.getByText('common.publish')) + const marketplaceButton = screen.getByText('common.publishToMarketplace').closest('a, button, div[role="button"]') as HTMLElement + expect(marketplaceButton).toBeInTheDocument() + // clicking should not call the API because publishedAt is undefined + fireEvent.click(screen.getByText('common.publishToMarketplace')) + expect(mockPublishToCreatorsPlatform).not.toHaveBeenCalled() + }) + + it('should hide marketplace button when enable_creators_platform is false', () => { + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + expect(screen.queryByText('common.publishToMarketplace')).not.toBeInTheDocument() + }) + it('should keep access control open when app detail is unavailable during confirmation', async () => { mockAppDetail = null diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 6ef09917e7..fc750971c1 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -29,7 +29,7 @@ import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' import { useSnippetAndEvaluationPlanAccess } from '@/hooks/use-snippet-and-evaluation-plan-access' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control' -import { fetchAppDetailDirect } from '@/service/apps' +import { fetchAppDetailDirect, publishToCreatorsPlatform } from '@/service/apps' import { fetchInstalledAppList } from '@/service/explore' import { systemFeaturesQueryOptions } from '@/service/system-features' import { useConvertWorkflowTypeMutation } from '@/service/use-apps' @@ -46,6 +46,7 @@ import { PublisherActionsSection, PublisherSummarySection, } from './sections' +import SuggestedAction from './suggested-action' import { getDisabledFunctionTooltip, getPublisherAppUrl, @@ -134,6 +135,7 @@ const AppPublisher = ({ const [evaluationWorkflowSwitchTargets, setEvaluationWorkflowSwitchTargets] = useState([]) const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false) + const [publishingToMarketplace, setPublishingToMarketplace] = useState(false) const workflowStore = useContext(WorkflowContext) const appDetail = useAppStore(state => state.appDetail) @@ -385,6 +387,22 @@ const AppPublisher = ({ if (!nextOpen) setEvaluationWorkflowSwitchTargets([]) }, []) + const handlePublishToMarketplace = useCallback(async () => { + if (!appDetail?.id || publishingToMarketplace) + return + setPublishingToMarketplace(true) + try { + const res = await publishToCreatorsPlatform({ appID: appDetail.id }) + if (res.redirect_url) + window.open(res.redirect_url, '_blank') + } + catch { + toast.error(t('common.publishToMarketplaceFailed', { ns: 'workflow' })) + } + finally { + setPublishingToMarketplace(false) + } + }, [appDetail?.id, publishingToMarketplace, t]) useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.shift.p`, (e) => { e.preventDefault() @@ -509,6 +527,19 @@ const AppPublisher = ({ workflowToolAvailable={workflowToolAvailable} workflowToolMessage={workflowToolMessage} /> + {systemFeatures.enable_creators_platform && ( +
+ } + disabled={!publishedAt || publishingToMarketplace} + onClick={handlePublishToMarketplace} + > + {publishingToMarketplace + ? t('common.publishingToMarketplace', { ns: 'workflow' }) + : t('common.publishToMarketplace', { ns: 'workflow' })} + +
+ )} )}
diff --git a/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx b/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx index f56b815399..b3d007c2e7 100644 --- a/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx +++ b/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx @@ -137,7 +137,7 @@ describe('CreateFromDSLModal', () => { />, ) - expect(screen.getByText('importFromDSL'))!.toBeInTheDocument() + expect(screen.getByText('importApp'))!.toBeInTheDocument() await waitFor(() => { expect(screen.getByText('demo.yml'))!.toBeInTheDocument() @@ -161,7 +161,7 @@ describe('CreateFromDSLModal', () => { }) expect(screen.getByPlaceholderText('importFromDSLUrlPlaceholder'))!.toBeInTheDocument() - const closeTrigger = screen.getByText('importFromDSL').parentElement?.querySelector('.cursor-pointer.items-center') as HTMLElement + const closeTrigger = screen.getByText('importApp').parentElement?.querySelector('.cursor-pointer.items-center') as HTMLElement fireEvent.click(closeTrigger) expect(handleClose).toHaveBeenCalledTimes(1) }) diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 4f99fe9027..bc5f352634 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -225,7 +225,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS onClose={noop} >
- {t('importFromDSL', { ns: 'app' })} + {t('importApp', { ns: 'app' })}
onClose()} diff --git a/web/app/components/apps/__tests__/index.spec.tsx b/web/app/components/apps/__tests__/index.spec.tsx index 2e0d1bcc84..94fa9f3484 100644 --- a/web/app/components/apps/__tests__/index.spec.tsx +++ b/web/app/components/apps/__tests__/index.spec.tsx @@ -7,9 +7,21 @@ import { useContextSelector } from 'use-context-selector' import AppListContext from '@/context/app-list-context' import { fetchAppDetail } from '@/service/explore' import { AppModeEnum } from '@/types/app' - import Apps from '../index' +vi.mock('@/next/dynamic', () => ({ + default: (loader: () => Promise<{ default: React.ComponentType }>) => { + const LazyComp = React.lazy(loader) + return function DynamicWrapper(props: Record) { + return React.createElement( + React.Suspense, + { fallback: null }, + React.createElement(LazyComp, props), + ) + } + }, +})) + let documentTitleCalls: string[] = [] let educationInitCalls: number = 0 const mockHandleImportDSL = vi.fn() @@ -65,6 +77,16 @@ vi.mock('@/hooks/use-import-dsl', () => ({ }), })) +const mockReplace = vi.fn() +let mockSearchParams = new URLSearchParams() + +vi.mock('@/next/navigation', () => ({ + useRouter: () => ({ + replace: mockReplace, + }), + useSearchParams: () => mockSearchParams, +})) + vi.mock('../list', () => { const MockList = () => { const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel) @@ -129,6 +151,16 @@ vi.mock('../../app/create-from-dsl-modal/dsl-confirm-modal', () => ({ ), })) +vi.mock('../import-from-marketplace-template-modal', () => ({ + default: ({ templateId, onClose, onConfirm }: { templateId: string, onClose: () => void, onConfirm: (dsl: string) => void }) => ( +
+ {templateId} + + +
+ ), +})) + vi.mock('@/service/explore', () => ({ fetchAppDetail: vi.fn(), })) @@ -161,6 +193,8 @@ describe('Apps', () => { vi.clearAllMocks() documentTitleCalls = [] educationInitCalls = 0 + mockSearchParams = new URLSearchParams() + mockReplace.mockClear() mockFetchAppDetail.mockResolvedValue({ id: 'template-1', name: 'Sample App', @@ -304,6 +338,66 @@ describe('Apps', () => { }) }) + describe('Marketplace Template', () => { + it('should render the template modal when template-id is in search params', async () => { + mockSearchParams = new URLSearchParams('template-id=tpl-42') + renderWithClient() + + expect(await screen.findByTestId('marketplace-template-modal')).toBeInTheDocument() + expect(screen.getByTestId('template-id')).toHaveTextContent('tpl-42') + }) + + it('should not render the template modal when no template-id is present', () => { + renderWithClient() + + expect(screen.queryByTestId('marketplace-template-modal')).not.toBeInTheDocument() + }) + + it('should close the template modal and remove template-id from URL', async () => { + mockSearchParams = new URLSearchParams('template-id=tpl-42') + renderWithClient() + + fireEvent.click(await screen.findByTestId('close-template')) + + expect(mockReplace).toHaveBeenCalledTimes(1) + const replaceArg = mockReplace.mock.calls[0]![0] as string + expect(replaceArg).not.toContain('template-id') + }) + + it('should import DSL from marketplace template on confirm', async () => { + mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => { + options.onSuccess?.() + }) + mockSearchParams = new URLSearchParams('template-id=tpl-42') + renderWithClient() + + fireEvent.click(await screen.findByTestId('confirm-template')) + + await waitFor(() => { + expect(mockHandleImportDSL).toHaveBeenCalledWith( + { mode: 'yaml-content', yaml_content: 'yaml-dsl-content' }, + expect.objectContaining({ onSuccess: expect.any(Function) }), + ) + expect(mockReplace).toHaveBeenCalled() + }) + }) + + it('should show DSL confirm modal when marketplace import is pending', async () => { + mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onPending?: () => void }) => { + options.onPending?.() + }) + mockSearchParams = new URLSearchParams('template-id=tpl-42') + renderWithClient() + + fireEvent.click(await screen.findByTestId('confirm-template')) + + await waitFor(() => { + expect(screen.getByTestId('dsl-confirm-modal')).toBeInTheDocument() + expect(mockReplace).toHaveBeenCalled() + }) + }) + }) + describe('Styling', () => { it('should have overflow-y-auto class', () => { const { container } = renderWithClient() diff --git a/web/app/components/apps/import-from-marketplace-template-modal.tsx b/web/app/components/apps/import-from-marketplace-template-modal.tsx new file mode 100644 index 0000000000..a6a3dee8e4 --- /dev/null +++ b/web/app/components/apps/import-from-marketplace-template-modal.tsx @@ -0,0 +1,182 @@ +'use client' + +import { Button } from '@langgenius/dify-ui/button' +import { Dialog, DialogContent } from '@langgenius/dify-ui/dialog' +import { toast } from '@langgenius/dify-ui/toast' +import { RiCloseLine } from '@remixicon/react' +import { useCallback, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { MARKETPLACE_API_PREFIX } from '@/config' +import { + fetchMarketplaceTemplateDSL, + useMarketplaceTemplateDetail, +} from '@/service/marketplace-templates' + +type ImportFromMarketplaceTemplateModalProps = { + templateId: string + onClose: () => void + onConfirm: (dslContent: string) => void +} + +const ImportFromMarketplaceTemplateModal = ({ + templateId, + onClose, + onConfirm, +}: ImportFromMarketplaceTemplateModalProps) => { + const { t } = useTranslation() + const { data, isLoading, isError } = useMarketplaceTemplateDetail(templateId) + const template = data?.data + const [importing, setImporting] = useState(false) + const isImportingRef = useRef(false) + + const CATEGORY_I18N_MAP: Record = useMemo(() => ({ + marketing: t('marketplace.template.category.marketing', { ns: 'app' }), + sales: t('marketplace.template.category.sales', { ns: 'app' }), + support: t('marketplace.template.category.support', { ns: 'app' }), + operations: t('marketplace.template.category.operations', { ns: 'app' }), + it: t('marketplace.template.category.it', { ns: 'app' }), + knowledge: t('marketplace.template.category.knowledge', { ns: 'app' }), + design: t('marketplace.template.category.design', { ns: 'app' }), + }), [t]) + + const translateCategory = useCallback((slug: string) => { + return CATEGORY_I18N_MAP[slug] ?? slug + }, [CATEGORY_I18N_MAP]) + + const handleConfirm = useCallback(async () => { + if (isImportingRef.current) + return + isImportingRef.current = true + setImporting(true) + try { + const dsl = await fetchMarketplaceTemplateDSL(templateId) + onConfirm(dsl) + } + catch { + toast.error(t('marketplace.template.importFailed', { ns: 'app' })) + } + finally { + setImporting(false) + isImportingRef.current = false + } + }, [templateId, onConfirm, t]) + + return ( + { + if (!open) + onClose() + }} + > + +
+ {t('marketplace.template.modalTitle', { ns: 'app' })} +
+ +
+
+ +
+ {isLoading && ( +
+
Loading...
+
+ )} + + {isError && ( +
+
+ {t('marketplace.template.fetchFailed', { ns: 'app' })} +
+
+ )} + + {template && ( +
+
+ {template.icon_file_key + ? ( + {template.template_name} + ) + : ( +
+ {template.icon || '📄'} +
+ )} +
+
{template.template_name}
+
+ + {t('marketplace.template.publishedBy', { ns: 'app' })} + {' '} + {template.publisher_unique_handle} + + · + + {t('marketplace.template.usageCount', { ns: 'app' })} + {' '} + {template.usage_count} + +
+
+
+ + {template.overview && ( +
+
+ {t('marketplace.template.overview', { ns: 'app' })} +
+
+ {template.overview} +
+
+ )} + + {template.categories.length > 0 && ( +
+ {template.categories.map(cat => ( + + {translateCategory(cat)} + + ))} +
+ )} +
+ )} +
+ +
+ + +
+
+
+ ) +} + +export default ImportFromMarketplaceTemplateModal diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index 9f23e42bb9..9623299336 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -9,6 +9,7 @@ import useDocumentTitle from '@/hooks/use-document-title' import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode } from '@/models/app' import dynamic from '@/next/dynamic' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchAppDetail } from '@/service/explore' import { trackCreateApp } from '@/utils/create-app-tracking' import List from './list' @@ -22,11 +23,16 @@ type AppsProps = { const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false }) const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false }) +const ImportFromMarketplaceTemplateModal = dynamic(() => import('./import-from-marketplace-template-modal'), { ssr: false }) const Apps = ({ pageType = 'apps', }: AppsProps) => { const { t } = useTranslation() + const searchParams = useSearchParams() + const { replace } = useRouter() + const templateId = searchParams.get('template-id') + const templateDismissedRef = useRef(false) useDocumentTitle(pageType === 'apps' ? t('menus.apps', { ns: 'common' }) @@ -68,6 +74,14 @@ const Apps = ({ const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false) + const handleCloseTemplateModal = useCallback(() => { + templateDismissedRef.current = true + const params = new URLSearchParams(searchParams.toString()) + params.delete('template-id') + const query = params.toString() + replace(query ? `?${query}` : window.location.pathname, { scroll: false }) + }, [searchParams, replace]) + const { handleImportDSL, handleImportDSLConfirm, @@ -84,6 +98,22 @@ const Apps = ({ }) }, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp]) + const handleMarketplaceTemplateConfirm = useCallback(async (dslContent: string) => { + await handleImportDSL({ + mode: DSLImportMode.YAML_CONTENT, + yaml_content: dslContent, + }, { + onSuccess: () => { + handleCloseTemplateModal() + onSuccess() + }, + onPending: () => { + handleCloseTemplateModal() + setShowDSLConfirmModal(true) + }, + }) + }, [handleImportDSL, handleCloseTemplateModal, onSuccess]) + const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, icon_type, @@ -162,6 +192,14 @@ const Apps = ({ onHide={() => setIsShowCreateModal(false)} /> )} + + {templateId && !templateDismissedRef.current && ( + + )}
) diff --git a/web/app/components/base/app-icon-picker/__tests__/index.spec.tsx b/web/app/components/base/app-icon-picker/__tests__/index.spec.tsx index 07dd809f41..7f452e64e9 100644 --- a/web/app/components/base/app-icon-picker/__tests__/index.spec.tsx +++ b/web/app/components/base/app-icon-picker/__tests__/index.spec.tsx @@ -1,3 +1,4 @@ +import type { ComponentProps } from 'react' import type { Area } from 'react-easy-crop' import type { ImageFile } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' @@ -122,11 +123,11 @@ describe('AppIconPicker', () => { }) } - const renderPicker = () => { + const renderPicker = (props: Partial> = {}) => { const onSelect = vi.fn() const onClose = vi.fn() - const { container } = render() + const { container } = render() return { onSelect, onClose, container } } @@ -220,6 +221,20 @@ describe('AppIconPicker', () => { expect(onSelect).not.toHaveBeenCalled() }) + + it('should submit the initial emoji when provided', async () => { + const { onSelect } = renderPicker({ initialEmoji: { icon: 'rabbit', background: '#E4FBCC' } }) + + await userEvent.click(screen.getByText(/ok/i)) + + await waitFor(() => { + expect(onSelect).toHaveBeenCalledWith({ + type: 'emoji', + icon: 'rabbit', + background: '#E4FBCC', + }) + }) + }) }) describe('Image Upload', () => { diff --git a/web/app/components/base/app-icon-picker/index.tsx b/web/app/components/base/app-icon-picker/index.tsx index 77bc0cd434..64a88f16e1 100644 --- a/web/app/components/base/app-icon-picker/index.tsx +++ b/web/app/components/base/app-icon-picker/index.tsx @@ -34,12 +34,17 @@ export type AppIconSelection = AppIconEmojiSelection | AppIconImageSelection type AppIconPickerProps = { onSelect?: (payload: AppIconSelection) => void onClose?: () => void + initialEmoji?: { + icon: string + background?: string | null + } className?: string } const AppIconPicker: FC = ({ onSelect, onClose, + initialEmoji, className, }) => { const { t } = useTranslation() @@ -138,7 +143,14 @@ const AppIconPicker: FC = ({
)} - {activeTab === 'emoji' && } + {activeTab === 'emoji' && ( + + )} {activeTab === 'image' && } diff --git a/web/app/components/base/emoji-picker/Inner.tsx b/web/app/components/base/emoji-picker/Inner.tsx index e2595c5efb..36a98f7dd1 100644 --- a/web/app/components/base/emoji-picker/Inner.tsx +++ b/web/app/components/base/emoji-picker/Inner.tsx @@ -45,20 +45,21 @@ type IEmojiPickerInnerProps = { } const EmojiPickerInner: FC = ({ + emoji, + background, onSelect, className, }) => { const { categories } = data as EmojiMartData - const [selectedEmoji, setSelectedEmoji] = useState('') - const [selectedBackground, setSelectedBackground] = useState(backgroundColors[0]) - const [showStyleColors, setShowStyleColors] = useState(false) + const [selectedEmoji, setSelectedEmoji] = useState(emoji || '') + const [selectedBackground, setSelectedBackground] = useState(background || backgroundColors[0]) + const [showStyleColors, setShowStyleColors] = useState(!!emoji) const [searchedEmojis, setSearchedEmojis] = useState([]) const [isSearching, setIsSearching] = useState(false) React.useEffect(() => { if (selectedEmoji) { - setShowStyleColors(true) /* v8 ignore next 2 - @preserve */ if (selectedBackground) onSelect?.(selectedEmoji, selectedBackground) @@ -105,6 +106,7 @@ const EmojiPickerInner: FC = ({ className="inline-flex h-10 w-10 items-center justify-center rounded-lg" onClick={() => { setSelectedEmoji(emoji) + setShowStyleColors(true) }} >
@@ -130,6 +132,7 @@ const EmojiPickerInner: FC = ({ className="inline-flex h-10 w-10 items-center justify-center rounded-lg" onClick={() => { setSelectedEmoji(emoji) + setShowStyleColors(true) }} >
diff --git a/web/app/components/base/emoji-picker/__tests__/Inner.spec.tsx b/web/app/components/base/emoji-picker/__tests__/Inner.spec.tsx index f0cf3091d7..41683d7af3 100644 --- a/web/app/components/base/emoji-picker/__tests__/Inner.spec.tsx +++ b/web/app/components/base/emoji-picker/__tests__/Inner.spec.tsx @@ -45,6 +45,15 @@ describe('EmojiPickerInner', () => { expect(screen.getByText('food'))!.toBeInTheDocument() expect(screen.getByPlaceholderText('Search emojis...'))!.toBeInTheDocument() }) + + it('initializes selected emoji and background when provided', async () => { + render() + + expect(screen.getByText('Choose Style'))!.toBeInTheDocument() + await waitFor(() => { + expect(mockOnSelect).toHaveBeenCalledWith('rabbit', '#E4FBCC') + }) + }) }) describe('User Interactions', () => { diff --git a/web/app/components/base/infotip/index.tsx b/web/app/components/base/infotip/index.tsx index b97b499af3..ce818fe030 100644 --- a/web/app/components/base/infotip/index.tsx +++ b/web/app/components/base/infotip/index.tsx @@ -73,7 +73,7 @@ export function Infotip({ /> {children} diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx index 3c734700a7..a09e25f6e9 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx @@ -599,6 +599,48 @@ describe('ComponentPicker (component-picker-block/index.tsx)', () => { }) }) + it('defaults to the first workflow variable and removes the full slash query when selecting by keyboard', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + const workflowVariableBlock = makeWorkflowVariableBlock({}, [ + makeWorkflowVarNode('node-1', 'Node 1', [ + makeWorkflowNodeVar('first_value', VarType.string), + makeWorkflowNodeVar('second_value', VarType.string), + ]), + ]) + + render(( + + )) + + const editor = await waitForEditor(captures) + const dispatchSpy = vi.spyOn(editor, 'dispatchCommand') + + await setEditorText(editor, '/e', true) + await flushNextTick() + + const firstItem = screen.getByText('first_value').closest('[data-selected]') + const secondItem = screen.getByText('second_value').closest('[data-selected]') + + expect(firstItem).toHaveAttribute('data-selected', 'true') + expect(secondItem).toHaveAttribute('data-selected', 'false') + + fireEvent.keyDown(document, { key: 'ArrowDown' }) + + expect(firstItem).toHaveAttribute('data-selected', 'false') + expect(secondItem).toHaveAttribute('data-selected', 'true') + + fireEvent.keyDown(document, { key: 'Enter' }) + + expect(dispatchSpy).toHaveBeenCalledWith(INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, ['node-1', 'second_value']) + await waitFor(() => expect(readEditorText(editor)).not.toContain('/e')) + }) + it('skips removing the trigger when selection is null (needRemove is null) and still dispatches', async () => { const captures: Captures = { editor: null, eventEmitter: null } diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx index 5e983ed09a..503af4077d 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx @@ -7,6 +7,7 @@ import type { ExternalToolBlockType, HistoryBlockType, LastRunBlockType, + MenuTextMatch, QueryBlockType, RequestURLBlockType, VariableBlockType, @@ -89,14 +90,14 @@ const ComponentPicker = ({ ], }) const [editor] = useLexicalComposerContext() - const triggerMatchRef = useRef(null) + const triggerMatchRef = useRef(null) const baseCheckForTriggerMatch = useBasicTypeaheadTriggerMatch(triggerString, { minLength: 0, maxLength: 75, }) const checkForTriggerMatch = useCallback((text: string, editor: LexicalEditor) => { const match = baseCheckForTriggerMatch(text, editor) - triggerMatchRef.current = match?.matchingString ?? null + triggerMatchRef.current = match return match }, [baseCheckForTriggerMatch]) @@ -183,7 +184,8 @@ const ComponentPicker = ({ const handleSelectWorkflowVariable = useCallback((variables: string[]) => { editor.update(() => { - const needRemove = $splitNodeContainingQuery(checkForTriggerMatch(triggerString, editor)!) + const currentTriggerMatch = triggerMatchRef.current ?? checkForTriggerMatch(triggerString, editor) + const needRemove = currentTriggerMatch ? $splitNodeContainingQuery(currentTriggerMatch) : null if (needRemove) needRemove.remove() }) @@ -214,7 +216,7 @@ const ComponentPicker = ({ anchorElementRef, { options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex }, ) => { - const effectiveQueryString = triggerMatchRef.current ?? queryString + const effectiveQueryString = triggerMatchRef.current?.matchingString ?? queryString if (blurHidden) return null diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx index bdfe7a41c1..d9ee9bcc1e 100644 --- a/web/app/components/develop/template/template_advanced_chat.en.mdx +++ b/web/app/components/develop/template/template_advanced_chat.en.mdx @@ -191,6 +191,24 @@ Chat applications support session persistence, allowing previous chat history to - `total_price` (decimal) optional Total cost - `currency` (string) optional e.g. `USD` / `RMB` - `created_at` (timestamp) timestamp of start, e.g., 1705395332 + - `event: human_input_required` Workflow paused and requires Human-in-the-Loop input + - `task_id` (string) Task ID, used for request tracking + - `workflow_run_id` (string) Unique ID of workflow execution + - `event` (string) fixed to `human_input_required` + - `data` (object) detail + - `form_id` (string) Human input form ID + - `node_id` (string) Human input node ID + - `node_title` (string) Human input node title + - `form_content` (string) Rendered form content + - `inputs` (array[object]) Input field definitions + - `actions` (array[object]) User action buttons + - `id` (string) Action ID + - `title` (string) Button text + - `button_style` (string) Button style + - `display_in_ui` (bool) Whether this form should be shown in UI + - `form_token` (string) Token used by `/form/human_input/:form_token` APIs + - `resolved_default_values` (object) Runtime-resolved default values + - `expiration_time` (timestamp) Form expiration time (Unix seconds) - `event: workflow_finished` workflow execution ends, success or failure in different states in the same event - `task_id` (string) Task ID, used for request tracking and the below Stop Generate API - `workflow_run_id` (string) Unique ID of workflow execution @@ -254,6 +272,12 @@ Chat applications support session persistence, allowing previous chat history to }'`} /> ### Blocking Mode + Blocking mode can return a normal chat message or a paused workflow response. + + When advanced chat pauses for Human-in-the-Loop, `event` becomes `workflow_paused`. + The payload still includes `message_id`, `conversation_id`, `answer`, and `workflow_run_id`, and `data` adds `paused_nodes` plus `reasons`. + For `human_input_required`, each reason contains the `form_id` and its `expiration_time`. + ```json {{ title: 'Response' }} { @@ -296,6 +320,83 @@ Chat applications support session persistence, allowing previous chat history to } ``` + + ```json {{ title: 'Paused Response Example' }} + { + "event": "workflow_paused", + "task_id": "8a9cbfcf-e7e0-4b17-aeef-24de57a2659a", + "id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "message_id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "conversation_id": "098e19be-356a-435d-9ec3-a406f4f1a97a", + "mode": "advanced-chat", + "answer": "", + "metadata": { + "annotation_reply": null, + "retriever_resources": [], + "usage": null + }, + "created_at": 1776074715, + "workflow_run_id": "7a4d6509-8a65-4c7d-a4fd-cf081dcf169f", + "data": { + "id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "mode": "advanced-chat", + "conversation_id": "098e19be-356a-435d-9ec3-a406f4f1a97a", + "message_id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "workflow_run_id": "7a4d6509-8a65-4c7d-a4fd-cf081dcf169f", + "answer": "", + "metadata": { + "annotation_reply": null, + "retriever_resources": [], + "usage": null + }, + "created_at": 1776074715, + "paused_nodes": [ + "1775724080699" + ], + "reasons": [ + { + "form_id": "019d864d-6f55-752c-9f4c-feee67508d5b", + "form_content": "this is form 2:\n\n{{#$output.some_field_2#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field_2", + "default": { + "type": "constant", + "selector": [], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "yes", + "button_style": "default" + }, + { + "id": "reject", + "title": "no", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775724080699", + "node_title": "Human Input 2", + "resolved_default_values": {}, + "form_token": "0dvwTdpTFXgCZmAo2FoiJ5", + "type": "human_input_required", + "expiration_time": 1776333914 + } + ], + "status": "paused", + "elapsed_time": 0.034081, + "total_tokens": 0, + "total_steps": 2 + } + } + ``` + ### Streaming Mode ```streaming {{ title: 'Response' }} @@ -314,6 +415,220 @@ Chat applications support session persistence, allowing previous chat history to data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` + Streaming mode can also pause for Human-in-the-Loop. In that case, the SSE stream emits `human_input_required` first and then `workflow_paused`. + + + ```streaming {{ title: 'Paused Streaming Response Example' }} + event: ping + + data: { + "event": "workflow_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "created_at": 1776129228, + "reason": "initial" + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "node_finished", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "inputs_truncated": false, + "process_data": {}, + "process_data_truncated": false, + "outputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "outputs_truncated": false, + "status": "succeeded", + "error": null, + "elapsed_time": 0.000097, + "execution_metadata": null, + "created_at": 1776129228, + "finished_at": 1776129228, + "files": [], + "iteration_id": null, + "loop_id": null + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "c09ff568-1d55-4f0d-9a07-512bcbfeb289", + "node_id": "1775717346519", + "node_type": "human-input", + "title": "Human Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "human_input_required", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "node_id": "1775717346519", + "node_title": "Human Input", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "expiration_time": 1776388428 + } + } + + data: { + "event": "workflow_paused", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "paused_nodes": [ + "1775717346519" + ], + "outputs": {}, + "reasons": [ + { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775717346519", + "node_title": "Human Input", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "type": "human_input_required", + "expiration_time": 1776388428 + } + ], + "status": "paused", + "created_at": 1776129228, + "elapsed_time": 0.070478, + "total_tokens": 0, + "total_steps": 2 + } + } + ``` + @@ -578,6 +893,198 @@ Chat applications support session persistence, allowing previous chat history to --- + + + + Retrieve a pending Human-in-the-Loop form by `form_token`. + + Use this endpoint when streaming returns `human_input_required` with a `form_token`. + + ### Path + - `form_token` (string) Required, token returned by the pause event. + + ### Response + - `form_content` (string) Rendered form content (markdown/plain text) + - `inputs` (array[object]) Form input definitions + - `resolved_default_values` (object) Default values resolved to strings + - `user_actions` (array[object]) Action buttons + - `expiration_time` (timestamp) Form expiration time (Unix seconds) + + ### Errors + - 404, form not found or does not belong to current app + - 412, `human_input_form_submitted`, form already submitted + - 412, `human_input_form_expired`, form expired + + + + + + ```json {{ title: 'Response' }} + { + "form_content": "Please confirm the final answer: {{#$output.answer#}}", + "inputs": [ + { + "label": "Answer", + "type": "text-input", + "required": true, + "output_variable_name": "answer" + } + ], + "resolved_default_values": { + "answer": "Initial value" + }, + "user_actions": [ + { "id": "approve", "title": "Approve", "button_style": "primary" }, + { "id": "reject", "title": "Reject", "button_style": "warning" } + ], + "expiration_time": 1735689600 + } + ``` + + + + +--- + + + + + Submit a pending Human-in-the-Loop form. + + ### Path + - `form_token` (string) Required, token returned by the pause event. + + ### Request Body + - `inputs` (object) Required, key/value pairs for form fields. + - `action` (string) Required, selected action ID from `user_actions`. + - `user` (string) Required, end-user identifier. + + ### Response + Returns an empty object on success. + + ### Errors + - 400, `invalid_form_data`, submitted data does not match the form schema + - 404, form not found or does not belong to current app + - 412, `human_input_form_submitted`, form already submitted + - 412, `human_input_form_expired`, form expired + + + + + + ```json {{ title: 'Response' }} + {} + ``` + + + + +--- + + + + + Continue receiving workflow events after submitting a human input form. + + This endpoint returns `text/event-stream` and can be used to observe resumed execution until completion. + + ### Path + - `task_id` (string) Required, workflow run ID (`workflow_run_id`). + + ### Query + - `user` (string) Required, end-user identifier. + - `include_state_snapshot` (bool) Optional, set to `true` to replay from persisted state snapshot before continuing with live events. + - `continue_on_pause` (bool) Optional, set to `true` to keep the stream open across `workflow_paused` events until `workflow_finished`. + + ### Response + Server-Sent Events stream (`text/event-stream`). + Typical events include `workflow_paused`, `node_started`, `node_finished`, `human_input_form_filled`, `human_input_form_timeout`, and `workflow_finished`. + If the workflow has already finished when you call this endpoint, the server returns a single finished event immediately. + + + + + + ```streaming {{ title: 'Response' }} + event: ping + + data: {"event":"workflow_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","sys.timestamp":1776087863},"created_at":1776087863,"reason":"initial"}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"succeeded","error":null,"elapsed_time":0.00032,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"paused","error":null,"elapsed_time":0.007381,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_paused","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","paused_nodes":["1775717346519"],"outputs":{},"reasons":[{"form_id":"019d8716-0fde-75da-8207-1458ccde76e5","form_content":"this is form 1:\n{{#$output.some_field#}}\n","inputs":[{"type":"paragraph","output_variable_name":"some_field","default":{"type":"variable","selector":["sys","workflow_run_id"],"value":""}}],"actions":[{"id":"approve","title":"YES","button_style":"default"},{"id":"reject","title":"NO","button_style":"default"}],"display_in_ui":true,"node_id":"1775717346519","node_title":"Human Input","resolved_default_values":{"some_field":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"form_token":"n7hFG4ZDYdGcgZ5VDc7EGM","type":"human_input_required"}],"status":"paused","created_at":1776087863,"elapsed_time":0.0,"total_tokens":0,"total_steps":2}} + + data: {"event":"workflow_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"created_at":1776087877,"reason":"resumption"}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"human_input_form_filled","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"node_id":"1775717346519","node_title":"Human Input","rendered_content":"this is form 1:\nfield 1 filled!\n","action_id":"approve","action_text":"YES"}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":{},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"some_field":"field 1 filled!","some_field_2":"from bruno with love","__action_id":"approve","__rendered_content":"this is form 1:\nfield 1 filled!\n"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.004431,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"text_chunk","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"text":"field 1 filled!","from_variable_selector":["1775717350710","output"]}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":{"some_field":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.264614,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":{"output":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.00003,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","status":"succeeded","outputs":{"output":"field 1 filled!"},"error":null,"elapsed_time":0.364935,"total_tokens":0,"total_steps":5,"created_by":{"id":"7932d34c-dcf4-4fba-b770-f2a9de88c0a0","user":"abc-123"},"created_at":1776087877,"finished_at":1776087877,"exceptions_count":0,"files":[]}} + ``` + + + + +--- + ### ブロッキングモード + ブロッキングモードでは、通常のチャット応答、または一時停止したワークフロー応答のいずれかが返されます。 + + Advanced Chat が Human-in-the-Loop で一時停止すると、`event` は `workflow_paused` になります。 + それでもペイロードには `message_id`、`conversation_id`、`answer`、`workflow_run_id` が含まれ、`data` には `paused_nodes` と `reasons` が追加されます。 + `human_input_required` の各 reason には `form_id` と `expiration_time` が含まれます。 + ```json {{ title: '応答' }} { @@ -296,6 +320,83 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' } ``` + + ```json {{ title: '一時停止レスポンス例' }} + { + "event": "workflow_paused", + "task_id": "8a9cbfcf-e7e0-4b17-aeef-24de57a2659a", + "id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "message_id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "conversation_id": "098e19be-356a-435d-9ec3-a406f4f1a97a", + "mode": "advanced-chat", + "answer": "", + "metadata": { + "annotation_reply": null, + "retriever_resources": [], + "usage": null + }, + "created_at": 1776074715, + "workflow_run_id": "7a4d6509-8a65-4c7d-a4fd-cf081dcf169f", + "data": { + "id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "mode": "advanced-chat", + "conversation_id": "098e19be-356a-435d-9ec3-a406f4f1a97a", + "message_id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "workflow_run_id": "7a4d6509-8a65-4c7d-a4fd-cf081dcf169f", + "answer": "", + "metadata": { + "annotation_reply": null, + "retriever_resources": [], + "usage": null + }, + "created_at": 1776074715, + "paused_nodes": [ + "1775724080699" + ], + "reasons": [ + { + "form_id": "019d864d-6f55-752c-9f4c-feee67508d5b", + "form_content": "this is form 2:\n\n{{#$output.some_field_2#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field_2", + "default": { + "type": "constant", + "selector": [], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "yes", + "button_style": "default" + }, + { + "id": "reject", + "title": "no", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775724080699", + "node_title": "Human Input 2", + "resolved_default_values": {}, + "form_token": "0dvwTdpTFXgCZmAo2FoiJ5", + "type": "human_input_required", + "expiration_time": 1776333914 + } + ], + "status": "paused", + "elapsed_time": 0.034081, + "total_tokens": 0, + "total_steps": 2 + } + } + ``` + ### ストリーミングモード ```streaming {{ title: '応答' }} @@ -314,6 +415,220 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` + ストリーミングモードでも Human-in-the-Loop により一時停止する場合があります。その場合、SSE ストリームではまず `human_input_required` が送られ、その後に `workflow_paused` が送られます。 + + + ```streaming {{ title: '一時停止ストリーミングレスポンス例' }} + event: ping + + data: { + "event": "workflow_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "created_at": 1776129228, + "reason": "initial" + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "node_finished", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "inputs_truncated": false, + "process_data": {}, + "process_data_truncated": false, + "outputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "outputs_truncated": false, + "status": "succeeded", + "error": null, + "elapsed_time": 0.000097, + "execution_metadata": null, + "created_at": 1776129228, + "finished_at": 1776129228, + "files": [], + "iteration_id": null, + "loop_id": null + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "c09ff568-1d55-4f0d-9a07-512bcbfeb289", + "node_id": "1775717346519", + "node_type": "human-input", + "title": "Human Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "human_input_required", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "node_id": "1775717346519", + "node_title": "Human Input", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "expiration_time": 1776388428 + } + } + + data: { + "event": "workflow_paused", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "paused_nodes": [ + "1775717346519" + ], + "outputs": {}, + "reasons": [ + { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775717346519", + "node_title": "Human Input", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "type": "human_input_required", + "expiration_time": 1776388428 + } + ], + "status": "paused", + "created_at": 1776129228, + "elapsed_time": 0.070478, + "total_tokens": 0, + "total_steps": 2 + } + } + ``` + @@ -579,6 +894,198 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --- + + + + `form_token` から保留中の Human-in-the-Loop フォームを取得します。 + + ストリーミングイベントで `human_input_required`(`form_token` を含む)が返された際に使用します。 + + ### パス + - `form_token` (string) 必須、一時停止イベントで返されたフォームトークン + + ### 応答 + - `form_content` (string) レンダリング済みフォーム内容(markdown/plain text) + - `inputs` (array[object]) 入力項目定義 + - `resolved_default_values` (object) 解決済みデフォルト値(文字列) + - `user_actions` (array[object]) アクションボタン一覧 + - `expiration_time` (timestamp) フォーム有効期限(Unix 秒) + + ### エラー + - 404, フォームが存在しない、または現在のアプリに属していない + - 412, `human_input_form_submitted`, 既に送信済み + - 412, `human_input_form_expired`, 期限切れ + + + + + + ```json {{ title: '応答' }} + { + "form_content": "最終回答を確認してください: {{#$output.answer#}}", + "inputs": [ + { + "label": "回答", + "type": "text-input", + "required": true, + "output_variable_name": "answer" + } + ], + "resolved_default_values": { + "answer": "初期値" + }, + "user_actions": [ + { "id": "approve", "title": "承認", "button_style": "primary" }, + { "id": "reject", "title": "却下", "button_style": "warning" } + ], + "expiration_time": 1735689600 + } + ``` + + + + +--- + + + + + 保留中の Human-in-the-Loop フォームを送信します。 + + ### パス + - `form_token` (string) 必須、一時停止イベントで返されたフォームトークン + + ### リクエストボディ + - `inputs` (object) 必須、フォーム項目の key/value + - `action` (string) 必須、`user_actions` から選択したアクション ID + - `user` (string) 必須、エンドユーザー識別子 + + ### 応答 + 成功時は空オブジェクトを返します。 + + ### エラー + - 400, `invalid_form_data`, 送信データがフォームスキーマに一致しない + - 404, フォームが存在しない、または現在のアプリに属していない + - 412, `human_input_form_submitted`, 既に送信済み + - 412, `human_input_form_expired`, 期限切れ + + + + + + ```json {{ title: '応答' }} + {} + ``` + + + + +--- + + + + + Human Input フォーム送信後に、ワークフロー再開後のイベントを継続受信します。 + + このエンドポイントは `text/event-stream` を返し、完了までイベントを購読できます。 + + ### パス + - `task_id` (string) 必須、workflow 実行 ID(`workflow_run_id`) + + ### クエリ + - `user` (string) 必須、エンドユーザー識別子 + - `include_state_snapshot` (bool) 任意、`true` の場合は永続化済み状態スナップショットを先に再生してからリアルタイムイベントへ移行 + - `continue_on_pause` (bool) 任意、`true` にすると `workflow_paused` イベントをまたいでもストリームを維持し、`workflow_finished` で終了します + + ### 応答 + Server-Sent Events ストリーム(`text/event-stream`)。 + 主なイベントは `workflow_paused`、`node_started`、`node_finished`、`human_input_form_filled`、`human_input_form_timeout`、`workflow_finished` です。 + 呼び出し時点でワークフローがすでに完了している場合、このエンドポイントは完了イベントを 1 件だけ即座に返します。 + + + + + + ```streaming {{ title: '応答' }} + event: ping + + data: {"event":"workflow_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","sys.timestamp":1776087863},"created_at":1776087863,"reason":"initial"}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"succeeded","error":null,"elapsed_time":0.00032,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"paused","error":null,"elapsed_time":0.007381,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_paused","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","paused_nodes":["1775717346519"],"outputs":{},"reasons":[{"form_id":"019d8716-0fde-75da-8207-1458ccde76e5","form_content":"this is form 1:\n{{#$output.some_field#}}\n","inputs":[{"type":"paragraph","output_variable_name":"some_field","default":{"type":"variable","selector":["sys","workflow_run_id"],"value":""}}],"actions":[{"id":"approve","title":"YES","button_style":"default"},{"id":"reject","title":"NO","button_style":"default"}],"display_in_ui":true,"node_id":"1775717346519","node_title":"Human Input","resolved_default_values":{"some_field":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"form_token":"n7hFG4ZDYdGcgZ5VDc7EGM","type":"human_input_required"}],"status":"paused","created_at":1776087863,"elapsed_time":0.0,"total_tokens":0,"total_steps":2}} + + data: {"event":"workflow_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"created_at":1776087877,"reason":"resumption"}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"human_input_form_filled","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"node_id":"1775717346519","node_title":"Human Input","rendered_content":"this is form 1:\nfield 1 filled!\n","action_id":"approve","action_text":"YES"}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":{},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"some_field":"field 1 filled!","some_field_2":"from bruno with love","__action_id":"approve","__rendered_content":"this is form 1:\nfield 1 filled!\n"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.004431,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"text_chunk","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"text":"field 1 filled!","from_variable_selector":["1775717350710","output"]}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":{"some_field":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.264614,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":{"output":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.00003,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","status":"succeeded","outputs":{"output":"field 1 filled!"},"error":null,"elapsed_time":0.364935,"total_tokens":0,"total_steps":5,"created_by":{"id":"7932d34c-dcf4-4fba-b770-f2a9de88c0a0","user":"abc-123"},"created_at":1776087877,"finished_at":1776087877,"exceptions_count":0,"files":[]}} + ``` + + + + +--- + ### 阻塞模式 + 阻塞模式可能返回普通聊天响应,也可能返回暂停中的工作流响应。 + + 当 Advanced Chat 因 Human-in-the-Loop 暂停时,`event` 会变为 `workflow_paused`。 + 响应仍然包含 `message_id`、`conversation_id`、`answer` 和 `workflow_run_id`,并且 `data` 中会新增 `paused_nodes` 和 `reasons`。 + 对于 `human_input_required`,每个 reason 都会包含 `form_id` 和 `expiration_time`。 + ```json {{ title: 'Response' }} { @@ -295,6 +319,83 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' } ``` + + ```json {{ title: 'Paused Response Example' }} + { + "event": "workflow_paused", + "task_id": "8a9cbfcf-e7e0-4b17-aeef-24de57a2659a", + "id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "message_id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "conversation_id": "098e19be-356a-435d-9ec3-a406f4f1a97a", + "mode": "advanced-chat", + "answer": "", + "metadata": { + "annotation_reply": null, + "retriever_resources": [], + "usage": null + }, + "created_at": 1776074715, + "workflow_run_id": "7a4d6509-8a65-4c7d-a4fd-cf081dcf169f", + "data": { + "id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "mode": "advanced-chat", + "conversation_id": "098e19be-356a-435d-9ec3-a406f4f1a97a", + "message_id": "31714374-88cb-485f-9fa4-e3ab2a9ed95e", + "workflow_run_id": "7a4d6509-8a65-4c7d-a4fd-cf081dcf169f", + "answer": "", + "metadata": { + "annotation_reply": null, + "retriever_resources": [], + "usage": null + }, + "created_at": 1776074715, + "paused_nodes": [ + "1775724080699" + ], + "reasons": [ + { + "form_id": "019d864d-6f55-752c-9f4c-feee67508d5b", + "form_content": "this is form 2:\n\n{{#$output.some_field_2#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field_2", + "default": { + "type": "constant", + "selector": [], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "yes", + "button_style": "default" + }, + { + "id": "reject", + "title": "no", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775724080699", + "node_title": "Human Input 2", + "resolved_default_values": {}, + "form_token": "0dvwTdpTFXgCZmAo2FoiJ5", + "type": "human_input_required", + "expiration_time": 1776333914 + } + ], + "status": "paused", + "elapsed_time": 0.034081, + "total_tokens": 0, + "total_steps": 2 + } + } + ``` + ### 流式模式 ```streaming {{ title: 'Response' }} @@ -313,6 +414,220 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` + 流式模式同样可能因为 Human-in-the-Loop 而暂停。此时 SSE 流会先返回 `human_input_required`,随后返回 `workflow_paused`。 + + + ```streaming {{ title: 'Paused Streaming Response Example' }} + event: ping + + data: { + "event": "workflow_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "created_at": 1776129228, + "reason": "initial" + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "node_finished", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "inputs_truncated": false, + "process_data": {}, + "process_data_truncated": false, + "outputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "outputs_truncated": false, + "status": "succeeded", + "error": null, + "elapsed_time": 0.000097, + "execution_metadata": null, + "created_at": 1776129228, + "finished_at": 1776129228, + "files": [], + "iteration_id": null, + "loop_id": null + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "c09ff568-1d55-4f0d-9a07-512bcbfeb289", + "node_id": "1775717346519", + "node_type": "human-input", + "title": "Human Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "human_input_required", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "node_id": "1775717346519", + "node_title": "Human Input", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "expiration_time": 1776388428 + } + } + + data: { + "event": "workflow_paused", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "paused_nodes": [ + "1775717346519" + ], + "outputs": {}, + "reasons": [ + { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775717346519", + "node_title": "Human Input", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "type": "human_input_required", + "expiration_time": 1776388428 + } + ], + "status": "paused", + "created_at": 1776129228, + "elapsed_time": 0.070478, + "total_tokens": 0, + "total_steps": 2 + } + } + ``` + @@ -572,6 +887,198 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --- + + + + 通过 `form_token` 获取待处理的 Human-in-the-Loop 表单。 + + 当流式事件返回 `human_input_required`(包含 `form_token`)时,可调用此接口拉取表单详情。 + + ### Path + - `form_token` (string) 必填,暂停事件返回的表单 token + + ### Response + - `form_content` (string) 已渲染的表单内容(markdown/plain text) + - `inputs` (array[object]) 表单输入项定义 + - `resolved_default_values` (object) 已解析的默认值(字符串) + - `user_actions` (array[object]) 操作按钮列表 + - `expiration_time` (timestamp) 表单过期时间(Unix 秒) + + ### Errors + - 404,表单不存在或不属于当前应用 + - 412,`human_input_form_submitted`,表单已被提交 + - 412,`human_input_form_expired`,表单已过期 + + + + + + ```json {{ title: 'Response' }} + { + "form_content": "请确认最终结果:{{#$output.answer#}}", + "inputs": [ + { + "label": "答案", + "type": "text-input", + "required": true, + "output_variable_name": "answer" + } + ], + "resolved_default_values": { + "answer": "初始值" + }, + "user_actions": [ + { "id": "approve", "title": "通过", "button_style": "primary" }, + { "id": "reject", "title": "拒绝", "button_style": "warning" } + ], + "expiration_time": 1735689600 + } + ``` + + + + +--- + + + + + 提交待处理的 Human-in-the-Loop 表单。 + + ### Path + - `form_token` (string) 必填,暂停事件返回的表单 token + + ### Request Body + - `inputs` (object) 必填,表单字段的 key/value + - `action` (string) 必填,从 `user_actions` 中选择的动作 ID + - `user` (string) 必填,终端用户标识 + + ### Response + 成功时返回空对象。 + + ### Errors + - 400,`invalid_form_data`,提交数据与表单 schema 不匹配 + - 404,表单不存在或不属于当前应用 + - 412,`human_input_form_submitted`,表单已被提交 + - 412,`human_input_form_expired`,表单已过期 + + + + + + ```json {{ title: 'Response' }} + {} + ``` + + + + +--- + + + + + 在提交人工输入表单后,继续订阅工作流后续执行事件。 + + 返回 `text/event-stream`,可持续接收直到工作流结束。 + + ### Path + - `task_id` (string) 必填,workflow 运行 ID(`workflow_run_id`) + + ### Query + - `user` (string) 必填,终端用户标识 + - `include_state_snapshot` (bool) 可选,设为 `true` 时会先回放持久化状态快照,再继续实时事件 + - `continue_on_pause` (bool) 可选,设为 `true` 时,流会在 `workflow_paused` 事件之间保持连接,直到 `workflow_finished` 才结束 + + ### Response + Server-Sent Events 流(`text/event-stream`)。 + 常见事件包括 `workflow_paused`、`node_started`、`node_finished`、`human_input_form_filled`、`human_input_form_timeout`、`workflow_finished`。 + 如果调用该接口时工作流已经结束,服务端会立即返回单个完成事件。 + + + + + + ```streaming {{ title: 'Response' }} + event: ping + + data: {"event":"workflow_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","sys.timestamp":1776087863},"created_at":1776087863,"reason":"initial"}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"succeeded","error":null,"elapsed_time":0.00032,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"paused","error":null,"elapsed_time":0.007381,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_paused","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","paused_nodes":["1775717346519"],"outputs":{},"reasons":[{"form_id":"019d8716-0fde-75da-8207-1458ccde76e5","form_content":"this is form 1:\n{{#$output.some_field#}}\n","inputs":[{"type":"paragraph","output_variable_name":"some_field","default":{"type":"variable","selector":["sys","workflow_run_id"],"value":""}}],"actions":[{"id":"approve","title":"YES","button_style":"default"},{"id":"reject","title":"NO","button_style":"default"}],"display_in_ui":true,"node_id":"1775717346519","node_title":"Human Input","resolved_default_values":{"some_field":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"form_token":"n7hFG4ZDYdGcgZ5VDc7EGM","type":"human_input_required"}],"status":"paused","created_at":1776087863,"elapsed_time":0.0,"total_tokens":0,"total_steps":2}} + + data: {"event":"workflow_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"created_at":1776087877,"reason":"resumption"}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"human_input_form_filled","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"node_id":"1775717346519","node_title":"Human Input","rendered_content":"this is form 1:\nfield 1 filled!\n","action_id":"approve","action_text":"YES"}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":{},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"some_field":"field 1 filled!","some_field_2":"from bruno with love","__action_id":"approve","__rendered_content":"this is form 1:\nfield 1 filled!\n"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.004431,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"text_chunk","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"text":"field 1 filled!","from_variable_selector":["1775717350710","output"]}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":{"some_field":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.264614,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":{"output":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.00003,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","status":"succeeded","outputs":{"output":"field 1 filled!"},"error":null,"elapsed_time":0.364935,"total_tokens":0,"total_steps":5,"created_by":{"id":"7932d34c-dcf4-4fba-b770-f2a9de88c0a0","user":"abc-123"},"created_at":1776087877,"finished_at":1776087877,"exceptions_count":0,"files":[]}} + ``` + + + + +--- + ### Blocking Mode + Blocking mode can return either a completed workflow result or a paused workflow result. + + When execution pauses for Human-in-the-Loop, the response still includes `workflow_run_id` and `task_id`, but `data.status` becomes `paused`. + The paused payload also includes `paused_nodes` and `reasons`. For `human_input_required`, each reason contains the `form_id` and its `expiration_time`. + ```json {{ title: 'Response' }} { @@ -236,6 +259,70 @@ Workflow applications offers non-session support and is ideal for translation, a } ``` + + ```json {{ title: 'Paused Response Example' }} + { + "task_id": "3938b985-f4c6-4806-87b6-215e0aca9d81", + "workflow_run_id": "4a80f375-682b-49c5-b199-e950aac4968f", + "data": { + "id": "4a80f375-682b-49c5-b199-e950aac4968f", + "workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "status": "paused", + "outputs": {}, + "error": null, + "elapsed_time": 0.035667, + "total_tokens": 0, + "total_steps": 2, + "created_at": 1776074783, + "finished_at": null, + "paused_nodes": [ + "1775717346519" + ], + "reasons": [ + { + "form_id": "019d864e-7a36-74a2-b94e-e5660c47f5a7", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775717346519", + "node_title": "Human Input", + "resolved_default_values": { + "some_field": "4a80f375-682b-49c5-b199-e950aac4968f" + }, + "form_token": "SZwvfmL47fTIsZynP2Jr9i", + "type": "human_input_required", + "expiration_time": 1776333983 + } + ] + } + } + ``` + ### Streaming Mode ```streaming {{ title: 'Response' }} @@ -247,6 +334,220 @@ Workflow applications offers non-session support and is ideal for translation, a data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` + Streaming mode can also pause for Human-in-the-Loop. In that case, the SSE stream emits `human_input_required` first and then `workflow_paused`. + + + ```streaming {{ title: 'Paused Streaming Response Example' }} + event: ping + + data: { + "event": "workflow_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "created_at": 1776129228, + "reason": "initial" + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "node_finished", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "inputs_truncated": false, + "process_data": {}, + "process_data_truncated": false, + "outputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "outputs_truncated": false, + "status": "succeeded", + "error": null, + "elapsed_time": 0.000097, + "execution_metadata": null, + "created_at": 1776129228, + "finished_at": 1776129228, + "files": [], + "iteration_id": null, + "loop_id": null + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "c09ff568-1d55-4f0d-9a07-512bcbfeb289", + "node_id": "1775717346519", + "node_type": "human-input", + "title": "Human Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "human_input_required", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "node_id": "1775717346519", + "node_title": "Human Input", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "expiration_time": 1776388428 + } + } + + data: { + "event": "workflow_paused", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "paused_nodes": [ + "1775717346519" + ], + "outputs": {}, + "reasons": [ + { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775717346519", + "node_title": "Human Input", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "type": "human_input_required", + "expiration_time": 1776388428 + } + ], + "status": "paused", + "created_at": 1776129228, + "elapsed_time": 0.070478, + "total_tokens": 0, + "total_steps": 2 + } + } + ``` + ```json {{ title: 'File upload sample code' }} import requests @@ -457,6 +758,24 @@ Workflow applications offers non-session support and is ideal for translation, a - `total_price` (decimal) optional total cost - `currency` (string) optional currency, such as `USD` / `RMB` - `created_at` (timestamp) timestamp of start, e.g., 1705395332 + - `event: human_input_required` Workflow paused and requires Human-in-the-Loop input + - `task_id` (string) Task ID, used for request tracking + - `workflow_run_id` (string) Unique ID of workflow execution + - `event` (string) fixed to `human_input_required` + - `data` (object) detail + - `form_id` (string) Human input form ID + - `node_id` (string) Human input node ID + - `node_title` (string) Human input node title + - `form_content` (string) Rendered form content + - `inputs` (array[object]) Input field definitions + - `actions` (array[object]) User action buttons + - `id` (string) Action ID + - `title` (string) Button text + - `button_style` (string) Button style + - `display_in_ui` (bool) Whether this form should be shown in UI + - `form_token` (string) Token used by `/form/human_input/:form_token` APIs + - `resolved_default_values` (object) Runtime-resolved default values + - `expiration_time` (timestamp) Form expiration time (Unix seconds) - `event: workflow_finished` workflow execution finished, success and failure are different states in the same event - `task_id` (string) Task ID, used for request tracking and the below Stop Generate API - `workflow_run_id` (string) Unique ID of workflow execution @@ -666,6 +985,198 @@ Workflow applications offers non-session support and is ideal for translation, a --- + + + + Retrieve a pending Human-in-the-Loop form by `form_token`. + + Use this endpoint when a workflow pauses with `human_input_required` and returns a `form_token`. + + ### Path + - `form_token` (string) Required, token returned by the pause event. + + ### Response + - `form_content` (string) Rendered form content (markdown/plain text) + - `inputs` (array[object]) Form input definitions + - `resolved_default_values` (object) Default values resolved to strings + - `user_actions` (array[object]) Action buttons + - `expiration_time` (timestamp) Form expiration time (Unix seconds) + + ### Errors + - 404, form not found or does not belong to current app + - 412, `human_input_form_submitted`, form already submitted + - 412, `human_input_form_expired`, form expired + + + + + + ```json {{ title: 'Response' }} + { + "form_content": "Please confirm the final answer: {{#$output.answer#}}", + "inputs": [ + { + "label": "Answer", + "type": "text-input", + "required": true, + "output_variable_name": "answer" + } + ], + "resolved_default_values": { + "answer": "Initial value" + }, + "user_actions": [ + { "id": "approve", "title": "Approve", "button_style": "primary" }, + { "id": "reject", "title": "Reject", "button_style": "warning" } + ], + "expiration_time": 1735689600 + } + ``` + + + + +--- + + + + + Submit a pending Human-in-the-Loop form. + + ### Path + - `form_token` (string) Required, token returned by the pause event. + + ### Request Body + - `inputs` (object) Required, key/value pairs for form fields. + - `action` (string) Required, selected action ID from `user_actions`. + - `user` (string) Required, end-user identifier. + + ### Response + Returns an empty object on success. + + ### Errors + - 400, `invalid_form_data`, submitted data does not match the form schema + - 404, form not found or does not belong to current app + - 412, `human_input_form_submitted`, form already submitted + - 412, `human_input_form_expired`, form expired + + + + + + ```json {{ title: 'Response' }} + {} + ``` + + + + +--- + + + + + Continue receiving workflow events after submitting a human input form. + + This endpoint returns `text/event-stream` and can be used to observe the resumed run until completion. + + ### Path + - `task_id` (string) Required, workflow run ID (`workflow_run_id`). + + ### Query + - `user` (string) Required, end-user identifier. + - `include_state_snapshot` (bool) Optional, set to `true` to replay from persisted state snapshot before continuing with live events. + - `continue_on_pause` (bool) Optional, set to `true` to keep the stream open across `workflow_paused` events until `workflow_finished`. + + ### Response + Server-Sent Events stream (`text/event-stream`). + Typical events include `workflow_paused`, `node_started`, `node_finished`, `human_input_form_filled`, `human_input_form_timeout`, and `workflow_finished`. + If the workflow has already finished when you call this endpoint, the server returns a single finished event immediately. + + + + + + ```streaming {{ title: 'Response' }} + event: ping + + data: {"event":"workflow_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","sys.timestamp":1776087863},"created_at":1776087863,"reason":"initial"}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"succeeded","error":null,"elapsed_time":0.00032,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"paused","error":null,"elapsed_time":0.007381,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_paused","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","paused_nodes":["1775717346519"],"outputs":{},"reasons":[{"form_id":"019d8716-0fde-75da-8207-1458ccde76e5","form_content":"this is form 1:\n{{#$output.some_field#}}\n","inputs":[{"type":"paragraph","output_variable_name":"some_field","default":{"type":"variable","selector":["sys","workflow_run_id"],"value":""}}],"actions":[{"id":"approve","title":"YES","button_style":"default"},{"id":"reject","title":"NO","button_style":"default"}],"display_in_ui":true,"node_id":"1775717346519","node_title":"Human Input","resolved_default_values":{"some_field":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"form_token":"n7hFG4ZDYdGcgZ5VDc7EGM","type":"human_input_required"}],"status":"paused","created_at":1776087863,"elapsed_time":0.0,"total_tokens":0,"total_steps":2}} + + data: {"event":"workflow_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"created_at":1776087877,"reason":"resumption"}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"human_input_form_filled","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"node_id":"1775717346519","node_title":"Human Input","rendered_content":"this is form 1:\nfield 1 filled!\n","action_id":"approve","action_text":"YES"}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":{},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"some_field":"field 1 filled!","some_field_2":"from bruno with love","__action_id":"approve","__rendered_content":"this is form 1:\nfield 1 filled!\n"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.004431,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"text_chunk","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"text":"field 1 filled!","from_variable_selector":["1775717350710","output"]}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":{"some_field":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.264614,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":{"output":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.00003,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","status":"succeeded","outputs":{"output":"field 1 filled!"},"error":null,"elapsed_time":0.364935,"total_tokens":0,"total_steps":5,"created_by":{"id":"7932d34c-dcf4-4fba-b770-f2a9de88c0a0","user":"abc-123"},"created_at":1776087877,"finished_at":1776087877,"exceptions_count":0,"files":[]}} + ``` + + + + +--- + ### ブロッキングモード + ブロッキングモードでは、完了済みのワークフロー結果、または一時停止中のワークフロー結果のいずれかが返されます。 + + Human-in-the-Loop で実行が一時停止した場合も、レスポンスには `workflow_run_id` と `task_id` が含まれますが、`data.status` は `paused` になります。 + 一時停止レスポンスには `paused_nodes` と `reasons` も含まれます。`human_input_required` の各 reason には `form_id` と `expiration_time` が含まれます。 + ```json {{ title: '応答' }} { @@ -236,6 +259,70 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' } ``` + + ```json {{ title: '一時停止レスポンス例' }} + { + "task_id": "3938b985-f4c6-4806-87b6-215e0aca9d81", + "workflow_run_id": "4a80f375-682b-49c5-b199-e950aac4968f", + "data": { + "id": "4a80f375-682b-49c5-b199-e950aac4968f", + "workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "status": "paused", + "outputs": {}, + "error": null, + "elapsed_time": 0.035667, + "total_tokens": 0, + "total_steps": 2, + "created_at": 1776074783, + "finished_at": null, + "paused_nodes": [ + "1775717346519" + ], + "reasons": [ + { + "form_id": "019d864e-7a36-74a2-b94e-e5660c47f5a7", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775717346519", + "node_title": "Human Input", + "resolved_default_values": { + "some_field": "4a80f375-682b-49c5-b199-e950aac4968f" + }, + "form_token": "SZwvfmL47fTIsZynP2Jr9i", + "type": "human_input_required", + "expiration_time": 1776333983 + } + ] + } + } + ``` + ### ストリーミングモード ```streaming {{ title: '応答' }} @@ -247,6 +334,220 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` + ストリーミングモードでも Human-in-the-Loop により一時停止する場合があります。その場合、SSE ストリームではまず `human_input_required` が送られ、その後に `workflow_paused` が送られます。 + + + ```streaming {{ title: '一時停止ストリーミングレスポンス例' }} + event: ping + + data: { + "event": "workflow_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "created_at": 1776129228, + "reason": "initial" + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "node_finished", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "inputs_truncated": false, + "process_data": {}, + "process_data_truncated": false, + "outputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "outputs_truncated": false, + "status": "succeeded", + "error": null, + "elapsed_time": 0.000097, + "execution_metadata": null, + "created_at": 1776129228, + "finished_at": 1776129228, + "files": [], + "iteration_id": null, + "loop_id": null + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "c09ff568-1d55-4f0d-9a07-512bcbfeb289", + "node_id": "1775717346519", + "node_type": "human-input", + "title": "Human Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "human_input_required", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "node_id": "1775717346519", + "node_title": "Human Input", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "expiration_time": 1776388428 + } + } + + data: { + "event": "workflow_paused", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "paused_nodes": [ + "1775717346519" + ], + "outputs": {}, + "reasons": [ + { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775717346519", + "node_title": "Human Input", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "type": "human_input_required", + "expiration_time": 1776388428 + } + ], + "status": "paused", + "created_at": 1776129228, + "elapsed_time": 0.070478, + "total_tokens": 0, + "total_steps": 2 + } + } + ``` + ```json {{ title: 'ファイルアップロードのサンプルコード' }} import requests @@ -452,6 +753,24 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `total_price` (decimal) オプション 総費用 - `currency` (string) オプション 通貨、例:`USD` / `RMB` - `created_at` (timestamp) 開始時間 + - `event: human_input_required` ワークフローが一時停止し、Human-in-the-Loop 入力が必要 + - `task_id` (string) タスク ID、リクエスト追跡に使用 + - `workflow_run_id` (string) ワークフロー実行 ID + - `event` (string) `human_input_required` に固定 + - `data` (object) 詳細内容 + - `form_id` (string) ヒューマン入力フォーム ID + - `node_id` (string) Human Input ノード ID + - `node_title` (string) Human Input ノードタイトル + - `form_content` (string) レンダリング済みフォーム内容 + - `inputs` (array[object]) フォーム入力項目の定義 + - `actions` (array[object]) ユーザーが選択できるアクションボタン + - `id` (string) アクション ID + - `title` (string) ボタンラベル + - `button_style` (string) ボタンスタイル + - `display_in_ui` (bool) UI にこのフォームを表示するかどうか + - `form_token` (string) `/form/human_input/:form_token` API で使用するトークン + - `resolved_default_values` (object) 実行時に解決されたデフォルト値 + - `expiration_time` (timestamp) フォームの有効期限(Unix 秒) - `event: workflow_finished` ワークフロー実行終了、成功と失敗は同じイベント内の異なる状態 - `task_id` (string) タスクID、リクエスト追跡と以下の停止応答インターフェースに使用 - `workflow_run_id` (string) ワークフロー実行ID @@ -661,6 +980,198 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --- + + + + `form_token` から保留中の Human-in-the-Loop フォームを取得します。 + + Workflow が `human_input_required`(`form_token` を含む)で一時停止した際に使用します。 + + ### パス + - `form_token` (string) 必須、一時停止イベントで返されたフォームトークン + + ### 応答 + - `form_content` (string) レンダリング済みフォーム内容(markdown/plain text) + - `inputs` (array[object]) 入力項目定義 + - `resolved_default_values` (object) 解決済みデフォルト値(文字列) + - `user_actions` (array[object]) アクションボタン一覧 + - `expiration_time` (timestamp) フォーム有効期限(Unix 秒) + + ### エラー + - 404, フォームが存在しない、または現在のアプリに属していない + - 412, `human_input_form_submitted`, 既に送信済み + - 412, `human_input_form_expired`, 期限切れ + + + + + + ```json {{ title: '応答' }} + { + "form_content": "最終回答を確認してください: {{#$output.answer#}}", + "inputs": [ + { + "label": "回答", + "type": "text-input", + "required": true, + "output_variable_name": "answer" + } + ], + "resolved_default_values": { + "answer": "初期値" + }, + "user_actions": [ + { "id": "approve", "title": "承認", "button_style": "primary" }, + { "id": "reject", "title": "却下", "button_style": "warning" } + ], + "expiration_time": 1735689600 + } + ``` + + + + +--- + + + + + 保留中の Human-in-the-Loop フォームを送信します。 + + ### パス + - `form_token` (string) 必須、一時停止イベントで返されたフォームトークン + + ### リクエストボディ + - `inputs` (object) 必須、フォーム項目の key/value + - `action` (string) 必須、`user_actions` から選択したアクション ID + - `user` (string) 必須、エンドユーザー識別子 + + ### 応答 + 成功時は空オブジェクトを返します。 + + ### エラー + - 400, `invalid_form_data`, 送信データがフォームスキーマに一致しない + - 404, フォームが存在しない、または現在のアプリに属していない + - 412, `human_input_form_submitted`, 既に送信済み + - 412, `human_input_form_expired`, 期限切れ + + + + + + ```json {{ title: '応答' }} + {} + ``` + + + + +--- + + + + + Human Input フォーム送信後に、ワークフロー再開後のイベントを継続受信します。 + + このエンドポイントは `text/event-stream` を返し、完了までイベントを購読できます。 + + ### パス + - `task_id` (string) 必須、workflow 実行 ID(`workflow_run_id`) + + ### クエリ + - `user` (string) 必須、エンドユーザー識別子 + - `include_state_snapshot` (bool) 任意、`true` の場合は永続化済み状態スナップショットを先に再生してからリアルタイムイベントへ移行 + - `continue_on_pause` (bool) 任意、`true` にすると `workflow_paused` イベントをまたいでもストリームを維持し、`workflow_finished` で終了します + + ### 応答 + Server-Sent Events ストリーム(`text/event-stream`)。 + 主なイベントは `workflow_paused`、`node_started`、`node_finished`、`human_input_form_filled`、`human_input_form_timeout`、`workflow_finished` です。 + 呼び出し時点でワークフローがすでに完了している場合、このエンドポイントは完了イベントを 1 件だけ即座に返します。 + + + + + + ```streaming {{ title: '応答' }} + event: ping + + data: {"event":"workflow_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","sys.timestamp":1776087863},"created_at":1776087863,"reason":"initial"}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"succeeded","error":null,"elapsed_time":0.00032,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"paused","error":null,"elapsed_time":0.007381,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_paused","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","paused_nodes":["1775717346519"],"outputs":{},"reasons":[{"form_id":"019d8716-0fde-75da-8207-1458ccde76e5","form_content":"this is form 1:\n{{#$output.some_field#}}\n","inputs":[{"type":"paragraph","output_variable_name":"some_field","default":{"type":"variable","selector":["sys","workflow_run_id"],"value":""}}],"actions":[{"id":"approve","title":"YES","button_style":"default"},{"id":"reject","title":"NO","button_style":"default"}],"display_in_ui":true,"node_id":"1775717346519","node_title":"Human Input","resolved_default_values":{"some_field":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"form_token":"n7hFG4ZDYdGcgZ5VDc7EGM","type":"human_input_required"}],"status":"paused","created_at":1776087863,"elapsed_time":0.0,"total_tokens":0,"total_steps":2}} + + data: {"event":"workflow_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"created_at":1776087877,"reason":"resumption"}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"human_input_form_filled","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"node_id":"1775717346519","node_title":"Human Input","rendered_content":"this is form 1:\nfield 1 filled!\n","action_id":"approve","action_text":"YES"}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":{},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"some_field":"field 1 filled!","some_field_2":"from bruno with love","__action_id":"approve","__rendered_content":"this is form 1:\nfield 1 filled!\n"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.004431,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"text_chunk","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"text":"field 1 filled!","from_variable_selector":["1775717350710","output"]}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":{"some_field":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.264614,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":{"output":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.00003,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","status":"succeeded","outputs":{"output":"field 1 filled!"},"error":null,"elapsed_time":0.364935,"total_tokens":0,"total_steps":5,"created_by":{"id":"7932d34c-dcf4-4fba-b770-f2a9de88c0a0","user":"abc-123"},"created_at":1776087877,"finished_at":1776087877,"exceptions_count":0,"files":[]}} + ``` + + + + +--- + ### Blocking Mode + 阻塞模式可能返回已完成的工作流结果,也可能返回暂停中的工作流结果。 + + 当执行因 Human-in-the-Loop 暂停时,响应仍然会包含 `workflow_run_id` 和 `task_id`,但 `data.status` 会变为 `paused`。 + 暂停响应还会包含 `paused_nodes` 和 `reasons`。对于 `human_input_required`,每个 reason 都会包含 `form_id` 和 `expiration_time`。 + ```json {{ title: 'Response' }} { @@ -226,6 +249,70 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 } ``` + + ```json {{ title: 'Paused Response Example' }} + { + "task_id": "3938b985-f4c6-4806-87b6-215e0aca9d81", + "workflow_run_id": "4a80f375-682b-49c5-b199-e950aac4968f", + "data": { + "id": "4a80f375-682b-49c5-b199-e950aac4968f", + "workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "status": "paused", + "outputs": {}, + "error": null, + "elapsed_time": 0.035667, + "total_tokens": 0, + "total_steps": 2, + "created_at": 1776074783, + "finished_at": null, + "paused_nodes": [ + "1775717346519" + ], + "reasons": [ + { + "form_id": "019d864e-7a36-74a2-b94e-e5660c47f5a7", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775717346519", + "node_title": "Human Input", + "resolved_default_values": { + "some_field": "4a80f375-682b-49c5-b199-e950aac4968f" + }, + "form_token": "SZwvfmL47fTIsZynP2Jr9i", + "type": "human_input_required", + "expiration_time": 1776333983 + } + ] + } + } + ``` + ### Streaming Mode ```streaming {{ title: 'Response' }} @@ -237,6 +324,220 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` + 流式模式同样可能因为 Human-in-the-Loop 而暂停。此时 SSE 流会先返回 `human_input_required`,随后返回 `workflow_paused`。 + + + ```streaming {{ title: 'Paused Streaming Response Example' }} + event: ping + + data: { + "event": "workflow_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "created_at": 1776129228, + "reason": "initial" + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "node_finished", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "7d9bb041-5ecb-497f-a674-d8706eed0ab1", + "node_id": "1775717266623", + "node_type": "start", + "title": "User Input", + "index": 1, + "predecessor_node_id": null, + "inputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "inputs_truncated": false, + "process_data": {}, + "process_data_truncated": false, + "outputs": { + "sys.files": [], + "sys.user_id": "abc-123", + "sys.app_id": "d1074979-f67e-4114-8691-e35878df9a89", + "sys.workflow_id": "e46514f1-c008-41ff-94b0-4f33d4b97d36", + "sys.workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "sys.timestamp": 1776129228 + }, + "outputs_truncated": false, + "status": "succeeded", + "error": null, + "elapsed_time": 0.000097, + "execution_metadata": null, + "created_at": 1776129228, + "finished_at": 1776129228, + "files": [], + "iteration_id": null, + "loop_id": null + } + } + + data: { + "event": "node_started", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "id": "c09ff568-1d55-4f0d-9a07-512bcbfeb289", + "node_id": "1775717346519", + "node_type": "human-input", + "title": "Human Input", + "index": 1, + "predecessor_node_id": null, + "inputs": null, + "inputs_truncated": false, + "created_at": 1776129228, + "extras": {}, + "iteration_id": null, + "loop_id": null, + "agent_strategy": null + } + } + + data: { + "event": "human_input_required", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "node_id": "1775717346519", + "node_title": "Human Input", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "expiration_time": 1776388428 + } + } + + data: { + "event": "workflow_paused", + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "task_id": "0399c5c2-181b-4493-a78e-1421914e8a25", + "data": { + "workflow_run_id": "a4959eb4-c852-4e0c-ac7a-348233f7f345", + "paused_nodes": [ + "1775717346519" + ], + "outputs": {}, + "reasons": [ + { + "form_id": "019d898d-3d80-7105-b920-9899ead4ff3e", + "form_content": "this is form 1:\n{{#$output.some_field#}}\n", + "inputs": [ + { + "type": "paragraph", + "output_variable_name": "some_field", + "default": { + "type": "variable", + "selector": [ + "sys", + "workflow_run_id" + ], + "value": "" + } + } + ], + "actions": [ + { + "id": "approve", + "title": "YES", + "button_style": "default" + }, + { + "id": "reject", + "title": "NO", + "button_style": "default" + } + ], + "display_in_ui": true, + "node_id": "1775717346519", + "node_title": "Human Input", + "resolved_default_values": { + "some_field": "a4959eb4-c852-4e0c-ac7a-348233f7f345" + }, + "form_token": "0Tb1nXYe4hzQUD706nHB4y", + "type": "human_input_required", + "expiration_time": 1776388428 + } + ], + "status": "paused", + "created_at": 1776129228, + "elapsed_time": 0.070478, + "total_tokens": 0, + "total_steps": 2 + } + } + ``` + ```json {{ title: 'File upload sample code' }} import requests @@ -445,6 +746,24 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 - `total_price` (decimal) optional 总费用 - `currency` (string) optional 货币,如 `USD` / `RMB` - `created_at` (timestamp) 开始时间 + - `event: human_input_required` Workflow 已暂停,等待 Human-in-the-Loop 输入 + - `task_id` (string) 任务 ID,用于请求跟踪 + - `workflow_run_id` (string) workflow 执行 ID + - `event` (string) 固定为 `human_input_required` + - `data` (object) 详细内容 + - `form_id` (string) 人工输入表单 ID + - `node_id` (string) Human Input 节点 ID + - `node_title` (string) Human Input 节点标题 + - `form_content` (string) 渲染后的表单内容 + - `inputs` (array[object]) 表单输入项定义 + - `actions` (array[object]) 用户可选动作按钮 + - `id` (string) 动作 ID + - `title` (string) 按钮文案 + - `button_style` (string) 按钮样式 + - `display_in_ui` (bool) 是否需要在 UI 展示该表单 + - `form_token` (string) 用于 `/form/human_input/:form_token` 接口的令牌 + - `resolved_default_values` (object) 运行时解析后的默认值 + - `expiration_time` (timestamp) 表单过期时间(Unix 秒级时间戳) - `event: workflow_finished` workflow 执行结束,成功失败同一事件中不同状态 - `task_id` (string) 任务 ID,用于请求跟踪和下方的停止响应接口 - `workflow_run_id` (string) workflow 执行 ID @@ -654,6 +973,198 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 --- + + + + 通过 `form_token` 获取待处理的 Human-in-the-Loop 表单。 + + 当 Workflow 在流式事件中返回 `human_input_required`(包含 `form_token`)时,可调用此接口拉取表单详情。 + + ### Path + - `form_token` (string) 必填,暂停事件返回的表单 token + + ### Response + - `form_content` (string) 已渲染的表单内容(markdown/plain text) + - `inputs` (array[object]) 表单输入项定义 + - `resolved_default_values` (object) 已解析的默认值(字符串) + - `user_actions` (array[object]) 操作按钮列表 + - `expiration_time` (timestamp) 表单过期时间(Unix 秒) + + ### Errors + - 404,表单不存在或不属于当前应用 + - 412,`human_input_form_submitted`,表单已被提交 + - 412,`human_input_form_expired`,表单已过期 + + + + + + ```json {{ title: 'Response' }} + { + "form_content": "请确认最终结果:{{#$output.answer#}}", + "inputs": [ + { + "label": "答案", + "type": "text-input", + "required": true, + "output_variable_name": "answer" + } + ], + "resolved_default_values": { + "answer": "初始值" + }, + "user_actions": [ + { "id": "approve", "title": "通过", "button_style": "primary" }, + { "id": "reject", "title": "拒绝", "button_style": "warning" } + ], + "expiration_time": 1735689600 + } + ``` + + + + +--- + + + + + 提交待处理的 Human-in-the-Loop 表单。 + + ### Path + - `form_token` (string) 必填,暂停事件返回的表单 token + + ### Request Body + - `inputs` (object) 必填,表单字段的 key/value + - `action` (string) 必填,从 `user_actions` 中选择的动作 ID + - `user` (string) 必填,终端用户标识 + + ### Response + 成功时返回空对象。 + + ### Errors + - 400,`invalid_form_data`,提交数据与表单 schema 不匹配 + - 404,表单不存在或不属于当前应用 + - 412,`human_input_form_submitted`,表单已被提交 + - 412,`human_input_form_expired`,表单已过期 + + + + + + ```json {{ title: 'Response' }} + {} + ``` + + + + +--- + + + + + 在提交人工输入表单后,继续订阅工作流后续执行事件。 + + 返回 `text/event-stream`,可持续接收直到工作流结束。 + + ### Path + - `task_id` (string) 必填,workflow 运行 ID(`workflow_run_id`) + + ### Query + - `user` (string) 必填,终端用户标识 + - `include_state_snapshot` (bool) 可选,设为 `true` 时会先回放持久化状态快照,再继续实时事件 + - `continue_on_pause` (bool) 可选,设为 `true` 时,流会在 `workflow_paused` 事件之间保持连接,直到 `workflow_finished` 才结束 + + ### Response + Server-Sent Events 流(`text/event-stream`)。 + 常见事件包括 `workflow_paused`、`node_started`、`node_finished`、`human_input_form_filled`、`human_input_form_timeout`、`workflow_finished`。 + 如果调用该接口时工作流已经结束,服务端会立即返回单个完成事件。 + + + + + + ```streaming {{ title: 'Response' }} + event: ping + + data: {"event":"workflow_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","sys.timestamp":1776087863},"created_at":1776087863,"reason":"initial"}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"b552d685-1119-4e6a-9a81-e91a23e5324b","node_id":"1775717266623","node_type":"start","title":"User Input","index":1,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"succeeded","error":null,"elapsed_time":0.00032,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"created_at":1776087863,"extras":{},"iteration_id":null,"loop_id":null}} + + data: {"event":"node_finished","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":2,"predecessor_node_id":null,"inputs":null,"process_data":null,"outputs":null,"status":"paused","error":null,"elapsed_time":0.007381,"execution_metadata":null,"created_at":1776087863,"finished_at":1776087863,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_paused","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","data":{"workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","paused_nodes":["1775717346519"],"outputs":{},"reasons":[{"form_id":"019d8716-0fde-75da-8207-1458ccde76e5","form_content":"this is form 1:\n{{#$output.some_field#}}\n","inputs":[{"type":"paragraph","output_variable_name":"some_field","default":{"type":"variable","selector":["sys","workflow_run_id"],"value":""}}],"actions":[{"id":"approve","title":"YES","button_style":"default"},{"id":"reject","title":"NO","button_style":"default"}],"display_in_ui":true,"node_id":"1775717346519","node_title":"Human Input","resolved_default_values":{"some_field":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"form_token":"n7hFG4ZDYdGcgZ5VDc7EGM","type":"human_input_required"}],"status":"paused","created_at":1776087863,"elapsed_time":0.0,"total_tokens":0,"total_steps":2}} + + data: {"event":"workflow_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","inputs":{"sys.files":[],"sys.user_id":"abc-123","sys.app_id":"d1074979-f67e-4114-8691-e35878df9a89","sys.workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","sys.workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c"},"created_at":1776087877,"reason":"resumption"}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"human_input_form_filled","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"node_id":"1775717346519","node_title":"Human Input","rendered_content":"this is form 1:\nfield 1 filled!\n","action_id":"approve","action_text":"YES"}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"8d7e8e01-5159-4089-a4b6-3aa394992cc2","node_id":"1775717346519","node_type":"human-input","title":"Human Input","index":1,"predecessor_node_id":null,"inputs":{},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"some_field":"field 1 filled!","some_field_2":"from bruno with love","__action_id":"approve","__rendered_content":"this is form 1:\nfield 1 filled!\n"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.004431,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"text_chunk","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"text":"field 1 filled!","from_variable_selector":["1775717350710","output"]}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"6d8fc3cb-19f7-440b-b83e-eed4e847a332","node_id":"1775717350710","node_type":"template-transform","title":"Template","index":1,"predecessor_node_id":null,"inputs":{"some_field":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.264614,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"node_started","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":null,"inputs_truncated":false,"created_at":1776087877,"extras":{},"iteration_id":null,"loop_id":null,"agent_strategy":null}} + + data: {"event":"node_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"e88dec7e-aa2c-41f7-8d73-032b749e23f5","node_id":"1775717354177","node_type":"end","title":"Output","index":1,"predecessor_node_id":null,"inputs":{"output":"field 1 filled!"},"inputs_truncated":false,"process_data":{},"process_data_truncated":false,"outputs":{"output":"field 1 filled!"},"outputs_truncated":false,"status":"succeeded","error":null,"elapsed_time":0.00003,"execution_metadata":null,"created_at":1776087877,"finished_at":1776087877,"files":[],"iteration_id":null,"loop_id":null}} + + data: {"event":"workflow_finished","workflow_run_id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","task_id":"1784c3dd-20eb-4919-bd5d-a8d800b74ada","data":{"id":"5d7ef348-e1c1-4f6d-bb9b-62cc2fb2ef3c","workflow_id":"e46514f1-c008-41ff-94b0-4f33d4b97d36","status":"succeeded","outputs":{"output":"field 1 filled!"},"error":null,"elapsed_time":0.364935,"total_tokens":0,"total_steps":5,"created_by":{"id":"7932d34c-dcf4-4fba-b770-f2a9de88c0a0","user":"abc-123"},"created_at":1776087877,"finished_at":1776087877,"exceptions_count":0,"files":[]}} + ``` + + + + +--- + { } }) - it('should reset emoji icon to initial props when picker is cancelled', async () => { + it('should allow changing only the background for the current emoji icon', async () => { vi.useFakeTimers() try { const { onConfirm } = await setup({ @@ -370,22 +370,14 @@ describe('CreateAppModal', () => { fireEvent.click(getAppIconTrigger()) - const categoryLabel = screen.getByText('people') - const emojiGrid = categoryLabel.nextElementSibling - const clickableEmojiWrapper = emojiGrid?.firstElementChild - if (!(clickableEmojiWrapper instanceof HTMLElement)) - throw new Error('Failed to locate emoji wrapper') - fireEvent.click(clickableEmojiWrapper) + const colorOption = Array.from(document.querySelectorAll('[style^="background:"]')) + .find(element => element.getAttribute('style')?.includes('#E4FBCC')) + if (!(colorOption instanceof HTMLElement) || !(colorOption.parentElement instanceof HTMLElement)) + throw new Error('Failed to locate background color option') + fireEvent.click(colorOption.parentElement) fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.ok' })) - expect(screen.queryByRole('button', { name: 'app.iconPicker.cancel' })).not.toBeInTheDocument() - - fireEvent.click(getAppIconTrigger()) - fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.cancel' })) - - expect(screen.queryByRole('button', { name: 'app.iconPicker.cancel' })).not.toBeInTheDocument() - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) await act(async () => { vi.advanceTimersByTime(300) @@ -396,7 +388,7 @@ describe('CreateAppModal', () => { expect(payload).toMatchObject({ icon_type: 'emoji', icon: '🤖', - icon_background: '#FFEAD5', + icon_background: '#E4FBCC', }) } finally { diff --git a/web/app/components/explore/create-app-modal/index.tsx b/web/app/components/explore/create-app-modal/index.tsx index ebe5a79a16..a7c9e06655 100644 --- a/web/app/components/explore/create-app-modal/index.tsx +++ b/web/app/components/explore/create-app-modal/index.tsx @@ -206,14 +206,14 @@ const CreateAppModal = ({ {showAppIconPicker && ( { setAppIcon(payload) setShowAppIconPicker(false) }} onClose={() => { - setAppIcon(appIconType === 'image' - ? { type: 'image' as const, url: appIconUrl, fileId: _appIcon } - : { type: 'emoji' as const, icon: _appIcon, background: appIconBackground }) setShowAppIconPicker(false) }} /> diff --git a/web/app/components/plugins/plugin-auth/authorize/__tests__/add-api-key-button.spec.tsx b/web/app/components/plugins/plugin-auth/authorize/__tests__/add-api-key-button.spec.tsx index 794f847168..7caef50516 100644 --- a/web/app/components/plugins/plugin-auth/authorize/__tests__/add-api-key-button.spec.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/__tests__/add-api-key-button.spec.tsx @@ -5,11 +5,29 @@ import AddApiKeyButton from '../add-api-key-button' let _mockModalOpen = false vi.mock('../api-key-modal', () => ({ - default: ({ onClose, onUpdate }: { onClose: () => void, onUpdate?: () => void }) => { - _mockModalOpen = true + default: ({ + open = true, + onClose, + onOpenChange, + onUpdate, + }: { + open?: boolean + onClose: () => void + onOpenChange?: (open: boolean) => void + onUpdate?: () => void + }) => { + _mockModalOpen = open + if (!open) + return null + + const handleClose = () => { + onOpenChange?.(false) + onClose() + } + return (
- +
) diff --git a/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx b/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx index 2bfa94d2ed..41f1aa3718 100644 --- a/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx @@ -1,5 +1,8 @@ import type { ApiKeyModalProps } from '../api-key-modal' +import type { FormSchema } from '@/app/components/base/form/types' +import { Popover, PopoverContent, PopoverTrigger } from '@langgenius/dify-ui/popover' import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' import * as React from 'react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { AuthCategory } from '../../types' @@ -20,17 +23,27 @@ vi.mock('@langgenius/dify-ui/toast', () => ({ })) const mockAddPluginCredential = vi.fn().mockResolvedValue({}) const mockUpdatePluginCredential = vi.fn().mockResolvedValue({}) -const mockFormValues = { isCheckValidated: true, values: { __name__: 'My Key', api_key: 'sk-123' } } +const defaultCredentialSchemas = [ + { name: 'api_key', label: 'API Key', type: 'secret-input', required: true }, +] +type MockFormValues = { + isCheckValidated: boolean + values: Record +} + +const defaultFormValues: MockFormValues = { isCheckValidated: true, values: { __name__: 'My Key', api_key: 'sk-123' } } +let mockCredentialSchemas = defaultCredentialSchemas +let mockIsSchemaLoading = false +let mockFormValues = defaultFormValues +const mockAuthFormProps = vi.fn() vi.mock('../../hooks/use-credential', () => ({ useAddPluginCredentialHook: () => ({ mutateAsync: mockAddPluginCredential, }), useGetPluginCredentialSchemaHook: () => ({ - data: [ - { name: 'api_key', label: 'API Key', type: 'secret-input', required: true }, - ], - isLoading: false, + data: mockCredentialSchemas, + isLoading: mockIsSchemaLoading, }), useUpdatePluginCredentialHook: () => ({ mutateAsync: mockUpdatePluginCredential, @@ -49,36 +62,19 @@ vi.mock('@/app/components/base/encrypted-bottom', () => ({ EncryptedBottom: () =>
, })) -vi.mock('@/app/components/base/modal/modal', () => ({ - default: ({ children, title, onClose, onConfirm, onExtraButtonClick, showExtraButton, disabled }: { - children: React.ReactNode - title: string - onClose?: () => void - onCancel?: () => void - onConfirm?: () => void - onExtraButtonClick?: () => void - showExtraButton?: boolean - disabled?: boolean - [key: string]: unknown - }) => ( -
-
{title}
- {children} - - - {showExtraButton && } -
- ), -})) - -vi.mock('@/app/components/base/form/form-scenarios/auth', () => ({ - default: React.forwardRef((_props: Record, ref: React.Ref) => { +vi.mock('@/app/components/base/form/form-scenarios/auth', () => { + const MockAuthForm = ({ ref, ...props }: { ref?: React.Ref } & Record) => { + mockAuthFormProps(props) React.useImperativeHandle(ref, () => ({ getFormValues: () => mockFormValues, })) return
- }), -})) + } + + return { + default: MockAuthForm, + } +}) vi.mock('@/app/components/base/form/types', () => ({ FormTypeEnum: { textInput: 'text-input' }, @@ -89,11 +85,73 @@ const basePayload = { provider: 'test-provider', } +const PopoverModalHarness = ({ + ApiKeyModal, + onClose, + onPopoverClose, +}: { + ApiKeyModal: React.FC + onClose: () => void + onPopoverClose: () => void +}) => { + const [open, setOpen] = React.useState(true) + + return ( + { + setOpen(nextOpen) + if (!nextOpen) + onPopoverClose() + }} + > + Credentials} /> + +
+ +
+
+
+ ) +} + +const ControlledModalHarness = ({ + ApiKeyModal, + onClose, +}: { + ApiKeyModal: React.FC + onClose: () => void +}) => { + const [open, setOpen] = React.useState(true) + + return ( + <> +
{String(open)}
+ + + ) +} + describe('ApiKeyModal', () => { let ApiKeyModal: React.FC beforeEach(async () => { vi.clearAllMocks() + mockCredentialSchemas = defaultCredentialSchemas + mockIsSchemaLoading = false + mockFormValues = defaultFormValues + mockAddPluginCredential.mockResolvedValue({}) + mockUpdatePluginCredential.mockResolvedValue({}) const mod = await import('../api-key-modal') ApiKeyModal = mod.default }) @@ -110,6 +168,56 @@ describe('ApiKeyModal', () => { expect(screen.getByTestId('auth-form')).toBeInTheDocument() }) + it('should prefer formSchemas prop and apply schema defaults', () => { + const customSchemas: FormSchema[] = [ + { + name: 'custom_api_key', + label: 'Custom API Key', + type: 'secret-input' as FormSchema['type'], + required: true, + default: 'default-key', + }, + ] + + render() + + expect(mockAuthFormProps).toHaveBeenCalledWith(expect.objectContaining({ + formSchemas: expect.arrayContaining([ + expect.objectContaining({ name: 'custom_api_key' }), + ]), + defaultValues: expect.objectContaining({ + custom_api_key: 'default-key', + }), + })) + }) + + it('should not render auth form when credential schema is empty', () => { + mockCredentialSchemas = [] + + render() + + expect(screen.queryByTestId('auth-form')).not.toBeInTheDocument() + }) + + it('should not submit when form ref is unavailable', () => { + mockCredentialSchemas = [] + + render() + + fireEvent.click(screen.getByTestId('modal-confirm')) + + expect(mockAddPluginCredential).not.toHaveBeenCalled() + }) + + it('should disable actions while loading credential schema', () => { + mockIsSchemaLoading = true + + render() + + expect(screen.queryByTestId('auth-form')).not.toBeInTheDocument() + expect(screen.getByTestId('modal-confirm')).toBeDisabled() + }) + it('should show remove button when editValues is provided', () => { render() @@ -130,6 +238,18 @@ describe('ApiKeyModal', () => { expect(mockOnClose).toHaveBeenCalled() }) + it('should close through controlled open state when cancel is clicked', async () => { + const mockOnClose = vi.fn() + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + await waitFor(() => { + expect(screen.getByTestId('modal-open-state')).toHaveTextContent('false') + }) + expect(mockOnClose).toHaveBeenCalled() + }) + it('should call addPluginCredential on confirm in add mode', async () => { const mockOnClose = vi.fn() const mockOnUpdate = vi.fn() @@ -145,6 +265,50 @@ describe('ApiKeyModal', () => { }) }) + it('should use empty credential name when authorization name is blank in add mode', async () => { + mockFormValues = { isCheckValidated: true, values: { api_key: 'sk-123' } } + + render() + + fireEvent.click(screen.getByTestId('modal-confirm')) + + await waitFor(() => { + expect(mockAddPluginCredential).toHaveBeenCalledWith(expect.objectContaining({ + name: '', + })) + }) + }) + + it('should not submit when form validation fails', () => { + mockFormValues = { isCheckValidated: false, values: {} } + + render() + + fireEvent.click(screen.getByTestId('modal-confirm')) + + expect(mockAddPluginCredential).not.toHaveBeenCalled() + expect(mockUpdatePluginCredential).not.toHaveBeenCalled() + }) + + it('should ignore repeated confirm while an action is in progress', async () => { + let repeatedClickTriggered = false + mockAddPluginCredential.mockImplementationOnce(async () => { + if (!repeatedClickTriggered) { + repeatedClickTriggered = true + fireEvent.click(screen.getByTestId('modal-confirm')) + } + return {} + }) + + render() + + fireEvent.click(screen.getByTestId('modal-confirm')) + + await waitFor(() => { + expect(mockAddPluginCredential).toHaveBeenCalledTimes(1) + }) + }) + it('should call updatePluginCredential on confirm in edit mode', async () => { render() @@ -155,6 +319,20 @@ describe('ApiKeyModal', () => { }) }) + it('should use empty credential name when authorization name is blank in edit mode', async () => { + mockFormValues = { isCheckValidated: true, values: { api_key: 'updated', __credential_id__: 'cred-1' } } + + render() + + fireEvent.click(screen.getByTestId('modal-confirm')) + + await waitFor(() => { + expect(mockUpdatePluginCredential).toHaveBeenCalledWith(expect.objectContaining({ + name: '', + })) + }) + }) + it('should call onRemove when remove button clicked', () => { const mockOnRemove = vi.fn() render() @@ -163,6 +341,49 @@ describe('ApiKeyModal', () => { expect(mockOnRemove).toHaveBeenCalled() }) + it('should stay open when clicking inside the modal from a popover', async () => { + // Use userEvent instead of fireEvent to avoid CI flakiness: userEvent + // awaits React act() between pointer/mouse/click so base-ui's dialog + // popup ref is guaranteed committed before outside-click detection runs. + const user = userEvent.setup() + const mockOnClose = vi.fn() + const mockOnPopoverClose = vi.fn() + + render( + , + ) + + const form = await screen.findByTestId('auth-form') + + await user.click(form) + + expect(mockOnClose).not.toHaveBeenCalled() + expect(mockOnPopoverClose).not.toHaveBeenCalled() + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + it('should close on backdrop click through controlled open state', async () => { + const mockOnClose = vi.fn() + render() + + const backdrop = document.querySelector('.bg-background-overlay') + if (!backdrop) + throw new Error('Expected dialog backdrop to render') + + fireEvent.pointerDown(backdrop) + fireEvent.mouseDown(backdrop) + fireEvent.click(backdrop) + + await waitFor(() => { + expect(screen.getByTestId('modal-open-state')).toHaveTextContent('false') + }) + expect(mockOnClose).toHaveBeenCalled() + }) + it('should render readme entrance when detail is provided', () => { const payload = { ...basePayload, detail: { name: 'Test' } as never } render() diff --git a/web/app/components/plugins/plugin-auth/authorize/add-api-key-button.tsx b/web/app/components/plugins/plugin-auth/authorize/add-api-key-button.tsx index 648a87dabc..38f3f85643 100644 --- a/web/app/components/plugins/plugin-auth/authorize/add-api-key-button.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/add-api-key-button.tsx @@ -25,20 +25,26 @@ const AddApiKeyButton = ({ formSchemas = [], }: AddApiKeyButtonProps) => { const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false) + const [isApiKeyModalMounted, setIsApiKeyModalMounted] = useState(false) return ( <> { - isApiKeyModalOpen && ( + isApiKeyModalMounted && ( setIsApiKeyModalOpen(false)} onUpdate={onUpdate} diff --git a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx index db513ecb6f..290621141c 100644 --- a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx @@ -3,6 +3,8 @@ import type { FormRefObject, FormSchema, } from '@/app/components/base/form/types' +import { Button } from '@langgenius/dify-ui/button' +import { Dialog, DialogCloseButton, DialogContent, DialogTitle } from '@langgenius/dify-ui/dialog' import { toast } from '@langgenius/dify-ui/toast' import { memo, @@ -16,7 +18,6 @@ import { EncryptedBottom } from '@/app/components/base/encrypted-bottom' import AuthForm from '@/app/components/base/form/form-scenarios/auth' import { FormTypeEnum } from '@/app/components/base/form/types' import Loading from '@/app/components/base/loading' -import Modal from '@/app/components/base/modal/modal' import { ReadmeEntrance } from '../../readme-panel/entrance' import { ReadmeShowType } from '../../readme-panel/store' import { @@ -28,8 +29,10 @@ import { CredentialTypeEnum } from '../types' export type ApiKeyModalProps = { pluginPayload: PluginPayload + open?: boolean + onOpenChange?: (open: boolean) => void onClose?: () => void - editValues?: Record + editValues?: Record onRemove?: () => void disabled?: boolean onUpdate?: () => void @@ -37,6 +40,8 @@ export type ApiKeyModalProps = { } const ApiKeyModal = ({ pluginPayload, + open = true, + onOpenChange, onClose, editValues, onRemove, @@ -73,7 +78,7 @@ const ApiKeyModal = ({ if (schema.default) acc[schema.name] = schema.default return acc - }, {} as Record) + }, {} as Record) const { mutateAsync: addPluginCredential } = useAddPluginCredentialHook(pluginPayload) const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload) const formRef = useRef(null) @@ -114,53 +119,102 @@ const ApiKeyModal = ({ } toast.success(t('api.actionSuccess', { ns: 'common' })) + onOpenChange?.(false) onClose?.() onUpdate?.() } finally { handleSetDoingAction(false) } - }, [addPluginCredential, onClose, onUpdate, updatePluginCredential, t, editValues, handleSetDoingAction]) + }, [addPluginCredential, onClose, onOpenChange, onUpdate, updatePluginCredential, t, editValues, handleSetDoingAction]) + + const isDisabled = disabled || isLoading || doingAction + const handleOpenChange = useCallback((nextOpen: boolean) => { + onOpenChange?.(nextOpen) + if (!nextOpen) + onClose?.() + }, [onClose, onOpenChange]) return ( -
) - } - bottomSlot={} - onConfirm={handleConfirm} - showExtraButton={!!editValues} - onExtraButtonClick={onRemove} - disabled={disabled || isLoading || doingAction} - clickOutsideNotClose={true} - wrapperClassName="z-1002!" + - {pluginPayload.detail && ( - - )} - { - isLoading && ( -
- + +
+
+ + {t('auth.useApiAuth', { ns: 'plugin' })} + +
+ {t('auth.useApiAuthDesc', { ns: 'plugin' })} +
+
- ) - } - { - !isLoading && !!mergedData.length && ( - - ) - } - +
+ {pluginPayload.detail && ( + + )} + { + isLoading && ( +
+ +
+ ) + } + { + !isLoading && !!mergedData.length && ( + + ) + } +
+
+
+
+ {editValues && ( + <> + +
+ + )} + + +
+
+
+ +
+
+ +
) } diff --git a/web/app/components/plugins/plugin-auth/authorized/index.tsx b/web/app/components/plugins/plugin-auth/authorized/index.tsx index b8b34e33e0..774821b0c8 100644 --- a/web/app/components/plugins/plugin-auth/authorized/index.tsx +++ b/web/app/components/plugins/plugin-auth/authorized/index.tsx @@ -19,9 +19,6 @@ import { PopoverTrigger, } from '@langgenius/dify-ui/popover' import { toast } from '@langgenius/dify-ui/toast' -import { - RiArrowDownSLine, -} from '@remixicon/react' import { memo, useCallback, @@ -93,19 +90,19 @@ const Authorized = ({ }, [onOpenChange]) const oAuthCredentials = credentials.filter(credential => credential.credential_type === CredentialTypeEnum.OAUTH2) const apiKeyCredentials = credentials.filter(credential => credential.credential_type === CredentialTypeEnum.API_KEY) - const pendingOperationCredentialId = useRef(null) + const pendingOperationCredentialIdRef = useRef(null) const [deleteCredentialId, setDeleteCredentialId] = useState(null) const { mutateAsync: deletePluginCredential } = useDeletePluginCredentialHook(pluginPayload) const openConfirm = useCallback((credentialId?: string) => { setMergedIsOpen(false) if (credentialId) - pendingOperationCredentialId.current = credentialId + pendingOperationCredentialIdRef.current = credentialId - setDeleteCredentialId(pendingOperationCredentialId.current) + setDeleteCredentialId(pendingOperationCredentialIdRef.current) }, [setMergedIsOpen]) const closeConfirm = useCallback(() => { setDeleteCredentialId(null) - pendingOperationCredentialId.current = null + pendingOperationCredentialIdRef.current = null }, []) const [doingAction, setDoingAction] = useState(false) const doingActionRef = useRef(doingAction) @@ -116,30 +113,37 @@ const Authorized = ({ const handleConfirm = useCallback(async () => { if (doingActionRef.current) return - if (!pendingOperationCredentialId.current) { + if (!pendingOperationCredentialIdRef.current) { setDeleteCredentialId(null) return } try { handleSetDoingAction(true) - await deletePluginCredential({ credential_id: pendingOperationCredentialId.current }) + await deletePluginCredential({ credential_id: pendingOperationCredentialIdRef.current }) toast.success(t('api.actionSuccess', { ns: 'common' })) onUpdate?.() setDeleteCredentialId(null) - pendingOperationCredentialId.current = null + pendingOperationCredentialIdRef.current = null } finally { handleSetDoingAction(false) } }, [deletePluginCredential, onUpdate, t, handleSetDoingAction]) const [editValues, setEditValues] = useState | null>(null) + const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false) const handleEdit = useCallback((id: string, values: Record) => { setMergedIsOpen(false) - pendingOperationCredentialId.current = id + pendingOperationCredentialIdRef.current = id setEditValues(values) + setIsApiKeyModalOpen(true) }, [setMergedIsOpen]) + const handleApiKeyModalOpenChange = useCallback((open: boolean) => { + setIsApiKeyModalOpen(open) + if (!open) + pendingOperationCredentialIdRef.current = null + }, []) const handleRemove = useCallback(() => { - setDeleteCredentialId(pendingOperationCredentialId.current) + setDeleteCredentialId(pendingOperationCredentialIdRef.current) }, []) const { mutateAsync: setPluginDefaultCredential } = useSetPluginDefaultCredentialHook(pluginPayload) const handleSetDefault = useCallback(async (id: string) => { @@ -213,7 +217,7 @@ const Authorized = ({ ` (${unavailableCredentials.length} ${t('auth.unavailable', { ns: 'plugin' })})` ) } - + ) } @@ -356,12 +360,11 @@ const Authorized = ({ { !!editValues && ( { - setEditValues(null) - pendingOperationCredentialId.current = null - }} + onClose={() => handleApiKeyModalOpenChange(false)} onRemove={handleRemove} disabled={disabled || doingAction} onUpdate={onUpdate} diff --git a/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx b/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx index 7a02781c17..08ac245172 100644 --- a/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx +++ b/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx @@ -156,7 +156,7 @@ describe('PanelContextmenu', () => { fireEvent.click(screen.getByText('common.run')) fireEvent.click(screen.getByText('common.pasteHere')) fireEvent.click(screen.getByText('export')) - fireEvent.click(screen.getByText('common.importDSL')) + fireEvent.click(screen.getByText('importApp')) clickAwayHandler?.() expect(mockHandleAddNote).toHaveBeenCalledTimes(1) diff --git a/web/app/components/workflow/hooks/__tests__/use-nodes-interactions.spec.ts b/web/app/components/workflow/hooks/__tests__/use-nodes-interactions.spec.ts index 41d1fb39d9..a77606fefc 100644 --- a/web/app/components/workflow/hooks/__tests__/use-nodes-interactions.spec.ts +++ b/web/app/components/workflow/hooks/__tests__/use-nodes-interactions.spec.ts @@ -4,6 +4,7 @@ import { createEdge, createNode } from '../../__tests__/fixtures' import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state' import { renderWorkflowHook } from '../../__tests__/workflow-test-env' import { collaborationManager } from '../../collaboration/core/collaboration-manager' +import { CUSTOM_NOTE_NODE } from '../../note-node/constants' import { BlockEnum, ControlMode } from '../../types' import { useNodesInteractions } from '../use-nodes-interactions' @@ -317,6 +318,41 @@ describe('useNodesInteractions', () => { expect(rfState.setEdges).not.toHaveBeenCalled() }) + it('ignores note node selection when clicking a linked text target', () => { + currentNodes = [ + createNode({ + id: 'note-1', + type: CUSTOM_NOTE_NODE, + data: { + type: '' as unknown as BlockEnum, + title: 'Note', + desc: '', + selected: false, + }, + }), + ] + currentEdges = [] + rfState.nodes = currentNodes as unknown as typeof rfState.nodes + rfState.edges = currentEdges as unknown as typeof rfState.edges + + const { result } = renderWorkflowHook(() => useNodesInteractions(), { + historyStore: { + nodes: currentNodes, + edges: currentEdges, + }, + }) + + const link = document.createElement('a') + link.className = 'note-editor-theme_link' + + act(() => { + result.current.handleNodeClick({ target: link } as never, currentNodes[0] as Node) + }) + + expect(rfState.setNodes).not.toHaveBeenCalled() + expect(rfState.setEdges).not.toHaveBeenCalled() + }) + it('updates entering states on node enter and clears them on leave using collaborative workflow state', () => { currentNodes = [ createNode({ diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 0ce86602fa..f885236ad9 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -137,6 +137,12 @@ const getUniquePastedNodeTitle = ( return titleCandidate } +const isNoteLinkClickTarget = (target: EventTarget | null, node: Node) => { + return node.type === CUSTOM_NOTE_NODE + && target instanceof HTMLElement + && !!target.closest('.note-editor-theme_link') +} + export const useNodesInteractions = () => { const { t } = useTranslation() const { data: appDslVersion } = useSuspenseQuery({ @@ -474,10 +480,12 @@ export const useNodesInteractions = () => { ) const handleNodeClick = useCallback( - (_, node) => { + (event, node) => { const { controlMode } = workflowStore.getState() if (controlMode === ControlMode.Comment) return + if (isNoteLinkClickTarget(event.target, node)) + return if (node.type === CUSTOM_ITERATION_START_NODE) return if (node.type === CUSTOM_LOOP_START_NODE) @@ -1704,7 +1712,7 @@ export const useNodesInteractions = () => { nodeId: node.id, }, }) - handleNodeSelect(node.id) + handleNodeSelect(node.id, true) }, [workflowStore, handleNodeSelect], ) diff --git a/web/app/components/workflow/nodes/_base/components/__tests__/form-input-item.branches.spec.tsx b/web/app/components/workflow/nodes/_base/components/__tests__/form-input-item.branches.spec.tsx index 7786dbec17..2e95473bb2 100644 --- a/web/app/components/workflow/nodes/_base/components/__tests__/form-input-item.branches.spec.tsx +++ b/web/app/components/workflow/nodes/_base/components/__tests__/form-input-item.branches.spec.tsx @@ -225,7 +225,7 @@ describe('FormInputItem branches', () => { }) expect(screen.getByText('alpha')).toBeInTheDocument() - fireEvent.click(screen.getByRole('button')) + fireEvent.click(screen.getByText('alpha').closest('button') as HTMLButtonElement) fireEvent.click(screen.getByText('beta')) expect(onChange).toHaveBeenCalledWith({ @@ -320,9 +320,9 @@ describe('FormInputItem branches', () => { }) await waitFor(() => { - expect(screen.getByRole('button')).not.toBeDisabled() + expect(screen.getByText('Select options').closest('button')).not.toBeDisabled() }) - fireEvent.click(screen.getByRole('button')) + fireEvent.click(screen.getByText('Select options').closest('button') as HTMLButtonElement) fireEvent.click(screen.getByText('trigger-option')) expect(onChange).toHaveBeenCalledWith({ diff --git a/web/app/components/workflow/nodes/_base/components/__tests__/node-handle.spec.tsx b/web/app/components/workflow/nodes/_base/components/__tests__/node-handle.spec.tsx index 772814f0f3..1f2c3b3aef 100644 --- a/web/app/components/workflow/nodes/_base/components/__tests__/node-handle.spec.tsx +++ b/web/app/components/workflow/nodes/_base/components/__tests__/node-handle.spec.tsx @@ -194,7 +194,9 @@ describe('node-handle', () => { fireEvent.click(addNodeButton) expect(addNodeButton).toHaveClass('opacity-100') - expect(addNodeButton).toHaveClass('pointer-events-auto') + // Trigger stays pointer-events-none so it never steals mousedown from + // the underlying React Flow handle (drag-to-connect must keep working). + expect(addNodeButton).toHaveClass('pointer-events-none') fireEvent.click(handle) @@ -236,7 +238,7 @@ describe('node-handle', () => { }) expect(getAddNodeButton()).toHaveClass('opacity-100') - expect(getAddNodeButton()).toHaveClass('pointer-events-auto') + expect(getAddNodeButton()).toHaveClass('pointer-events-none') }) it.each([ @@ -266,7 +268,7 @@ describe('node-handle', () => { fireEvent.click(addNodeButton) expect(addNodeButton).toHaveClass('opacity-100') - expect(addNodeButton).toHaveClass('pointer-events-auto') + expect(addNodeButton).toHaveClass('pointer-events-none') fireEvent.click(getSelectNodeButton()) @@ -295,7 +297,7 @@ describe('node-handle', () => { expect(addNodeButton).toHaveClass('custom-selector') expect(addNodeButton).toHaveClass('opacity-100') - expect(addNodeButton).toHaveClass('pointer-events-auto') + expect(addNodeButton).toHaveClass('pointer-events-none') }) it.each([ @@ -332,7 +334,7 @@ describe('node-handle', () => { const addNodeButton = getAddNodeButton() expect(addNodeButton).toHaveClass('opacity-100') - expect(addNodeButton).toHaveClass('pointer-events-auto') + expect(addNodeButton).toHaveClass('pointer-events-none') expect(mockSetShouldAutoOpenStartNodeSelector).toHaveBeenCalledWith(false) expect(mockSetHasSelectedStartNode).toHaveBeenCalledWith(false) }) diff --git a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx index 10a167c504..a00e8e1adc 100644 --- a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx +++ b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx @@ -5,7 +5,7 @@ import type { } from '@/app/components/workflow/types' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import Tooltip from '@/app/components/base/tooltip' +import { Infotip } from '@/app/components/base/infotip' import Collapse from '../collapse' import DefaultValue from './default-value' import ErrorHandleTypeSelector from './error-handle-type-selector' @@ -57,7 +57,9 @@ const ErrorHandle = ({
{t('nodes.common.errorHandle.title', { ns: 'workflow' })}
- + + {t('nodes.common.errorHandle.tip', { ns: 'workflow' })} + {collapseIcon}
= ({ onChange, }) => { const { t } = useTranslation() + const variableLabel = t('nodes.common.typeSwitch.variable', { ns: 'workflow' }) + const inputLabel = t('nodes.common.typeSwitch.input', { ns: 'workflow' }) + return (
- -
onChange(VarType.variable)} - > - -
-
- -
onChange(VarType.constant)} - > - -
-
+ {value === VarType.variable + ? ( + + ) + : ( + + onChange(VarType.variable)} + > + + + )} + /> + {variableLabel} + + )} + {value === VarType.constant + ? ( + + ) + : ( + + onChange(VarType.constant)} + > + + + )} + /> + {inputLabel} + + )}
) } diff --git a/web/app/components/workflow/nodes/_base/components/help-link.tsx b/web/app/components/workflow/nodes/_base/components/help-link.tsx index 30f95a12be..298f50738f 100644 --- a/web/app/components/workflow/nodes/_base/components/help-link.tsx +++ b/web/app/components/workflow/nodes/_base/components/help-link.tsx @@ -1,8 +1,7 @@ import type { BlockEnum } from '@/app/components/workflow/types' -import { RiBookOpenLine } from '@remixicon/react' +import { Tooltip, TooltipContent, TooltipTrigger } from '@langgenius/dify-ui/tooltip' import { memo } from 'react' import { useTranslation } from 'react-i18next' -import TooltipPlus from '@/app/components/base/tooltip' import { useNodeHelpLink } from '../hooks/use-node-help-link' type HelpLinkProps = { @@ -17,19 +16,25 @@ const HelpLink = ({ if (!link) return null - return ( - - - - - + const label = t('userProfile.helpCenter', { ns: 'common' }) + return ( + + + + + )} + /> + {label} + ) } diff --git a/web/app/components/workflow/nodes/_base/components/node-handle.tsx b/web/app/components/workflow/nodes/_base/components/node-handle.tsx index e84b09ac95..58597a1670 100644 --- a/web/app/components/workflow/nodes/_base/components/node-handle.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-handle.tsx @@ -115,9 +115,9 @@ export const NodeTargetHandle = memo(({ triggerClassName={open => ` absolute left-0 top-0 opacity-0 pointer-events-none transition-opacity duration-150 ${nodeSelectorClassName} - group-hover:opacity-100 group-hover:pointer-events-auto - ${data.selected && 'opacity-100 pointer-events-auto'} - ${open && 'opacity-100 pointer-events-auto'} + group-hover:opacity-100 + ${data.selected && 'opacity-100'} + ${open && 'opacity-100'} `} availableBlocksTypes={availablePrevBlocks} /> @@ -233,9 +233,9 @@ export const NodeSourceHandle = memo(({ triggerClassName={open => ` absolute top-0 left-0 opacity-0 pointer-events-none transition-opacity duration-150 ${nodeSelectorClassName} - group-hover:opacity-100 group-hover:pointer-events-auto - ${data.selected && 'opacity-100 pointer-events-auto'} - ${open && 'opacity-100 pointer-events-auto'} + group-hover:opacity-100 + ${data.selected && 'opacity-100'} + ${open && 'opacity-100'} `} availableBlocksTypes={availableNextBlocks} /> diff --git a/web/app/components/workflow/nodes/_base/components/variable/__tests__/var-reference-vars.spec.tsx b/web/app/components/workflow/nodes/_base/components/variable/__tests__/var-reference-vars.spec.tsx index 372fcb3508..b8d1013db9 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/__tests__/var-reference-vars.spec.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/__tests__/var-reference-vars.spec.tsx @@ -52,6 +52,42 @@ describe('VarReferenceVars', () => { expect(onClose).toHaveBeenCalledTimes(1) }) + it('should select the first visible variable by default and support arrow navigation in slash mode', () => { + const onChange = vi.fn() + + render( + , + ) + + const firstItem = screen.getByText('first_value').closest('[data-selected]') + const secondItem = screen.getByText('second_value').closest('[data-selected]') + + expect(firstItem).toHaveAttribute('data-selected', 'true') + expect(secondItem).toHaveAttribute('data-selected', 'false') + + fireEvent.keyDown(document, { key: 'ArrowDown' }) + + expect(firstItem).toHaveAttribute('data-selected', 'false') + expect(secondItem).toHaveAttribute('data-selected', 'true') + + fireEvent.keyDown(document, { key: 'Enter' }) + + expect(onChange).toHaveBeenCalledWith(['node-a', 'second_value'], expect.objectContaining({ + variable: 'second_value', + })) + }) + it('should call onChange when a variable item is chosen', () => { const onChange = vi.fn() @@ -172,6 +208,43 @@ describe('VarReferenceVars', () => { expect(onChange).toHaveBeenNthCalledWith(4, ['node-special', 'asset'], expect.objectContaining({ variable: 'asset' })) }) + it('should resolve selectors for special variables and file support from keyboard selection', () => { + const onChange = vi.fn() + + render( + , + ) + + fireEvent.keyDown(document, { key: 'Enter' }) + fireEvent.keyDown(document, { key: 'ArrowDown' }) + fireEvent.keyDown(document, { key: 'Enter' }) + fireEvent.keyDown(document, { key: 'ArrowDown' }) + fireEvent.keyDown(document, { key: 'Enter' }) + fireEvent.keyDown(document, { key: 'ArrowDown' }) + fireEvent.keyDown(document, { key: 'Enter' }) + + expect(onChange).toHaveBeenNthCalledWith(1, ['env', 'API_KEY'], expect.objectContaining({ variable: 'env.API_KEY' })) + expect(onChange).toHaveBeenNthCalledWith(2, ['conversation', 'user_name'], expect.objectContaining({ variable: 'conversation.user_name' })) + expect(onChange).toHaveBeenNthCalledWith(3, ['node-special', 'current'], expect.objectContaining({ variable: 'current' })) + expect(onChange).toHaveBeenNthCalledWith(4, ['node-special', 'asset'], expect.objectContaining({ variable: 'asset' })) + }) + it('should render object vars and select them by node path', () => { const onChange = vi.fn() @@ -251,4 +324,26 @@ describe('VarReferenceVars', () => { fireEvent.click(screen.getByText('asset')) expect(onChange).not.toHaveBeenCalled() }) + + it('should ignore file vars when file support is disabled during keyboard selection', () => { + const onChange = vi.fn() + + render( + , + ) + + fireEvent.keyDown(document, { key: 'Enter' }) + + expect(onChange).not.toHaveBeenCalled() + }) }) diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-vars.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-vars.tsx index 28ad104ed7..38fef9016d 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-vars.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-vars.tsx @@ -12,11 +12,8 @@ import { import { useHover } from 'ahooks' import { noop } from 'es-toolkit/function' import * as React from 'react' -import { useEffect, useMemo, useRef, useState } from 'react' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' -import { CodeAssistant, MagicEdit } from '@/app/components/base/icons/src/vender/line/general' -import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' import Input from '@/app/components/base/input' import PickerStructurePanel from '@/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/picker' import { VariableIconWithColor } from '@/app/components/workflow/nodes/_base/components/variable/variable-label' @@ -31,6 +28,42 @@ import { getVariableDisplayName, } from './var-reference-vars.helpers' +const VAR_SEARCH_INPUT_CLASS_NAME = 'var-search-input' + +const resolveValueSelector = ({ + itemData, + isFlat, + isSupportFileVar, + nodeId, + objPath, +}: { + itemData: Var + isFlat?: boolean + isSupportFileVar?: boolean + nodeId: string + objPath: string[] +}) => { + const isStructureOutput = itemData.type === VarType.object && (itemData.children as StructuredOutput)?.schema?.properties + const isFile = itemData.type === VarType.file && !isStructureOutput + const isSys = itemData.variable.startsWith('sys.') + const isEnv = itemData.variable.startsWith('env.') + const isChatVar = itemData.variable.startsWith('conversation.') + const isRagVariable = itemData.isRagVariable + + return getValueSelector({ + itemData, + isFlat, + isSupportFileVar, + isFile, + isSys, + isEnv, + isChatVar, + isRagVariable, + nodeId, + objPath, + }) +} + type ItemProps = { nodeId: string title: string @@ -47,6 +80,8 @@ type ItemProps = { zIndex?: number className?: string preferSchemaType?: boolean + isSelected?: boolean + onActivate?: () => void } const Item: FC = ({ @@ -64,11 +99,11 @@ const Item: FC = ({ zIndex, className, preferSchemaType, + isSelected, + onActivate, }) => { const isStructureOutput = itemData.type === VarType.object && (itemData.children as StructuredOutput)?.schema?.properties - const isFile = itemData.type === VarType.file && !isStructureOutput const isObj = ([VarType.object, VarType.file].includes(itemData.type) && itemData.children && (itemData.children as Var[]).length > 0) - const isSys = itemData.variable.startsWith('sys.') const isEnv = itemData.variable.startsWith('env.') const isChatVar = itemData.variable.startsWith('conversation.') const isRagVariable = itemData.isRagVariable @@ -76,15 +111,21 @@ const Item: FC = ({ if (!isFlat) return null const variable = itemData.variable - let Icon switch (variable) { case 'current': - Icon = isInCodeGeneratorInstructionEditor ? CodeAssistant : MagicEdit - return + return ( + + ) case 'error_message': - return + return default: - return + return } }, [isFlat, isInCodeGeneratorInstructionEditor, itemData.variable]) @@ -147,15 +188,10 @@ const Item: FC = ({ const handleChosen = (e: React.MouseEvent) => { e.stopPropagation() e.nativeEvent.stopImmediatePropagation() - const valueSelector = getValueSelector({ + const valueSelector = resolveValueSelector({ itemData, isFlat, isSupportFileVar, - isFile, - isSys, - isEnv, - isChatVar, - isRagVariable, nodeId, objPath, }) @@ -173,11 +209,13 @@ const Item: FC = ({ ref={itemRef} className={cn( (isObj || isStructureOutput) ? 'pr-1' : 'pr-[18px]', - isHovering && ((isObj || isStructureOutput) ? 'bg-components-panel-on-panel-item-bg-hover' : 'bg-state-base-hover'), + (isHovering || isSelected) && ((isObj || isStructureOutput) ? 'bg-components-panel-on-panel-item-bg-hover' : 'bg-state-base-hover'), 'relative flex h-6 w-full cursor-pointer items-center rounded-md pl-3', className, )} + data-selected={isSelected ? 'true' : 'false'} onClick={handleChosen} + onMouseEnter={onActivate} onMouseDown={(e) => { e.preventDefault() e.stopPropagation() @@ -210,7 +248,7 @@ const Item: FC = ({
{(preferSchemaType && itemData.schemaType) ? itemData.schemaType : itemData.type}
{ (isObj || isStructureOutput) && ( - + ) }
@@ -221,7 +259,7 @@ const Item: FC = ({ open={open} onOpenChange={noop} > - + = ({ }) => { const { t } = useTranslation() const [internalSearchValue, setInternalSearchValue] = useState('') + const listRef = useRef(null) const searchValue = searchText ?? internalSearchValue - - const handleKeyDown = (e: React.KeyboardEvent) => { - if (e.key === 'Escape') { - e.preventDefault() - onClose?.() - } - } - const filteredVars = useMemo(() => filterReferenceVars(vars, searchValue), [vars, searchValue]) + const selectableItems = useMemo(() => { + return filteredVars.flatMap(node => node.vars.map(item => ({ + nodeId: node.nodeId, + isFlat: node.isFlat, + itemData: item, + }))) + }, [filteredVars]) + const indexedFilteredVars = useMemo(() => { + let optionIndex = 0 + + return filteredVars.map(node => ({ + ...node, + vars: node.vars.map(variable => ({ + variable, + optionIndex: optionIndex++, + })), + })) + }, [filteredVars]) + const [selectedIndex, setSelectedIndex] = useState(-1) + const effectiveSelectedIndex = selectableItems.length ? Math.min(Math.max(selectedIndex, 0), selectableItems.length - 1) : -1 + + useEffect(() => { + const listElement = listRef.current + const selectedElement = listElement?.querySelector('[data-selected="true"]') as HTMLElement | null + if (!listElement || !selectedElement) + return + + const selectedTop = selectedElement.offsetTop + const selectedBottom = selectedTop + selectedElement.offsetHeight + const visibleTop = listElement.scrollTop + const visibleBottom = visibleTop + listElement.clientHeight + + if (selectedTop < visibleTop) + listElement.scrollTop = selectedTop + else if (selectedBottom > visibleBottom) + listElement.scrollTop = selectedBottom - listElement.clientHeight + }, [effectiveSelectedIndex]) + + const selectItem = useCallback((index: number) => { + const selectedItem = selectableItems[index] + if (!selectedItem) + return + + const { itemData, nodeId, isFlat } = selectedItem + const valueSelector = resolveValueSelector({ + itemData, + isFlat, + isSupportFileVar, + nodeId, + objPath: [], + }) + + if (valueSelector) + onChange(valueSelector, itemData) + }, [isSupportFileVar, onChange, selectableItems]) + + const handleKeyboardEvent = useCallback((event: Pick) => { + if (event.key === 'Escape') { + event.preventDefault() + onClose?.() + return + } + + if (!selectableItems.length) + return + + if (event.key === 'ArrowDown' || event.key === 'ArrowUp') { + event.preventDefault() + event.stopPropagation() + setSelectedIndex( + event.key === 'ArrowDown' + ? Math.min(effectiveSelectedIndex + 1, selectableItems.length - 1) + : Math.max(effectiveSelectedIndex - 1, 0), + ) + return + } + + if (event.key === 'Enter') { + event.preventDefault() + event.stopPropagation() + selectItem(effectiveSelectedIndex) + } + }, [effectiveSelectedIndex, onClose, selectableItems.length, selectItem]) + + const handleKeyDown = useCallback((e: React.KeyboardEvent) => { + handleKeyboardEvent(e) + }, [handleKeyboardEvent]) + + useEffect(() => { + if (!hideSearch) + return + + const handleDocumentKeyDown = (event: KeyboardEvent) => { + if (event.altKey || event.ctrlKey || event.metaKey) + return + if (!['ArrowDown', 'ArrowUp', 'Enter', 'Escape'].includes(event.key)) + return + + handleKeyboardEvent(event) + } + + document.addEventListener('keydown', handleDocumentKeyDown, true) + return () => document.removeEventListener('keydown', handleDocumentKeyDown, true) + }, [handleKeyboardEvent, hideSearch]) return ( <> { !hideSearch && ( <> -
e.stopPropagation()}> +
e.stopPropagation()}> = ({ {filteredVars.length > 0 ? ( -
- +
{ - filteredVars.map((item, i) => ( -
+ indexedFilteredVars.map((item, i) => ( +
{!item.isFlat && (
= ({ {item.title}
)} - {item.vars.map((v, j) => ( + {item.vars.map(({ variable, optionIndex }) => ( setSelectedIndex(optionIndex)} /> ))} - {item.isFlat && !filteredVars[i + 1]?.isFlat && !!filteredVars.find(item => !item.isFlat) && ( + {item.isFlat && !indexedFilteredVars[i + 1]?.isFlat && !!indexedFilteredVars.find(item => !item.isFlat) && (
{t('debug.lastOutput', { ns: 'workflow' })}
diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index ed83c58b6e..43f8e5773e 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -79,6 +79,7 @@ const BaseNode: FC = ({ const appId = useStore(s => s.appId) const { nodePanelPresence } = useCollaboration(appId as string) const controlMode = useStore(s => s.controlMode) + const isContextMenuTarget = useStore(s => s.nodeMenu?.nodeId === id) const currentUserPresence = useMemo(() => { const userId = userProfile?.id || '' @@ -123,7 +124,7 @@ const BaseNode: FC = ({ const { hasNodeInspectVars } = useInspectVarsCrud() const isLoading = data._runningStatus === NodeRunningStatus.Running || data._singleRunningStatus === NodeRunningStatus.Running const hasVarValue = hasNodeInspectVars(id) - const showSelectedBorder = Boolean(data.selected || data._isBundled || data._isEntering) + const showSelectedBorder = Boolean(data.selected || isContextMenuTarget || data._isBundled || data._isEntering) const { showRunningBorder, showSuccessBorder, diff --git a/web/app/components/workflow/nodes/agent/components/__tests__/model-bar.spec.tsx b/web/app/components/workflow/nodes/agent/components/__tests__/model-bar.spec.tsx index d85f54ed19..2127b48dca 100644 --- a/web/app/components/workflow/nodes/agent/components/__tests__/model-bar.spec.tsx +++ b/web/app/components/workflow/nodes/agent/components/__tests__/model-bar.spec.tsx @@ -1,5 +1,5 @@ import type { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { fireEvent, render, screen } from '@testing-library/react' +import { render, screen } from '@testing-library/react' import { ModelBar } from '../model-bar' type ModelProviderItem = { @@ -52,11 +52,9 @@ describe('agent/model-bar', () => { const emptySelector = screen.getByText((_, element) => element?.textContent === 'no-model:0') - fireEvent.mouseEnter(emptySelector) - expect(emptySelector).toBeInTheDocument() expect(screen.getByText('indicator:red')).toBeInTheDocument() - expect(screen.getByText('workflow.nodes.agent.modelNotSelected')).toBeInTheDocument() + expect(screen.getByLabelText('workflow.nodes.agent.modelNotSelected')).toBeInTheDocument() }) it('should render the selected model without warning when it is installed', () => { @@ -69,10 +67,8 @@ describe('agent/model-bar', () => { it('should show a warning tooltip when the selected model is not installed', () => { render() - fireEvent.mouseEnter(screen.getByText('openai/gpt-4.1:1')) - expect(screen.getByText('openai/gpt-4.1:1')).toBeInTheDocument() expect(screen.getByText('indicator:red')).toBeInTheDocument() - expect(screen.getByText('workflow.nodes.agent.modelNotInstallTooltip')).toBeInTheDocument() + expect(screen.getByLabelText('workflow.nodes.agent.modelNotInstallTooltip')).toBeInTheDocument() }) }) diff --git a/web/app/components/workflow/nodes/agent/components/__tests__/tool-icon.spec.tsx b/web/app/components/workflow/nodes/agent/components/__tests__/tool-icon.spec.tsx index 30a12bb528..af61b43367 100644 --- a/web/app/components/workflow/nodes/agent/components/__tests__/tool-icon.spec.tsx +++ b/web/app/components/workflow/nodes/agent/components/__tests__/tool-icon.spec.tsx @@ -87,19 +87,17 @@ describe('agent/tool-icon', () => { const { rerender } = render() - fireEvent.mouseEnter(screen.getByText('app-icon:#fff:B')) expect(screen.getByText('indicator:yellow')).toBeInTheDocument() - expect(screen.getByText('workflow.nodes.agent.toolNotAuthorizedTooltip:{"tool":"tool-b"}')).toBeInTheDocument() + expect(screen.getByLabelText('workflow.nodes.agent.toolNotAuthorizedTooltip:{"tool":"tool-b"}')).toBeInTheDocument() mockWorkflowTools = [] mockMarketplaceIcon = 'https://example.com/market-tool.png' rerender() const marketplaceIcon = screen.getByRole('img', { name: 'tool icon' }) - fireEvent.mouseEnter(marketplaceIcon) expect(marketplaceIcon).toHaveAttribute('src', 'https://example.com/market-tool.png') expect(screen.getByText('indicator:red')).toBeInTheDocument() - expect(screen.getByText('workflow.nodes.agent.toolNotInstallTooltip:{"tool":"tool-c"}')).toBeInTheDocument() + expect(screen.getByLabelText('workflow.nodes.agent.toolNotInstallTooltip:{"tool":"tool-c"}')).toBeInTheDocument() }) it('should fall back to the group icon while tool data is still loading', () => { diff --git a/web/app/components/workflow/nodes/agent/components/model-bar.tsx b/web/app/components/workflow/nodes/agent/components/model-bar.tsx index 8e2f19d726..0ec0b943ef 100644 --- a/web/app/components/workflow/nodes/agent/components/model-bar.tsx +++ b/web/app/components/workflow/nodes/agent/components/model-bar.tsx @@ -1,7 +1,7 @@ import type { FC } from 'react' +import { Tooltip, TooltipContent, TooltipTrigger } from '@langgenius/dify-ui/tooltip' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' -import Tooltip from '@/app/components/base/tooltip' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' @@ -10,7 +10,10 @@ import Indicator from '@/app/components/header/indicator' type ModelBarProps = { provider: string model: string -} | {} +} | { + provider?: never + model?: never +} const useAllModel = () => { const { data: textGeneration } = useModelList(ModelTypeEnum.textGeneration) @@ -35,23 +38,27 @@ const useAllModel = () => { export const ModelBar: FC = (props) => { const { t } = useTranslation() const modelList = useAllModel() - if (!('provider' in props)) { + if (props.provider === undefined) { + const tooltip = t('nodes.agent.modelNotSelected', { ns: 'workflow' }) + return ( - -
- - -
+ + + + +
+ )} + /> + {tooltip} ) } @@ -59,23 +66,34 @@ export const ModelBar: FC = (props) => { provider => provider.provider === props.provider && provider.models.some(model => model.model === props.model), ) const showWarn = modelList && !modelInstalled - return modelList && ( - -
- - {showWarn && } -
+ if (!modelList) + return null + + const modelNotInstalledTooltip = t('nodes.agent.modelNotInstallTooltip', { ns: 'workflow' }) + const modelSelector = ( +
+ + {showWarn && } +
+ ) + + if (modelInstalled) + return modelSelector + + return ( + + + {modelNotInstalledTooltip} ) } diff --git a/web/app/components/workflow/nodes/agent/components/tool-icon.tsx b/web/app/components/workflow/nodes/agent/components/tool-icon.tsx index b545c2f370..3986dcf6a4 100644 --- a/web/app/components/workflow/nodes/agent/components/tool-icon.tsx +++ b/web/app/components/workflow/nodes/agent/components/tool-icon.tsx @@ -1,9 +1,10 @@ +import type { ReactNode } from 'react' import { cn } from '@langgenius/dify-ui/cn' +import { Tooltip, TooltipContent, TooltipTrigger } from '@langgenius/dify-ui/tooltip' import { memo, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import AppIcon from '@/app/components/base/app-icon' import { Group } from '@/app/components/base/icons/src/vender/other' -import Tooltip from '@/app/components/base/tooltip' import Indicator from '@/app/components/header/indicator' import { useAllBuiltInTools, useAllCustomTools, useAllMCPTools, useAllWorkflowTools } from '@/service/use-tools' import { getIconFromMarketPlace } from '@/utils/get-icon' @@ -62,44 +63,50 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => { throw new Error('Unknown status') }, [name, notSuccess, status, t]) const [iconFetchError, setIconFetchError] = useState(false) - return ( - + + if (!iconFetchError && icon) { + if (typeof icon === 'string') { + iconContent = ( + tool icon setIconFetchError(true)} + /> + ) + } + else if (typeof icon === 'object') { + iconContent = ( + + ) + } + } + + const iconNode = ( +
-
-
- {(() => { - if (iconFetchError || !icon) - return - if (typeof icon === 'string') { - return ( - tool icon setIconFetchError(true)} - /> - ) - } - if (typeof icon === 'object') { - return ( - - ) - } - return - })()} -
- {indicator && } +
+ {iconContent}
+ {indicator && } +
+ ) + + if (!notSuccess || !tooltip) + return iconNode + + return ( + + + {tooltip} ) }) diff --git a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx index 814b3cea6d..b81032242f 100644 --- a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx @@ -9,7 +9,7 @@ import { import { Switch } from '@langgenius/dify-ui/switch' import { memo, useCallback } from 'react' import { useTranslation } from 'react-i18next' -import Tooltip from '@/app/components/base/tooltip' +import { Infotip } from '@/app/components/base/infotip' import { env } from '@/env' export type TopKAndScoreThresholdProps = { @@ -59,10 +59,13 @@ const TopKAndScoreThreshold = ({
{t('datasetConfig.top_k', { ns: 'appDebug' })} - + + {t('datasetConfig.top_kTip', { ns: 'appDebug' })} +
{t('datasetConfig.score_threshold', { ns: 'appDebug' })}
- + + {t('datasetConfig.score_thresholdTip', { ns: 'appDebug' })} +
+type EditorOnMount = NonNullable['onMount']> +type MonacoEditor = Parameters[0] +type Monaco = Parameters[1] + const CodeEditor: FC = ({ value, onUpdate, @@ -36,8 +39,8 @@ const CodeEditor: FC = ({ }) => { const { t } = useTranslation() const { theme } = useTheme() - const monacoRef = useRef(null) - const editorRef = useRef(null) + const monacoRef = useRef(null) + const editorRef = useRef(null) const [isMounted, setIsMounted] = React.useState(false) const containerRef = useRef(null) @@ -50,7 +53,7 @@ const CodeEditor: FC = ({ } }, [theme]) - const handleEditorDidMount = useCallback((editor: any, monaco: any) => { + const handleEditorDidMount = useCallback((editor, monaco) => { editorRef.current = editor monacoRef.current = monaco @@ -83,7 +86,7 @@ const CodeEditor: FC = ({ }) monaco.editor.setTheme('light-theme') setIsMounted(true) - }, []) + }, [onBlur, onFocus]) const formatJsonContent = useCallback(() => { if (editorRef.current) @@ -122,24 +125,36 @@ const CodeEditor: FC = ({
{showFormatButton && ( - - + + + + + )} + /> + {t('operation.format', { ns: 'common' })} )} - - + + copy(value)} + > + + + )} + /> + {t('operation.copy', { ns: 'common' })}
diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/actions.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/actions.tsx index 0afedab3d2..2f3b70aefc 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/actions.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/actions.tsx @@ -1,8 +1,7 @@ import type { FC } from 'react' -import { RiAddCircleLine, RiDeleteBinLine, RiEditLine } from '@remixicon/react' +import { Tooltip, TooltipContent, TooltipTrigger } from '@langgenius/dify-ui/tooltip' import * as React from 'react' import { useTranslation } from 'react-i18next' -import Tooltip from '@/app/components/base/tooltip' type ActionsProps = { disableAddBtn: boolean @@ -18,36 +17,59 @@ const Actions: FC = ({ onDelete, }) => { const { t } = useTranslation() + const addChildFieldLabel = t('nodes.llm.jsonSchema.addChildField', { ns: 'workflow' }) + const editLabel = t('operation.edit', { ns: 'common' }) + const removeLabel = t('operation.remove', { ns: 'common' }) return (
- - + + + + + )} + /> + {addChildFieldLabel} - - + + + + + )} + /> + {editLabel} - - + + + + + )} + /> + {removeLabel}
) diff --git a/web/app/components/workflow/nodes/tool/components/__tests__/copy-id.spec.tsx b/web/app/components/workflow/nodes/tool/components/__tests__/copy-id.spec.tsx index ee6791ca03..2fb7e66e24 100644 --- a/web/app/components/workflow/nodes/tool/components/__tests__/copy-id.spec.tsx +++ b/web/app/components/workflow/nodes/tool/components/__tests__/copy-id.spec.tsx @@ -20,27 +20,21 @@ describe('tool/copy-id', () => { it('should copy content and reset copied state when mouse leaves', () => { const { container } = render() - const trigger = screen.getByText('tool-123').parentElement as HTMLElement + const trigger = screen.getByRole('button', { name: 'appOverview.overview.appInfo.embedded.copy' }) const wrapper = container.querySelector('.inline-flex') as HTMLElement - act(() => { - fireEvent.mouseEnter(trigger) - }) - expect(screen.getByText('appOverview.overview.appInfo.embedded.copy')).toBeInTheDocument() - act(() => { fireEvent.click(trigger) vi.advanceTimersByTime(100) }) expect(copy).toHaveBeenCalledWith('tool-123') - expect(screen.getByText('appOverview.overview.appInfo.embedded.copied')).toBeInTheDocument() + expect(trigger).toHaveAccessibleName('appOverview.overview.appInfo.embedded.copied') act(() => { fireEvent.mouseLeave(wrapper) vi.advanceTimersByTime(100) - fireEvent.mouseEnter(trigger) }) - expect(screen.getByText('appOverview.overview.appInfo.embedded.copy')).toBeInTheDocument() + expect(trigger).toHaveAccessibleName('appOverview.overview.appInfo.embedded.copy') }) it('should stop click propagation from the outer wrapper', () => { diff --git a/web/app/components/workflow/nodes/tool/components/copy-id.tsx b/web/app/components/workflow/nodes/tool/components/copy-id.tsx index eaf3d1bec5..18a510caaf 100644 --- a/web/app/components/workflow/nodes/tool/components/copy-id.tsx +++ b/web/app/components/workflow/nodes/tool/components/copy-id.tsx @@ -1,11 +1,10 @@ 'use client' -import { RiFileCopyLine } from '@remixicon/react' +import { Tooltip, TooltipContent, TooltipTrigger } from '@langgenius/dify-ui/tooltip' import copy from 'copy-to-clipboard' import { debounce } from 'es-toolkit/compat' import * as React from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import Tooltip from '@/app/components/base/tooltip' type Props = { content: string @@ -25,27 +24,33 @@ const CopyFeedbackNew = ({ content }: Props) => { const onMouseLeave = debounce(() => { setIsCopied(false) }, 100) + const tooltip = (isCopied + ? t(`${prefixEmbedded}.copied`, { ns: 'appOverview' }) + : t(`${prefixEmbedded}.copy`, { ns: 'appOverview' })) || '' return (
e.stopPropagation()} onMouseLeave={onMouseLeave}> - -
-
- {content} -
- -
+ + + + {content} + + + + )} + /> + + {tooltip} +
) diff --git a/web/app/components/workflow/note-node/__tests__/hooks.spec.tsx b/web/app/components/workflow/note-node/__tests__/hooks.spec.tsx index 9642d3d9bf..f31e550284 100644 --- a/web/app/components/workflow/note-node/__tests__/hooks.spec.tsx +++ b/web/app/components/workflow/note-node/__tests__/hooks.spec.tsx @@ -1,4 +1,5 @@ import { act, renderHook } from '@testing-library/react' +import { NOTE_SHOW_AUTHOR_STORAGE_KEY } from '../constants' import { useNote } from '../hooks' const mockHandleNodeDataUpdateWithSyncDraft = vi.hoisted(() => vi.fn()) @@ -19,6 +20,7 @@ vi.mock('../../hooks', () => ({ describe('useNote', () => { beforeEach(() => { vi.clearAllMocks() + localStorage.clear() }) it('updates theme and author visibility while saving note history entries', () => { @@ -39,6 +41,7 @@ describe('useNote', () => { }) expect(mockSaveStateToHistory).toHaveBeenNthCalledWith(1, 'note-change', { nodeId: 'note-1' }) expect(mockSaveStateToHistory).toHaveBeenNthCalledWith(2, 'note-change', { nodeId: 'note-1' }) + expect(localStorage.getItem(NOTE_SHOW_AUTHOR_STORAGE_KEY)).toBe('true') }) it('serializes non-empty editor state and clears empty editor state', () => { diff --git a/web/app/components/workflow/note-node/constants.ts b/web/app/components/workflow/note-node/constants.ts index b2fa223690..d4891977cc 100644 --- a/web/app/components/workflow/note-node/constants.ts +++ b/web/app/components/workflow/note-node/constants.ts @@ -1,6 +1,7 @@ import { NoteTheme } from './types' export const CUSTOM_NOTE_NODE = 'custom-note' +export const NOTE_SHOW_AUTHOR_STORAGE_KEY = 'workflow-note-show-author' export const THEME_MAP: Record = { [NoteTheme.blue]: { diff --git a/web/app/components/workflow/note-node/hooks.ts b/web/app/components/workflow/note-node/hooks.ts index 6924f31af5..6248e7670d 100644 --- a/web/app/components/workflow/note-node/hooks.ts +++ b/web/app/components/workflow/note-node/hooks.ts @@ -2,6 +2,7 @@ import type { EditorState } from 'lexical' import type { NoteTheme } from './types' import { useCallback } from 'react' import { useNodeDataUpdate, useWorkflowHistory, WorkflowHistoryEvent } from '../hooks' +import { NOTE_SHOW_AUTHOR_STORAGE_KEY } from './constants' export const useNote = (id: string) => { const { handleNodeDataUpdateWithSyncDraft } = useNodeDataUpdate() @@ -20,6 +21,7 @@ export const useNote = (id: string) => { }, [handleNodeDataUpdateWithSyncDraft, id]) const handleShowAuthorChange = useCallback((showAuthor: boolean) => { + localStorage.setItem(NOTE_SHOW_AUTHOR_STORAGE_KEY, String(showAuthor)) handleNodeDataUpdateWithSyncDraft({ id, data: { showAuthor } }) saveStateToHistory(WorkflowHistoryEvent.NoteChange, { nodeId: id }) }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) diff --git a/web/app/components/workflow/note-node/note-editor/__tests__/editor.spec.tsx b/web/app/components/workflow/note-node/note-editor/__tests__/editor.spec.tsx index 92df65d8f2..b675f57849 100644 --- a/web/app/components/workflow/note-node/note-editor/__tests__/editor.spec.tsx +++ b/web/app/components/workflow/note-node/note-editor/__tests__/editor.spec.tsx @@ -1,4 +1,7 @@ import type { EditorState, LexicalEditor } from 'lexical' +import { readFileSync } from 'node:fs' +import { resolve } from 'node:path' +import { $createLinkNode } from '@lexical/link' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { $createParagraphNode, $createTextNode, $getRoot } from 'lexical' @@ -7,6 +10,10 @@ import { NoteEditorContextProvider } from '../context' import Editor from '../editor' const emptyValue = JSON.stringify({ root: { children: [] } }) +const themeCss = readFileSync( + resolve(process.cwd(), 'app/components/workflow/note-node/note-editor/theme/theme.css'), + 'utf8', +) const EditorProbe = ({ onReady, @@ -52,6 +59,35 @@ describe('Editor', () => { expect(screen.getByText('Type note')).toBeInTheDocument() expect(screen.getByRole('textbox')).toBeInTheDocument() }) + + it('should render linked text with distinct link styling', async () => { + let editor: LexicalEditor | null = null + + renderEditor({}, instance => (editor = instance)) + + await waitFor(() => { + expect(editor).not.toBeNull() + }) + + act(() => { + editor!.update(() => { + const root = $getRoot() + root.clear() + const paragraph = $createParagraphNode() + const link = $createLinkNode('https://example.com/docs') + link.append($createTextNode('Linked docs')) + paragraph.append(link) + root.append(paragraph) + }, { discrete: true }) + }) + + const link = await screen.findByRole('link', { name: 'Linked docs' }) + + expect(link).toHaveClass('note-editor-theme_link') + expect(themeCss).toContain('.note-editor-theme_link') + expect(themeCss).toContain('font-weight: 500;') + expect(themeCss).toContain('text-decoration: underline;') + }) }) // Focus and blur should toggle workflow shortcuts while editing content. diff --git a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/__tests__/component.spec.tsx b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/__tests__/component.spec.tsx index b288421b60..2a278fc1e0 100644 --- a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/__tests__/component.spec.tsx +++ b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/__tests__/component.spec.tsx @@ -1,4 +1,4 @@ -import { fireEvent, render, screen } from '@testing-library/react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' import NoteEditorContext from '../../../context' import { createNoteEditorStore } from '../../../store' import LinkEditorComponent from '../component' @@ -18,6 +18,59 @@ describe('link editor component', () => { vi.clearAllMocks() }) + it('cancels a newly created empty link when pressing Escape', () => { + const store = createNoteEditorStore() + const anchor = document.createElement('button') + const portalRoot = document.createElement('div') + document.body.appendChild(anchor) + document.body.appendChild(portalRoot) + store.setState({ + linkAnchorElement: anchor, + linkOperatorShow: false, + selectedLinkUrl: '', + }) + + render( + + + , + ) + + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Escape' }) + + expect(mockHandleUnlink).toHaveBeenCalledTimes(1) + expect(mockHandleSaveLink).not.toHaveBeenCalled() + }) + + it('cancels a newly created empty link when clicking outside the editor', async () => { + const store = createNoteEditorStore() + const anchor = document.createElement('button') + const portalRoot = document.createElement('div') + document.body.appendChild(anchor) + document.body.appendChild(portalRoot) + store.setState({ + linkAnchorElement: anchor, + linkOperatorShow: false, + selectedLinkUrl: '', + }) + + render( + + + , + ) + + expect(screen.getByRole('textbox')).toBeInTheDocument() + fireEvent.mouseDown(document.body) + fireEvent.mouseUp(document.body) + fireEvent.click(document.body) + + await waitFor(() => { + expect(mockHandleUnlink).toHaveBeenCalledTimes(1) + }) + expect(mockHandleSaveLink).not.toHaveBeenCalled() + }) + it('renders the inline link editor and saves the edited url', () => { const store = createNoteEditorStore() const anchor = document.createElement('button') @@ -42,4 +95,27 @@ describe('link editor component', () => { expect(mockHandleSaveLink).toHaveBeenCalledWith('https://example.com') }) + + it('saves the edited url when pressing Enter', () => { + const store = createNoteEditorStore() + const anchor = document.createElement('button') + const portalRoot = document.createElement('div') + document.body.appendChild(anchor) + document.body.appendChild(portalRoot) + store.setState({ + linkAnchorElement: anchor, + linkOperatorShow: false, + selectedLinkUrl: 'https://example.com', + }) + + render( + + + , + ) + + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) + + expect(mockHandleSaveLink).toHaveBeenCalledWith('https://example.com') + }) }) diff --git a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/__tests__/hooks.spec.tsx b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/__tests__/hooks.spec.tsx index 4272050fac..c0c6767b46 100644 --- a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/__tests__/hooks.spec.tsx +++ b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/__tests__/hooks.spec.tsx @@ -14,7 +14,7 @@ const { } = vi.hoisted(() => { const listeners: { update?: () => void - click?: (payload: { metaKey?: boolean, ctrlKey?: boolean }) => boolean + click?: (payload: { metaKey?: boolean, ctrlKey?: boolean, target?: EventTarget | null }) => boolean } = {} const editor = { @@ -36,6 +36,8 @@ const { selectedLinkUrl: '', setLinkAnchorElement: vi.fn(), setLinkOperatorShow: vi.fn(), + setSelectedLinkUrl: vi.fn(), + setSelectedIsLink: vi.fn(), }, mockListeners: listeners, } @@ -78,6 +80,8 @@ describe('link editor hooks', () => { mockStoreState.selectedLinkUrl = '' mockStoreState.setLinkAnchorElement = mockSetLinkAnchorElement mockStoreState.setLinkOperatorShow = mockSetLinkOperatorShow + mockStoreState.setSelectedLinkUrl = vi.fn() + mockStoreState.setSelectedIsLink = vi.fn() mockListeners.update = undefined mockListeners.click = undefined @@ -124,6 +128,26 @@ describe('link editor hooks', () => { expect(mockSetLinkOperatorShow).toHaveBeenCalledWith(false) }) + it('should show the link operator immediately when clicking a link target', () => { + const target = document.createElement('a') + target.className = 'note-editor-theme_link' + target.href = 'https://dify.ai/docs' + + renderHook(() => useOpenLink()) + + let handled = false + act(() => { + handled = mockListeners.click?.({ target }) ?? false + vi.runAllTimers() + }) + + expect(handled).toBe(true) + expect(mockStoreState.setSelectedLinkUrl).toHaveBeenCalledWith('https://dify.ai/docs') + expect(mockStoreState.setSelectedIsLink).toHaveBeenCalledWith(true) + expect(mockSetLinkAnchorElement).toHaveBeenCalledWith(target) + expect(mockSetLinkOperatorShow).toHaveBeenCalledWith(true) + }) + it('should open the selected link in a new tab on meta or ctrl click', () => { mockStoreState.selectedIsLink = true mockStoreState.selectedLinkUrl = 'https://dify.ai' diff --git a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/component.tsx b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/component.tsx index 754905cd46..7e5b6c586c 100644 --- a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/component.tsx +++ b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/component.tsx @@ -16,7 +16,9 @@ import { useClickAway } from 'ahooks' import { escape } from 'es-toolkit/string' import { memo, + useCallback, useEffect, + useRef, useState, } from 'react' import { useTranslation } from 'react-i18next' @@ -40,6 +42,7 @@ const LinkEditorComponent = ({ const setLinkAnchorElement = useStore(s => s.setLinkAnchorElement) const setLinkOperatorShow = useStore(s => s.setLinkOperatorShow) const [url, setUrl] = useState(selectedLinkUrl) + const floatingRef = useRef(null) const { refs, floatingStyles, elements } = useFloating({ placement: 'top', middleware: [ @@ -49,9 +52,19 @@ const LinkEditorComponent = ({ ], }) - useClickAway(() => { + const handleCancelLinkEdit = useCallback(() => { + if (!linkOperatorShow && !selectedLinkUrl) { + handleUnlink() + return + } + setLinkAnchorElement() - }, linkAnchorElement) + setLinkOperatorShow(false) + }, [handleUnlink, linkOperatorShow, selectedLinkUrl, setLinkAnchorElement, setLinkOperatorShow]) + + useClickAway(() => { + handleCancelLinkEdit() + }, [floatingRef, linkAnchorElement]) useEffect(() => { setUrl(selectedLinkUrl) @@ -74,7 +87,10 @@ const LinkEditorComponent = ({ linkOperatorShow && 'p-0.5 system-xs-medium text-text-tertiary shadow-sm', )} style={floatingStyles} - ref={refs.setFloating} + ref={(node) => { + refs.setFloating(node) + floatingRef.current = node + }} > { !linkOperatorShow && ( @@ -83,6 +99,21 @@ const LinkEditorComponent = ({ className="mr-0.5 h-6 w-[196px] appearance-none rounded-xs bg-transparent p-1 text-[13px] text-components-input-text-filled outline-hidden" value={url} onChange={e => setUrl(e.target.value)} + onKeyDown={(e) => { + if (e.key === 'Enter') { + e.preventDefault() + e.stopPropagation() + if (url) + handleSaveLink(url) + return + } + + if (e.key === 'Escape') { + e.preventDefault() + e.stopPropagation() + handleCancelLinkEdit() + } + }} placeholder={t('nodes.note.editor.enterUrl', { ns: 'workflow' }) || ''} autoFocus /> diff --git a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts index 6debfa5f8b..361a74f4c6 100644 --- a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts +++ b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts @@ -9,6 +9,12 @@ import { useTranslation } from 'react-i18next' import { useNoteEditorStore } from '../../store' import { urlRegExp } from '../../utils' +const getClickedLinkElement = (target: EventTarget | null) => { + return target instanceof HTMLElement + ? target.closest('.note-editor-theme_link') as HTMLElement | null + : null +} + export const useOpenLink = () => { const [editor] = useLexicalComposerContext() const noteEditorStore = useNoteEditorStore() @@ -30,11 +36,34 @@ export const useOpenLink = () => { }) }), editor.registerCommand(CLICK_COMMAND, (payload) => { setTimeout(() => { - const { selectedLinkUrl, selectedIsLink, setLinkAnchorElement, setLinkOperatorShow } = noteEditorStore.getState() + const { + selectedLinkUrl, + selectedIsLink, + setLinkAnchorElement, + setLinkOperatorShow, + setSelectedLinkUrl, + setSelectedIsLink, + } = noteEditorStore.getState() + const clickedLinkElement = getClickedLinkElement(payload.target) + const clickedLinkUrl = clickedLinkElement?.getAttribute('href') || selectedLinkUrl + + if (clickedLinkElement && clickedLinkUrl) { + if (payload.metaKey || payload.ctrlKey) { + window.open(clickedLinkUrl, '_blank') + return + } + + setSelectedLinkUrl(clickedLinkUrl) + setSelectedIsLink(true) + setLinkAnchorElement(clickedLinkElement) + setLinkOperatorShow(true) + return + } + if (selectedIsLink) { if ((payload.metaKey || payload.ctrlKey) && selectedLinkUrl) { window.open(selectedLinkUrl, '_blank') - return true + return } setLinkAnchorElement(true) if (selectedLinkUrl) @@ -47,7 +76,7 @@ export const useOpenLink = () => { setLinkOperatorShow(false) } }) - return false + return !!getClickedLinkElement(payload.target) }, COMMAND_PRIORITY_LOW)) }, [editor, noteEditorStore]) } diff --git a/web/app/components/workflow/note-node/note-editor/store.ts b/web/app/components/workflow/note-node/note-editor/store.ts index 3507bb7c0c..4bba12b6f6 100644 --- a/web/app/components/workflow/note-node/note-editor/store.ts +++ b/web/app/components/workflow/note-node/note-editor/store.ts @@ -7,7 +7,7 @@ import NoteEditorContext from './context' type Shape = { linkAnchorElement: HTMLElement | null - setLinkAnchorElement: (open?: boolean) => void + setLinkAnchorElement: (open?: boolean | HTMLElement | null) => void linkOperatorShow: boolean setLinkOperatorShow: (linkOperatorShow: boolean) => void selectedIsBold: boolean @@ -28,6 +28,11 @@ export const createNoteEditorStore = () => { return createStore(set => ({ linkAnchorElement: null, setLinkAnchorElement: (open) => { + if (open instanceof HTMLElement) { + set(() => ({ linkAnchorElement: open })) + return + } + if (open) { setTimeout(() => { const nativeSelection = window.getSelection() diff --git a/web/app/components/workflow/note-node/note-editor/theme/index.ts b/web/app/components/workflow/note-node/note-editor/theme/index.ts index 5cb8dec37f..a815291d05 100644 --- a/web/app/components/workflow/note-node/note-editor/theme/index.ts +++ b/web/app/components/workflow/note-node/note-editor/theme/index.ts @@ -8,7 +8,7 @@ const theme: EditorThemeClasses = { ul: 'note-editor-theme_list-ul', listitem: 'note-editor-theme_list-li', }, - link: 'note-editor-theme_link', + link: 'note-editor-theme_link nodrag nopan nowheel', text: { italic: 'note-editor-theme_text-italic', strikethrough: 'note-editor-theme_text-strikethrough', diff --git a/web/app/components/workflow/note-node/note-editor/theme/theme.css b/web/app/components/workflow/note-node/note-editor/theme/theme.css index 77b745ca4a..6d22c58b1c 100644 --- a/web/app/components/workflow/note-node/note-editor/theme/theme.css +++ b/web/app/components/workflow/note-node/note-editor/theme/theme.css @@ -16,11 +16,18 @@ .note-editor-theme_link { cursor: pointer; - color: var(--text-text-selected); + color: var(--color-text-accent); + font-weight: 500; + text-decoration: underline; + text-decoration-thickness: 1px; + text-underline-offset: 2px; + text-decoration-color: color-mix(in srgb, var(--color-text-accent) 60%, transparent); + transition: color 0.15s ease, text-decoration-color 0.15s ease; } .note-editor-theme_link:hover { - text-decoration: underline; + color: var(--color-text-accent-secondary); + text-decoration-color: currentColor; } .note-editor-theme_text-strikethrough { diff --git a/web/app/components/workflow/note-node/note-editor/toolbar/command.tsx b/web/app/components/workflow/note-node/note-editor/toolbar/command.tsx index ab4bf8c7bb..9f9c02de33 100644 --- a/web/app/components/workflow/note-node/note-editor/toolbar/command.tsx +++ b/web/app/components/workflow/note-node/note-editor/toolbar/command.tsx @@ -1,17 +1,10 @@ import { cn } from '@langgenius/dify-ui/cn' -import { - RiBold, - RiItalic, - RiLink, - RiListUnordered, - RiStrikethrough, -} from '@remixicon/react' +import { Tooltip, TooltipContent, TooltipTrigger } from '@langgenius/dify-ui/tooltip' import { memo, useMemo, } from 'react' import { useTranslation } from 'react-i18next' -import Tooltip from '@/app/components/base/tooltip' import { useStore } from '../store' import { useCommand } from './hooks' @@ -32,15 +25,15 @@ const Command = ({ const icon = useMemo(() => { switch (type) { case 'bold': - return + return case 'italic': - return + return case 'strikethrough': - return + return case 'link': - return + return case 'bullet': - return + return } }, [type, selectedIsBold, selectedIsItalic, selectedIsStrikeThrough, selectedIsLink, selectedIsBullet]) @@ -60,22 +53,27 @@ const Command = ({ }, [type, t]) return ( - -
+ handleCommand(type)} + > + {icon} + )} - onClick={() => handleCommand(type)} - > - {icon} -
+ /> + {tip}
) } diff --git a/web/app/components/workflow/operator/hooks.ts b/web/app/components/workflow/operator/hooks.ts index 23248a89a3..dcdfa4f629 100644 --- a/web/app/components/workflow/operator/hooks.ts +++ b/web/app/components/workflow/operator/hooks.ts @@ -1,7 +1,10 @@ import type { NoteNodeType } from '../note-node/types' import { useCallback } from 'react' import { useAppContext } from '@/context/app-context' -import { CUSTOM_NOTE_NODE } from '../note-node/constants' +import { + CUSTOM_NOTE_NODE, + NOTE_SHOW_AUTHOR_STORAGE_KEY, +} from '../note-node/constants' import { NoteTheme } from '../note-node/types' import { useWorkflowStore } from '../store' import { generateNewNode } from '../utils' @@ -20,7 +23,7 @@ export const useOperator = () => { text: '', theme: NoteTheme.blue, author: userProfile?.name || '', - showAuthor: true, + showAuthor: localStorage.getItem(NOTE_SHOW_AUTHOR_STORAGE_KEY) !== 'false', width: 240, height: 88, _isCandidate: true, diff --git a/web/app/components/workflow/panel-contextmenu.tsx b/web/app/components/workflow/panel-contextmenu.tsx index ffe88d3dc9..4478839077 100644 --- a/web/app/components/workflow/panel-contextmenu.tsx +++ b/web/app/components/workflow/panel-contextmenu.tsx @@ -137,7 +137,7 @@ const PanelContextmenu = () => { className="flex h-8 cursor-pointer items-center justify-between rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover" onClick={() => setShowImportDSLModal(true)} > - {t('common.importDSL', { ns: 'workflow' })} + {t('importApp', { ns: 'app' })}
diff --git a/web/app/components/workflow/update-dsl-modal.tsx b/web/app/components/workflow/update-dsl-modal.tsx index cfa9c995eb..549dee487f 100644 --- a/web/app/components/workflow/update-dsl-modal.tsx +++ b/web/app/components/workflow/update-dsl-modal.tsx @@ -205,7 +205,7 @@ const UpdateDSLModal = ({ onClose={onCancel} >
-
{t('common.importDSL', { ns: 'workflow' })}
+
{t('importApp', { ns: 'app' })}
diff --git a/web/app/page.tsx b/web/app/page.tsx index 65f8827e01..a866fd4c39 100644 --- a/web/app/page.tsx +++ b/web/app/page.tsx @@ -1,18 +1,23 @@ -import Loading from '@/app/components/base/loading' -import Link from '@/next/link' +import { redirect } from '@/next/navigation' -const Home = async () => { - return ( -
+type HomePageProps = { + searchParams: Promise> +} -
- -
- 🚀 -
-
-
- ) +const Home = async ({ searchParams }: HomePageProps) => { + const resolvedSearchParams = await searchParams + const urlSearchParams = new URLSearchParams() + Object.entries(resolvedSearchParams).forEach(([key, value]) => { + if (value === undefined) + return + if (Array.isArray(value)) { + value.forEach(item => urlSearchParams.append(key, item)) + return + } + urlSearchParams.set(key, value) + }) + const queryString = urlSearchParams.toString() + redirect(queryString ? `/apps?${queryString}` : '/apps') } export default Home diff --git a/web/app/signin/check-code/page.tsx b/web/app/signin/check-code/page.tsx index fb52e0b5b7..42024c561b 100644 --- a/web/app/signin/check-code/page.tsx +++ b/web/app/signin/check-code/page.tsx @@ -51,7 +51,7 @@ export default function CheckCode() { router.replace(`/signin/invite-settings?${searchParams.toString()}`) } else { - const redirectUrl = resolvePostLoginRedirect() + const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') } } diff --git a/web/app/signin/components/mail-and-password-auth.tsx b/web/app/signin/components/mail-and-password-auth.tsx index 6feaf11426..30bc78666c 100644 --- a/web/app/signin/components/mail-and-password-auth.tsx +++ b/web/app/signin/components/mail-and-password-auth.tsx @@ -75,7 +75,7 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis router.replace(`/signin/invite-settings?${searchParams.toString()}`) } else { - const redirectUrl = resolvePostLoginRedirect() + const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') } } diff --git a/web/app/signin/invite-settings/page.tsx b/web/app/signin/invite-settings/page.tsx index 7066ab041c..43ca96ab05 100644 --- a/web/app/signin/invite-settings/page.tsx +++ b/web/app/signin/invite-settings/page.tsx @@ -65,7 +65,7 @@ export default function InviteSettingsPage() { if (res.result === 'success') { // Tokens are now stored in cookies by the backend await setLocaleOnClient(language!, false) - const redirectUrl = resolvePostLoginRedirect() + const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') } } diff --git a/web/app/signin/normal-form.tsx b/web/app/signin/normal-form.tsx index 779aba5c9c..a32c7e9b3d 100644 --- a/web/app/signin/normal-form.tsx +++ b/web/app/signin/normal-form.tsx @@ -49,7 +49,7 @@ const NormalForm = () => { try { if (isLoggedIn) { setIsRedirecting(true) - const redirectUrl = resolvePostLoginRedirect() + const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') return } diff --git a/web/app/signin/utils/post-login-redirect.ts b/web/app/signin/utils/post-login-redirect.ts index a94fb2ad79..0015296a41 100644 --- a/web/app/signin/utils/post-login-redirect.ts +++ b/web/app/signin/utils/post-login-redirect.ts @@ -1,15 +1,63 @@ -let postLoginRedirect: string | null = null +import type { ReadonlyURLSearchParams } from '@/next/navigation' -export const setPostLoginRedirect = (value: string | null) => { - postLoginRedirect = value +const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending_redirect' +const REDIRECT_URL_KEY = 'redirect_url' + +type OAuthPendingRedirect = { + value?: string + expiry?: number } -export const resolvePostLoginRedirect = () => { - if (postLoginRedirect) { - const redirectUrl = postLoginRedirect - postLoginRedirect = null - return redirectUrl +const getCurrentUnixTimestamp = () => Math.floor(Date.now() / 1000) + +function removeOAuthPendingRedirect() { + try { + localStorage.removeItem(OAUTH_AUTHORIZE_PENDING_KEY) } - - return null + catch {} +} + +function getOAuthPendingRedirect(): string | null { + try { + const raw = localStorage.getItem(OAUTH_AUTHORIZE_PENDING_KEY) + if (!raw) + return null + removeOAuthPendingRedirect() + const item: OAuthPendingRedirect = JSON.parse(raw) + if (!item.value || typeof item.expiry !== 'number') + return null + return getCurrentUnixTimestamp() > item.expiry ? null : item.value + } + catch { + removeOAuthPendingRedirect() + return null + } +} + +export function setOAuthPendingRedirect(url: string, ttlSeconds: number = 300) { + try { + const item: OAuthPendingRedirect = { + value: url, + expiry: getCurrentUnixTimestamp() + ttlSeconds, + } + localStorage.setItem(OAUTH_AUTHORIZE_PENDING_KEY, JSON.stringify(item)) + } + catch {} +} + +export const resolvePostLoginRedirect = (searchParams?: ReadonlyURLSearchParams) => { + if (searchParams) { + const redirectUrl = searchParams.get(REDIRECT_URL_KEY) + if (redirectUrl) { + try { + removeOAuthPendingRedirect() + return decodeURIComponent(redirectUrl) + } + catch { + removeOAuthPendingRedirect() + return redirectUrl + } + } + } + return getOAuthPendingRedirect() } diff --git a/web/contract/marketplace.ts b/web/contract/marketplace.ts index 3573ba5c24..9f2475041e 100644 --- a/web/contract/marketplace.ts +++ b/web/contract/marketplace.ts @@ -1,5 +1,6 @@ import type { CollectionsAndPluginsSearchParams, MarketplaceCollection, PluginsSearchParams } from '@/app/components/plugins/marketplace/types' import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types' +import type { MarketplaceTemplate } from '@/types/marketplace-template' import { type } from '@orpc/contract' import { base } from './base' @@ -54,3 +55,15 @@ export const searchAdvancedContract = base body: Omit }>()) .output(type<{ data: PluginsFromMarketplaceResponse }>()) + +export const templateDetailContract = base + .route({ + path: '/templates/{templateId}', + method: 'GET', + }) + .input(type<{ + params: { + templateId: string + } + }>()) + .output(type<{ data: MarketplaceTemplate }>()) diff --git a/web/contract/router.ts b/web/contract/router.ts index f165729e54..e5610dc81c 100644 --- a/web/contract/router.ts +++ b/web/contract/router.ts @@ -94,12 +94,13 @@ import { workflowDraftUpdateFeaturesContract, } from './console/workflow' import { workflowCommentContracts } from './console/workflow-comment' -import { collectionPluginsContract, collectionsContract, searchAdvancedContract } from './marketplace' +import { collectionPluginsContract, collectionsContract, searchAdvancedContract, templateDetailContract } from './marketplace' export const marketplaceRouterContract = { collections: collectionsContract, collectionPlugins: collectionPluginsContract, searchAdvanced: searchAdvancedContract, + templateDetail: templateDetailContract, } export type MarketPlaceInputs = InferContractRouterInputs diff --git a/web/docs/overlay-migration.md b/web/docs/overlay-migration.md index 73c0f02d9d..cb020f9ab6 100644 --- a/web/docs/overlay-migration.md +++ b/web/docs/overlay-migration.md @@ -44,12 +44,6 @@ This document tracks the Dify-web migration away from legacy overlay APIs. ## Allowlist maintenance -- After each migration batch, run: - -```sh -pnpm -C web lint:fix --prune-suppressions -``` - - If a migrated file was in the allowlist, remove it from `web/eslint.constants.mjs` in the same PR. - Never increase allowlist scope to bypass new code. diff --git a/web/i18n/ar-TN/app.json b/web/i18n/ar-TN/app.json index 154758077b..7bc25ccf42 100644 --- a/web/i18n/ar-TN/app.json +++ b/web/i18n/ar-TN/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "رموز تعبيرية", "iconPicker.image": "صورة", "iconPicker.ok": "موافق", + "importApp": "استيراد التطبيق", "importDSL": "استيراد ملف DSL", "importFromDSL": "استيراد من DSL", "importFromDSLFile": "من ملف DSL", "importFromDSLUrl": "من رابط", "importFromDSLUrlPlaceholder": "لصق رابط DSL هنا", "join": "انضم إلى المجتمع", + "marketplace.template.categories": "الفئات", + "marketplace.template.category.design": "التصميم", + "marketplace.template.category.it": "تكنولوجيا المعلومات", + "marketplace.template.category.knowledge": "المعرفة", + "marketplace.template.category.marketing": "التسويق", + "marketplace.template.category.operations": "العمليات", + "marketplace.template.category.sales": "المبيعات", + "marketplace.template.category.support": "الدعم", + "marketplace.template.fetchFailed": "فشل في جلب القالب", + "marketplace.template.importConfirm": "استيراد", + "marketplace.template.importFailed": "فشل في استيراد القالب", + "marketplace.template.modalTitle": "استيراد من Marketplace", + "marketplace.template.overview": "نظرة عامة", + "marketplace.template.publishedBy": "بواسطة", + "marketplace.template.usageCount": "الاستخدام", + "marketplace.template.viewOnMarketplace": "عرض على Marketplace", "maxActiveRequests": "أقصى عدد للطلبات المتزامنة", "maxActiveRequestsPlaceholder": "أدخل 0 لغير محدود", "maxActiveRequestsTip": "الحد الأقصى لعدد الطلبات النشطة المتزامنة لكل تطبيق (0 لغير محدود)", diff --git a/web/i18n/ar-TN/workflow.json b/web/i18n/ar-TN/workflow.json index 04a618fb3b..cc6c533ca1 100644 --- a/web/i18n/ar-TN/workflow.json +++ b/web/i18n/ar-TN/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "أدخل المحتوى في المربع أدناه لبدء تصحيح أخطاء Chatbot", "common.processData": "معالجة البيانات", "common.publish": "نشر", + "common.publishToMarketplace": "نشر على Marketplace", + "common.publishToMarketplaceFailed": "فشل النشر على Marketplace", "common.publishUpdate": "نشر التحديث", "common.published": "منشور", "common.publishedAt": "تم النشر في", + "common.publishingToMarketplace": "جارٍ النشر...", "common.redo": "إعادة", "common.restart": "إعادة تشغيل", "common.restore": "استعادة", diff --git a/web/i18n/de-DE/app.json b/web/i18n/de-DE/app.json index b316dcebce..c429e37802 100644 --- a/web/i18n/de-DE/app.json +++ b/web/i18n/de-DE/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Bild", "iconPicker.ok": "OK", + "importApp": "App importieren", "importDSL": "DSL-Datei importieren", "importFromDSL": "Import von DSL", "importFromDSLFile": "Aus DSL-Datei", "importFromDSLUrl": "Von URL", "importFromDSLUrlPlaceholder": "DSL-Link hier einfügen", "join": "Treten Sie der Gemeinschaft bei", + "marketplace.template.categories": "Kategorien", + "marketplace.template.category.design": "Design", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Wissen", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Betrieb", + "marketplace.template.category.sales": "Vertrieb", + "marketplace.template.category.support": "Support", + "marketplace.template.fetchFailed": "Vorlage konnte nicht abgerufen werden", + "marketplace.template.importConfirm": "Importieren", + "marketplace.template.importFailed": "Vorlage konnte nicht importiert werden", + "marketplace.template.modalTitle": "Aus Marketplace importieren", + "marketplace.template.overview": "Übersicht", + "marketplace.template.publishedBy": "Von", + "marketplace.template.usageCount": "Nutzung", + "marketplace.template.viewOnMarketplace": "Im Marketplace ansehen", "maxActiveRequests": "Maximale gleichzeitige Anfragen", "maxActiveRequestsPlaceholder": "Geben Sie 0 für unbegrenzt ein", "maxActiveRequestsTip": "Maximale Anzahl gleichzeitiger aktiver Anfragen pro App (0 für unbegrenzt)", diff --git a/web/i18n/de-DE/workflow.json b/web/i18n/de-DE/workflow.json index fe50c09651..426c023259 100644 --- a/web/i18n/de-DE/workflow.json +++ b/web/i18n/de-DE/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Geben Sie den Inhalt in das Feld unten ein, um das Debuggen des Chatbots zu starten", "common.processData": "Daten verarbeiten", "common.publish": "Veröffentlichen", + "common.publishToMarketplace": "Im Marketplace veröffentlichen", + "common.publishToMarketplaceFailed": "Veröffentlichung im Marketplace fehlgeschlagen", "common.publishUpdate": "Update veröffentlichen", "common.published": "Veröffentlicht", "common.publishedAt": "Veröffentlicht am", + "common.publishingToMarketplace": "Wird veröffentlicht...", "common.redo": "Wiederholen", "common.restart": "Neustarten", "common.restore": "Wiederherstellen", diff --git a/web/i18n/en-US/app.json b/web/i18n/en-US/app.json index 450608aa80..bb30ef44cf 100644 --- a/web/i18n/en-US/app.json +++ b/web/i18n/en-US/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Image", "iconPicker.ok": "OK", + "importApp": "Import App", "importDSL": "Import DSL file", "importFromDSL": "Import from DSL", "importFromDSLFile": "From DSL file", "importFromDSLUrl": "From URL", "importFromDSLUrlPlaceholder": "Paste DSL link here", "join": "Join the community", + "marketplace.template.categories": "Categories", + "marketplace.template.category.design": "Design", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Knowledge", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Operations", + "marketplace.template.category.sales": "Sales", + "marketplace.template.category.support": "Support", + "marketplace.template.fetchFailed": "Failed to fetch template", + "marketplace.template.importConfirm": "Import", + "marketplace.template.importFailed": "Failed to import template", + "marketplace.template.modalTitle": "Import from Marketplace", + "marketplace.template.overview": "Overview", + "marketplace.template.publishedBy": "By", + "marketplace.template.usageCount": "Usage", + "marketplace.template.viewOnMarketplace": "View on Marketplace", "maxActiveRequests": "Max concurrent requests", "maxActiveRequestsPlaceholder": "Enter 0 for unlimited", "maxActiveRequestsTip": "Maximum number of concurrent active requests per app (0 for unlimited)", diff --git a/web/i18n/en-US/workflow.json b/web/i18n/en-US/workflow.json index a3e7623562..c8cbed8f1e 100644 --- a/web/i18n/en-US/workflow.json +++ b/web/i18n/en-US/workflow.json @@ -231,9 +231,12 @@ "common.publish": "Publish", "common.publishAsEvaluationWorkflow": "Publish as Evaluation Workflow", "common.publishAsStandardWorkflow": "Publish as Standard Workflow", + "common.publishToMarketplace": "Publish to Marketplace", + "common.publishToMarketplaceFailed": "Failed to publish to Marketplace", "common.publishUpdate": "Publish Update", "common.published": "Published", "common.publishedAt": "Published", + "common.publishingToMarketplace": "Publishing...", "common.redo": "Redo", "common.restart": "Restart", "common.restore": "Restore", diff --git a/web/i18n/es-ES/app.json b/web/i18n/es-ES/app.json index 251746db7f..5cc805c8f6 100644 --- a/web/i18n/es-ES/app.json +++ b/web/i18n/es-ES/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Imagen", "iconPicker.ok": "OK", + "importApp": "Importar App", "importDSL": "Importar archivo DSL", "importFromDSL": "Importar desde DSL", "importFromDSLFile": "Desde el archivo DSL", "importFromDSLUrl": "URL de origen", "importFromDSLUrlPlaceholder": "Pegar enlace DSL aquí", "join": "Únete a la comunidad", + "marketplace.template.categories": "Categorías", + "marketplace.template.category.design": "Diseño", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Conocimiento", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Operaciones", + "marketplace.template.category.sales": "Ventas", + "marketplace.template.category.support": "Soporte", + "marketplace.template.fetchFailed": "Error al obtener la plantilla", + "marketplace.template.importConfirm": "Importar", + "marketplace.template.importFailed": "Error al importar la plantilla", + "marketplace.template.modalTitle": "Importar desde Marketplace", + "marketplace.template.overview": "Vista general", + "marketplace.template.publishedBy": "Por", + "marketplace.template.usageCount": "Uso", + "marketplace.template.viewOnMarketplace": "Ver en Marketplace", "maxActiveRequests": "Máximas solicitudes concurrentes", "maxActiveRequestsPlaceholder": "Introduce 0 para ilimitado", "maxActiveRequestsTip": "Número máximo de solicitudes activas concurrentes por aplicación (0 para ilimitado)", diff --git a/web/i18n/es-ES/workflow.json b/web/i18n/es-ES/workflow.json index 5da69241e7..c55ffdfc1e 100644 --- a/web/i18n/es-ES/workflow.json +++ b/web/i18n/es-ES/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Ingrese contenido en el cuadro de abajo para comenzar a depurar el Chatbot", "common.processData": "Procesar datos", "common.publish": "Publicar", + "common.publishToMarketplace": "Publicar en Marketplace", + "common.publishToMarketplaceFailed": "Error al publicar en Marketplace", "common.publishUpdate": "Publicar actualización", "common.published": "Publicado", "common.publishedAt": "Publicado el", + "common.publishingToMarketplace": "Publicando...", "common.redo": "Rehacer", "common.restart": "Reiniciar", "common.restore": "Restaurar", diff --git a/web/i18n/fa-IR/app.json b/web/i18n/fa-IR/app.json index ed253fc569..3bdba44440 100644 --- a/web/i18n/fa-IR/app.json +++ b/web/i18n/fa-IR/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "ایموجی", "iconPicker.image": "تصویر", "iconPicker.ok": "باشه", + "importApp": "وارد کردن برنامه", "importDSL": "وارد کردن فایل DSL", "importFromDSL": "وارد کردن از DSL", "importFromDSLFile": "از فایل DSL", "importFromDSLUrl": "از URL", "importFromDSLUrlPlaceholder": "لینک DSL را اینجا بچسبانید", "join": "پیوستن به جامعه", + "marketplace.template.categories": "دسته‌بندی‌ها", + "marketplace.template.category.design": "طراحی", + "marketplace.template.category.it": "فناوری اطلاعات", + "marketplace.template.category.knowledge": "دانش", + "marketplace.template.category.marketing": "بازاریابی", + "marketplace.template.category.operations": "عملیات", + "marketplace.template.category.sales": "فروش", + "marketplace.template.category.support": "پشتیبانی", + "marketplace.template.fetchFailed": "دریافت قالب ناموفق بود", + "marketplace.template.importConfirm": "وارد کردن", + "marketplace.template.importFailed": "وارد کردن قالب ناموفق بود", + "marketplace.template.modalTitle": "وارد کردن از Marketplace", + "marketplace.template.overview": "نمای کلی", + "marketplace.template.publishedBy": "توسط", + "marketplace.template.usageCount": "استفاده", + "marketplace.template.viewOnMarketplace": "مشاهده در Marketplace", "maxActiveRequests": "بیشترین درخواست‌های همزمان", "maxActiveRequestsPlaceholder": "برای نامحدود، 0 را وارد کنید", "maxActiveRequestsTip": "حداکثر تعداد درخواست‌های فعال همزمان در هر برنامه (0 برای نامحدود)", diff --git a/web/i18n/fa-IR/workflow.json b/web/i18n/fa-IR/workflow.json index 3210cf8919..c23c781a04 100644 --- a/web/i18n/fa-IR/workflow.json +++ b/web/i18n/fa-IR/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "محتوا را در کادر زیر وارد کنید تا اشکال‌زدایی چت‌بات آغاز شود", "common.processData": "پردازش داده‌ها", "common.publish": "انتشار", + "common.publishToMarketplace": "انتشار در Marketplace", + "common.publishToMarketplaceFailed": "انتشار در Marketplace ناموفق بود", "common.publishUpdate": "انتشار به‌روزرسانی", "common.published": "منتشر شده", "common.publishedAt": "منتشر شده در", + "common.publishingToMarketplace": "در حال انتشار...", "common.redo": "بازانجام", "common.restart": "راه‌اندازی مجدد", "common.restore": "بازیابی", diff --git a/web/i18n/fr-FR/app.json b/web/i18n/fr-FR/app.json index f6af8380bb..f90623ce18 100644 --- a/web/i18n/fr-FR/app.json +++ b/web/i18n/fr-FR/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Image", "iconPicker.ok": "OK", + "importApp": "Importer l'App", "importDSL": "Importer le fichier DSL", "importFromDSL": "Importation à partir d'une DSL", "importFromDSLFile": "À partir d’un fichier DSL", "importFromDSLUrl": "À partir de l’URL", "importFromDSLUrlPlaceholder": "Collez le lien DSL ici", "join": "Rejoindre la communauté", + "marketplace.template.categories": "Catégories", + "marketplace.template.category.design": "Design", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Connaissance", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Opérations", + "marketplace.template.category.sales": "Ventes", + "marketplace.template.category.support": "Support", + "marketplace.template.fetchFailed": "Échec de la récupération du modèle", + "marketplace.template.importConfirm": "Importer", + "marketplace.template.importFailed": "Échec de l'importation du modèle", + "marketplace.template.modalTitle": "Importer depuis le Marketplace", + "marketplace.template.overview": "Aperçu", + "marketplace.template.publishedBy": "Par", + "marketplace.template.usageCount": "Utilisation", + "marketplace.template.viewOnMarketplace": "Voir sur le Marketplace", "maxActiveRequests": "Nombre maximal de requêtes simultanées", "maxActiveRequestsPlaceholder": "Entrez 0 pour illimité", "maxActiveRequestsTip": "Nombre maximum de requêtes actives concurrentes par application (0 pour illimité)", diff --git a/web/i18n/fr-FR/workflow.json b/web/i18n/fr-FR/workflow.json index da3e69dab3..727c3a91e6 100644 --- a/web/i18n/fr-FR/workflow.json +++ b/web/i18n/fr-FR/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Entrez le contenu dans la boîte ci-dessous pour commencer à déboguer le Chatbot", "common.processData": "Traiter les données", "common.publish": "Publier", + "common.publishToMarketplace": "Publier sur le Marketplace", + "common.publishToMarketplaceFailed": "Échec de la publication sur le Marketplace", "common.publishUpdate": "Publier une mise à jour", "common.published": "Publié", "common.publishedAt": "Publié le", + "common.publishingToMarketplace": "Publication en cours...", "common.redo": "Réexécuter", "common.restart": "Redémarrer", "common.restore": "Restaurer", diff --git a/web/i18n/hi-IN/app.json b/web/i18n/hi-IN/app.json index 3705c4dec1..a7cc347820 100644 --- a/web/i18n/hi-IN/app.json +++ b/web/i18n/hi-IN/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "इमोजी", "iconPicker.image": "छवि", "iconPicker.ok": "ठीक है", + "importApp": "ऐप आयात करें", "importDSL": "डीएसएल फ़ाइल आयात करें", "importFromDSL": "DSL से आयात करें", "importFromDSLFile": "डीएसएल फ़ाइल से", "importFromDSLUrl": "यूआरएल से", "importFromDSLUrlPlaceholder": "डीएसएल लिंक यहां पेस्ट करें", "join": "समुदाय में शामिल हों", + "marketplace.template.categories": "श्रेणियाँ", + "marketplace.template.category.design": "डिज़ाइन", + "marketplace.template.category.it": "आईटी", + "marketplace.template.category.knowledge": "ज्ञान", + "marketplace.template.category.marketing": "मार्केटिंग", + "marketplace.template.category.operations": "संचालन", + "marketplace.template.category.sales": "बिक्री", + "marketplace.template.category.support": "समर्थन", + "marketplace.template.fetchFailed": "टेम्पलेट प्राप्त करने में विफल", + "marketplace.template.importConfirm": "आयात करें", + "marketplace.template.importFailed": "टेम्पलेट आयात करने में विफल", + "marketplace.template.modalTitle": "Marketplace से आयात करें", + "marketplace.template.overview": "अवलोकन", + "marketplace.template.publishedBy": "द्वारा", + "marketplace.template.usageCount": "उपयोग", + "marketplace.template.viewOnMarketplace": "Marketplace पर देखें", "maxActiveRequests": "अधिकतम समवर्ती अनुरोध", "maxActiveRequestsPlaceholder": "असीमित के लिए 0 दर्ज करें", "maxActiveRequestsTip": "प्रति ऐप सक्रिय अनुरोधों की अधिकतम संख्या (असीमित के लिए 0)", diff --git a/web/i18n/hi-IN/workflow.json b/web/i18n/hi-IN/workflow.json index 20845af0b8..8b5ea73535 100644 --- a/web/i18n/hi-IN/workflow.json +++ b/web/i18n/hi-IN/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "चैटबॉट का डीबग शुरू करने के लिए नीचे दिए गए बॉक्स में सामग्री दर्ज करें", "common.processData": "डेटा प्रोसेस करें", "common.publish": "प्रकाशित करें", + "common.publishToMarketplace": "Marketplace पर प्रकाशित करें", + "common.publishToMarketplaceFailed": "Marketplace पर प्रकाशित करने में विफल", "common.publishUpdate": "अपडेट प्रकाशित करें", "common.published": "प्रकाशित", "common.publishedAt": "प्रकाशित", + "common.publishingToMarketplace": "प्रकाशित हो रहा है...", "common.redo": "फिर से करें", "common.restart": "पुनः आरंभ करें", "common.restore": "पुनर्स्थापित करें", diff --git a/web/i18n/id-ID/app.json b/web/i18n/id-ID/app.json index 23aadc9da6..c47dda1886 100644 --- a/web/i18n/id-ID/app.json +++ b/web/i18n/id-ID/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Citra", "iconPicker.ok": "OK", + "importApp": "Impor Aplikasi", "importDSL": "Impor file DSL", "importFromDSL": "Impor dari DSL", "importFromDSLFile": "Dari file DSL", "importFromDSLUrl": "Dari URL", "importFromDSLUrlPlaceholder": "Tempel tautan DSL di sini", "join": "Bergabunglah dengan komunitas", + "marketplace.template.categories": "Kategori", + "marketplace.template.category.design": "Desain", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Pengetahuan", + "marketplace.template.category.marketing": "Pemasaran", + "marketplace.template.category.operations": "Operasi", + "marketplace.template.category.sales": "Penjualan", + "marketplace.template.category.support": "Dukungan", + "marketplace.template.fetchFailed": "Gagal mengambil templat", + "marketplace.template.importConfirm": "Impor", + "marketplace.template.importFailed": "Gagal mengimpor templat", + "marketplace.template.modalTitle": "Impor dari Marketplace", + "marketplace.template.overview": "Ikhtisar", + "marketplace.template.publishedBy": "Oleh", + "marketplace.template.usageCount": "Penggunaan", + "marketplace.template.viewOnMarketplace": "Lihat di Marketplace", "maxActiveRequests": "Permintaan bersamaan maksimum", "maxActiveRequestsPlaceholder": "Masukkan 0 untuk tidak terbatas", "maxActiveRequestsTip": "Jumlah maksimum permintaan aktif bersamaan per aplikasi (0 untuk tidak terbatas)", diff --git a/web/i18n/id-ID/workflow.json b/web/i18n/id-ID/workflow.json index 2c32f25aab..058c15334b 100644 --- a/web/i18n/id-ID/workflow.json +++ b/web/i18n/id-ID/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Masukkan konten di kotak di bawah ini untuk mulai men-debug Chatbot", "common.processData": "Proses Data", "common.publish": "Menerbitkan", + "common.publishToMarketplace": "Publikasikan ke Marketplace", + "common.publishToMarketplaceFailed": "Gagal mempublikasikan ke Marketplace", "common.publishUpdate": "Publikasikan Pembaruan", "common.published": "Diterbitkan", "common.publishedAt": "Diterbitkan", + "common.publishingToMarketplace": "Mempublikasikan...", "common.redo": "Ulangi", "common.restart": "Restart", "common.restore": "Mengembalikan", diff --git a/web/i18n/it-IT/app.json b/web/i18n/it-IT/app.json index e721ecf655..0719a49571 100644 --- a/web/i18n/it-IT/app.json +++ b/web/i18n/it-IT/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Immagine", "iconPicker.ok": "OK", + "importApp": "Importa App", "importDSL": "Importa file DSL", "importFromDSL": "Importazione da DSL", "importFromDSLFile": "Da file DSL", "importFromDSLUrl": "Dall'URL", "importFromDSLUrlPlaceholder": "Incolla qui il link DSL", "join": "Unisciti alla comunità", + "marketplace.template.categories": "Categorie", + "marketplace.template.category.design": "Design", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Conoscenza", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Operazioni", + "marketplace.template.category.sales": "Vendite", + "marketplace.template.category.support": "Supporto", + "marketplace.template.fetchFailed": "Impossibile recuperare il modello", + "marketplace.template.importConfirm": "Importa", + "marketplace.template.importFailed": "Impossibile importare il modello", + "marketplace.template.modalTitle": "Importa dal Marketplace", + "marketplace.template.overview": "Panoramica", + "marketplace.template.publishedBy": "Di", + "marketplace.template.usageCount": "Utilizzo", + "marketplace.template.viewOnMarketplace": "Visualizza sul Marketplace", "maxActiveRequests": "Massimo numero di richieste concorrenti", "maxActiveRequestsPlaceholder": "Inserisci 0 per illimitato", "maxActiveRequestsTip": "Numero massimo di richieste attive concorrenti per app (0 per illimitato)", diff --git a/web/i18n/it-IT/workflow.json b/web/i18n/it-IT/workflow.json index 1c779d1365..fbd3041fb9 100644 --- a/web/i18n/it-IT/workflow.json +++ b/web/i18n/it-IT/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Inserisci contenuto nella casella sottostante per avviare il debug del Chatbot", "common.processData": "Elabora Dati", "common.publish": "Pubblica", + "common.publishToMarketplace": "Pubblica sul Marketplace", + "common.publishToMarketplaceFailed": "Pubblicazione sul Marketplace non riuscita", "common.publishUpdate": "Pubblica aggiornamento", "common.published": "Pubblicato", "common.publishedAt": "Pubblicato", + "common.publishingToMarketplace": "Pubblicazione...", "common.redo": "Ripeti", "common.restart": "Riavvia", "common.restore": "Ripristina", diff --git a/web/i18n/ja-JP/app.json b/web/i18n/ja-JP/app.json index 925095d447..8ccaaababe 100644 --- a/web/i18n/ja-JP/app.json +++ b/web/i18n/ja-JP/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "絵文字", "iconPicker.image": "画像", "iconPicker.ok": "OK", + "importApp": "アプリをインポート", "importDSL": "DSL ファイルをインポート", "importFromDSL": "DSL からインポート", "importFromDSLFile": "DSL ファイルから", "importFromDSLUrl": "URL から", "importFromDSLUrlPlaceholder": "DSL リンクをここに貼り付けます", "join": "コミュニティに参加する", + "marketplace.template.categories": "カテゴリ", + "marketplace.template.category.design": "デザイン", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "知識", + "marketplace.template.category.marketing": "マーケティング", + "marketplace.template.category.operations": "オペレーション", + "marketplace.template.category.sales": "セールス", + "marketplace.template.category.support": "サポート", + "marketplace.template.fetchFailed": "テンプレートの取得に失敗しました", + "marketplace.template.importConfirm": "インポート", + "marketplace.template.importFailed": "テンプレートのインポートに失敗しました", + "marketplace.template.modalTitle": "マーケットプレイスからインポート", + "marketplace.template.overview": "概要", + "marketplace.template.publishedBy": "提供者", + "marketplace.template.usageCount": "使用数", + "marketplace.template.viewOnMarketplace": "マーケットプレイスで見る", "maxActiveRequests": "最大同時リクエスト数", "maxActiveRequestsPlaceholder": "無制限のために0を入力してください", "maxActiveRequestsTip": "アプリごとの同時アクティブリクエストの最大数(無制限の場合は0)", diff --git a/web/i18n/ja-JP/workflow.json b/web/i18n/ja-JP/workflow.json index 1ee43c17cf..1154a5baba 100644 --- a/web/i18n/ja-JP/workflow.json +++ b/web/i18n/ja-JP/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "入力欄にテキストを入力してチャットボットのデバッグを開始", "common.processData": "データ処理", "common.publish": "公開する", + "common.publishToMarketplace": "マーケットプレイスに公開", + "common.publishToMarketplaceFailed": "マーケットプレイスへの公開に失敗しました", "common.publishUpdate": "更新を公開", "common.published": "公開済み", "common.publishedAt": "公開日時", + "common.publishingToMarketplace": "公開中...", "common.redo": "やり直し", "common.restart": "再起動", "common.restore": "復元", diff --git a/web/i18n/ko-KR/app.json b/web/i18n/ko-KR/app.json index 4f29da5f1e..b9dd592f03 100644 --- a/web/i18n/ko-KR/app.json +++ b/web/i18n/ko-KR/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "이모지", "iconPicker.image": "이미지", "iconPicker.ok": "확인", + "importApp": "앱 가져오기", "importDSL": "DSL 파일 가져오기", "importFromDSL": "DSL 에서 가져오기", "importFromDSLFile": "DSL 파일에서", "importFromDSLUrl": "URL 에서", "importFromDSLUrlPlaceholder": "여기에 DSL 링크 붙여 넣기", "join": "커뮤니티에 참여하기", + "marketplace.template.categories": "카테고리", + "marketplace.template.category.design": "디자인", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "지식", + "marketplace.template.category.marketing": "마케팅", + "marketplace.template.category.operations": "운영", + "marketplace.template.category.sales": "영업", + "marketplace.template.category.support": "지원", + "marketplace.template.fetchFailed": "템플릿 가져오기 실패", + "marketplace.template.importConfirm": "가져오기", + "marketplace.template.importFailed": "템플릿 가져오기 실패", + "marketplace.template.modalTitle": "마켓플레이스에서 가져오기", + "marketplace.template.overview": "개요", + "marketplace.template.publishedBy": "제공:", + "marketplace.template.usageCount": "사용량", + "marketplace.template.viewOnMarketplace": "마켓플레이스에서 보기", "maxActiveRequests": "동시 최대 요청 수", "maxActiveRequestsPlaceholder": "무제한 사용을 원하시면 0을 입력하세요.", "maxActiveRequestsTip": "앱당 최대 동시 활성 요청 수(무제한은 0)", diff --git a/web/i18n/ko-KR/workflow.json b/web/i18n/ko-KR/workflow.json index b6291e4366..a80c34b294 100644 --- a/web/i18n/ko-KR/workflow.json +++ b/web/i18n/ko-KR/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "디버깅을 시작하려면 아래 상자에 내용을 입력하세요", "common.processData": "데이터 처리", "common.publish": "게시하기", + "common.publishToMarketplace": "마켓플레이스에 게시", + "common.publishToMarketplaceFailed": "마켓플레이스 게시 실패", "common.publishUpdate": "업데이트 게시", "common.published": "게시됨", "common.publishedAt": "발행일", + "common.publishingToMarketplace": "게시 중...", "common.redo": "다시 실행", "common.restart": "재시작", "common.restore": "복원", diff --git a/web/i18n/nl-NL/app.json b/web/i18n/nl-NL/app.json index 0ad608d53c..9bd50b5b92 100644 --- a/web/i18n/nl-NL/app.json +++ b/web/i18n/nl-NL/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Image", "iconPicker.ok": "OK", + "importApp": "App importeren", "importDSL": "Import DSL file", "importFromDSL": "Import from DSL", "importFromDSLFile": "From DSL file", "importFromDSLUrl": "From URL", "importFromDSLUrlPlaceholder": "Paste DSL link here", "join": "Join the community", + "marketplace.template.categories": "Categorieën", + "marketplace.template.category.design": "Ontwerp", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Kennis", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Operaties", + "marketplace.template.category.sales": "Verkoop", + "marketplace.template.category.support": "Ondersteuning", + "marketplace.template.fetchFailed": "Template ophalen mislukt", + "marketplace.template.importConfirm": "Importeren", + "marketplace.template.importFailed": "Template importeren mislukt", + "marketplace.template.modalTitle": "Importeren vanuit Marketplace", + "marketplace.template.overview": "Overzicht", + "marketplace.template.publishedBy": "Door", + "marketplace.template.usageCount": "Gebruik", + "marketplace.template.viewOnMarketplace": "Bekijken op Marketplace", "maxActiveRequests": "Max concurrent requests", "maxActiveRequestsPlaceholder": "Enter 0 for unlimited", "maxActiveRequestsTip": "Maximum number of concurrent active requests per app (0 for unlimited)", diff --git a/web/i18n/nl-NL/workflow.json b/web/i18n/nl-NL/workflow.json index c3d5824ef7..c8e8753eb4 100644 --- a/web/i18n/nl-NL/workflow.json +++ b/web/i18n/nl-NL/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Enter content in the box below to start debugging the Chatbot", "common.processData": "Process Data", "common.publish": "Publish", + "common.publishToMarketplace": "Publiceren op Marketplace", + "common.publishToMarketplaceFailed": "Publiceren op Marketplace mislukt", "common.publishUpdate": "Publish Update", "common.published": "Published", "common.publishedAt": "Published", + "common.publishingToMarketplace": "Publiceren...", "common.redo": "Redo", "common.restart": "Restart", "common.restore": "Restore", diff --git a/web/i18n/pl-PL/app.json b/web/i18n/pl-PL/app.json index a3ae06e3cd..0f6f5cd298 100644 --- a/web/i18n/pl-PL/app.json +++ b/web/i18n/pl-PL/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Obraz", "iconPicker.ok": "OK", + "importApp": "Importuj aplikację", "importDSL": "Importuj plik DSL", "importFromDSL": "Importowanie z DSL", "importFromDSLFile": "Z pliku DSL", "importFromDSLUrl": "Z adresu URL", "importFromDSLUrlPlaceholder": "Wklej tutaj link DSL", "join": "Dołącz do społeczności", + "marketplace.template.categories": "Kategorie", + "marketplace.template.category.design": "Design", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Wiedza", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Operacje", + "marketplace.template.category.sales": "Sprzedaż", + "marketplace.template.category.support": "Wsparcie", + "marketplace.template.fetchFailed": "Nie udało się pobrać szablonu", + "marketplace.template.importConfirm": "Importuj", + "marketplace.template.importFailed": "Nie udało się zaimportować szablonu", + "marketplace.template.modalTitle": "Importuj z Marketplace", + "marketplace.template.overview": "Przegląd", + "marketplace.template.publishedBy": "Przez", + "marketplace.template.usageCount": "Użycie", + "marketplace.template.viewOnMarketplace": "Zobacz na Marketplace", "maxActiveRequests": "Maksymalne równoczesne żądania", "maxActiveRequestsPlaceholder": "Wprowadź 0, aby uzyskać nielimitowane", "maxActiveRequestsTip": "Maksymalna liczba jednoczesnych aktywnych żądań na aplikację (0 dla nieograniczonej)", diff --git a/web/i18n/pl-PL/workflow.json b/web/i18n/pl-PL/workflow.json index 6b0bda1ff8..805960a851 100644 --- a/web/i18n/pl-PL/workflow.json +++ b/web/i18n/pl-PL/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Wprowadź treść w poniższym polu, aby rozpocząć debugowanie Chatbota", "common.processData": "Przetwórz dane", "common.publish": "Opublikuj", + "common.publishToMarketplace": "Publikuj na Marketplace", + "common.publishToMarketplaceFailed": "Nie udało się opublikować na Marketplace", "common.publishUpdate": "Opublikuj aktualizację", "common.published": "Opublikowane", "common.publishedAt": "Opublikowane", + "common.publishingToMarketplace": "Publikowanie...", "common.redo": "Ponów", "common.restart": "Uruchom ponownie", "common.restore": "Przywróć", diff --git a/web/i18n/pt-BR/app.json b/web/i18n/pt-BR/app.json index 43447c970c..3c59423e99 100644 --- a/web/i18n/pt-BR/app.json +++ b/web/i18n/pt-BR/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Imagem", "iconPicker.ok": "OK", + "importApp": "Importar App", "importDSL": "Importar arquivo DSL", "importFromDSL": "Importar de DSL", "importFromDSLFile": "Do arquivo DSL", "importFromDSLUrl": "Do URL", "importFromDSLUrlPlaceholder": "Cole o link DSL aqui", "join": "Participe da comunidade", + "marketplace.template.categories": "Categorias", + "marketplace.template.category.design": "Design", + "marketplace.template.category.it": "TI", + "marketplace.template.category.knowledge": "Conhecimento", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Operações", + "marketplace.template.category.sales": "Vendas", + "marketplace.template.category.support": "Suporte", + "marketplace.template.fetchFailed": "Falha ao buscar modelo", + "marketplace.template.importConfirm": "Importar", + "marketplace.template.importFailed": "Falha ao importar modelo", + "marketplace.template.modalTitle": "Importar do Marketplace", + "marketplace.template.overview": "Visão geral", + "marketplace.template.publishedBy": "Por", + "marketplace.template.usageCount": "Uso", + "marketplace.template.viewOnMarketplace": "Ver no Marketplace", "maxActiveRequests": "Máximo de solicitações simultâneas", "maxActiveRequestsPlaceholder": "Digite 0 para ilimitado", "maxActiveRequestsTip": "Número máximo de solicitações ativas simultâneas por aplicativo (0 para ilimitado)", diff --git a/web/i18n/pt-BR/workflow.json b/web/i18n/pt-BR/workflow.json index a8a7511100..de6c882e0c 100644 --- a/web/i18n/pt-BR/workflow.json +++ b/web/i18n/pt-BR/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Digite o conteúdo na caixa abaixo para começar a depurar o Chatbot", "common.processData": "Processar dados", "common.publish": "Publicar", + "common.publishToMarketplace": "Publicar no Marketplace", + "common.publishToMarketplaceFailed": "Falha ao publicar no Marketplace", "common.publishUpdate": "Publicar Atualização", "common.published": "Publicado", "common.publishedAt": "Publicado em", + "common.publishingToMarketplace": "Publicando...", "common.redo": "Refazer", "common.restart": "Reiniciar", "common.restore": "Restaurar", diff --git a/web/i18n/ro-RO/app.json b/web/i18n/ro-RO/app.json index cfa0b8aedc..f93e4f10a0 100644 --- a/web/i18n/ro-RO/app.json +++ b/web/i18n/ro-RO/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Imagine", "iconPicker.ok": "OK", + "importApp": "Importați aplicația", "importDSL": "Importă fișier DSL", "importFromDSL": "Import din DSL", "importFromDSLFile": "Din fișierul DSL", "importFromDSLUrl": "De la URL", "importFromDSLUrlPlaceholder": "Lipiți linkul DSL aici", "join": "Alătură-te comunității", + "marketplace.template.categories": "Categorii", + "marketplace.template.category.design": "Design", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Cunoaștere", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Operațiuni", + "marketplace.template.category.sales": "Vânzări", + "marketplace.template.category.support": "Suport", + "marketplace.template.fetchFailed": "Eroare la obținerea șablonului", + "marketplace.template.importConfirm": "Importați", + "marketplace.template.importFailed": "Eroare la importul șablonului", + "marketplace.template.modalTitle": "Importați din Marketplace", + "marketplace.template.overview": "Prezentare generală", + "marketplace.template.publishedBy": "De", + "marketplace.template.usageCount": "Utilizare", + "marketplace.template.viewOnMarketplace": "Vizualizați pe Marketplace", "maxActiveRequests": "Maxime cereri simultane", "maxActiveRequestsPlaceholder": "Introduceți 0 pentru nelimitat", "maxActiveRequestsTip": "Numărul maxim de cereri active concurente pe aplicație (0 pentru nelimitat)", diff --git a/web/i18n/ro-RO/workflow.json b/web/i18n/ro-RO/workflow.json index c15e8508ab..7b551294e8 100644 --- a/web/i18n/ro-RO/workflow.json +++ b/web/i18n/ro-RO/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Introduceți conținutul în caseta de mai jos pentru a începe depanarea Chatbotului", "common.processData": "Procesează date", "common.publish": "Publică", + "common.publishToMarketplace": "Publicați pe Marketplace", + "common.publishToMarketplaceFailed": "Eroare la publicarea pe Marketplace", "common.publishUpdate": "Publicați actualizarea", "common.published": "Publicat", "common.publishedAt": "Publicat la", + "common.publishingToMarketplace": "Se publică...", "common.redo": "Refă", "common.restart": "Repornește", "common.restore": "Restaurează", diff --git a/web/i18n/ru-RU/app.json b/web/i18n/ru-RU/app.json index 7b53ea61fb..8a9327e81e 100644 --- a/web/i18n/ru-RU/app.json +++ b/web/i18n/ru-RU/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Эмодзи", "iconPicker.image": "Изображение", "iconPicker.ok": "ОК", + "importApp": "Импортировать приложение", "importDSL": "Импортировать файл DSL", "importFromDSL": "Импортировать из DSL", "importFromDSLFile": "Из файла DSL", "importFromDSLUrl": "Из URL", "importFromDSLUrlPlaceholder": "Вставьте ссылку DSL сюда", "join": "Присоединяйтесь к сообществу", + "marketplace.template.categories": "Категории", + "marketplace.template.category.design": "Дизайн", + "marketplace.template.category.it": "ИТ", + "marketplace.template.category.knowledge": "Знания", + "marketplace.template.category.marketing": "Маркетинг", + "marketplace.template.category.operations": "Операции", + "marketplace.template.category.sales": "Продажи", + "marketplace.template.category.support": "Поддержка", + "marketplace.template.fetchFailed": "Не удалось получить шаблон", + "marketplace.template.importConfirm": "Импортировать", + "marketplace.template.importFailed": "Не удалось импортировать шаблон", + "marketplace.template.modalTitle": "Импортировать из Marketplace", + "marketplace.template.overview": "Обзор", + "marketplace.template.publishedBy": "От", + "marketplace.template.usageCount": "Использование", + "marketplace.template.viewOnMarketplace": "Открыть в Marketplace", "maxActiveRequests": "Максимальное количество параллельных запросов", "maxActiveRequestsPlaceholder": "Введите 0 для неограниченного количества", "maxActiveRequestsTip": "Максимальное количество одновременно активных запросов на одно приложение (0 для неограниченного количества)", diff --git a/web/i18n/ru-RU/workflow.json b/web/i18n/ru-RU/workflow.json index 55622ec730..89d2657208 100644 --- a/web/i18n/ru-RU/workflow.json +++ b/web/i18n/ru-RU/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Введите текст в поле ниже, чтобы начать отладку чат-бота", "common.processData": "Обработка данных", "common.publish": "Опубликовать", + "common.publishToMarketplace": "Опубликовать в Marketplace", + "common.publishToMarketplaceFailed": "Не удалось опубликовать в Marketplace", "common.publishUpdate": "Опубликовать обновление", "common.published": "Опубликовано", "common.publishedAt": "Опубликовано", + "common.publishingToMarketplace": "Публикация...", "common.redo": "Повторить", "common.restart": "Перезапустить", "common.restore": "Восстановить", diff --git a/web/i18n/sl-SI/app.json b/web/i18n/sl-SI/app.json index ce09d32059..a8a14d7488 100644 --- a/web/i18n/sl-SI/app.json +++ b/web/i18n/sl-SI/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Slika", "iconPicker.ok": "V redu", + "importApp": "Uvozi aplikacijo", "importDSL": "Uvozi datoteko DSL", "importFromDSL": "Uvozi iz DSL", "importFromDSLFile": "Iz datoteke DSL", "importFromDSLUrl": "Iz URL-ja", "importFromDSLUrlPlaceholder": "Tukaj prilepi povezavo DSL", "join": "Pridruži se skupnosti", + "marketplace.template.categories": "Kategorije", + "marketplace.template.category.design": "Oblikovanje", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Znanje", + "marketplace.template.category.marketing": "Trženje", + "marketplace.template.category.operations": "Operacije", + "marketplace.template.category.sales": "Prodaja", + "marketplace.template.category.support": "Podpora", + "marketplace.template.fetchFailed": "Pridobivanje predloge ni uspelo", + "marketplace.template.importConfirm": "Uvozi", + "marketplace.template.importFailed": "Uvoz predloge ni uspel", + "marketplace.template.modalTitle": "Uvozi iz Marketplace", + "marketplace.template.overview": "Pregled", + "marketplace.template.publishedBy": "Avtor", + "marketplace.template.usageCount": "Uporaba", + "marketplace.template.viewOnMarketplace": "Ogled na Marketplace", "maxActiveRequests": "Maksimalno število hkratnih zahtevkov", "maxActiveRequestsPlaceholder": "Vnesite 0 za neomejeno", "maxActiveRequestsTip": "Največje število hkrati aktivnih zahtevkov na aplikacijo (0 za neomejeno)", diff --git a/web/i18n/sl-SI/workflow.json b/web/i18n/sl-SI/workflow.json index 7ea8dedec4..a7c2914626 100644 --- a/web/i18n/sl-SI/workflow.json +++ b/web/i18n/sl-SI/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Vnesite vsebino v spodnje polje, da začnete odpravljati napake v chatbotu", "common.processData": "Obdelava podatkov", "common.publish": "Objavi", + "common.publishToMarketplace": "Objavi na Marketplace", + "common.publishToMarketplaceFailed": "Objava na Marketplace ni uspela", "common.publishUpdate": "Objavi posodobitev", "common.published": "Objavljeno", "common.publishedAt": "Objavljeno", + "common.publishingToMarketplace": "Objavljanje...", "common.redo": "Ponovno naredi", "common.restart": "Znova zaženi", "common.restore": "Obnovi", diff --git a/web/i18n/th-TH/app.json b/web/i18n/th-TH/app.json index d59a5b8505..624d7b9ec9 100644 --- a/web/i18n/th-TH/app.json +++ b/web/i18n/th-TH/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "อิโมจิ", "iconPicker.image": "ภาพ", "iconPicker.ok": "ตกลง, ได้", + "importApp": "นำเข้าแอป", "importDSL": "นําเข้าไฟล์ DSL", "importFromDSL": "นําเข้าจาก DSL", "importFromDSLFile": "จากไฟล์ DSL", "importFromDSLUrl": "จาก URL", "importFromDSLUrlPlaceholder": "วางลิงค์ DSL ที่นี่", "join": "เข้าร่วมชุมชนนักพัฒนา", + "marketplace.template.categories": "หมวดหมู่", + "marketplace.template.category.design": "การออกแบบ", + "marketplace.template.category.it": "ไอที", + "marketplace.template.category.knowledge": "ความรู้", + "marketplace.template.category.marketing": "การตลาด", + "marketplace.template.category.operations": "การดำเนินงาน", + "marketplace.template.category.sales": "การขาย", + "marketplace.template.category.support": "การสนับสนุน", + "marketplace.template.fetchFailed": "ดึงข้อมูลเทมเพลตล้มเหลว", + "marketplace.template.importConfirm": "นำเข้า", + "marketplace.template.importFailed": "นำเข้าเทมเพลตล้มเหลว", + "marketplace.template.modalTitle": "นำเข้าจาก Marketplace", + "marketplace.template.overview": "ภาพรวม", + "marketplace.template.publishedBy": "โดย", + "marketplace.template.usageCount": "การใช้งาน", + "marketplace.template.viewOnMarketplace": "ดูบน Marketplace", "maxActiveRequests": "จำนวนคำขอพร้อมกันสูงสุด", "maxActiveRequestsPlaceholder": "ใส่ 0 สำหรับไม่จำกัด", "maxActiveRequestsTip": "จำนวนการร้องขอที่ใช้งานพร้อมกันสูงสุดต่อแอป (0 หมายถึงไม่จำกัด)", diff --git a/web/i18n/th-TH/workflow.json b/web/i18n/th-TH/workflow.json index e1280cf438..d8a9b53f2a 100644 --- a/web/i18n/th-TH/workflow.json +++ b/web/i18n/th-TH/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "ป้อนเนื้อหาในช่องด้านล่างเพื่อเริ่มแก้ไขข้อบกพร่องของแชทบอท", "common.processData": "ประมวลผลข้อมูล", "common.publish": "ตีพิมพ์", + "common.publishToMarketplace": "เผยแพร่ไปยัง Marketplace", + "common.publishToMarketplaceFailed": "เผยแพร่ไปยัง Marketplace ล้มเหลว", "common.publishUpdate": "เผยแพร่การอัปเดต", "common.published": "เผย แพร่", "common.publishedAt": "เผย แพร่", + "common.publishingToMarketplace": "กำลังเผยแพร่...", "common.redo": "พร้อม", "common.restart": "เริ่มใหม่", "common.restore": "ซ่อมแซม", diff --git a/web/i18n/tr-TR/app.json b/web/i18n/tr-TR/app.json index 2978f7cffd..aa10f954e9 100644 --- a/web/i18n/tr-TR/app.json +++ b/web/i18n/tr-TR/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Görsel", "iconPicker.ok": "Tamam", + "importApp": "Uygulamayı İçe Aktar", "importDSL": "DSL dosyasını içe aktar", "importFromDSL": "DSL içe aktar", "importFromDSLFile": "DSL dosyasından", "importFromDSLUrl": "URL'den", "importFromDSLUrlPlaceholder": "DSL bağlantısını buraya yapıştır", "join": "Topluluğa katıl", + "marketplace.template.categories": "Kategoriler", + "marketplace.template.category.design": "Tasarım", + "marketplace.template.category.it": "BT", + "marketplace.template.category.knowledge": "Bilgi", + "marketplace.template.category.marketing": "Pazarlama", + "marketplace.template.category.operations": "Operasyonlar", + "marketplace.template.category.sales": "Satış", + "marketplace.template.category.support": "Destek", + "marketplace.template.fetchFailed": "Şablon alınamadı", + "marketplace.template.importConfirm": "İçe Aktar", + "marketplace.template.importFailed": "Şablon içe aktarılamadı", + "marketplace.template.modalTitle": "Marketplace'den İçe Aktar", + "marketplace.template.overview": "Genel Bakış", + "marketplace.template.publishedBy": "Yayıncı", + "marketplace.template.usageCount": "Kullanım", + "marketplace.template.viewOnMarketplace": "Marketplace'de Görüntüle", "maxActiveRequests": "Maksimum eş zamanlı istekler", "maxActiveRequestsPlaceholder": "Sınırsız için 0 girin", "maxActiveRequestsTip": "Her uygulama için maksimum eşzamanlı aktif istek sayısı (sınırsız için 0)", diff --git a/web/i18n/tr-TR/workflow.json b/web/i18n/tr-TR/workflow.json index 54ee28cf1c..7cd69d7df1 100644 --- a/web/i18n/tr-TR/workflow.json +++ b/web/i18n/tr-TR/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Sohbet Robotunu hata ayıklamak için aşağıdaki kutuya içerik girin", "common.processData": "Veriyi İşle", "common.publish": "Yayınla", + "common.publishToMarketplace": "Marketplace'de Yayınla", + "common.publishToMarketplaceFailed": "Marketplace'de Yayınlama Başarısız", "common.publishUpdate": "Güncellemeyi Yayınla", "common.published": "Yayınlandı", "common.publishedAt": "Yayınlandı", + "common.publishingToMarketplace": "Yayınlanıyor...", "common.redo": "Yinele", "common.restart": "Yeniden Başlat", "common.restore": "Geri Yükle", diff --git a/web/i18n/uk-UA/app.json b/web/i18n/uk-UA/app.json index f224f0c31f..f88e1e60f9 100644 --- a/web/i18n/uk-UA/app.json +++ b/web/i18n/uk-UA/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Емодзі", "iconPicker.image": "Зображення", "iconPicker.ok": "OK", + "importApp": "Імпортувати додаток", "importDSL": "Імпортувати файл DSL", "importFromDSL": "Імпорт з DSL", "importFromDSLFile": "З DSL-файлу", "importFromDSLUrl": "З URL", "importFromDSLUrlPlaceholder": "Вставте посилання на DSL тут", "join": "Приєднуйтесь до спільноти", + "marketplace.template.categories": "Категорії", + "marketplace.template.category.design": "Дизайн", + "marketplace.template.category.it": "ІТ", + "marketplace.template.category.knowledge": "Знання", + "marketplace.template.category.marketing": "Маркетинг", + "marketplace.template.category.operations": "Операції", + "marketplace.template.category.sales": "Продажі", + "marketplace.template.category.support": "Підтримка", + "marketplace.template.fetchFailed": "Не вдалося отримати шаблон", + "marketplace.template.importConfirm": "Імпортувати", + "marketplace.template.importFailed": "Не вдалося імпортувати шаблон", + "marketplace.template.modalTitle": "Імпортувати з Marketplace", + "marketplace.template.overview": "Огляд", + "marketplace.template.publishedBy": "Від", + "marketplace.template.usageCount": "Використання", + "marketplace.template.viewOnMarketplace": "Переглянути на Marketplace", "maxActiveRequests": "Максимальна кількість одночасних запитів", "maxActiveRequestsPlaceholder": "Введіть 0 для необмеженого", "maxActiveRequestsTip": "Максимальна кількість одночасних активних запитів на додаток (0 для необмеженої кількості)", diff --git a/web/i18n/uk-UA/workflow.json b/web/i18n/uk-UA/workflow.json index 94f869845e..44d527618e 100644 --- a/web/i18n/uk-UA/workflow.json +++ b/web/i18n/uk-UA/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Введіть вміст у поле нижче, щоб розпочати налагодження чат-бота", "common.processData": "Обробити дані", "common.publish": "Опублікувати", + "common.publishToMarketplace": "Опублікувати на Marketplace", + "common.publishToMarketplaceFailed": "Не вдалося опублікувати на Marketplace", "common.publishUpdate": "Опублікувати оновлення", "common.published": "Опубліковано", "common.publishedAt": "Опубліковано о", + "common.publishingToMarketplace": "Публікація...", "common.redo": "Повторити", "common.restart": "Перезапустити", "common.restore": "Відновити", diff --git a/web/i18n/vi-VN/app.json b/web/i18n/vi-VN/app.json index 399d2dccf5..2be7906afb 100644 --- a/web/i18n/vi-VN/app.json +++ b/web/i18n/vi-VN/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Biểu tượng cảm xúc", "iconPicker.image": "Hình ảnh", "iconPicker.ok": "Đồng ý", + "importApp": "Nhập App", "importDSL": "Nhập tệp DSL", "importFromDSL": "Nhập từ DSL", "importFromDSLFile": "Từ tệp DSL", "importFromDSLUrl": "Từ URL", "importFromDSLUrlPlaceholder": "Dán liên kết DSL vào đây", "join": "Tham gia cộng đồng", + "marketplace.template.categories": "Danh mục", + "marketplace.template.category.design": "Thiết kế", + "marketplace.template.category.it": "CNTT", + "marketplace.template.category.knowledge": "Kiến thức", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Vận hành", + "marketplace.template.category.sales": "Bán hàng", + "marketplace.template.category.support": "Hỗ trợ", + "marketplace.template.fetchFailed": "Không thể lấy mẫu", + "marketplace.template.importConfirm": "Nhập", + "marketplace.template.importFailed": "Không thể nhập mẫu", + "marketplace.template.modalTitle": "Nhập từ Marketplace", + "marketplace.template.overview": "Tổng quan", + "marketplace.template.publishedBy": "Bởi", + "marketplace.template.usageCount": "Lượt sử dụng", + "marketplace.template.viewOnMarketplace": "Xem trên Marketplace", "maxActiveRequests": "Số yêu cầu đồng thời tối đa", "maxActiveRequestsPlaceholder": "Nhập 0 để không giới hạn", "maxActiveRequestsTip": "Số yêu cầu hoạt động đồng thời tối đa cho mỗi ứng dụng (0 để không giới hạn)", diff --git a/web/i18n/vi-VN/workflow.json b/web/i18n/vi-VN/workflow.json index 377a794464..231c01bc82 100644 --- a/web/i18n/vi-VN/workflow.json +++ b/web/i18n/vi-VN/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Nhập nội dung vào hộp bên dưới để bắt đầu gỡ lỗi Chatbot", "common.processData": "Xử lý dữ liệu", "common.publish": "Xuất bản", + "common.publishToMarketplace": "Xuất bản lên Marketplace", + "common.publishToMarketplaceFailed": "Xuất bản lên Marketplace thất bại", "common.publishUpdate": "Cập nhật xuất bản", "common.published": "Đã xuất bản", "common.publishedAt": "Đã xuất bản lúc", + "common.publishingToMarketplace": "Đang xuất bản...", "common.redo": "Làm lại", "common.restart": "Khởi động lại", "common.restore": "Khôi phục", diff --git a/web/i18n/zh-Hans/app.json b/web/i18n/zh-Hans/app.json index 278a1b782d..8f46a4433e 100644 --- a/web/i18n/zh-Hans/app.json +++ b/web/i18n/zh-Hans/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "表情符号", "iconPicker.image": "图片", "iconPicker.ok": "确认", + "importApp": "导入应用", "importDSL": "导入 DSL 文件", "importFromDSL": "导入 DSL", "importFromDSLFile": "文件", "importFromDSLUrl": "URL", "importFromDSLUrlPlaceholder": "输入 DSL 文件的 URL", "join": "参与社区", + "marketplace.template.categories": "分类", + "marketplace.template.category.design": "设计", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "知识", + "marketplace.template.category.marketing": "营销", + "marketplace.template.category.operations": "运营", + "marketplace.template.category.sales": "销售", + "marketplace.template.category.support": "支持", + "marketplace.template.fetchFailed": "获取模板失败", + "marketplace.template.importConfirm": "导入", + "marketplace.template.importFailed": "导入模板失败", + "marketplace.template.modalTitle": "从市场导入", + "marketplace.template.overview": "概述", + "marketplace.template.publishedBy": "来自", + "marketplace.template.usageCount": "使用次数", + "marketplace.template.viewOnMarketplace": "在市场查看", "maxActiveRequests": "最大活跃请求数", "maxActiveRequestsPlaceholder": "0 表示不限制", "maxActiveRequestsTip": "当前应用的最大活跃请求数(0 表示不限制)", diff --git a/web/i18n/zh-Hans/workflow.json b/web/i18n/zh-Hans/workflow.json index 817bf580c9..f9b8e9d652 100644 --- a/web/i18n/zh-Hans/workflow.json +++ b/web/i18n/zh-Hans/workflow.json @@ -231,9 +231,12 @@ "common.publish": "发布", "common.publishAsEvaluationWorkflow": "发布为评测工作流", "common.publishAsStandardWorkflow": "发布为标准工作流", + "common.publishToMarketplace": "发布到市场", + "common.publishToMarketplaceFailed": "发布到市场失败", "common.publishUpdate": "发布更新", "common.published": "已发布", "common.publishedAt": "发布于", + "common.publishingToMarketplace": "发布中...", "common.redo": "重做", "common.restart": "重新开始", "common.restore": "恢复", diff --git a/web/i18n/zh-Hant/app.json b/web/i18n/zh-Hant/app.json index a7fbcfd65f..7c485b6520 100644 --- a/web/i18n/zh-Hant/app.json +++ b/web/i18n/zh-Hant/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "表情符號", "iconPicker.image": "圖片", "iconPicker.ok": "確認", + "importApp": "匯入應用", "importDSL": "匯入 DSL 檔案", "importFromDSL": "從 DSL 導入", "importFromDSLFile": "從 DSL 檔", "importFromDSLUrl": "寄件者 URL", "importFromDSLUrlPlaceholder": "在此處貼上 DSL 連結", "join": "參與社群", + "marketplace.template.categories": "分類", + "marketplace.template.category.design": "設計", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "知識", + "marketplace.template.category.marketing": "行銷", + "marketplace.template.category.operations": "營運", + "marketplace.template.category.sales": "銷售", + "marketplace.template.category.support": "支援", + "marketplace.template.fetchFailed": "獲取模板失敗", + "marketplace.template.importConfirm": "匯入", + "marketplace.template.importFailed": "匯入模板失敗", + "marketplace.template.modalTitle": "從市場匯入", + "marketplace.template.overview": "概覽", + "marketplace.template.publishedBy": "由", + "marketplace.template.usageCount": "使用次數", + "marketplace.template.viewOnMarketplace": "在市場上查看", "maxActiveRequests": "同時最大請求數", "maxActiveRequestsPlaceholder": "輸入 0 以表示無限", "maxActiveRequestsTip": "每個應用程式可同時活躍請求的最大數量(0為無限制)", diff --git a/web/i18n/zh-Hant/workflow.json b/web/i18n/zh-Hant/workflow.json index 1e10badec0..9d296250db 100644 --- a/web/i18n/zh-Hant/workflow.json +++ b/web/i18n/zh-Hant/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "在下面的框中輸入內容開始測試聊天機器人", "common.processData": "資料處理", "common.publish": "發佈", + "common.publishToMarketplace": "發佈到市場", + "common.publishToMarketplaceFailed": "發佈到市場失敗", "common.publishUpdate": "發布更新", "common.published": "已發佈", "common.publishedAt": "發佈於", + "common.publishingToMarketplace": "發佈中...", "common.redo": "重做", "common.restart": "重新開始", "common.restore": "恢復", diff --git a/web/next/navigation.ts b/web/next/navigation.ts index ec7c112645..f8ff821d1f 100644 --- a/web/next/navigation.ts +++ b/web/next/navigation.ts @@ -1,4 +1,5 @@ export { + redirect, useParams, usePathname, useRouter, @@ -6,3 +7,4 @@ export { useSelectedLayoutSegment, useSelectedLayoutSegments, } from 'next/navigation' +export type { ReadonlyURLSearchParams } from 'next/navigation' diff --git a/web/service/__tests__/base.spec.ts b/web/service/__tests__/base.spec.ts new file mode 100644 index 0000000000..a4d1dcbfe7 --- /dev/null +++ b/web/service/__tests__/base.spec.ts @@ -0,0 +1,68 @@ +import { buildSigninUrlWithRedirect } from '../base' + +vi.mock('@/utils/var', () => ({ + basePath: '/app', + API_PREFIX: '/console/api', + PUBLIC_API_PREFIX: '/api', + IS_CE_EDITION: false, +})) + +describe('buildSigninUrlWithRedirect', () => { + const originalLocation = globalThis.location + + beforeEach(() => { + Object.defineProperty(globalThis, 'location', { + value: { + origin: 'https://example.com', + pathname: '/apps', + href: 'https://example.com/apps', + }, + writable: true, + configurable: true, + }) + }) + + afterEach(() => { + Object.defineProperty(globalThis, 'location', { + value: originalLocation, + writable: true, + configurable: true, + }) + }) + + it('should return plain signin URL for non-OAuth pages', () => { + const url = buildSigninUrlWithRedirect() + expect(url).toBe('https://example.com/app/signin') + }) + + it('should append redirect_url for OAuth authorize pages', () => { + const oauthHref = 'https://example.com/account/oauth/authorize?client_id=abc&state=xyz' + Object.defineProperty(globalThis, 'location', { + value: { + origin: 'https://example.com', + pathname: '/account/oauth/authorize', + href: oauthHref, + }, + writable: true, + configurable: true, + }) + + const url = buildSigninUrlWithRedirect() + expect(url).toBe(`https://example.com/app/signin?redirect_url=${encodeURIComponent(oauthHref)}`) + }) + + it('should not include redirect_url for other paths containing partial match', () => { + Object.defineProperty(globalThis, 'location', { + value: { + origin: 'https://example.com', + pathname: '/settings/oauth', + href: 'https://example.com/settings/oauth', + }, + writable: true, + configurable: true, + }) + + const url = buildSigninUrlWithRedirect() + expect(url).toBe('https://example.com/app/signin') + }) +}) diff --git a/web/service/apps.ts b/web/service/apps.ts index b6a5386fe0..d2c6593a34 100644 --- a/web/service/apps.ts +++ b/web/service/apps.ts @@ -192,3 +192,11 @@ export const updateTracingConfig = ({ appId, body }: { appId: string, body: Trac export const removeTracingConfig = ({ appId, provider }: { appId: string, provider: TracingProvider }): Promise => { return del(`/apps/${appId}/trace-config?tracing_provider=${provider}`) } + +type PublishToCreatorsPlatformResponse = { + redirect_url: string +} + +export const publishToCreatorsPlatform = ({ appID }: { appID: string }): Promise => { + return post(`apps/${appID}/publish-to-creators-platform`, { body: {} }) +} diff --git a/web/service/base.ts b/web/service/base.ts index 64d13ef59a..d1ef06c314 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -140,6 +140,20 @@ function jumpTo(url: string) { globalThis.location.href = url } +const OAUTH_AUTHORIZE_PATH = '/account/oauth/authorize' + +export const buildSigninUrlWithRedirect = (): string => { + const loginUrl = `${globalThis.location.origin}${basePath}/signin` + + // Only preserve redirect URL for OAuth authorize pages + if (globalThis.location.pathname.includes(OAUTH_AUTHORIZE_PATH)) { + const currentUrl = globalThis.location.href + return `${loginUrl}?redirect_url=${encodeURIComponent(currentUrl)}` + } + + return loginUrl +} + function unicodeToChar(text: string) { if (!text) return '' @@ -795,14 +809,14 @@ export const request = async(url: string, options = {}, otherOptions?: IOther if (refreshErr === null) return baseFetch(url, options, otherOptionsForBaseFetch) if (location.pathname !== `${basePath}/signin` || !IS_CE_EDITION) { - jumpTo(loginUrl) + jumpTo(buildSigninUrlWithRedirect()) return Promise.reject(err) } if (!silent) { toast.error(message) return Promise.reject(err) } - jumpTo(loginUrl) + jumpTo(buildSigninUrlWithRedirect()) return Promise.reject(err) } else { diff --git a/web/service/marketplace-templates.ts b/web/service/marketplace-templates.ts new file mode 100644 index 0000000000..d9ff7f314f --- /dev/null +++ b/web/service/marketplace-templates.ts @@ -0,0 +1,18 @@ +import { useQuery } from '@tanstack/react-query' +import { MARKETPLACE_API_PREFIX } from '@/config' +import { marketplaceQuery } from './client' + +export const useMarketplaceTemplateDetail = (templateId: string | null) => { + return useQuery({ + ...marketplaceQuery.templateDetail.queryOptions({ input: { params: { templateId: templateId ?? '' } } }), + enabled: !!templateId, + }) +} + +export const fetchMarketplaceTemplateDSL = async (templateId: string): Promise => { + const url = `${MARKETPLACE_API_PREFIX}/templates/${templateId}/dsl` + const response = await fetch(url) + if (!response.ok) + throw new Error(`Failed to fetch DSL: ${response.statusText}`) + return response.text() +} diff --git a/web/types/feature.ts b/web/types/feature.ts index 635221f2be..77d4045318 100644 --- a/web/types/feature.ts +++ b/web/types/feature.ts @@ -64,6 +64,7 @@ export type SystemFeatures = { allow_email_code_login: boolean allow_email_password_login: boolean } + enable_creators_platform: boolean enable_trial_app: boolean enable_explore_banner: boolean } @@ -108,6 +109,7 @@ export const defaultSystemFeatures: SystemFeatures = { allow_email_code_login: false, allow_email_password_login: false, }, + enable_creators_platform: false, enable_trial_app: false, enable_explore_banner: false, } diff --git a/web/types/marketplace-template.ts b/web/types/marketplace-template.ts new file mode 100644 index 0000000000..ac2b7cb2aa --- /dev/null +++ b/web/types/marketplace-template.ts @@ -0,0 +1,11 @@ +export type MarketplaceTemplate = { + id: string + template_name: string + overview: string + icon: string + icon_background: string + icon_file_key: string + publisher_unique_handle: string + usage_count: number + categories: string[] +}