Merge branch 'main' into refactor/clean-unnecessary-none-batch2

This commit is contained in:
Asuka Minato 2026-06-18 17:29:09 +09:00 committed by GitHub
commit 96d1fa2917
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
547 changed files with 49023 additions and 5819 deletions

View File

@ -21,6 +21,7 @@ env:
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
DIFY_WEB_IMAGE_NAME: ${{ vars.DIFY_WEB_IMAGE_NAME || 'langgenius/dify-web' }}
DIFY_API_IMAGE_NAME: ${{ vars.DIFY_API_IMAGE_NAME || 'langgenius/dify-api' }}
DIFY_AGENT_IMAGE_NAME: ${{ vars.DIFY_AGENT_IMAGE_NAME || 'langgenius/dify-agent-backend' }}
jobs:
build:
@ -60,6 +61,20 @@ jobs:
file: "web/Dockerfile"
platform: linux/arm64
runs_on: depot-ubuntu-24.04-4
- service_name: "build-agent-amd64"
image_name_env: "DIFY_AGENT_IMAGE_NAME"
artifact_context: "agent"
build_context: "{{defaultContext}}"
file: "dify-agent/Dockerfile"
platform: linux/amd64
runs_on: depot-ubuntu-24.04-4
- service_name: "build-agent-arm64"
image_name_env: "DIFY_AGENT_IMAGE_NAME"
artifact_context: "agent"
build_context: "{{defaultContext}}"
file: "dify-agent/Dockerfile"
platform: linux/arm64
runs_on: depot-ubuntu-24.04-4
steps:
- name: Prepare
@ -122,6 +137,9 @@ jobs:
- service_name: "validate-web-amd64"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
- service_name: "validate-agent-amd64"
build_context: "{{defaultContext}}"
file: "dify-agent/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
@ -147,6 +165,9 @@ jobs:
- service_name: "merge-web-images"
image_name_env: "DIFY_WEB_IMAGE_NAME"
context: "web"
- service_name: "merge-agent-images"
image_name_env: "DIFY_AGENT_IMAGE_NAME"
context: "agent"
steps:
- name: Download digests
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1

View File

@ -1,7 +1,8 @@
from typing import Any, Literal
from copy import deepcopy
from typing import Any, Literal, override
from uuid import UUID
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, GetJsonSchemaHandler, model_validator
from libs.helper import UUIDStrOrEmpty
@ -12,6 +13,45 @@ class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@classmethod
@override
def __get_pydantic_json_schema__(cls, core_schema: Any, handler: GetJsonSchemaHandler) -> dict[str, Any]:
schema = handler.resolve_ref_schema(handler(core_schema))
properties = schema.get("properties")
if not isinstance(properties, dict):
return schema
auto_generate_schema = deepcopy(properties.get("auto_generate", {"type": "boolean"}))
name_schema = deepcopy(properties.get("name", {"type": "string"}))
non_blank_name_schema: dict[str, Any] = {"pattern": r".*\S.*", "type": "string"}
if isinstance(name_schema, dict) and isinstance(name_schema.get("title"), str):
non_blank_name_schema["title"] = name_schema["title"]
auto_generate_true_schema = {**auto_generate_schema, "enum": [True]}
auto_generate_true_schema.pop("default", None)
return {
**schema,
"anyOf": [
{
"properties": {
"auto_generate": auto_generate_true_schema,
"name": name_schema,
},
"required": ["auto_generate"],
"type": "object",
},
{
"properties": {
"auto_generate": {**auto_generate_schema, "enum": [False]},
"name": non_blank_name_schema,
},
"required": ["name"],
"type": "object",
},
],
}
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
@ -101,4 +141,7 @@ class TextToAudioPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
voice: str | None = Field(default=None, description="Voice to use for TTS")
text: str | None = Field(default=None, description="Text to convert to audio")
streaming: bool | None = Field(default=None, description="Enable streaming response")
streaming: bool | None = Field(
default=None,
description="Reserved for compatibility; TTS response streaming is determined by the provider output.",
)

View File

@ -42,10 +42,11 @@ def stringify_form_default_values(values: dict[str, object]) -> dict[str, str]:
"""Serialize default values into strings expected by human-input form clients."""
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
match value:
case None:
result[key] = ""
case dict() | list():
result[key] = json.dumps(value, ensure_ascii=False)
case _:
result[key] = str(value)
return result

View File

@ -2,20 +2,28 @@ from uuid import UUID
from flask import abort, request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from pydantic import AliasChoices, BaseModel, Field, field_validator
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.agent.app_helpers import resolve_agent_app_model
from controllers.console.app.app import (
AppDetailWithSite,
AppDetailWithSite as GenericAppDetailWithSite,
)
from controllers.console.app.app import (
AppListQuery,
AppPagination,
AppPartial,
CopyAppPayload,
UpdateAppPayload,
_normalize_app_list_query_args,
)
from controllers.console.app.app import (
AppPagination as GenericAppPagination,
)
from controllers.console.app.app import (
AppPartial as GenericAppPartial,
)
from controllers.console.app.app import (
UpdateAppPayload as GenericUpdateAppPayload,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@ -31,6 +39,8 @@ from fields.agent_fields import (
AgentConfigSnapshotListResponse,
AgentInviteOptionsResponse,
AgentLogListResponse,
AgentLogMessageListResponse,
AgentLogSourceListResponse,
AgentPublishedReferenceResponse,
AgentRosterListResponse,
AgentStatisticSummaryEnvelopeResponse,
@ -64,14 +74,33 @@ class AgentIdPath(BaseModel):
class AgentAppCreatePayload(BaseModel):
name: str = Field(..., min_length=1, description="Agent name")
description: str | None = Field(default=None, description="Agent description (max 400 chars)", max_length=400)
role: str = Field(default="", description="Agent role", max_length=255)
role: str = Field(..., min_length=1, description="Agent role", max_length=255)
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("role")
@classmethod
def validate_role(cls, value: str) -> str:
role = value.strip()
if not role:
raise ValueError("Agent role is required.")
return role
class AgentAppUpdatePayload(UpdateAppPayload):
role: str | None = Field(default=None, description="Agent role", max_length=255)
# Keep agent-app roster DTOs agent-specific instead of reusing the shared
# /apps response/request models. The roster surface needs Agent-only fields such
# as `role`, while the generic console/apps contracts must stay unchanged.
class AgentAppUpdatePayload(GenericUpdateAppPayload):
role: str = Field(..., min_length=1, description="Agent role", max_length=255)
@field_validator("role")
@classmethod
def validate_role(cls, value: str) -> str:
role = value.strip()
if not role:
raise ValueError("Agent role is required.")
return role
class AgentAppPublishedReferenceResponse(BaseModel):
@ -82,19 +111,6 @@ class AgentAppPublishedReferenceResponse(BaseModel):
app_icon_background: str | None = None
class AgentAppPartial(AppPartial):
published_reference_count: int = 0
published_references: list[AgentAppPublishedReferenceResponse] = Field(default_factory=list)
class AgentAppPagination(BaseModel):
page: int
limit: int
total: int
has_more: bool
data: list[AgentAppPartial]
class AgentLogsQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Page size")
@ -131,6 +147,26 @@ class AgentStatisticsQuery(BaseModel):
return value
class AgentAppPartial(GenericAppPartial):
app_id: str | None = None
role: str | None = None
active_config_is_published: bool = False
published_reference_count: int = 0
published_references: list[AgentAppPublishedReferenceResponse] = Field(default_factory=list)
class AgentAppDetailWithSite(GenericAppDetailWithSite):
app_id: str | None = None
role: str | None = None
active_config_is_published: bool = False
class AgentAppPagination(GenericAppPagination):
data: list[AgentAppPartial] = Field( # type: ignore[assignment] # pyrefly: ignore[bad-override-mutable-attribute]
validation_alias=AliasChoices("items", "data")
)
register_schema_models(
console_ns,
AgentAppCreatePayload,
@ -141,18 +177,20 @@ register_schema_models(
AgentStatisticsQuery,
AgentIdPath,
AppListQuery,
UpdateAppPayload,
RosterListQuery,
)
register_response_schema_models(
console_ns,
AppDetailWithSite,
AgentAppPagination,
AgentAppPublishedReferenceResponse,
AgentAppDetailWithSite,
AgentAppPartial,
AgentConfigSnapshotDetailResponse,
AgentConfigSnapshotListResponse,
AgentInviteOptionsResponse,
AgentLogListResponse,
AgentLogMessageListResponse,
AgentLogSourceListResponse,
AgentPublishedReferenceResponse,
AgentRosterListResponse,
AgentStatisticSummaryEnvelopeResponse,
@ -164,16 +202,25 @@ def _agent_roster_service() -> AgentRosterService:
def _serialize_agent_app_detail(app_model) -> dict:
"""Serialize an Agent App detail using roster-only DTOs.
`/agent` responses are roster-shaped rather than raw app-shaped: `id`
becomes the backing roster Agent id, `app_id` carries the underlying App
id, and `role` is injected from the backing roster Agent. Keeping that
remap in this serializer lets generated console/agent contracts expose the
roster persona fields without widening the shared /apps detail schema.
"""
app_model = AppService().get_app(app_model)
if FeatureService.get_system_features().webapp_auth.enabled:
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode # type: ignore[attr-defined]
roster_service = _agent_roster_service()
agent = roster_service.get_app_backing_agent(tenant_id=app_model.tenant_id, app_id=app_model.id)
payload = AgentAppDetailWithSite.model_validate(app_model, from_attributes=True).model_dump(mode="json")
agent = roster_service.get_app_backing_agent(tenant_id=app_model.tenant_id, app_id=str(app_model.id))
if not agent:
raise AgentNotFoundError()
payload = AppDetailWithSite.model_validate(app_model, from_attributes=True).model_dump(mode="json")
payload.pop("bound_agent_id", None)
payload["app_id"] = str(app_model.id)
payload["id"] = agent.id
@ -186,6 +233,14 @@ def _serialize_agent_app_detail(app_model) -> dict:
def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
"""Serialize Agent App lists with roster-shaped items.
Each item starts from the shared App list shape, then drops
`bound_agent_id`, rewrites `id` to the backing roster Agent id, stores the
original App id in `app_id`, and injects roster-only `role` when a backing
Agent is present.
"""
app_ids = [str(app.id) for app in app_pagination.items]
roster_service = _agent_roster_service()
agents_by_app_id = roster_service.load_app_backing_agents_by_app_id(
@ -200,7 +255,7 @@ def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
tenant_id=tenant_id,
agent_ids=[agent.id for agent in agents_by_app_id.values()],
)
payload = AppPagination.model_validate(app_pagination, from_attributes=True).model_dump(mode="json")
payload = AgentAppPagination.model_validate(app_pagination, from_attributes=True).model_dump(mode="json")
for item in payload["data"]:
app_id = item["id"]
item.pop("bound_agent_id", None)
@ -266,7 +321,7 @@ class AgentAppListApi(Resource):
status="normal",
)
app_pagination = AppService().get_paginate_apps(current_user.id, current_tenant_id, params)
app_pagination = AppService().get_paginate_apps(current_user.id, current_tenant_id, params, db.session)
if app_pagination is None:
empty = AgentAppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json")
@ -274,7 +329,7 @@ class AgentAppListApi(Resource):
return _serialize_agent_app_pagination(app_pagination, tenant_id=current_tenant_id)
@console_ns.expect(console_ns.models[AgentAppCreatePayload.__name__])
@console_ns.response(201, "Agent app created successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(201, "Agent app created successfully", console_ns.models[AgentAppDetailWithSite.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@ -302,7 +357,7 @@ class AgentAppListApi(Resource):
@console_ns.route("/agent/<uuid:agent_id>")
class AgentAppApi(Resource):
@console_ns.response(200, "Agent app detail", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(200, "Agent app detail", console_ns.models[AgentAppDetailWithSite.__name__])
@setup_required
@login_required
@account_initialization_required
@ -313,7 +368,7 @@ class AgentAppApi(Resource):
return _serialize_agent_app_detail(app_model)
@console_ns.expect(console_ns.models[AgentAppUpdatePayload.__name__])
@console_ns.response(200, "Agent app updated successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(200, "Agent app updated successfully", console_ns.models[AgentAppDetailWithSite.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@ -353,7 +408,7 @@ class AgentAppApi(Resource):
@console_ns.route("/agent/<uuid:agent_id>/copy")
class AgentAppCopyApi(Resource):
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
@console_ns.response(201, "Agent app copied successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(201, "Agent app copied successfully", console_ns.models[AgentAppDetailWithSite.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@ -416,6 +471,7 @@ class AgentLogsApi(Resource):
try:
payload = _agent_observability_service().list_logs(
app=app_model,
agent_id=str(agent_id),
params=AgentLogQueryParams(
page=query.page,
limit=query.limit,
@ -431,6 +487,53 @@ class AgentLogsApi(Resource):
return dump_response(AgentLogListResponse, payload)
@console_ns.route("/agent/<uuid:agent_id>/logs/<uuid:conversation_id>/messages")
class AgentLogMessagesApi(Resource):
@console_ns.doc(params=query_params_from_model(AgentLogsQuery))
@console_ns.response(200, "Agent log messages", console_ns.models[AgentLogMessageListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, current_user: Account, agent_id: UUID, conversation_id: UUID):
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
query = AgentLogsQuery.model_validate(request.args.to_dict(flat=True))
start, end = _parse_observability_time_range(query.start, query.end, current_user)
try:
payload = _agent_observability_service().list_log_messages(
app=app_model,
agent_id=str(agent_id),
conversation_id=str(conversation_id),
params=AgentLogQueryParams(
page=query.page,
limit=query.limit,
keyword=query.keyword,
status=query.status,
source=query.source,
start=start,
end=end,
),
)
except ValueError as exc:
abort(400, description=str(exc))
return dump_response(AgentLogMessageListResponse, payload)
@console_ns.route("/agent/<uuid:agent_id>/log-sources")
class AgentLogSourcesApi(Resource):
@console_ns.response(200, "Agent log sources", console_ns.models[AgentLogSourceListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, current_user: Account, agent_id: UUID):
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
payload = _agent_observability_service().list_log_sources(app=app_model, agent_id=str(agent_id))
return dump_response(AgentLogSourceListResponse, payload)
@console_ns.route("/agent/<uuid:agent_id>/statistics/summary")
class AgentStatisticsSummaryApi(Resource):
@console_ns.doc(params=query_params_from_model(AgentStatisticsQuery))
@ -452,6 +555,7 @@ class AgentStatisticsSummaryApi(Resource):
try:
payload = _agent_observability_service().get_statistics_summary(
app=app_model,
agent_id=str(agent_id),
params=AgentStatisticsQueryParams(source=query.source, start=start, end=end, timezone=timezone),
)
except ValueError as exc:

View File

@ -30,7 +30,7 @@ from models import Account
from models.agent_config_entities import AgentFileRefConfig, AgentSkillRefConfig
from models.model import App, AppMode, UploadFile
from services.agent.composer_service import AgentComposerService
from services.agent.skill_package_service import SkillManifest, SkillPackageError, SkillPackageService
from services.agent.skill_package_service import SkillManifest, SkillPackageError
from services.agent.skill_standardize_service import SkillStandardizeService
from services.agent.skill_tool_inference_service import (
SkillToolInferenceError,
@ -45,11 +45,18 @@ from services.agent_drive_service import (
normalize_drive_key,
)
from services.agent_service import AgentService
from services.file_service import FileService
logger = logging.getLogger(__name__)
_WORKFLOW_AGENT_DRIVE_APP_MODES = [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]
_AGENT_SKILL_UPLOAD_PARAMS = {
"file": {
"in": "formData",
"type": "file",
"required": True,
"description": "Skill package (.zip or .skill).",
}
}
class AgentLogQuery(BaseModel):
@ -125,11 +132,6 @@ class AgentSkillUploadResponse(ResponseModel):
manifest: SkillManifest
class AgentSkillStandardizeResponse(ResponseModel):
skill: AgentSkillRefConfig
manifest: SkillManifest
class AgentDriveFileResponse(ResponseModel):
name: str
drive_key: str
@ -156,7 +158,6 @@ register_response_schema_models(
AgentDriveFileCommitResponse,
AgentDriveFileResponse,
AgentLogResponse,
AgentSkillStandardizeResponse,
AgentSkillUploadResponse,
SkillToolInferenceResult,
)
@ -174,30 +175,9 @@ def _agent_not_bound() -> tuple[dict[str, str], int]:
return {"code": "agent_not_bound", "message": "no agent is bound for this app/node"}, 400
def _upload_skill_for_app(*, current_user: Account):
if "file" not in request.files:
return {"code": "no_file", "message": "no skill file uploaded"}, 400
if len(request.files) > 1:
return {"code": "too_many_files", "message": "only one skill file is allowed"}, 400
def _upload_skill_for_app(*, current_user: Account, app_model: App):
"""Upload one skill package and commit its normalized files into the agent drive."""
upload = request.files["file"]
content = upload.stream.read()
try:
manifest = SkillPackageService().validate_and_extract(content=content, filename=upload.filename or "")
except SkillPackageError as exc:
return {"code": exc.code, "message": exc.message}, exc.status_code
upload_file = FileService(db.engine).upload_file(
filename=upload.filename or "skill.zip",
content=content,
mimetype=upload.mimetype or "application/zip",
user=current_user,
)
skill_ref = manifest.to_skill_ref(file_id=upload_file.id)
return {"skill": skill_ref.model_dump(exclude_none=True), "manifest": manifest.model_dump()}, 201
def _standardize_skill_for_app(*, current_user: Account, app_model: App):
query = query_params_from_request(AgentDriveMutationQuery)
agent_id = _resolve_agent_id(app_model, query.node_id)
if not agent_id:
@ -382,51 +362,9 @@ class AgentLogApi(Resource):
@console_ns.route("/agent/<uuid:agent_id>/skills/upload")
class AgentSkillUploadByAgentApi(Resource):
@console_ns.doc("upload_agent_skill_by_agent")
@console_ns.doc(description="Upload + validate a Skill package for an Agent App")
@console_ns.doc(params={"agent_id": "Agent ID"})
@console_ns.response(201, "Skill validated", console_ns.models[AgentSkillUploadResponse.__name__])
@console_ns.response(400, "Invalid skill package")
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, current_user: Account, agent_id: UUID):
resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
return _upload_skill_for_app(current_user=current_user)
@console_ns.route("/apps/<uuid:app_id>/agent/skills/upload")
class AgentSkillUploadApi(Resource):
@console_ns.doc("upload_agent_skill")
@console_ns.doc(description="Upload + validate a Skill package (.zip/.skill) and extract its manifest")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(201, "Skill validated", console_ns.models[AgentSkillUploadResponse.__name__])
@console_ns.response(400, "Invalid skill package")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=_WORKFLOW_AGENT_DRIVE_APP_MODES)
@with_current_user
def post(self, current_user: Account, app_model: App):
"""Validate an uploaded Skill package and persist the archive.
Returns a validated skill ref (to bind into the Agent soul config on save)
plus its manifest. Standardizing into the agent drive is ENG-594.
"""
return _upload_skill_for_app(current_user=current_user)
@console_ns.route("/agent/<uuid:agent_id>/skills/standardize")
class AgentSkillStandardizeByAgentApi(Resource):
@console_ns.doc("standardize_agent_skill_by_agent")
@console_ns.doc(description="Validate + standardize a Skill into an Agent App drive")
@console_ns.doc(params={"agent_id": "Agent ID"})
@console_ns.response(
201,
"Skill standardized into drive",
console_ns.models[AgentSkillStandardizeResponse.__name__],
)
@console_ns.doc(description="Upload + standardize a Skill into an Agent App drive")
@console_ns.doc(consumes=["multipart/form-data"], params={"agent_id": "Agent ID", **_AGENT_SKILL_UPLOAD_PARAMS})
@console_ns.response(201, "Skill uploaded into drive", console_ns.models[AgentSkillUploadResponse.__name__])
@console_ns.response(400, "Invalid skill package or no bound agent")
@setup_required
@login_required
@ -435,19 +373,22 @@ class AgentSkillStandardizeByAgentApi(Resource):
@with_current_tenant_id
def post(self, tenant_id: str, current_user: Account, agent_id: UUID):
app_model = resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
return _standardize_skill_for_app(current_user=current_user, app_model=app_model)
return _upload_skill_for_app(current_user=current_user, app_model=app_model)
@console_ns.route("/apps/<uuid:app_id>/agent/skills/standardize")
class AgentSkillStandardizeApi(Resource):
@console_ns.doc("standardize_agent_skill")
@console_ns.doc(description="Validate + standardize a Skill into the agent drive (ENG-594)")
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentDriveMutationQuery)})
@console_ns.response(
201,
"Skill standardized into drive",
console_ns.models[AgentSkillStandardizeResponse.__name__],
@console_ns.route("/apps/<uuid:app_id>/agent/skills/upload")
class AgentSkillUploadApi(Resource):
@console_ns.doc("upload_agent_skill")
@console_ns.doc(description="Upload + standardize a Skill into the agent drive")
@console_ns.doc(
consumes=["multipart/form-data"],
params={
"app_id": "Application ID",
**query_params_from_model(AgentDriveMutationQuery),
**_AGENT_SKILL_UPLOAD_PARAMS,
},
)
@console_ns.response(201, "Skill uploaded into drive", console_ns.models[AgentSkillUploadResponse.__name__])
@console_ns.response(400, "Invalid skill package or no bound agent")
@setup_required
@login_required
@ -455,8 +396,8 @@ class AgentSkillStandardizeApi(Resource):
@get_app_model(mode=_WORKFLOW_AGENT_DRIVE_APP_MODES)
@with_current_user
def post(self, current_user: Account, app_model: App):
"""Upload a Skill, validate it, and standardize it into the app agent's drive."""
return _standardize_skill_for_app(current_user=current_user, app_model=app_model)
"""Upload a Skill, validate it, and commit drive-backed skill files."""
return _upload_skill_for_app(current_user=current_user, app_model=app_model)
@console_ns.route("/agent/<uuid:agent_id>/files")

View File

@ -402,8 +402,6 @@ class AppPartial(ResponseModel):
bound_agent_id: str | None = None
# For Agent App responses exposed through /agent.
app_id: str | None = None
role: str | None = None
active_config_is_published: bool = False
is_starred: bool = False
@computed_field(return_type=str | None) # type: ignore
@ -457,8 +455,6 @@ class AppDetailWithSite(AppDetail):
bound_agent_id: str | None = None
# For Agent App responses exposed through /agent.
app_id: str | None = None
role: str | None = None
active_config_is_published: bool = False
@computed_field(return_type=str | None) # type: ignore
@property
@ -541,10 +537,7 @@ register_schema_models(
ModelConfig,
Site,
DeletedTool,
AppPartial,
AppDetail,
AppDetailWithSite,
AppPagination,
AppExportResponse,
Segmentation,
PreProcessingRule,
@ -564,6 +557,13 @@ register_schema_models(
LoadBalancingPayload,
)
register_response_schema_models(
console_ns,
AppPartial,
AppDetailWithSite,
AppPagination,
)
@console_ns.route("/apps")
class AppListApi(Resource):
@ -594,7 +594,7 @@ class AppListApi(Resource):
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params, db.session)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200
@ -661,7 +661,7 @@ class StarredAppListApi(Resource):
is_created_by_me=args.is_created_by_me,
)
app_pagination = AppService().get_paginate_starred_apps(current_user_id, current_tenant_id, params)
app_pagination = AppService().get_paginate_starred_apps(current_user_id, current_tenant_id, params, db.session)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200

View File

@ -1,6 +1,7 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from configs import dify_config
from constants.languages import supported_language
@ -11,7 +12,8 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, timezone
from models import AccountStatus
from services.account_service import RegisterService
from models.account import TenantAccountJoin, TenantAccountRole
from services.account_service import RegisterService, TenantService
from services.billing_service import BillingService
@ -25,18 +27,22 @@ class ActivatePayload(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
name: str = Field(..., max_length=30)
interface_language: str = Field(...)
timezone: str = Field(...)
name: str | None = Field(default=None, max_length=30)
interface_language: str | None = Field(default=None)
timezone: str | None = Field(default=None)
@field_validator("interface_language")
@classmethod
def validate_lang(cls, value: str) -> str:
def validate_lang(cls, value: str | None) -> str | None:
if value is None:
return None
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_tz(cls, value: str) -> str:
def validate_tz(cls, value: str | None) -> str | None:
if value is None:
return None
return timezone(value)
@ -48,6 +54,8 @@ class ActivationCheckData(BaseModel):
workspace_name: str | None
workspace_id: str | None
email: str | None
account_status: str | None = None
requires_setup: bool | None = None
class ActivationCheckResponse(BaseModel):
@ -95,9 +103,20 @@ class ActivateCheckApi(Resource):
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None
account = invitation.get("account")
account_status = account.status if account else None
requires_setup = data.get("requires_setup")
if requires_setup is None:
requires_setup = account_status == AccountStatus.PENDING
return {
"is_valid": invitation is not None,
"data": {"workspace_name": workspace_name, "workspace_id": workspace_id, "email": invitee_email},
"data": {
"workspace_name": workspace_name,
"workspace_id": workspace_id,
"email": invitee_email,
"account_status": account_status,
"requires_setup": requires_setup,
},
}
else:
return {"is_valid": False}
@ -126,15 +145,45 @@ class ActivateApi(Resource):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(account.email):
raise AccountInFreezeError()
tenant = invitation["tenant"]
raw_role = invitation["data"].get("role")
try:
role = TenantAccountRole(raw_role) if raw_role else TenantAccountRole.NORMAL
except ValueError:
role = TenantAccountRole.NORMAL
if not TenantAccountRole.is_non_owner_role(role):
role = TenantAccountRole.NORMAL
membership_id = db.session.scalar(
select(TenantAccountJoin.id).where(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.account_id == account.id,
)
)
requires_setup = invitation["data"].get("requires_setup")
if requires_setup is None:
requires_setup = account.status == AccountStatus.PENDING
setup_fields: tuple[str, str, str] | None = None
if requires_setup:
if not args.name or not args.interface_language or not args.timezone:
raise AlreadyActivateError()
setup_fields = (args.name, args.interface_language, args.timezone)
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
account.name = args.name
if membership_id is None:
TenantService.create_tenant_member(tenant, account, str(role))
account.interface_language = args.interface_language
account.timezone = args.timezone
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()
if setup_fields:
account.name = setup_fields[0]
account.interface_language = setup_fields[1]
account.timezone = setup_fields[2]
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
TenantService.switch_tenant(account, tenant.id)
return {"result": "success"}

View File

@ -409,6 +409,7 @@ class DatasetListApi(Resource):
datasets, total = DatasetService.get_datasets(
query.page,
query.limit,
db.session,
current_tenant_id,
current_user,
query.keyword,

View File

@ -122,7 +122,7 @@ class TagListApi(Resource):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type), db.session)
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@ -146,9 +146,9 @@ class TagUpdateDeleteApi(Resource):
raise Forbidden()
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str)
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str, db.session)
binding_count = TagService.get_tag_binding_count(tag_id_str)
binding_count = TagService.get_tag_binding_count(tag_id_str, db.session)
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
@ -164,7 +164,7 @@ class TagUpdateDeleteApi(Resource):
def delete(self, tag_id: UUID):
tag_id_str = str(tag_id)
TagService.delete_tag(tag_id_str)
TagService.delete_tag(tag_id_str, db.session)
return "", 204
@ -189,7 +189,8 @@ def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
),
db.session,
)
return {"result": "success"}, 200
@ -203,7 +204,8 @@ def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
),
db.session,
)
return {"result": "success"}, 200

View File

@ -232,7 +232,11 @@ class MemberInviteEmailApi(Resource):
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
{
"status": "already_member",
"email": invitee_email,
"message": "Account already in workspace.",
}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})

