Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-04-24 17:52:27 +08:00
commit 5263a65ed6
269 changed files with 11072 additions and 7352 deletions

View File

@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Creators Platform configuration
CREATORS_PLATFORM_FEATURES_ENABLED=true
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}

View File

@ -11,7 +11,7 @@ from configs import dify_config
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db
from models import Tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
)
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for Creators Platform integration
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable Creators Platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client ID for Creators Platform integration",
default="",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@ -1379,6 +1400,7 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
CreatorsPlatformConfig,
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,

View File

@ -0,0 +1,6 @@
from pydantic import BaseModel, JsonValue
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue]
action: str

View File

@ -692,6 +692,32 @@ class AppExportApi(Resource):
return payload.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
class AppPublishToCreatorsPlatformApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")

View File

@ -8,10 +8,10 @@ from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource):
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@staticmethod
def _ensure_console_recipient_type(form: Form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
raise NotFoundError("form not found")
@setup_required
@login_required
@account_initialization_required
@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource):
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
raise NotFoundError(f"form not found, token={form_token}")
# The type checker is not smart enought to validate the following invariant.
# So we need to assert it manually.
assert recipient_type is not None, "recipient_type cannot be None here."

View File

@ -37,6 +37,11 @@ class TagBindingRemovePayload(BaseModel):
type: TagType = Field(description="Tag type")
class TagBindingItemDeletePayload(BaseModel):
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
class TagListQueryParam(BaseModel):
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
keyword: str | None = Field(None, description="Search keyword")
@ -70,6 +75,7 @@ register_schema_models(
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagBindingItemDeletePayload,
TagListQueryParam,
TagResponse,
)
@ -152,41 +158,107 @@ class TagUpdateDeleteApi(Resource):
return "", 204
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
def _require_tag_binding_edit_permission() -> None:
"""
Ensure the current account can edit tag bindings.
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
"""
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
def _create_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
def _remove_tag_binding() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=payload.tag_id,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings")
class TagBindingCollectionApi(Resource):
"""Canonical collection resource for tag binding creation."""
@console_ns.doc("create_tag_binding")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
return _create_tag_bindings()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
@console_ns.route("/tag-bindings/<uuid:id>")
class TagBindingItemApi(Resource):
"""Canonical item resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding")
@console_ns.doc(params={"id": "Tag ID"})
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
_require_tag_binding_edit_permission()
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=str(id),
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings/create")
class DeprecatedTagBindingCreateApi(Resource):
"""Deprecated verb-based alias for tag binding creation."""
@console_ns.doc("create_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
class DeprecatedTagBindingRemoveApi(Resource):
"""Deprecated verb-based alias for tag binding deletion."""
@console_ns.doc("delete_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200
return _remove_tag_binding()

View File

@ -23,9 +23,11 @@ from .app import (
conversation,
file,
file_preview,
human_input_form,
message,
site,
workflow,
workflow_events,
)
from .dataset import (
dataset,
@ -50,6 +52,7 @@ __all__ = [
"file",
"file_preview",
"hit_testing",
"human_input_form",
"index",
"message",
"metadata",
@ -58,6 +61,7 @@ __all__ = [
"segment",
"site",
"workflow",
"workflow_events",
]
api.add_namespace(service_api_ns)

View File

@ -0,0 +1,137 @@
"""
Service API human input form endpoints.
This module exposes app-token authenticated APIs for fetching and submitting
paused human input forms in workflow/chatflow runs.
"""
import json
import logging
from datetime import datetime
from flask import Response
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
def _jsonify_form_definition(form: Form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
# Keep app-token callers scoped to the public web-form surface; internal HITL
# routes must continue to flow through console-only authentication.
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
raise NotFound("Form not found")
@service_api_ns.route("/form/human_input/<string:form_token>")
class WorkflowHumanInputFormApi(Resource):
@service_api_ns.doc("get_human_input_form")
@service_api_ns.doc(description="Get a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form retrieved successfully",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token
def get(self, app_model: App, form_token: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
@service_api_ns.doc("submit_human_input_form")
@service_api_ns.doc(description="Submit a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form submitted successfully",
400: "Bad request - invalid submission data",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, form_token: str):
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
recipient_type = form.recipient_type
if recipient_type is None:
logger.warning("Recipient type is None for form, form_id=%s", form.id)
raise BadRequest("Form recipient type is invalid")
try:
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
submission_end_user_id=end_user.id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200

View File

@ -0,0 +1,142 @@
"""
Service API workflow resume event stream endpoints.
"""
import json
from collections.abc import Generator
from flask import Response, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.task_entities import StreamEvent
from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@service_api_ns.route("/workflow/<string:task_id>/events")
class WorkflowEventsApi(Resource):
"""Service API for getting workflow execution events after resume."""
@service_api_ns.doc("get_workflow_events")
@service_api_ns.doc(description="Get workflow execution events stream after resume")
@service_api_ns.doc(
params={
"task_id": "Workflow run ID",
"user": "End user identifier (query param)",
"include_state_snapshot": (
"Whether to replay from persisted state snapshot, "
'specify `"true"` to include a status snapshot of executed nodes'
),
"continue_on_pause": (
"Whether to keep the stream open across workflow_paused events,"
'specify `"true"` to keep the stream open for `workflow_paused` events.'
),
}
)
@service_api_ns.doc(
responses={
200: "SSE event stream",
401: "Unauthorized - invalid API token",
404: "Workflow run not found",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, task_id: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise NotWorkflowAppError()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=task_id,
)
if workflow_run is None:
raise NotFound("Workflow run not found")
if workflow_run.app_id != app_model.id:
raise NotFound("Workflow run not found")
if workflow_run.created_by_role != CreatorUserRole.END_USER:
raise NotFound("Workflow run not found")
if workflow_run.created_by != end_user.id:
raise NotFound("Workflow run not found")
workflow_run_entity = workflow_run
if workflow_run_entity.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run_entity.id,
workflow_run=workflow_run_entity,
creator_user=end_user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise NotWorkflowAppError()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run_entity,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(
app_mode,
workflow_run_entity.id,
terminal_events=terminal_events,
),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)

View File

@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,

View File

@ -34,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
@ -655,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
) -> (
ChatbotAppBlockingResponse
| AdvancedChatPausedBlockingResponse
| Generator[ChatbotAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -3,7 +3,7 @@ from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppBlockingResponse,
AdvancedChatPausedBlockingResponse,
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
@ -12,22 +12,40 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
StreamEvent,
)
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class AdvancedChatAppGenerateResponseConverter(
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
@classmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
paused_data = blocking_response.data.model_dump(mode="json")
return {
"event": StreamEvent.WORKFLOW_PAUSED.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
"workflow_run_id": blocking_response.data.workflow_run_id,
"data": paused_data,
}
response = {
"event": "message",
"event": StreamEvent.MESSAGE.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
return response

View File

@ -53,14 +53,18 @@ from core.app.entities.queue_entities import (
WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
StreamResponse,
WorkflowPauseStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
AdvancedChatPausedBlockingResponse,
Generator[ChatbotAppStreamResponse, None, None],
]:
"""
Process generate task pipeline.
:return:
@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
"""
Process blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
return AdvancedChatPausedBlockingResponse(
task_id=stream_response.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=stream_response.data.workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
status=stream_response.data.status,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
),
)
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> AdvancedChatPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return AdvancedChatPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=paused_nodes,
reasons=reasons,
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
)
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@ -1,7 +1,9 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union
from typing import Any, Union, cast
from pydantic import JsonValue
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@ -11,8 +13,10 @@ from graphon.model_runtime.errors.invoke import InvokeError
logger = logging.getLogger(__name__)
class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
@classmethod
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
return cast(TBlockingResponse, response)
@classmethod
def convert(
@ -20,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
else:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -29,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
else:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -39,12 +43,12 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@ -106,13 +110,13 @@ class AppGenerateResponseConverter(ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses: dict[type[Exception], dict[str, Any]] = {
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
@ -126,7 +130,7 @@ class AppGenerateResponseConverter(ABC):
}
# Determine the response based on the type of exception
data: dict[str, Any] | None = None
data: dict[str, JsonValue] | None = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
)
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@ -52,6 +52,7 @@ from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -336,7 +337,26 @@ class WorkflowResponseConverter:
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
form_token_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=(
HumanInputSurface.SERVICE_API
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
else None
),
)
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
pause_reasons = enrich_human_input_pause_reasons(
pause_reasons,
form_tokens_by_form_id=form_token_by_form_id,
expiration_times_by_form_id={
form_id: int(expiration_time.timestamp())
for form_id, expiration_time in expiration_times_by_form_id.items()
},
)
responses: list[StreamResponse] = []

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,17 +14,15 @@ from core.app.entities.task_entities import (
)
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
response = {
response: dict[str, Any] = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,

View File

@ -1,6 +1,7 @@
from collections.abc import Callable, Generator, Mapping
from collections.abc import Callable, Generator, Iterable, Mapping
from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.task_entities import StreamEvent
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from models.model import AppMode
@ -26,6 +27,7 @@ class MessageGenerator:
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
@ -33,4 +35,5 @@ class MessageGenerator:
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
terminal_events=terminal_events,
)

View File

@ -13,11 +13,9 @@ from core.app.entities.task_entities import (
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -27,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -59,7 +59,7 @@ def stream_topic_events(
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events:
if terminal_events is None:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:

View File

@ -25,7 +25,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -9,24 +9,29 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
class WorkflowAppGenerateResponseConverter(
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
):
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return blocking_response.model_dump()
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -42,12 +42,15 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
def process(
self,
) -> Union[
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
]:
"""
Process generate task pipeline.
:return:
@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
"""
To blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
response = WorkflowAppBlockingResponse(
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data(
data=WorkflowAppPausedBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
return WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
),
)
return response
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> WorkflowAppPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
created_at = int(runtime_state.start_at)
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
data=WorkflowAppPausedBlockingResponse.Data(
id=human_input_responses[-1].workflow_run_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
created_at=created_at,
finished_at=None,
paused_nodes=paused_nodes,
reasons=reasons,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:

View File

@ -1,12 +1,13 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.nodes.human_input.entities import FormInput, UserAction
@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse):
data: Data
class HumanInputRequiredPauseReasonPayload(BaseModel):
"""
Public pause-reason payload used by blocking responses when only
``human_input_required`` events are available.
"""
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@classmethod
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
return cls(
form_id=data.form_id,
node_id=data.node_id,
node_title=data.node_title,
form_content=data.form_content,
inputs=data.inputs,
actions=data.actions,
display_in_ui=data.display_in_ui,
form_token=data.form_token,
resolved_default_values=data.resolved_default_values,
expiration_time=data.expiration_time,
)
class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel):
"""
@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return {
"event": self.event.value,
"task_id": self.task_id,
@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return {
"event": self.event.value,
"task_id": self.task_id,
@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
data: Data
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
"""
ChatbotAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
mode: str
conversation_id: str
message_id: str
workflow_run_id: str
answer: str
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
status: WorkflowExecutionStatus
elapsed_time: float
total_tokens: int
total_steps: int
data: Data
class CompletionAppBlockingResponse(AppBlockingResponse):
"""
CompletionAppBlockingResponse entity
@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
data: Data
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
"""
WorkflowAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
workflow_id: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
created_at: int
finished_at: int | None
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
workflow_run_id: str
data: Data
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
class DifyCredentialsProvider:
"""Resolves and returns LLM credentials for a given provider and model.
Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).
Do **not** keep one instance for the lifetime of a process or across
unrelated invocations. Create a new provider per request, workflow run, or
other bounded scope where up-to-date credentials matter.
"""
tenant_id: str
provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
def __init__(
self,
@ -30,8 +44,12 @@ class DifyCredentialsProvider:
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
self.credentials_cache = {}
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
@ -46,6 +64,7 @@ class DifyCredentialsProvider:
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials
@ -65,7 +84,8 @@ class DifyModelFactory:
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
),
enable_credentials_cache=True,
)
self.model_manager = model_manager
@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager)
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),

View File

@ -0,0 +1,41 @@
"""
Helper module for Creators Platform integration.
Provides functionality to upload DSL files to the Creators Platform
and generate redirect URLs with OAuth authorization codes.
"""
import logging
from urllib.parse import urlencode
import httpx
from yarl import URL
from configs import dify_config
logger = logging.getLogger(__name__)
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
response.raise_for_status()
data = response.json()
claim_code = data.get("data", {}).get("claim_code")
if not claim_code:
raise ValueError("Creators Platform did not return a valid claim_code")
return claim_code
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
params: dict[str, str] = {"dsl_claim_code": claim_code}
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
if client_id:
from services.oauth_server import OAuthServerService
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
params["oauth_code"] = oauth_code
return f"{base_url}?{urlencode(params)}"

View File

@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
@ -217,8 +215,8 @@ class LLMGenerator:
else:
# Default-model generation keeps the built-in suggested-questions tuning.
model_parameters = {
"max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
"temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
"max_tokens": 2560,
"temperature": 0.0,
}
stop = []

View File

@ -10,7 +10,14 @@ logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser:
def __init__(self, instruction_prompt: str | None = None) -> None:
self._instruction_prompt = instruction_prompt or DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
@staticmethod
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
if not instruction_prompt or not instruction_prompt.strip():
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
def get_format_instructions(self) -> str:
return self._instruction_prompt

View File

@ -104,9 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
'["question1","question2","question3"]\n'
)
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step."

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
@ -36,11 +37,13 @@ class ModelInstance:
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
self.provider_model_bundle = provider_model_bundle
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
@ -434,8 +437,30 @@ class ModelInstance:
class ModelManager:
def __init__(self, provider_manager: ProviderManager):
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
``(tenant_id, provider, model_type, model)`` are stored in
``_credentials_cache`` and reused. That can return **stale** credentials after
API keys or provider settings change, so a manager constructed with
``enable_credentials_cache=True`` should not be kept for the lifetime of a
process or shared across unrelated work. Prefer a new manager per request,
workflow run, or similar bounded scope.
The default is ``enable_credentials_cache=False``; in that mode the internal
credential cache is not populated, and each ``get_model_instance`` call
loads credentials from the current provider configuration.
"""
def __init__(
self,
provider_manager: ProviderManager,
*,
enable_credentials_cache: bool = False,
) -> None:
self._provider_manager = provider_manager
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
@ -463,8 +488,19 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
cred_cache_key = (tenant_id, provider, model_type.value, model)
if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)
ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@ -156,7 +156,8 @@ class Jieba(BaseKeyword):
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return dict(keyword_table_dict["__data__"]["table"])
data: Any = keyword_table_dict["__data__"]
return dict(data["table"])
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(

View File

@ -109,7 +109,7 @@ class JiebaKeywordTableHandler:
"""Extract keywords with JIEBA tfidf."""
keywords = self._tfidf.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
topK=max_keywords_per_chunk or 10,
)
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords)

