mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 18:27:19 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
5263a65ed6
@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Creators Platform configuration
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED=true
|
||||
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for Creators Platform integration
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable Creators Platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client ID for Creators Platform integration",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -1379,6 +1400,7 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
CreatorsPlatformConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
|
||||
6
api/controllers/common/human_input.py
Normal file
6
api/controllers/common/human_input.py
Normal file
@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
action: str
|
||||
@ -692,6 +692,32 @@ class AppExportApi(Resource):
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||
class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
|
||||
@ -8,10 +8,10 @@ from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_recipient_type(form: Form) -> None:
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
|
||||
raise NotFoundError("form not found")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
|
||||
self._ensure_console_recipient_type(form)
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
# So we need to assert it manually.
|
||||
assert recipient_type is not None, "recipient_type cannot be None here."
|
||||
|
||||
@ -37,6 +37,11 @@ class TagBindingRemovePayload(BaseModel):
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagBindingItemDeletePayload(BaseModel):
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagListQueryParam(BaseModel):
|
||||
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
@ -70,6 +75,7 @@ register_schema_models(
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagBindingItemDeletePayload,
|
||||
TagListQueryParam,
|
||||
TagResponse,
|
||||
)
|
||||
@ -152,41 +158,107 @@ class TagUpdateDeleteApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_binding() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_id=payload.tag_id,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings")
|
||||
class TagBindingCollectionApi(Resource):
|
||||
"""Canonical collection resource for tag binding creation."""
|
||||
|
||||
@console_ns.doc("create_tag_binding")
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
return _create_tag_bindings()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
|
||||
|
||||
@console_ns.route("/tag-bindings/<uuid:id>")
|
||||
class TagBindingItemApi(Resource):
|
||||
"""Canonical item resource for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("delete_tag_binding")
|
||||
@console_ns.doc(params={"id": "Tag ID"})
|
||||
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id):
|
||||
_require_tag_binding_edit_permission()
|
||||
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_id=str(id),
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class DeprecatedTagBindingCreateApi(Resource):
|
||||
"""Deprecated verb-based alias for tag binding creation."""
|
||||
|
||||
@console_ns.doc("create_tag_binding_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
class DeprecatedTagBindingRemoveApi(Resource):
|
||||
"""Deprecated verb-based alias for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("delete_tag_binding_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return _remove_tag_binding()
|
||||
|
||||
@ -23,9 +23,11 @@ from .app import (
|
||||
conversation,
|
||||
file,
|
||||
file_preview,
|
||||
human_input_form,
|
||||
message,
|
||||
site,
|
||||
workflow,
|
||||
workflow_events,
|
||||
)
|
||||
from .dataset import (
|
||||
dataset,
|
||||
@ -50,6 +52,7 @@ __all__ = [
|
||||
"file",
|
||||
"file_preview",
|
||||
"hit_testing",
|
||||
"human_input_form",
|
||||
"index",
|
||||
"message",
|
||||
"metadata",
|
||||
@ -58,6 +61,7 @@ __all__ = [
|
||||
"segment",
|
||||
"site",
|
||||
"workflow",
|
||||
"workflow_events",
|
||||
]
|
||||
|
||||
api.add_namespace(service_api_ns)
|
||||
|
||||
137
api/controllers/service_api/app/human_input_form.py
Normal file
137
api/controllers/service_api/app/human_input_form.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""
|
||||
Service API human input form endpoints.
|
||||
|
||||
This module exposes app-token authenticated APIs for fetching and submitting
|
||||
paused human input forms in workflow/chatflow runs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": _to_timestamp(form.expiration_time),
|
||||
}
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
|
||||
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
|
||||
# Keep app-token callers scoped to the public web-form surface; internal HITL
|
||||
# routes must continue to flow through console-only authentication.
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
@service_api_ns.route("/form/human_input/<string:form_token>")
|
||||
class WorkflowHumanInputFormApi(Resource):
|
||||
@service_api_ns.doc("get_human_input_form")
|
||||
@service_api_ns.doc(description="Get a paused human input form by token")
|
||||
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Form retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Form not found",
|
||||
412: "Form already submitted or expired",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, form_token: str):
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_service_api(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@service_api_ns.doc("submit_human_input_form")
|
||||
@service_api_ns.doc(description="Submit a paused human input form by token")
|
||||
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Form submitted successfully",
|
||||
400: "Bad request - invalid submission data",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Form not found",
|
||||
412: "Form already submitted or expired",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, form_token: str):
|
||||
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_service_api(form)
|
||||
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type is None:
|
||||
logger.warning("Recipient type is None for form, form_id=%s", form.id)
|
||||
raise BadRequest("Form recipient type is invalid")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_end_user_id=end_user.id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return {}, 200
|
||||
142
api/controllers/service_api/app/workflow_events.py
Normal file
142
api/controllers/service_api/app/workflow_events.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
Service API workflow resume event stream endpoints.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
@service_api_ns.route("/workflow/<string:task_id>/events")
|
||||
class WorkflowEventsApi(Resource):
|
||||
"""Service API for getting workflow execution events after resume."""
|
||||
|
||||
@service_api_ns.doc("get_workflow_events")
|
||||
@service_api_ns.doc(description="Get workflow execution events stream after resume")
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"task_id": "Workflow run ID",
|
||||
"user": "End user identifier (query param)",
|
||||
"include_state_snapshot": (
|
||||
"Whether to replay from persisted state snapshot, "
|
||||
'specify `"true"` to include a status snapshot of executed nodes'
|
||||
),
|
||||
"continue_on_pause": (
|
||||
"Whether to keep the stream open across workflow_paused events,"
|
||||
'specify `"true"` to keep the stream open for `workflow_paused` events.'
|
||||
),
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "SSE event stream",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Workflow run not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=task_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.created_by != end_user.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
workflow_run_entity = workflow_run
|
||||
|
||||
if workflow_run_entity.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run_entity.id,
|
||||
workflow_run=workflow_run_entity,
|
||||
creator_user=end_user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app_mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run_entity,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=HumanInputSurface.SERVICE_API,
|
||||
close_on_pause=not continue_on_pause,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(
|
||||
app_mode,
|
||||
workflow_run_entity.id,
|
||||
terminal_events=terminal_events,
|
||||
),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
|
||||
@ -34,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
@ -655,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
|
||||
) -> (
|
||||
ChatbotAppBlockingResponse
|
||||
| AdvancedChatPausedBlockingResponse
|
||||
| Generator[ChatbotAppStreamResponse, None, None]
|
||||
):
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppBlockingResponse,
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
@ -12,22 +12,40 @@ from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
class AdvancedChatAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
|
||||
paused_data = blocking_response.data.model_dump(mode="json")
|
||||
return {
|
||||
"event": StreamEvent.WORKFLOW_PAUSED.value,
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
"workflow_run_id": blocking_response.data.workflow_run_id,
|
||||
"data": paused_data,
|
||||
}
|
||||
|
||||
response = {
|
||||
"event": "message",
|
||||
"event": StreamEvent.MESSAGE.value,
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -53,14 +53,18 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if message.status == MessageStatus.PAUSED and message.answer:
|
||||
self._task_state.answer = message.answer
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
Generator[ChatbotAppStreamResponse, None, None],
|
||||
]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
answer=self._task_state.answer,
|
||||
metadata=self._message_end_to_stream_response().metadata,
|
||||
created_at=self._message_created_at,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
status=stream_response.data.status,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
),
|
||||
)
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {}
|
||||
if stream_response.metadata:
|
||||
@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
continue
|
||||
|
||||
if human_input_responses:
|
||||
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _build_paused_blocking_response_from_human_input(
|
||||
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||
) -> AdvancedChatPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
workflow_run_id=human_input_responses[-1].workflow_run_id,
|
||||
answer=self._task_state.answer,
|
||||
metadata=self._message_end_to_stream_response().metadata,
|
||||
created_at=self._message_created_at,
|
||||
paused_nodes=paused_nodes,
|
||||
reasons=reasons,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
|
||||
total_tokens=runtime_state.total_tokens,
|
||||
total_steps=runtime_state.node_run_steps,
|
||||
),
|
||||
)
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
@ -11,8 +13,10 @@ from graphon.model_runtime.errors.invoke import InvokeError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppGenerateResponseConverter(ABC):
|
||||
_blocking_response_type: type[AppBlockingResponse]
|
||||
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
@classmethod
|
||||
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
|
||||
return cast(TBlockingResponse, response)
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
@ -20,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -29,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -39,12 +43,12 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -106,13 +110,13 @@ class AppGenerateResponseConverter(ABC):
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
error_responses: dict[type[Exception], dict[str, Any]] = {
|
||||
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
|
||||
ValueError: {"code": "invalid_param", "status": 400},
|
||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||
QuotaExceededError: {
|
||||
@ -126,7 +130,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data: dict[str, Any] | None = None
|
||||
data: dict[str, JsonValue] | None = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -52,6 +52,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@ -336,7 +337,26 @@ class WorkflowResponseConverter:
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=(
|
||||
HumanInputSurface.SERVICE_API
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
# otherwise clients see schema drift after resume.
|
||||
pause_reasons = enrich_human_input_pause_reasons(
|
||||
pause_reasons,
|
||||
form_tokens_by_form_id=form_token_by_form_id,
|
||||
expiration_times_by_form_id={
|
||||
form_id: int(expiration_time.timestamp())
|
||||
for form_id, expiration_time in expiration_times_by_form_id.items()
|
||||
},
|
||||
)
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -12,17 +14,15 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = CompletionAppBlockingResponse
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
response: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping
|
||||
|
||||
from core.app.apps.streaming_utils import stream_topic_events
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from models.model import AppMode
|
||||
@ -26,6 +27,7 @@ class MessageGenerator:
|
||||
idle_timeout=300,
|
||||
ping_interval: float = 10.0,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
terminal_events: Iterable[str | StreamEvent] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return stream_topic_events(
|
||||
@ -33,4 +35,5 @@ class MessageGenerator:
|
||||
idle_timeout=idle_timeout,
|
||||
ping_interval=ping_interval,
|
||||
on_subscribe=on_subscribe,
|
||||
terminal_events=terminal_events,
|
||||
)
|
||||
|
||||
@ -13,11 +13,9 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -27,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
) -> (
|
||||
WorkflowAppBlockingResponse
|
||||
| WorkflowAppPausedBlockingResponse
|
||||
| Generator[WorkflowAppStreamResponse, None, None]
|
||||
):
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -59,7 +59,7 @@ def stream_topic_events(
|
||||
|
||||
|
||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
||||
if not terminal_events:
|
||||
if terminal_events is None:
|
||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
||||
values: set[str] = set()
|
||||
for item in terminal_events:
|
||||
|
||||
@ -25,7 +25,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
) -> (
|
||||
WorkflowAppBlockingResponse
|
||||
| WorkflowAppPausedBlockingResponse
|
||||
| Generator[WorkflowAppStreamResponse, None, None]
|
||||
):
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -9,24 +9,29 @@ from core.app.entities.task_entities import (
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return blocking_response.model_dump()
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -42,12 +42,15 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
|
||||
]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
|
||||
"""
|
||||
To blocking response.
|
||||
:return:
|
||||
"""
|
||||
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
id=stream_response.data.workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=stream_response.data.status,
|
||||
@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=stream_response.data.created_at,
|
||||
finished_at=None,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
return WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
if human_input_responses:
|
||||
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _build_paused_blocking_response_from_human_input(
|
||||
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||
) -> WorkflowAppPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
created_at = int(runtime_state.start_at)
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=human_input_responses[-1].workflow_run_id,
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
id=human_input_responses[-1].workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
error=None,
|
||||
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
|
||||
total_tokens=runtime_state.total_tokens,
|
||||
total_steps=runtime_state.node_run_steps,
|
||||
created_at=created_at,
|
||||
finished_at=None,
|
||||
paused_nodes=paused_nodes,
|
||||
reasons=reasons,
|
||||
),
|
||||
)
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputRequiredPauseReasonPayload(BaseModel):
|
||||
"""
|
||||
Public pause-reason payload used by blocking responses when only
|
||||
``human_input_required`` events are available.
|
||||
"""
|
||||
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
|
||||
@classmethod
|
||||
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
|
||||
return cls(
|
||||
form_id=data.form_id,
|
||||
node_id=data.node_id,
|
||||
node_title=data.node_title,
|
||||
form_content=data.form_content,
|
||||
inputs=data.inputs,
|
||||
actions=data.actions,
|
||||
display_in_ui=data.display_in_ui,
|
||||
form_token=data.form_token,
|
||||
resolved_default_values=data.resolved_default_values,
|
||||
expiration_time=data.expiration_time,
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormFilledResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
ChatbotAppPausedBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
mode: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
workflow_run_id: str
|
||||
answer: str
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
|
||||
status: WorkflowExecutionStatus
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
|
||||
data: Data
|
||||
|
||||
|
||||
class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
CompletionAppBlockingResponse entity
|
||||
@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
WorkflowAppPausedBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
created_at: int
|
||||
finished_at: int | None
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class AgentLogStreamResponse(StreamResponse):
|
||||
"""
|
||||
AgentLogStreamResponse entity
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
||||
@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
"""Resolves and returns LLM credentials for a given provider and model.
|
||||
|
||||
Fetched credentials are stored in :attr:`credentials_cache` and reused for
|
||||
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
|
||||
Because of that cache, a single instance can return stale credentials after
|
||||
the tenant or provider configuration changes (e.g. API key rotation).
|
||||
|
||||
Do **not** keep one instance for the lifetime of a process or across
|
||||
unrelated invocations. Create a new provider per request, workflow run, or
|
||||
other bounded scope where up-to-date credentials matter.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
credentials_cache: dict[tuple[str, str], dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -30,8 +44,12 @@ class DifyCredentialsProvider:
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
self.provider_manager = provider_manager
|
||||
self.credentials_cache = {}
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
if (provider_name, model_name) in self.credentials_cache:
|
||||
return deepcopy(self.credentials_cache[(provider_name, model_name)])
|
||||
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider_name)
|
||||
if not provider_configuration:
|
||||
@ -46,6 +64,7 @@ class DifyCredentialsProvider:
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
@ -65,7 +84,8 @@ class DifyModelFactory:
|
||||
provider_manager=create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
),
|
||||
enable_credentials_cache=True,
|
||||
)
|
||||
self.model_manager = model_manager
|
||||
|
||||
@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
model_manager = ModelManager(provider_manager=provider_manager)
|
||||
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
|
||||
|
||||
return (
|
||||
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
||||
|
||||
41
api/core/helper/creators.py
Normal file
41
api/core/helper/creators.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""
|
||||
Helper module for Creators Platform integration.
|
||||
|
||||
Provides functionality to upload DSL files to the Creators Platform
|
||||
and generate redirect URLs with OAuth authorization codes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||
|
||||
|
||||
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
claim_code = data.get("data", {}).get("claim_code")
|
||||
if not claim_code:
|
||||
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||
return claim_code
|
||||
|
||||
|
||||
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||
if client_id:
|
||||
from services.oauth_server import OAuthServerService
|
||||
|
||||
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||
params["oauth_code"] = oauth_code
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen
|
||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.llm_generator.prompts import (
|
||||
CONVERSATION_TITLE_PROMPT,
|
||||
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
@ -217,8 +215,8 @@ class LLMGenerator:
|
||||
else:
|
||||
# Default-model generation keeps the built-in suggested-questions tuning.
|
||||
model_parameters = {
|
||||
"max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
"temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
"max_tokens": 2560,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
stop = []
|
||||
|
||||
|
||||
@ -10,7 +10,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
def __init__(self, instruction_prompt: str | None = None) -> None:
|
||||
self._instruction_prompt = instruction_prompt or DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
|
||||
|
||||
@staticmethod
|
||||
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
|
||||
if not instruction_prompt or not instruction_prompt.strip():
|
||||
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return self._instruction_prompt
|
||||
|
||||
@ -104,9 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
||||
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
|
||||
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
||||
" in the long text. Please think step by step."
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
@ -36,11 +37,13 @@ class ModelInstance:
|
||||
Model instance class.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
|
||||
self.provider_model_bundle = provider_model_bundle
|
||||
self.model_name = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
if credentials is None:
|
||||
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.credentials = credentials
|
||||
# Runtime LLM invocation fields.
|
||||
self.parameters: Mapping[str, Any] = {}
|
||||
self.stop: Sequence[str] = ()
|
||||
@ -434,8 +437,30 @@ class ModelInstance:
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, provider_manager: ProviderManager):
|
||||
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
|
||||
|
||||
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
|
||||
``(tenant_id, provider, model_type, model)`` are stored in
|
||||
``_credentials_cache`` and reused. That can return **stale** credentials after
|
||||
API keys or provider settings change, so a manager constructed with
|
||||
``enable_credentials_cache=True`` should not be kept for the lifetime of a
|
||||
process or shared across unrelated work. Prefer a new manager per request,
|
||||
workflow run, or similar bounded scope.
|
||||
|
||||
The default is ``enable_credentials_cache=False``; in that mode the internal
|
||||
credential cache is not populated, and each ``get_model_instance`` call
|
||||
loads credentials from the current provider configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_manager: ProviderManager,
|
||||
*,
|
||||
enable_credentials_cache: bool = False,
|
||||
) -> None:
|
||||
self._provider_manager = provider_manager
|
||||
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
|
||||
self._enable_credentials_cache = enable_credentials_cache
|
||||
|
||||
@classmethod
|
||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||
@ -463,8 +488,19 @@ class ModelManager:
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
model_instance = ModelInstance(provider_model_bundle, model)
|
||||
return model_instance
|
||||
cred_cache_key = (tenant_id, provider, model_type.value, model)
|
||||
|
||||
if cred_cache_key in self._credentials_cache:
|
||||
return ModelInstance(
|
||||
provider_model_bundle,
|
||||
model,
|
||||
deepcopy(self._credentials_cache[cred_cache_key]),
|
||||
)
|
||||
|
||||
ret = ModelInstance(provider_model_bundle, model)
|
||||
if self._enable_credentials_cache:
|
||||
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
|
||||
return ret
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
|
||||
@ -156,7 +156,8 @@ class Jieba(BaseKeyword):
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
return dict(keyword_table_dict["__data__"]["table"])
|
||||
data: Any = keyword_table_dict["__data__"]
|
||||
return dict(data["table"])
|
||||
else:
|
||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
|
||||
@ -109,7 +109,7 @@ class JiebaKeywordTableHandler:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = self._tfidf.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
topK=max_keywords_per_chunk or 10,
|
||||
)
|
||||
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
||||
keywords = cast(list[str], keywords)
|
||||
|
||||
@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter:
|
||||
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
stream=False, # pyright: ignore[reportArgumentType]
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
@ -14,23 +14,23 @@ from configs import dify_config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
class EncryptionError(Exception):
|
||||
"""Encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemOAuthEncrypter:
|
||||
class SystemEncrypter:
|
||||
"""
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
A simple parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
This class provides methods to encrypt and decrypt parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
Initialize the encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||
def encrypt_params(self, params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters.
|
||||
Encrypt parameters.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
EncryptionError: If encryption fails
|
||||
ValueError: If params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
raise EncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters.
|
||||
Decrypt parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
Decrypted parameters dictionary
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If decryption fails
|
||||
EncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(oauth_params, dict):
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return oauth_params
|
||||
return params
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
raise EncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
|
||||
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
Create an encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
return SystemEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: SystemOAuthEncrypter | None = None
|
||||
_encrypter: SystemEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
def get_system_encrypter() -> SystemEncrypter:
|
||||
"""
|
||||
Get the global OAuth encrypter instance.
|
||||
Get the global encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
global _encrypter
|
||||
if _encrypter is None:
|
||||
_encrypter = SystemEncrypter()
|
||||
return _encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||
def encrypt_system_params(params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
Encrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary
|
||||
params: Parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
return get_system_encrypter().encrypt_params(params)
|
||||
|
||||
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
Decrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
Decrypted parameters dictionary
|
||||
"""
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||
return get_system_encrypter().decrypt_params(encrypted_data)
|
||||
@ -12,20 +12,16 @@ from collections.abc import Sequence
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
|
||||
from extensions.ext_database import db
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
_FORM_TOKEN_PRIORITY = {
|
||||
RecipientType.BACKSTAGE: 0,
|
||||
RecipientType.CONSOLE: 1,
|
||||
RecipientType.STANDALONE_WEB_APP: 2,
|
||||
}
|
||||
|
||||
|
||||
def load_form_tokens_by_form_id(
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
session: Session | None = None,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Load the preferred access token for each human input form."""
|
||||
unique_form_ids = list(dict.fromkeys(form_ids))
|
||||
@ -33,23 +29,43 @@ def load_form_tokens_by_form_id(
|
||||
return {}
|
||||
|
||||
if session is not None:
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids)
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
|
||||
|
||||
|
||||
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
|
||||
tokens_by_form_id: dict[str, tuple[int, str]] = {}
|
||||
def _load_form_tokens_by_form_id(
|
||||
session: Session,
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
|
||||
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(stmt):
|
||||
priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type)
|
||||
if priority is None or not recipient.access_token:
|
||||
if not recipient.access_token:
|
||||
continue
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(
|
||||
(recipient.recipient_type, recipient.access_token)
|
||||
)
|
||||
|
||||
candidate = (priority, recipient.access_token)
|
||||
current = tokens_by_form_id.get(recipient.form_id)
|
||||
if current is None or candidate[0] < current[0]:
|
||||
tokens_by_form_id[recipient.form_id] = candidate
|
||||
tokens_by_form_id: dict[str, str] = {}
|
||||
for form_id, recipients in recipients_by_form_id.items():
|
||||
token = _get_surface_form_token(recipients, surface=surface)
|
||||
if token is not None:
|
||||
tokens_by_form_id[form_id] = token
|
||||
return tokens_by_form_id
|
||||
|
||||
return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()}
|
||||
|
||||
def _get_surface_form_token(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> str | None:
|
||||
if surface == HumanInputSurface.SERVICE_API:
|
||||
for recipient_type, token in recipients:
|
||||
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||
return token
|
||||
|
||||
return get_preferred_form_token(recipients)
|
||||
|
||||
73
api/core/workflow/human_input_policy.py
Normal file
73
api/core/workflow/human_input_policy.py
Normal file
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from models.human_input import RecipientType
|
||||
|
||||
|
||||
class HumanInputSurface(StrEnum):
|
||||
SERVICE_API = "service_api"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
# Service API is intentionally narrower than other surfaces: app-token callers
|
||||
# should only be able to act on end-user web forms, not internal console flows.
|
||||
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
|
||||
}
|
||||
|
||||
# A single HITL form can have multiple recipient records; this shared priority
|
||||
# keeps every API surface consistent about which resume token to expose.
|
||||
_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = {
|
||||
RecipientType.BACKSTAGE: 0,
|
||||
RecipientType.CONSOLE: 1,
|
||||
RecipientType.STANDALONE_WEB_APP: 2,
|
||||
}
|
||||
|
||||
|
||||
def is_recipient_type_allowed_for_surface(
|
||||
recipient_type: RecipientType | None,
|
||||
surface: HumanInputSurface,
|
||||
) -> bool:
|
||||
if recipient_type is None:
|
||||
return False
|
||||
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
|
||||
|
||||
|
||||
def get_preferred_form_token(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
) -> str | None:
|
||||
chosen_token: str | None = None
|
||||
chosen_priority: int | None = None
|
||||
for recipient_type, token in recipients:
|
||||
priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type)
|
||||
if priority is None or not token:
|
||||
continue
|
||||
if chosen_priority is None or priority < chosen_priority:
|
||||
chosen_priority = priority
|
||||
chosen_token = token
|
||||
return chosen_token
|
||||
|
||||
|
||||
def enrich_human_input_pause_reasons(
|
||||
reasons: Sequence[Mapping[str, Any]],
|
||||
*,
|
||||
form_tokens_by_form_id: Mapping[str, str],
|
||||
expiration_times_by_form_id: Mapping[str, int],
|
||||
) -> list[dict[str, Any]]:
|
||||
enriched: list[dict[str, Any]] = []
|
||||
for reason in reasons:
|
||||
updated = dict(reason)
|
||||
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
form_id = updated.get("form_id")
|
||||
if isinstance(form_id, str):
|
||||
updated["form_token"] = form_tokens_by_form_id.get(form_id)
|
||||
expiration_time = expiration_times_by_form_id.get(form_id)
|
||||
if expiration_time is not None:
|
||||
updated["expiration_time"] = expiration_time
|
||||
enriched.append(updated)
|
||||
return enriched
|
||||
@ -225,8 +225,10 @@ class TestSpanBuilder:
|
||||
span = builder.build_span(span_data)
|
||||
assert isinstance(span, ReadableSpan)
|
||||
assert span.name == "test-span"
|
||||
assert span.context is not None
|
||||
assert span.context.trace_id == 123
|
||||
assert span.context.span_id == 456
|
||||
assert span.parent is not None
|
||||
assert span.parent.span_id == 789
|
||||
assert span.resource == resource
|
||||
assert span.attributes == {"attr1": "val1"}
|
||||
|
||||
@ -64,12 +64,13 @@ class TestSpanData:
|
||||
|
||||
def test_span_data_missing_required_fields(self):
|
||||
with pytest.raises(ValidationError):
|
||||
SpanData(
|
||||
trace_id=123,
|
||||
# span_id missing
|
||||
name="test_span",
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
SpanData.model_validate(
|
||||
{
|
||||
"trace_id": 123,
|
||||
"name": "test_span",
|
||||
"start_time": 1000,
|
||||
"end_time": 2000,
|
||||
}
|
||||
)
|
||||
|
||||
def test_span_data_arbitrary_types_allowed(self):
|
||||
|
||||
@ -2,12 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
|
||||
import pytest
|
||||
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
@ -44,7 +46,7 @@ class RecordingTraceClient:
|
||||
self.endpoint = endpoint
|
||||
self.added_spans: list[object] = []
|
||||
|
||||
def add_span(self, span) -> None:
|
||||
def add_span(self, span: object) -> None:
|
||||
self.added_spans.append(span)
|
||||
|
||||
def api_check(self) -> bool:
|
||||
@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags.SAMPLED,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
return Link(context)
|
||||
|
||||
|
||||
def _make_trace_metadata(
|
||||
trace_id: int = 1,
|
||||
workflow_span_id: int = 2,
|
||||
session_id: str = "s",
|
||||
user_id: str = "u",
|
||||
links: list[Link] | None = None,
|
||||
) -> TraceMetadata:
|
||||
return TraceMetadata(
|
||||
trace_id=trace_id,
|
||||
workflow_span_id=workflow_span_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
links=[] if links is None else links,
|
||||
)
|
||||
|
||||
|
||||
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
|
||||
return cast(RecordingTraceClient, trace_instance.trace_client)
|
||||
|
||||
|
||||
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
|
||||
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
|
||||
|
||||
|
||||
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
|
||||
defaults = {
|
||||
"workflow_id": "workflow-id",
|
||||
@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
add_workflow_span.assert_called_once()
|
||||
passed_trace_metadata = add_workflow_span.call_args.args[1]
|
||||
passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
|
||||
assert passed_trace_metadata.trace_id == 111
|
||||
assert passed_trace_metadata.workflow_span_id == 222
|
||||
assert passed_trace_metadata.session_id == "c"
|
||||
assert passed_trace_metadata.user_id == "u"
|
||||
assert passed_trace_metadata.links == []
|
||||
|
||||
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
|
||||
assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
|
||||
|
||||
|
||||
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_message_trace_info(message_data=None)
|
||||
trace_instance.message_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
assert _recording_trace_client(trace_instance).added_spans == []
|
||||
|
||||
|
||||
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
|
||||
)
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 2
|
||||
message_span, llm_span = trace_instance.trace_client.added_spans
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 2
|
||||
message_span, llm_span = spans
|
||||
|
||||
assert message_span.name == "message"
|
||||
assert message_span.trace_id == 10
|
||||
@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
|
||||
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
assert _recording_trace_client(trace_instance).added_spans == []
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
|
||||
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
|
||||
|
||||
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "dataset_retrieval"
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "query"
|
||||
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
|
||||
@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
|
||||
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_tool_trace_info(message_data=None)
|
||||
trace_instance.tool_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
assert _recording_trace_client(trace_instance).added_spans == []
|
||||
|
||||
|
||||
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
|
||||
)
|
||||
)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "my-tool"
|
||||
assert span.status == status
|
||||
assert span.attributes[TOOL_NAME] == "my-tool"
|
||||
@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
|
||||
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
|
||||
|
||||
@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
|
||||
):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
|
||||
|
||||
@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
|
||||
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
|
||||
|
||||
@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
|
||||
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
|
||||
|
||||
@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
|
||||
):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_execution.node_type = BuiltinNodeTypes.CODE
|
||||
@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "title"
|
||||
@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
|
||||
trace_metadata = _make_trace_metadata(links=[_make_link()])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "my-tool"
|
||||
@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
|
||||
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
|
||||
)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "retrieval"
|
||||
@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
|
||||
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
|
||||
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "llm"
|
||||
@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
# CASE 1: With message_id
|
||||
trace_info = _make_workflow_trace_info(
|
||||
@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
||||
)
|
||||
trace_instance.add_workflow_span(trace_info, trace_metadata)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 2
|
||||
message_span = trace_instance.trace_client.added_spans[0]
|
||||
workflow_span = trace_instance.trace_client.added_spans[1]
|
||||
client = _recording_trace_client(trace_instance)
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 2
|
||||
message_span = spans[0]
|
||||
workflow_span = spans[1]
|
||||
|
||||
assert message_span.name == "message"
|
||||
assert message_span.span_kind == SpanKind.SERVER
|
||||
@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
||||
assert workflow_span.span_kind == SpanKind.INTERNAL
|
||||
assert workflow_span.parent_span_id == 20
|
||||
|
||||
trace_instance.trace_client.added_spans.clear()
|
||||
client.added_spans.clear()
|
||||
|
||||
# CASE 2: Without message_id
|
||||
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
|
||||
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "workflow"
|
||||
assert span.span_kind == SpanKind.SERVER
|
||||
assert span.parent_span_id is None
|
||||
@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
|
||||
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "suggested_question"
|
||||
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
@ -170,7 +172,7 @@ def test_create_common_span_attributes():
|
||||
|
||||
def test_format_retrieval_documents():
|
||||
# Not a list
|
||||
assert format_retrieval_documents("not a list") == []
|
||||
assert format_retrieval_documents(cast(list[object], "not a list")) == []
|
||||
|
||||
# Valid list
|
||||
docs = [
|
||||
@ -211,7 +213,7 @@ def test_format_retrieval_documents():
|
||||
|
||||
def test_format_input_messages():
|
||||
# Not a dict
|
||||
assert format_input_messages(None) == serialize_json_data([])
|
||||
assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
|
||||
|
||||
# No prompts
|
||||
assert format_input_messages({}) == serialize_json_data([])
|
||||
@ -244,7 +246,7 @@ def test_format_input_messages():
|
||||
|
||||
def test_format_output_messages():
|
||||
# Not a dict
|
||||
assert format_output_messages(None) == serialize_json_data([])
|
||||
assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
|
||||
|
||||
# No text
|
||||
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])
|
||||
|
||||
@ -25,13 +25,13 @@ class TestAliyunConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig()
|
||||
AliyunConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="test_license")
|
||||
AliyunConfig.model_validate({"license_key": "test_license"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
|
||||
|
||||
def test_app_name_validation_empty(self):
|
||||
"""Test app_name validation with empty value"""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -129,7 +130,7 @@ def test_set_span_status():
|
||||
return "SilentErrorRepr"
|
||||
|
||||
span.reset_mock()
|
||||
set_span_status(span, SilentError())
|
||||
set_span_status(span, cast(Exception | str | None, SilentError()))
|
||||
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"
|
||||
|
||||
|
||||
|
||||
@ -28,13 +28,13 @@ class TestLangfuseConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig()
|
||||
LangfuseConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(public_key="public")
|
||||
LangfuseConfig.model_validate({"public_key": "public"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(secret_key="secret")
|
||||
LangfuseConfig.model_validate({"secret_key": "secret"})
|
||||
|
||||
def test_host_validation_empty(self):
|
||||
"""Test host validation with empty value"""
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime:
|
||||
|
||||
assert trace._get_completion_start_time(start_time, None) is None
|
||||
assert trace._get_completion_start_time(start_time, -1) is None
|
||||
assert trace._get_completion_start_time(start_time, "invalid") is None
|
||||
assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None
|
||||
|
||||
@ -21,13 +21,13 @@ class TestLangSmithConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig()
|
||||
LangSmithConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(api_key="key")
|
||||
LangSmithConfig.model_validate({"api_key": "key"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(project="project")
|
||||
LangSmithConfig.model_validate({"project": "project"})
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
|
||||
@ -599,7 +599,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_instance.message_trace(_make_message_trace_info())
|
||||
mock_tracing["start"].assert_called_once()
|
||||
@ -609,7 +608,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_info = _make_message_trace_info(error="something broke")
|
||||
trace_instance.message_trace(trace_info)
|
||||
@ -620,7 +618,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
monkeypatch.setenv("FILES_URL", "http://files.test")
|
||||
|
||||
file_data = SimpleNamespace(url="path/to/file.png")
|
||||
@ -638,7 +635,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
|
||||
trace_instance.message_trace(trace_info)
|
||||
@ -651,7 +647,6 @@ class TestMessageTrace:
|
||||
|
||||
end_user = MagicMock()
|
||||
end_user.session_id = "session-xyz"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
|
||||
|
||||
trace_info = _make_message_trace_info(
|
||||
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
|
||||
@ -664,7 +659,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_info = _make_message_trace_info(
|
||||
metadata={"from_account_id": "acc-1"},
|
||||
|
||||
@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
|
||||
@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace:
|
||||
return instance
|
||||
|
||||
|
||||
def _add_trace_mock(instance: OpikDataTrace) -> MagicMock:
|
||||
return cast(MagicMock, instance.add_trace)
|
||||
|
||||
|
||||
def _add_span_mock(instance: OpikDataTrace) -> MagicMock:
|
||||
return cast(MagicMock, instance.add_span)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _seed_to_uuid4
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
def test_root_span_is_created(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
assert instance.add_span.called
|
||||
assert _add_span_mock(instance).called
|
||||
|
||||
def test_root_span_id_matches_expected(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
expected = self._expected_root_span_id(trace_info)
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["id"] == expected
|
||||
|
||||
def test_root_span_has_no_parent(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["parent_span_id"] is None
|
||||
|
||||
def test_trace_name_is_workflow_trace(self):
|
||||
@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
|
||||
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
|
||||
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
def test_root_span_name_is_workflow_trace(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
def test_root_span_has_workflow_tag(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert "workflow" in root_span_kwargs["tags"]
|
||||
|
||||
def test_node_execution_spans_are_parented_to_root(self):
|
||||
@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
instance = self._run(trace_info, node_executions=[node_exec])
|
||||
|
||||
# call_args_list[0] = root span, [1] = node execution span
|
||||
assert instance.add_span.call_count == 2
|
||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
||||
add_span = _add_span_mock(instance)
|
||||
assert add_span.call_count == 2
|
||||
node_span_kwargs = add_span.call_args_list[1][0][0]
|
||||
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
|
||||
|
||||
def test_node_span_not_parented_to_workflow_app_log_id(self):
|
||||
@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
instance = self._run(trace_info, node_executions=[node_exec])
|
||||
|
||||
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
|
||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
||||
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
|
||||
assert node_span_kwargs["parent_span_id"] != old_parent_id
|
||||
|
||||
def test_root_span_id_differs_from_trace_id(self):
|
||||
@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId:
|
||||
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
|
||||
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
|
||||
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
|
||||
|
||||
def test_root_span_uses_workflow_run_id_directly(self):
|
||||
@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId:
|
||||
instance = self._run(trace_info)
|
||||
|
||||
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["id"] == expected_root_span_id
|
||||
|
||||
def test_root_span_id_differs_from_no_message_id_case(self):
|
||||
@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId:
|
||||
|
||||
instance = self._run(trace_info, node_executions=[node_exec])
|
||||
|
||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
||||
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
|
||||
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, TypedDict, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module
|
||||
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
|
||||
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags
|
||||
|
||||
metric_reader_instances: list[DummyMetricReader] = []
|
||||
meter_provider_instances: list[DummyMeterProvider] = []
|
||||
@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality:
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class PatchedCoreComponents(TypedDict):
|
||||
span_exporter: MagicMock
|
||||
span_processor: MagicMock
|
||||
tracer: MagicMock
|
||||
span: MagicMock
|
||||
tracer_provider: MagicMock
|
||||
logger: MagicMock
|
||||
trace_api: Any
|
||||
|
||||
|
||||
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Drop fake metric modules into sys.modules so the client imports resolve."""
|
||||
|
||||
@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
||||
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents:
|
||||
span_exporter = MagicMock(name="span_exporter")
|
||||
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
|
||||
|
||||
@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
||||
}
|
||||
|
||||
|
||||
def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext:
|
||||
return SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
|
||||
|
||||
def _build_client() -> TencentTraceClient:
|
||||
return TencentTraceClient(
|
||||
service_name="service",
|
||||
@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st
|
||||
|
||||
|
||||
def test_resolve_grpc_target_handles_errors() -> None:
|
||||
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
|
||||
assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None:
|
||||
client.record_trace_duration(0.5)
|
||||
|
||||
|
||||
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
|
||||
def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
|
||||
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
|
||||
@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str,
|
||||
logger.debug.assert_called()
|
||||
|
||||
|
||||
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
ctx = _make_span_context(span_id=2)
|
||||
span.get_span_context.return_value = ctx
|
||||
|
||||
data = SpanData(
|
||||
trace_id=1,
|
||||
@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str,
|
||||
span.add_event.assert_called_once()
|
||||
span.set_status.assert_called_once()
|
||||
span.end.assert_called_once_with(end_time=20)
|
||||
assert client.span_contexts[2] == "ctx"
|
||||
assert client.span_contexts[2] == ctx
|
||||
|
||||
|
||||
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
client.span_contexts[10] = "existing"
|
||||
existing_context = _make_span_context(span_id=10)
|
||||
client.span_contexts[10] = existing_context
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "child"
|
||||
span.get_span_context.return_value = _make_span_context(span_id=11)
|
||||
|
||||
data = SpanData(
|
||||
trace_id=1,
|
||||
@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[
|
||||
|
||||
client._create_and_export_span(data)
|
||||
trace_api = patch_core_components["trace_api"]
|
||||
trace_api.NonRecordingSpan.assert_called_once_with("existing")
|
||||
trace_api.NonRecordingSpan.assert_called_once_with(existing_context)
|
||||
trace_api.set_span_in_context.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
span.get_span_context.return_value = _make_span_context(span_id=2)
|
||||
client.tracer.start_span.side_effect = RuntimeError("boom")
|
||||
|
||||
client._create_and_export_span(
|
||||
@ -385,7 +407,7 @@ def test_get_project_url() -> None:
|
||||
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
|
||||
|
||||
|
||||
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
|
||||
def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span_processor = patch_core_components["span_processor"]
|
||||
tracer_provider = patch_core_components["tracer_provider"]
|
||||
@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object
|
||||
metric_reader.shutdown.assert_called_once()
|
||||
|
||||
|
||||
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
|
||||
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
meter_provider = meter_provider_instances[-1]
|
||||
meter_provider.shutdown.side_effect = RuntimeError("boom")
|
||||
assert client.metric_reader is not None
|
||||
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
|
||||
|
||||
client.shutdown()
|
||||
@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p
|
||||
assert client.metric_reader is None
|
||||
|
||||
|
||||
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
|
||||
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com
|
||||
logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
span.get_span_context.return_value = _make_span_context(span_id=2)
|
||||
|
||||
data = SpanData.model_construct(
|
||||
trace_id=1,
|
||||
@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None:
|
||||
hist_mock = MagicMock(name="hist_llm_duration")
|
||||
client.hist_llm_duration = hist_mock
|
||||
|
||||
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
|
||||
client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2}))
|
||||
_, attrs = hist_mock.record.call_args.args
|
||||
assert isinstance(attrs["foo"], str)
|
||||
assert attrs["bar"] == 2
|
||||
@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None:
|
||||
hist_mock = MagicMock(name="hist_trace_duration")
|
||||
client.hist_trace_duration = hist_mock
|
||||
|
||||
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
|
||||
client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True}))
|
||||
_, attrs = hist_mock.record.call_args.args
|
||||
assert isinstance(attrs["meta"], str)
|
||||
assert attrs["ok"] is True
|
||||
@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None:
|
||||
],
|
||||
)
|
||||
def test_record_methods_handle_exceptions(
|
||||
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
|
||||
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents
|
||||
) -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name=attr_name)
|
||||
@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions(
|
||||
def test_metrics_initializes_grpc_metric_exporter() -> None:
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
|
||||
assert isinstance(exporter, DummyGrpcMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
|
||||
assert metric_reader.exporter.kwargs["insecure"] is False
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert exporter.kwargs["endpoint"] == "trace.example.com:4317"
|
||||
assert exporter.kwargs["insecure"] is False
|
||||
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyHttpMetricExporter, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
|
||||
assert isinstance(exporter, DummyHttpMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyJsonMetricExporter, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
|
||||
assert isinstance(exporter, DummyJsonMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert "preferred_temporality" in metric_reader.exporter.kwargs
|
||||
assert exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||
assert "preferred_temporality" in exporter.kwargs
|
||||
|
||||
|
||||
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey
|
||||
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
|
||||
_ = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
|
||||
assert "preferred_temporality" not in metric_reader.exporter.kwargs
|
||||
assert isinstance(exporter, DummyJsonMetricExporterNoTemporality)
|
||||
assert "preferred_temporality" not in exporter.kwargs
|
||||
|
||||
|
||||
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@ -31,13 +31,13 @@ class TestWeaveConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig()
|
||||
WeaveConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(api_key="key")
|
||||
WeaveConfig.model_validate({"api_key": "key"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(project="project")
|
||||
WeaveConfig.model_validate({"project": "project"})
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
|
||||
@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
|
||||
|
||||
auth = PasswordAuthenticator(config.user, config.password)
|
||||
options = ClusterOptions(auth)
|
||||
self._cluster = Cluster(config.connection_string, options)
|
||||
self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
|
||||
self._bucket = self._cluster.bucket(config.bucket_name)
|
||||
self._scope = self._bucket.scope(config.scope_name)
|
||||
self._bucket_name = config.bucket_name
|
||||
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
|
||||
def _load_collection_fields(self, fields: list[str] | None = None):
|
||||
if fields is None:
|
||||
# Load collection fields from remote server
|
||||
collection_info = self._client.describe_collection(self._collection_name)
|
||||
collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
|
||||
fields = [field["name"] for field in collection_info["fields"]]
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
|
||||
@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
|
||||
return False
|
||||
|
||||
try:
|
||||
milvus_version = self._client.get_server_version()
|
||||
milvus_version_raw = self._client.get_server_version()
|
||||
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
|
||||
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
|
||||
if "Zilliz Cloud" in milvus_version:
|
||||
return True
|
||||
|
||||
@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import jieba.posseg as pseg # type: ignore
|
||||
import numpy
|
||||
@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
|
||||
oracledb.defaults.fetch_lobs = False
|
||||
|
||||
|
||||
class _OraclePoolParams(TypedDict, total=False):
|
||||
user: str
|
||||
password: str
|
||||
dsn: str
|
||||
min: int
|
||||
max: int
|
||||
increment: int
|
||||
config_dir: str | None
|
||||
wallet_location: str | None
|
||||
wallet_password: str | None
|
||||
|
||||
|
||||
class OracleVectorConfig(BaseModel):
|
||||
user: str
|
||||
password: str
|
||||
@ -127,22 +139,18 @@ class OracleVector(BaseVector):
|
||||
return connection
|
||||
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
pool_params = {
|
||||
"user": config.user,
|
||||
"password": config.password,
|
||||
"dsn": config.dsn,
|
||||
"min": 1,
|
||||
"max": 5,
|
||||
"increment": 1,
|
||||
}
|
||||
pool_params = _OraclePoolParams(
|
||||
user=config.user,
|
||||
password=config.password,
|
||||
dsn=config.dsn,
|
||||
min=1,
|
||||
max=5,
|
||||
increment=1,
|
||||
)
|
||||
if config.is_autonomous:
|
||||
pool_params.update(
|
||||
{
|
||||
"config_dir": config.config_dir,
|
||||
"wallet_location": config.wallet_location,
|
||||
"wallet_password": config.wallet_password,
|
||||
}
|
||||
)
|
||||
pool_params["config_dir"] = config.config_dir
|
||||
pool_params["wallet_location"] = config.wallet_location
|
||||
pool_params["wallet_password"] = config.wallet_password
|
||||
return oracledb.create_pool(**pool_params)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
|
||||
@ -42,7 +42,7 @@ from libs.helper import convert_datetime_to_date
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.time_parser import get_time_threshold
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.human_input import HumanInputForm
|
||||
from models.human_input import HumanInputForm, HumanInputFormRecipient
|
||||
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
@ -63,6 +63,7 @@ class _WorkflowRunError(Exception):
|
||||
def _build_human_input_required_reason(
|
||||
reason_model: WorkflowPauseReason,
|
||||
form_model: HumanInputForm | None,
|
||||
recipients: Sequence[HumanInputFormRecipient] = (),
|
||||
) -> HumanInputRequired:
|
||||
form_content = ""
|
||||
inputs = []
|
||||
@ -89,7 +90,7 @@ def _build_human_input_required_reason(
|
||||
resolved_default_values = dict(definition.default_values)
|
||||
node_title = definition.node_title or node_title
|
||||
|
||||
return HumanInputRequired(
|
||||
reason = HumanInputRequired(
|
||||
form_id=form_id,
|
||||
form_content=form_content,
|
||||
inputs=inputs,
|
||||
@ -98,6 +99,7 @@ def _build_human_input_required_reason(
|
||||
node_title=node_title,
|
||||
resolved_default_values=resolved_default_values,
|
||||
)
|
||||
return reason
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
@ -804,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
|
||||
for form in session.scalars(form_stmt).all():
|
||||
form_models[form.id] = form
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {}
|
||||
if form_ids:
|
||||
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(recipient_stmt).all():
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient)
|
||||
|
||||
pause_reasons: list[PauseReason] = []
|
||||
for reason in pause_reason_models:
|
||||
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
form_model = form_models.get(reason.form_id)
|
||||
pause_reasons.append(_build_human_input_required_reason(reason, form_model))
|
||||
pause_reasons.append(
|
||||
_build_human_input_required_reason(
|
||||
reason,
|
||||
form_model,
|
||||
recipients_by_form_id.get(reason.form_id, ()),
|
||||
)
|
||||
)
|
||||
else:
|
||||
pause_reasons.append(reason.to_entity())
|
||||
return pause_reasons
|
||||
|
||||
@ -162,6 +162,7 @@ class AppGenerateService:
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
@ -183,6 +184,10 @@ class AppGenerateService:
|
||||
else:
|
||||
# Blocking mode: run synchronously and return JSON instead of SSE
|
||||
# Keep behaviour consistent with WORKFLOW blocking branch.
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory.get_session_maker(),
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
advanced_generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
advanced_generator.convert_to_event_stream(
|
||||
@ -194,6 +199,7 @@ class AppGenerateService:
|
||||
invoke_from=invoke_from,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=False,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
),
|
||||
request_id=request_id,
|
||||
|
||||
@ -5,6 +5,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cachetools.func import ttl_cache
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
|
||||
|
||||
class EnterpriseService:
|
||||
@classmethod
|
||||
@ttl_cache(ttl=5)
|
||||
def get_info(cls):
|
||||
return EnterpriseRequest.send_request("GET", "/info")
|
||||
|
||||
|
||||
@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel):
|
||||
enable_change_email: bool = True
|
||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||
trial_models: list[str] = []
|
||||
enable_creators_platform: bool = False
|
||||
enable_trial_app: bool = False
|
||||
enable_explore_banner: bool = False
|
||||
|
||||
@ -241,6 +242,9 @@ class FeatureService:
|
||||
if dify_config.MARKETPLACE_ENABLED:
|
||||
system_features.enable_marketplace = True
|
||||
|
||||
if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
system_features.enable_creators_platform = True
|
||||
|
||||
return system_features
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ToolProviderID
|
||||
@ -521,7 +521,7 @@ class BuiltinToolManageService:
|
||||
)
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from core.trigger.entities.api_entities import (
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
@ -635,7 +635,7 @@ class TriggerProviderService:
|
||||
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.entities.task_entities import (
|
||||
HumanInputRequiredResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
@ -22,10 +23,14 @@ from core.app.entities.task_entities import (
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models.human_input import HumanInputForm
|
||||
from models.model import AppMode, Message
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
@ -59,8 +64,10 @@ def build_workflow_event_stream(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
human_input_surface: HumanInputSurface | None = None,
|
||||
idle_timeout: float = 300,
|
||||
ping_interval: float = 10.0,
|
||||
close_on_pause: bool = True,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]:
|
||||
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
@ -115,13 +122,15 @@ def build_workflow_event_stream(
|
||||
message_context=message_context,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=human_input_surface,
|
||||
)
|
||||
|
||||
for event in snapshot_events:
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
yield event
|
||||
if _is_terminal_event(event, include_paused=True):
|
||||
if _is_terminal_event(event, close_on_pause=close_on_pause):
|
||||
return
|
||||
|
||||
while True:
|
||||
@ -146,7 +155,7 @@ def build_workflow_event_stream(
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
yield event
|
||||
if _is_terminal_event(event, include_paused=True):
|
||||
if _is_terminal_event(event, close_on_pause=close_on_pause):
|
||||
return
|
||||
finally:
|
||||
buffer_state.stop_event.set()
|
||||
@ -207,6 +216,8 @@ def _build_snapshot_events(
|
||||
message_context: MessageContext | None,
|
||||
pause_entity: WorkflowPauseEntity | None,
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
session_maker: sessionmaker[Session] | None = None,
|
||||
human_input_surface: HumanInputSurface | None = None,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
events: list[Mapping[str, Any]] = []
|
||||
|
||||
@ -241,12 +252,24 @@ def _build_snapshot_events(
|
||||
events.append(node_finished)
|
||||
|
||||
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
|
||||
for human_input_event in _build_human_input_required_events(
|
||||
workflow_run_id=workflow_run.id,
|
||||
task_id=task_id,
|
||||
pause_entity=pause_entity,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=human_input_surface,
|
||||
):
|
||||
_apply_message_context(human_input_event, message_context)
|
||||
events.append(human_input_event)
|
||||
|
||||
pause_event = _build_pause_event(
|
||||
workflow_run=workflow_run,
|
||||
workflow_run_id=workflow_run.id,
|
||||
task_id=task_id,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=human_input_surface,
|
||||
)
|
||||
if pause_event is not None:
|
||||
_apply_message_context(pause_event, message_context)
|
||||
@ -314,6 +337,97 @@ def _build_node_started_event(
|
||||
return response.to_ignore_detail_dict()
|
||||
|
||||
|
||||
def _build_human_input_required_events(
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
task_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
session_maker: sessionmaker[Session] | None,
|
||||
human_input_surface: HumanInputSurface | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
|
||||
human_input_form_ids = [
|
||||
form_id
|
||||
for reason in reasons
|
||||
if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
for form_id in [reason.get("form_id")]
|
||||
if isinstance(form_id, str)
|
||||
]
|
||||
|
||||
expiration_times_by_form_id: dict[str, int] = {}
|
||||
display_in_ui_by_form_id: dict[str, bool] = {}
|
||||
form_tokens_by_form_id: dict[str, str] = {}
|
||||
if human_input_form_ids and session_maker is not None:
|
||||
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where(
|
||||
HumanInputForm.id.in_(human_input_form_ids)
|
||||
)
|
||||
with session_maker() as session:
|
||||
for form_id, expiration_time, form_definition in session.execute(stmt):
|
||||
expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
|
||||
try:
|
||||
definition_payload = json.loads(form_definition) if form_definition else {}
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_tokens_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=human_input_surface,
|
||||
)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
for reason in reasons:
|
||||
if reason.get("TYPE") != PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
continue
|
||||
|
||||
form_id_raw = reason.get("form_id")
|
||||
node_id_raw = reason.get("node_id")
|
||||
node_title_raw = reason.get("node_title")
|
||||
form_content_raw = reason.get("form_content")
|
||||
if not isinstance(form_id_raw, str):
|
||||
continue
|
||||
if not isinstance(node_id_raw, str):
|
||||
continue
|
||||
if not isinstance(node_title_raw, str):
|
||||
continue
|
||||
if not isinstance(form_content_raw, str):
|
||||
continue
|
||||
form_id = form_id_raw
|
||||
node_id = node_id_raw
|
||||
node_title = node_title_raw
|
||||
form_content = form_content_raw
|
||||
|
||||
inputs = reason.get("inputs")
|
||||
actions = reason.get("actions")
|
||||
resolved_default_values = reason.get("resolved_default_values")
|
||||
|
||||
expiration_time = expiration_times_by_form_id.get(form_id)
|
||||
if expiration_time is None:
|
||||
continue
|
||||
|
||||
response = HumanInputRequiredResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id=form_id,
|
||||
node_id=node_id,
|
||||
node_title=node_title,
|
||||
form_content=form_content,
|
||||
inputs=inputs if isinstance(inputs, list) else [],
|
||||
actions=actions if isinstance(actions, list) else [],
|
||||
display_in_ui=display_in_ui_by_form_id.get(form_id, False),
|
||||
form_token=form_tokens_by_form_id.get(form_id),
|
||||
resolved_default_values=(resolved_default_values if isinstance(resolved_default_values, dict) else {}),
|
||||
expiration_time=expiration_time,
|
||||
),
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
events.append(payload)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def _build_node_finished_event(
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
@ -356,6 +470,8 @@ def _build_pause_event(
|
||||
task_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
session_maker: sessionmaker[Session] | None,
|
||||
human_input_surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
paused_nodes: list[str] = []
|
||||
outputs: dict[str, Any] = {}
|
||||
@ -365,6 +481,36 @@ def _build_pause_event(
|
||||
outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {}))
|
||||
|
||||
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
|
||||
human_input_form_ids = [
|
||||
form_id
|
||||
for reason in reasons
|
||||
if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
for form_id in [reason.get("form_id")]
|
||||
if isinstance(form_id, str)
|
||||
]
|
||||
form_tokens_by_form_id: dict[str, str] = {}
|
||||
expiration_times_by_form_id: dict[str, int] = {}
|
||||
if human_input_form_ids and session_maker is not None:
|
||||
with session_maker() as session:
|
||||
form_tokens_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=human_input_surface,
|
||||
)
|
||||
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
|
||||
HumanInputForm.id.in_(human_input_form_ids)
|
||||
)
|
||||
for row in session.execute(stmt):
|
||||
form_id, expiration_time, *_rest = row
|
||||
expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
# otherwise clients see schema drift after resume.
|
||||
reasons = enrich_human_input_pause_reasons(
|
||||
reasons,
|
||||
form_tokens_by_form_id=form_tokens_by_form_id,
|
||||
expiration_times_by_form_id=expiration_times_by_form_id,
|
||||
)
|
||||
|
||||
response = WorkflowPauseStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
@ -449,12 +595,19 @@ def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
|
||||
return event
|
||||
|
||||
|
||||
def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool:
|
||||
def _is_terminal_event(
|
||||
event: Mapping[str, Any] | str,
|
||||
close_on_pause: bool = True,
|
||||
*,
|
||||
include_paused: bool | None = None,
|
||||
) -> bool:
|
||||
if include_paused is not None:
|
||||
close_on_pause = include_paused
|
||||
if not isinstance(event, Mapping):
|
||||
return False
|
||||
event_type = event.get("event")
|
||||
if event_type == StreamEvent.WORKFLOW_FINISHED.value:
|
||||
return True
|
||||
if include_paused:
|
||||
if close_on_pause:
|
||||
return event_type == StreamEvent.WORKFLOW_PAUSED.value
|
||||
return False
|
||||
|
||||
@ -399,6 +399,8 @@ def _resume_advanced_chat(
|
||||
workflow_run_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> None:
|
||||
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
|
||||
|
||||
try:
|
||||
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
||||
except ValueError:
|
||||
@ -426,7 +428,7 @@ def _resume_advanced_chat(
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
application_generate_entity=generate_entity,
|
||||
application_generate_entity=resumed_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@ -436,9 +438,8 @@ def _resume_advanced_chat(
|
||||
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
|
||||
raise
|
||||
|
||||
if generate_entity.stream:
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
|
||||
|
||||
|
||||
def _resume_workflow(
|
||||
@ -455,6 +456,8 @@ def _resume_workflow(
|
||||
workflow_run_repo,
|
||||
pause_entity,
|
||||
) -> None:
|
||||
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
|
||||
|
||||
try:
|
||||
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
||||
except ValueError:
|
||||
@ -480,7 +483,7 @@ def _resume_workflow(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=generate_entity,
|
||||
application_generate_entity=resumed_generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
@ -490,11 +493,18 @@ def _resume_workflow(
|
||||
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
|
||||
raise
|
||||
|
||||
if generate_entity.stream:
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
|
||||
|
||||
workflow_run_repo.delete_workflow_pause(pause_entity)
|
||||
try:
|
||||
workflow_run_repo.delete_workflow_pause(pause_entity)
|
||||
except Exception as exc:
|
||||
if exc.__class__.__name__ != "_WorkflowRunError" or "WorkflowPause not found" not in str(exc):
|
||||
raise
|
||||
logger.info(
|
||||
"Skipped deleting workflow pause %s after resume because it was already replaced or removed",
|
||||
pause_entity.id,
|
||||
)
|
||||
|
||||
|
||||
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution")
|
||||
|
||||
@ -171,35 +171,13 @@ class TestChatMessageApiPermissions:
|
||||
parent_message_id=None,
|
||||
)
|
||||
|
||||
class MockQuery:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
if getattr(self.model, "__name__", "") == "Conversation":
|
||||
return mock_conversation
|
||||
return None
|
||||
|
||||
def order_by(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, *_):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
if getattr(self.model, "__name__", "") == "Message":
|
||||
return [mock_message]
|
||||
return []
|
||||
|
||||
mock_session = mock.Mock()
|
||||
mock_session.query.side_effect = MockQuery
|
||||
mock_session.scalar.return_value = False
|
||||
mock_session.scalar.return_value = mock_conversation
|
||||
mock_session.scalars.return_value.all.return_value = [mock_message]
|
||||
|
||||
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock())
|
||||
|
||||
class DummyPagination:
|
||||
def __init__(self, data, limit, has_more):
|
||||
|
||||
@ -24,7 +24,6 @@ def _patch_wraps():
|
||||
patch("controllers.console.wraps.dify_config", dify_settings),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
@ -11,6 +12,7 @@ import pytest
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.human_input_adapter import DeliveryMethodType
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.entities import WorkflowExecution
|
||||
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
@ -20,9 +22,11 @@ from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.human_input import (
|
||||
BackstageRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
@ -628,12 +632,12 @@ class TestPrivateWorkflowPauseEntity:
|
||||
class TestBuildHumanInputRequiredReason:
|
||||
"""Integration tests for _build_human_input_required_reason using real DB models."""
|
||||
|
||||
def test_builds_reason_from_form_definition(
|
||||
def test_prefers_standalone_web_app_token_when_available(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Build the graph pause reason from the stored form definition."""
|
||||
"""Use the public standalone web-app token for service API payloads."""
|
||||
|
||||
expiration_time = naive_utc_now()
|
||||
form_definition = FormDefinition(
|
||||
@ -660,6 +664,40 @@ class TestBuildHumanInputRequiredReason:
|
||||
db_session_with_containers.add(form_model)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
form_id=form_model.id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
channel_payload="{}",
|
||||
)
|
||||
db_session_with_containers.add(delivery)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
backstage_access_token = secrets.token_urlsafe(8)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=backstage_access_token,
|
||||
)
|
||||
console_access_token = secrets.token_urlsafe(8)
|
||||
console_recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload="{}",
|
||||
access_token=console_access_token,
|
||||
)
|
||||
web_app_access_token = secrets.token_urlsafe(8)
|
||||
web_app_recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
recipient_payload="{}",
|
||||
access_token=web_app_access_token,
|
||||
)
|
||||
db_session_with_containers.add_all([backstage_recipient, console_recipient, web_app_recipient])
|
||||
db_session_with_containers.flush()
|
||||
# Create a pause so the reason has a valid pause_id
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
@ -688,8 +726,15 @@ class TestBuildHumanInputRequiredReason:
|
||||
# Refresh to ensure we have DB-round-tripped objects
|
||||
db_session_with_containers.refresh(form_model)
|
||||
db_session_with_containers.refresh(reason_model)
|
||||
db_session_with_containers.refresh(backstage_recipient)
|
||||
db_session_with_containers.refresh(console_recipient)
|
||||
db_session_with_containers.refresh(web_app_recipient)
|
||||
|
||||
reason = _build_human_input_required_reason(reason_model, form_model)
|
||||
reason = _build_human_input_required_reason(
|
||||
reason_model,
|
||||
form_model,
|
||||
[backstage_recipient, console_recipient, web_app_recipient],
|
||||
)
|
||||
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert reason.node_title == "Ask Name"
|
||||
@ -697,3 +742,92 @@ class TestBuildHumanInputRequiredReason:
|
||||
assert reason.inputs[0].output_variable_name == "name"
|
||||
assert reason.actions[0].id == "approve"
|
||||
assert reason.resolved_default_values == {"name": "Alice"}
|
||||
assert not hasattr(reason, "form_token")
|
||||
|
||||
def test_falls_back_to_console_token_when_web_app_token_missing(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Use the console token only when no standalone web-app token exists."""
|
||||
|
||||
expiration_time = naive_utc_now()
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "Alice"},
|
||||
node_title="Ask Name",
|
||||
display_in_ui=True,
|
||||
)
|
||||
|
||||
form_model = HumanInputForm(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
workflow_run_id=str(uuid4()),
|
||||
node_id="node-1",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="rendered",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
db_session_with_containers.add(form_model)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
form_id=form_model.id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
channel_payload="{}",
|
||||
)
|
||||
db_session_with_containers.add(delivery)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
backstage_access_token = secrets.token_urlsafe(8)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=backstage_access_token,
|
||||
)
|
||||
console_access_token = secrets.token_urlsafe(8)
|
||||
console_recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload="{}",
|
||||
access_token=console_access_token,
|
||||
)
|
||||
db_session_with_containers.add_all([backstage_recipient, console_recipient])
|
||||
db_session_with_containers.flush()
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.flush()
|
||||
test_scope.state_keys.add(pause.state_object_key)
|
||||
|
||||
reason_model = WorkflowPauseReason(
|
||||
pause_id=pause.id,
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id=form_model.id,
|
||||
node_id="node-1",
|
||||
message="",
|
||||
)
|
||||
db_session_with_containers.add(reason_model)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient, console_recipient])
|
||||
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert not hasattr(reason, "form_token")
|
||||
|
||||
@ -13,6 +13,12 @@ from models.model import App, Conversation, Message
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
|
||||
def _execute_result(rows):
|
||||
result = mock.Mock()
|
||||
result.all.return_value = rows
|
||||
return result
|
||||
|
||||
|
||||
class TestFeedbackService:
|
||||
"""Test FeedbackService methods."""
|
||||
|
||||
@ -81,25 +87,17 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback data in CSV format."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test CSV export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
@ -120,25 +118,17 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback data in JSON format."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test JSON export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
@ -157,25 +147,17 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with various filters."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test with filters
|
||||
result = FeedbackService.export_feedbacks(
|
||||
@ -193,17 +175,7 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback when no data exists."""
|
||||
|
||||
# Setup mock query result with no data
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result([])
|
||||
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
@ -251,24 +223,17 @@ class TestFeedbackService:
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
long_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
long_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
@ -309,24 +274,17 @@ class TestFeedbackService:
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
chinese_feedback,
|
||||
chinese_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
None, # No account for user feedback
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
chinese_feedback,
|
||||
chinese_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
None,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
@ -339,32 +297,24 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
|
||||
"""Test that rating emojis are properly formatted in export."""
|
||||
|
||||
# Setup mock query result with both like and dislike feedback
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
),
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
),
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
),
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine):
|
||||
configure_session_factory(_unit_test_engine, expire_on_commit=False)
|
||||
|
||||
|
||||
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
|
||||
def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner):
|
||||
"""
|
||||
Helper to set up the mock DB execute chain for tenant/account authentication.
|
||||
Helper to stub the tenant-owner execute result for service API app authentication.
|
||||
|
||||
This configures the mock to return (tenant, account) for the
|
||||
db.session.execute(select(...).join().join().where()).one_or_none()
|
||||
query used by validate_app_token decorator.
|
||||
The validate_app_token decorator currently resolves the active tenant owner
|
||||
via db.session.execute(select(Tenant, Account)...).one_or_none().
|
||||
|
||||
Args:
|
||||
mock_db: The mocked db object
|
||||
mock_tenant: Mock tenant object to return
|
||||
mock_account: Mock account object to return
|
||||
mock_owner: Mock owner object to return from the execute result
|
||||
"""
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account)
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner)
|
||||
|
||||
|
||||
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
|
||||
def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join):
|
||||
"""
|
||||
Helper to set up the mock DB execute chain for dataset tenant authentication.
|
||||
Helper to stub the tenant-owner execute result for dataset token authentication.
|
||||
|
||||
This configures the mock to return (tenant, tenant_account) for the
|
||||
db.session.execute(select(...).where().where().where().where()).one_or_none()
|
||||
query used by validate_dataset_token decorator.
|
||||
The validate_dataset_token decorator currently resolves the owner mapping via
|
||||
db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and
|
||||
then loads the Account separately via db.session.get(...).
|
||||
|
||||
Args:
|
||||
mock_db: The mocked db object
|
||||
mock_tenant: Mock tenant object to return
|
||||
mock_ta: Mock tenant account object to return
|
||||
mock_tenant_account_join: Mock tenant-account join object to return
|
||||
"""
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join)
|
||||
|
||||
@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation:
|
||||
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with (
|
||||
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
|
||||
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
|
||||
@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
|
||||
@ -43,7 +43,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = True
|
||||
|
||||
# Act
|
||||
@ -76,7 +75,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
@ -109,7 +107,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = False
|
||||
|
||||
# Act
|
||||
@ -135,7 +132,6 @@ class TestAuthenticationSecurity:
|
||||
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
|
||||
"""Test that reset password returns success with token for existing accounts."""
|
||||
# Mock the setup check
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Test with existing account
|
||||
mock_get_user.return_value = MagicMock(email="existing@example.com")
|
||||
|
||||
@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- IP rate limiting is checked
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "email_token_123"
|
||||
@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Registration is allowed by system features
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = None
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Registration is blocked by system features
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = None
|
||||
mock_get_features.return_value.is_allow_register = False
|
||||
@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Prevents spam and abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- AccountInFreezeError is raised for frozen accounts
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.side_effect = AccountRegisterError("Account frozen")
|
||||
|
||||
@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Defaults to en-US when not specified
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
@ -286,7 +280,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is logged in with token pair
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
@ -335,7 +328,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is logged in after account creation
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = None
|
||||
mock_create_account.return_value = mock_account
|
||||
@ -369,7 +361,6 @@ class TestEmailCodeLoginApi:
|
||||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
@ -392,7 +383,6 @@ class TestEmailCodeLoginApi:
|
||||
- InvalidEmailError is raised when email doesn't match token
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
@ -415,7 +405,6 @@ class TestEmailCodeLoginApi:
|
||||
- EmailCodeError is raised for wrong verification code
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
@ -453,7 +442,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is added as owner of new workspace
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
@ -496,7 +484,6 @@ class TestEmailCodeLoginApi:
|
||||
- WorkspacesLimitExceeded is raised when limit reached
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
@ -538,7 +525,6 @@ class TestEmailCodeLoginApi:
|
||||
- NotAllowedCreateWorkspace is raised when creation disabled
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
|
||||
@ -110,7 +110,6 @@ class TestLoginApi:
|
||||
- Rate limit is reset after successful login
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.return_value = mock_account
|
||||
@ -162,7 +161,6 @@ class TestLoginApi:
|
||||
- Authentication proceeds with invitation token
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
|
||||
mock_authenticate.return_value = mock_account
|
||||
@ -199,7 +197,6 @@ class TestLoginApi:
|
||||
- EmailPasswordLoginLimitError is raised when limit exceeded
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = True
|
||||
mock_get_invitation.return_value = None
|
||||
|
||||
@ -228,7 +225,6 @@ class TestLoginApi:
|
||||
- AccountInFreezeError is raised for frozen accounts
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_frozen.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@ -268,7 +264,6 @@ class TestLoginApi:
|
||||
- Generic error message prevents user enumeration
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
|
||||
@ -305,7 +300,6 @@ class TestLoginApi:
|
||||
- Login is prevented even with valid credentials
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = AccountLoginError("Account is banned")
|
||||
@ -351,7 +345,6 @@ class TestLoginApi:
|
||||
- User cannot login without an assigned workspace
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.return_value = mock_account
|
||||
@ -383,7 +376,6 @@ class TestLoginApi:
|
||||
- Security check prevents invitation token abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
|
||||
|
||||
@ -425,7 +417,6 @@ class TestLoginApi:
|
||||
mock_token_pair,
|
||||
):
|
||||
"""Test that login retries with lowercase email when uppercase lookup fails."""
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
|
||||
@ -459,7 +450,6 @@ class TestLoginApi:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
|
||||
mock_get_account.side_effect = Unauthorized("Account is banned.")
|
||||
|
||||
@ -513,7 +503,6 @@ class TestLogoutApi:
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_current_account.return_value = (mock_account, MagicMock())
|
||||
|
||||
# Act
|
||||
@ -539,7 +528,6 @@ class TestLogoutApi:
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
# Create a mock anonymous user that will pass isinstance check
|
||||
anonymous_user = MagicMock()
|
||||
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})
|
||||
|
||||
@ -46,7 +46,6 @@ class TestPartnerTenants:
|
||||
patch("libs.login.dify_config.LOGIN_DISABLED", False),
|
||||
patch("libs.login.check_csrf_token") as mock_csrf,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_csrf.return_value = None
|
||||
yield {"db": mock_db, "csrf": mock_csrf}
|
||||
|
||||
|
||||
@ -8,8 +8,10 @@ from werkzeug.exceptions import Forbidden
|
||||
import controllers.console.tag.tags as module
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.tag.tags import (
|
||||
TagBindingCreateApi,
|
||||
TagBindingDeleteApi,
|
||||
DeprecatedTagBindingCreateApi,
|
||||
DeprecatedTagBindingRemoveApi,
|
||||
TagBindingCollectionApi,
|
||||
TagBindingItemApi,
|
||||
TagListApi,
|
||||
TagUpdateDeleteApi,
|
||||
)
|
||||
@ -205,9 +207,9 @@ class TestTagUpdateDeleteApi:
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestTagBindingCreateApi:
|
||||
class TestTagBindingCollectionApi:
|
||||
def test_create_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingCreateApi()
|
||||
api = TagBindingCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@ -232,7 +234,7 @@ class TestTagBindingCreateApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_create_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagBindingCreateApi()
|
||||
api = TagBindingCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
@ -247,9 +249,78 @@ class TestTagBindingCreateApi:
|
||||
method(api)
|
||||
|
||||
|
||||
class TestTagBindingDeleteApi:
|
||||
class TestDeprecatedTagBindingCreateApi:
|
||||
def test_create_success(self, app, admin_user, payload_patch):
|
||||
api = DeprecatedTagBindingCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"tag_ids": ["tag-1"],
|
||||
"target_id": "target-1",
|
||||
"type": "knowledge",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestTagBindingItemApi:
|
||||
def test_delete_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingItemApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {
|
||||
"target_id": "target-1",
|
||||
"type": "knowledge",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
delete_payload = delete_mock.call_args.args[0]
|
||||
assert delete_payload.tag_id == "tag-1"
|
||||
assert delete_payload.target_id == "target-1"
|
||||
assert delete_payload.type == TagType.KNOWLEDGE
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_forbidden(self, app, readonly_user):
|
||||
api = TagBindingItemApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "tag-1")
|
||||
|
||||
|
||||
class TestDeprecatedTagBindingRemoveApi:
|
||||
def test_remove_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingDeleteApi()
|
||||
api = DeprecatedTagBindingRemoveApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@ -274,7 +345,7 @@ class TestTagBindingDeleteApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_remove_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagBindingDeleteApi()
|
||||
api = DeprecatedTagBindingRemoveApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
@ -297,3 +368,35 @@ class TestTagResponseModel:
|
||||
|
||||
assert payload["type"] == "knowledge"
|
||||
assert payload["binding_count"] == "1"
|
||||
|
||||
|
||||
class TestTagBindingRouteMetadata:
|
||||
def test_legacy_write_routes_are_marked_deprecated(self):
|
||||
assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True
|
||||
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True
|
||||
assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True
|
||||
assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True
|
||||
|
||||
def test_write_routes_have_stable_operation_ids(self):
|
||||
assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding"
|
||||
assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding"
|
||||
assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated"
|
||||
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated"
|
||||
|
||||
def test_canonical_and_legacy_write_routes_are_registered(self):
|
||||
route_map = {
|
||||
resource.__name__: urls
|
||||
for resource, urls, _route_doc, _kwargs in console_ns.resources
|
||||
if resource.__name__
|
||||
in {
|
||||
"TagBindingCollectionApi",
|
||||
"TagBindingItemApi",
|
||||
"DeprecatedTagBindingCreateApi",
|
||||
"DeprecatedTagBindingRemoveApi",
|
||||
}
|
||||
}
|
||||
|
||||
assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",)
|
||||
assert route_map["TagBindingItemApi"] == ("/tag-bindings/<uuid:id>",)
|
||||
assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",)
|
||||
assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",)
|
||||
|
||||
@ -122,6 +122,35 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch)
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_by_token(self, _token):
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
submit_mock = Mock()
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
|
||||
|
||||
@ -24,10 +24,6 @@ def app():
|
||||
return app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
|
||||
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
|
||||
account = Account(name=account_id, email=email)
|
||||
@ -64,7 +60,6 @@ class TestChangeEmailSend:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@ -117,7 +112,6 @@ class TestChangeEmailSend:
|
||||
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@ -163,7 +157,6 @@ class TestChangeEmailValidity:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@ -223,7 +216,6 @@ class TestChangeEmailValidity:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@ -280,7 +272,6 @@ class TestChangeEmailValidity:
|
||||
"""A token whose phase marker is a string but not a known transition must be rejected."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@ -330,7 +321,6 @@ class TestChangeEmailValidity:
|
||||
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@ -378,7 +368,6 @@ class TestChangeEmailReset:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@ -434,7 +423,6 @@ class TestChangeEmailReset:
|
||||
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@ -488,7 +476,6 @@ class TestChangeEmailReset:
|
||||
"""A verified token for address A must not be replayed to change to address B."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@ -561,7 +548,6 @@ class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
|
||||
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
with app.test_request_context(
|
||||
"/account/delete/feedback",
|
||||
method="POST",
|
||||
@ -578,7 +564,6 @@ class TestCheckEmailUnique:
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
@ -16,10 +16,6 @@ def app():
|
||||
return flask_app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_feature_flags():
|
||||
placeholder_quota = SimpleNamespace(limit=0, size=0)
|
||||
workspace_members = SimpleNamespace(is_available=lambda count: True)
|
||||
@ -49,7 +45,6 @@ class TestMemberInviteEmailApi:
|
||||
mock_get_features,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_get_features.return_value = _build_feature_flags()
|
||||
mock_invite_member.return_value = "token-abc"
|
||||
|
||||
|
||||
@ -310,7 +310,6 @@ class TestSystemSetup:
|
||||
def test_should_allow_when_setup_complete(self, mock_db):
|
||||
"""Test that requests are allowed when setup is complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
|
||||
|
||||
@setup_required
|
||||
def admin_view():
|
||||
|
||||
@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None
|
||||
|
||||
@contextmanager
|
||||
def _mock_db():
|
||||
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
|
||||
mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True)
|
||||
with patch("extensions.ext_database.db.session", mock_session):
|
||||
yield
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, AppMode
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
|
||||
|
||||
|
||||
class TestAppParameterApi:
|
||||
@ -74,7 +74,7 @@ class TestAppParameterApi:
|
||||
# Mock tenant owner info for login
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -120,7 +120,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -161,7 +161,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -200,7 +200,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -263,7 +263,7 @@ class TestAppMetaApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -331,7 +331,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -388,7 +388,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -434,7 +434,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -486,7 +486,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
|
||||
@ -0,0 +1,707 @@
|
||||
"""Dedicated tests for HITL behavior exposed through the Service API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
import services.app_generate_service as ags_module
|
||||
from controllers.service_api.app.workflow_events import WorkflowEventsApi
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.common import workflow_response_converter
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
HumanInputRequiredResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
from graphon.nodes.human_input.enums import FormInputType
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.workflow_event_snapshot_service import _build_snapshot_events
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
class _DummyRateLimit:
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
return "dummy-request-id"
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int) -> None:
|
||||
self.client_id = client_id
|
||||
self.max_active_requests = max_active_requests
|
||||
|
||||
def enter(self, request_id: str | None = None) -> str:
|
||||
return request_id or "dummy-request-id"
|
||||
|
||||
def exit(self, request_id: str) -> None:
|
||||
return None
|
||||
|
||||
def generate(self, generator, request_id: str):
|
||||
return generator
|
||||
|
||||
|
||||
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
|
||||
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
|
||||
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
|
||||
monkeypatch.setattr(
|
||||
workflow_events_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: repo,
|
||||
)
|
||||
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
|
||||
return workflow_events_module
|
||||
|
||||
|
||||
def _build_service_api_pause_converter() -> WorkflowResponseConverter:
|
||||
application_generate_entity = SimpleNamespace(
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
|
||||
)
|
||||
system_variables = build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id="run-id",
|
||||
)
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "account-id"
|
||||
user.name = "Tester"
|
||||
user.email = "tester@example.com"
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def _build_advanced_chat_paused_blocking_response() -> AdvancedChatPausedBlockingResponse:
|
||||
data = AdvancedChatPausedBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
workflow_run_id="run-1",
|
||||
answer="partial",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[
|
||||
{
|
||||
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
"form_id": "form-1",
|
||||
"expiration_time": 100,
|
||||
}
|
||||
],
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
)
|
||||
return AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
|
||||
|
||||
|
||||
def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse:
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
id="r1",
|
||||
workflow_id="wf-1",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
error=None,
|
||||
elapsed_time=0.5,
|
||||
total_tokens=0,
|
||||
total_steps=2,
|
||||
created_at=1,
|
||||
finished_at=None,
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _FakePauseEntity(WorkflowPauseEntity):
|
||||
pause_id: str
|
||||
workflow_run_id: str
|
||||
paused_at_value: datetime
|
||||
pause_reasons: Sequence[HumanInputRequired]
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.pause_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return self.workflow_run_id
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
raise AssertionError("state is not required for snapshot tests")
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return self.paused_at_value
|
||||
|
||||
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
|
||||
return self.pause_reasons
|
||||
|
||||
|
||||
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"input": "value"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
|
||||
created_at = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
|
||||
return WorkflowNodeExecutionSnapshot(
|
||||
execution_id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
title="Human Input",
|
||||
index=1,
|
||||
status=status.value,
|
||||
elapsed_time=0.5,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
iteration_id=None,
|
||||
loop_id=None,
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.register_paused_node("node-1")
|
||||
runtime_state.outputs = {"result": "value"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
class TestHitlServiceApi:
|
||||
# Service API event-stream continuation
|
||||
def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
msg_generator.retrieve_events.return_value = ["raw-event"]
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: streamed\n\n"
|
||||
msg_generator.retrieve_events.assert_called_once_with(
|
||||
AppMode.WORKFLOW,
|
||||
"run-1",
|
||||
terminal_events=[],
|
||||
)
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
|
||||
|
||||
def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
|
||||
self, app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
|
||||
snapshot_builder = Mock(return_value=["snapshot-events"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true",
|
||||
method="GET",
|
||||
):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: snapshot\n\n"
|
||||
msg_generator.retrieve_events.assert_not_called()
|
||||
snapshot_builder.assert_called_once_with(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=ANY,
|
||||
human_input_surface=HumanInputSurface.SERVICE_API,
|
||||
close_on_pause=False,
|
||||
)
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
|
||||
|
||||
def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False)
|
||||
monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit)
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow)
|
||||
monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker")
|
||||
|
||||
generator_instance = MagicMock()
|
||||
generator_instance.generate.return_value = {"result": "advanced-blocking"}
|
||||
generator_instance.convert_to_event_stream.side_effect = lambda payload: payload
|
||||
monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance)
|
||||
|
||||
app_model = MagicMock()
|
||||
app_model.mode = AppMode.ADVANCED_CHAT
|
||||
app_model.id = "app-id"
|
||||
app_model.tenant_id = "tenant-id"
|
||||
app_model.max_active_requests = 0
|
||||
app_model.is_agent = False
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "user-id"
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"workflow_id": None, "query": "hi", "inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"result": "advanced-blocking"}
|
||||
call_kwargs = generator_instance.generate.call_args.kwargs
|
||||
assert call_kwargs["streaming"] is False
|
||||
assert call_kwargs["pause_state_config"] is not None
|
||||
assert call_kwargs["pause_state_config"].session_factory == "session-maker"
|
||||
assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id"
|
||||
|
||||
# Blocking payload contract
|
||||
def test_advanced_chat_blocking_pause_payload_contract(self) -> None:
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(
|
||||
_build_advanced_chat_paused_blocking_response()
|
||||
)
|
||||
|
||||
assert response["event"] == "workflow_paused"
|
||||
assert response["workflow_run_id"] == "run-1"
|
||||
assert response["answer"] == "partial"
|
||||
assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
assert response["data"]["reasons"][0]["expiration_time"] == 100
|
||||
assert "human_input_forms" not in response["data"]
|
||||
|
||||
def test_workflow_blocking_pause_payload_contract(self) -> None:
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
|
||||
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(
|
||||
_build_workflow_paused_blocking_response()
|
||||
)
|
||||
|
||||
assert response["workflow_run_id"] == "r1"
|
||||
assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED
|
||||
assert response["data"]["paused_nodes"] == ["node-1"]
|
||||
assert response["data"]["reasons"] == [
|
||||
{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}
|
||||
]
|
||||
assert "human_input_forms" not in response["data"]
|
||||
|
||||
def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None:
|
||||
from core.app.app_config.entities import AppAdditionalFeatures
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from models.enums import MessageStatus
|
||||
from models.model import EndUser
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
trace_manager=None,
|
||||
workflow_run_id="run-id",
|
||||
)
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT),
|
||||
message=SimpleNamespace(
|
||||
id="message-id",
|
||||
query="hello",
|
||||
created_at=datetime.utcnow(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
),
|
||||
user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"),
|
||||
stream=False,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
pipeline._task_state.answer = "partial answer"
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
|
||||
def _gen():
|
||||
yield HumanInputRequiredResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
form_content="Need approval",
|
||||
inputs=[],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
form_token="token-1",
|
||||
resolved_default_values={},
|
||||
expiration_time=123,
|
||||
),
|
||||
)
|
||||
yield WorkflowPauseStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id="run-id",
|
||||
paused_nodes=["node-1"],
|
||||
outputs={},
|
||||
reasons=[
|
||||
{
|
||||
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
"form_id": "form-1",
|
||||
"node_id": "node-1",
|
||||
"expiration_time": 123,
|
||||
},
|
||||
],
|
||||
status="paused",
|
||||
created_at=1,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert isinstance(response, AdvancedChatPausedBlockingResponse)
|
||||
assert response.data.answer == "partial answer"
|
||||
assert response.data.workflow_run_id == "run-id"
|
||||
assert response.data.reasons[0]["form_id"] == "form-1"
|
||||
assert response.data.reasons[0]["expiration_time"] == 123
|
||||
|
||||
def test_workflow_blocking_pipeline_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from core.app.apps.workflow import generate_task_pipeline as workflow_pipeline_module
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
extras={},
|
||||
call_depth=0,
|
||||
)
|
||||
pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
user=SimpleNamespace(id="user", session_id="session"),
|
||||
stream=False,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(workflow_pipeline_module.time, "time", lambda: 1700000000)
|
||||
|
||||
def _gen():
|
||||
yield HumanInputRequiredResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
form_content="content",
|
||||
expiration_time=1,
|
||||
),
|
||||
)
|
||||
yield WorkflowPauseStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id="run",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}],
|
||||
created_at=1,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert isinstance(response, WorkflowAppPausedBlockingResponse)
|
||||
assert response.data.status == WorkflowExecutionStatus.PAUSED
|
||||
assert response.data.paused_nodes == ["node-1"]
|
||||
assert response.data.reasons == [{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}]
|
||||
|
||||
def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
converter = _build_service_api_pause_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
class _FakeSession:
|
||||
def execute(self, _stmt):
|
||||
return [("form-1", expiration_time, '{"display_in_ui": true}')]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
|
||||
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
workflow_response_converter,
|
||||
"load_form_tokens_by_form_id",
|
||||
lambda form_ids, session=None, surface=None: {"form-1": "token"},
|
||||
)
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="Rendered",
|
||||
inputs=[
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
|
||||
],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
node_id="node-id",
|
||||
node_title="Human Step",
|
||||
form_token="token",
|
||||
)
|
||||
queue_event = QueueWorkflowPausedEvent(
|
||||
reasons=[reason],
|
||||
outputs={"answer": "value"},
|
||||
paused_nodes=["node-id"],
|
||||
)
|
||||
|
||||
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
|
||||
responses = converter.workflow_pause_to_stream_response(
|
||||
event=queue_event,
|
||||
task_id="task",
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
|
||||
pause_resp = responses[-1]
|
||||
assert pause_resp.workflow_run_id == "run-id"
|
||||
assert pause_resp.data.paused_nodes == ["node-id"]
|
||||
assert pause_resp.data.outputs == {}
|
||||
assert pause_resp.data.reasons[0]["TYPE"] == "human_input_required"
|
||||
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
|
||||
assert pause_resp.data.reasons[0]["form_token"] == "token"
|
||||
assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp())
|
||||
|
||||
assert isinstance(responses[0], HumanInputRequiredResponse)
|
||||
hi_resp = responses[0]
|
||||
assert hi_resp.data.form_id == "form-1"
|
||||
assert hi_resp.data.node_id == "node-id"
|
||||
assert hi_resp.data.node_title == "Human Step"
|
||||
assert hi_resp.data.inputs[0].output_variable_name == "field"
|
||||
assert hi_resp.data.actions[0].id == "approve"
|
||||
assert hi_resp.data.display_in_ui is True
|
||||
assert hi_resp.data.form_token == "token"
|
||||
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
|
||||
|
||||
# Snapshot payload contract
|
||||
def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
|
||||
resumption_context = _build_resumption_context("task-ctx")
|
||||
monkeypatch.setattr(
|
||||
"services.workflow_event_snapshot_service.load_form_tokens_by_form_id",
|
||||
lambda form_ids, session=None, surface=None: {"form-1": "wtok"},
|
||||
)
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def session_maker() -> _SessionContext:
|
||||
return _SessionContext(
|
||||
SimpleNamespace(
|
||||
execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')],
|
||||
)
|
||||
)
|
||||
|
||||
pause_entity = _FakePauseEntity(
|
||||
pause_id="pause-1",
|
||||
workflow_run_id="run-1",
|
||||
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
pause_reasons=[
|
||||
HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
form_token="wtok",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=[snapshot],
|
||||
task_id="task-ctx",
|
||||
message_context=None,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
|
||||
assert [event["event"] for event in events] == [
|
||||
"workflow_started",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
"human_input_required",
|
||||
"workflow_paused",
|
||||
]
|
||||
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
|
||||
assert events[3]["data"]["form_token"] == "wtok"
|
||||
assert events[3]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
|
||||
pause_data = events[-1]["data"]
|
||||
assert pause_data["paused_nodes"] == ["node-1"]
|
||||
assert pause_data["outputs"] == {"result": "value"}
|
||||
assert pause_data["reasons"][0]["TYPE"] == "human_input_required"
|
||||
assert pause_data["reasons"][0]["form_token"] == "wtok"
|
||||
assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
|
||||
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
|
||||
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
|
||||
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
|
||||
assert pause_data["total_tokens"] == workflow_run.total_tokens
|
||||
assert pause_data["total_steps"] == workflow_run.total_steps
|
||||
@ -0,0 +1,184 @@
|
||||
"""Unit tests for Service API human input form endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi
|
||||
from models.human_input import RecipientType
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
class TestWorkflowHumanInputFormApi:
|
||||
def test_get_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
definition = SimpleNamespace(
|
||||
model_dump=lambda: {
|
||||
"rendered_content": "Rendered form content",
|
||||
"inputs": [{"output_variable_name": "name"}],
|
||||
"default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}},
|
||||
"user_actions": [{"id": "approve", "title": "Approve"}],
|
||||
}
|
||||
)
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
get_definition=lambda: definition,
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
response = handler(api, app_model=app_model, form_token="token-1")
|
||||
|
||||
payload = json.loads(response.get_data(as_text=True))
|
||||
assert payload == {
|
||||
"form_content": "Rendered form content",
|
||||
"inputs": [{"output_variable_name": "name"}],
|
||||
"resolved_default_values": {"name": "Alice", "age": "30", "meta": '{"k": "v"}'},
|
||||
"user_actions": [{"id": "approve", "title": "Approve"}],
|
||||
"expiration_time": int(form.expiration_time.timestamp()),
|
||||
}
|
||||
service_mock.get_form_by_token.assert_called_once_with("token-1")
|
||||
service_mock.ensure_form_active.assert_called_once_with(form)
|
||||
|
||||
def test_get_form_not_in_app(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="another-app",
|
||||
tenant_id="tenant-1",
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, form_token="token-1")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"recipient_type",
|
||||
[
|
||||
RecipientType.CONSOLE,
|
||||
RecipientType.BACKSTAGE,
|
||||
RecipientType.EMAIL_MEMBER,
|
||||
RecipientType.EMAIL_EXTERNAL,
|
||||
],
|
||||
)
|
||||
def test_get_rejects_non_service_api_recipient_types(
|
||||
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
|
||||
) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=recipient_type,
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, form_token="token-1")
|
||||
|
||||
service_mock.ensure_form_active.assert_not_called()
|
||||
|
||||
def test_post_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/form/human_input/token-1",
|
||||
method="POST",
|
||||
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
|
||||
):
|
||||
response, status = handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
|
||||
|
||||
assert response == {}
|
||||
assert status == 200
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token-1",
|
||||
selected_action_id="approve",
|
||||
form_data={"name": "Alice"},
|
||||
submission_end_user_id="end-user-1",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"recipient_type",
|
||||
[
|
||||
RecipientType.CONSOLE,
|
||||
RecipientType.BACKSTAGE,
|
||||
RecipientType.EMAIL_MEMBER,
|
||||
RecipientType.EMAIL_EXTERNAL,
|
||||
],
|
||||
)
|
||||
def test_post_rejects_non_service_api_recipient_types(
|
||||
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
|
||||
) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=recipient_type,
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/form/human_input/token-1",
|
||||
method="POST",
|
||||
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
|
||||
|
||||
service_mock.submit_form_by_token.assert_not_called()
|
||||
@ -0,0 +1,166 @@
|
||||
"""Unit tests for Service API workflow event stream endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.app.workflow_events import WorkflowEventsApi
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
|
||||
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
|
||||
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
|
||||
monkeypatch.setattr(
|
||||
workflow_events_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: repo,
|
||||
)
|
||||
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
|
||||
return workflow_events_module
|
||||
|
||||
|
||||
class TestWorkflowEventsApi:
|
||||
def test_wrong_app_mode(self, app) -> None:
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_mock_repo_for_run(monkeypatch, workflow_run=None)
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="another-user",
|
||||
finished_at=None,
|
||||
)
|
||||
_mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
monkeypatch.setattr(
|
||||
workflow_events_module.WorkflowResponseConverter,
|
||||
"workflow_run_result_to_finish_response",
|
||||
lambda **_kwargs: SimpleNamespace(
|
||||
model_dump=lambda mode="json": {"task_id": "run-1", "status": "succeeded"},
|
||||
event=SimpleNamespace(value="workflow_finished"),
|
||||
),
|
||||
)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
body = response.get_data(as_text=True).strip()
|
||||
assert body.startswith("data: ")
|
||||
payload = json.loads(body[len("data: ") :])
|
||||
assert payload["task_id"] == "run-1"
|
||||
assert payload["event"] == "workflow_finished"
|
||||
|
||||
def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
msg_generator.retrieve_events.return_value = ["raw-event"]
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: streamed\n\n"
|
||||
msg_generator.retrieve_events.assert_called_once_with(
|
||||
AppMode.WORKFLOW,
|
||||
"run-1",
|
||||
terminal_events=None,
|
||||
)
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
|
||||
|
||||
def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
|
||||
snapshot_builder = Mock(return_value=["snapshot-events"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: snapshot\n\n"
|
||||
msg_generator.retrieve_events.assert_not_called()
|
||||
snapshot_builder.assert_called_once()
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
|
||||
@ -15,7 +15,10 @@ from flask import Flask
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, AppMode, EndUser
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
from tests.unit_tests.conftest import (
|
||||
setup_mock_dataset_owner_execute_result,
|
||||
setup_mock_tenant_owner_execute_result,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -123,9 +126,7 @@ class AuthenticationMocker:
|
||||
mock_db.session.get.side_effect = [mock_app, mock_tenant]
|
||||
|
||||
if mock_account:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
@staticmethod
|
||||
def setup_dataset_auth(mock_db, mock_tenant, mock_account):
|
||||
@ -133,8 +134,7 @@ class AuthenticationMocker:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
|
||||
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
|
||||
mock_db.session.get.return_value = mock_account
|
||||
|
||||
|
||||
|
||||
@ -701,8 +701,8 @@ class TestDocumentApiDelete:
|
||||
``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which
|
||||
internally calls ``validate_and_get_api_token``. To bypass the decorator
|
||||
we call the original function via ``__wrapped__`` (preserved by
|
||||
``functools.wraps``). ``delete`` queries the dataset via
|
||||
``db.session.query(Dataset)`` directly, so we patch ``db`` at the
|
||||
``functools.wraps``). ``delete`` loads the dataset via
|
||||
``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the
|
||||
controller module.
|
||||
"""
|
||||
|
||||
|
||||
@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan
|
||||
from models.account import TenantStatus
|
||||
from models.model import ApiToken
|
||||
from tests.unit_tests.conftest import (
|
||||
setup_mock_dataset_tenant_query,
|
||||
setup_mock_tenant_account_query,
|
||||
setup_mock_dataset_owner_execute_result,
|
||||
setup_mock_tenant_owner_execute_result,
|
||||
)
|
||||
|
||||
|
||||
@ -141,14 +141,11 @@ class TestValidateAppToken:
|
||||
mock_account = Mock()
|
||||
mock_account.id = str(uuid.uuid4())
|
||||
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
# Use side_effect to return app first, then tenant via session.get()
|
||||
mock_db.session.get.side_effect = [mock_app, mock_tenant]
|
||||
|
||||
# Mock the tenant owner query (execute(select(...)).one_or_none())
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
# Mock the tenant owner execute result (execute(select(...)).one_or_none())
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(app_model):
|
||||
@ -471,7 +468,7 @@ class TestValidateDatasetToken:
|
||||
mock_account.current_tenant = mock_tenant
|
||||
|
||||
# Mock the tenant account join query (execute(select(...)).one_or_none())
|
||||
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
|
||||
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
|
||||
|
||||
# Mock the account lookup via session.get()
|
||||
mock_db.session.get.return_value = mock_account
|
||||
|
||||
@ -22,18 +22,16 @@ class FakeSession:
|
||||
|
||||
def __init__(self, mapping: dict[str, Any] | None = None):
|
||||
self._mapping: dict[str, Any] = mapping or {}
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model: type) -> FakeSession:
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
def get(self, model: type, _ident: object) -> Any:
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
def scalar(self, stmt: Any) -> Any:
|
||||
try:
|
||||
model = stmt.column_descriptions[0]["entity"]
|
||||
except (AttributeError, IndexError, KeyError, TypeError):
|
||||
return None
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
|
||||
class FakeDB:
|
||||
|
||||
@ -36,18 +36,6 @@ class _FakeSession:
|
||||
|
||||
def __init__(self, mapping: dict[str, Any]):
|
||||
self._mapping = mapping
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model):
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
def get(self, model, ident):
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
@ -34,7 +34,6 @@ def _patch_wraps():
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
patch("controllers.web.login.dify_config", web_dify),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@ -154,7 +154,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
@ -301,7 +300,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock ConversationVariable.from_variable to return mock objects
|
||||
@ -453,7 +451,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -10,7 +13,8 @@ from core.app.entities.task_entities import (
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestAdvancedChatGenerateResponseConverter:
|
||||
@ -28,6 +32,37 @@ class TestAdvancedChatGenerateResponseConverter:
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
assert "usage" not in response["metadata"]
|
||||
|
||||
def test_blocking_full_response_derives_pause_data_from_model_dump(self, monkeypatch: pytest.MonkeyPatch):
|
||||
data = AdvancedChatPausedBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
workflow_run_id="run-1",
|
||||
answer="partial",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[{"type": PauseReasonType.HUMAN_INPUT_REQUIRED, "form_id": "form-1"}],
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
)
|
||||
original_model_dump = type(data).model_dump
|
||||
|
||||
def _model_dump_with_future_field(self, *args, **kwargs):
|
||||
payload = original_model_dump(self, *args, **kwargs)
|
||||
payload["future_field"] = "future-value"
|
||||
return payload
|
||||
|
||||
monkeypatch.setattr(type(data), "model_dump", _model_dump_with_future_field)
|
||||
blocking = AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
|
||||
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
|
||||
assert response["data"]["future_field"] == "future-value"
|
||||
|
||||
def test_stream_simple_response_includes_node_events(self):
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t1",
|
||||
|
||||
@ -39,15 +39,19 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from core.base.tts.app_generator_tts_publisher import AudioTrunk
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.nodes.human_input.entities import UserAction
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import MessageStatus
|
||||
@ -123,6 +127,57 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
assert response.data.answer == "done"
|
||||
assert response.data.metadata == {"k": "v"}
|
||||
|
||||
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.answer = "partial answer"
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
total_tokens=7,
|
||||
node_run_steps=3,
|
||||
)
|
||||
|
||||
def _gen():
|
||||
yield HumanInputRequiredResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
form_content="Need approval",
|
||||
inputs=[],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
form_token="token-1",
|
||||
resolved_default_values={},
|
||||
expiration_time=123,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert isinstance(response, AdvancedChatPausedBlockingResponse)
|
||||
assert response.data.workflow_run_id == "run-id"
|
||||
assert response.data.status == "paused"
|
||||
assert response.data.paused_nodes == ["node-1"]
|
||||
assert response.data.reasons == [
|
||||
{
|
||||
"TYPE": PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
"form_id": "form-1",
|
||||
"node_id": "node-1",
|
||||
"node_title": "Approval",
|
||||
"form_content": "Need approval",
|
||||
"inputs": [],
|
||||
"actions": [{"id": "approve", "title": "Approve", "button_style": "default"}],
|
||||
"display_in_ui": True,
|
||||
"form_token": "token-1",
|
||||
"resolved_default_values": {},
|
||||
"expiration_time": 123,
|
||||
}
|
||||
]
|
||||
|
||||
def test_handle_text_chunk_event_updates_state(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_cycle_manager = SimpleNamespace(
|
||||
|
||||
@ -375,7 +375,7 @@ def test_generate_success_returns_converted(generator, mocker):
|
||||
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.get.return_value = workflow
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
|
||||
@ -132,11 +132,8 @@ def test_run_pipeline_not_found(mocker):
|
||||
app_generate_entity.single_iteration_run = None
|
||||
app_generate_entity.single_loop_run = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.get.side_effect = [None, None]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
@ -157,11 +154,9 @@ def test_run_workflow_not_initialized(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query_pipeline
|
||||
session.get.side_effect = [None, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
|
||||
@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
|
||||
|
||||
class _DummyConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
|
||||
blocking_full_calls: list[WorkflowAppBlockingResponse] = []
|
||||
blocking_simple_calls: list[WorkflowAppBlockingResponse] = []
|
||||
stream_full_calls: list[Generator[AppStreamResponse, None, None]] = []
|
||||
stream_simple_calls: list[Generator[AppStreamResponse, None, None]] = []
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
cls.blocking_full_calls = []
|
||||
cls.blocking_simple_calls = []
|
||||
cls.stream_full_calls = []
|
||||
cls.stream_simple_calls = []
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
cls.blocking_full_calls.append(blocking_response)
|
||||
return {"kind": "blocking-full", "task_id": blocking_response.task_id}
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
cls.blocking_simple_calls.append(blocking_response)
|
||||
return {"kind": "blocking-simple", "task_id": blocking_response.task_id}
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls.stream_full_calls.append(stream_response)
|
||||
yield {"kind": "stream-full"}
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls.stream_simple_calls.append(stream_response)
|
||||
yield {"kind": "stream-simple"}
|
||||
|
||||
|
||||
def _build_blocking_response() -> WorkflowAppBlockingResponse:
|
||||
return WorkflowAppBlockingResponse(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id="run-1",
|
||||
workflow_id="workflow-1",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={"ok": True},
|
||||
error=None,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_stream_response() -> Generator[AppStreamResponse, None, None]:
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id="run-1",
|
||||
stream_response=PingStreamResponse(task_id="task-1"),
|
||||
)
|
||||
|
||||
|
||||
def test_convert_routes_blocking_response_by_invoke_from() -> None:
|
||||
_DummyConverter.reset()
|
||||
blocking_response = _build_blocking_response()
|
||||
|
||||
full_result = _DummyConverter.convert(blocking_response, InvokeFrom.SERVICE_API)
|
||||
simple_result = _DummyConverter.convert(blocking_response, InvokeFrom.WEB_APP)
|
||||
|
||||
assert full_result == {"kind": "blocking-full", "task_id": "task-1"}
|
||||
assert simple_result == {"kind": "blocking-simple", "task_id": "task-1"}
|
||||
assert _DummyConverter.blocking_full_calls == [blocking_response]
|
||||
assert _DummyConverter.blocking_simple_calls == [blocking_response]
|
||||
|
||||
|
||||
def test_convert_routes_stream_response_by_invoke_from() -> None:
|
||||
_DummyConverter.reset()
|
||||
|
||||
full_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.SERVICE_API))
|
||||
simple_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.WEB_APP))
|
||||
|
||||
assert full_result == [{"kind": "stream-full"}]
|
||||
assert simple_result == [{"kind": "stream-simple"}]
|
||||
assert len(_DummyConverter.stream_full_calls) == 1
|
||||
assert len(_DummyConverter.stream_simple_calls) == 1
|
||||
@ -1,6 +1,7 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
@ -23,7 +24,21 @@ class TestMessageGenerator:
|
||||
"core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}])
|
||||
) as mock_stream,
|
||||
):
|
||||
events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2))
|
||||
events = list(
|
||||
MessageGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
"run-1",
|
||||
idle_timeout=1,
|
||||
ping_interval=2,
|
||||
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
|
||||
)
|
||||
)
|
||||
|
||||
assert events == [{"event": "ping"}]
|
||||
mock_stream.assert_called_once()
|
||||
mock_stream.assert_called_once_with(
|
||||
topic="topic",
|
||||
idle_timeout=1,
|
||||
ping_interval=2,
|
||||
on_subscribe=None,
|
||||
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
|
||||
)
|
||||
|
||||
@ -88,6 +88,10 @@ def test_normalize_terminal_events_defaults():
|
||||
}
|
||||
|
||||
|
||||
def test_normalize_terminal_events_empty_values():
|
||||
assert _normalize_terminal_events([]) == set({})
|
||||
|
||||
|
||||
def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
|
||||
topic = FakeTopic()
|
||||
times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0]
|
||||
@ -106,3 +110,21 @@ def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
# next receive yields None -> ping interval triggers
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
|
||||
|
||||
def test_stream_topic_events_can_continue_past_pause():
|
||||
topic = FakeTopic()
|
||||
topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_PAUSED.value}).encode())
|
||||
topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_FINISHED.value}).encode())
|
||||
|
||||
generator = stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=1.0,
|
||||
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
|
||||
)
|
||||
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
assert next(generator)["event"] == StreamEvent.WORKFLOW_PAUSED.value
|
||||
assert next(generator)["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
|
||||
@ -36,11 +36,12 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.base.tts.app_generator_tts_publisher import AudioTrunk
|
||||
@ -91,27 +92,50 @@ def _make_pipeline():
|
||||
|
||||
|
||||
class TestWorkflowGenerateTaskPipeline:
|
||||
def test_to_blocking_response_handles_pause(self):
|
||||
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
total_tokens=5,
|
||||
node_run_steps=2,
|
||||
)
|
||||
|
||||
def _gen():
|
||||
yield WorkflowPauseStreamResponse(
|
||||
yield HumanInputRequiredResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id="run",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
created_at=1,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
workflow_run_id="run-id",
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
form_content="content",
|
||||
expiration_time=1,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert isinstance(response, WorkflowAppPausedBlockingResponse)
|
||||
assert response.workflow_run_id == "run-id"
|
||||
assert response.data.status == WorkflowExecutionStatus.PAUSED
|
||||
assert response.data.created_at == 0
|
||||
assert response.data.paused_nodes == ["node-1"]
|
||||
assert response.data.reasons == [
|
||||
{
|
||||
"TYPE": "human_input_required",
|
||||
"form_id": "form-1",
|
||||
"node_id": "node-1",
|
||||
"node_title": "Human Input",
|
||||
"form_content": "content",
|
||||
"inputs": [],
|
||||
"actions": [],
|
||||
"display_in_ui": False,
|
||||
"form_token": None,
|
||||
"resolved_default_values": {},
|
||||
"expiration_time": 1,
|
||||
}
|
||||
]
|
||||
|
||||
def test_to_blocking_response_handles_finish(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
@ -775,9 +775,6 @@ class TestNotionExtractorLastEditedTime:
|
||||
"last_edited_time": "2024-11-27T18:00:00.000Z",
|
||||
}
|
||||
mock_request.return_value = mock_response
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
extractor_page.update_last_edited_time(mock_document_model)
|
||||
@ -863,9 +860,6 @@ class TestNotionExtractorIntegration:
|
||||
}
|
||||
|
||||
mock_request.side_effect = [last_edited_response, block_response]
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
documents = extractor.extract()
|
||||
@ -919,10 +913,6 @@ class TestNotionExtractorIntegration:
|
||||
}
|
||||
mock_post.return_value = database_response
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
documents = extractor.extract()
|
||||
|
||||
|
||||
106
api/tests/unit_tests/core/helper/test_creators.py
Normal file
106
api/tests/unit_tests/core/helper/test_creators.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Tests for the Creators Platform helper module."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from yarl import URL
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_creators_url(monkeypatch):
|
||||
"""Patch the module-level creators_platform_api_url for all tests."""
|
||||
monkeypatch.setattr(
|
||||
"core.helper.creators.creators_platform_api_url",
|
||||
URL("https://creators.example.com"),
|
||||
)
|
||||
|
||||
|
||||
class TestUploadDSL:
|
||||
@patch("core.helper.creators.httpx.post")
|
||||
def test_returns_claim_code(self, mock_post):
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"data": {"claim_code": "abc123"}}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
from core.helper.creators import upload_dsl
|
||||
|
||||
result = upload_dsl(b"app: demo", "demo.yaml")
|
||||
|
||||
assert result == "abc123"
|
||||
mock_post.assert_called_once()
|
||||
call_kwargs = mock_post.call_args
|
||||
assert "anonymous-upload" in call_kwargs.args[0]
|
||||
assert call_kwargs.kwargs["timeout"] == 30
|
||||
|
||||
@patch("core.helper.creators.httpx.post")
|
||||
def test_raises_on_missing_claim_code(self, mock_post):
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"data": {}}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
from core.helper.creators import upload_dsl
|
||||
|
||||
with pytest.raises(ValueError, match="claim_code"):
|
||||
upload_dsl(b"app: demo")
|
||||
|
||||
@patch("core.helper.creators.httpx.post")
|
||||
def test_raises_on_http_error(self, mock_post):
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"Server Error",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(),
|
||||
)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
from core.helper.creators import upload_dsl
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
upload_dsl(b"app: demo")
|
||||
|
||||
|
||||
class TestGetRedirectUrl:
|
||||
@patch("core.helper.creators.dify_config")
|
||||
def test_without_oauth_client_id(self, mock_config):
|
||||
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
|
||||
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
|
||||
|
||||
from core.helper.creators import get_redirect_url
|
||||
|
||||
url = get_redirect_url("user-1", "claim-abc")
|
||||
|
||||
assert "dsl_claim_code=claim-abc" in url
|
||||
assert "oauth_code" not in url
|
||||
assert url.startswith("https://creators.example.com")
|
||||
|
||||
@patch("core.helper.creators.dify_config")
|
||||
def test_with_oauth_client_id(self, mock_config):
|
||||
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
|
||||
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "client-xyz"
|
||||
|
||||
with patch(
|
||||
"services.oauth_server.OAuthServerService.sign_oauth_authorization_code",
|
||||
return_value="oauth-code-123",
|
||||
) as mock_sign:
|
||||
from core.helper.creators import get_redirect_url
|
||||
|
||||
url = get_redirect_url("user-1", "claim-abc")
|
||||
|
||||
mock_sign.assert_called_once_with("client-xyz", "user-1")
|
||||
assert "dsl_claim_code=claim-abc" in url
|
||||
assert "oauth_code=oauth-code-123" in url
|
||||
|
||||
@patch("core.helper.creators.dify_config")
|
||||
def test_strips_trailing_slash(self, mock_config):
|
||||
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com/"
|
||||
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
|
||||
|
||||
from core.helper.creators import get_redirect_url
|
||||
|
||||
url = get_redirect_url("user-1", "claim-abc")
|
||||
|
||||
assert url.startswith("https://creators.example.com?")
|
||||
assert "creators.example.com/?" not in url
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user