View File

@ -126,6 +126,7 @@ class CustomizedSnippetsApi(Resource):
snippet_service = _snippet_service()
snippets, total, has_more = snippet_service.get_snippets(
tenant_id=current_tenant_id,
session=db.session,
page=query.page,
limit=query.limit,
keyword=query.keyword,

View File

@ -174,7 +174,7 @@ class AppListApi(Resource):
tag_ids: list[str] | None = None
if query.tag:
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag, db.session)
if not tags:
return empty
tag_ids = [tag.id for tag in tags]
@ -191,7 +191,7 @@ class AppListApi(Resource):
openapi_visible=True,
)
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params, db.session)
if pagination is None:
return empty

View File

@ -45,6 +45,13 @@ class AnnotationJobStatusResponse(ResponseModel):
error_msg: str | None = None
ANNOTATION_REPLY_ACTION_PARAM = {
"description": "Action to perform: 'enable' or 'disable'",
"enum": ["enable", "disable"],
"type": "string",
}
register_schema_models(
service_api_ns,
AnnotationCreatePayload,
@ -58,10 +65,22 @@ register_response_schema_models(service_api_ns, AnnotationJobStatusResponse)
@service_api_ns.route("/apps/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource):
@service_api_ns.doc(
summary="Configure Annotation Reply",
description=(
"Enables or disables the annotation reply feature. Requires embedding model configuration "
"when enabling. Executes asynchronously — use [Get Annotation Reply Job "
"Status](/api-reference/annotations/get-annotation-reply-job-status) to track progress."
),
tags=["Annotations"],
responses={
200: "Annotation reply settings task initiated.",
},
)
@service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__])
@service_api_ns.doc("annotation_reply_action")
@service_api_ns.doc(description="Enable or disable annotation reply feature")
@service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
@service_api_ns.doc(params={"action": ANNOTATION_REPLY_ACTION_PARAM})
@service_api_ns.doc(
responses={
200: "Action completed successfully",
@ -92,6 +111,18 @@ class AnnotationReplyActionApi(Resource):
@service_api_ns.route("/apps/annotation-reply/<string:action>/status/<uuid:job_id>")
class AnnotationReplyActionStatusApi(Resource):
@service_api_ns.doc(
summary="Get Annotation Reply Job Status",
description=(
"Retrieves the status of an asynchronous annotation reply configuration job started by "
"[Configure Annotation Reply](/api-reference/annotations/configure-annotation-reply)."
),
tags=["Annotations"],
responses={
200: "Successfully retrieved task status.",
400: "`invalid_param` : The specified job does not exist.",
},
)
@service_api_ns.doc("get_annotation_reply_action_status")
@service_api_ns.doc(description="Get the status of an annotation reply action job")
@service_api_ns.doc(params={"action": "Action type", "job_id": "Job ID"})
@ -127,6 +158,14 @@ class AnnotationReplyActionStatusApi(Resource):
@service_api_ns.route("/apps/annotations")
class AnnotationListApi(Resource):
@service_api_ns.doc(
summary="List Annotations",
description="Retrieves a paginated list of annotations for the application. Supports keyword search filtering.",
tags=["Annotations"],
responses={
200: "Successfully retrieved annotation list.",
},
)
@service_api_ns.doc("list_annotations")
@service_api_ns.doc(description="List annotations for the application")
@service_api_ns.doc(params=query_params_from_model(AnnotationListQuery))
@ -159,6 +198,17 @@ class AnnotationListApi(Resource):
)
return response.model_dump(mode="json")
@service_api_ns.doc(
summary="Create Annotation",
description=(
"Creates a new annotation. Annotations provide predefined question-answer pairs that the app "
"can match and return directly instead of generating a response."
),
tags=["Annotations"],
responses={
201: "Annotation created successfully.",
},
)
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("create_annotation")
@service_api_ns.doc(description="Create a new annotation")
@ -185,6 +235,16 @@ class AnnotationListApi(Resource):
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource):
@service_api_ns.doc(
summary="Update Annotation",
description="Updates the question and answer of an existing annotation.",
tags=["Annotations"],
responses={
200: "Annotation updated successfully.",
403: "`forbidden` : Insufficient permissions to edit annotations.",
404: "`not_found` : Annotation does not exist.",
},
)
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("update_annotation")
@service_api_ns.doc(description="Update an existing annotation")
@ -212,6 +272,16 @@ class AnnotationUpdateDeleteApi(Resource):
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
@service_api_ns.doc(
summary="Delete Annotation",
description="Deletes an annotation and its associated hit history.",
tags=["Annotations"],
responses={
204: "Annotation deleted successfully.",
403: "`forbidden` : Insufficient permissions to edit annotations.",
404: "`not_found` : Annotation does not exist.",
},
)
@service_api_ns.doc("delete_annotation")
@service_api_ns.doc(description="Delete an annotation")
@service_api_ns.doc(params={"annotation_id": "Annotation ID"})

View File