View File

@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter:
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
stream=False, # pyright: ignore[reportArgumentType]
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
usage = result.usage or LLMUsage.empty_usage()

View File

@ -14,23 +14,23 @@ from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
class EncryptionError(Exception):
"""Encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
class SystemEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
A simple parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
This class provides methods to encrypt and decrypt parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: str | None = None):
"""
Initialize the OAuth encrypter.
Initialize the encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
def encrypt_params(self, params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Encrypt parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
EncryptionError: If encryption fails
ValueError: If params is invalid
"""
try:
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
raise EncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Decrypt parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
EncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
if not isinstance(params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
return params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
raise EncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
"""
Create an OAuth encrypter instance.
Create an encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
return SystemEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: SystemOAuthEncrypter | None = None
_encrypter: SystemEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
def get_system_encrypter() -> SystemEncrypter:
"""
Get the global OAuth encrypter instance.
Get the global encrypter instance.
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
global _encrypter
if _encrypter is None:
_encrypter = SystemEncrypter()
return _encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
def encrypt_system_params(params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Encrypt parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
params: Parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
return get_system_encrypter().encrypt_params(params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Decrypt parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
return get_system_encrypter().decrypt_params(encrypted_data)

View File

@ -12,20 +12,16 @@ from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
from extensions.ext_database import db
from models.human_input import HumanInputFormRecipient, RecipientType
_FORM_TOKEN_PRIORITY = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def load_form_tokens_by_form_id(
form_ids: Sequence[str],
*,
session: Session | None = None,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
"""Load the preferred access token for each human input form."""
unique_form_ids = list(dict.fromkeys(form_ids))
@ -33,23 +29,43 @@ def load_form_tokens_by_form_id(
return {}
if session is not None:
return _load_form_tokens_by_form_id(session, unique_form_ids)
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
with Session(bind=db.engine, expire_on_commit=False) as new_session:
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
tokens_by_form_id: dict[str, tuple[int, str]] = {}
def _load_form_tokens_by_form_id(
session: Session,
form_ids: Sequence[str],
*,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(stmt):
priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type)
if priority is None or not recipient.access_token:
if not recipient.access_token:
continue
recipients_by_form_id.setdefault(recipient.form_id, []).append(
(recipient.recipient_type, recipient.access_token)
)
candidate = (priority, recipient.access_token)
current = tokens_by_form_id.get(recipient.form_id)
if current is None or candidate[0] < current[0]:
tokens_by_form_id[recipient.form_id] = candidate
tokens_by_form_id: dict[str, str] = {}
for form_id, recipients in recipients_by_form_id.items():
token = _get_surface_form_token(recipients, surface=surface)
if token is not None:
tokens_by_form_id[form_id] = token
return tokens_by_form_id
return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()}
def _get_surface_form_token(
recipients: Sequence[tuple[RecipientType, str]],
*,
surface: HumanInputSurface | None,
) -> str | None:
if surface == HumanInputSurface.SERVICE_API:
for recipient_type, token in recipients:
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
return token
return get_preferred_form_token(recipients)

View File

@ -0,0 +1,73 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from graphon.entities.pause_reason import PauseReasonType
from models.human_input import RecipientType
class HumanInputSurface(StrEnum):
SERVICE_API = "service_api"
CONSOLE = "console"
# Service API is intentionally narrower than other surfaces: app-token callers
# should only be able to act on end-user web forms, not internal console flows.
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
}
# A single HITL form can have multiple recipient records; this shared priority
# keeps every API surface consistent about which resume token to expose.
_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def is_recipient_type_allowed_for_surface(
recipient_type: RecipientType | None,
surface: HumanInputSurface,
) -> bool:
if recipient_type is None:
return False
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
def get_preferred_form_token(
recipients: Sequence[tuple[RecipientType, str]],
) -> str | None:
chosen_token: str | None = None
chosen_priority: int | None = None
for recipient_type, token in recipients:
priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type)
if priority is None or not token:
continue
if chosen_priority is None or priority < chosen_priority:
chosen_priority = priority
chosen_token = token
return chosen_token
def enrich_human_input_pause_reasons(
reasons: Sequence[Mapping[str, Any]],
*,
form_tokens_by_form_id: Mapping[str, str],
expiration_times_by_form_id: Mapping[str, int],
) -> list[dict[str, Any]]:
enriched: list[dict[str, Any]] = []
for reason in reasons:
updated = dict(reason)
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_id = updated.get("form_id")
if isinstance(form_id, str):
updated["form_token"] = form_tokens_by_form_id.get(form_id)
expiration_time = expiration_times_by_form_id.get(form_id)
if expiration_time is not None:
updated["expiration_time"] = expiration_time
enriched.append(updated)
return enriched

View File

@ -225,8 +225,10 @@ class TestSpanBuilder:
span = builder.build_span(span_data)
assert isinstance(span, ReadableSpan)
assert span.name == "test-span"
assert span.context is not None
assert span.context.trace_id == 123
assert span.context.span_id == 456
assert span.parent is not None
assert span.parent.span_id == 789
assert span.resource == resource
assert span.attributes == {"attr1": "val1"}

View File

@ -64,12 +64,13 @@ class TestSpanData:
def test_span_data_missing_required_fields(self):
with pytest.raises(ValidationError):
SpanData(
trace_id=123,
# span_id missing
name="test_span",
start_time=1000,
end_time=2000,
SpanData.model_validate(
{
"trace_id": 123,
"name": "test_span",
"start_time": 1000,
"end_time": 2000,
}
)
def test_span_data_arbitrary_types_allowed(self):

View File

@ -2,12 +2,14 @@ from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
import pytest
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from dify_trace_aliyun.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
@ -44,7 +46,7 @@ class RecordingTraceClient:
self.endpoint = endpoint
self.added_spans: list[object] = []
def add_span(self, span) -> None:
def add_span(self, span: object) -> None:
self.added_spans.append(span)
def api_check(self) -> bool:
@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags.SAMPLED,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return Link(context)
def _make_trace_metadata(
trace_id: int = 1,
workflow_span_id: int = 2,
session_id: str = "s",
user_id: str = "u",
links: list[Link] | None = None,
) -> TraceMetadata:
return TraceMetadata(
trace_id=trace_id,
workflow_span_id=workflow_span_id,
session_id=session_id,
user_id=user_id,
links=[] if links is None else links,
)
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
return cast(RecordingTraceClient, trace_instance.trace_client)
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
defaults = {
"workflow_id": "workflow-id",
@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
trace_instance.workflow_trace(trace_info)
add_workflow_span.assert_called_once()
passed_trace_metadata = add_workflow_span.call_args.args[1]
passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
assert passed_trace_metadata.trace_id == 111
assert passed_trace_metadata.workflow_span_id == 222
assert passed_trace_metadata.session_id == "c"
assert passed_trace_metadata.user_id == "u"
assert passed_trace_metadata.links == []
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_message_trace_info(message_data=None)
trace_instance.message_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
)
trace_instance.message_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 2
message_span, llm_span = trace_instance.trace_client.added_spans
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span, llm_span = spans
assert message_span.name == "message"
assert message_span.trace_id == 10
@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
trace_instance.dataset_retrieval_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "dataset_retrieval"
assert span.attributes[RETRIEVAL_QUERY] == "query"
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_tool_trace_info(message_data=None)
trace_instance.tool_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
)
)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "my-tool"
assert span.status == status
assert span.attributes[TOOL_NAME] == "my-tool"
@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
node_execution.node_type = BuiltinNodeTypes.CODE
@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "title"
@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
trace_metadata = _make_trace_metadata(links=[_make_link()])
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "my-tool"
@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "retrieval"
@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "llm"
@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
# CASE 1: With message_id
trace_info = _make_workflow_trace_info(
@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
)
trace_instance.add_workflow_span(trace_info, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 2
message_span = trace_instance.trace_client.added_spans[0]
workflow_span = trace_instance.trace_client.added_spans[1]
client = _recording_trace_client(trace_instance)
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span = spans[0]
workflow_span = spans[1]
assert message_span.name == "message"
assert message_span.span_kind == SpanKind.SERVER
@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
assert workflow_span.span_kind == SpanKind.INTERNAL
assert workflow_span.parent_span_id == 20
trace_instance.trace_client.added_spans.clear()
client.added_spans.clear()
# CASE 2: Without message_id
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "workflow"
assert span.span_kind == SpanKind.SERVER
assert span.parent_span_id is None
@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
trace_instance.suggested_question_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "suggested_question"
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'

View File

@ -1,4 +1,6 @@
import json
from collections.abc import Mapping
from typing import Any, cast
from unittest.mock import MagicMock
from dify_trace_aliyun.entities.semconv import (
@ -170,7 +172,7 @@ def test_create_common_span_attributes():
def test_format_retrieval_documents():
# Not a list
assert format_retrieval_documents("not a list") == []
assert format_retrieval_documents(cast(list[object], "not a list")) == []
# Valid list
docs = [
@ -211,7 +213,7 @@ def test_format_retrieval_documents():
def test_format_input_messages():
# Not a dict
assert format_input_messages(None) == serialize_json_data([])
assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No prompts
assert format_input_messages({}) == serialize_json_data([])
@ -244,7 +246,7 @@ def test_format_input_messages():
def test_format_output_messages():
# Not a dict
assert format_output_messages(None) == serialize_json_data([])
assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No text
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])

View File

@ -25,13 +25,13 @@ class TestAliyunConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
AliyunConfig.model_validate({})
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
AliyunConfig.model_validate({"license_key": "test_license"})
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""

View File

@ -1,4 +1,5 @@
from datetime import UTC, datetime, timedelta
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
@ -129,7 +130,7 @@ def test_set_span_status():
return "SilentErrorRepr"
span.reset_mock()
set_span_status(span, SilentError())
set_span_status(span, cast(Exception | str | None, SilentError()))
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"

View File

@ -28,13 +28,13 @@ class TestLangfuseConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangfuseConfig()
LangfuseConfig.model_validate({})
with pytest.raises(ValidationError):
LangfuseConfig(public_key="public")
LangfuseConfig.model_validate({"public_key": "public"})
with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret")
LangfuseConfig.model_validate({"secret_key": "secret"})
def test_host_validation_empty(self):
"""Test host validation with empty value"""

View File

@ -2,6 +2,7 @@
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock, patch
from dify_trace_langfuse.config import LangfuseConfig
@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime:
assert trace._get_completion_start_time(start_time, None) is None
assert trace._get_completion_start_time(start_time, -1) is None
assert trace._get_completion_start_time(start_time, "invalid") is None
assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None

View File

@ -21,13 +21,13 @@ class TestLangSmithConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangSmithConfig()
LangSmithConfig.model_validate({})
with pytest.raises(ValidationError):
LangSmithConfig(api_key="key")
LangSmithConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError):
LangSmithConfig(project="project")
LangSmithConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""

View File

@ -599,7 +599,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_instance.message_trace(_make_message_trace_info())
mock_tracing["start"].assert_called_once()
@ -609,7 +608,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(error="something broke")
trace_instance.message_trace(trace_info)
@ -620,7 +618,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
monkeypatch.setenv("FILES_URL", "http://files.test")
file_data = SimpleNamespace(url="path/to/file.png")
@ -638,7 +635,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
trace_instance.message_trace(trace_info)
@ -651,7 +647,6 @@ class TestMessageTrace:
end_user = MagicMock()
end_user.session_id = "session-xyz"
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
trace_info = _make_message_trace_info(
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
@ -664,7 +659,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(
metadata={"from_account_id": "acc-1"},

View File

@ -12,6 +12,7 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import cast
from unittest.mock import MagicMock, patch
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace:
return instance
def _add_trace_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_trace)
def _add_span_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_span)
# ---------------------------------------------------------------------------
# _seed_to_uuid4
# ---------------------------------------------------------------------------
@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId:
def test_root_span_is_created(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
assert instance.add_span.called
assert _add_span_mock(instance).called
def test_root_span_id_matches_expected(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
expected = self._expected_root_span_id(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected
def test_root_span_has_no_parent(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["parent_span_id"] is None
def test_trace_name_is_workflow_trace(self):
@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId:
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_name_is_workflow_trace(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_has_workflow_tag(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert "workflow" in root_span_kwargs["tags"]
def test_node_execution_spans_are_parented_to_root(self):
@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
# call_args_list[0] = root span, [1] = node execution span
assert instance.add_span.call_count == 2
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
add_span = _add_span_mock(instance)
assert add_span.call_count == 2
node_span_kwargs = add_span.call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
def test_node_span_not_parented_to_workflow_app_log_id(self):
@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] != old_parent_id
def test_root_span_id_differs_from_trace_id(self):
@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId:
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
def test_root_span_uses_workflow_run_id_directly(self):
@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info)
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected_root_span_id
def test_root_span_id_differs_from_no_message_id_case(self):
@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import sys
import types
from types import SimpleNamespace
from typing import Any, TypedDict, cast
from unittest.mock import MagicMock
import pytest
@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags
metric_reader_instances: list[DummyMetricReader] = []
meter_provider_instances: list[DummyMeterProvider] = []
@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality:
self.kwargs = kwargs
class PatchedCoreComponents(TypedDict):
span_exporter: MagicMock
span_processor: MagicMock
tracer: MagicMock
span: MagicMock
tracer_provider: MagicMock
logger: MagicMock
trace_api: Any
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
"""Drop fake metric modules into sys.modules so the client imports resolve."""
@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.fixture(autouse=True)
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents:
span_exporter = MagicMock(name="span_exporter")
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
}
def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext:
return SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
def _build_client() -> TencentTraceClient:
return TencentTraceClient(
service_name="service",
@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st
def test_resolve_grpc_target_handles_errors() -> None:
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317)
@pytest.mark.parametrize(
@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None:
client.record_trace_duration(0.5)
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str,
logger.debug.assert_called()
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
ctx = _make_span_context(span_id=2)
span.get_span_context.return_value = ctx
data = SpanData(
trace_id=1,
@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str,
span.add_event.assert_called_once()
span.set_status.assert_called_once()
span.end.assert_called_once_with(end_time=20)
assert client.span_contexts[2] == "ctx"
assert client.span_contexts[2] == ctx
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
client.span_contexts[10] = "existing"
existing_context = _make_span_context(span_id=10)
client.span_contexts[10] = existing_context
span = patch_core_components["span"]
span.get_span_context.return_value = "child"
span.get_span_context.return_value = _make_span_context(span_id=11)
data = SpanData(
trace_id=1,
@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[
client._create_and_export_span(data)
trace_api = patch_core_components["trace_api"]
trace_api.NonRecordingSpan.assert_called_once_with("existing")
trace_api.NonRecordingSpan.assert_called_once_with(existing_context)
trace_api.set_span_in_context.assert_called_once()
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
span.get_span_context.return_value = _make_span_context(span_id=2)
client.tracer.start_span.side_effect = RuntimeError("boom")
client._create_and_export_span(
@ -385,7 +407,7 @@ def test_get_project_url() -> None:
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span_processor = patch_core_components["span_processor"]
tracer_provider = patch_core_components["tracer_provider"]
@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object
metric_reader.shutdown.assert_called_once()
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
meter_provider = meter_provider_instances[-1]
meter_provider.shutdown.side_effect = RuntimeError("boom")
assert client.metric_reader is not None
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
client.shutdown()
@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p
assert client.metric_reader is None
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com
logger.exception.assert_called_once()
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
span.get_span_context.return_value = _make_span_context(span_id=2)
data = SpanData.model_construct(
trace_id=1,
@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_llm_duration")
client.hist_llm_duration = hist_mock
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2}))
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["foo"], str)
assert attrs["bar"] == 2
@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_trace_duration")
client.hist_trace_duration = hist_mock
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True}))
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["meta"], str)
assert attrs["ok"] is True
@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None:
],
)
def test_record_methods_handle_exceptions(
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents
) -> None:
client = _build_client()
hist_mock = MagicMock(name=attr_name)
@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions(
def test_metrics_initializes_grpc_metric_exporter() -> None:
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
assert isinstance(exporter, DummyGrpcMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert metric_reader.exporter.kwargs["insecure"] is False
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert exporter.kwargs["insecure"] is False
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyHttpMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
assert isinstance(exporter, DummyHttpMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert exporter.kwargs["endpoint"] == client.endpoint
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
assert isinstance(exporter, DummyJsonMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in metric_reader.exporter.kwargs
assert exporter.kwargs["endpoint"] == client.endpoint
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in exporter.kwargs
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
_ = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in metric_reader.exporter.kwargs
assert isinstance(exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in exporter.kwargs
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:

View File

@ -31,13 +31,13 @@ class TestWeaveConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
WeaveConfig()
WeaveConfig.model_validate({})
with pytest.raises(ValidationError):
WeaveConfig(api_key="key")
WeaveConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError):
WeaveConfig(project="project")
WeaveConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""

View File

@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
auth = PasswordAuthenticator(config.user, config.password)
options = ClusterOptions(auth)
self._cluster = Cluster(config.connection_string, options)
self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
self._bucket = self._cluster.bucket(config.bucket_name)
self._scope = self._bucket.scope(config.scope_name)
self._bucket_name = config.bucket_name
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, TypedDict
from typing import Any, TypedDict, cast
from packaging import version
from pydantic import BaseModel, model_validator
@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
def _load_collection_fields(self, fields: list[str] | None = None):
if fields is None:
# Load collection fields from remote server
collection_info = self._client.describe_collection(self._collection_name)
collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
fields = [field["name"] for field in collection_info["fields"]]
# Since primary field is auto-id, no need to track it
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
return False
try:
milvus_version = self._client.get_server_version()
milvus_version_raw = self._client.get_server_version()
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
if "Zilliz Cloud" in milvus_version:
return True

View File

@ -3,7 +3,7 @@ import json
import logging
import re
import uuid
from typing import Any
from typing import Any, TypedDict
import jieba.posseg as pseg # type: ignore
import numpy
@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
oracledb.defaults.fetch_lobs = False
class _OraclePoolParams(TypedDict, total=False):
user: str
password: str
dsn: str
min: int
max: int
increment: int
config_dir: str | None
wallet_location: str | None
wallet_password: str | None
class OracleVectorConfig(BaseModel):
user: str
password: str
@ -127,22 +139,18 @@ class OracleVector(BaseVector):
return connection
def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = {
"user": config.user,
"password": config.password,
"dsn": config.dsn,
"min": 1,
"max": 5,
"increment": 1,
}
pool_params = _OraclePoolParams(
user=config.user,
password=config.password,
dsn=config.dsn,
min=1,
max=5,
increment=1,
)
if config.is_autonomous:
pool_params.update(
{
"config_dir": config.config_dir,
"wallet_location": config.wallet_location,
"wallet_password": config.wallet_password,
}
)
pool_params["config_dir"] = config.config_dir
pool_params["wallet_location"] = config.wallet_location
pool_params["wallet_password"] = config.wallet_password
return oracledb.create_pool(**pool_params)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):

View File

@ -42,7 +42,7 @@ from libs.helper import convert_datetime_to_date
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from models.enums import WorkflowRunTriggeredFrom
from models.human_input import HumanInputForm
from models.human_input import HumanInputForm, HumanInputFormRecipient
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
from repositories.entities.workflow_pause import WorkflowPauseEntity
@ -63,6 +63,7 @@ class _WorkflowRunError(Exception):
def _build_human_input_required_reason(
reason_model: WorkflowPauseReason,
form_model: HumanInputForm | None,
recipients: Sequence[HumanInputFormRecipient] = (),
) -> HumanInputRequired:
form_content = ""
inputs = []
@ -89,7 +90,7 @@ def _build_human_input_required_reason(
resolved_default_values = dict(definition.default_values)
node_title = definition.node_title or node_title
return HumanInputRequired(
reason = HumanInputRequired(
form_id=form_id,
form_content=form_content,
inputs=inputs,
@ -98,6 +99,7 @@ def _build_human_input_required_reason(
node_title=node_title,
resolved_default_values=resolved_default_values,
)
return reason
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
@ -804,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
for form in session.scalars(form_stmt).all():
form_models[form.id] = form
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {}
if form_ids:
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(recipient_stmt).all():
recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient)
pause_reasons: list[PauseReason] = []
for reason in pause_reason_models:
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_model = form_models.get(reason.form_id)
pause_reasons.append(_build_human_input_required_reason(reason, form_model))
pause_reasons.append(
_build_human_input_required_reason(
reason,
form_model,
recipients_by_form_id.get(reason.form_id, ()),
)
)
else:
pause_reasons.append(reason.to_entity())
return pause_reasons

View File

@ -162,6 +162,7 @@ class AppGenerateService:
invoke_from=invoke_from,
streaming=True,
call_depth=0,
workflow_run_id=str(uuid.uuid4()),
)
payload_json = payload.model_dump_json()
@ -183,6 +184,10 @@ class AppGenerateService:
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
@ -194,6 +199,7 @@ class AppGenerateService:
invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
pause_state_config=pause_config,
)
),
request_id=request_id,

View File

@ -5,6 +5,7 @@ import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from cachetools.func import ttl_cache
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
class EnterpriseService:
@classmethod
@ttl_cache(ttl=5)
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")

View File

@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel):
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
trial_models: list[str] = []
enable_creators_platform: bool = False
enable_trial_app: bool = False
enable_explore_banner: bool = False
@ -241,6 +242,9 @@ class FeatureService:
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
system_features.enable_creators_platform = True
return system_features
@classmethod

View File

@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID
@ -521,7 +521,7 @@ class BuiltinToolManageService:
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from core.trigger.entities.api_entities import (
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
@ -635,7 +635,7 @@ class TriggerProviderService:
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.message_generator import MessageGenerator
from core.app.entities.task_entities import (
HumanInputRequiredResponse,
MessageReplaceStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
@ -22,10 +23,14 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from graphon.runtime import GraphRuntimeState
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models.human_input import HumanInputForm
from models.model import AppMode, Message
from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
@ -59,8 +64,10 @@ def build_workflow_event_stream(
tenant_id: str,
app_id: str,
session_maker: sessionmaker[Session],
human_input_surface: HumanInputSurface | None = None,
idle_timeout: float = 300,
ping_interval: float = 10.0,
close_on_pause: bool = True,
) -> Generator[Mapping[str, Any] | str, None, None]:
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@ -115,13 +122,15 @@ def build_workflow_event_stream(
message_context=message_context,
pause_entity=pause_entity,
resumption_context=resumption_context,
session_maker=session_maker,
human_input_surface=human_input_surface,
)
for event in snapshot_events:
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event, include_paused=True):
if _is_terminal_event(event, close_on_pause=close_on_pause):
return
while True:
@ -146,7 +155,7 @@ def build_workflow_event_stream(
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event, include_paused=True):
if _is_terminal_event(event, close_on_pause=close_on_pause):
return
finally:
buffer_state.stop_event.set()
@ -207,6 +216,8 @@ def _build_snapshot_events(
message_context: MessageContext | None,
pause_entity: WorkflowPauseEntity | None,
resumption_context: WorkflowResumptionContext | None,
session_maker: sessionmaker[Session] | None = None,
human_input_surface: HumanInputSurface | None = None,
) -> list[Mapping[str, Any]]:
events: list[Mapping[str, Any]] = []
@ -241,12 +252,24 @@ def _build_snapshot_events(
events.append(node_finished)
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
for human_input_event in _build_human_input_required_events(
workflow_run_id=workflow_run.id,
task_id=task_id,
pause_entity=pause_entity,
session_maker=session_maker,
human_input_surface=human_input_surface,
):
_apply_message_context(human_input_event, message_context)
events.append(human_input_event)
pause_event = _build_pause_event(
workflow_run=workflow_run,
workflow_run_id=workflow_run.id,
task_id=task_id,
pause_entity=pause_entity,
resumption_context=resumption_context,
session_maker=session_maker,
human_input_surface=human_input_surface,
)
if pause_event is not None:
_apply_message_context(pause_event, message_context)
@ -314,6 +337,97 @@ def _build_node_started_event(
return response.to_ignore_detail_dict()
def _build_human_input_required_events(
*,
workflow_run_id: str,
task_id: str,
pause_entity: WorkflowPauseEntity,
session_maker: sessionmaker[Session] | None,
human_input_surface: HumanInputSurface | None,
) -> list[dict[str, Any]]:
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
human_input_form_ids = [
form_id
for reason in reasons
if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED
for form_id in [reason.get("form_id")]
if isinstance(form_id, str)
]
expiration_times_by_form_id: dict[str, int] = {}
display_in_ui_by_form_id: dict[str, bool] = {}
form_tokens_by_form_id: dict[str, str] = {}
if human_input_form_ids and session_maker is not None:
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where(
HumanInputForm.id.in_(human_input_form_ids)
)
with session_maker() as session:
for form_id, expiration_time, form_definition in session.execute(stmt):
expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
try:
definition_payload = json.loads(form_definition) if form_definition else {}
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_tokens_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=human_input_surface,
)
events: list[dict[str, Any]] = []
for reason in reasons:
if reason.get("TYPE") != PauseReasonType.HUMAN_INPUT_REQUIRED:
continue
form_id_raw = reason.get("form_id")
node_id_raw = reason.get("node_id")
node_title_raw = reason.get("node_title")
form_content_raw = reason.get("form_content")
if not isinstance(form_id_raw, str):
continue
if not isinstance(node_id_raw, str):
continue
if not isinstance(node_title_raw, str):
continue
if not isinstance(form_content_raw, str):
continue
form_id = form_id_raw
node_id = node_id_raw
node_title = node_title_raw
form_content = form_content_raw
inputs = reason.get("inputs")
actions = reason.get("actions")
resolved_default_values = reason.get("resolved_default_values")
expiration_time = expiration_times_by_form_id.get(form_id)
if expiration_time is None:
continue
response = HumanInputRequiredResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=HumanInputRequiredResponse.Data(
form_id=form_id,
node_id=node_id,
node_title=node_title,
form_content=form_content,
inputs=inputs if isinstance(inputs, list) else [],
actions=actions if isinstance(actions, list) else [],
display_in_ui=display_in_ui_by_form_id.get(form_id, False),
form_token=form_tokens_by_form_id.get(form_id),
resolved_default_values=(resolved_default_values if isinstance(resolved_default_values, dict) else {}),
expiration_time=expiration_time,
),
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
events.append(payload)
return events
def _build_node_finished_event(
*,
workflow_run_id: str,
@ -356,6 +470,8 @@ def _build_pause_event(
task_id: str,
pause_entity: WorkflowPauseEntity,
resumption_context: WorkflowResumptionContext | None,
session_maker: sessionmaker[Session] | None,
human_input_surface: HumanInputSurface | None = None,
) -> dict[str, Any] | None:
paused_nodes: list[str] = []
outputs: dict[str, Any] = {}
@ -365,6 +481,36 @@ def _build_pause_event(
outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {}))
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
human_input_form_ids = [
form_id
for reason in reasons
if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED
for form_id in [reason.get("form_id")]
if isinstance(form_id, str)
]
form_tokens_by_form_id: dict[str, str] = {}
expiration_times_by_form_id: dict[str, int] = {}
if human_input_form_ids and session_maker is not None:
with session_maker() as session:
form_tokens_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=human_input_surface,
)
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
HumanInputForm.id.in_(human_input_form_ids)
)
for row in session.execute(stmt):
form_id, expiration_time, *_rest = row
expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
reasons = enrich_human_input_pause_reasons(
reasons,
form_tokens_by_form_id=form_tokens_by_form_id,
expiration_times_by_form_id=expiration_times_by_form_id,
)
response = WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
@ -449,12 +595,19 @@ def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
return event
def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool:
def _is_terminal_event(
event: Mapping[str, Any] | str,
close_on_pause: bool = True,
*,
include_paused: bool | None = None,
) -> bool:
if include_paused is not None:
close_on_pause = include_paused
if not isinstance(event, Mapping):
return False
event_type = event.get("event")
if event_type == StreamEvent.WORKFLOW_FINISHED.value:
return True
if include_paused:
if close_on_pause:
return event_type == StreamEvent.WORKFLOW_PAUSED.value
return False

View File

@ -399,6 +399,8 @@ def _resume_advanced_chat(
workflow_run_id: str,
workflow_run: WorkflowRun,
) -> None:
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
@ -426,7 +428,7 @@ def _resume_advanced_chat(
user=user,
conversation=conversation,
message=message,
application_generate_entity=generate_entity,
application_generate_entity=resumed_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_runtime_state=graph_runtime_state,
@ -436,9 +438,8 @@ def _resume_advanced_chat(
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
raise
if generate_entity.stream:
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
def _resume_workflow(
@ -455,6 +456,8 @@ def _resume_workflow(
workflow_run_repo,
pause_entity,
) -> None:
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
@ -480,7 +483,7 @@ def _resume_workflow(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=generate_entity,
application_generate_entity=resumed_generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
@ -490,11 +493,18 @@ def _resume_workflow(
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
raise
if generate_entity.stream:
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
workflow_run_repo.delete_workflow_pause(pause_entity)
try:
workflow_run_repo.delete_workflow_pause(pause_entity)
except Exception as exc:
if exc.__class__.__name__ != "_WorkflowRunError" or "WorkflowPause not found" not in str(exc):
raise
logger.info(
"Skipped deleting workflow pause %s after resume because it was already replaced or removed",
pause_entity.id,
)
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution")

View File

@ -171,35 +171,13 @@ class TestChatMessageApiPermissions:
parent_message_id=None,
)
class MockQuery:
def __init__(self, model):
self.model = model
def where(self, *args, **kwargs):
return self
def first(self):
if getattr(self.model, "__name__", "") == "Conversation":
return mock_conversation
return None
def order_by(self, *args, **kwargs):
return self
def limit(self, *_):
return self
def all(self):
if getattr(self.model, "__name__", "") == "Message":
return [mock_message]
return []
mock_session = mock.Mock()
mock_session.query.side_effect = MockQuery
mock_session.scalar.return_value = False
mock_session.scalar.return_value = mock_conversation
mock_session.scalars.return_value.all.return_value = [mock_message]
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
monkeypatch.setattr(message_api, "current_user", mock_account)
monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock())
class DummyPagination:
def __init__(self, data, limit, has_more):

View File

@ -24,7 +24,6 @@ def _patch_wraps():
patch("controllers.console.wraps.dify_config", dify_settings),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import secrets
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from unittest.mock import Mock
@ -11,6 +12,7 @@ import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.human_input_adapter import DeliveryMethodType
from extensions.ext_storage import storage
from graphon.entities import WorkflowExecution
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
@ -20,9 +22,11 @@ from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.human_input import (
BackstageRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
@ -628,12 +632,12 @@ class TestPrivateWorkflowPauseEntity:
class TestBuildHumanInputRequiredReason:
"""Integration tests for _build_human_input_required_reason using real DB models."""
def test_builds_reason_from_form_definition(
def test_prefers_standalone_web_app_token_when_available(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Build the graph pause reason from the stored form definition."""
"""Use the public standalone web-app token for service API payloads."""
expiration_time = naive_utc_now()
form_definition = FormDefinition(
@ -660,6 +664,40 @@ class TestBuildHumanInputRequiredReason:
db_session_with_containers.add(form_model)
db_session_with_containers.flush()
delivery = HumanInputDelivery(
form_id=form_model.id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload="{}",
)
db_session_with_containers.add(delivery)
db_session_with_containers.flush()
backstage_access_token = secrets.token_urlsafe(8)
backstage_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=backstage_access_token,
)
console_access_token = secrets.token_urlsafe(8)
console_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.CONSOLE,
recipient_payload="{}",
access_token=console_access_token,
)
web_app_access_token = secrets.token_urlsafe(8)
web_app_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
recipient_payload="{}",
access_token=web_app_access_token,
)
db_session_with_containers.add_all([backstage_recipient, console_recipient, web_app_recipient])
db_session_with_containers.flush()
# Create a pause so the reason has a valid pause_id
workflow_run = _create_workflow_run(
db_session_with_containers,
@ -688,8 +726,15 @@ class TestBuildHumanInputRequiredReason:
# Refresh to ensure we have DB-round-tripped objects
db_session_with_containers.refresh(form_model)
db_session_with_containers.refresh(reason_model)
db_session_with_containers.refresh(backstage_recipient)
db_session_with_containers.refresh(console_recipient)
db_session_with_containers.refresh(web_app_recipient)
reason = _build_human_input_required_reason(reason_model, form_model)
reason = _build_human_input_required_reason(
reason_model,
form_model,
[backstage_recipient, console_recipient, web_app_recipient],
)
assert isinstance(reason, HumanInputRequired)
assert reason.node_title == "Ask Name"
@ -697,3 +742,92 @@ class TestBuildHumanInputRequiredReason:
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"
assert reason.resolved_default_values == {"name": "Alice"}
assert not hasattr(reason, "form_token")
def test_falls_back_to_console_token_when_web_app_token_missing(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Use the console token only when no standalone web-app token exists."""
expiration_time = naive_utc_now()
form_definition = FormDefinition(
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "Alice"},
node_title="Ask Name",
display_in_ui=True,
)
form_model = HumanInputForm(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
workflow_run_id=str(uuid4()),
node_id="node-1",
form_definition=form_definition.model_dump_json(),
rendered_content="rendered",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
db_session_with_containers.add(form_model)
db_session_with_containers.flush()
delivery = HumanInputDelivery(
form_id=form_model.id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload="{}",
)
db_session_with_containers.add(delivery)
db_session_with_containers.flush()
backstage_access_token = secrets.token_urlsafe(8)
backstage_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=backstage_access_token,
)
console_access_token = secrets.token_urlsafe(8)
console_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.CONSOLE,
recipient_payload="{}",
access_token=console_access_token,
)
db_session_with_containers.add_all([backstage_recipient, console_recipient])
db_session_with_containers.flush()
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause = WorkflowPause(
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
db_session_with_containers.add(pause)
db_session_with_containers.flush()
test_scope.state_keys.add(pause.state_object_key)
reason_model = WorkflowPauseReason(
pause_id=pause.id,
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id=form_model.id,
node_id="node-1",
message="",
)
db_session_with_containers.add(reason_model)
db_session_with_containers.commit()
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient, console_recipient])
assert isinstance(reason, HumanInputRequired)
assert not hasattr(reason, "form_token")

View File

@ -13,6 +13,12 @@ from models.model import App, Conversation, Message
from services.feedback_service import FeedbackService
def _execute_result(rows):
result = mock.Mock()
result.all.return_value = rows
return result
class TestFeedbackService:
"""Test FeedbackService methods."""
@ -81,25 +87,17 @@ class TestFeedbackService:
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in CSV format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
)
# Test CSV export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@ -120,25 +118,17 @@ class TestFeedbackService:
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in JSON format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
)
# Test JSON export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
@ -157,25 +147,17 @@ class TestFeedbackService:
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
"""Test exporting feedback with various filters."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
)
# Test with filters
result = FeedbackService.export_feedbacks(
@ -193,17 +175,7 @@ class TestFeedbackService:
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
"""Test exporting feedback when no data exists."""
# Setup mock query result with no data
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result([])
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@ -251,24 +223,17 @@ class TestFeedbackService:
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
long_message,
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
long_message,
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
@ -309,24 +274,17 @@ class TestFeedbackService:
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
chinese_feedback,
chinese_message,
sample_data["conversation"],
sample_data["app"],
None, # No account for user feedback
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
chinese_feedback,
chinese_message,
sample_data["conversation"],
sample_data["app"],
None,
)
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@ -339,32 +297,24 @@ class TestFeedbackService:
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
"""Test that rating emojis are properly formatted in export."""
# Setup mock query result with both like and dislike feedback
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
),
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
),
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
),
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
),
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")