@ -33,6 +33,18 @@ register_response_schema_models(service_api_ns, Parameters, AppMetaResponse, App
class AppParameterApi(Resource):
"""Resource for app variables."""
@service_api_ns.doc(
summary="Get App Parameters",
description=(
"Retrieve the application's input form configuration, including feature switches, input "
"parameter names, types, and default values."
),
tags=["Applications"],
responses={
200: "Application parameters information.",
400: "`app_unavailable` : App unavailable or misconfigured.",
},
)
@service_api_ns.doc("get_app_parameters")
@service_api_ns.doc(description="Retrieve application input parameters and configuration")
@service_api_ns.doc(
@ -71,6 +83,14 @@ class AppParameterApi(Resource):
@service_api_ns.route("/meta")
class AppMetaApi(Resource):
@service_api_ns.doc(
summary="Get App Meta",
description="Retrieve metadata about this application, including tool icons and other configuration details.",
tags=["Applications"],
responses={
200: "Successfully retrieved application meta information.",
},
)
@service_api_ns.doc("get_app_meta")
@service_api_ns.doc(description="Get application metadata")
@service_api_ns.doc(
@ -92,6 +112,14 @@ class AppMetaApi(Resource):
@service_api_ns.route("/info")
class AppInfoApi(Resource):
@service_api_ns.doc(
summary="Get App Info",
description="Retrieve basic information about this application, including name, description, tags, and mode.",
tags=["Applications"],
responses={
200: "Basic information of the application.",
},
)
@service_api_ns.doc("get_app_info")
@service_api_ns.doc(description="Get basic application information")
@service_api_ns.doc(

View File

@ -20,6 +20,7 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.service_api.schema import binary_response, expect_with_user, multipart_file_params
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from graphon.model_runtime.errors.invoke import InvokeError
@ -39,8 +40,31 @@ register_response_schema_models(service_api_ns, AudioBinaryResponse, AudioTransc
@service_api_ns.route("/audio-to-text")
class AudioApi(Resource):
@service_api_ns.doc(
summary="Convert Audio to Text",
description=(
"Convert audio file to text. Supported MIME types: `audio/mp3`, `audio/mpga`, `audio/m4a`, "
"`audio/wav`, and `audio/amr`. File size limit is `30 MB`."
),
tags=["TTS"],
responses={
200: "Successfully converted audio to text.",
400: (
"- `app_unavailable` : App unavailable or misconfigured.\n"
"- `provider_not_support_speech_to_text` : Model provider does not support speech-to-text.\n"
"- `provider_not_initialize` : No valid model provider credentials found.\n"
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
"- `model_currently_not_support` : Current model does not support this operation.\n"
"- `completion_request_error` : Speech recognition request failed."
),
413: "`audio_too_large` : Audio file size exceeded the limit.",
415: "`unsupported_audio_type` : Audio type is not allowed.",
500: "`internal_server_error` : Internal server error.",
},
)
@service_api_ns.doc("audio_to_text")
@service_api_ns.doc(description="Convert audio to text using speech-to-text")
@service_api_ns.doc(consumes=["multipart/form-data"], params=multipart_file_params(include_user=True))
@service_api_ns.doc(
responses={
200: "Audio successfully transcribed",
@ -99,7 +123,27 @@ register_schema_model(service_api_ns, TextToAudioPayload)
@service_api_ns.route("/text-to-audio")
class TextApi(Resource):
@service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__])
@service_api_ns.doc(
summary="Convert Text to Audio",
description="Convert text to speech.",
tags=["TTS"],
responses={
200: (
"Returns the generated audio. Generator responses are streamed by the service as `audio/mpeg`; "
"otherwise the provider output is returned directly."
),
400: (
"- `app_unavailable` : App unavailable or misconfigured.\n"
"- `provider_not_initialize` : No valid model provider credentials found.\n"
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
"- `model_currently_not_support` : Current model does not support this operation.\n"
"- `completion_request_error` : Text-to-speech request failed."
),
500: "`internal_server_error` : Internal server error.",
},
)
@expect_with_user(service_api_ns, TextToAudioPayload)
@binary_response(service_api_ns, "audio/mpeg")
@service_api_ns.doc("text_to_audio")
@service_api_ns.doc(description="Convert text to audio using text-to-speech")
@service_api_ns.doc(
@ -110,11 +154,7 @@ class TextApi(Resource):
500: "Internal server error",
}
)
@service_api_ns.response(
200,
"Text successfully converted to audio",
service_api_ns.models[AudioBinaryResponse.__name__],
)
@service_api_ns.response(200, "Text successfully converted to audio")
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser):
"""Convert text to audio using text-to-speech.

View File

@ -20,6 +20,7 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.service_api.schema import expect_user_json, expect_with_user, json_or_event_stream_response
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
@ -92,7 +93,33 @@ register_response_schema_models(service_api_ns, GeneratedAppResponse, SimpleResu
@service_api_ns.route("/completion-messages")
class CompletionApi(Resource):
@service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__])
@service_api_ns.doc(
summary="Send Completion Message",
description="Send a request to the text generation application.",
tags=["Completions"],
responses={
200: (
"Successful response. The content type and structure depend on the `response_mode` parameter "
"in the request.\n"
"\n"
"- If `response_mode` is `blocking`, returns `application/json` with a `CompletionResponse` "
"object.\n"
"- If `response_mode` is `streaming`, returns `text/event-stream` with a stream of "
"`ChunkCompletionEvent` objects."
),
400: (
"- `app_unavailable` : App unavailable or misconfigured.\n"
"- `provider_not_initialize` : No valid model provider credentials found.\n"
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
"- `model_currently_not_support` : Current model unavailable.\n"
"- `completion_request_error` : Text generation failed."
),
429: "`too_many_requests` : Too many concurrent requests for this app.",
500: "`internal_server_error` : Internal server error.",
},
)
@expect_with_user(service_api_ns, CompletionRequestPayload)
@json_or_event_stream_response(service_api_ns)
@service_api_ns.doc("create_completion")
@service_api_ns.doc(description="Create a completion for the given prompt")
@service_api_ns.doc(
@ -168,6 +195,15 @@ class CompletionApi(Resource):
@service_api_ns.route("/completion-messages/<string:task_id>/stop")
class CompletionStopApi(Resource):
@service_api_ns.doc(
summary="Stop Completion Message Generation",
description="Stops a completion message generation task. Only supported in `streaming` mode.",
tags=["Completions"],
responses={
400: "`app_unavailable` : App unavailable or misconfigured.",
},
)
@expect_user_json(service_api_ns)
@service_api_ns.doc("stop_completion")
@service_api_ns.doc(description="Stop a running completion task")
@service_api_ns.doc(params={"task_id": "The ID of the task to stop"})
@ -197,7 +233,39 @@ class CompletionStopApi(Resource):
@service_api_ns.route("/chat-messages")
class ChatApi(Resource):
@service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__])
@service_api_ns.doc(
summary="Send Chat Message",
description="Send a request to the chat application.",
tags=["Chats", "Chatflows"],
responses={
200: (
"Successful response. The content type and structure depend on the `response_mode` parameter "
"in the request.\n"
"\n"
"- If `response_mode` is `blocking`, returns `application/json` with a "
"`ChatCompletionResponse` object.\n"
"- If `response_mode` is `streaming`, returns `text/event-stream` with a stream of "
"Server-Sent Events."
),
400: (
"- `app_unavailable` : App unavailable or misconfigured.\n"
"- `not_chat_app` : App mode does not match the API route.\n"
"- `conversation_completed` : The conversation has ended.\n"
"- `provider_not_initialize` : No valid model provider credentials found.\n"
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
"- `model_currently_not_support` : Current model unavailable.\n"
"- `completion_request_error` : Text generation failed."
),
404: "`not_found` : Conversation does not exist.",
429: (
"- `too_many_requests` : Too many concurrent requests for this app.\n"
"- `rate_limit_error` : The upstream model provider rate limit was exceeded."
),
500: "`internal_server_error` : Internal server error.",
},
)
@expect_with_user(service_api_ns, ChatRequestPayload)
@json_or_event_stream_response(service_api_ns)
@service_api_ns.doc("create_chat_message")
@service_api_ns.doc(description="Send a message in a chat conversation")
@service_api_ns.doc(
@ -276,6 +344,15 @@ class ChatApi(Resource):
@service_api_ns.route("/chat-messages/<string:task_id>/stop")
class ChatStopApi(Resource):
@service_api_ns.doc(
summary="Stop Chat Message Generation",
description="Stops a chat message generation task. Only supported in `streaming` mode.",
tags=["Chats", "Chatflows"],
responses={
400: "`not_chat_app` : App mode does not match the API route.",
},
)
@expect_user_json(service_api_ns)
@service_api_ns.doc("stop_chat_message")
@service_api_ns.doc(description="Stop a running chat message generation")
@service_api_ns.doc(params={"task_id": "The ID of the task to stop"})

View File

@ -13,6 +13,7 @@ from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.schema import expect_user_json, expect_with_user
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
@ -145,6 +146,16 @@ register_response_schema_models(
@service_api_ns.route("/conversations")
class ConversationApi(Resource):
@service_api_ns.doc(
summary="List Conversations",
description="Retrieve the conversation list for the current user, ordered by most recently active.",
tags=["Conversations"],
responses={
200: "Successfully retrieved conversations list.",
400: "`not_chat_app` : App mode does not match the API route.",
404: "`not_found` : Last conversation does not exist (invalid `last_id`).",
},
)
@service_api_ns.doc(params=query_params_from_model(ConversationListQuery))
@service_api_ns.doc("list_conversations")
@service_api_ns.doc(description="List all conversations for the current user")
@ -197,6 +208,17 @@ class ConversationApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>")
class ConversationDetailApi(Resource):
@service_api_ns.doc(
summary="Delete Conversation",
description="Delete a conversation.",
tags=["Conversations"],
responses={
204: "Conversation deleted successfully.",
400: "`not_chat_app` : App mode does not match the API route.",
404: "`not_found` : Conversation does not exist.",
},
)
@expect_user_json(service_api_ns)
@service_api_ns.doc("delete_conversation")
@service_api_ns.doc(description="Delete a specific conversation")
@service_api_ns.doc(params={"c_id": "Conversation ID"})
@ -225,7 +247,20 @@ class ConversationDetailApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/name")
class ConversationRenameApi(Resource):
@service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__])
@service_api_ns.doc(
summary="Rename Conversation",
description=(
"Rename a conversation or auto-generate a name. The conversation name is used for display on "
"clients that support multiple conversations."
),
tags=["Conversations"],
responses={
200: "Conversation renamed successfully.",
400: "`not_chat_app` : App mode does not match the API route.",
404: "`not_found` : Conversation does not exist.",
},
)
@expect_with_user(service_api_ns, ConversationRenamePayload)
@service_api_ns.doc("rename_conversation")
@service_api_ns.doc(description="Rename a conversation or auto-generate a name")
@service_api_ns.doc(params={"c_id": "Conversation ID"})
@ -267,6 +302,16 @@ class ConversationRenameApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/variables")
class ConversationVariablesApi(Resource):
@service_api_ns.doc(
summary="List Conversation Variables",
description="Retrieve variables from a specific conversation.",
tags=["Conversations"],
responses={
200: "Successfully retrieved conversation variables.",
400: "`not_chat_app` : App mode does not match the API route.",
404: "`not_found` : Conversation does not exist.",
},
)
@service_api_ns.doc(params=query_params_from_model(ConversationVariablesQuery))
@service_api_ns.doc("list_conversation_variables")
@service_api_ns.doc(description="List all variables for a conversation")
@ -312,7 +357,22 @@ class ConversationVariablesApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>")
class ConversationVariableDetailApi(Resource):
@service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__])
@service_api_ns.doc(
summary="Update Conversation Variable",
description="Update the value of a specific conversation variable. The value must match the expected type.",
tags=["Conversations"],
responses={
200: "Variable updated successfully.",
400: (
"- `not_chat_app` : App mode does not match the API route.\n"
"- `bad_request` : Variable value type mismatch."
),
404: (
"- `not_found` : Conversation does not exist.\n- `not_found` : Conversation variable does not exist."
),
},
)
@expect_with_user(service_api_ns, ConversationVariableUpdatePayload)
@service_api_ns.doc("update_conversation_variable")
@service_api_ns.doc(description="Update a conversation variable's value")
@service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})

View File

@ -12,6 +12,7 @@ from controllers.common.errors import (
)
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.schema import multipart_file_params
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from fields.file_fields import FileResponse
@ -23,8 +24,27 @@ register_schema_models(service_api_ns, FileResponse)
@service_api_ns.route("/files/upload")
class FileApi(Resource):
@service_api_ns.doc(
summary="Upload File",
description=(
"Upload a file for use when sending messages, enabling multimodal understanding of images, "
"documents, audio, and video. Uploaded files are for use by the current end-user only."
),
tags=["Files"],
responses={
201: "File uploaded successfully.",
400: (
"- `no_file_uploaded` : No file was provided in the request.\n"
"- `too_many_files` : Only one file is allowed per request.\n"
"- `filename_not_exists_error` : The uploaded file has no filename."
),
413: "`file_too_large` : File size exceeded.",
415: "`unsupported_file_type` : File type not allowed.",
},
)
@service_api_ns.doc("upload_file")
@service_api_ns.doc(description="Upload a file for use in conversations")
@service_api_ns.doc(consumes=["multipart/form-data"], params=multipart_file_params(include_user=True))
@service_api_ns.doc(
responses={
201: "File uploaded successfully",

View File

@ -15,6 +15,7 @@ from controllers.service_api.app.error import (
FileAccessDeniedError,
FileNotFoundError,
)
from controllers.service_api.schema import binary_response
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from extensions.ext_storage import storage
@ -30,6 +31,26 @@ class FilePreviewQuery(BaseModel):
register_schema_model(service_api_ns, FilePreviewQuery)
register_response_schema_model(service_api_ns, BinaryFileResponse)
FILE_PREVIEW_RESPONSE_MEDIA_TYPES = [
"application/octet-stream",
"application/pdf",
"audio/aac",
"audio/flac",
"audio/mp4",
"audio/mpeg",
"audio/ogg",
"audio/wav",
"audio/x-m4a",
"image/gif",
"image/jpeg",
"image/png",
"image/webp",
"text/plain",
"video/mp4",
"video/quicktime",
"video/webm",
]
@service_api_ns.route("/files/<uuid:file_id>/preview")
class FilePreviewApi(Resource):
@ -40,7 +61,26 @@ class FilePreviewApi(Resource):
Files can only be accessed if they belong to messages within the requesting app's context.
"""
@service_api_ns.doc(
summary="Download File",
description=(
"Preview or download uploaded files previously uploaded via the [Upload "
"File](/api-reference/files/upload-file) API. Files can only be accessed if they belong to "
"messages within the requesting application."
),
tags=["Files"],
responses={
200: (
"Returns the raw file content. The `Content-Type` header is set to the file's MIME type. If "
"`as_attachment` is `true`, the file is returned as a download with `Content-Disposition: "
"attachment`."
),
403: "`file_access_denied` : Access to the requested file is denied.",
404: "`file_not_found` : The requested file was not found.",
},
)
@service_api_ns.doc(params=query_params_from_model(FilePreviewQuery))
@binary_response(service_api_ns, FILE_PREVIEW_RESPONSE_MEDIA_TYPES)
@service_api_ns.doc("preview_file")
@service_api_ns.doc(description="Preview or download a file uploaded via Service API")
@service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
@ -52,11 +92,7 @@ class FilePreviewApi(Resource):
404: "File not found",
}
)
@service_api_ns.response(
200,
"File retrieved successfully",
service_api_ns.models[BinaryFileResponse.__name__],
)
@service_api_ns.response(200, "File retrieved successfully")
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, file_id: UUID):
"""

View File

@ -18,6 +18,7 @@ from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.schema import expect_with_user
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
@ -72,6 +73,23 @@ def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
@service_api_ns.route("/form/human_input/<string:form_token>")
class WorkflowHumanInputFormApi(Resource):
@service_api_ns.doc(
summary="Get Human Input Form",
description=(
"Retrieve a paused Human Input form's contents using the `form_token` from a "
"`human_input_required` event. Requires **WebApp** delivery."
),
tags=["Human Input"],
responses={
200: "Form contents retrieved successfully.",
404: "`not_found` : Form not found.",
412: (
"- `human_input_form_submitted` : Form already submitted. Forms are one-shot; the first "
"response wins regardless of which user submits it.\n"
"- `human_input_form_expired` : The form's expiration time passed before submission arrived."
),
},
)
@service_api_ns.doc("get_human_input_form")
@service_api_ns.doc(description="Get a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@ -101,7 +119,29 @@ class WorkflowHumanInputFormApi(Resource):
inputs = service.resolve_form_inputs(form)
return _jsonify_form_definition(form, inputs=inputs)
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
@service_api_ns.doc(
summary="Submit Human Input Form",
description=(
"Submit the recipient's response to a paused Human Input form. The workflow resumes on "
"acceptance; use [Stream Workflow Events](/api-reference/chatflows/stream-workflow-events) "
"to follow subsequent events. Requires **WebApp** delivery."
),
tags=["Human Input"],
responses={
200: "Form submitted successfully. The response body is an empty object.",
400: (
"- `bad_request` : Form recipient type is invalid.\n"
"- `invalid_form_data` : Submission failed validation against the form definition."
),
404: "`not_found` : Form not found.",
412: (
"- `human_input_form_submitted` : Form already submitted. Forms are one-shot; the first "
"response wins regardless of which user submits it.\n"
"- `human_input_form_expired` : The form's expiration time passed before submission arrived."
),
},
)
@expect_with_user(service_api_ns, HumanInputFormSubmitPayload)
@service_api_ns.doc("submit_human_input_form")
@service_api_ns.doc(description="Submit a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})

View File

@ -12,6 +12,7 @@ from controllers.common.fields import SimpleResultStringListResponse
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.schema import expect_with_user
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.base import ResponseModel
@ -64,6 +65,19 @@ register_response_schema_models(
@service_api_ns.route("/messages")
class MessageListApi(Resource):
@service_api_ns.doc(
summary="List Conversation Messages",
description=(
"Returns historical chat records in a scrolling load format, with the first page returning "
"the latest `limit` messages, i.e., in reverse order."
),
tags=["Conversations"],
responses={
200: "Successfully retrieved conversation history.",
400: "`not_chat_app` : App mode does not match the API route.",
404: ("- `not_found` : Conversation does not exist.\n- `not_found` : First message does not exist."),
},
)
@service_api_ns.doc(params=query_params_from_model(MessageListQuery))
@service_api_ns.doc("list_messages")
@service_api_ns.doc(description="List messages in a conversation")
@ -112,7 +126,19 @@ class MessageListApi(Resource):
@service_api_ns.route("/messages/<uuid:message_id>/feedbacks")
class MessageFeedbackApi(Resource):
@service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__])
@service_api_ns.doc(
summary="Submit Message Feedback",
description=(
"Submit feedback for a message. End users can rate messages as `like` or `dislike`, and "
"optionally provide text feedback. Pass `null` for `rating` to revoke previously submitted "
"feedback."
),
tags=["Feedback"],
responses={
404: "`not_found` : Message does not exist.",
},
)
@expect_with_user(service_api_ns, MessageFeedbackPayload)
@service_api_ns.response(200, "Feedback submitted successfully", service_api_ns.models[ResultResponse.__name__])
@service_api_ns.doc("create_message_feedback")
@service_api_ns.doc(description="Submit feedback for a message")
@ -150,6 +176,17 @@ class MessageFeedbackApi(Resource):
@service_api_ns.route("/app/feedbacks")
class AppGetFeedbacksApi(Resource):
@service_api_ns.doc(
summary="List App Feedbacks",
description=(
"Retrieve a paginated list of all feedback submitted for messages in this application, "
"including both end-user and admin feedback."
),
tags=["Feedback"],
responses={
200: "A list of application feedbacks.",
},
)
@service_api_ns.doc(params=query_params_from_model(FeedbackListQuery))
@service_api_ns.doc("get_app_feedbacks")
@service_api_ns.doc(description="Get all feedbacks for the application")
@ -177,6 +214,20 @@ class AppGetFeedbacksApi(Resource):
@service_api_ns.route("/messages/<uuid:message_id>/suggested")
class MessageSuggestedApi(Resource):
@service_api_ns.doc(
summary="Get Next Suggested Questions",
description="Get next questions suggestions for the current message.",
tags=["Chats", "Chatflows"],
responses={
200: "Successfully retrieved suggested questions.",
400: (
"- `not_chat_app` : App mode does not match the API route.\n"
"- `bad_request` : Suggested questions feature is disabled."
),
404: "`not_found` : Message does not exist.",
500: "`internal_server_error` : Internal server error.",
},
)
@service_api_ns.response(
200,
"Suggested questions retrieved successfully",

View File

@ -17,6 +17,18 @@ register_response_schema_models(service_api_ns, SiteResponse)
class AppSiteApi(Resource):
"""Resource for app sites."""
@service_api_ns.doc(
summary="Get App WebApp Settings",
description=(
"Retrieve the WebApp settings of this application, including site configuration, theme, and "
"customization options."
),
tags=["Applications"],
responses={
200: "WebApp settings of the application.",
403: "`forbidden` : Site not found for this application or the workspace has been archived.",
},
)
@service_api_ns.doc("get_app_site")
@service_api_ns.doc(description="Get application site configuration")
@service_api_ns.doc(

View File

@ -21,6 +21,11 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.service_api.schema import (
expect_user_json,
expect_with_user,
json_or_event_stream_response,
)
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -177,14 +182,15 @@ register_response_schema_models(
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
status = _enum_value(workflow_run.status)
raw_outputs = workflow_run.outputs_dict
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
outputs: dict = {}
elif isinstance(raw_outputs, dict):
outputs = raw_outputs
elif isinstance(raw_outputs, Mapping):
outputs = dict(raw_outputs)
else:
outputs = {}
match raw_outputs:
case _ if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
outputs: dict = {}
case dict():
outputs = raw_outputs
case _ if isinstance(raw_outputs, Mapping):
outputs = dict(raw_outputs)
case _:
outputs = {}
return WorkflowRunResponse.model_validate(
{
"id": workflow_run.id,
@ -208,6 +214,16 @@ def _serialize_workflow_log_pagination(pagination) -> dict:
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
class WorkflowRunDetailApi(Resource):
@service_api_ns.doc(
summary="Get Workflow Run Detail",
description="Retrieve the current execution results of a workflow task based on the workflow execution ID.",
tags=["Chatflows", "Workflows"],
responses={
200: "Successfully retrieved workflow run details.",
400: "`not_workflow_app` : App mode does not match the API route.",
404: "`not_found` : Workflow run not found.",
},
)
@service_api_ns.doc("get_workflow_run_detail")
@service_api_ns.doc(description="Get workflow run details")
@service_api_ns.doc(params={"workflow_run_id": "Workflow run ID"})
@ -249,7 +265,37 @@ class WorkflowRunDetailApi(Resource):
@service_api_ns.route("/workflows/run")
class WorkflowRunApi(Resource):
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
@service_api_ns.doc(
summary="Run Workflow",
description="Execute a workflow. Cannot be executed without a published workflow.",
tags=["Workflows"],
responses={
200: (
"Successful response. The content type and structure depend on the `response_mode` parameter "
"in the request.\n"
"\n"
"- If `response_mode` is `blocking`, returns `application/json` with a "
"`WorkflowBlockingResponse` object.\n"
"- If `response_mode` is `streaming`, returns `text/event-stream` with a stream of "
"`ChunkWorkflowEvent` objects."
),
400: (
"- `not_workflow_app` : App mode does not match the API route.\n"
"- `provider_not_initialize` : No valid model provider credentials found.\n"
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
"- `model_currently_not_support` : Current model unavailable.\n"
"- `completion_request_error` : Workflow execution request failed.\n"
"- `invalid_param` : Invalid parameter value."
),
429: (
"- `too_many_requests` : Too many concurrent requests for this app.\n"
"- `rate_limit_error` : The upstream model provider rate limit was exceeded."
),
500: "`internal_server_error` : Internal server error.",
},
)
@expect_with_user(service_api_ns, WorkflowRunPayload)
@json_or_event_stream_response(service_api_ns)
@service_api_ns.doc("run_workflow")
@service_api_ns.doc(description="Execute a workflow")
@service_api_ns.doc(
@ -313,7 +359,42 @@ class WorkflowRunApi(Resource):
@service_api_ns.route("/workflows/<string:workflow_id>/run")
class WorkflowRunByIdApi(Resource):
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
@service_api_ns.doc(
summary="Run Workflow by ID",
description=(
"Execute a specific workflow version identified by its ID. Useful for running a particular "
"published version of the workflow."
),
tags=["Workflows"],
responses={
200: (
"Successful response. The content type and structure depend on the `response_mode` parameter "
"in the request.\n"
"\n"
"- If `response_mode` is `blocking`, returns `application/json` with a "
"`WorkflowBlockingResponse` object.\n"
"- If `response_mode` is `streaming`, returns `text/event-stream` with a stream of "
"`ChunkWorkflowEvent` objects."
),
400: (
"- `not_workflow_app` : App mode does not match the API route.\n"
"- `bad_request` : Workflow is a draft or has an invalid ID format.\n"
"- `provider_not_initialize` : No valid model provider credentials found.\n"
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
"- `model_currently_not_support` : Current model unavailable.\n"
"- `completion_request_error` : Workflow execution request failed.\n"
"- `invalid_param` : Required parameter missing or invalid."
),
404: "`not_found` : Workflow not found.",
429: (
"- `too_many_requests` : Too many concurrent requests for this app.\n"
"- `rate_limit_error` : The upstream model provider rate limit was exceeded."
),
500: "`internal_server_error` : Internal server error.",
},
)
@expect_with_user(service_api_ns, WorkflowRunPayload)
@json_or_event_stream_response(service_api_ns)
@service_api_ns.doc("run_workflow_by_id")
@service_api_ns.doc(description="Execute a specific workflow by ID")
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
@ -387,6 +468,18 @@ class WorkflowRunByIdApi(Resource):
@service_api_ns.route("/workflows/tasks/<string:task_id>/stop")
class WorkflowTaskStopApi(Resource):
@service_api_ns.doc(
summary="Stop Workflow Task",
description="Stop a running workflow task. Only supported in `streaming` mode.",
tags=["Workflows"],
responses={
400: (
"- `not_workflow_app` : App mode does not match the API route.\n"
"- `invalid_param` : Required parameter missing or invalid."
),
},
)
@expect_user_json(service_api_ns)
@service_api_ns.doc("stop_workflow_task")
@service_api_ns.doc(description="Stop a running workflow task")
@service_api_ns.doc(params={"task_id": "Task ID to stop"})
@ -417,6 +510,14 @@ class WorkflowTaskStopApi(Resource):
@service_api_ns.route("/workflows/logs")
class WorkflowAppLogApi(Resource):
@service_api_ns.doc(
summary="List Workflow Logs",
description="Retrieve paginated workflow execution logs with filtering options.",
tags=["Chatflows", "Workflows"],
responses={
200: "Successfully retrieved workflow logs.",
},
)
@service_api_ns.doc(params=query_params_from_model(WorkflowLogQuery))
@service_api_ns.doc("get_workflow_logs")
@service_api_ns.doc(description="Get workflow execution logs")

View File

@ -15,6 +15,7 @@ from controllers.common.fields import EventStreamResponse
from controllers.common.schema import query_params_from_model, register_response_schema_model, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.schema import event_stream_response
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
@ -44,6 +45,24 @@ register_response_schema_model(service_api_ns, EventStreamResponse)
class WorkflowEventsApi(Resource):
"""Service API for getting workflow execution events after resume."""
@service_api_ns.doc(
summary="Stream Workflow Events",
description=(
"Resume the Server-Sent Events stream for a workflow run after a pause or a dropped SSE "
"connection. For runs that have already finished, the stream emits a single "
"`workflow_finished` event and closes."
),
tags=["Chatflows", "Workflows"],
responses={
200: (
"Server-Sent Events stream. Each event is delivered as `data: {JSON}\\n\\n`. Event payloads "
"follow the same schemas as the original streaming response."
),
400: "`not_workflow_app` : Please check if your app mode matches the right API route.",
404: "`not_found` : Workflow run not found.",
},
)
@event_stream_response(service_api_ns)
@service_api_ns.doc("get_workflow_events")
@service_api_ns.doc(description="Get workflow execution events stream after resume")
@service_api_ns.doc(params={"task_id": "Workflow run ID"})

View File

@ -1,8 +1,8 @@
from typing import Any, Literal
from typing import Any, Literal, override
from uuid import UUID
from flask import request
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, GetJsonSchemaHandler, RootModel, field_validator, model_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -79,6 +79,13 @@ class DocumentStatusPayload(BaseModel):
document_ids: list[str] = Field(default_factory=list, description="Document IDs to update")
DOCUMENT_STATUS_ACTION_PARAM = {
"description": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'",
"enum": ["enable", "disable", "archive", "un_archive"],
"type": "string",
}
class TagNamePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=50)
@ -114,6 +121,45 @@ class TagUnbindingPayload(BaseModel):
tag_id: str | None = None
target_id: str
@classmethod
@override
def __get_pydantic_json_schema__(cls, _core_schema: object, _handler: GetJsonSchemaHandler) -> dict[str, object]:
tag_id_property = {
"description": "Legacy single tag ID accepted by the Service API.",
"type": "string",
}
tag_ids_property = {
"description": "Tag IDs to unbind. Use this for new integrations.",
"items": {"type": "string"},
"minItems": 1,
"type": "array",
}
target_id_property = {"title": "Target Id", "type": "string"}
return {
"anyOf": [
{
"properties": {
"tag_id": tag_id_property,
"tag_ids": tag_ids_property,
"target_id": target_id_property,
},
"required": ["tag_id", "target_id"],
"type": "object",
},
{
"properties": {
"tag_id": {**tag_id_property, "nullable": True},
"tag_ids": tag_ids_property,
"target_id": target_id_property,
},
"required": ["tag_ids", "target_id"],
"type": "object",
},
],
"description": "Accepts either the legacy tag_id payload or the normalized tag_ids payload.",
"title": cls.__name__,
}
@model_validator(mode="before")
@classmethod
def normalize_legacy_tag_id(cls, data: object) -> object:
@ -204,6 +250,14 @@ register_response_schema_models(
class DatasetListApi(DatasetApiResource):
"""Resource for datasets."""
@service_api_ns.doc(
summary="List Knowledge Bases",
description="Returns a paginated list of knowledge bases. Supports filtering by keyword and tags.",
tags=["Knowledge Bases"],
responses={
200: "List of knowledge bases.",
},
)
@service_api_ns.doc("list_datasets")
@service_api_ns.doc(description="List all datasets")
@service_api_ns.doc(
@ -262,6 +316,19 @@ class DatasetListApi(DatasetApiResource):
}
return dump_response(DatasetListResponse, response), 200
@service_api_ns.doc(
summary="Create an Empty Knowledge Base",
description=(
"Create a new empty knowledge base. After creation, use [Create Document by "
"Text](/api-reference/documents/create-document-by-text) or [Create Document by "
"File](/api-reference/documents/create-document-by-file) to add documents."
),
tags=["Knowledge Bases"],
responses={
200: "Knowledge base created successfully.",
409: "`dataset_name_duplicate` : The dataset name already exists. Please modify your dataset name.",
},
)
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@service_api_ns.doc("create_dataset")
@service_api_ns.doc(description="Create a new dataset")
@ -327,6 +394,19 @@ class DatasetListApi(DatasetApiResource):
class DatasetApi(DatasetApiResource):
"""Resource for dataset."""
@service_api_ns.doc(
summary="Get Knowledge Base",
description=(
"Retrieve detailed information about a specific knowledge base, including its embedding "
"model, retrieval configuration, and document statistics."
),
tags=["Knowledge Bases"],
responses={
200: "Knowledge base details.",
403: "`forbidden` : Insufficient permissions to access this knowledge base.",
404: "`not_found` : Dataset not found.",
},
)
@service_api_ns.doc("get_dataset")
@service_api_ns.doc(description="Get a specific dataset by ID")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -392,6 +472,19 @@ class DatasetApi(DatasetApiResource):
200,
)
@service_api_ns.doc(
summary="Update Knowledge Base",
description=(
"Update the name, description, permissions, or retrieval settings of an existing knowledge "
"base. Only the fields provided in the request body are updated."
),
tags=["Knowledge Bases"],
responses={
200: "Knowledge base updated successfully.",
403: "`forbidden` : Insufficient permissions to access this knowledge base.",
404: "`not_found` : Dataset not found.",
},
)
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
@service_api_ns.doc("update_dataset")
@service_api_ns.doc(description="Update an existing dataset")
@ -474,6 +567,22 @@ class DatasetApi(DatasetApiResource):
return DatasetDetailWithPartialMembersResponse.model_validate(result_data).model_dump(mode="json"), 200
@service_api_ns.doc(
summary="Delete Knowledge Base",
description=(
"Permanently delete a knowledge base and all its documents. The knowledge base must not be "
"in use by any application."
),
tags=["Knowledge Bases"],
responses={
204: "Success.",
404: "`not_found` : Dataset not found.",
409: (
"`dataset_in_use` : The knowledge base is being used by some apps. Please remove it from the "
"apps before deleting."
),
},
)
@service_api_ns.doc("delete_dataset")
@service_api_ns.doc(description="Delete a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -519,6 +628,17 @@ class DatasetApi(DatasetApiResource):
class DocumentStatusApi(DatasetApiResource):
"""Resource for batch document status operations."""
@service_api_ns.doc(
summary="Update Document Status in Batch",
description="Enable, disable, archive, or unarchive multiple documents at once.",
tags=["Documents"],
responses={
200: "Documents updated successfully.",
400: "`invalid_action` : Invalid action.",
403: "`forbidden` : Insufficient permissions.",
404: "`not_found` : Knowledge base not found.",
},
)
@service_api_ns.response(
200,
"Document status updated successfully",
@ -529,7 +649,7 @@ class DocumentStatusApi(DatasetApiResource):
@service_api_ns.doc(
params={
"dataset_id": "Dataset ID",
"action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'",
"action": DOCUMENT_STATUS_ACTION_PARAM,
}
)
@service_api_ns.doc(
@ -591,6 +711,14 @@ class DocumentStatusApi(DatasetApiResource):
@service_api_ns.route("/datasets/tags")
class DatasetTagsApi(DatasetApiResource):
@service_api_ns.doc(
summary="List Knowledge Tags",
description="Returns the list of all knowledge base tags in the workspace.",
tags=["Tags"],
responses={
200: "List of tags.",
},
)
@service_api_ns.doc("list_dataset_tags")
@service_api_ns.doc(description="Get all knowledge type tags")
@service_api_ns.doc(
@ -612,6 +740,14 @@ class DatasetTagsApi(DatasetApiResource):
tags = TagService.get_tags(db.session(), "knowledge", cid)
return dump_response(KnowledgeTagListResponse, tags), 200
@service_api_ns.doc(
summary="Create Knowledge Tag",
description="Create a new tag for organizing knowledge bases.",
tags=["Tags"],
responses={
200: "Tag created successfully.",
},
)
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@service_api_ns.doc(description="Add a knowledge type tag")
@ -634,7 +770,7 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden()
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE), db.session)
response = dump_response(
KnowledgeTagResponse,
@ -642,6 +778,14 @@ class DatasetTagsApi(DatasetApiResource):
)
return response, 200
@service_api_ns.doc(
summary="Update Knowledge Tag",
description="Rename an existing knowledge base tag.",
tags=["Tags"],
responses={
200: "Tag updated successfully.",
},
)
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_tag")
@service_api_ns.doc(description="Update a knowledge type tag")
@ -664,9 +808,9 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
tag_id = payload.tag_id
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id)
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id, db.session)
binding_count = TagService.get_tag_binding_count(tag_id)
binding_count = TagService.get_tag_binding_count(tag_id, db.session)
response = dump_response(
KnowledgeTagResponse,
@ -674,6 +818,14 @@ class DatasetTagsApi(DatasetApiResource):
)
return response, 200
@service_api_ns.doc(
summary="Delete Knowledge Tag",
description="Permanently delete a knowledge base tag. Does not delete the knowledge bases that were tagged.",
tags=["Tags"],
responses={
204: "Success.",
},
)
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@service_api_ns.doc("delete_dataset_tag")
@service_api_ns.doc(description="Delete a knowledge type tag")
@ -688,13 +840,21 @@ class DatasetTagsApi(DatasetApiResource):
def delete(self, _):
"""Delete a knowledge type tag."""
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
TagService.delete_tag(payload.tag_id, db.session)
return "", 204
@service_api_ns.route("/datasets/tags/binding")
class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.doc(
summary="Create Tag Binding",
description="Bind one or more tags to a knowledge base. A knowledge base can have multiple tags.",
tags=["Tags"],
responses={
204: "Success.",
},
)
@service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__])
@service_api_ns.doc("bind_dataset_tags")
@service_api_ns.doc(description="Bind tags to a dataset")
@ -713,7 +873,8 @@ class DatasetTagBindingApi(DatasetApiResource):
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE),
db.session,
)
return "", 204
@ -721,6 +882,14 @@ class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.doc(
summary="Delete Tag Binding",
description="Remove one or more tags from a knowledge base.",
tags=["Tags"],
responses={
204: "Success.",
},
)
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
@service_api_ns.doc("unbind_dataset_tags")
@service_api_ns.doc(description="Unbind tags from a dataset")
@ -739,7 +908,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE),
db.session,
)
return "", 204
@ -747,6 +917,14 @@ class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/tags")
class DatasetTagsBindingStatusApi(DatasetApiResource):
@service_api_ns.doc(
summary="Get Knowledge Base Tags",
description="Returns the list of tags bound to a specific knowledge base.",
tags=["Tags"],
responses={
200: "Tags bound to the knowledge base.",
},
)
@service_api_ns.doc("get_dataset_tags_binding_status")
@service_api_ns.doc(description="Get tags bound to a specific dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -766,6 +944,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
dataset_id = kwargs.get("dataset_id")
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
tags = TagService.get_tags_by_target_id(
"knowledge", current_user.current_tenant_id, str(dataset_id), db.session
)
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200

View File

@ -8,11 +8,12 @@ deprecated in generated API docs so clients migrate toward the canonical paths.
import json
from collections.abc import Mapping
from contextlib import ExitStack
from typing import Any, Literal, Self
from copy import deepcopy
from typing import Any, Literal, Self, override
from uuid import UUID
from flask import request, send_file
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, Field, GetJsonSchemaHandler, field_validator, model_validator
from sqlalchemy import desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
@ -39,6 +40,7 @@ from controllers.service_api.dataset.error import (
DocumentIndexingError,
InvalidMetadataError,
)
from controllers.service_api.schema import binary_response
from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
@ -104,6 +106,36 @@ class DocumentTextUpdate(BaseModel):
raise ValueError("Invalid doc_form.")
return value
@classmethod
@override
def __get_pydantic_json_schema__(cls, core_schema: Any, handler: GetJsonSchemaHandler) -> dict[str, Any]:
schema = handler.resolve_ref_schema(handler(core_schema))
properties = schema.get("properties")
if not isinstance(properties, dict):
return schema
text_branch_properties = deepcopy(properties)
text_branch_properties["text"] = _non_null_property_schema(properties.get("text"))
text_branch_properties["name"] = _non_null_property_schema(properties.get("name"))
no_text_branch_properties = deepcopy(properties)
no_text_branch_properties["text"] = {"type": "null"}
return {
**schema,
"anyOf": [
{
"properties": text_branch_properties,
"required": ["name", "text"],
"type": "object",
},
{
"properties": no_text_branch_properties,
"type": "object",
},
],
}
@model_validator(mode="after")
def check_text_and_name(self) -> Self:
if self.text is not None and self.name is None:
@ -111,6 +143,24 @@ class DocumentTextUpdate(BaseModel):
return self
def _non_null_property_schema(property_schema: object) -> dict[str, Any]:
if not isinstance(property_schema, dict):
return {}
any_of = property_schema.get("anyOf")
if isinstance(any_of, list):
non_null_candidates = [
candidate for candidate in any_of if isinstance(candidate, dict) and candidate.get("type") != "null"
]
if len(non_null_candidates) == 1:
return {
**{key: value for key, value in property_schema.items() if key != "anyOf"},
**deepcopy(non_null_candidates[0]),
}
return deepcopy(property_schema)
class DocumentListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
@ -351,6 +401,24 @@ def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID
class DocumentAddByTextApi(DatasetApiResource):
"""Resource for the canonical text document creation route."""
@service_api_ns.doc(
summary="Create Document by Text",
description=(
"Create a document from raw text content. The document is processed asynchronously — use the "
"returned `batch` ID with [Get Document Indexing Status](/api-reference/documents/"
"get-document-indexing-status) to track progress."
),
tags=["Documents"],
responses={
200: "Document created successfully.",
400: (
"- `provider_not_initialize` : No valid model provider credentials found. Please go to "
"Settings -> Model Provider to complete your provider credentials.\n"
"- `invalid_param` : Knowledge base does not exist. / indexing_technique is required. / "
"Invalid doc_form (must be `text_model`, `hierarchical_model`, or `qa_model`)."
),
},
)
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
@service_api_ns.doc("create_document_by_text")
@service_api_ns.doc(description="Create a new document by providing text content")
@ -409,6 +477,25 @@ class DeprecatedDocumentAddByTextApi(DatasetApiResource):
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for the canonical text document update route."""
@service_api_ns.doc(
summary="Update Document by Text",
description=(
"Update an existing document's text content, name, or processing configuration. Re-triggers "
"indexing if content changes — use the returned `batch` ID with [Get Document Indexing "
"Status](/api-reference/documents/get-document-indexing-status) to track progress."
),
tags=["Documents"],
responses={
200: "Document updated successfully.",
400: (
"- `provider_not_initialize` : No valid model provider credentials found. Please go to "
"Settings -> Model Provider to complete your provider credentials.\n"
"- `invalid_param` : Knowledge base does not exist, name is required when text is "
"provided, or invalid doc_form (must be `text_model`, `hierarchical_model`, or "
"`qa_model`)."
),
},
)
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
@service_api_ns.doc("update_document_by_text")
@service_api_ns.doc(description="Update an existing document by providing text content")
@ -463,11 +550,42 @@ class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
@service_api_ns.route(
"/datasets/<uuid:dataset_id>/document/create_by_file",
"/datasets/<uuid:dataset_id>/document/create-by-file",
doc={
"post": {
"deprecated": True,
"description": (
"Deprecated legacy alias for creating a new document by uploading a file. "
"Use /datasets/{dataset_id}/document/create-by-file instead."
),
}
},
)
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-file")
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
@service_api_ns.doc(
summary="Create Document by File",
description=(
"Create a document by uploading a file. Supports common document formats (PDF, TXT, DOCX, "
"etc.). Processing is asynchronous — use the returned `batch` ID with [Get Document "
"Indexing Status](/api-reference/documents/get-document-indexing-status) to track progress."
),
tags=["Documents"],
responses={
200: "Document created successfully.",
400: (
"- `no_file_uploaded` : Please upload your file.\n"
"- `too_many_files` : Only one file is allowed.\n"
"- `filename_not_exists_error` : The specified filename does not exist.\n"
"- `provider_not_initialize` : No valid model provider credentials found. Please go to "
"Settings -> Model Provider to complete your provider credentials.\n"
"- `invalid_param` : Knowledge base does not exist, external datasets not supported, "
"file too large, unsupported file type, missing required fields, or invalid doc_form "
"(must be `text_model`, `hierarchical_model`, or `qa_model`)."
),
},
)
@service_api_ns.doc("create_document_by_file")
@service_api_ns.doc(description="Create a new document by uploading a file")
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_CREATE_BY_FILE_PARAMS)
@ -658,6 +776,27 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
"""Deprecated resource aliases for file document updates."""
@service_api_ns.doc(
summary="Update Document by File",
description=(
"Update an existing document by uploading a new file. Re-triggers indexing — use the returned "
"`batch` ID with [Get Document Indexing Status](/api-reference/documents/"
"get-document-indexing-status) to track progress."
),
tags=["Documents"],
responses={
200: "Document updated successfully.",
400: (
"- `too_many_files` : Only one file is allowed.\n"
"- `filename_not_exists_error` : The specified filename does not exist.\n"
"- `provider_not_initialize` : No valid model provider credentials found. Please go to "
"Settings -> Model Provider to complete your provider credentials.\n"
"- `invalid_param` : Knowledge base does not exist, external datasets not supported, "
"file too large, unsupported file type, or invalid doc_form (must be `text_model`, "
"`hierarchical_model`, or `qa_model`)."
),
},
)
@service_api_ns.doc("update_document_by_file_deprecated")
@service_api_ns.doc(deprecated=True)
@service_api_ns.doc(
@ -686,6 +825,18 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
class DocumentListApi(DatasetApiResource):
@service_api_ns.doc(
summary="List Documents",
description=(
"Returns a paginated list of documents in the knowledge base. Supports filtering by keyword "
"and indexing status."
),
tags=["Documents"],
responses={
200: "List of documents.",
404: "`not_found` : Knowledge base not found.",
},
)
@service_api_ns.doc("list_documents")
@service_api_ns.doc(description="List all documents in a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", **query_params_from_model(DocumentListQuery)})
@ -746,6 +897,19 @@ class DocumentListApi(DatasetApiResource):
class DocumentBatchDownloadZipApi(DatasetApiResource):
"""Download multiple uploaded-file documents as a single ZIP archive."""
@service_api_ns.doc(
summary="Download Documents as ZIP",
description=(
"Download multiple uploaded-file documents as a single ZIP archive. Accepts up to `100` document IDs."
),
tags=["Documents"],
responses={
200: "ZIP archive containing the requested documents.",
403: "`forbidden` : Insufficient permissions.",
404: "`not_found` : Document or dataset not found.",
},
)
@binary_response(service_api_ns, "application/zip")
@service_api_ns.expect(service_api_ns.models[DocumentBatchDownloadZipPayload.__name__])
@service_api_ns.doc("download_documents_as_zip")
@service_api_ns.doc(description="Download selected uploaded documents as a single ZIP archive")
@ -758,11 +922,7 @@ class DocumentBatchDownloadZipApi(DatasetApiResource):
404: "Document or dataset not found",
}
)
@service_api_ns.response(
200,
"ZIP archive generated successfully",
service_api_ns.models[BinaryFileResponse.__name__],
)
@service_api_ns.response(200, "ZIP archive generated successfully")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id: UUID):
payload = DocumentBatchDownloadZipPayload.model_validate(service_api_ns.payload or {})
@ -789,6 +949,20 @@ class DocumentBatchDownloadZipApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
class DocumentIndexingStatusApi(DatasetApiResource):
@service_api_ns.doc(
summary="Get Document Indexing Status",
description=(
"Check the indexing progress of documents in a batch. Returns the current processing stage "
"and chunk completion counts for each document. Poll this endpoint until `indexing_status` "
"reaches `completed` or `error`. The status progresses through: `waiting` → `parsing` → "
"`cleaning` → `splitting` → `indexing` → `completed`."
),
tags=["Documents"],
responses={
200: "Indexing status for documents in the batch.",
404: "`not_found` : Knowledge base not found. / Documents not found.",
},
)
@service_api_ns.doc("get_document_indexing_status")
@service_api_ns.doc(description="Get indexing status for documents in a batch")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "batch": "Batch ID"})
@ -861,6 +1035,16 @@ class DocumentIndexingStatusApi(DatasetApiResource):
class DocumentDownloadApi(DatasetApiResource):
"""Return a signed download URL for a document's original uploaded file."""
@service_api_ns.doc(
summary="Download Document",
description="Get a signed download URL for a document's original uploaded file.",
tags=["Documents"],
responses={
200: "Download URL generated successfully.",
403: "`forbidden` : No permission to access this document.",
404: "`not_found` : Document not found.",
},
)
@service_api_ns.doc("get_document_download_url")
@service_api_ns.doc(description="Get a signed download URL for a document's original uploaded file")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -895,6 +1079,24 @@ class DocumentDownloadApi(DatasetApiResource):
class DocumentApi(DatasetApiResource):
METADATA_CHOICES = {"all", "only", "without"}
@service_api_ns.doc(
summary="Get Document",
description=(
"Retrieve detailed information about a specific document, including its indexing status, "
"metadata, and processing statistics."
),
tags=["Documents"],
responses={
200: (
"Document details. The response shape varies based on the `metadata` query parameter. When "
"`metadata` is `only`, only `id`, `doc_type`, and `doc_metadata` are returned. When "
"`metadata` is `without`, `doc_type` and `doc_metadata` are omitted."
),
400: "`invalid_metadata` : Invalid metadata value for the specified key.",
403: "`forbidden` : No permission.",
404: "`not_found` : Document not found.",
},
)
@service_api_ns.doc("get_document")
@service_api_ns.doc(description="Get a specific document by ID")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -1036,6 +1238,17 @@ class DocumentApi(DatasetApiResource):
"""Update document by file on the canonical document resource."""
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
@service_api_ns.doc(
summary="Delete Document",
description="Permanently delete a document and all its chunks from the knowledge base.",
tags=["Documents"],
responses={
204: "Success.",
400: "`document_indexing` : Cannot delete document during indexing.",
403: "`archived_document_immutable` : The archived document is not editable.",
404: "`not_found` : Document Not Exists.",
},
)
@service_api_ns.doc("delete_document")
@service_api_ns.doc(description="Delete a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})

View File

@ -13,6 +13,32 @@ register_response_schema_models(service_api_ns, HitTestingResponse)
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
@service_api_ns.doc(
summary="Retrieve Chunks from a Knowledge Base / Test Retrieval",
description=(
"Performs a search query against a knowledge base to retrieve the most relevant chunks. This "
"endpoint can be used for both production retrieval and test retrieval."
),
tags=["Knowledge Bases"],
responses={
200: "Retrieval results.",
400: (
"- `dataset_not_initialized` : The dataset is still being initialized or indexing. Please "
"wait a moment.\n"
"- `provider_not_initialize` : No valid model provider credentials found. Please go to "
"Settings -> Model Provider to complete your provider credentials.\n"
"- `provider_quota_exceeded` : Your quota for Dify Hosted OpenAI has been exhausted. Please "
"go to Settings -> Model Provider to complete your own provider credentials.\n"
"- `model_currently_not_support` : Dify Hosted OpenAI trial currently not support the GPT-4 "
"model.\n"
"- `completion_request_error` : Completion request failed.\n"
"- `invalid_param` : Invalid parameter value."
),
403: "`forbidden` : Insufficient permissions.",
404: "`not_found` : Knowledge base not found.",
500: "`internal_server_error` : An internal error occurred during retrieval.",
},
)
@service_api_ns.doc("dataset_hit_testing")
@service_api_ns.doc(description="Perform hit testing on a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})

View File

@ -24,6 +24,12 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.metadata_service import MetadataService
BUILT_IN_METADATA_ACTION_PARAM = {
"description": "Action to perform: 'enable' or 'disable'",
"enum": ["enable", "disable"],
"type": "string",
}
register_schema_model(service_api_ns, MetadataUpdatePayload)
register_schema_models(
service_api_ns,
@ -43,6 +49,17 @@ register_response_schema_models(
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.doc(
summary="Create Metadata Field",
description=(
"Create a custom metadata field for the knowledge base. Metadata fields can be used to "
"annotate documents with structured information."
),
tags=["Metadata"],
responses={
201: "Metadata field created successfully.",
},
)
@service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__])
@service_api_ns.doc("create_dataset_metadata")
@service_api_ns.doc(description="Create metadata for a dataset")
@ -71,6 +88,17 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return dump_response(DatasetMetadataResponse, metadata), 201
@service_api_ns.doc(
summary="List Metadata Fields",
description=(
"Returns the list of all metadata fields (both custom and built-in) for the knowledge base, "
"along with the count of documents using each field."
),
tags=["Metadata"],
responses={
200: "Metadata fields for the knowledge base.",
},
)
@service_api_ns.doc("get_dataset_metadata")
@service_api_ns.doc(description="Get all metadata for a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -96,6 +124,14 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
class DatasetMetadataServiceApi(DatasetApiResource):
@service_api_ns.doc(
summary="Update Metadata Field",
description="Rename a custom metadata field.",
tags=["Metadata"],
responses={
200: "Metadata field updated successfully.",
},
)
@service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_metadata")
@service_api_ns.doc(description="Update metadata name")
@ -125,6 +161,17 @@ class DatasetMetadataServiceApi(DatasetApiResource):
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return dump_response(DatasetMetadataResponse, metadata), 200
@service_api_ns.doc(
summary="Delete Metadata Field",
description=(
"Permanently delete a custom metadata field. Documents using this field will lose their "
"metadata values for it."
),
tags=["Metadata"],
responses={
204: "Success.",
},
)
@service_api_ns.doc("delete_dataset_metadata")
@service_api_ns.doc(description="Delete metadata")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
@ -152,6 +199,16 @@ class DatasetMetadataServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
@service_api_ns.doc(
summary="Get Built-in Metadata Fields",
description=(
"Returns the list of built-in metadata fields provided by the system (e.g., document type, source URL)."
),
tags=["Metadata"],
responses={
200: "Built-in metadata fields.",
},
)
@service_api_ns.doc("get_built_in_fields")
@service_api_ns.doc(description="Get all built-in metadata fields")
@service_api_ns.doc(
@ -173,9 +230,17 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@service_api_ns.doc(
summary="Update Built-in Metadata Field",
description="Enable or disable built-in metadata fields for the knowledge base.",
tags=["Metadata"],
responses={
200: "Built-in metadata field toggled successfully.",
},
)
@service_api_ns.doc("toggle_built_in_field")
@service_api_ns.doc(description="Enable or disable built-in metadata field")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "action": "Action to perform: 'enable' or 'disable'"})
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "action": BUILT_IN_METADATA_ACTION_PARAM})
@service_api_ns.doc(
responses={
200: "Action completed successfully",
@ -205,6 +270,17 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
class DocumentMetadataEditServiceApi(DatasetApiResource):
@service_api_ns.doc(
summary="Update Document Metadata in Batch",
description=(
"Update metadata values for multiple documents at once. Each document in the request "
"receives the specified metadata key-value pairs."
),
tags=["Metadata"],
responses={
200: "Document metadata updated successfully.",
},
)
@service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__])
@service_api_ns.doc("update_documents_metadata")
@service_api_ns.doc(description="Update metadata for multiple documents")