View File

@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine):
configure_session_factory(_unit_test_engine, expire_on_commit=False)
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner):
"""
Helper to set up the mock DB execute chain for tenant/account authentication.
Helper to stub the tenant-owner execute result for service API app authentication.
This configures the mock to return (tenant, account) for the
db.session.execute(select(...).join().join().where()).one_or_none()
query used by validate_app_token decorator.
The validate_app_token decorator currently resolves the active tenant owner
via db.session.execute(select(Tenant, Account)...).one_or_none().
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_account: Mock account object to return
mock_owner: Mock owner object to return from the execute result
"""
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner)
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join):
"""
Helper to set up the mock DB execute chain for dataset tenant authentication.
Helper to stub the tenant-owner execute result for dataset token authentication.
This configures the mock to return (tenant, tenant_account) for the
db.session.execute(select(...).where().where().where().where()).one_or_none()
query used by validate_dataset_token decorator.
The validate_dataset_token decorator currently resolves the owner mapping via
db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and
then loads the Account separately via db.session.get(...).
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_ta: Mock tenant account object to return
mock_tenant_account_join: Mock tenant-account join object to return
"""
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join)

View File

@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation:
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with (
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")

View File

@ -43,7 +43,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = True
# Act
@ -76,7 +75,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Act
with self.app.test_request_context(
@ -109,7 +107,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = False
# Act
@ -135,7 +132,6 @@ class TestAuthenticationSecurity:
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
"""Test that reset password returns success with token for existing accounts."""
# Mock the setup check
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Test with existing account
mock_get_user.return_value = MagicMock(email="existing@example.com")

View File

@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi:
- IP rate limiting is checked
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "email_token_123"
@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi:
- Registration is allowed by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = True
@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi:
- Registration is blocked by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = False
@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi:
- Prevents spam and abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.side_effect = AccountRegisterError("Account frozen")
@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi:
- Defaults to en-US when not specified
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "token"
@ -286,7 +280,6 @@ class TestEmailCodeLoginApi:
- User is logged in with token pair
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
@ -335,7 +328,6 @@ class TestEmailCodeLoginApi:
- User is logged in after account creation
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
mock_get_user.return_value = None
mock_create_account.return_value = mock_account
@ -369,7 +361,6 @@ class TestEmailCodeLoginApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
@ -392,7 +383,6 @@ class TestEmailCodeLoginApi:
- InvalidEmailError is raised when email doesn't match token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
# Act & Assert
@ -415,7 +405,6 @@ class TestEmailCodeLoginApi:
- EmailCodeError is raised for wrong verification code
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
# Act & Assert
@ -453,7 +442,6 @@ class TestEmailCodeLoginApi:
- User is added as owner of new workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
@ -496,7 +484,6 @@ class TestEmailCodeLoginApi:
- WorkspacesLimitExceeded is raised when limit reached
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
@ -538,7 +525,6 @@ class TestEmailCodeLoginApi:
- NotAllowedCreateWorkspace is raised when creation disabled
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []

View File

@ -110,7 +110,6 @@ class TestLoginApi:
- Rate limit is reset after successful login
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
@ -162,7 +161,6 @@ class TestLoginApi:
- Authentication proceeds with invitation token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
mock_authenticate.return_value = mock_account
@ -199,7 +197,6 @@ class TestLoginApi:
- EmailPasswordLoginLimitError is raised when limit exceeded
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
mock_get_invitation.return_value = None
@ -228,7 +225,6 @@ class TestLoginApi:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_frozen.return_value = True
# Act & Assert
@ -268,7 +264,6 @@ class TestLoginApi:
- Generic error message prevents user enumeration
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
@ -305,7 +300,6 @@ class TestLoginApi:
- Login is prevented even with valid credentials
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountLoginError("Account is banned")
@ -351,7 +345,6 @@ class TestLoginApi:
- User cannot login without an assigned workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
@ -383,7 +376,6 @@ class TestLoginApi:
- Security check prevents invitation token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
@ -425,7 +417,6 @@ class TestLoginApi:
mock_token_pair,
):
"""Test that login retries with lowercase email when uppercase lookup fails."""
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
@ -459,7 +450,6 @@ class TestLoginApi:
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")
@ -513,7 +503,6 @@ class TestLogoutApi:
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_current_account.return_value = (mock_account, MagicMock())
# Act
@ -539,7 +528,6 @@ class TestLogoutApi:
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
# Create a mock anonymous user that will pass isinstance check
anonymous_user = MagicMock()
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})