View File

@ -19,6 +19,11 @@ from controllers.common.schema import (
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.dataset.rag_pipeline.serializers import serialize_upload_file
from controllers.service_api.schema import (
event_stream_response,
json_or_event_stream_response,
multipart_file_params,
)
from controllers.service_api.wraps import DatasetApiResource
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
@ -95,6 +100,18 @@ register_response_schema_models(
class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins."""
@service_api_ns.doc(
summary="List Datasource Plugins",
description=(
"List the datasource nodes configured in the knowledge pipeline. Each node includes the "
"plugin it uses plus the metadata needed to run it."
),
tags=["Knowledge Pipeline"],
responses={
200: "List of datasource nodes configured in the pipeline.",
404: "`not_found` : Dataset not found.",
},
)
@service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins")
@service_api_ns.doc(description="List all datasource plugins for a rag pipeline")
@service_api_ns.doc(
@ -137,6 +154,19 @@ class DatasourcePluginsApi(DatasetApiResource):
class DatasourceNodeRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@service_api_ns.doc(
summary="Run Datasource Node",
description=(
"Execute a single datasource node within the knowledge pipeline. Returns a streaming "
"response with the node execution results."
),
tags=["Knowledge Pipeline"],
responses={
200: "Streaming response with node execution events.",
404: "`not_found` : Dataset not found.",
},
)
@event_stream_response(service_api_ns)
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
@service_api_ns.doc(
@ -195,6 +225,24 @@ class DatasourceNodeRunApi(DatasetApiResource):
class PipelineRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@service_api_ns.doc(
summary="Run Pipeline",
description=(
"Execute the full knowledge pipeline for a knowledge base. Supports both streaming and "
"blocking response modes."
),
tags=["Knowledge Pipeline"],
responses={
200: (
"Pipeline execution result. Format depends on `response_mode`: streaming returns a "
"`text/event-stream`, blocking returns a JSON object."
),
403: "`forbidden` : Forbidden.",
404: "`not_found` : Dataset not found.",
500: "`pipeline_run_error` : Pipeline execution failed.",
},
)
@json_or_event_stream_response(service_api_ns)
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
@service_api_ns.doc(
@ -248,8 +296,24 @@ class PipelineRunApi(DatasetApiResource):
class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
"""Resource for uploading a file to a knowledgebase pipeline."""
@service_api_ns.doc(
summary="Upload Pipeline File",
description="Upload a file for use in a knowledge pipeline. Accepts a single file via `multipart/form-data`.",
tags=["Knowledge Pipeline"],
responses={
201: "File uploaded successfully.",
400: (
"- `no_file_uploaded` : Please upload your file.\n"
"- `filename_not_exists_error` : The specified filename does not exist.\n"
"- `too_many_files` : Only one file is allowed."
),
413: "`file_too_large` : File size exceeded.",
415: "`unsupported_file_type` : File type not allowed.",
},
)
@service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload")
@service_api_ns.doc(description="Upload a file to a knowledgebase pipeline")
@service_api_ns.doc(consumes=["multipart/form-data"], params=multipart_file_params(include_user=False))
@service_api_ns.doc(
responses={
201: "File uploaded successfully",

View File

@ -128,6 +128,18 @@ register_response_schema_models(
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
@service_api_ns.doc(
summary="Create Chunks",
description=(
"Create one or more chunks within a document. Each chunk can include optional keywords and an "
"answer field (for QA-mode documents)."
),
tags=["Chunks"],
responses={
200: "Chunks created successfully.",
404: "`not_found` : Document is not completed or is disabled.",
},
)
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
@service_api_ns.doc("create_segments")
@service_api_ns.doc(description="Create segments in a document")
@ -209,6 +221,14 @@ class SegmentApi(DatasetApiResource):
}
return dump_response(SegmentCreateListResponse, response), 200
@service_api_ns.doc(
summary="List Chunks",
description="Returns a paginated list of chunks within a document. Supports filtering by keyword and status.",
tags=["Chunks"],
responses={
200: "List of chunks.",
},
)
@service_api_ns.doc("list_segments")
@service_api_ns.doc(description="List segments in a document")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@ -294,6 +314,14 @@ class SegmentApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
class DatasetSegmentApi(DatasetApiResource):
@service_api_ns.doc(
summary="Delete Chunk",
description="Permanently delete a chunk from the document.",
tags=["Chunks"],
responses={
204: "Success.",
},
)
@service_api_ns.doc("delete_segment")
@service_api_ns.doc(description="Delete a specific segment")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@ -329,6 +357,14 @@ class DatasetSegmentApi(DatasetApiResource):
SegmentService.delete_segment(segment, document, dataset)
return "", 204
@service_api_ns.doc(
summary="Update Chunk",
description="Update a chunk's content, keywords, or answer. Re-triggers indexing for the modified chunk.",
tags=["Chunks"],
responses={
200: "Chunk updated successfully.",
},
)
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment")
@service_api_ns.doc(description="Update a specific segment")
@ -391,6 +427,17 @@ class DatasetSegmentApi(DatasetApiResource):
}
return dump_response(SegmentDetailResponse, response), 200
@service_api_ns.doc(
summary="Get Chunk",
description=(
"Retrieve detailed information about a specific chunk, including its content, keywords, and "
"indexing status."
),
tags=["Chunks"],
responses={
200: "Chunk details.",
},
)
@service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@ -442,6 +489,15 @@ class DatasetSegmentApi(DatasetApiResource):
class ChildChunkApi(DatasetApiResource):
"""Resource for child chunks."""
@service_api_ns.doc(
summary="Create Child Chunk",
description="Create a child chunk under the specified segment.",
tags=["Chunks"],
responses={
200: "Child chunk created successfully.",
400: "`invalid_param` : Create child chunk index failed.",
},
)
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
@service_api_ns.doc("create_child_chunk")
@service_api_ns.doc(description="Create a new child chunk for a segment")
@ -511,6 +567,14 @@ class ChildChunkApi(DatasetApiResource):
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
@service_api_ns.doc(
summary="List Child Chunks",
description="Returns a paginated list of child chunks under a specific parent chunk.",
tags=["Chunks"],
responses={
200: "List of child chunks.",
},
)
@service_api_ns.doc("list_child_chunks")
@service_api_ns.doc(description="List child chunks for a segment")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@ -576,6 +640,15 @@ class ChildChunkApi(DatasetApiResource):
class DatasetChildChunkApi(DatasetApiResource):
"""Resource for updating child chunks."""
@service_api_ns.doc(
summary="Delete Child Chunk",
description="Permanently delete a child chunk from its parent chunk.",
tags=["Chunks"],
responses={
204: "Success.",
400: "`invalid_param` : Delete child chunk index failed.",
},
)
@service_api_ns.doc("delete_child_chunk")
@service_api_ns.doc(description="Delete a specific child chunk")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@ -634,6 +707,15 @@ class DatasetChildChunkApi(DatasetApiResource):
return "", 204
@service_api_ns.doc(
summary="Update Child Chunk",
description="Update the content of an existing child chunk.",
tags=["Chunks"],
responses={
200: "Child chunk updated successfully.",
400: "`invalid_param` : Update child chunk index failed.",
},
)
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk")
@service_api_ns.doc(description="Update a specific child chunk")

View File

@ -17,6 +17,18 @@ register_response_schema_models(service_api_ns, EndUserDetail)
class EndUserApi(Resource):
"""Resource for retrieving end user details by ID."""
@service_api_ns.doc(
summary="Get End User Info",
description=(
"Retrieve an end user by ID. Useful when other APIs return an end-user ID (e.g., "
"`created_by` from [Upload File](/api-reference/files/upload-file))."
),
tags=["End Users"],
responses={
200: "End user retrieved successfully.",
404: "`end_user_not_found` : End user not found.",
},
)
@service_api_ns.doc("get_end_user")
@service_api_ns.doc(description="Get an end user by ID")
@service_api_ns.doc(

View File

@ -0,0 +1,113 @@
"""Service API OpenAPI documentation helpers.
These helpers keep documentation-only request shapes next to controller
definitions without changing the Pydantic models used for runtime validation.
"""
from __future__ import annotations
from collections.abc import Sequence
from copy import deepcopy
from typing import cast
from flask_restx import Namespace
from pydantic import BaseModel
USER_PROPERTY_SCHEMA: dict[str, object] = {"description": "End user identifier", "type": "string"}
USER_QUERY_PARAM: dict[str, object] = {"description": "End user identifier", "in": "query", "type": "string"}
USER_FORM_PARAM: dict[str, object] = {"description": "End user identifier", "in": "formData", "type": "string"}
FILE_FORM_PARAM: dict[str, object] = {"in": "formData", "required": True, "type": "file"}
USER_FETCH_FROM_ATTR = "_dify_service_api_user_fetch_from"
USER_REQUIRED_ATTR = "_dify_service_api_user_required"
JSON_USER_FETCH_FROM = "JSON"
def expect_with_user(namespace: Namespace, model: type[BaseModel]):
"""Document a JSON request body as ``model`` plus Service API ``user``."""
source_model = namespace.models[model.__name__]
model_name = f"{model.__name__}WithUser"
def decorator(view_func):
required = _json_user_required(view_func)
schema = cast(dict[str, object], deepcopy(source_model.__schema__))
_add_user_property(schema, required=required)
if model_name not in namespace.models:
namespace.schema_model(model_name, schema)
return namespace.expect(namespace.models[model_name], validate=False)(view_func)
return decorator
def expect_user_json(namespace: Namespace):
"""Document a JSON request body that only carries the Service API ``user``."""
def decorator(view_func):
required = _json_user_required(view_func)
schema: dict[str, object] = {"properties": {}, "title": "ServiceApiUserPayload", "type": "object"}
_add_user_property(schema, required=required)
model_name = "RequiredServiceApiUserPayload" if required else "OptionalServiceApiUserPayload"
if model_name not in namespace.models:
namespace.schema_model(model_name, schema)
return namespace.expect(namespace.models[model_name], validate=False)(view_func)
return decorator
def multipart_file_params(*, include_user: bool) -> dict[str, dict[str, object]]:
params: dict[str, dict[str, object]] = {"file": FILE_FORM_PARAM}
if include_user:
params["user"] = USER_FORM_PARAM
return deepcopy(params)
def json_or_event_stream_response(namespace: Namespace):
return namespace.doc(produces=["application/json", "text/event-stream"])
def event_stream_response(namespace: Namespace):
return namespace.doc(produces=["text/event-stream"])
def binary_response(namespace: Namespace, media_type: str | Sequence[str]):
media_types = [media_type] if isinstance(media_type, str) else list(media_type)
return namespace.doc(produces=media_types)
def _json_user_required(view_func) -> bool:
fetch_from = getattr(view_func, USER_FETCH_FROM_ATTR, None)
if fetch_from != JSON_USER_FETCH_FROM:
raise ValueError("JSON user documentation must match validate_app_token(fetch_user_arg=WhereisUserArg.JSON)")
return bool(getattr(view_func, USER_REQUIRED_ATTR, False))
def _add_user_property(schema: dict[str, object], *, required: bool) -> None:
variants: list[dict[str, object]] = []
for keyword in ("anyOf", "oneOf"):
candidates = schema.get(keyword)
if isinstance(candidates, list):
variants.extend(candidate for candidate in candidates if isinstance(candidate, dict))
if variants:
for variant in variants:
_add_user_property_to_object_schema(variant, required=required)
_add_user_property_to_object_schema(schema, required=required)
def _add_user_property_to_object_schema(schema: dict[str, object], *, required: bool) -> None:
properties = schema.setdefault("properties", {})
if isinstance(properties, dict):
cast(dict[str, object], properties)["user"] = USER_PROPERTY_SCHEMA
if required:
required_fields = schema.setdefault("required", [])
if isinstance(required_fields, list) and "user" not in required_fields:
required_fields.append("user")
else:
required_fields = schema.get("required")
if isinstance(required_fields, list) and "user" in required_fields:
required_fields.remove("user")
if required_fields == []:
schema.pop("required", None)

View File

@ -19,6 +19,17 @@ register_response_schema_models(service_api_ns, ProviderWithModelsListResponse)
@service_api_ns.route("/workspaces/current/models/model-types/<string:model_type>")
class ModelProviderAvailableModelApi(Resource):
@service_api_ns.doc(
summary="Get Available Models",
description=(
"Retrieve the list of available models by type. Primarily used to query `text-embedding` and "
"`rerank` models for knowledge base configuration."
),
tags=["Models"],
responses={
200: "Available models for the specified type.",
},
)
@service_api_ns.doc("get_available_models")
@service_api_ns.doc(description="Get available models by model type")
@service_api_ns.doc(params={"model_type": "Type of model to retrieve"})

View File

@ -4,16 +4,23 @@ import time
from collections.abc import Callable
from enum import StrEnum, auto
from functools import wraps
from typing import cast, overload
from typing import Protocol, cast, overload
from flask import current_app, request
from flask_login import user_logged_in
from flask_restx import Resource
from flask_restx.utils import merge
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from configs import dify_config
from controllers.service_api.schema import (
USER_FETCH_FROM_ATTR,
USER_FORM_PARAM,
USER_QUERY_PARAM,
USER_REQUIRED_ATTR,
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -28,6 +35,12 @@ from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
class _RestxDocumentedView(Protocol):
"""Callable view object carrying Flask-RESTX documentation metadata."""
__apidoc__: dict[str, object]
class WhereisUserArg(StrEnum):
"""
Enum for whereis_user_arg.
@ -43,6 +56,35 @@ class FetchUserArg(BaseModel):
required: bool = False
APP_TOKEN_FORBIDDEN_RESPONSE = {
403: "Forbidden - token scope, app, dataset, or workspace access denied",
}
DATASET_TOKEN_AUTH_RESPONSES = {
401: "Unauthorized - invalid API token",
403: "Forbidden - dataset API access or workspace access denied",
}
def _document_app_token_contract(view_func: Callable[..., object], fetch_user_arg: FetchUserArg | None) -> None:
doc: dict[str, object] = {"responses": APP_TOKEN_FORBIDDEN_RESPONSE}
if fetch_user_arg is not None:
setattr(view_func, USER_FETCH_FROM_ATTR, fetch_user_arg.fetch_from.name)
setattr(view_func, USER_REQUIRED_ATTR, fetch_user_arg.required)
match fetch_user_arg.fetch_from:
case WhereisUserArg.QUERY:
doc["params"] = {"user": {**USER_QUERY_PARAM, "required": fetch_user_arg.required}}
case WhereisUserArg.FORM:
doc["params"] = {"user": {**USER_FORM_PARAM, "required": fetch_user_arg.required}}
case WhereisUserArg.JSON:
pass
cast(_RestxDocumentedView, view_func).__apidoc__ = cast(
dict[str, object],
merge(getattr(view_func, "__apidoc__", {}), doc),
)
@overload
def validate_app_token[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
@ -126,6 +168,7 @@ def validate_app_token[**P, R](
return view_func(*args, **kwargs)
_document_app_token_contract(decorated_view, fetch_user_arg)
return decorated_view
if view is None:
@ -343,6 +386,8 @@ def validate_and_get_api_token(scope: str | None = None):
class DatasetApiResource(Resource):
__apidoc__ = {"responses": DATASET_TOKEN_AUTH_RESPONSES}
method_decorators = [validate_dataset_token]
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:

View File

@ -118,7 +118,7 @@ class BaseAgentRunner(AppRunner):
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.query: str | None = ""
self.query: str = ""
self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(

View File

@ -72,13 +72,48 @@ def publish_text_answer(
both the backend-produced answer and short-circuited answers (moderation /
annotation reply) share the exact same persistence + SSE path.
"""
publish_text_delta(
queue_manager=queue_manager,
model_name=model_name,
delta=answer,
user_query=user_query,
)
publish_message_end(
queue_manager=queue_manager,
model_name=model_name,
answer=answer,
user_query=user_query,
)
def publish_text_delta(
*,
queue_manager: AppQueueManager,
model_name: str,
delta: str,
user_query: str | None = None,
) -> None:
"""Publish one assistant text delta through the EasyUI chat pipeline."""
if not delta:
return
prompt_messages = _prompt_messages_from_query(user_query)
chunk = LLMResultChunk(
model=model_name,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=answer)),
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=delta)),
)
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
def publish_message_end(
*,
queue_manager: AppQueueManager,
model_name: str,
answer: str,
user_query: str | None = None,
) -> None:
"""Publish the terminal assistant result without emitting another delta."""
prompt_messages = _prompt_messages_from_query(user_query)
queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
@ -151,7 +186,12 @@ class AgentAppRunner:
)
create_response = self._agent_backend_client.create_run(runtime.request)
terminal = self._consume_stream(create_response.run_id, queue_manager=queue_manager)
terminal, streamed_answer = self._consume_stream(
create_response.run_id,
queue_manager=queue_manager,
model_name=model_name,
query=query,
)
if isinstance(terminal, AgentBackendDeferredToolCallInternalEvent):
# ENG-635: the agent asked a human. End this turn with the question and
@ -175,7 +215,13 @@ class AgentAppRunner:
raise AgentBackendError(str(error))
answer = self._extract_answer(terminal.output)
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, query=query)
self._publish_terminal_answer(
queue_manager=queue_manager,
model_name=model_name,
answer=answer,
query=query,
streamed_answer=streamed_answer,
)
self._save_session(
scope=scope,
backend_run_id=terminal.run_id,
@ -272,8 +318,16 @@ class AgentAppRunner:
parts.append(args.markdown)
return "\n\n".join(parts)
def _consume_stream(self, run_id: str, *, queue_manager: AppQueueManager):
def _consume_stream(
self,
run_id: str,
*,
queue_manager: AppQueueManager,
model_name: str,
query: str | None,
):
terminal = None
streamed_answer_parts: list[str] = []
for public_event in self._agent_backend_client.stream_events(run_id):
if queue_manager.is_stopped():
self._cancel_run(run_id)
@ -286,16 +340,23 @@ class AgentAppRunner:
AgentBackendInternalEventType.RUN_STARTED,
AgentBackendInternalEventType.STREAM_EVENT,
):
# Stream deltas are accumulated by the backend into the
# terminal output; token-level forwarding is an S3 refinement.
if isinstance(internal_event, AgentBackendStreamInternalEvent):
text_delta = self._extract_stream_text_delta(internal_event)
if text_delta:
streamed_answer_parts.append(text_delta)
publish_text_delta(
queue_manager=queue_manager,
model_name=model_name,
delta=text_delta,
user_query=query,
)
continue
continue
terminal = internal_event
break
if terminal is not None:
break
return terminal
return terminal, "".join(streamed_answer_parts)
def _cancel_run(self, run_id: str) -> None:
try:
@ -310,6 +371,35 @@ class AgentAppRunner:
# task pipeline streams the chunk over SSE and persists the message.
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, user_query=query)
def _publish_terminal_answer(
self,
*,
queue_manager: AppQueueManager,
model_name: str,
answer: str,
query: str | None,
streamed_answer: str,
) -> None:
"""Finish a successful streamed turn without duplicating the final text."""
if not streamed_answer:
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, query=query)
return
if answer.startswith(streamed_answer):
publish_text_delta(
queue_manager=queue_manager,
model_name=model_name,
delta=answer[len(streamed_answer) :],
user_query=query,
)
elif answer != streamed_answer:
logger.warning(
"Agent App streamed answer does not match terminal output; "
"using terminal output for message persistence."
)
publish_message_end(queue_manager=queue_manager, model_name=model_name, answer=answer, user_query=query)
def _save_session(
self,
*,
@ -357,5 +447,27 @@ class AgentAppRunner:
return json.dumps(output, ensure_ascii=False)
return json.dumps(output, ensure_ascii=False)
@staticmethod
def _extract_stream_text_delta(event: AgentBackendStreamInternalEvent) -> str | None:
data = event.data
if not isinstance(data, dict):
return None
__all__ = ["AgentAppRunner", "publish_text_answer"]
if data.get("event_kind") == "part_delta":
delta = data.get("delta")
if isinstance(delta, dict) and delta.get("part_delta_kind") == "text":
content_delta = delta.get("content_delta")
if isinstance(content_delta, str):
return content_delta
if data.get("event_kind") == "part_start":
part = data.get("part")
if isinstance(part, dict) and part.get("part_kind") == "text":
content = part.get("content")
if isinstance(content, str):
return content
return None
__all__ = ["AgentAppRunner", "publish_message_end", "publish_text_answer", "publish_text_delta"]

View File

@ -231,22 +231,23 @@ class AppRunner:
:param tenant_id: tenant id for multimodal output
:return:
"""
if not stream and isinstance(invoke_result, LLMResult):
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager,
)
elif stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
)
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
match invoke_result:
case LLMResult() if not stream:
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager,
)
case _ if stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
)
case _:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(
self,

View File

@ -882,7 +882,7 @@ class WorkflowResponseConverter:
return files
@classmethod
def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None:
def _get_file_var_from_value(cls, value: object) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value
@ -891,10 +891,11 @@ class WorkflowResponseConverter:
if not value:
return None
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, File):
return value.to_dict()
match value:
case dict() if value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
return value
case File():
return value.to_dict()
return None

View File

@ -241,7 +241,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Mapping[str, object] = Field(default_factory=dict)
created_at: int
finished_at: int | None
exceptions_count: int | None = 0
exceptions_count: int = 0
files: Sequence[Mapping[str, Any]] | None = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED

View File

@ -144,15 +144,16 @@ def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str,
Returns an empty dict if the context is missing or incomplete.
"""
parent_trace_context = args.get("parent_trace_context")
if isinstance(parent_trace_context, ParentTraceContext):
context = parent_trace_context
elif isinstance(parent_trace_context, Mapping):
try:
context = ParentTraceContext.model_validate(parent_trace_context)
except ValidationError:
match parent_trace_context:
case ParentTraceContext():
context = parent_trace_context
case Mapping():
try:
context = ParentTraceContext.model_validate(parent_trace_context)
except ValidationError:
return {}
case _:
return {}
else:
return {}
if context.parent_node_execution_id is None:
return {}

View File

@ -116,20 +116,21 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
return value if isinstance(value, str) else str(value)
case PluginParameterType.BOOLEAN:
if value is None:
return False
elif isinstance(value, str):
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
# and also '0' for False and '1' for True
match value.lower():
case "true" | "yes" | "y" | "1":
return True
case "false" | "no" | "n" | "0":
return False
case _:
return bool(value)
else:
return value if isinstance(value, bool) else bool(value)
match value:
case None:
return False
case str():
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
# and also '0' for False and '1' for True
match value.lower():
case "true" | "yes" | "y" | "1":
return True
case "false" | "no" | "n" | "0":
return False
case _:
return bool(value)
case _:
return value if isinstance(value, bool) else bool(value)
case PluginParameterType.NUMBER:
match value:

View File

@ -71,8 +71,8 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
mode: str
completion_params: dict[str, Any] = Field(default_factory=dict)
prompt_messages: list[PromptMessage] = Field(default_factory=list)
tools: list[PromptMessageTool] = Field(default_factory=list[PromptMessageTool])
stop: list[str] = Field(default_factory=list[str])
tools: list[PromptMessageTool] | None = Field(default_factory=list[PromptMessageTool])
stop: list[str] | None = Field(default_factory=list[str])
stream: bool = False
model_config = ConfigDict(protected_namespaces=())

View File

@ -20,6 +20,7 @@ from core.plugin.impl.exc import (
PluginDaemonNotFoundError,
PluginDaemonUnauthorizedError,
PluginInvokeError,
PluginLLMPollingUnsupportedError,
PluginNotFoundError,
PluginPermissionDeniedError,
PluginUniqueIdentifierError,
@ -370,6 +371,10 @@ class BasePluginClient:
raise TriggerInvokeError(error_object.get("message"))
case EventIgnoreError.__name__:
raise EventIgnoreError(description=error_object.get("message"))
# NOTE: current plugin sdk / plugin daemon does not raise exception with
# type `PluginLLMPollingUnsupportedError`.
case PluginLLMPollingUnsupportedError.__name__:
raise PluginLLMPollingUnsupportedError(description=error_object.get("message"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:

View File

@ -5,6 +5,13 @@ from pydantic import TypeAdapter
from extensions.ext_logging import get_request_id
# NOTE: Avoid renaming exception classes in this file, since
# the `_handle_plugin_daemon_error` in api/core/plugin/impl/base.py
# build exception instances based on the class name.
#
# Renaming of exception classes could result in incorrect exception
# being raised.
class PluginDaemonError(Exception):
"""Base class for all plugin daemon errors."""
@ -75,6 +82,10 @@ class PluginInvokeError(PluginDaemonClientSideError, ValueError):
)
class PluginLLMPollingUnsupportedError(PluginInvokeError):
"""Plugin-backed LLM polling is unavailable for the requested model."""
class PluginUniqueIdentifierError(PluginDaemonClientSideError):
description: str = "Unique Identifier Error"

View File

@ -13,13 +13,17 @@ from core.plugin.entities.plugin_daemon import (
PluginVoicesResponse,
)
from core.plugin.impl.base import BasePluginClient
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
from core.plugin.impl.exc import PluginInvokeError, PluginLLMPollingUnsupportedError
from graphon.model_runtime.entities.llm_entities import LLMPollingResult, LLMResultChunk
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.utils.encoders import jsonable_encoder
_POLLING_UNSUPPORTED_INVOKE_ERROR_TYPES = frozenset((NotImplementedError.__name__,))
_POLLING_UNSUPPORTED_ERROR_MESSAGE = "does not support polling"
class PluginModelClient(BasePluginClient):
@staticmethod
@ -197,6 +201,103 @@ class PluginModelClient(BasePluginClient):
except PluginDaemonInnerError as e:
raise ValueError(e.message + str(e.code))
def start_llm_polling(
self,
tenant_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any] | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
json_schema: dict[str, Any] | None = None,
) -> LLMPollingResult:
"""Start an LLM polling request for plugin-backed long-running jobs."""
try:
return self._request_with_plugin_daemon_response(
method="POST",
path=f"plugin/{tenant_id}/dispatch/model/polling/start",
type_=LLMPollingResult,
data=jsonable_encoder(
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": ModelType.LLM.value,
"model": model,
"credentials": credentials,
"prompt_messages": prompt_messages,
"model_parameters": model_parameters,
"tools": tools,
"stop": stop,
"stream": False,
"json_schema": json_schema,
},
)
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
except PluginInvokeError as error:
self._raise_typed_polling_unsupported_error(error)
raise
def check_llm_polling(
self,
tenant_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
plugin_state: dict[str, Any],
) -> LLMPollingResult:
"""Check the latest state for a plugin-backed LLM polling job."""
try:
return self._request_with_plugin_daemon_response(
method="POST",
path=f"plugin/{tenant_id}/dispatch/model/polling/check",
type_=LLMPollingResult,
data=jsonable_encoder(
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": ModelType.LLM.value,
"model": model,
"credentials": credentials,
"plugin_state": plugin_state,
},
)
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
except PluginInvokeError as error:
self._raise_typed_polling_unsupported_error(error)
raise
@staticmethod
def _raise_typed_polling_unsupported_error(error: PluginInvokeError) -> None:
"""Convert plugin polling capability failures into a dedicated Dify exception."""
if error.get_error_type() == PluginLLMPollingUnsupportedError.__name__:
raise PluginLLMPollingUnsupportedError(description=error.description) from error
if (
error.get_error_type() in _POLLING_UNSUPPORTED_INVOKE_ERROR_TYPES
# This is ugly, we should not rely on error messages while checking
# error types.
and _POLLING_UNSUPPORTED_ERROR_MESSAGE in error.get_error_message().lower()
):
raise PluginLLMPollingUnsupportedError(description=error.description) from error
def get_llm_num_tokens(
self,
tenant_id: str,

View File

@ -6,6 +6,7 @@ from collections.abc import Generator, Iterable, Sequence
from typing import IO, Any, Literal, cast, overload, override
from pydantic import ValidationError
from pydantic.json_schema import JsonValue
from redis import RedisError
from configs import dify_config
@ -17,6 +18,7 @@ from core.plugin.impl.model import PluginModelClient
from core.plugin.plugin_service import PluginService
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.llm_entities import (
LLMPollingResult,
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
@ -430,6 +432,54 @@ class PluginModelRuntime(ModelRuntime):
tools=list(tools) if tools else None,
)
def start_llm_polling(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
json_schema: dict[str, Any] | None,
) -> LLMPollingResult:
"""Start a plugin-side polling job for long-running LLM invocations."""
plugin_id, provider_name = self._split_provider(provider)
return self.client.start_llm_polling(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
prompt_messages=list(prompt_messages),
model_parameters=model_parameters,
tools=list(tools) if tools else None,
stop=list(stop) if stop else None,
json_schema=json_schema,
)
def check_llm_polling(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
plugin_state: dict[str, JsonValue],
) -> LLMPollingResult:
"""Check the latest plugin-side polling state for an LLM invocation."""
plugin_id, provider_name = self._split_provider(provider)
return self.client.check_llm_polling(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
plugin_state=plugin_state,
)
@override
def invoke_text_embedding(
self,

View File

@ -304,22 +304,23 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
Returns:
True if any Dify $ref is found, False otherwise
"""
if isinstance(schema, dict):
# Check if this dict has a $ref field
ref_uri = schema.get("$ref")
if ref_uri and _is_dify_schema_ref(ref_uri):
return True
# Check nested values
for value in schema.values():
if _has_dify_refs_recursive(value):
match schema:
case dict():
# Check if this dict has a $ref field
ref_uri = schema.get("$ref")
if ref_uri and _is_dify_schema_ref(ref_uri):
return True
elif isinstance(schema, list):
# Check each item in the list
for item in schema:
if _has_dify_refs_recursive(item):
return True
# Check nested values
for value in schema.values():
if _has_dify_refs_recursive(value):
return True
case list():
# Check each item in the list
for item in schema:
if _has_dify_refs_recursive(item):
return True
# Primitive types don't contain refs
return False

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any, override
from datetime import datetime, tzinfo
from typing import Any, cast, override
import pytz # type: ignore[import-untyped]
@ -35,17 +35,26 @@ class LocaltimeToTimestampTool(BuiltinTool):
yield self.create_text_message(f"{timestamp}")
# TODO: this method's type is messy
@staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
def localtime_to_timestamp(localtime: str, time_format: str, local_tz: str | tzinfo | None = None) -> int | None:
try:
local_time = datetime.strptime(localtime, time_format)
if local_tz is None:
localtime = local_time.astimezone() # type: ignore
elif isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz)
localtime = local_tz.localize(local_time) # type: ignore
timestamp = int(localtime.timestamp()) # type: ignore
converted_localtime: datetime
match local_tz:
case None:
converted_localtime = local_time.astimezone()
case str() as timezone_name:
timezone = pytz.timezone(timezone_name)
converted_localtime = timezone.localize(local_time)
case tzinfo():
localize = getattr(local_tz, "localize", None)
if callable(localize):
converted_localtime = cast(datetime, localize(local_time))
else:
converted_localtime = local_time.replace(tzinfo=local_tz)
case _:
raise ValueError("local_tz must be None, a timezone name, or a tzinfo instance")
timestamp = int(converted_localtime.timestamp())
return timestamp
except Exception as e:
raise ToolInvokeError(str(e))

View File

@ -122,13 +122,14 @@ class MCPTool(Tool):
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
"""Process JSON content based on its type."""
if isinstance(content_json, dict):
yield self.create_json_message(content_json)
elif isinstance(content_json, list):
yield from self._process_json_list(content_json)
else:
# For primitive types (str, int, bool, etc.), convert to string
yield self.create_text_message(str(content_json))
match content_json:
case dict():
yield self.create_json_message(content_json)
case list():
yield from self._process_json_list(content_json)
case _:
# For primitive types (str, int, bool, etc.), convert to string
yield self.create_text_message(str(content_json))
def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]:
"""Process a list of JSON items."""
@ -222,16 +223,17 @@ class MCPTool(Tool):
# Recursively search through nested structures
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
match value:
case _ if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
case list() if not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
@override