View File

@ -46,7 +46,6 @@ class TestPartnerTenants:
patch("libs.login.dify_config.LOGIN_DISABLED", False),
patch("libs.login.check_csrf_token") as mock_csrf,
):
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_csrf.return_value = None
yield {"db": mock_db, "csrf": mock_csrf}

View File

@ -8,8 +8,10 @@ from werkzeug.exceptions import Forbidden
import controllers.console.tag.tags as module
from controllers.console import console_ns
from controllers.console.tag.tags import (
TagBindingCreateApi,
TagBindingDeleteApi,
DeprecatedTagBindingCreateApi,
DeprecatedTagBindingRemoveApi,
TagBindingCollectionApi,
TagBindingItemApi,
TagListApi,
TagUpdateDeleteApi,
)
@ -205,9 +207,9 @@ class TestTagUpdateDeleteApi:
assert status == 204
class TestTagBindingCreateApi:
class TestTagBindingCollectionApi:
def test_create_success(self, app, admin_user, payload_patch):
api = TagBindingCreateApi()
api = TagBindingCollectionApi()
method = unwrap(api.post)
payload = {
@ -232,7 +234,7 @@ class TestTagBindingCreateApi:
assert result["result"] == "success"
def test_create_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingCreateApi()
api = TagBindingCollectionApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
@ -247,9 +249,78 @@ class TestTagBindingCreateApi:
method(api)
class TestTagBindingDeleteApi:
class TestDeprecatedTagBindingCreateApi:
def test_create_success(self, app, admin_user, payload_patch):
api = DeprecatedTagBindingCreateApi()
method = unwrap(api.post)
payload = {
"tag_ids": ["tag-1"],
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
):
result, status = method(api)
save_mock.assert_called_once()
assert status == 200
assert result["result"] == "success"
class TestTagBindingItemApi:
def test_delete_success(self, app, admin_user, payload_patch):
api = TagBindingItemApi()
method = unwrap(api.delete)
payload = {
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
):
result, status = method(api, "tag-1")
delete_mock.assert_called_once()
delete_payload = delete_mock.call_args.args[0]
assert delete_payload.tag_id == "tag-1"
assert delete_payload.target_id == "target-1"
assert delete_payload.type == TagType.KNOWLEDGE
assert status == 200
assert result["result"] == "success"
def test_delete_forbidden(self, app, readonly_user):
api = TagBindingItemApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
):
with pytest.raises(Forbidden):
method(api, "tag-1")
class TestDeprecatedTagBindingRemoveApi:
def test_remove_success(self, app, admin_user, payload_patch):
api = TagBindingDeleteApi()
api = DeprecatedTagBindingRemoveApi()
method = unwrap(api.post)
payload = {
@ -274,7 +345,7 @@ class TestTagBindingDeleteApi:
assert result["result"] == "success"
def test_remove_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingDeleteApi()
api = DeprecatedTagBindingRemoveApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
@ -297,3 +368,35 @@ class TestTagResponseModel:
assert payload["type"] == "knowledge"
assert payload["binding_count"] == "1"
class TestTagBindingRouteMetadata:
def test_legacy_write_routes_are_marked_deprecated(self):
assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True
assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True
assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True
def test_write_routes_have_stable_operation_ids(self):
assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding"
assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding"
assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated"
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated"
def test_canonical_and_legacy_write_routes_are_registered(self):
route_map = {
resource.__name__: urls
for resource, urls, _route_doc, _kwargs in console_ns.resources
if resource.__name__
in {
"TagBindingCollectionApi",
"TagBindingItemApi",
"DeprecatedTagBindingCreateApi",
"DeprecatedTagBindingRemoveApi",
}
}
assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",)
assert route_map["TagBindingItemApi"] == ("/tag-bindings/<uuid:id>",)
assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",)
assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",)

View File

@ -122,6 +122,35 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch)
handler(api, form_token="token")
def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
submit_mock = Mock()
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)

View File

@ -24,10 +24,6 @@ def app():
return app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
account = Account(name=account_id, email=email)
@ -64,7 +60,6 @@ class TestChangeEmailSend:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
@ -117,7 +112,6 @@ class TestChangeEmailSend:
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
@ -163,7 +157,6 @@ class TestChangeEmailValidity:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
@ -223,7 +216,6 @@ class TestChangeEmailValidity:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@ -280,7 +272,6 @@ class TestChangeEmailValidity:
"""A token whose phase marker is a string but not a known transition must be rejected."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@ -330,7 +321,6 @@ class TestChangeEmailValidity:
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@ -378,7 +368,6 @@ class TestChangeEmailReset:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@ -434,7 +423,6 @@ class TestChangeEmailReset:
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@ -488,7 +476,6 @@ class TestChangeEmailReset:
"""A verified token for address A must not be replayed to change to address B."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@ -561,7 +548,6 @@ class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
_mock_wraps_db(mock_db)
with app.test_request_context(
"/account/delete/feedback",
method="POST",
@ -578,7 +564,6 @@ class TestCheckEmailUnique:
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
_mock_wraps_db(mock_db)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True

View File

@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from flask import Flask, g
@ -16,10 +16,6 @@ def app():
return flask_app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_feature_flags():
placeholder_quota = SimpleNamespace(limit=0, size=0)
workspace_members = SimpleNamespace(is_available=lambda count: True)
@ -49,7 +45,6 @@ class TestMemberInviteEmailApi:
mock_get_features,
app,
):
_mock_wraps_db(mock_db)
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"