View File

@ -196,16 +196,17 @@ class WorkflowTool(Tool):
return usage_candidate
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
match value:
case _ if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
case _ if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
@override
@ -393,24 +394,25 @@ class WorkflowTool(Tool):
files: list[File] = []
result = {}
for key, value in outputs.items():
if isinstance(value, list):
for item in value:
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(self.runtime.tenant_id),
access_controller=_file_access_controller,
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(self.runtime.tenant_id),
access_controller=_file_access_controller,
)
files.append(file)
match value:
case list():
for item in value:
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(self.runtime.tenant_id),
access_controller=_file_access_controller,
)
files.append(file)
case dict() if value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(self.runtime.tenant_id),
access_controller=_file_access_controller,
)
files.append(file)
result[key] = value

View File

@ -47,7 +47,7 @@ class EventParameter(BaseModel):
template: PluginParameterTemplate | None = Field(default=None, description="The template of the parameter")
scope: str | None = None
required: bool = False
multiple: bool | None = Field(
multiple: bool = Field(
default=False,
description="Whether the parameter is multiple select, only valid for select or dynamic-select type",
)

View File

@ -26,6 +26,7 @@ from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
DifyPreparedLLM,
DifyPreparedPollingLLM,
DifyPromptMessageSerializer,
DifyRetrieverAttachmentLoader,
DifyToolFileManager,
@ -531,7 +532,11 @@ class DifyNodeFactory(NodeFactory):
node_init_kwargs: dict[str, object] = {
"credentials_provider": self._llm_credentials_provider,
"model_factory": self._llm_model_factory,
"model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance,
"model_instance": (
self._wrap_model_instance_for_node(node_data=validated_node_data, model_instance=model_instance)
if wrap_model_instance
else model_instance
),
"memory": self._build_memory_for_llm_node(
node_data=validated_node_data,
model_instance=model_instance,
@ -555,6 +560,23 @@ class DifyNodeFactory(NodeFactory):
node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY)
return node_init_kwargs
@staticmethod
def _wrap_model_instance_for_node(
*,
node_data: LLMCompatibleNodeData,
model_instance: ModelInstance,
) -> DifyPreparedLLM:
# Only graphon's LLM node consumes the polling protocol. Keep classifier
# and extractor nodes on the existing wrapper even if the same model
# advertises polling support.
if node_data.type == BuiltinNodeTypes.LLM and DifyNodeFactory._supports_plugin_llm_polling(model_instance):
return DifyPreparedPollingLLM(model_instance)
return DifyPreparedLLM(model_instance)
@staticmethod
def _supports_plugin_llm_polling(model_instance: ModelInstance) -> bool:
return model_instance.get_model_schema().support_polling
def _build_retriever_attachment_loader(self, node_data: LLMNodeData) -> DifyRetrieverAttachmentLoader:
return DifyRetrieverAttachmentLoader(
file_reference_factory=self._file_reference_factory,

View File

@ -4,6 +4,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast, overload, override
from pydantic import JsonValue
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -38,6 +39,7 @@ from factories import file_factory
from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities import LLMMode
from graphon.model_runtime.entities.llm_entities import (
LLMPollingResult,
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
@ -54,6 +56,7 @@ from graphon.nodes.human_input.entities import (
HumanInputNodeData,
)
from graphon.nodes.llm.runtime_protocols import (
LLMPollingCapableProtocol,
LLMProtocol,
PromptMessageSerializerProtocol,
RetrieverAttachmentLoaderProtocol,
@ -278,6 +281,58 @@ class DifyPreparedLLM(LLMProtocol):
return isinstance(error, OutputParserError)
class DifyPreparedPollingLLM(DifyPreparedLLM, LLMPollingCapableProtocol):
"""Prepared workflow LLM adapter that exposes Graphon's polling protocol."""
def __init__(self, model_instance: ModelInstance) -> None:
from core.plugin.impl.model_runtime import PluginModelRuntime
super().__init__(model_instance)
model_type_instance = model_instance.model_type_instance
if not isinstance(model_type_instance, LargeLanguageModel):
raise TypeError("Polling wrapper requires a large-language-model instance.")
plugin_model_runtime = model_type_instance.model_runtime
if not isinstance(plugin_model_runtime, PluginModelRuntime):
raise TypeError("Polling wrapper requires a plugin-backed model runtime.")
self._plugin_model_runtime = plugin_model_runtime
@override
def start_llm_polling(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
json_schema: Mapping[str, Any] | None,
) -> LLMPollingResult:
return self._plugin_model_runtime.start_llm_polling(
provider=self.provider,
model=self.model_name,
credentials=self._model_instance.credentials,
prompt_messages=prompt_messages,
model_parameters=dict(model_parameters),
tools=tools,
stop=stop,
json_schema=dict(json_schema) if json_schema is not None else None,
)
@override
def check_llm_polling(
self,
*,
plugin_state: Mapping[str, JsonValue],
) -> LLMPollingResult:
return self._plugin_model_runtime.check_llm_polling(
provider=self.provider,
model=self.model_name,
credentials=self._model_instance.credentials,
plugin_state=dict(plugin_state),
)
class DifyPromptMessageSerializer(PromptMessageSerializerProtocol):
@override
def serialize(

View File

@ -297,25 +297,26 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
for cond in conditions.conditions or []:
value = cond.value
resolved_value: str | Sequence[str] | int | float | None
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values: list[str] = []
for v in value:
segment_group = variable_pool.convert_template(v)
match value:
case str():
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_values.append(
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
)
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
else:
resolved_value = value
resolved_value = segment_group.text
case _ if isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values: list[str] = []
for v in value:
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
)
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
case _:
resolved_value = value
resolved_conditions.append(
Condition(
name=cond.name,

View File

@ -167,6 +167,12 @@ def _patch_union_schema_markdown(markdown: str, spec_path: Path) -> str:
return markdown
def _strip_trailing_line_whitespace(markdown: str) -> str:
"""Remove converter-emitted trailing whitespace without changing line structure."""
return "\n".join(line.rstrip(" \t") for line in markdown.split("\n"))
def _convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None:
markdown_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory(prefix=f"{markdown_path.stem}-") as temp_dir:
@ -201,6 +207,7 @@ def _convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None:
temp_markdown_path.read_text(encoding="utf-8"),
spec_path,
)
converted_markdown = _strip_trailing_line_whitespace(converted_markdown)
if not converted_markdown.strip():
raise RuntimeError(f"swagger-markdown wrote an empty document for {markdown_path}")

View File

@ -104,11 +104,14 @@ def _field_signature(field: object) -> object:
"description",
"example",
"max",
"max_items",
"min",
"min_items",
"nullable",
"readonly",
"required",
"title",
"unique",
):
if hasattr(field_instance, attr_name):
signature[attr_name] = _jsonable_schema_value(getattr(field_instance, attr_name))
@ -154,9 +157,9 @@ def create_spec_app() -> Flask:
apply_runtime_defaults()
from libs.flask_restx_compat import patch_swagger_for_inline_nested_dicts
from libs.flask_restx_compat import install_swagger_compatibility
patch_swagger_for_inline_nested_dicts()
install_swagger_compatibility()
app = Flask(__name__)

View File

@ -354,11 +354,12 @@ def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]:
def target_names(target: ast.AST) -> Iterable[str]:
if isinstance(target, ast.Name):
yield target.id
elif isinstance(target, ast.Tuple | ast.List):
for item in target.elts:
yield from target_names(item)
match target:
case ast.Name():
yield target.id
case ast.Tuple() | ast.List():
for item in target.elts:
yield from target_names(item)
def record_assignment(

View File

@ -107,17 +107,58 @@ class AgentInviteOptionsResponse(ResponseModel):
has_more: bool
class AgentLogItemResponse(ResponseModel):
class AgentLogSourceResponse(ResponseModel):
id: str
type: Literal["webapp", "workflow"]
app_id: str
app_name: str
app_icon_type: str | None = None
app_icon: str | None = None
app_icon_background: str | None = None
workflow_id: str | None = None
workflow_version: str | None = None
node_id: str | None = None
class AgentLogSourceGroupResponse(ResponseModel):
type: Literal["webapp", "workflow"]
label: str
sources: list[AgentLogSourceResponse] = Field(default_factory=list)
class AgentLogSourceListResponse(ResponseModel):
data: list[AgentLogSourceResponse]
groups: list[AgentLogSourceGroupResponse]
class AgentLogConversationItemResponse(ResponseModel):
id: str
conversation_id: str
title: str | None = None
end_user_id: str | None = None
message_count: int
user_rate: float | None = None
operation_rate: float | None = None
unread: bool
source: AgentLogSourceResponse | None = None
status: Literal["success", "failed", "paused"]
created_at: int | None = None
updated_at: int | None = None
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class AgentLogMessageItemResponse(ResponseModel):
id: str
message_id: str
conversation_id: str
conversation_name: str | None = None
query: str
answer: str
status: str
error: str | None = None
source: str | None = None
from_source: str | None = None
from_end_user_id: str | None = None
from_account_id: str | None = None
message_tokens: int
@ -136,7 +177,15 @@ class AgentLogItemResponse(ResponseModel):
class AgentLogListResponse(ResponseModel):
data: list[AgentLogItemResponse]
data: list[AgentLogConversationItemResponse]
page: int
limit: int
total: int
has_more: bool
class AgentLogMessageListResponse(ResponseModel):
data: list[AgentLogMessageItemResponse]
page: int
limit: int
total: int

View File

@ -94,12 +94,13 @@ class RedisSubscriptionBase(Subscription):
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
match channel_field:
case bytes():
channel_name = channel_field.decode("utf-8")
case str():
channel_name = channel_field
case _:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning(

View File

@ -88,22 +88,23 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
#
# Since we have already filtered at the caller's site, we can safely set
# `ignore_subscribe_messages=False`.
if isinstance(self._client, RedisCluster):
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
# specifying the `target_node` argument would use busy-looping to wait
# for incoming message, consuming excessive CPU quota.
#
# Here we specify the `target_node` to mitigate this problem.
node = self._client.get_node_from_key(self._topic)
return self._pubsub.get_sharded_message( # type: ignore[attr-defined]
ignore_subscribe_messages=False,
timeout=1,
target_node=node,
)
elif isinstance(self._client, Redis):
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
else:
raise AssertionError("client should be either Redis or RedisCluster.")
match self._client:
case RedisCluster():
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
# specifying the `target_node` argument would use busy-looping to wait
# for incoming message, consuming excessive CPU quota.
#
# Here we specify the `target_node` to mitigate this problem.
node = self._client.get_node_from_key(self._topic)
return self._pubsub.get_sharded_message( # type: ignore[attr-defined]
ignore_subscribe_messages=False,
timeout=1,
target_node=node,
)
case Redis():
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
case _:
raise AssertionError("client should be either Redis or RedisCluster.")
@override
def _get_message_type(self) -> str:

View File

@ -138,10 +138,11 @@ class _StreamsSubscription(Subscription):
if isinstance(fields, dict):
data = fields.get(b"data")
data_bytes: bytes | None = None
if isinstance(data, str):
data_bytes = data.encode()
elif isinstance(data, (bytes, bytearray)):
data_bytes = bytes(data)
match data:
case str():
data_bytes = data.encode()
case bytes() | bytearray():
data_bytes = bytes(data)
if data_bytes is not None:
self._queue.put_nowait(data_bytes)
last_id = entry_id

View File

@ -9,7 +9,7 @@ from werkzeug.http import HTTP_STATUS_CODES
from configs import dify_config
from core.errors.error import AppInvokeQuotaExceededError
from libs.flask_restx_compat import patch_swagger_for_inline_nested_dicts
from libs.flask_restx_compat import install_swagger_compatibility
from libs.token import build_force_logout_cookie_headers
@ -127,16 +127,22 @@ def register_external_error_handlers(api: Api, body_formatter: ErrorBodyFormatte
class ExternalApi(Api):
_authorizations = {
"Bearer": {
"type": "apiKey",
"in": "header",
"name": "Authorization",
"description": "Type: Bearer {your-api-key}",
"bearerFormat": "API_KEY",
"description": "Use the Service API key as a Bearer token in the Authorization header.",
"scheme": "bearer",
"type": "http",
}
}
def __init__(self, app: Blueprint | Flask, *args, error_body_formatter: ErrorBodyFormatter | None = None, **kwargs):
def __init__(
self,
app: Blueprint | Flask,
*args,
error_body_formatter: ErrorBodyFormatter | None = None,
**kwargs,
):
self._error_body_formatter = error_body_formatter
patch_swagger_for_inline_nested_dicts()
install_swagger_compatibility()
kwargs.setdefault("authorizations", self._authorizations)
kwargs.setdefault("security", "Bearer")
kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED

View File

@ -8,12 +8,14 @@ spec export fail or succeed in the same way.
import hashlib
import json
from typing import TypeGuard
from typing import TypeGuard, cast
from flask import current_app
from flask_restx import fields
from flask_restx import swagger as restx_swagger
from flask_restx.model import Model, OrderedModel, instance
from flask_restx.swagger import Swagger
from flask_restx.utils import not_none
def _is_inline_field_map(value: object) -> TypeGuard[dict[object, object]]:
@ -98,20 +100,28 @@ def _inline_model_name(nested_fields: dict[object, object]) -> str:
return f"_AnonymousInlineModel_{digest}"
def patch_swagger_for_inline_nested_dicts() -> None:
"""Allow OpenAPI generation to handle legacy inline Flask-RESTX field dicts.
def install_swagger_compatibility() -> None:
"""Install Dify's Flask-RESTX OpenAPI compatibility hooks.
Some existing controllers use raw field mappings in `fields.Nested({...})`
or directly in `@namespace.response(...)`. Runtime marshalling accepts that,
but Flask-RESTX registration expects a named model. Convert those
anonymous mappings into temporary named models during docs generation.
Flask-RESTX also drops parameter descriptions from generated schemas and
does not expose the Werkzeug `uuid` route converter as `format: uuid`.
"""
if getattr(Swagger, "_dify_inline_nested_dict_patch", False):
if getattr(Swagger, "_dify_swagger_compatibility_installed", False):
return
original_register_model = Swagger.register_model
original_register_field = Swagger.register_field
original_extract_path_params = restx_swagger.extract_path_params
original_schema_from_parameter = Swagger.schema_from_parameter
original_description_for = Swagger.description_for
original_serialize_operation = Swagger.serialize_operation
original_parameters_and_request_body_for = Swagger.parameters_and_request_body_for
original_as_dict = Swagger.as_dict
def get_or_create_inline_model(self: Swagger, nested_fields: dict[object, object]) -> object:
@ -134,6 +144,65 @@ def patch_swagger_for_inline_nested_dicts() -> None:
original_register_field(self, field)
def schema_from_parameter_with_description(self: Swagger, param: dict[str, object]) -> dict[str, object]:
schema = cast(dict[str, object], original_schema_from_parameter(self, param))
description = param.get("description")
if isinstance(description, str):
schema["description"] = description
return schema
def extract_path_params_with_uuid_format(path: str):
params = original_extract_path_params(path)
for converter, _arguments, variable in restx_swagger.parse_rule(path):
if converter == "uuid" and variable in params:
params[variable]["format"] = "uuid"
return params
def description_for_with_explicit_summary(self: Swagger, doc: dict[str, object], method: str):
method_doc = doc.get(method)
if (
isinstance(method_doc, dict)
and isinstance(method_doc.get("summary"), str)
and isinstance(method_doc.get("description"), str)
):
return method_doc["description"]
return original_description_for(self, doc, method)
def serialize_operation_with_explicit_summary_tags(
self: Swagger, doc: dict[str, object], method: str, inherited_request_body=None
):
operation = original_serialize_operation(self, doc, method, inherited_request_body)
method_doc = doc.get(method)
if not isinstance(method_doc, dict):
return operation
summary = method_doc.get("summary")
if isinstance(summary, str):
operation["summary"] = summary
tags = method_doc.get("tags")
if isinstance(tags, list) and all(isinstance(tag, str) for tag in tags):
operation["tags"] = tags
return operation
def serialize_resource_with_explicit_operation_tags(self: Swagger, ns, resource, url, route_doc=None, **kwargs):
doc = self.extract_resource_doc(resource, url, route_doc=route_doc)
if doc is False:
return None
path_params, path_request_body = original_parameters_and_request_body_for(self, doc)
path: dict[str, object] = {"parameters": path_params or None}
methods = [method.lower() for method in resource.methods or []]
requested_methods = [method.lower() for method in kwargs.get("methods", [])]
for method in methods:
if doc[method] is False or requested_methods and method not in requested_methods:
continue
operation = self.serialize_operation(doc, method, path_request_body)
operation.setdefault("tags", [ns.name])
path[method] = operation
return not_none(path)
def as_dict_with_inline_dict_support(self: Swagger):
# Temporary set RESTX_INCLUDE_ALL_MODELS = false to prevent "length changed while iterating" error
include_all_models = current_app.config.get("RESTX_INCLUDE_ALL_MODELS", False)
@ -145,5 +214,10 @@ def patch_swagger_for_inline_nested_dicts() -> None:
Swagger.register_model = register_model_with_inline_dict_support
Swagger.register_field = register_field_with_inline_dict_support
restx_swagger.extract_path_params = extract_path_params_with_uuid_format
Swagger.schema_from_parameter = schema_from_parameter_with_description
Swagger.description_for = description_for_with_explicit_summary
Swagger.serialize_operation = serialize_operation_with_explicit_summary_tags
Swagger.serialize_resource = serialize_resource_with_explicit_operation_tags
Swagger.as_dict = as_dict_with_inline_dict_support
Swagger._dify_inline_nested_dict_patch = True
Swagger._dify_swagger_compatibility_installed = True

View File

@ -1174,34 +1174,32 @@ class Conversation(Base):
# Convert file mapping to File object
for key, value in inputs.items():
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
):
value_dict = cast(dict[str, Any], value)
inputs[key] = build_file_from_input_mapping(
file_mapping=value_dict,
tenant_resolver=tenant_resolver,
)
elif isinstance(value, list):
value_list = value
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
file_list.append(
build_file_from_input_mapping(
file_mapping=item_dict,
tenant_resolver=tenant_resolver,
match value:
case dict() if cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY:
value_dict = cast(dict[str, Any], value)
inputs[key] = build_file_from_input_mapping(
file_mapping=value_dict,
tenant_resolver=tenant_resolver,
)
case list():
value_list = value
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
file_list.append(
build_file_from_input_mapping(
file_mapping=item_dict,
tenant_resolver=tenant_resolver,
)
)
)
inputs[key] = file_list
inputs[key] = file_list
return inputs
@ -1516,46 +1514,45 @@ class Message(Base):
owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)),
)
for key, value in inputs.items():
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
):
value_dict = cast(dict[str, Any], value)
inputs[key] = build_file_from_input_mapping(
file_mapping=value_dict,
tenant_resolver=tenant_resolver,
)
elif isinstance(value, list):
value_list = value
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
file_list.append(
build_file_from_input_mapping(
file_mapping=item_dict,
tenant_resolver=tenant_resolver,
match value:
case dict() if cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY:
value_dict = cast(dict[str, Any], value)
inputs[key] = build_file_from_input_mapping(
file_mapping=value_dict,
tenant_resolver=tenant_resolver,
)
case list():
value_list = value
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
file_list.append(
build_file_from_input_mapping(
file_mapping=item_dict,
tenant_resolver=tenant_resolver,
)
)
)
inputs[key] = file_list
inputs[key] = file_list
return inputs
@inputs.setter
def inputs(self, value: Mapping[str, Any]):
inputs = dict(value)
for k, v in inputs.items():
if isinstance(v, File):
inputs[k] = v.model_dump()
elif isinstance(v, list):
v_list = v
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
match v:
case File():
inputs[k] = v.model_dump()
case list():
v_list = v
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
self._inputs = inputs
@property

View File

@ -96,12 +96,13 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]):
def process_bind_param(self, value: T | dict[str, Any] | str | None, dialect: Dialect) -> str | None:
if value is None:
return None
if isinstance(value, self._model_class):
model = value
elif isinstance(value, str):
model = self._model_class.model_validate_json(value)
else:
model = self._model_class.model_validate(value)
match value:
case _ if isinstance(value, self._model_class):
model = value
case str():
model = self._model_class.model_validate_json(value)
case _:
model = self._model_class.model_validate(value)
return json.dumps(model.model_dump(mode="json"), ensure_ascii=False, sort_keys=True, separators=(",", ":"))
@override

File diff suppressed because it is too large Load Diff

View File

@ -4,10 +4,9 @@ User-scoped programmatic API (bearer auth)
## Version: 1.0
### Available authorizations
#### Bearer (API Key Authentication)
Type: Bearer {your-api-key}
**Name:** Authorization
**In:** header
#### Bearer (HTTP, bearer)
Use the Service API key as a Bearer token in the Authorization header.
Bearer format: API_KEY
---
## openapi

File diff suppressed because it is too large Load Diff

View File

@ -4,10 +4,9 @@ Public APIs for web applications including file uploads, chat interactions, and
## Version: 1.0
### Available authorizations
#### Bearer (API Key Authentication)
Type: Bearer {your-api-key}
**Name:** Authorization
**In:** header
#### Bearer (HTTP, bearer)
Use the Service API key as a Bearer token in the Authorization header.
Bearer format: API_KEY
---
## web
@ -140,7 +139,7 @@ Delete a specific conversation.
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| c_id | path | Conversation UUID | Yes | string |
| c_id | path | Conversation UUID | Yes | string (uuid) |
#### Responses
@ -160,7 +159,7 @@ Rename a specific conversation with a custom name or auto-generate one.
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| c_id | path | Conversation UUID | Yes | string |
| c_id | path | Conversation UUID | Yes | string (uuid) |
| auto_generate | query | Auto-generate conversation name | No | boolean |
| name | query | New conversation name | No | string |
@ -188,7 +187,7 @@ Pin a specific conversation to keep it at the top of the list.
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| c_id | path | Conversation UUID | Yes | string |
| c_id | path | Conversation UUID | Yes | string (uuid) |
#### Responses
@ -208,7 +207,7 @@ Unpin a specific conversation to remove it from the top of the list.
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| c_id | path | Conversation UUID | Yes | string |
| c_id | path | Conversation UUID | Yes | string (uuid) |
#### Responses
@ -494,7 +493,7 @@ Submit feedback (like/dislike) for a specific message.
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| message_id | path | Message UUID | Yes | string |
| message_id | path | Message UUID | Yes | string (uuid) |
| content | query | Feedback content | No | string |
| rating | query | Feedback rating | No | string, <br>**Available values:** "dislike", "like" |
@ -523,7 +522,7 @@ Generate a new completion similar to an existing message (completion apps only).
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| response_mode | query | Response mode | Yes | string, <br>**Available values:** "blocking", "streaming" |
| message_id | path | | Yes | string |
| message_id | path | | Yes | string (uuid) |
#### Responses
@ -543,7 +542,7 @@ Get suggested follow-up questions after a message (chat apps only).
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| message_id | path | Message UUID | Yes | string |
| message_id | path | Message UUID | Yes | string (uuid) |
#### Responses
@ -731,7 +730,7 @@ Remove a message from saved messages.
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| message_id | path | Message UUID to delete | Yes | string |
| message_id | path | Message UUID to delete | Yes | string (uuid) |
#### Responses
@ -1633,7 +1632,7 @@ Default configuration for form inputs.
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| message_id | string | Message ID | No |
| streaming | boolean | Enable streaming response | No |
| streaming | boolean | Reserved for compatibility; TTS response streaming is determined by the provider output. | No |
| text | string | Text to convert to audio | No |
| voice | string | Voice to use for TTS | No |

View File

@ -1150,13 +1150,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
try:
# Convert outputs to string based on type
outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value
if isinstance(trace_info.outputs, dict | list):
outputs_str = safe_json_dumps(trace_info.outputs)
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
elif isinstance(trace_info.outputs, str):
outputs_str = trace_info.outputs
else:
outputs_str = str(trace_info.outputs)
match trace_info.outputs:
case dict() | list():
outputs_str = safe_json_dumps(trace_info.outputs)
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
case str():
outputs_str = trace_info.outputs
case _:
outputs_str = str(trace_info.outputs)
llm_attributes: dict[str, Any] = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
@ -1553,25 +1554,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id)
# Handle list of messages
if isinstance(prompts, list):
for message_index, message in enumerate(prompts):
if not isinstance(message, dict):
continue
match prompts:
case list():
for message_index, message in enumerate(prompts):
if not isinstance(message, dict):
continue
role = message.get("role", "user")
content = message.get("text") or message.get("content") or ""
role = message.get("role", "user")
content = message.get("text") or message.get("content") or ""
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
tool_calls = message.get("tool_calls") or []
if isinstance(tool_calls, list):
for tool_index, tool_call in enumerate(tool_calls):
set_tool_call_attributes(message_index, tool_index, tool_call)
tool_calls = message.get("tool_calls") or []
if isinstance(tool_calls, list):
for tool_index, tool_call in enumerate(tool_calls):
set_tool_call_attributes(message_index, tool_index, tool_call)
# Handle single dict or plain string prompt
elif isinstance(prompts, (dict, str)):
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
# Handle single dict or plain string prompt
case dict() | str():
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
return attributes

View File

@ -18,24 +18,25 @@ def validate_input_output(v, field_name):
"""
if v == {} or v is None:
return v
if isinstance(v, str):
return [
{
"role": "assistant" if field_name == "output" else "user",
"content": v,
}
]
elif isinstance(v, list):
if len(v) > 0 and isinstance(v[0], dict):
v = replace_text_with_content(data=v)
return v
else:
match v:
case str():
return [
{
"role": "assistant" if field_name == "output" else "user",
"content": str(v),
"content": v,
}
]
case list():
if len(v) > 0 and isinstance(v[0], dict):
v = replace_text_with_content(data=v)
return v
else:
return [
{
"role": "assistant" if field_name == "output" else "user",
"content": str(v),
}
]
return v

View File

@ -64,40 +64,20 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
"total_tokens": values.get("total_tokens", 0),
}
file_list = values.get("file_list", [])
if isinstance(v, str):
match field_name:
case "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case "outputs":
return {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case _:
pass
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
match v:
case str():
match field_name:
case "inputs":
data = {
"messages": v,
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case "outputs":
data = {
return {
"choices": {
"role": "ai",
"content": v,
@ -107,16 +87,37 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
}
case _:
pass
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case list():
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
match field_name:
case "inputs":
data = {
"messages": v,
}
case "outputs":
data = {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case _:
pass
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
if isinstance(v, dict):
v["usage_metadata"] = usage_metadata
v["file_list"] = file_list

View File

@ -40,41 +40,19 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
"total_tokens": values.get("total_tokens", 0),
}
file_list = values.get("file_list", [])
if isinstance(v, str):
if field_name == "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
return {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
match v:
case str():
if field_name == "inputs":
data = {
"messages": [
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
for msg in v
]
if isinstance(v, list)
else v,
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
data = {
return {
"choices": {
"role": "ai",
"content": v,
@ -82,16 +60,39 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
"file_list": file_list,
},
}
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case list():
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
if field_name == "inputs":
data = {
"messages": [
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
for msg in v
]
if isinstance(v, list)
else v,
}
elif field_name == "outputs":
data = {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
if isinstance(v, dict):
v["usage_metadata"] = usage_metadata
v["file_list"] = file_list

View File

@ -361,12 +361,13 @@ class ClickzettaVector(BaseVector):
first_pass = json.loads(raw_metadata)
# Handle double-encoded JSON
if isinstance(first_pass, str):
metadata = parse_metadata_json(first_pass)
elif isinstance(first_pass, dict):
metadata = first_pass
else:
metadata = {}
match first_pass:
case str():
metadata = parse_metadata_json(first_pass)
case dict():
metadata = first_pass
case _:
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, ValueError, TypeError):
@ -942,12 +943,13 @@ class ClickzettaVector(BaseVector):
# First parse may yield a string (double-encoded JSON)
first_pass = json.loads(row[2])
if isinstance(first_pass, str):
metadata = parse_metadata_json(first_pass)
elif isinstance(first_pass, dict):
metadata = first_pass
else:
metadata = {}
match first_pass:
case str():
metadata = parse_metadata_json(first_pass)
case dict():
metadata = first_pass
case _:
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, ValueError, TypeError):

View File

@ -98,14 +98,15 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
values: list[Any] = []
if isinstance(query, psql.Composed):
for part in query:
if isinstance(part, psql.Identifier):
values.append(("ident", part._obj[0] if part._obj else ""))
elif isinstance(part, psql.Literal):
values.append(("literal", part._obj))
elif isinstance(part, psql.Composed):
for sub in part:
if isinstance(sub, psql.Literal):
values.append(("literal", sub._obj))
match part:
case psql.Identifier():
values.append(("ident", part._obj[0] if part._obj else ""))
case psql.Literal():
values.append(("literal", part._obj))
case psql.Composed():
for sub in part:
if isinstance(sub, psql.Literal):
values.append(("literal", sub._obj))
return values

View File

@ -44,7 +44,7 @@ dependencies = [
"resend>=2.27.0,<3.0.0",
# Emerging: newer and fast-moving, use compatible pins
"fastopenapi[flask]==0.7.0",
"graphon==0.5.1",
"graphon==0.5.2",
"httpx-sse==0.4.3",
"json-repair==0.59.4",
]

View File

@ -957,21 +957,22 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
)
pause_reason_models = []
for reason in pause_reasons:
if isinstance(reason, HumanInputRequired):
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
form_id=reason.form_id,
)
elif isinstance(reason, SchedulingPause):
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
message=reason.message,
)
else:
raise AssertionError(f"unkown reason type: {type(reason)}")
match reason:
case HumanInputRequired():
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
form_id=reason.form_id,
)
case SchedulingPause():
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
message=reason.message,
)
case _:
raise AssertionError(f"unknown reason type: {type(reason)}")
pause_reason_models.append(pause_reason_model)

View File

@ -5,7 +5,7 @@ import secrets
import uuid
from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, TypedDict, cast
from typing import Any, NotRequired, TypedDict, cast
from pydantic import BaseModel, TypeAdapter, ValidationError
from sqlalchemy import Row, delete, func, select, update
@ -18,6 +18,8 @@ class InvitationData(TypedDict):
account_id: str
email: str
workspace_id: str
role: NotRequired[str]
requires_setup: NotRequired[bool]
_invitation_adapter: TypeAdapter[InvitationData] = TypeAdapter(InvitationData)
@ -1805,6 +1807,7 @@ class RegisterService:
account = AccountService.get_account_by_email_with_case_fallback(email)
requires_setup = False
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
name = normalized_email.split("@")[0]
@ -1819,6 +1822,7 @@ class RegisterService:
# Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role)
TenantService.switch_tenant(account, tenant.id)
requires_setup = True
else:
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = db.session.scalar(
@ -1826,15 +1830,16 @@ class RegisterService:
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
requires_setup = account.status == AccountStatus.PENDING
if not ta:
if not ta and account.status == AccountStatus.PENDING:
TenantService.create_tenant_member(tenant, account, role)
# Support resend invitation email when the account is pending status
if account.status != AccountStatus.PENDING:
if ta and account.status != AccountStatus.PENDING:
raise AccountAlreadyInTenantError("Account already in tenant.")
token = cls.generate_invite_token(tenant, account)
token = cls.generate_invite_token(tenant, account, role, requires_setup=requires_setup)
language = account.interface_language or "en-US"
# send email
@ -1849,12 +1854,16 @@ class RegisterService:
return token
@classmethod
def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
def generate_invite_token(
cls, tenant: Tenant, account: Account, role: str = "normal", *, requires_setup: bool = False
) -> str:
token = str(uuid.uuid4())
invitation_data = {
"account_id": account.id,
"email": account.email,
"workspace_id": tenant.id,
"role": str(role),
"requires_setup": requires_setup,
}
expiry_hours = dify_config.INVITE_EXPIRY_HOURS
redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data))
@ -1889,16 +1898,7 @@ class RegisterService:
if not tenant:
return None
tenant_account = db.session.execute(
select(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
).first()
if not tenant_account:
return None
account = tenant_account[0]
account = db.session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1))
if not account:
return None

View File

@ -334,16 +334,17 @@ class ComposerConfigValidator:
@classmethod
def _reject_plaintext_secrets(cls, value: Any, *, path: str) -> None:
if isinstance(value, dict):
for key, nested in value.items():
normalized_key = key.lower().replace("-", "_")
nested_path = f"{path}.{key}"
if normalized_key in _PLAINTEXT_SECRET_KEYS and isinstance(nested, str) and nested:
raise PlaintextSecretNotAllowedError(f"Plaintext secret is not allowed at {nested_path}")
cls._reject_plaintext_secrets(nested, path=nested_path)
elif isinstance(value, list):
for index, nested in enumerate(value):
cls._reject_plaintext_secrets(nested, path=f"{path}[{index}]")
match value:
case dict():
for key, nested in value.items():
normalized_key = key.lower().replace("-", "_")
nested_path = f"{path}.{key}"
if normalized_key in _PLAINTEXT_SECRET_KEYS and isinstance(nested, str) and nested:
raise PlaintextSecretNotAllowedError(f"Plaintext secret is not allowed at {nested_path}")
cls._reject_plaintext_secrets(nested, path=nested_path)
case list():
for index, nested in enumerate(value):
cls._reject_plaintext_secrets(nested, path=f"{path}[{index}]")
@classmethod
def _has_install_command(cls, entry: dict[str, Any]) -> bool:

View File

@ -6,12 +6,15 @@ from decimal import Decimal
from typing import Any
import sqlalchemy as sa
from sqlalchemy import func, or_, select
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import aliased
from core.app.entities.app_invoke_entities import InvokeFrom
from libs.helper import convert_datetime_to_date, escape_like_pattern, to_timestamp
from models.agent import WorkflowAgentNodeBinding
from models.enums import MessageStatus
from models.model import App, Conversation, Message
from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
@dataclass(frozen=True)
@ -33,6 +36,16 @@ class AgentStatisticsQueryParams:
timezone: str = "UTC"
@dataclass(frozen=True)
class AgentSourceFilter:
kind: str
app_id: str | None = None
workflow_id: str | None = None
workflow_version: str | None = None
node_id: str | None = None
invoke_from: InvokeFrom | None = None
class AgentObservabilityService:
_SOURCE_ALIASES: dict[str, InvokeFrom] = {
"api": InvokeFrom.SERVICE_API,
@ -66,6 +79,31 @@ class AgentObservabilityService:
except KeyError as exc:
raise ValueError(f"Unsupported source: {source}") from exc
@classmethod
def resolve_source_filter(cls, source: str | None) -> AgentSourceFilter:
if not source or source.strip().lower() == "all":
return AgentSourceFilter(kind="all")
normalized = source.strip()
lowered = normalized.lower()
if lowered == "webapp":
return AgentSourceFilter(kind="webapp")
if lowered.startswith("webapp:"):
return AgentSourceFilter(kind="webapp", app_id=normalized.split(":", 1)[1] or None)
if lowered == "workflow":
return AgentSourceFilter(kind="workflow")
if lowered.startswith("workflow:"):
parts = normalized.split(":", 4)
if len(parts) != 5 or not all(parts[1:]):
raise ValueError(f"Unsupported source: {source}")
return AgentSourceFilter(
kind="workflow",
app_id=parts[1],
workflow_id=parts[2],
workflow_version=parts[3],
node_id=parts[4],
)
return AgentSourceFilter(kind="webapp", invoke_from=cls.resolve_source(source))
@staticmethod
def _message_status(message: Message) -> str:
if message.error or message.status == MessageStatus.ERROR:
@ -104,19 +142,255 @@ class AgentObservabilityService:
"updated_at": to_timestamp(message.updated_at),
}
def list_logs(self, *, app: App, params: AgentLogQueryParams) -> dict[str, Any]:
source = self.resolve_source(params.source)
stmt = (
select(Message, Conversation)
.join(Conversation, Conversation.id == Message.conversation_id)
.where(Message.app_id == app.id, Conversation.app_id == app.id)
)
stmt = self._apply_source_filter(stmt, source)
def list_logs(self, *, app: App, agent_id: str, params: AgentLogQueryParams) -> dict[str, Any]:
source_filter = self.resolve_source_filter(params.source)
rows: list[dict[str, Any]] = []
if source_filter.kind in {"all", "webapp"}:
rows.extend(self._list_webapp_conversation_logs(app=app, params=params, source_filter=source_filter))
if source_filter.kind in {"all", "workflow"}:
rows.extend(
self._list_workflow_conversation_logs(
app=app,
agent_id=agent_id,
params=params,
source_filter=source_filter,
)
)
rows.sort(key=lambda row: (row["updated_at"] or 0, row["id"]), reverse=True)
if params.start:
stmt = stmt.where(Message.created_at >= params.start)
if params.end:
stmt = stmt.where(Message.created_at < params.end)
total = len(rows)
start = (params.page - 1) * params.limit
end = start + params.limit
return {
"data": rows[start:end],
"page": params.page,
"limit": params.limit,
"total": total,
"has_more": end < total,
}
def list_log_messages(
self, *, app: App, agent_id: str, conversation_id: str, params: AgentLogQueryParams
) -> dict[str, Any]:
source_filter = self.resolve_source_filter(params.source)
rows: list[Message] = []
if source_filter.kind in {"all", "webapp"}:
rows.extend(
self._list_webapp_messages(
app=app,
conversation_id=conversation_id,
params=params,
source_filter=source_filter,
)
)
if source_filter.kind in {"all", "workflow"}:
rows.extend(
self._list_workflow_messages(
app=app,
agent_id=agent_id,
conversation_id=conversation_id,
params=params,
source_filter=source_filter,
)
)
deduped = {message.id: message for message in rows}
sorted_rows = sorted(deduped.values(), key=lambda message: (message.created_at, message.id), reverse=True)
total = len(sorted_rows)
start = (params.page - 1) * params.limit
end = start + params.limit
return {
"data": [self.serialize_log_message(message) for message in sorted_rows[start:end]],
"page": params.page,
"limit": params.limit,
"total": total,
"has_more": end < total,
}
def list_log_sources(self, *, app: App, agent_id: str) -> dict[str, Any]:
webapp_source = self._serialize_webapp_source(app)
workflow_sources = self._list_workflow_sources(app=app, agent_id=agent_id)
return {
"data": [webapp_source, *workflow_sources],
"groups": [
{"type": "webapp", "label": "WEBAPP", "sources": [webapp_source]},
{"type": "workflow", "label": "WORKFLOW", "sources": workflow_sources},
],
}
def _list_webapp_conversation_logs(
self, *, app: App, params: AgentLogQueryParams, source_filter: AgentSourceFilter
) -> list[dict[str, Any]]:
stmt = (
select(
Conversation,
func.count(Message.id).label("message_count"),
func.max(Message.created_at).label("created_at"),
func.max(Message.updated_at).label("updated_at"),
func.sum(sa.case((Message.status == MessageStatus.PAUSED, 1), else_=0)).label("paused_count"),
func.sum(
sa.case((or_(Message.error.is_not(None), Message.status == MessageStatus.ERROR), 1), else_=0)
).label("failed_count"),
)
.join(Message, Message.conversation_id == Conversation.id)
.where(Message.app_id == app.id, Conversation.app_id == app.id)
.group_by(Conversation.id)
)
stmt = self._apply_observability_filters(stmt, params=params, source_filter=source_filter)
rows = list(self._session.execute(stmt).all())
return [
self._serialize_conversation_log(
conversation=row[0],
message_count=row.message_count,
paused_count=row.paused_count,
failed_count=row.failed_count,
source=self._serialize_webapp_source(app),
created_at=row.created_at,
updated_at=row.updated_at,
)
for row in rows
]
def _list_workflow_conversation_logs(
self, *, app: App, agent_id: str, params: AgentLogQueryParams, source_filter: AgentSourceFilter
) -> list[dict[str, Any]]:
workflow_app = aliased(App)
stmt = (
select(
Conversation,
workflow_app,
WorkflowAgentNodeBinding.workflow_id,
WorkflowAgentNodeBinding.workflow_version,
WorkflowAgentNodeBinding.node_id,
func.count(sa.distinct(Message.id)).label("message_count"),
func.max(Message.created_at).label("created_at"),
func.max(Message.updated_at).label("updated_at"),
func.sum(sa.case((Message.status == MessageStatus.PAUSED, 1), else_=0)).label("paused_count"),
func.sum(
sa.case((or_(Message.error.is_not(None), Message.status == MessageStatus.ERROR), 1), else_=0)
).label("failed_count"),
)
.select_from(Message)
.join(Conversation, Conversation.id == Message.conversation_id)
.join(WorkflowRun, WorkflowRun.id == Message.workflow_run_id)
.join(
WorkflowAgentNodeBinding,
and_(
WorkflowAgentNodeBinding.tenant_id == app.tenant_id,
WorkflowAgentNodeBinding.agent_id == agent_id,
WorkflowAgentNodeBinding.app_id == WorkflowRun.app_id,
WorkflowAgentNodeBinding.workflow_id == WorkflowRun.workflow_id,
WorkflowAgentNodeBinding.workflow_version == WorkflowRun.version,
),
)
.join(
WorkflowNodeExecutionModel,
and_(
WorkflowNodeExecutionModel.workflow_run_id == WorkflowRun.id,
WorkflowNodeExecutionModel.node_id == WorkflowAgentNodeBinding.node_id,
),
)
.join(workflow_app, workflow_app.id == WorkflowAgentNodeBinding.app_id)
.where(Message.workflow_run_id.is_not(None), Conversation.app_id == WorkflowAgentNodeBinding.app_id)
.group_by(
Conversation.id,
workflow_app.id,
WorkflowAgentNodeBinding.workflow_id,
WorkflowAgentNodeBinding.workflow_version,
WorkflowAgentNodeBinding.node_id,
)
)
stmt = self._apply_observability_filters(stmt, params=params, source_filter=source_filter)
stmt = self._apply_workflow_source_filter(stmt, source_filter)
rows = list(self._session.execute(stmt).all())
return [
self._serialize_conversation_log(
conversation=row[0],
message_count=row.message_count,
paused_count=row.paused_count,
failed_count=row.failed_count,
source=self._serialize_workflow_source(
app=row[1],
workflow_id=row.workflow_id,
workflow_version=row.workflow_version,
node_id=row.node_id,
),
created_at=row.created_at,
updated_at=row.updated_at,
)
for row in rows
]
def _list_webapp_messages(
self, *, app: App, conversation_id: str, params: AgentLogQueryParams, source_filter: AgentSourceFilter
) -> list[Message]:
stmt = select(Message).where(Message.app_id == app.id, Message.conversation_id == conversation_id)
stmt = self._apply_message_filters(stmt, params=params, source_filter=source_filter)
return list(self._session.scalars(stmt.order_by(Message.created_at.desc(), Message.id.desc())).all())
def _list_workflow_messages(
self,
*,
app: App,
agent_id: str,
conversation_id: str,
params: AgentLogQueryParams,
source_filter: AgentSourceFilter,
) -> list[Message]:
stmt = (
select(Message)
.join(WorkflowRun, WorkflowRun.id == Message.workflow_run_id)
.join(
WorkflowAgentNodeBinding,
and_(
WorkflowAgentNodeBinding.tenant_id == app.tenant_id,
WorkflowAgentNodeBinding.agent_id == agent_id,
WorkflowAgentNodeBinding.app_id == WorkflowRun.app_id,
WorkflowAgentNodeBinding.workflow_id == WorkflowRun.workflow_id,
WorkflowAgentNodeBinding.workflow_version == WorkflowRun.version,
),
)
.join(
WorkflowNodeExecutionModel,
and_(
WorkflowNodeExecutionModel.workflow_run_id == WorkflowRun.id,
WorkflowNodeExecutionModel.node_id == WorkflowAgentNodeBinding.node_id,
),
)
.where(Message.conversation_id == conversation_id)
)
stmt = self._apply_message_filters(stmt, params=params, source_filter=source_filter)
stmt = self._apply_workflow_source_filter(stmt, source_filter)
return list(self._session.scalars(stmt.order_by(Message.created_at.desc(), Message.id.desc())).all())
def _list_workflow_sources(self, *, app: App, agent_id: str) -> list[dict[str, Any]]:
workflow_app = aliased(App)
stmt = (
select(
workflow_app,
WorkflowAgentNodeBinding.workflow_id,
WorkflowAgentNodeBinding.workflow_version,
WorkflowAgentNodeBinding.node_id,
)
.join(workflow_app, workflow_app.id == WorkflowAgentNodeBinding.app_id)
.where(WorkflowAgentNodeBinding.tenant_id == app.tenant_id, WorkflowAgentNodeBinding.agent_id == agent_id)
.order_by(workflow_app.name.asc(), WorkflowAgentNodeBinding.node_id.asc())
)
rows = self._session.execute(stmt).all()
deduped: dict[str, dict[str, Any]] = {}
for row in rows:
source = self._serialize_workflow_source(
app=row[0],
workflow_id=row.workflow_id,
workflow_version=row.workflow_version,
node_id=row.node_id,
)
deduped[source["id"]] = source
return list(deduped.values())
@classmethod
def _apply_observability_filters(cls, stmt, *, params: AgentLogQueryParams, source_filter: AgentSourceFilter):
stmt = cls._apply_message_filters(stmt, params=params, source_filter=source_filter, include_keyword=False)
if params.keyword:
escaped_keyword = escape_like_pattern(params.keyword)
pattern = f"%{escaped_keyword}%"
@ -127,27 +401,41 @@ class AgentObservabilityService:
Conversation.name.ilike(pattern, escape="\\"),
)
)
if params.status:
stmt = self._apply_status_filter(stmt, params.status)
return stmt
total = self._session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
rows = list(
self._session.execute(
stmt.order_by(Message.created_at.desc(), Message.id.desc())
.offset((params.page - 1) * params.limit)
.limit(params.limit)
).all()
)
data = []
for message, conversation in rows:
data.append(self.serialize_log_message(message, conversation))
return {
"data": data,
"page": params.page,
"limit": params.limit,
"total": total,
"has_more": params.page * params.limit < total,
}
@classmethod
def _apply_message_filters(
cls, stmt, *, params: AgentLogQueryParams, source_filter: AgentSourceFilter, include_keyword: bool = True
):
stmt = cls._apply_source_filter(stmt, source_filter.invoke_from)
if params.start:
stmt = stmt.where(Message.created_at >= params.start)
if params.end:
stmt = stmt.where(Message.created_at < params.end)
if include_keyword and params.keyword:
escaped_keyword = escape_like_pattern(params.keyword)
pattern = f"%{escaped_keyword}%"
stmt = stmt.where(
or_(
Message.query.ilike(pattern, escape="\\"),
Message.answer.ilike(pattern, escape="\\"),
)
)
if params.status:
stmt = cls._apply_status_filter(stmt, params.status)
return stmt
@staticmethod
def _apply_workflow_source_filter(stmt, source_filter: AgentSourceFilter):
if source_filter.app_id:
stmt = stmt.where(WorkflowAgentNodeBinding.app_id == source_filter.app_id)
if source_filter.workflow_id:
stmt = stmt.where(WorkflowAgentNodeBinding.workflow_id == source_filter.workflow_id)
if source_filter.workflow_version:
stmt = stmt.where(WorkflowAgentNodeBinding.workflow_version == source_filter.workflow_version)
if source_filter.node_id:
stmt = stmt.where(WorkflowAgentNodeBinding.node_id == source_filter.node_id)
return stmt
@classmethod
def _apply_source_filter(cls, stmt, source: InvokeFrom | None):
@ -166,22 +454,95 @@ class AgentObservabilityService:
return stmt.where(Message.status == MessageStatus.PAUSED)
raise ValueError(f"Unsupported status: {status}")
def get_statistics_summary(self, *, app: App, params: AgentStatisticsQueryParams) -> dict[str, Any]:
source = self.resolve_source(params.source)
rows = self._load_daily_statistics(app=app, params=params, source=source)
@classmethod
def _serialize_conversation_log(
cls,
*,
conversation: Conversation,
message_count: int,
paused_count: int,
failed_count: int,
source: dict[str, Any],
created_at: datetime | None,
updated_at: datetime | None,
) -> dict[str, Any]:
return {
"id": conversation.id,
"conversation_id": conversation.id,
"title": conversation.name,
"end_user_id": conversation.from_end_user_id,
"message_count": int(message_count or 0),
"user_rate": None,
"operation_rate": None,
"unread": conversation.read_at is None,
"source": source,
"status": cls._conversation_status(paused_count=paused_count, failed_count=failed_count),
"created_at": to_timestamp(created_at or conversation.created_at),
"updated_at": to_timestamp(updated_at or conversation.updated_at),
}
@staticmethod
def _conversation_status(*, paused_count: int, failed_count: int) -> str:
if paused_count:
return "paused"
if failed_count:
return "failed"
return "success"
@staticmethod
def _serialize_webapp_source(app: App) -> dict[str, Any]:
icon_type = app.icon_type.value if app.icon_type else None
return {
"id": f"webapp:{app.id}",
"type": "webapp",
"app_id": app.id,
"app_name": app.name,
"app_icon_type": icon_type,
"app_icon": app.icon,
"app_icon_background": app.icon_background,
"workflow_id": None,
"workflow_version": None,
"node_id": None,
}
@staticmethod
def _serialize_workflow_source(
*,
app: App,
workflow_id: str,
workflow_version: str,
node_id: str,
) -> dict[str, Any]:
icon_type = app.icon_type.value if app.icon_type else None
return {
"id": f"workflow:{app.id}:{workflow_id}:{workflow_version}:{node_id}",
"type": "workflow",
"app_id": app.id,
"app_name": app.name,
"app_icon_type": icon_type,
"app_icon": app.icon,
"app_icon_background": app.icon_background,
"workflow_id": workflow_id,
"workflow_version": workflow_version,
"node_id": node_id,
}
def get_statistics_summary(self, *, app: App, agent_id: str, params: AgentStatisticsQueryParams) -> dict[str, Any]:
source_filter = self.resolve_source_filter(params.source)
rows = self._load_daily_statistics(app=app, agent_id=agent_id, params=params, source_filter=source_filter)
charts = self._build_charts(rows)
summary = self._build_summary(rows)
return {
"source": source.value if source else "all",
"source": params.source or "all",
"summary": summary,
"charts": charts,
}
def _load_daily_statistics(
self, *, app: App, params: AgentStatisticsQueryParams, source: InvokeFrom | None
self, *, app: App, agent_id: str, params: AgentStatisticsQueryParams, source_filter: AgentSourceFilter
) -> list[dict[str, Any]]:
converted_created_at = convert_datetime_to_date("m.created_at")
source_condition = "AND m.invoke_from != :debugger" if source is None else "AND m.invoke_from = :source"
message_scope = self._statistics_message_scope_sql(source_filter)
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(m.id) AS message_count,
@ -197,15 +558,24 @@ FROM messages m
LEFT JOIN message_feedbacks mf
ON mf.message_id = m.id AND mf.rating = 'like'
WHERE
m.app_id = :app_id
{source_condition}"""
{message_scope}"""
args: dict[str, Any] = {
"tz": params.timezone,
"app_id": app.id,
"tenant_id": app.tenant_id,
"agent_id": agent_id,
"debugger": InvokeFrom.DEBUGGER,
}
if source is not None:
args["source"] = source
if source_filter.invoke_from is not None:
args["source"] = source_filter.invoke_from
if source_filter.app_id:
args["source_app_id"] = source_filter.app_id
if source_filter.workflow_id:
args["workflow_id"] = source_filter.workflow_id
if source_filter.workflow_version:
args["workflow_version"] = source_filter.workflow_version
if source_filter.node_id:
args["node_id"] = source_filter.node_id
if params.start:
sql_query += " AND m.created_at >= :start"
args["start"] = params.start
@ -216,6 +586,45 @@ WHERE
return [dict(row._mapping) for row in self._session.execute(sa.text(sql_query), args).all()]
@staticmethod
def _statistics_message_scope_sql(source_filter: AgentSourceFilter) -> str:
app_scope = "m.app_id = :app_id"
if source_filter.invoke_from is None:
app_scope += " AND m.invoke_from != :debugger"
else:
app_scope += " AND m.invoke_from = :source"
workflow_binding_filters = []
if source_filter.app_id:
workflow_binding_filters.append("wanb.app_id = :source_app_id")
if source_filter.workflow_id:
workflow_binding_filters.append("wanb.workflow_id = :workflow_id")
if source_filter.workflow_version:
workflow_binding_filters.append("wanb.workflow_version = :workflow_version")
if source_filter.node_id:
workflow_binding_filters.append("wanb.node_id = :node_id")
extra_workflow_filters = f"AND {' AND '.join(workflow_binding_filters)}" if workflow_binding_filters else ""
workflow_scope = f"""m.workflow_run_id IS NOT NULL
AND EXISTS (
SELECT 1
FROM workflow_runs wr
JOIN workflow_agent_node_bindings wanb
ON wanb.tenant_id = :tenant_id
AND wanb.agent_id = :agent_id
AND wanb.app_id = wr.app_id
AND wanb.workflow_id = wr.workflow_id
AND wanb.workflow_version = wr.version
{extra_workflow_filters}
JOIN workflow_node_executions wne
ON wne.workflow_run_id = wr.id
AND wne.node_id = wanb.node_id
WHERE wr.id = m.workflow_run_id
)"""
if source_filter.kind == "webapp":
return app_scope
if source_filter.kind == "workflow":
return workflow_scope
return f"(({app_scope}) OR ({workflow_scope}))"
@staticmethod
def _build_charts(rows: list[dict[str, Any]]) -> dict[str, list[dict[str, Any]]]:
messages = []

View File

@ -10,7 +10,9 @@ to the agent drive (Agent Files §5.4 / §4):
Both are stored as ``ToolFile`` records and bound via ``AgentDriveService.commit``
with ``value_owned_by_drive=True`` (the drive owns their lifecycle). The returned
skill ref records the stable drive paths + file ids (not just the raw upload id),
so the Composer can reload the bound skill list.
so the Composer can reload the bound skill list. The console ``/skills/upload``
endpoints delegate to this service so "upload" now always means drive-backed skill
normalization.
"""
from __future__ import annotations
@ -34,7 +36,7 @@ def slugify_skill_name(name: str) -> str:
class SkillStandardizeService:
"""Validate + standardize a Skill package into a per-agent drive."""
"""Validate + standardize a Skill package into a per-agent drive upload result."""
def __init__(
self,

View File

@ -75,7 +75,7 @@ class CreateAppParams(BaseModel):
class AppService:
@staticmethod
def _build_app_list_filters(
user_id: str, tenant_id: str, params: AppListBaseParams
user_id: str, tenant_id: str, params: AppListBaseParams, session: scoped_session
) -> list[sa.ColumnElement[bool]]:
filters = [App.tenant_id == tenant_id, App.is_universal == False]
@ -115,7 +115,7 @@ class AppService:
escaped_name = escape_like_pattern(name)
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
if params.tag_ids and len(params.tag_ids) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, match_all=True)
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, session, match_all=True)
if target_ids and len(target_ids) > 0:
filters.append(App.id.in_(target_ids))
else:
@ -197,7 +197,9 @@ class AppService:
).scalars()
)
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
def get_paginate_apps(
self, user_id: str, tenant_id: str, params: AppListParams, session: scoped_session
) -> Pagination | None:
"""
Get app list with pagination, filters, and explicit sort order.
:param user_id: user id
@ -205,7 +207,7 @@ class AppService:
:param params: query parameters
:return:
"""
filters = self._build_app_list_filters(user_id, tenant_id, params)
filters = self._build_app_list_filters(user_id, tenant_id, params, session)
if not filters:
return None
@ -231,12 +233,12 @@ class AppService:
return app_models
def get_paginate_starred_apps(
self, user_id: str, tenant_id: str, params: StarredAppListParams
self, user_id: str, tenant_id: str, params: StarredAppListParams, session: scoped_session
) -> Pagination | None:
"""
Get apps starred by the current account with pagination, filters, and explicit sort order.
"""
filters = self._build_app_list_filters(user_id, tenant_id, params)
filters = self._build_app_list_filters(user_id, tenant_id, params, session)
if not filters:
return None
@ -540,17 +542,21 @@ class AppService:
*,
name: str | None = None,
description: str | None = None,
role: str | None = None,
icon_type: IconType | str | None = None,
icon: str | None = None,
icon_background: str | None = None,
role: str | None = None,
account_id: str | None = None,
updated_at: datetime | None = None,
) -> None:
"""Keep the Roster identity aligned with its Agent App shell.
Agent Soul remains versioned through Composer. This helper only mirrors
user-facing identity fields so Roster and Agent Console do not drift.
user-facing identity fields, including the roster role/persona label,
so Roster and Agent Console do not drift.
Role omission is intentional: ``role=None`` preserves the backing
Agent's current role, while ``role=""`` explicitly clears it.
"""
agent = self._get_backing_agent_for_update(app)
if agent is None:
@ -560,14 +566,14 @@ class AppService:
agent.name = name
if description is not None:
agent.description = description
if role is not None:
agent.role = role
if icon_type is not None:
agent.icon_type = self._to_agent_icon_type(icon_type)
if icon is not None:
agent.icon = icon
if icon_background is not None:
agent.icon_background = icon_background
if role is not None:
agent.role = role
agent.updated_by = account_id
if updated_at is not None:
agent.updated_at = updated_at
@ -599,10 +605,12 @@ class AppService:
app,
name=app.name,
description=app.description,
# Omitted role must stay omitted here: None means "preserve current
# backing-agent role", while an empty string is an explicit clear.
role=args.get("role"),
icon_type=app.icon_type,
icon=app.icon,
icon_background=app.icon_background,
role=args.get("role"),
account_id=current_user.id,
updated_at=app.updated_at,
)

View File

@ -14,12 +14,13 @@ class AttachmentService:
_session_maker: sessionmaker
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
match session_factory:
case Engine():
self._session_maker = sessionmaker(bind=session_factory)
case sessionmaker():
self._session_maker = session_factory
case _:
raise AssertionError("must be a sessionmaker or an Engine.")
def get_file_base64(self, file_id: str) -> str:
with self._session_maker(expire_on_commit=False) as session:

View File

@ -13,7 +13,7 @@ import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from redis.exceptions import LockNotOwnedError
from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
@ -235,7 +235,9 @@ class _EstimateArgs(BaseModel):
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
def get_datasets(
page, per_page, session: scoped_session, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False
):
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
if user:
@ -295,6 +297,7 @@ class DatasetService:
"knowledge",
tenant_id,
tag_ids,
session,
match_all=True,
)
else:

View File

@ -60,6 +60,7 @@ class ComposerSavePayload(BaseModel):
class RosterAgentCreatePayload(BaseModel):
name: str = Field(min_length=1, max_length=255)
mode: Literal["agent"] = "agent"
description: str = ""
role: str = Field(default="", max_length=255)
icon_type: AgentIconType | None = None

View File

@ -39,12 +39,13 @@ class FileService:
_session_maker: sessionmaker[Session]
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
match session_factory:
case Engine():
self._session_maker = sessionmaker(bind=session_factory)
case sessionmaker():
self._session_maker = session_factory
case _:
raise AssertionError("must be a sessionmaker or an Engine.")
def upload_file(
self,

View File

@ -119,10 +119,11 @@ class HumanInputDeliveryTestService:
class EmailDeliveryTestHandler:
def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None:
if session_factory is None:
session_factory = sessionmaker(bind=db.engine)
elif isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
match session_factory:
case None:
session_factory = sessionmaker(bind=db.engine)
case Engine():
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def supports(self, method: DeliveryChannelConfig) -> bool:
@ -179,11 +180,12 @@ class EmailDeliveryTestHandler:
emails: list[str] = []
bound_reference_ids: list[str] = []
for recipient in recipients.items:
if isinstance(recipient, MemberRecipient):
bound_reference_ids.append(recipient.reference_id)
elif isinstance(recipient, ExternalRecipient):
if recipient.email:
emails.append(recipient.email)
match recipient:
case MemberRecipient():
bound_reference_ids.append(recipient.reference_id)
case ExternalRecipient():
if recipient.email:
emails.append(recipient.email)
if recipients.include_bound_group:
bound_reference_ids = []

View File

@ -6,7 +6,7 @@ from datetime import UTC, datetime
from typing import Any
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from core.db import session_factory
from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@ -192,6 +192,7 @@ class SnippetService:
self,
*,
tenant_id: str,
session: scoped_session,
page: int = 1,
limit: int = 20,
keyword: str | None = None,
@ -229,20 +230,19 @@ class SnippetService:
stmt = stmt.where(CustomizedSnippet.created_by.in_(creators))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, match_all=True)
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, session, match_all=True)
if target_ids:
stmt = stmt.where(CustomizedSnippet.id.in_(target_ids))
else:
return [], 0, False
with self._session_scope() as session:
# Get total count
count_stmt = select(func.count()).select_from(stmt.subquery())
total = session.scalar(count_stmt) or 0
# Get total count
count_stmt = select(func.count()).select_from(stmt.subquery())
total = session.scalar(count_stmt) or 0
# Apply pagination
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
snippets = list(session.scalars(stmt).all())
# Apply pagination
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
snippets = list(session.scalars(stmt).all())
has_more = len(snippets) > limit
if has_more:

View File

@ -6,10 +6,9 @@ from flask_login import current_user
from pydantic import BaseModel, Field
from sqlalchemy import delete, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, scoped_session
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import TagType
from models.model import App, Tag, TagBinding
@ -56,7 +55,7 @@ class TagService:
@staticmethod
def get_target_ids_by_tag_ids(
tag_type: str, current_tenant_id: str, tag_ids: list[str], *, match_all: bool = False
tag_type: str, current_tenant_id: str, tag_ids: list[str], session: scoped_session, *, match_all: bool = False
):
"""
Return target IDs bound to tags for the given tenant and resource type.
@ -70,7 +69,7 @@ class TagService:
return []
# Deduplicate repeated query params so match_all counts each requested tag once.
requested_tag_ids = list(dict.fromkeys(tag_ids))
tags = db.session.scalars(
tags = session.scalars(
select(Tag.id).where(
Tag.id.in_(requested_tag_ids),
Tag.tenant_id == current_tenant_id,
@ -86,13 +85,13 @@ class TagService:
if match_all:
if len(tag_ids) != len(requested_tag_ids):
return []
return db.session.scalars(
return session.scalars(
select(TagBinding.target_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.group_by(TagBinding.target_id)
.having(func.count(sa.distinct(TagBinding.tag_id)) == len(tag_ids))
).all()
tag_bindings = db.session.scalars(
tag_bindings = session.scalars(
select(TagBinding.target_id).where(
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
)
@ -100,11 +99,11 @@ class TagService:
return tag_bindings
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str, session: scoped_session):
if not tag_type or not tag_name:
return []
tags = list(
db.session.scalars(
session.scalars(
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
)
@ -113,8 +112,8 @@ class TagService:
return tags
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
tags = db.session.scalars(
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str, session: scoped_session):
tags = session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
@ -128,8 +127,8 @@ class TagService:
return tags or []
@staticmethod
def save_tags(payload: SaveTagPayload) -> Tag:
if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name):
def save_tags(payload: SaveTagPayload, session: scoped_session) -> Tag:
if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name, session):
raise ValueError("Tag name already exists")
tag = Tag(
name=payload.name,
@ -138,17 +137,17 @@ class TagService:
tenant_id=current_user.current_tenant_id,
)
tag.id = str(uuid.uuid4())
db.session.add(tag)
db.session.commit()
session.add(tag)
session.commit()
return tag
@staticmethod
def update_tags(payload: UpdateTagPayload, tag_id: str) -> Tag:
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
def update_tags(payload: UpdateTagPayload, tag_id: str, session: scoped_session) -> Tag:
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
if not tag:
raise NotFound("Tag not found")
if payload.name != tag.name:
existing = db.session.scalar(
existing = session.scalar(
select(Tag)
.where(
Tag.name == payload.name,
@ -161,31 +160,31 @@ class TagService:
if existing:
raise ValueError("Tag name already exists")
tag.name = payload.name
db.session.commit()
session.commit()
return tag
@staticmethod
def get_tag_binding_count(tag_id: str) -> int:
count = db.session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
def get_tag_binding_count(tag_id: str, session: scoped_session) -> int:
count = session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
return count
@staticmethod
def delete_tag(tag_id: str):
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
def delete_tag(tag_id: str, session: scoped_session):
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
if not tag:
raise NotFound("Tag not found")
db.session.delete(tag)
session.delete(tag)
# delete tag binding
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
tag_bindings = session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)
db.session.commit()
session.delete(tag_binding)
session.commit()
@staticmethod
def save_tag_binding(payload: TagBindingCreatePayload):
TagService.check_target_exists(payload.type, payload.target_id)
valid_tag_ids = db.session.scalars(
def save_tag_binding(payload: TagBindingCreatePayload, session: scoped_session):
TagService.check_target_exists(payload.type, payload.target_id, session)
valid_tag_ids = session.scalars(
select(Tag.id).where(
Tag.id.in_(payload.tag_ids),
Tag.tenant_id == current_user.current_tenant_id,
@ -193,7 +192,7 @@ class TagService:
)
).all()
for tag_id in valid_tag_ids:
tag_binding = db.session.scalar(
tag_binding = session.scalar(
select(TagBinding)
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == payload.target_id)
.limit(1)
@ -206,15 +205,15 @@ class TagService:
tenant_id=current_user.current_tenant_id,
created_by=current_user.id,
)
db.session.add(new_tag_binding)
db.session.commit()
session.add(new_tag_binding)
session.commit()
@staticmethod
def delete_tag_binding(payload: TagBindingDeletePayload):
TagService.check_target_exists(payload.type, payload.target_id)
def delete_tag_binding(payload: TagBindingDeletePayload, session: scoped_session):
TagService.check_target_exists(payload.type, payload.target_id, session)
result = cast(
CursorResult,
db.session.execute(
session.execute(
delete(TagBinding).where(
TagBinding.target_id == payload.target_id,
TagBinding.tag_id.in_(payload.tag_ids),
@ -230,12 +229,12 @@ class TagService:
)
if result.rowcount:
db.session.commit()
session.commit()
@staticmethod
def check_target_exists(type: str, target_id: str):
def check_target_exists(type: str, target_id: str, session: scoped_session):
if type == "knowledge":
dataset = db.session.scalar(
dataset = session.scalar(
select(Dataset)
.where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
.limit(1)
@ -243,13 +242,13 @@ class TagService:
if not dataset:
raise NotFound("Dataset not found")
elif type == "app":
app = db.session.scalar(
app = session.scalar(
select(App).where(App.tenant_id == current_user.current_tenant_id, App.id == target_id).limit(1)
)
if not app:
raise NotFound("App not found")
elif type == "snippet":
snippet = db.session.scalar(
snippet = session.scalar(
select(CustomizedSnippet)
.where(CustomizedSnippet.tenant_id == current_user.current_tenant_id, CustomizedSnippet.id == target_id)
.limit(1)

View File

@ -125,10 +125,11 @@ class DraftVarLoader(VariableLoader):
# can be safely accessed before any offloading logic is applied.
for draft_var in draft_vars:
value = draft_var.get_value()
if isinstance(value, FileSegment):
files.append(value.value)
elif isinstance(value, ArrayFileSegment):
files.extend(value.value)
match value:
case FileSegment():
files.append(value.value)
case ArrayFileSegment():
files.extend(value.value)
with Session(bind=self._engine) as session:
storage_key_loader = StorageKeyLoader(
session,

View File

@ -34,10 +34,11 @@ class WorkflowRunService:
def __init__(self, session_factory: Engine | sessionmaker | None = None):
"""Initialize WorkflowRunService with repository dependencies."""
if session_factory is None:
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
elif isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
match session_factory:
case None:
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
case Engine():
session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
self._session_factory = session_factory
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(

View File

@ -131,10 +131,11 @@ def fetch_table_rows(
for row in rows:
normalized = dict(row)
for key, value in normalized.items():
if isinstance(value, datetime):
normalized[key] = value.isoformat()
elif isinstance(value, uuid.UUID):
normalized[key] = str(value)
match value:
case datetime():
normalized[key] = value.isoformat()
case uuid.UUID():
normalized[key] = str(value)
result.append(normalized)
return result

View File

@ -16,7 +16,7 @@ since these test controller-level behavior.
import uuid
from contextlib import ExitStack
from datetime import UTC, datetime
from unittest.mock import Mock, PropertyMock, patch
from unittest.mock import ANY, Mock, PropertyMock, patch
import pytest
from flask import Flask
@ -1129,7 +1129,7 @@ class TestDatasetTagsApiPatch:
assert status == 200
assert response == {"id": "tag-1", "name": "Updated Tag", "type": "knowledge", "binding_count": "5"}
mock_tag_svc.update_tags.assert_called_once()
update_payload, tag_id = mock_tag_svc.update_tags.call_args.args
update_payload, tag_id, session = mock_tag_svc.update_tags.call_args.args
assert update_payload.name == "Updated Tag"
assert tag_id == "tag-1"
@ -1184,7 +1184,7 @@ class TestDatasetTagsApiDelete:
result = api.delete(_=None)
assert result == ("", 204)
mock_tag_svc.delete_tag.assert_called_once_with("tag-1")
mock_tag_svc.delete_tag.assert_called_once_with("tag-1", ANY)
@patch("libs.login.current_user")
def test_delete_tag_forbidden(self, mock_current_user, app: Flask):
@ -1233,7 +1233,7 @@ class TestDatasetTagsBindingStatusApi:
assert status_code == 200
assert response["data"] == [{"id": "tag_1", "name": "Test Tag"}]
assert response["total"] == 1
mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123")
mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123", ANY)
class TestDatasetTagBindingApiPost:
@ -1266,7 +1266,8 @@ class TestDatasetTagBindingApiPost:
from services.tag_service import TagBindingCreatePayload
mock_tag_svc.save_tag_binding.assert_called_once_with(
TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
ANY,
)
@patch("controllers.service_api.dataset.dataset.current_user")
@ -1317,7 +1318,8 @@ class TestDatasetTagUnbindingApiPost:
from services.tag_service import TagBindingDeletePayload
mock_tag_svc.delete_tag_binding.assert_called_once_with(
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
ANY,
)
@patch("controllers.service_api.dataset.dataset.TagService")
@ -1347,7 +1349,8 @@ class TestDatasetTagUnbindingApiPost:
from services.tag_service import TagBindingDeletePayload
mock_tag_svc.delete_tag_binding.assert_called_once_with(
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
ANY,
)
@patch("controllers.service_api.dataset.dataset.current_user")

View File

@ -2755,7 +2755,7 @@ class TestRegisterService:
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test inviting an existing member who is not in the tenant yet.
Test inviting an existing active account who is not in the tenant yet.
"""
fake = Faker()
tenant_name = fake.company()
@ -2791,20 +2791,20 @@ class TestRegisterService:
# Mock the email task
with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail:
mock_send_mail.delay.return_value = None
with pytest.raises(AccountAlreadyInTenantError, match="Account already in tenant."):
# Execute invitation
token = RegisterService.invite_new_member(
tenant=tenant,
email=existing_member_email,
language=language,
role="admin",
inviter=inviter,
)
# Verify email task was not called
mock_send_mail.delay.assert_not_called()
token = RegisterService.invite_new_member(
tenant=tenant,
email=existing_member_email,
language=language,
role="admin",
inviter=inviter,
)
# Verify tenant member was created for existing account
assert token is not None
assert len(token) > 0
mock_send_mail.delay.assert_called_once()
# Existing active accounts must accept the invite before becoming workspace members.
from models.account import TenantAccountJoin
tenant_join = (
@ -2812,8 +2812,13 @@ class TestRegisterService:
.filter_by(tenant_id=tenant.id, account_id=existing_account.id)
.first()
)
assert tenant_join is not None
assert tenant_join.role == "admin"
assert tenant_join is None
invitation = RegisterService.get_invitation_if_token_valid(None, None, token)
assert invitation is not None
assert invitation["account"].id == existing_account.id
assert invitation["data"]["role"] == "admin"
assert invitation["data"]["requires_setup"] is False
def test_invite_new_member_existing_member(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@ -234,7 +234,7 @@ class TestAppService:
# Get paginated apps
params = AppListParams(page=1, limit=10, mode="chat")
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
# Verify pagination results
assert paginated_apps is not None
@ -295,7 +295,7 @@ class TestAppService:
db_session_with_containers.commit()
last_modified_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
)
assert last_modified_apps is not None
assert [app.name for app in last_modified_apps.items] == [
@ -305,7 +305,10 @@ class TestAppService:
]
recently_created_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", sort_by="recently_created")
account.id,
tenant.id,
AppListParams(page=1, limit=10, mode="chat", sort_by="recently_created"),
db_session_with_containers,
)
assert recently_created_apps is not None
assert [app.name for app in recently_created_apps.items] == [
@ -315,7 +318,10 @@ class TestAppService:
]
earliest_created_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created")
account.id,
tenant.id,
AppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created"),
db_session_with_containers,
)
assert earliest_created_apps is not None
assert [app.name for app in earliest_created_apps.items] == [
@ -366,7 +372,7 @@ class TestAppService:
assert star_count == 1
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
)
assert paginated_apps is not None
starred_by_app_id = {app.id: app.is_starred for app in paginated_apps.items}
@ -377,7 +383,7 @@ class TestAppService:
db_session_with_containers.commit()
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
)
assert paginated_apps is not None
starred_by_app_id = {app.id: app.is_starred for app in paginated_apps.items}
@ -442,7 +448,7 @@ class TestAppService:
db_session_with_containers.commit()
last_modified_apps = app_service.get_paginate_starred_apps(
account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat")
account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
)
assert last_modified_apps is not None
assert [app.name for app in last_modified_apps.items] == [
@ -457,6 +463,7 @@ class TestAppService:
account.id,
tenant.id,
StarredAppListParams(page=1, limit=10, mode="chat", sort_by="recently_created"),
db_session_with_containers,
)
assert recently_created_apps is not None
assert [app.name for app in recently_created_apps.items] == [
@ -469,6 +476,7 @@ class TestAppService:
account.id,
tenant.id,
StarredAppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created"),
db_session_with_containers,
)
assert earliest_created_apps is not None
assert [app.name for app in earliest_created_apps.items] == [
@ -522,20 +530,25 @@ class TestAppService:
completion_app = app_service.create_app(tenant.id, completion_app_params, account)
# Test filter by mode
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"))
chat_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
)
assert len(chat_apps.items) == 1
assert chat_apps.items[0].mode == "chat"
# Test filter by name
filtered_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat")
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat"), db_session_with_containers
)
assert len(filtered_apps.items) == 1
assert "Chat" in filtered_apps.items[0].name
# Test filter by created_by_me
my_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True)
account.id,
tenant.id,
AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True),
db_session_with_containers,
)
assert len(my_apps.items) == 1
@ -588,6 +601,7 @@ class TestAppService:
first_account.id,
tenant.id,
AppListParams(page=1, limit=10, mode="chat", creator_ids=[second_account.id]),
db_session_with_containers,
)
assert filtered_apps is not None
@ -635,10 +649,12 @@ class TestAppService:
# Test with tag filter
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["tag1", "tag2"])
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
# Verify tag service was called
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"], match_all=True)
mock_tag_service.assert_called_once_with(
"app", tenant.id, ["tag1", "tag2"], db_session_with_containers, match_all=True
)
# Verify results
assert paginated_apps is not None
@ -651,7 +667,7 @@ class TestAppService:
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["nonexistent_tag"])
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
# Should return None when no apps match tag filter
assert paginated_apps is None
@ -1467,7 +1483,7 @@ class TestAppService:
# Test 1: Search with % character
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10), db_session_with_containers
)
assert paginated_apps is not None
assert paginated_apps.total == 1
@ -1476,7 +1492,10 @@ class TestAppService:
# Test 2: Search with _ character
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="test_data", mode="chat", page=1, limit=10)
account.id,
tenant.id,
AppListParams(name="test_data", mode="chat", page=1, limit=10),
db_session_with_containers,
)
assert paginated_apps is not None
assert paginated_apps.total == 1
@ -1485,7 +1504,10 @@ class TestAppService:
# Test 3: Search with \ character
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10)
account.id,
tenant.id,
AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10),
db_session_with_containers,
)
assert paginated_apps is not None
assert paginated_apps.total == 1
@ -1494,7 +1516,7 @@ class TestAppService:
# Test 4: Search with % should NOT match 100% (verifies escaping works)
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10), db_session_with_containers
)
assert paginated_apps is not None
assert paginated_apps.total == 1

View File

@ -227,7 +227,7 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id)
datasets, total = DatasetService.get_datasets(page, per_page, db_session_with_containers, tenant_id=tenant.id)
# Assert
assert len(datasets) == 5
@ -257,7 +257,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, search=search)
datasets, total = DatasetService.get_datasets(
page, per_page, db_session_with_containers, tenant_id=tenant.id, search=search
)
# Assert
assert len(datasets) == 1
@ -301,7 +303,9 @@ class TestDatasetServiceGetDatasets:
tag_ids = [tag_1.id, tag_2.id]
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids)
datasets, total = DatasetService.get_datasets(
page, per_page, db_session_with_containers, tenant_id=tenant.id, tag_ids=tag_ids
)
# Assert
assert len(datasets) == 1
@ -326,7 +330,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids)
datasets, total = DatasetService.get_datasets(
page, per_page, db_session_with_containers, tenant_id=tenant.id, tag_ids=tag_ids
)
# Assert
# When tag_ids is empty, tag filtering is skipped, so normal query results are returned
@ -356,7 +362,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, user=None)
datasets, total = DatasetService.get_datasets(
page, per_page, db_session_with_containers, tenant_id=tenant.id, user=None
)
# Assert
assert len(datasets) == 1
@ -384,6 +392,7 @@ class TestDatasetServiceGetDatasets:
datasets, total = DatasetService.get_datasets(
page=1,
per_page=20,
session=db_session_with_containers,
tenant_id=tenant.id,
user=owner,
include_all=True,
@ -408,7 +417,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
)
# Assert
assert len(datasets) == 1
@ -432,7 +443,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
)
# Assert
assert len(datasets) == 1
@ -459,7 +472,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
)
# Assert
assert len(datasets) == 1
@ -486,7 +501,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=operator
)
# Assert
assert len(datasets) == 1
@ -509,7 +526,9 @@ class TestDatasetServiceGetDatasets:
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=operator
)
# Assert
assert datasets == []

View File

@ -449,7 +449,7 @@ class TestTagService:
# Act: Execute the method under test
tag_ids = [tag.id for tag in tags]
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids)
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -485,7 +485,7 @@ class TestTagService:
)
# Act: Execute the method under test with empty tag IDs
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, [])
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, [], db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -533,13 +533,19 @@ class TestTagService:
# Act: Execute the method under test
tag_ids = [tag.id for tag in tags]
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, match_all=True)
result = TagService.get_target_ids_by_tag_ids(
"knowledge", tenant.id, tag_ids, db_session_with_containers, match_all=True
)
# Assert: Verify the expected outcomes
assert result == [dataset_with_all_tags.id]
missing_tag_result = TagService.get_target_ids_by_tag_ids(
"knowledge", tenant.id, [tags[0].id, str(uuid.uuid4())], match_all=True
"knowledge",
tenant.id,
[tags[0].id, str(uuid.uuid4())],
db_session_with_containers,
match_all=True,
)
assert missing_tag_result == []
@ -565,7 +571,9 @@ class TestTagService:
non_existent_tag_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# Act: Execute the method under test
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, non_existent_tag_ids)
result = TagService.get_target_ids_by_tag_ids(
"knowledge", tenant.id, non_existent_tag_ids, db_session_with_containers
)
# Assert: Verify the expected outcomes
assert result is not None
@ -599,7 +607,7 @@ class TestTagService:
db_session_with_containers.commit()
# Act: Execute the method under test
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag")
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag", db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -625,7 +633,7 @@ class TestTagService:
)
# Act: Execute the method under test with non-existent tag name
result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag")
result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag", db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -650,8 +658,8 @@ class TestTagService:
)
# Act: Execute the method under test with empty parameters
result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag")
result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "")
result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag", db_session_with_containers)
result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "", db_session_with_containers)
# Assert: Verify the expected outcomes
assert result_empty_type is not None
@ -688,7 +696,7 @@ class TestTagService:
)
# Act: Execute the method under test
result = TagService.get_tags_by_target_id("app", tenant.id, app.id)
result = TagService.get_tags_by_target_id("app", tenant.id, app.id, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -720,7 +728,7 @@ class TestTagService:
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
# Act: Execute the method under test
result = TagService.get_tags_by_target_id("app", tenant.id, app.id)
result = TagService.get_tags_by_target_id("app", tenant.id, app.id, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -745,7 +753,7 @@ class TestTagService:
tag_args = SaveTagPayload(name="test_tag_name", type="knowledge")
# Act: Execute the method under test
result = TagService.save_tags(tag_args)
result = TagService.save_tags(tag_args, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -783,11 +791,11 @@ class TestTagService:
# Create first tag
tag_args = SaveTagPayload(name="duplicate_tag", type="app")
TagService.save_tags(tag_args)
TagService.save_tags(tag_args, db_session_with_containers)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
TagService.save_tags(tag_args)
TagService.save_tags(tag_args, db_session_with_containers)
assert "Tag name already exists" in str(exc_info.value)
def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
@ -807,13 +815,13 @@ class TestTagService:
# Create a tag to update
tag_args = SaveTagPayload(name="original_name", type="knowledge")
tag = TagService.save_tags(tag_args)
tag = TagService.save_tags(tag_args, db_session_with_containers)
# Update args
update_args = UpdateTagPayload(name="updated_name")
# Act: Execute the method under test
result = TagService.update_tags(update_args, tag.id)
result = TagService.update_tags(update_args, tag.id, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result is not None
@ -854,7 +862,7 @@ class TestTagService:
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
TagService.update_tags(update_args, non_existent_tag_id)
TagService.update_tags(update_args, non_existent_tag_id, db_session_with_containers)
assert "Tag not found" in str(exc_info.value)
def test_update_tags_duplicate_name_error(
@ -875,17 +883,17 @@ class TestTagService:
# Create two tags
tag1_args = SaveTagPayload(name="first_tag", type="app")
tag1 = TagService.save_tags(tag1_args)
tag1 = TagService.save_tags(tag1_args, db_session_with_containers)
tag2_args = SaveTagPayload(name="second_tag", type="app")
tag2 = TagService.save_tags(tag2_args)
tag2 = TagService.save_tags(tag2_args, db_session_with_containers)
# Try to update second tag with first tag's name
update_args = UpdateTagPayload(name="first_tag")
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
TagService.update_tags(update_args, tag2.id)
TagService.update_tags(update_args, tag2.id, db_session_with_containers)
assert "Tag name already exists" in str(exc_info.value)
def test_get_tag_binding_count_success(
@ -917,8 +925,8 @@ class TestTagService:
)
# Act: Execute the method under test
result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id)
result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id)
result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id, db_session_with_containers)
result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result_tag_with_bindings == 1
@ -946,7 +954,7 @@ class TestTagService:
non_existent_tag_id = str(uuid.uuid4())
# Act: Execute the method under test
result = TagService.get_tag_binding_count(non_existent_tag_id)
result = TagService.get_tag_binding_count(non_existent_tag_id, db_session_with_containers)
# Assert: Verify the expected outcomes
assert result == 0
@ -986,7 +994,7 @@ class TestTagService:
assert binding_before is not None
# Act: Execute the method under test
TagService.delete_tag(tag.id)
TagService.delete_tag(tag.id, db_session_with_containers)
# Assert: Verify the expected outcomes
# Verify tag was deleted
@ -1018,7 +1026,7 @@ class TestTagService:
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
TagService.delete_tag(non_existent_tag_id)
TagService.delete_tag(non_existent_tag_id, db_session_with_containers)
assert "Tag not found" in str(exc_info.value)
def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
@ -1048,7 +1056,7 @@ class TestTagService:
binding_payload = TagBindingCreatePayload(
type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags]
)
TagService.save_tag_binding(binding_payload)
TagService.save_tag_binding(binding_payload, db_session_with_containers)
# Assert: Verify the expected outcomes
@ -1090,10 +1098,10 @@ class TestTagService:
# Create first binding
binding_payload = TagBindingCreatePayload(type="app", target_id=app.id, tag_ids=[tag.id])
TagService.save_tag_binding(binding_payload)
TagService.save_tag_binding(binding_payload, db_session_with_containers)
# Act: Try to create duplicate binding
TagService.save_tag_binding(binding_payload)
TagService.save_tag_binding(binding_payload, db_session_with_containers)
# Assert: Verify the expected outcomes
@ -1173,7 +1181,7 @@ class TestTagService:
delete_payload = TagBindingDeletePayload(
type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags]
)
TagService.delete_tag_binding(delete_payload)
TagService.delete_tag_binding(delete_payload, db_session_with_containers)
# Assert: Verify the expected outcomes
# Verify tag bindings were deleted
@ -1209,7 +1217,7 @@ class TestTagService:
# Act: Try to delete non-existent binding
delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_ids=[tag.id])
TagService.delete_tag_binding(delete_payload)
TagService.delete_tag_binding(delete_payload, db_session_with_containers)
# Assert: Verify the expected outcomes
# No error should be raised, and database state should remain unchanged
@ -1240,7 +1248,7 @@ class TestTagService:
dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
# Act: Execute the method under test
TagService.check_target_exists("knowledge", dataset.id)
TagService.check_target_exists("knowledge", dataset.id, db_session_with_containers)
# Assert: Verify the expected outcomes
# No exception should be raised for existing dataset
@ -1268,7 +1276,7 @@ class TestTagService:
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
TagService.check_target_exists("knowledge", non_existent_dataset_id)
TagService.check_target_exists("knowledge", non_existent_dataset_id, db_session_with_containers)
assert "Dataset not found" in str(exc_info.value)
def test_check_target_exists_app_success(
@ -1292,7 +1300,7 @@ class TestTagService:
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
# Act: Execute the method under test
TagService.check_target_exists("app", app.id)
TagService.check_target_exists("app", app.id, db_session_with_containers)
# Assert: Verify the expected outcomes
# No exception should be raised for existing app
@ -1320,7 +1328,7 @@ class TestTagService:
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
TagService.check_target_exists("app", non_existent_app_id)
TagService.check_target_exists("app", non_existent_app_id, db_session_with_containers)
assert "App not found" in str(exc_info.value)
def test_check_target_exists_invalid_type(
@ -1346,5 +1354,5 @@ class TestTagService:
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
TagService.check_target_exists("invalid_type", non_existent_target_id)
TagService.check_target_exists("invalid_type", non_existent_target_id, db_session_with_containers)
assert "Invalid binding type" in str(exc_info.value)

View File

@ -266,6 +266,7 @@ def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path):
assert module._schema_ref_name(None) is None
assert module._schema_markdown_type(None) == ""
assert module._schema_markdown_type({"anyOf": [{"type": "null"}]}) == ""
assert module._strip_trailing_line_whitespace("line \ncell\t \n") == "line\ncell\n"
assert module._replace_schema_table_type("unchanged", "Definition", "field", "") == "unchanged"
assert (
module._replace_schema_table_type(
@ -319,7 +320,10 @@ def test_convert_spec_to_markdown_patches_generated_union_tables(tmp_path, monke
assert kwargs["check"] is False
markdown_path = Path(args[args.index("-o") + 1])
markdown_path.write_text(
"""#### FormInputConfig
"Intro line"
+ " \n"
+ """
#### FormInputConfig
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
@ -340,5 +344,7 @@ def test_convert_spec_to_markdown_patches_generated_union_tables(tmp_path, monke
module._convert_spec_to_markdown(spec_path, output_path)
converted = output_path.read_text(encoding="utf-8")
assert "Intro line \n" not in converted
assert "Intro line\n" in converted
assert "| FormInputConfig | [ParagraphInputConfig](#paragraphinputconfig) | | |" in converted
assert "| default | [StringSource](#stringsource) | | No |" in converted

View File

@ -8,12 +8,13 @@ from pathlib import Path
def _walk_values(value):
yield value
if isinstance(value, dict):
for child in value.values():
yield from _walk_values(child)
elif isinstance(value, list):
for child in value:
yield from _walk_values(child)
match value:
case dict():
for child in value.values():
yield from _walk_values(child)
case list():
for child in value:
yield from _walk_values(child)
def _load_generate_swagger_specs_module():
@ -106,6 +107,39 @@ def test_generate_specs_writes_get_operations_without_request_bodies(tmp_path):
assert all("requestBody" not in operation for operation in _get_operations(payload))
def test_generate_specs_writes_service_api_reference_descriptions(tmp_path):
module = _load_generate_swagger_specs_module()
written_paths = module.generate_specs(tmp_path)
service_path = next(path for path in written_paths if path.name == "service-openapi.json")
payload = json.loads(service_path.read_text(encoding="utf-8"))
chat_operation = payload["paths"]["/chat-messages"]["post"]
assert chat_operation["summary"] == "Send Chat Message"
assert chat_operation["description"] == "Send a request to the chat application."
assert chat_operation["tags"] == ["Chatflows", "Chats"]
rename_operation = payload["paths"]["/conversations/{c_id}/name"]["post"]
assert rename_operation["summary"] == "Rename Conversation"
def test_standalone_inline_model_name_includes_list_constraints():
module = _load_generate_swagger_specs_module()
from flask_restx import fields
cases = (
({"min_items": 1}, {"min_items": 2}),
({"max_items": 1}, {"max_items": 2}),
({"unique": True}, {"unique": False}),
)
for first_kwargs, second_kwargs in cases:
first_inline_model = {"items": fields.List(fields.String, **first_kwargs)}
second_inline_model = {"items": fields.List(fields.String, **second_kwargs)}
assert module._inline_model_name(first_inline_model) != module._inline_model_name(second_inline_model)
def test_generate_specs_is_idempotent(tmp_path):
module = _load_generate_swagger_specs_module()

View File

@ -24,7 +24,9 @@ from controllers.console.agent.roster import (
AgentAppCopyApi,
AgentAppListApi,
AgentInviteOptionsApi,
AgentLogMessagesApi,
AgentLogsApi,
AgentLogSourcesApi,
AgentRosterVersionDetailApi,
AgentRosterVersionsApi,
AgentStatisticsSummaryApi,
@ -93,6 +95,7 @@ def _agent_app_composer_response() -> dict:
def _app_detail_obj(**overrides):
data = {
"id": "app-1",
"tenant_id": "tenant-1",
"name": "Iris",
"description": "Agent app",
"mode_compatible_with_agent": "agent",
@ -116,7 +119,6 @@ def _app_detail_obj(**overrides):
"deleted_tools": [],
"site": None,
"bound_agent_id": "00000000-0000-0000-0000-000000000001",
"tenant_id": "tenant-1",
}
data.update(overrides)
return SimpleNamespace(**data)
@ -153,6 +155,8 @@ def test_agent_v2_console_routes_are_agent_id_first() -> None:
"/agent/<uuid:agent_id>/chat-messages/<uuid:message_id>/suggested-questions",
"/agent/<uuid:agent_id>/messages/<uuid:message_id>",
"/agent/<uuid:agent_id>/logs",
"/agent/<uuid:agent_id>/logs/<uuid:conversation_id>/messages",
"/agent/<uuid:agent_id>/log-sources",
"/agent/<uuid:agent_id>/statistics/summary",
"/agent/invite-options",
):
@ -187,7 +191,7 @@ def test_agent_app_list_and_create_use_agent_route(
def get_app(self, app_obj: object) -> object:
return app_obj
def get_paginate_apps(self, user_id: str, tenant_id: str, params) -> object:
def get_paginate_apps(self, user_id: str, tenant_id: str, params, session) -> object:
captured["list"] = {"user_id": user_id, "tenant_id": tenant_id, "params": params}
return SimpleNamespace(
page=1,
@ -270,7 +274,13 @@ def test_agent_app_list_and_create_use_agent_route(
with app.test_request_context(
"/console/api/agent",
json={"name": "Iris", "description": "Agent app", "icon_type": "emoji", "icon": "robot"},
json={
"name": "Iris",
"description": "Agent app",
"role": "Coordinator",
"icon_type": "emoji",
"icon": "robot",
},
):
created, status = unwrap(AgentAppListApi.post)(AgentAppListApi(), "tenant-1", SimpleNamespace(id=account_id))
@ -283,6 +293,23 @@ def test_agent_app_list_and_create_use_agent_route(
create_call = cast(dict[str, object], captured["create"])
create_params = cast(Any, create_call["params"])
assert create_params.mode == "agent"
assert create_params.agent_role == "Coordinator"
def test_agent_app_create_requires_role(app: Flask, account_id: str) -> None:
with app.test_request_context(
"/console/api/agent",
json={"name": "Iris", "description": "Agent app", "icon_type": "emoji", "icon": "robot"},
):
with pytest.raises(ValueError, match="Field required"):
unwrap(AgentAppListApi.post)(AgentAppListApi(), "tenant-1", SimpleNamespace(id=account_id))
with app.test_request_context(
"/console/api/agent",
json={"name": "Iris", "description": "Agent app", "role": " ", "icon_type": "emoji", "icon": "robot"},
):
with pytest.raises(ValueError, match="Agent role is required"):
unwrap(AgentAppListApi.post)(AgentAppListApi(), "tenant-1", SimpleNamespace(id=account_id))
def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
@ -331,7 +358,7 @@ def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
with app.test_request_context(
"/console/api/agent/00000000-0000-0000-0000-000000000001",
json={"name": "Renamed", "description": "", "icon_type": "emoji", "icon": "R"},
json={"name": "Renamed", "description": "", "role": "Reviewer", "icon_type": "emoji", "icon": "R"},
):
updated = unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", agent_id)
@ -343,6 +370,7 @@ def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
assert "bound_agent_id" not in updated
update_call = cast(dict[str, object], captured["update"])
assert update_call["app"] is app_model
assert cast(dict[str, object], update_call["args"])["role"] == "Reviewer"
deleted, status = unwrap(AgentAppApi.delete)(AgentAppApi(), "tenant-1", agent_id)
assert (deleted, status) == ("", 204)
@ -395,6 +423,45 @@ def test_agent_app_copy_uses_agent_id_and_returns_agent_detail(
}
def test_agent_app_update_rejects_empty_role(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
agent_id = "00000000-0000-0000-0000-000000000001"
app_model = _app_detail_obj(id="app-1", bound_agent_id=agent_id)
captured: dict[str, object] = {}
monkeypatch.setattr(
roster_controller.AgentRosterService,
"get_agent_app_model",
lambda _self, **kwargs: app_model,
)
monkeypatch.setattr(
roster_controller.AgentRosterService,
"get_app_backing_agent",
lambda _self, **kwargs: SimpleNamespace(id=agent_id, role="", active_config_snapshot_id=None),
)
monkeypatch.setattr(
roster_controller.FeatureService,
"get_system_features",
lambda: SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)),
)
class FakeAppService:
def get_app(self, app_obj: object) -> object:
return app_obj
def update_app(self, app_obj: object, args: dict[str, object]) -> object:
captured["update"] = {"app": app_obj, "args": args}
return _app_detail_obj(id="app-1", name=args["name"], bound_agent_id=agent_id)
monkeypatch.setattr(roster_controller, "AppService", FakeAppService)
with app.test_request_context(
"/console/api/agent/00000000-0000-0000-0000-000000000001",
json={"name": "Renamed", "description": "", "role": "", "icon_type": "emoji", "icon": "R"},
):
with pytest.raises(ValueError, match="String should have at least 1 character"):
unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", agent_id)
def test_invite_options_get_parses_app_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict[str, object] = {}
@ -461,21 +528,59 @@ def test_agent_observability_routes_resolve_app_from_agent_id(
captured: dict[str, object] = {}
class FakeObservabilityService:
def list_logs(self, *, app, params):
captured["logs"] = {"app": app, "params": params}
def list_logs(self, *, app, agent_id, params):
captured["logs"] = {"app": app, "agent_id": agent_id, "params": params}
return {
"data": [
{
"conversation_id": "conversation-1",
"id": "conversation-1",
"title": "Debug",
"end_user_id": "end-user-1",
"message_count": 2,
"user_rate": None,
"operation_rate": None,
"unread": True,
"source": {
"id": "webapp:app-1",
"type": "webapp",
"app_id": "app-1",
"app_name": "Iris",
"app_icon_type": "emoji",
"app_icon": "robot",
"app_icon_background": "#fff",
"workflow_id": None,
"workflow_version": None,
"node_id": None,
},
"status": "success",
"created_at": 1,
"updated_at": 2,
}
],
"page": 2,
"limit": 5,
"total": 6,
"has_more": False,
}
def list_log_messages(self, *, app, agent_id, conversation_id, params):
captured["messages"] = {
"app": app,
"agent_id": agent_id,
"conversation_id": conversation_id,
"params": params,
}
return {
"data": [
{
"id": "message-1",
"message_id": "message-1",
"conversation_id": "conversation-1",
"conversation_name": "Debug",
"query": "hello",
"answer": "hi",
"status": "success",
"error": None,
"source": "explore",
"from_source": "console",
"from_end_user_id": None,
"from_account_id": account_id,
"message_tokens": 1,
@ -488,14 +593,34 @@ def test_agent_observability_routes_resolve_app_from_agent_id(
"updated_at": 2,
}
],
"page": 2,
"limit": 5,
"total": 6,
"page": 1,
"limit": 20,
"total": 1,
"has_more": False,
}
def get_statistics_summary(self, *, app, params):
captured["statistics"] = {"app": app, "params": params}
def list_log_sources(self, *, app, agent_id):
captured["sources"] = {"app": app, "agent_id": agent_id}
return {
"data": [
{
"id": "webapp:app-1",
"type": "webapp",
"app_id": "app-1",
"app_name": "Iris",
"app_icon_type": "emoji",
"app_icon": "robot",
"app_icon_background": "#fff",
"workflow_id": None,
"workflow_version": None,
"node_id": None,
}
],
"groups": [{"type": "webapp", "label": "WEBAPP", "sources": []}],
}
def get_statistics_summary(self, *, app, agent_id, params):
captured["statistics"] = {"app": app, "agent_id": agent_id, "params": params}
return {
"source": "all",
"summary": {
@ -532,9 +657,11 @@ def test_agent_observability_routes_resolve_app_from_agent_id(
):
logs = unwrap(AgentLogsApi.get)(AgentLogsApi(), "tenant-1", account, agent_id)
assert logs["data"][0]["id"] == "message-1"
assert logs["data"][0]["id"] == "conversation-1"
assert logs["data"][0]["source"]["id"] == "webapp:app-1"
logs_call = cast(dict[str, object], captured["logs"])
assert logs_call["app"] is app_model
assert logs_call["agent_id"] == agent_id
logs_params = cast(Any, logs_call["params"])
assert logs_params.page == 2
assert logs_params.limit == 5
@ -542,6 +669,31 @@ def test_agent_observability_routes_resolve_app_from_agent_id(
assert logs_params.status == "success"
assert logs_params.source == "console"
with app.test_request_context(
"/console/api/agent/00000000-0000-0000-0000-000000000001/logs/00000000-0000-0000-0000-000000000002/messages"
):
messages = unwrap(AgentLogMessagesApi.get)(
AgentLogMessagesApi(),
"tenant-1",
account,
agent_id,
"00000000-0000-0000-0000-000000000002",
)
assert messages["data"][0]["id"] == "message-1"
messages_call = cast(dict[str, object], captured["messages"])
assert messages_call["app"] is app_model
assert messages_call["agent_id"] == agent_id
assert messages_call["conversation_id"] == "00000000-0000-0000-0000-000000000002"
with app.test_request_context("/console/api/agent/00000000-0000-0000-0000-000000000001/log-sources"):
sources = unwrap(AgentLogSourcesApi.get)(AgentLogSourcesApi(), "tenant-1", account, agent_id)
assert sources["data"][0]["id"] == "webapp:app-1"
sources_call = cast(dict[str, object], captured["sources"])
assert sources_call["app"] is app_model
assert sources_call["agent_id"] == agent_id
with app.test_request_context(
"/console/api/agent/00000000-0000-0000-0000-000000000001/statistics/summary?source=api"
):
@ -550,6 +702,7 @@ def test_agent_observability_routes_resolve_app_from_agent_id(
assert statistics["summary"]["total_messages"] == 1
stats_call = cast(dict[str, object], captured["statistics"])
assert stats_call["app"] is app_model
assert stats_call["agent_id"] == agent_id
stats_params = cast(Any, stats_call["params"])
assert stats_params.source == "api"
assert stats_params.timezone == "UTC"

Some files were not shown because too many files have changed in this diff Show More