View File

@ -310,7 +310,6 @@ class TestSystemSetup:
def test_should_allow_when_setup_complete(self, mock_db):
"""Test that requests are allowed when setup is complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
@setup_required
def admin_view():

View File

@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None
@contextmanager
def _mock_db():
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True)
with patch("extensions.ext_database.db.session", mock_session):
yield

View File

@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter
from controllers.service_api.app.error import AppUnavailableError
from models.account import TenantStatus
from models.model import App, AppMode
from tests.unit_tests.conftest import setup_mock_tenant_account_query
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
class TestAppParameterApi:
@ -74,7 +74,7 @@ class TestAppParameterApi:
# Mock tenant owner info for login
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -120,7 +120,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -161,7 +161,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -200,7 +200,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -263,7 +263,7 @@ class TestAppMetaApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -331,7 +331,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -388,7 +388,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -434,7 +434,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -486,7 +486,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):

View File

@ -0,0 +1,707 @@
"""Dedicated tests for HITL behavior exposed through the Service API."""
from __future__ import annotations
import json
import sys
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import ANY, MagicMock, Mock
import pytest
import services.app_generate_service as ags_module
from controllers.service_api.app.workflow_events import WorkflowEventsApi
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.common import workflow_response_converter
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
HumanInputRequiredResponse,
WorkflowAppPausedBlockingResponse,
WorkflowPauseStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from core.workflow.human_input_policy import HumanInputSurface
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from graphon.nodes.human_input.entities import FormInput, UserAction
from graphon.nodes.human_input.enums import FormInputType
from graphon.runtime import GraphRuntimeState, VariablePool
from models.account import Account
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services.app_generate_service import AppGenerateService
from services.workflow_event_snapshot_service import _build_snapshot_events
from tests.unit_tests.controllers.service_api.conftest import _unwrap
class _DummyRateLimit:
@staticmethod
def gen_request_key() -> str:
return "dummy-request-id"
def __init__(self, client_id: str, max_active_requests: int) -> None:
self.client_id = client_id
self.max_active_requests = max_active_requests
def enter(self, request_id: str | None = None) -> str:
return request_id or "dummy-request-id"
def exit(self, request_id: str) -> None:
return None
def generate(self, generator, request_id: str):
return generator
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
monkeypatch.setattr(
workflow_events_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: repo,
)
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
return workflow_events_module
def _build_service_api_pause_converter() -> WorkflowResponseConverter:
application_generate_entity = SimpleNamespace(
inputs={},
files=[],
invoke_from=InvokeFrom.SERVICE_API,
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
)
system_variables = build_system_variables(
user_id="user",
app_id="app-id",
workflow_id="workflow-id",
workflow_execution_id="run-id",
)
user = MagicMock(spec=Account)
user.id = "account-id"
user.name = "Tester"
user.email = "tester@example.com"
return WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=system_variables,
)
def _build_advanced_chat_paused_blocking_response() -> AdvancedChatPausedBlockingResponse:
data = AdvancedChatPausedBlockingResponse.Data(
id="msg-1",
mode="chat",
conversation_id="c1",
message_id="m1",
workflow_run_id="run-1",
answer="partial",
metadata={"usage": {"total_tokens": 1}},
created_at=1,
paused_nodes=["node-1"],
reasons=[
{
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
"form_id": "form-1",
"expiration_time": 100,
}
],
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
)
return AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse:
return WorkflowAppPausedBlockingResponse(
task_id="t1",
workflow_run_id="r1",
data=WorkflowAppPausedBlockingResponse.Data(
id="r1",
workflow_id="wf-1",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=0.5,
total_tokens=0,
total_steps=2,
created_at=1,
finished_at=None,
paused_nodes=["node-1"],
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}],
),
)
@dataclass(frozen=True)
class _FakePauseEntity(WorkflowPauseEntity):
pause_id: str
workflow_run_id: str
paused_at_value: datetime
pause_reasons: Sequence[HumanInputRequired]
@property
def id(self) -> str:
return self.pause_id
@property
def workflow_execution_id(self) -> str:
return self.workflow_run_id
def get_state(self) -> bytes:
raise AssertionError("state is not required for snapshot tests")
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return self.paused_at_value
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
return self.pause_reasons
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"input": "value"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
created_at = datetime(2024, 1, 1, tzinfo=UTC)
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
return WorkflowNodeExecutionSnapshot(
execution_id="exec-1",
node_id="node-1",
node_type="human-input",
title="Human Input",
index=1,
status=status.value,
elapsed_time=0.5,
created_at=created_at,
finished_at=finished_at,
iteration_id=None,
loop_id=None,
)
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.register_paused_node("node-1")
runtime_state.outputs = {"result": "value"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
class TestHitlServiceApi:
# Service API event-stream continuation
def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
msg_generator.retrieve_events.return_value = ["raw-event"]
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: streamed\n\n"
msg_generator.retrieve_events.assert_called_once_with(
AppMode.WORKFLOW,
"run-1",
terminal_events=[],
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
self, app, monkeypatch: pytest.MonkeyPatch
) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
snapshot_builder = Mock(return_value=["snapshot-events"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(
"/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true",
method="GET",
):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: snapshot\n\n"
msg_generator.retrieve_events.assert_not_called()
snapshot_builder.assert_called_once_with(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=ANY,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=False,
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False)
monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit)
workflow = MagicMock()
workflow.created_by = "owner-id"
monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow)
monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker")
generator_instance = MagicMock()
generator_instance.generate.return_value = {"result": "advanced-blocking"}
generator_instance.convert_to_event_stream.side_effect = lambda payload: payload
monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance)
app_model = MagicMock()
app_model.mode = AppMode.ADVANCED_CHAT
app_model.id = "app-id"
app_model.tenant_id = "tenant-id"
app_model.max_active_requests = 0
app_model.is_agent = False
user = MagicMock()
user.id = "user-id"
result = AppGenerateService.generate(
app_model=app_model,
user=user,
args={"workflow_id": None, "query": "hi", "inputs": {}},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
assert result == {"result": "advanced-blocking"}
call_kwargs = generator_instance.generate.call_args.kwargs
assert call_kwargs["streaming"] is False
assert call_kwargs["pause_state_config"] is not None
assert call_kwargs["pause_state_config"].session_factory == "session-maker"
assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id"
# Blocking payload contract
def test_advanced_chat_blocking_pause_payload_contract(self) -> None:
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(
_build_advanced_chat_paused_blocking_response()
)
assert response["event"] == "workflow_paused"
assert response["workflow_run_id"] == "run-1"
assert response["answer"] == "partial"
assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED
assert response["data"]["reasons"][0]["expiration_time"] == 100
assert "human_input_forms" not in response["data"]
def test_workflow_blocking_pause_payload_contract(self) -> None:
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(
_build_workflow_paused_blocking_response()
)
assert response["workflow_run_id"] == "r1"
assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED
assert response["data"]["paused_nodes"] == ["node-1"]
assert response["data"]["reasons"] == [
{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}
]
assert "human_input_forms" not in response["data"]
def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None:
from core.app.app_config.entities import AppAdditionalFeatures
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from models.enums import MessageStatus
from models.model import EndUser
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.ADVANCED_CHAT,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
query="hello",
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
extras={},
trace_manager=None,
workflow_run_id="run-id",
)
pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT),
message=SimpleNamespace(
id="message-id",
query="hello",
created_at=datetime.utcnow(),
status=MessageStatus.NORMAL,
answer="",
),
user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"),
stream=False,
dialogue_count=1,
draft_var_saver_factory=lambda **kwargs: None,
)
pipeline._task_state.answer = "partial answer"
pipeline._workflow_run_id = "run-id"
def _gen():
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run-id",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Approval",
form_content="Need approval",
inputs=[],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
form_token="token-1",
resolved_default_values={},
expiration_time=123,
),
)
yield WorkflowPauseStreamResponse(
task_id="task",
workflow_run_id="run-id",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run-id",
paused_nodes=["node-1"],
outputs={},
reasons=[
{
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
"form_id": "form-1",
"node_id": "node-1",
"expiration_time": 123,
},
],
status="paused",
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, AdvancedChatPausedBlockingResponse)
assert response.data.answer == "partial answer"
assert response.data.workflow_run_id == "run-id"
assert response.data.reasons[0]["form_id"] == "form-1"
assert response.data.reasons[0]["expiration_time"] == 123
def test_workflow_blocking_pipeline_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
from core.app.apps.workflow import generate_task_pipeline as workflow_pipeline_module
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
trace_manager=None,
workflow_execution_id="run-id",
extras={},
call_depth=0,
)
pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
user=SimpleNamespace(id="user", session_id="session"),
stream=False,
draft_var_saver_factory=lambda **kwargs: None,
)
monkeypatch.setattr(workflow_pipeline_module.time, "time", lambda: 1700000000)
def _gen():
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Human Input",
form_content="content",
expiration_time=1,
),
)
yield WorkflowPauseStreamResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
paused_nodes=["node-1"],
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}],
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, WorkflowAppPausedBlockingResponse)
assert response.data.status == WorkflowExecutionStatus.PAUSED
assert response.data.paused_nodes == ["node-1"]
assert response.data.reasons == [{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}]
def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None:
converter = _build_service_api_pause_converter()
converter.workflow_start_to_stream_response(
task_id="task",
workflow_run_id="run-id",
workflow_id="workflow-id",
reason=WorkflowStartReason.INITIAL,
)
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
class _FakeSession:
def execute(self, _stmt):
return [("form-1", expiration_time, '{"display_in_ui": true}')]
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
workflow_response_converter,
"load_form_tokens_by_form_id",
lambda form_ids, session=None, surface=None: {"form-1": "token"},
)
reason = HumanInputRequired(
form_id="form-1",
form_content="Rendered",
inputs=[
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
node_id="node-id",
node_title="Human Step",
form_token="token",
)
queue_event = QueueWorkflowPausedEvent(
reasons=[reason],
outputs={"answer": "value"},
paused_nodes=["node-id"],
)
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
responses = converter.workflow_pause_to_stream_response(
event=queue_event,
task_id="task",
graph_runtime_state=runtime_state,
)
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
pause_resp = responses[-1]
assert pause_resp.workflow_run_id == "run-id"
assert pause_resp.data.paused_nodes == ["node-id"]
assert pause_resp.data.outputs == {}
assert pause_resp.data.reasons[0]["TYPE"] == "human_input_required"
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
assert pause_resp.data.reasons[0]["form_token"] == "token"
assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp())
assert isinstance(responses[0], HumanInputRequiredResponse)
hi_resp = responses[0]
assert hi_resp.data.form_id == "form-1"
assert hi_resp.data.node_id == "node-id"
assert hi_resp.data.node_title == "Human Step"
assert hi_resp.data.inputs[0].output_variable_name == "field"
assert hi_resp.data.actions[0].id == "approve"
assert hi_resp.data.display_in_ui is True
assert hi_resp.data.form_token == "token"
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
# Snapshot payload contract
def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
resumption_context = _build_resumption_context("task-ctx")
monkeypatch.setattr(
"services.workflow_event_snapshot_service.load_form_tokens_by_form_id",
lambda form_ids, session=None, surface=None: {"form-1": "wtok"},
)
class _SessionContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return False
def session_maker() -> _SessionContext:
return _SessionContext(
SimpleNamespace(
execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')],
)
)
pause_entity = _FakePauseEntity(
pause_id="pause-1",
workflow_run_id="run-1",
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
pause_reasons=[
HumanInputRequired(
form_id="form-1",
form_content="content",
node_id="node-1",
node_title="Human Input",
form_token="wtok",
)
],
)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-ctx",
message_context=None,
pause_entity=pause_entity,
resumption_context=resumption_context,
session_maker=session_maker,
)
assert [event["event"] for event in events] == [
"workflow_started",
"node_started",
"node_finished",
"human_input_required",
"workflow_paused",
]
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
assert events[3]["data"]["form_token"] == "wtok"
assert events[3]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
pause_data = events[-1]["data"]
assert pause_data["paused_nodes"] == ["node-1"]
assert pause_data["outputs"] == {"result": "value"}
assert pause_data["reasons"][0]["TYPE"] == "human_input_required"
assert pause_data["reasons"][0]["form_token"] == "wtok"
assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
assert pause_data["total_tokens"] == workflow_run.total_tokens
assert pause_data["total_steps"] == workflow_run.total_steps

View File

@ -0,0 +1,184 @@
"""Unit tests for Service API human input form endpoints."""
from __future__ import annotations
import json
import sys
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import NotFound
from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi
from models.human_input import RecipientType
from tests.unit_tests.controllers.service_api.conftest import _unwrap
class TestWorkflowHumanInputFormApi:
def test_get_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
definition = SimpleNamespace(
model_dump=lambda: {
"rendered_content": "Rendered form content",
"inputs": [{"output_variable_name": "name"}],
"default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}},
"user_actions": [{"id": "approve", "title": "Approve"}],
}
)
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=RecipientType.STANDALONE_WEB_APP,
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
get_definition=lambda: definition,
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
response = handler(api, app_model=app_model, form_token="token-1")
payload = json.loads(response.get_data(as_text=True))
assert payload == {
"form_content": "Rendered form content",
"inputs": [{"output_variable_name": "name"}],
"resolved_default_values": {"name": "Alice", "age": "30", "meta": '{"k": "v"}'},
"user_actions": [{"id": "approve", "title": "Approve"}],
"expiration_time": int(form.expiration_time.timestamp()),
}
service_mock.get_form_by_token.assert_called_once_with("token-1")
service_mock.ensure_form_active.assert_called_once_with(form)
def test_get_form_not_in_app(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(
app_id="another-app",
tenant_id="tenant-1",
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, form_token="token-1")
@pytest.mark.parametrize(
"recipient_type",
[
RecipientType.CONSOLE,
RecipientType.BACKSTAGE,
RecipientType.EMAIL_MEMBER,
RecipientType.EMAIL_EXTERNAL,
],
)
def test_get_rejects_non_service_api_recipient_types(
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
) -> None:
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=recipient_type,
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, form_token="token-1")
service_mock.ensure_form_active.assert_not_called()
def test_post_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=RecipientType.STANDALONE_WEB_APP,
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(
"/form/human_input/token-1",
method="POST",
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
):
response, status = handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
assert response == {}
assert status == 200
service_mock.submit_form_by_token.assert_called_once_with(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token-1",
selected_action_id="approve",
form_data={"name": "Alice"},
submission_end_user_id="end-user-1",
)
@pytest.mark.parametrize(
"recipient_type",
[
RecipientType.CONSOLE,
RecipientType.BACKSTAGE,
RecipientType.EMAIL_MEMBER,
RecipientType.EMAIL_EXTERNAL,
],
)
def test_post_rejects_non_service_api_recipient_types(
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
) -> None:
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=recipient_type,
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(
"/form/human_input/token-1",
method="POST",
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
service_mock.submit_form_by_token.assert_not_called()

View File

@ -0,0 +1,166 @@
"""Unit tests for Service API workflow event stream endpoints."""
from __future__ import annotations
import json
import sys
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import NotFound
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.app.workflow_events import WorkflowEventsApi
from models.enums import CreatorUserRole
from models.model import AppMode
from tests.unit_tests.controllers.service_api.conftest import _unwrap
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
monkeypatch.setattr(
workflow_events_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: repo,
)
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
return workflow_events_module
class TestWorkflowEventsApi:
def test_wrong_app_mode(self, app) -> None:
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
_mock_repo_for_run(monkeypatch, workflow_run=None)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="another-user",
finished_at=None,
)
_mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=datetime(2099, 1, 1, tzinfo=UTC),
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
monkeypatch.setattr(
workflow_events_module.WorkflowResponseConverter,
"workflow_run_result_to_finish_response",
lambda **_kwargs: SimpleNamespace(
model_dump=lambda mode="json": {"task_id": "run-1", "status": "succeeded"},
event=SimpleNamespace(value="workflow_finished"),
),
)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.mimetype == "text/event-stream"
body = response.get_data(as_text=True).strip()
assert body.startswith("data: ")
payload = json.loads(body[len("data: ") :])
assert payload["task_id"] == "run-1"
assert payload["event"] == "workflow_finished"
def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
msg_generator.retrieve_events.return_value = ["raw-event"]
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: streamed\n\n"
msg_generator.retrieve_events.assert_called_once_with(
AppMode.WORKFLOW,
"run-1",
terminal_events=None,
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
snapshot_builder = Mock(return_value=["snapshot-events"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: snapshot\n\n"
msg_generator.retrieve_events.assert_not_called()
snapshot_builder.assert_called_once()
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])

View File

@ -15,7 +15,10 @@ from flask import Flask
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.account import TenantStatus
from models.model import App, AppMode, EndUser
from tests.unit_tests.conftest import setup_mock_tenant_account_query
from tests.unit_tests.conftest import (
setup_mock_dataset_owner_execute_result,
setup_mock_tenant_owner_execute_result,
)
@pytest.fixture
@ -123,9 +126,7 @@ class AuthenticationMocker:
mock_db.session.get.side_effect = [mock_app, mock_tenant]
if mock_account:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
@staticmethod
def setup_dataset_auth(mock_db, mock_tenant, mock_account):
@ -133,8 +134,7 @@ class AuthenticationMocker:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
mock_db.session.get.return_value = mock_account

View File

@ -701,8 +701,8 @@ class TestDocumentApiDelete:
``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which
internally calls ``validate_and_get_api_token``. To bypass the decorator
we call the original function via ``__wrapped__`` (preserved by
``functools.wraps``). ``delete`` queries the dataset via
``db.session.query(Dataset)`` directly, so we patch ``db`` at the
``functools.wraps``). ``delete`` loads the dataset via
``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the
controller module.
"""

View File

@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan
from models.account import TenantStatus
from models.model import ApiToken
from tests.unit_tests.conftest import (
setup_mock_dataset_tenant_query,
setup_mock_tenant_account_query,
setup_mock_dataset_owner_execute_result,
setup_mock_tenant_owner_execute_result,
)
@ -141,14 +141,11 @@ class TestValidateAppToken:
mock_account = Mock()
mock_account.id = str(uuid.uuid4())
mock_ta = Mock()
mock_ta.account_id = mock_account.id
# Use side_effect to return app first, then tenant via session.get()
mock_db.session.get.side_effect = [mock_app, mock_tenant]
# Mock the tenant owner query (execute(select(...)).one_or_none())
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
# Mock the tenant owner execute result (execute(select(...)).one_or_none())
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
@validate_app_token
def protected_view(app_model):
@ -471,7 +468,7 @@ class TestValidateDatasetToken:
mock_account.current_tenant = mock_tenant
# Mock the tenant account join query (execute(select(...)).one_or_none())
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
# Mock the account lookup via session.get()
mock_db.session.get.return_value = mock_account

View File

@ -22,18 +22,16 @@ class FakeSession:
def __init__(self, mapping: dict[str, Any] | None = None):
self._mapping: dict[str, Any] = mapping or {}
self._model_name: str | None = None
def query(self, model: type) -> FakeSession:
self._model_name = model.__name__
return self
def get(self, model: type, _ident: object) -> Any:
return self._mapping.get(model.__name__)
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
return self
def first(self) -> Any:
assert self._model_name is not None
return self._mapping.get(self._model_name)
def scalar(self, stmt: Any) -> Any:
try:
model = stmt.column_descriptions[0]["entity"]
except (AttributeError, IndexError, KeyError, TypeError):
return None
return self._mapping.get(model.__name__)
class FakeDB:

View File

@ -36,18 +36,6 @@ class _FakeSession:
def __init__(self, mapping: dict[str, Any]):
self._mapping = mapping
self._model_name: str | None = None
def query(self, model):
self._model_name = model.__name__
return self
def where(self, *args, **kwargs):
return self
def first(self):
assert self._model_name is not None
return self._mapping.get(self._model_name)
def get(self, model, ident):
return self._mapping.get(model.__name__)

View File

@ -34,7 +34,6 @@ def _patch_wraps():
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
patch("controllers.web.login.dify_config", web_dify),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield

View File

@ -154,7 +154,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool
@ -301,7 +300,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock ConversationVariable.from_variable to return mock objects
@ -453,7 +451,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool

View File

@ -1,7 +1,10 @@
from collections.abc import Generator
import pytest
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@ -10,7 +13,8 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
)
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
class TestAdvancedChatGenerateResponseConverter:
@ -28,6 +32,37 @@ class TestAdvancedChatGenerateResponseConverter:
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "usage" not in response["metadata"]
def test_blocking_full_response_derives_pause_data_from_model_dump(self, monkeypatch: pytest.MonkeyPatch):
data = AdvancedChatPausedBlockingResponse.Data(
id="msg-1",
mode="chat",
conversation_id="c1",
message_id="m1",
workflow_run_id="run-1",
answer="partial",
metadata={"usage": {"total_tokens": 1}},
created_at=1,
paused_nodes=["node-1"],
reasons=[{"type": PauseReasonType.HUMAN_INPUT_REQUIRED, "form_id": "form-1"}],
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
)
original_model_dump = type(data).model_dump
def _model_dump_with_future_field(self, *args, **kwargs):
payload = original_model_dump(self, *args, **kwargs)
payload["future_field"] = "future-value"
return payload
monkeypatch.setattr(type(data), "model_dump", _model_dump_with_future_field)
blocking = AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
assert response["data"]["future_field"] == "future-value"
def test_stream_simple_response_includes_node_events(self):
node_start = NodeStartStreamResponse(
task_id="t1",

View File

@ -39,15 +39,19 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
AnnotationReply,
AnnotationReplyAccount,
HumanInputRequiredResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
)
from core.base.tts.app_generator_tts_publisher import AudioTrunk
from core.workflow.system_variables import build_system_variables
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import BuiltinNodeTypes
from graphon.nodes.human_input.entities import UserAction
from graphon.runtime import GraphRuntimeState, VariablePool
from libs.datetime_utils import naive_utc_now
from models.enums import MessageStatus
@ -123,6 +127,57 @@ class TestAdvancedChatGenerateTaskPipeline:
assert response.data.answer == "done"
assert response.data.metadata == {"k": "v"}
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
pipeline = _make_pipeline()
pipeline._task_state.answer = "partial answer"
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
start_at=0.0,
total_tokens=7,
node_run_steps=3,
)
def _gen():
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run-id",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Approval",
form_content="Need approval",
inputs=[],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
form_token="token-1",
resolved_default_values={},
expiration_time=123,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, AdvancedChatPausedBlockingResponse)
assert response.data.workflow_run_id == "run-id"
assert response.data.status == "paused"
assert response.data.paused_nodes == ["node-1"]
assert response.data.reasons == [
{
"TYPE": PauseReasonType.HUMAN_INPUT_REQUIRED,
"form_id": "form-1",
"node_id": "node-1",
"node_title": "Approval",
"form_content": "Need approval",
"inputs": [],
"actions": [{"id": "approve", "title": "Approve", "button_style": "default"}],
"display_in_ui": True,
"form_token": "token-1",
"resolved_default_values": {},
"expiration_time": 123,
}
]
def test_handle_text_chunk_event_updates_state(self):
pipeline = _make_pipeline()
pipeline._message_cycle_manager = SimpleNamespace(

View File

@ -375,7 +375,7 @@ def test_generate_success_returns_converted(generator, mocker):
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = workflow
session.get.return_value = workflow
mocker.patch.object(module.db, "session", session)
queue_manager = MagicMock()

View File

@ -132,11 +132,8 @@ def test_run_pipeline_not_found(mocker):
app_generate_entity.single_iteration_run = None
app_generate_entity.single_loop_run = None
query = MagicMock()
query.where.return_value.first.return_value = None
session = MagicMock()
session.query.return_value = query
session.get.side_effect = [None, None]
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(
@ -157,11 +154,9 @@ def test_run_workflow_not_initialized(mocker):
app_generate_entity = _build_app_generate_entity()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
session = MagicMock()
session.query.return_value = query_pipeline
session.get.side_effect = [None, pipeline]
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(

View File

@ -0,0 +1,102 @@
from __future__ import annotations
from collections.abc import Generator
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import (
AppStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
from graphon.enums import WorkflowExecutionStatus
class _DummyConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
blocking_full_calls: list[WorkflowAppBlockingResponse] = []
blocking_simple_calls: list[WorkflowAppBlockingResponse] = []
stream_full_calls: list[Generator[AppStreamResponse, None, None]] = []
stream_simple_calls: list[Generator[AppStreamResponse, None, None]] = []
@classmethod
def reset(cls) -> None:
cls.blocking_full_calls = []
cls.blocking_simple_calls = []
cls.stream_full_calls = []
cls.stream_simple_calls = []
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
cls.blocking_full_calls.append(blocking_response)
return {"kind": "blocking-full", "task_id": blocking_response.task_id}
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
cls.blocking_simple_calls.append(blocking_response)
return {"kind": "blocking-simple", "task_id": blocking_response.task_id}
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
cls.stream_full_calls.append(stream_response)
yield {"kind": "stream-full"}
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
cls.stream_simple_calls.append(stream_response)
yield {"kind": "stream-simple"}
def _build_blocking_response() -> WorkflowAppBlockingResponse:
return WorkflowAppBlockingResponse(
task_id="task-1",
workflow_run_id="run-1",
data=WorkflowAppBlockingResponse.Data(
id="run-1",
workflow_id="workflow-1",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs={"ok": True},
error=None,
elapsed_time=0.1,
total_tokens=0,
total_steps=1,
created_at=1,
finished_at=2,
),
)
def _build_stream_response() -> Generator[AppStreamResponse, None, None]:
yield WorkflowAppStreamResponse(
workflow_run_id="run-1",
stream_response=PingStreamResponse(task_id="task-1"),
)
def test_convert_routes_blocking_response_by_invoke_from() -> None:
_DummyConverter.reset()
blocking_response = _build_blocking_response()
full_result = _DummyConverter.convert(blocking_response, InvokeFrom.SERVICE_API)
simple_result = _DummyConverter.convert(blocking_response, InvokeFrom.WEB_APP)
assert full_result == {"kind": "blocking-full", "task_id": "task-1"}
assert simple_result == {"kind": "blocking-simple", "task_id": "task-1"}
assert _DummyConverter.blocking_full_calls == [blocking_response]
assert _DummyConverter.blocking_simple_calls == [blocking_response]
def test_convert_routes_stream_response_by_invoke_from() -> None:
_DummyConverter.reset()
full_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.SERVICE_API))
simple_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.WEB_APP))
assert full_result == [{"kind": "stream-full"}]
assert simple_result == [{"kind": "stream-simple"}]
assert len(_DummyConverter.stream_full_calls) == 1
assert len(_DummyConverter.stream_simple_calls) == 1

View File

@ -1,6 +1,7 @@
from unittest.mock import Mock, patch
from core.app.apps.message_generator import MessageGenerator
from core.app.entities.task_entities import StreamEvent
from models.model import AppMode
@ -23,7 +24,21 @@ class TestMessageGenerator:
"core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}])
) as mock_stream,
):
events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2))
events = list(
MessageGenerator.retrieve_events(
AppMode.WORKFLOW,
"run-1",
idle_timeout=1,
ping_interval=2,
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
)
)
assert events == [{"event": "ping"}]
mock_stream.assert_called_once()
mock_stream.assert_called_once_with(
topic="topic",
idle_timeout=1,
ping_interval=2,
on_subscribe=None,
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
)

View File

@ -88,6 +88,10 @@ def test_normalize_terminal_events_defaults():
}
def test_normalize_terminal_events_empty_values():
assert _normalize_terminal_events([]) == set({})
def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
topic = FakeTopic()
times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0]
@ -106,3 +110,21 @@ def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
assert next(generator) == StreamEvent.PING.value
# next receive yields None -> ping interval triggers
assert next(generator) == StreamEvent.PING.value
def test_stream_topic_events_can_continue_past_pause():
topic = FakeTopic()
topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_PAUSED.value}).encode())
topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_FINISHED.value}).encode())
generator = stream_topic_events(
topic=topic,
idle_timeout=1.0,
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
)
assert next(generator) == StreamEvent.PING.value
assert next(generator)["event"] == StreamEvent.WORKFLOW_PAUSED.value
assert next(generator)["event"] == StreamEvent.WORKFLOW_FINISHED.value
with pytest.raises(StopIteration):
next(generator)

View File

@ -36,11 +36,12 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
WorkflowAppPausedBlockingResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.base.tts.app_generator_tts_publisher import AudioTrunk
@ -91,27 +92,50 @@ def _make_pipeline():
class TestWorkflowGenerateTaskPipeline:
def test_to_blocking_response_handles_pause(self):
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
start_at=0.0,
total_tokens=5,
node_run_steps=2,
)
def _gen():
yield WorkflowPauseStreamResponse(
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
workflow_run_id="run-id",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Human Input",
form_content="content",
expiration_time=1,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, WorkflowAppPausedBlockingResponse)
assert response.workflow_run_id == "run-id"
assert response.data.status == WorkflowExecutionStatus.PAUSED
assert response.data.created_at == 0
assert response.data.paused_nodes == ["node-1"]
assert response.data.reasons == [
{
"TYPE": "human_input_required",
"form_id": "form-1",
"node_id": "node-1",
"node_title": "Human Input",
"form_content": "content",
"inputs": [],
"actions": [],
"display_in_ui": False,
"form_token": None,
"resolved_default_values": {},
"expiration_time": 1,
}
]
def test_to_blocking_response_handles_finish(self):
pipeline = _make_pipeline()

View File

@ -775,9 +775,6 @@ class TestNotionExtractorLastEditedTime:
"last_edited_time": "2024-11-27T18:00:00.000Z",
}
mock_request.return_value = mock_response
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
# Act
extractor_page.update_last_edited_time(mock_document_model)
@ -863,9 +860,6 @@ class TestNotionExtractorIntegration:
}
mock_request.side_effect = [last_edited_response, block_response]
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
# Act
documents = extractor.extract()
@ -919,10 +913,6 @@ class TestNotionExtractorIntegration:
}
mock_post.return_value = database_response
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
# Act
documents = extractor.extract()

View File

@ -0,0 +1,106 @@
"""Tests for the Creators Platform helper module."""
from unittest.mock import MagicMock, patch
import httpx
import pytest
from yarl import URL
@pytest.fixture(autouse=True)
def _patch_creators_url(monkeypatch):
"""Patch the module-level creators_platform_api_url for all tests."""
monkeypatch.setattr(
"core.helper.creators.creators_platform_api_url",
URL("https://creators.example.com"),
)
class TestUploadDSL:
@patch("core.helper.creators.httpx.post")
def test_returns_claim_code(self, mock_post):
mock_response = MagicMock(spec=httpx.Response)
mock_response.json.return_value = {"data": {"claim_code": "abc123"}}
mock_response.raise_for_status = MagicMock()
mock_post.return_value = mock_response
from core.helper.creators import upload_dsl
result = upload_dsl(b"app: demo", "demo.yaml")
assert result == "abc123"
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
assert "anonymous-upload" in call_kwargs.args[0]
assert call_kwargs.kwargs["timeout"] == 30
@patch("core.helper.creators.httpx.post")
def test_raises_on_missing_claim_code(self, mock_post):
mock_response = MagicMock(spec=httpx.Response)
mock_response.json.return_value = {"data": {}}
mock_response.raise_for_status = MagicMock()
mock_post.return_value = mock_response
from core.helper.creators import upload_dsl
with pytest.raises(ValueError, match="claim_code"):
upload_dsl(b"app: demo")
@patch("core.helper.creators.httpx.post")
def test_raises_on_http_error(self, mock_post):
mock_response = MagicMock(spec=httpx.Response)
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Server Error",
request=MagicMock(),
response=MagicMock(),
)
mock_post.return_value = mock_response
from core.helper.creators import upload_dsl
with pytest.raises(httpx.HTTPStatusError):
upload_dsl(b"app: demo")
class TestGetRedirectUrl:
@patch("core.helper.creators.dify_config")
def test_without_oauth_client_id(self, mock_config):
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
from core.helper.creators import get_redirect_url
url = get_redirect_url("user-1", "claim-abc")
assert "dsl_claim_code=claim-abc" in url
assert "oauth_code" not in url
assert url.startswith("https://creators.example.com")
@patch("core.helper.creators.dify_config")
def test_with_oauth_client_id(self, mock_config):
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "client-xyz"
with patch(
"services.oauth_server.OAuthServerService.sign_oauth_authorization_code",
return_value="oauth-code-123",
) as mock_sign:
from core.helper.creators import get_redirect_url
url = get_redirect_url("user-1", "claim-abc")
mock_sign.assert_called_once_with("client-xyz", "user-1")
assert "dsl_claim_code=claim-abc" in url
assert "oauth_code=oauth-code-123" in url
@patch("core.helper.creators.dify_config")
def test_strips_trailing_slash(self, mock_config):
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com/"
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
from core.helper.creators import get_redirect_url
url = get_redirect_url("user-1", "claim-abc")
assert url.startswith("https://creators.example.com?")
assert "creators.example.com/?" not in url

Some files were not shown because too many files have changed in this diff Show More