mirror of
https://github.com/langgenius/dify.git
synced 2026-06-26 14:51:13 +08:00
Merge branch 'main' into refactor/clean-unnecessary-none-batch2
This commit is contained in:
commit
96d1fa2917
21
.github/workflows/build-push.yml
vendored
21
.github/workflows/build-push.yml
vendored
@ -21,6 +21,7 @@ env:
|
||||
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
DIFY_WEB_IMAGE_NAME: ${{ vars.DIFY_WEB_IMAGE_NAME || 'langgenius/dify-web' }}
|
||||
DIFY_API_IMAGE_NAME: ${{ vars.DIFY_API_IMAGE_NAME || 'langgenius/dify-api' }}
|
||||
DIFY_AGENT_IMAGE_NAME: ${{ vars.DIFY_AGENT_IMAGE_NAME || 'langgenius/dify-agent-backend' }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@ -60,6 +61,20 @@ jobs:
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-agent-amd64"
|
||||
image_name_env: "DIFY_AGENT_IMAGE_NAME"
|
||||
artifact_context: "agent"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "dify-agent/Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-agent-arm64"
|
||||
image_name_env: "DIFY_AGENT_IMAGE_NAME"
|
||||
artifact_context: "agent"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "dify-agent/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
@ -122,6 +137,9 @@ jobs:
|
||||
- service_name: "validate-web-amd64"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
- service_name: "validate-agent-amd64"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "dify-agent/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
@ -147,6 +165,9 @@ jobs:
|
||||
- service_name: "merge-web-images"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
context: "web"
|
||||
- service_name: "merge-agent-images"
|
||||
image_name_env: "DIFY_AGENT_IMAGE_NAME"
|
||||
context: "agent"
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import Any, Literal
|
||||
from copy import deepcopy
|
||||
from typing import Any, Literal, override
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, GetJsonSchemaHandler, model_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
|
||||
@ -12,6 +13,45 @@ class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def __get_pydantic_json_schema__(cls, core_schema: Any, handler: GetJsonSchemaHandler) -> dict[str, Any]:
|
||||
schema = handler.resolve_ref_schema(handler(core_schema))
|
||||
properties = schema.get("properties")
|
||||
if not isinstance(properties, dict):
|
||||
return schema
|
||||
|
||||
auto_generate_schema = deepcopy(properties.get("auto_generate", {"type": "boolean"}))
|
||||
name_schema = deepcopy(properties.get("name", {"type": "string"}))
|
||||
non_blank_name_schema: dict[str, Any] = {"pattern": r".*\S.*", "type": "string"}
|
||||
if isinstance(name_schema, dict) and isinstance(name_schema.get("title"), str):
|
||||
non_blank_name_schema["title"] = name_schema["title"]
|
||||
|
||||
auto_generate_true_schema = {**auto_generate_schema, "enum": [True]}
|
||||
auto_generate_true_schema.pop("default", None)
|
||||
|
||||
return {
|
||||
**schema,
|
||||
"anyOf": [
|
||||
{
|
||||
"properties": {
|
||||
"auto_generate": auto_generate_true_schema,
|
||||
"name": name_schema,
|
||||
},
|
||||
"required": ["auto_generate"],
|
||||
"type": "object",
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"auto_generate": {**auto_generate_schema, "enum": [False]},
|
||||
"name": non_blank_name_schema,
|
||||
},
|
||||
"required": ["name"],
|
||||
"type": "object",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
@ -101,4 +141,7 @@ class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
voice: str | None = Field(default=None, description="Voice to use for TTS")
|
||||
text: str | None = Field(default=None, description="Text to convert to audio")
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
streaming: bool | None = Field(
|
||||
default=None,
|
||||
description="Reserved for compatibility; TTS response streaming is determined by the provider output.",
|
||||
)
|
||||
|
||||
@ -42,10 +42,11 @@ def stringify_form_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
"""Serialize default values into strings expected by human-input form clients."""
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
match value:
|
||||
case None:
|
||||
result[key] = ""
|
||||
case dict() | list():
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
case _:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
@ -2,20 +2,28 @@ from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import AliasChoices, BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.agent.app_helpers import resolve_agent_app_model
|
||||
from controllers.console.app.app import (
|
||||
AppDetailWithSite,
|
||||
AppDetailWithSite as GenericAppDetailWithSite,
|
||||
)
|
||||
from controllers.console.app.app import (
|
||||
AppListQuery,
|
||||
AppPagination,
|
||||
AppPartial,
|
||||
CopyAppPayload,
|
||||
UpdateAppPayload,
|
||||
_normalize_app_list_query_args,
|
||||
)
|
||||
from controllers.console.app.app import (
|
||||
AppPagination as GenericAppPagination,
|
||||
)
|
||||
from controllers.console.app.app import (
|
||||
AppPartial as GenericAppPartial,
|
||||
)
|
||||
from controllers.console.app.app import (
|
||||
UpdateAppPayload as GenericUpdateAppPayload,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
@ -31,6 +39,8 @@ from fields.agent_fields import (
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentLogListResponse,
|
||||
AgentLogMessageListResponse,
|
||||
AgentLogSourceListResponse,
|
||||
AgentPublishedReferenceResponse,
|
||||
AgentRosterListResponse,
|
||||
AgentStatisticSummaryEnvelopeResponse,
|
||||
@ -64,14 +74,33 @@ class AgentIdPath(BaseModel):
|
||||
class AgentAppCreatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Agent name")
|
||||
description: str | None = Field(default=None, description="Agent description (max 400 chars)", max_length=400)
|
||||
role: str = Field(default="", description="Agent role", max_length=255)
|
||||
role: str = Field(..., min_length=1, description="Agent role", max_length=255)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, value: str) -> str:
|
||||
role = value.strip()
|
||||
if not role:
|
||||
raise ValueError("Agent role is required.")
|
||||
return role
|
||||
|
||||
class AgentAppUpdatePayload(UpdateAppPayload):
|
||||
role: str | None = Field(default=None, description="Agent role", max_length=255)
|
||||
|
||||
# Keep agent-app roster DTOs agent-specific instead of reusing the shared
|
||||
# /apps response/request models. The roster surface needs Agent-only fields such
|
||||
# as `role`, while the generic console/apps contracts must stay unchanged.
|
||||
class AgentAppUpdatePayload(GenericUpdateAppPayload):
|
||||
role: str = Field(..., min_length=1, description="Agent role", max_length=255)
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, value: str) -> str:
|
||||
role = value.strip()
|
||||
if not role:
|
||||
raise ValueError("Agent role is required.")
|
||||
return role
|
||||
|
||||
|
||||
class AgentAppPublishedReferenceResponse(BaseModel):
|
||||
@ -82,19 +111,6 @@ class AgentAppPublishedReferenceResponse(BaseModel):
|
||||
app_icon_background: str | None = None
|
||||
|
||||
|
||||
class AgentAppPartial(AppPartial):
|
||||
published_reference_count: int = 0
|
||||
published_references: list[AgentAppPublishedReferenceResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentAppPagination(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AgentAppPartial]
|
||||
|
||||
|
||||
class AgentLogsQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size")
|
||||
@ -131,6 +147,26 @@ class AgentStatisticsQuery(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
class AgentAppPartial(GenericAppPartial):
|
||||
app_id: str | None = None
|
||||
role: str | None = None
|
||||
active_config_is_published: bool = False
|
||||
published_reference_count: int = 0
|
||||
published_references: list[AgentAppPublishedReferenceResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentAppDetailWithSite(GenericAppDetailWithSite):
|
||||
app_id: str | None = None
|
||||
role: str | None = None
|
||||
active_config_is_published: bool = False
|
||||
|
||||
|
||||
class AgentAppPagination(GenericAppPagination):
|
||||
data: list[AgentAppPartial] = Field( # type: ignore[assignment] # pyrefly: ignore[bad-override-mutable-attribute]
|
||||
validation_alias=AliasChoices("items", "data")
|
||||
)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AgentAppCreatePayload,
|
||||
@ -141,18 +177,20 @@ register_schema_models(
|
||||
AgentStatisticsQuery,
|
||||
AgentIdPath,
|
||||
AppListQuery,
|
||||
UpdateAppPayload,
|
||||
RosterListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AppDetailWithSite,
|
||||
AgentAppPagination,
|
||||
AgentAppPublishedReferenceResponse,
|
||||
AgentAppDetailWithSite,
|
||||
AgentAppPartial,
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentLogListResponse,
|
||||
AgentLogMessageListResponse,
|
||||
AgentLogSourceListResponse,
|
||||
AgentPublishedReferenceResponse,
|
||||
AgentRosterListResponse,
|
||||
AgentStatisticSummaryEnvelopeResponse,
|
||||
@ -164,16 +202,25 @@ def _agent_roster_service() -> AgentRosterService:
|
||||
|
||||
|
||||
def _serialize_agent_app_detail(app_model) -> dict:
|
||||
"""Serialize an Agent App detail using roster-only DTOs.
|
||||
|
||||
`/agent` responses are roster-shaped rather than raw app-shaped: `id`
|
||||
becomes the backing roster Agent id, `app_id` carries the underlying App
|
||||
id, and `role` is injected from the backing roster Agent. Keeping that
|
||||
remap in this serializer lets generated console/agent contracts expose the
|
||||
roster persona fields without widening the shared /apps detail schema.
|
||||
"""
|
||||
|
||||
app_model = AppService().get_app(app_model)
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode # type: ignore[attr-defined]
|
||||
|
||||
roster_service = _agent_roster_service()
|
||||
agent = roster_service.get_app_backing_agent(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
payload = AgentAppDetailWithSite.model_validate(app_model, from_attributes=True).model_dump(mode="json")
|
||||
agent = roster_service.get_app_backing_agent(tenant_id=app_model.tenant_id, app_id=str(app_model.id))
|
||||
if not agent:
|
||||
raise AgentNotFoundError()
|
||||
payload = AppDetailWithSite.model_validate(app_model, from_attributes=True).model_dump(mode="json")
|
||||
payload.pop("bound_agent_id", None)
|
||||
payload["app_id"] = str(app_model.id)
|
||||
payload["id"] = agent.id
|
||||
@ -186,6 +233,14 @@ def _serialize_agent_app_detail(app_model) -> dict:
|
||||
|
||||
|
||||
def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
|
||||
"""Serialize Agent App lists with roster-shaped items.
|
||||
|
||||
Each item starts from the shared App list shape, then drops
|
||||
`bound_agent_id`, rewrites `id` to the backing roster Agent id, stores the
|
||||
original App id in `app_id`, and injects roster-only `role` when a backing
|
||||
Agent is present.
|
||||
"""
|
||||
|
||||
app_ids = [str(app.id) for app in app_pagination.items]
|
||||
roster_service = _agent_roster_service()
|
||||
agents_by_app_id = roster_service.load_app_backing_agents_by_app_id(
|
||||
@ -200,7 +255,7 @@ def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
|
||||
tenant_id=tenant_id,
|
||||
agent_ids=[agent.id for agent in agents_by_app_id.values()],
|
||||
)
|
||||
payload = AppPagination.model_validate(app_pagination, from_attributes=True).model_dump(mode="json")
|
||||
payload = AgentAppPagination.model_validate(app_pagination, from_attributes=True).model_dump(mode="json")
|
||||
for item in payload["data"]:
|
||||
app_id = item["id"]
|
||||
item.pop("bound_agent_id", None)
|
||||
@ -266,7 +321,7 @@ class AgentAppListApi(Resource):
|
||||
status="normal",
|
||||
)
|
||||
|
||||
app_pagination = AppService().get_paginate_apps(current_user.id, current_tenant_id, params)
|
||||
app_pagination = AppService().get_paginate_apps(current_user.id, current_tenant_id, params, db.session)
|
||||
if app_pagination is None:
|
||||
empty = AgentAppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json")
|
||||
@ -274,7 +329,7 @@ class AgentAppListApi(Resource):
|
||||
return _serialize_agent_app_pagination(app_pagination, tenant_id=current_tenant_id)
|
||||
|
||||
@console_ns.expect(console_ns.models[AgentAppCreatePayload.__name__])
|
||||
@console_ns.response(201, "Agent app created successfully", console_ns.models[AppDetailWithSite.__name__])
|
||||
@console_ns.response(201, "Agent app created successfully", console_ns.models[AgentAppDetailWithSite.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@ -302,7 +357,7 @@ class AgentAppListApi(Resource):
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>")
|
||||
class AgentAppApi(Resource):
|
||||
@console_ns.response(200, "Agent app detail", console_ns.models[AppDetailWithSite.__name__])
|
||||
@console_ns.response(200, "Agent app detail", console_ns.models[AgentAppDetailWithSite.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -313,7 +368,7 @@ class AgentAppApi(Resource):
|
||||
return _serialize_agent_app_detail(app_model)
|
||||
|
||||
@console_ns.expect(console_ns.models[AgentAppUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Agent app updated successfully", console_ns.models[AppDetailWithSite.__name__])
|
||||
@console_ns.response(200, "Agent app updated successfully", console_ns.models[AgentAppDetailWithSite.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@ -353,7 +408,7 @@ class AgentAppApi(Resource):
|
||||
@console_ns.route("/agent/<uuid:agent_id>/copy")
|
||||
class AgentAppCopyApi(Resource):
|
||||
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||
@console_ns.response(201, "Agent app copied successfully", console_ns.models[AppDetailWithSite.__name__])
|
||||
@console_ns.response(201, "Agent app copied successfully", console_ns.models[AgentAppDetailWithSite.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@ -416,6 +471,7 @@ class AgentLogsApi(Resource):
|
||||
try:
|
||||
payload = _agent_observability_service().list_logs(
|
||||
app=app_model,
|
||||
agent_id=str(agent_id),
|
||||
params=AgentLogQueryParams(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
@ -431,6 +487,53 @@ class AgentLogsApi(Resource):
|
||||
return dump_response(AgentLogListResponse, payload)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/logs/<uuid:conversation_id>/messages")
|
||||
class AgentLogMessagesApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(AgentLogsQuery))
|
||||
@console_ns.response(200, "Agent log messages", console_ns.models[AgentLogMessageListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, current_user: Account, agent_id: UUID, conversation_id: UUID):
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
query = AgentLogsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
start, end = _parse_observability_time_range(query.start, query.end, current_user)
|
||||
try:
|
||||
payload = _agent_observability_service().list_log_messages(
|
||||
app=app_model,
|
||||
agent_id=str(agent_id),
|
||||
conversation_id=str(conversation_id),
|
||||
params=AgentLogQueryParams(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
status=query.status,
|
||||
source=query.source,
|
||||
start=start,
|
||||
end=end,
|
||||
),
|
||||
)
|
||||
except ValueError as exc:
|
||||
abort(400, description=str(exc))
|
||||
return dump_response(AgentLogMessageListResponse, payload)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/log-sources")
|
||||
class AgentLogSourcesApi(Resource):
|
||||
@console_ns.response(200, "Agent log sources", console_ns.models[AgentLogSourceListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
payload = _agent_observability_service().list_log_sources(app=app_model, agent_id=str(agent_id))
|
||||
return dump_response(AgentLogSourceListResponse, payload)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/statistics/summary")
|
||||
class AgentStatisticsSummaryApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(AgentStatisticsQuery))
|
||||
@ -452,6 +555,7 @@ class AgentStatisticsSummaryApi(Resource):
|
||||
try:
|
||||
payload = _agent_observability_service().get_statistics_summary(
|
||||
app=app_model,
|
||||
agent_id=str(agent_id),
|
||||
params=AgentStatisticsQueryParams(source=query.source, start=start, end=end, timezone=timezone),
|
||||
)
|
||||
except ValueError as exc:
|
||||
|
||||
@ -30,7 +30,7 @@ from models import Account
|
||||
from models.agent_config_entities import AgentFileRefConfig, AgentSkillRefConfig
|
||||
from models.model import App, AppMode, UploadFile
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.skill_package_service import SkillManifest, SkillPackageError, SkillPackageService
|
||||
from services.agent.skill_package_service import SkillManifest, SkillPackageError
|
||||
from services.agent.skill_standardize_service import SkillStandardizeService
|
||||
from services.agent.skill_tool_inference_service import (
|
||||
SkillToolInferenceError,
|
||||
@ -45,11 +45,18 @@ from services.agent_drive_service import (
|
||||
normalize_drive_key,
|
||||
)
|
||||
from services.agent_service import AgentService
|
||||
from services.file_service import FileService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_WORKFLOW_AGENT_DRIVE_APP_MODES = [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]
|
||||
_AGENT_SKILL_UPLOAD_PARAMS = {
|
||||
"file": {
|
||||
"in": "formData",
|
||||
"type": "file",
|
||||
"required": True,
|
||||
"description": "Skill package (.zip or .skill).",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class AgentLogQuery(BaseModel):
|
||||
@ -125,11 +132,6 @@ class AgentSkillUploadResponse(ResponseModel):
|
||||
manifest: SkillManifest
|
||||
|
||||
|
||||
class AgentSkillStandardizeResponse(ResponseModel):
|
||||
skill: AgentSkillRefConfig
|
||||
manifest: SkillManifest
|
||||
|
||||
|
||||
class AgentDriveFileResponse(ResponseModel):
|
||||
name: str
|
||||
drive_key: str
|
||||
@ -156,7 +158,6 @@ register_response_schema_models(
|
||||
AgentDriveFileCommitResponse,
|
||||
AgentDriveFileResponse,
|
||||
AgentLogResponse,
|
||||
AgentSkillStandardizeResponse,
|
||||
AgentSkillUploadResponse,
|
||||
SkillToolInferenceResult,
|
||||
)
|
||||
@ -174,30 +175,9 @@ def _agent_not_bound() -> tuple[dict[str, str], int]:
|
||||
return {"code": "agent_not_bound", "message": "no agent is bound for this app/node"}, 400
|
||||
|
||||
|
||||
def _upload_skill_for_app(*, current_user: Account):
|
||||
if "file" not in request.files:
|
||||
return {"code": "no_file", "message": "no skill file uploaded"}, 400
|
||||
if len(request.files) > 1:
|
||||
return {"code": "too_many_files", "message": "only one skill file is allowed"}, 400
|
||||
def _upload_skill_for_app(*, current_user: Account, app_model: App):
|
||||
"""Upload one skill package and commit its normalized files into the agent drive."""
|
||||
|
||||
upload = request.files["file"]
|
||||
content = upload.stream.read()
|
||||
try:
|
||||
manifest = SkillPackageService().validate_and_extract(content=content, filename=upload.filename or "")
|
||||
except SkillPackageError as exc:
|
||||
return {"code": exc.code, "message": exc.message}, exc.status_code
|
||||
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=upload.filename or "skill.zip",
|
||||
content=content,
|
||||
mimetype=upload.mimetype or "application/zip",
|
||||
user=current_user,
|
||||
)
|
||||
skill_ref = manifest.to_skill_ref(file_id=upload_file.id)
|
||||
return {"skill": skill_ref.model_dump(exclude_none=True), "manifest": manifest.model_dump()}, 201
|
||||
|
||||
|
||||
def _standardize_skill_for_app(*, current_user: Account, app_model: App):
|
||||
query = query_params_from_request(AgentDriveMutationQuery)
|
||||
agent_id = _resolve_agent_id(app_model, query.node_id)
|
||||
if not agent_id:
|
||||
@ -382,51 +362,9 @@ class AgentLogApi(Resource):
|
||||
@console_ns.route("/agent/<uuid:agent_id>/skills/upload")
|
||||
class AgentSkillUploadByAgentApi(Resource):
|
||||
@console_ns.doc("upload_agent_skill_by_agent")
|
||||
@console_ns.doc(description="Upload + validate a Skill package for an Agent App")
|
||||
@console_ns.doc(params={"agent_id": "Agent ID"})
|
||||
@console_ns.response(201, "Skill validated", console_ns.models[AgentSkillUploadResponse.__name__])
|
||||
@console_ns.response(400, "Invalid skill package")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
return _upload_skill_for_app(current_user=current_user)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/skills/upload")
|
||||
class AgentSkillUploadApi(Resource):
|
||||
@console_ns.doc("upload_agent_skill")
|
||||
@console_ns.doc(description="Upload + validate a Skill package (.zip/.skill) and extract its manifest")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(201, "Skill validated", console_ns.models[AgentSkillUploadResponse.__name__])
|
||||
@console_ns.response(400, "Invalid skill package")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=_WORKFLOW_AGENT_DRIVE_APP_MODES)
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""Validate an uploaded Skill package and persist the archive.
|
||||
|
||||
Returns a validated skill ref (to bind into the Agent soul config on save)
|
||||
plus its manifest. Standardizing into the agent drive is ENG-594.
|
||||
"""
|
||||
return _upload_skill_for_app(current_user=current_user)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/skills/standardize")
|
||||
class AgentSkillStandardizeByAgentApi(Resource):
|
||||
@console_ns.doc("standardize_agent_skill_by_agent")
|
||||
@console_ns.doc(description="Validate + standardize a Skill into an Agent App drive")
|
||||
@console_ns.doc(params={"agent_id": "Agent ID"})
|
||||
@console_ns.response(
|
||||
201,
|
||||
"Skill standardized into drive",
|
||||
console_ns.models[AgentSkillStandardizeResponse.__name__],
|
||||
)
|
||||
@console_ns.doc(description="Upload + standardize a Skill into an Agent App drive")
|
||||
@console_ns.doc(consumes=["multipart/form-data"], params={"agent_id": "Agent ID", **_AGENT_SKILL_UPLOAD_PARAMS})
|
||||
@console_ns.response(201, "Skill uploaded into drive", console_ns.models[AgentSkillUploadResponse.__name__])
|
||||
@console_ns.response(400, "Invalid skill package or no bound agent")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -435,19 +373,22 @@ class AgentSkillStandardizeByAgentApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
app_model = resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
return _standardize_skill_for_app(current_user=current_user, app_model=app_model)
|
||||
return _upload_skill_for_app(current_user=current_user, app_model=app_model)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/skills/standardize")
|
||||
class AgentSkillStandardizeApi(Resource):
|
||||
@console_ns.doc("standardize_agent_skill")
|
||||
@console_ns.doc(description="Validate + standardize a Skill into the agent drive (ENG-594)")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentDriveMutationQuery)})
|
||||
@console_ns.response(
|
||||
201,
|
||||
"Skill standardized into drive",
|
||||
console_ns.models[AgentSkillStandardizeResponse.__name__],
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/skills/upload")
|
||||
class AgentSkillUploadApi(Resource):
|
||||
@console_ns.doc("upload_agent_skill")
|
||||
@console_ns.doc(description="Upload + standardize a Skill into the agent drive")
|
||||
@console_ns.doc(
|
||||
consumes=["multipart/form-data"],
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
**query_params_from_model(AgentDriveMutationQuery),
|
||||
**_AGENT_SKILL_UPLOAD_PARAMS,
|
||||
},
|
||||
)
|
||||
@console_ns.response(201, "Skill uploaded into drive", console_ns.models[AgentSkillUploadResponse.__name__])
|
||||
@console_ns.response(400, "Invalid skill package or no bound agent")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -455,8 +396,8 @@ class AgentSkillStandardizeApi(Resource):
|
||||
@get_app_model(mode=_WORKFLOW_AGENT_DRIVE_APP_MODES)
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""Upload a Skill, validate it, and standardize it into the app agent's drive."""
|
||||
return _standardize_skill_for_app(current_user=current_user, app_model=app_model)
|
||||
"""Upload a Skill, validate it, and commit drive-backed skill files."""
|
||||
return _upload_skill_for_app(current_user=current_user, app_model=app_model)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/files")
|
||||
|
||||
@ -402,8 +402,6 @@ class AppPartial(ResponseModel):
|
||||
bound_agent_id: str | None = None
|
||||
# For Agent App responses exposed through /agent.
|
||||
app_id: str | None = None
|
||||
role: str | None = None
|
||||
active_config_is_published: bool = False
|
||||
is_starred: bool = False
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@ -457,8 +455,6 @@ class AppDetailWithSite(AppDetail):
|
||||
bound_agent_id: str | None = None
|
||||
# For Agent App responses exposed through /agent.
|
||||
app_id: str | None = None
|
||||
role: str | None = None
|
||||
active_config_is_published: bool = False
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -541,10 +537,7 @@ register_schema_models(
|
||||
ModelConfig,
|
||||
Site,
|
||||
DeletedTool,
|
||||
AppPartial,
|
||||
AppDetail,
|
||||
AppDetailWithSite,
|
||||
AppPagination,
|
||||
AppExportResponse,
|
||||
Segmentation,
|
||||
PreProcessingRule,
|
||||
@ -564,6 +557,13 @@ register_schema_models(
|
||||
LoadBalancingPayload,
|
||||
)
|
||||
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AppPartial,
|
||||
AppDetailWithSite,
|
||||
AppPagination,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@ -594,7 +594,7 @@ class AppListApi(Resource):
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
|
||||
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params, db.session)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
@ -661,7 +661,7 @@ class StarredAppListApi(Resource):
|
||||
is_created_by_me=args.is_created_by_me,
|
||||
)
|
||||
|
||||
app_pagination = AppService().get_paginate_starred_apps(current_user_id, current_tenant_id, params)
|
||||
app_pagination = AppService().get_paginate_starred_apps(current_user_id, current_tenant_id, params, db.session)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@ -11,7 +12,8 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, timezone
|
||||
from models import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -25,18 +27,22 @@ class ActivatePayload(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
email: EmailStr | None = Field(default=None)
|
||||
token: str
|
||||
name: str = Field(..., max_length=30)
|
||||
interface_language: str = Field(...)
|
||||
timezone: str = Field(...)
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
interface_language: str | None = Field(default=None)
|
||||
timezone: str | None = Field(default=None)
|
||||
|
||||
@field_validator("interface_language")
|
||||
@classmethod
|
||||
def validate_lang(cls, value: str) -> str:
|
||||
def validate_lang(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return supported_language(value)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_tz(cls, value: str) -> str:
|
||||
def validate_tz(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return timezone(value)
|
||||
|
||||
|
||||
@ -48,6 +54,8 @@ class ActivationCheckData(BaseModel):
|
||||
workspace_name: str | None
|
||||
workspace_id: str | None
|
||||
email: str | None
|
||||
account_status: str | None = None
|
||||
requires_setup: bool | None = None
|
||||
|
||||
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
@ -95,9 +103,20 @@ class ActivateCheckApi(Resource):
|
||||
workspace_name = tenant.name if tenant else None
|
||||
workspace_id = tenant.id if tenant else None
|
||||
invitee_email = data.get("email") if data else None
|
||||
account = invitation.get("account")
|
||||
account_status = account.status if account else None
|
||||
requires_setup = data.get("requires_setup")
|
||||
if requires_setup is None:
|
||||
requires_setup = account_status == AccountStatus.PENDING
|
||||
return {
|
||||
"is_valid": invitation is not None,
|
||||
"data": {"workspace_name": workspace_name, "workspace_id": workspace_id, "email": invitee_email},
|
||||
"data": {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_id": workspace_id,
|
||||
"email": invitee_email,
|
||||
"account_status": account_status,
|
||||
"requires_setup": requires_setup,
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {"is_valid": False}
|
||||
@ -126,15 +145,45 @@ class ActivateApi(Resource):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(account.email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
tenant = invitation["tenant"]
|
||||
raw_role = invitation["data"].get("role")
|
||||
try:
|
||||
role = TenantAccountRole(raw_role) if raw_role else TenantAccountRole.NORMAL
|
||||
except ValueError:
|
||||
role = TenantAccountRole.NORMAL
|
||||
if not TenantAccountRole.is_non_owner_role(role):
|
||||
role = TenantAccountRole.NORMAL
|
||||
|
||||
membership_id = db.session.scalar(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.account_id == account.id,
|
||||
)
|
||||
)
|
||||
|
||||
requires_setup = invitation["data"].get("requires_setup")
|
||||
if requires_setup is None:
|
||||
requires_setup = account.status == AccountStatus.PENDING
|
||||
|
||||
setup_fields: tuple[str, str, str] | None = None
|
||||
if requires_setup:
|
||||
if not args.name or not args.interface_language or not args.timezone:
|
||||
raise AlreadyActivateError()
|
||||
setup_fields = (args.name, args.interface_language, args.timezone)
|
||||
|
||||
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
|
||||
|
||||
account.name = args.name
|
||||
if membership_id is None:
|
||||
TenantService.create_tenant_member(tenant, account, str(role))
|
||||
|
||||
account.interface_language = args.interface_language
|
||||
account.timezone = args.timezone
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
if setup_fields:
|
||||
account.name = setup_fields[0]
|
||||
account.interface_language = setup_fields[1]
|
||||
account.timezone = setup_fields[2]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -409,6 +409,7 @@ class DatasetListApi(Resource):
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
query.page,
|
||||
query.limit,
|
||||
db.session,
|
||||
current_tenant_id,
|
||||
current_user,
|
||||
query.keyword,
|
||||
|
||||
@ -122,7 +122,7 @@ class TagListApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type), db.session)
|
||||
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
@ -146,9 +146,9 @@ class TagUpdateDeleteApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str)
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str, db.session)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id_str)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id_str, db.session)
|
||||
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
@ -164,7 +164,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
def delete(self, tag_id: UUID):
|
||||
tag_id_str = str(tag_id)
|
||||
|
||||
TagService.delete_tag(tag_id_str)
|
||||
TagService.delete_tag(tag_id_str, db.session)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -189,7 +189,8 @@ def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
),
|
||||
db.session,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -203,7 +204,8 @@ def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
),
|
||||
db.session,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@ -232,7 +232,11 @@ class MemberInviteEmailApi(Resource):
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append(
|
||||
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
|
||||
{
|
||||
"status": "already_member",
|
||||
"email": invitee_email,
|
||||
"message": "Account already in workspace.",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
|
||||
|
||||
@ -126,6 +126,7 @@ class CustomizedSnippetsApi(Resource):
|
||||
snippet_service = _snippet_service()
|
||||
snippets, total, has_more = snippet_service.get_snippets(
|
||||
tenant_id=current_tenant_id,
|
||||
session=db.session,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
|
||||
@ -174,7 +174,7 @@ class AppListApi(Resource):
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag, db.session)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
@ -191,7 +191,7 @@ class AppListApi(Resource):
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params, db.session)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
|
||||
@ -45,6 +45,13 @@ class AnnotationJobStatusResponse(ResponseModel):
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
ANNOTATION_REPLY_ACTION_PARAM = {
|
||||
"description": "Action to perform: 'enable' or 'disable'",
|
||||
"enum": ["enable", "disable"],
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
AnnotationCreatePayload,
|
||||
@ -58,10 +65,22 @@ register_response_schema_models(service_api_ns, AnnotationJobStatusResponse)
|
||||
|
||||
@service_api_ns.route("/apps/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Configure Annotation Reply",
|
||||
description=(
|
||||
"Enables or disables the annotation reply feature. Requires embedding model configuration "
|
||||
"when enabling. Executes asynchronously — use [Get Annotation Reply Job "
|
||||
"Status](/api-reference/annotations/get-annotation-reply-job-status) to track progress."
|
||||
),
|
||||
tags=["Annotations"],
|
||||
responses={
|
||||
200: "Annotation reply settings task initiated.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__])
|
||||
@service_api_ns.doc("annotation_reply_action")
|
||||
@service_api_ns.doc(description="Enable or disable annotation reply feature")
|
||||
@service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
|
||||
@service_api_ns.doc(params={"action": ANNOTATION_REPLY_ACTION_PARAM})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Action completed successfully",
|
||||
@ -92,6 +111,18 @@ class AnnotationReplyActionApi(Resource):
|
||||
|
||||
@service_api_ns.route("/apps/annotation-reply/<string:action>/status/<uuid:job_id>")
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Get Annotation Reply Job Status",
|
||||
description=(
|
||||
"Retrieves the status of an asynchronous annotation reply configuration job started by "
|
||||
"[Configure Annotation Reply](/api-reference/annotations/configure-annotation-reply)."
|
||||
),
|
||||
tags=["Annotations"],
|
||||
responses={
|
||||
200: "Successfully retrieved task status.",
|
||||
400: "`invalid_param` : The specified job does not exist.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_annotation_reply_action_status")
|
||||
@service_api_ns.doc(description="Get the status of an annotation reply action job")
|
||||
@service_api_ns.doc(params={"action": "Action type", "job_id": "Job ID"})
|
||||
@ -127,6 +158,14 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
|
||||
@service_api_ns.route("/apps/annotations")
|
||||
class AnnotationListApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="List Annotations",
|
||||
description="Retrieves a paginated list of annotations for the application. Supports keyword search filtering.",
|
||||
tags=["Annotations"],
|
||||
responses={
|
||||
200: "Successfully retrieved annotation list.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("list_annotations")
|
||||
@service_api_ns.doc(description="List annotations for the application")
|
||||
@service_api_ns.doc(params=query_params_from_model(AnnotationListQuery))
|
||||
@ -159,6 +198,17 @@ class AnnotationListApi(Resource):
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Create Annotation",
|
||||
description=(
|
||||
"Creates a new annotation. Annotations provide predefined question-answer pairs that the app "
|
||||
"can match and return directly instead of generating a response."
|
||||
),
|
||||
tags=["Annotations"],
|
||||
responses={
|
||||
201: "Annotation created successfully.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_annotation")
|
||||
@service_api_ns.doc(description="Create a new annotation")
|
||||
@ -185,6 +235,16 @@ class AnnotationListApi(Resource):
|
||||
|
||||
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Update Annotation",
|
||||
description="Updates the question and answer of an existing annotation.",
|
||||
tags=["Annotations"],
|
||||
responses={
|
||||
200: "Annotation updated successfully.",
|
||||
403: "`forbidden` : Insufficient permissions to edit annotations.",
|
||||
404: "`not_found` : Annotation does not exist.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("update_annotation")
|
||||
@service_api_ns.doc(description="Update an existing annotation")
|
||||
@ -212,6 +272,16 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Delete Annotation",
|
||||
description="Deletes an annotation and its associated hit history.",
|
||||
tags=["Annotations"],
|
||||
responses={
|
||||
204: "Annotation deleted successfully.",
|
||||
403: "`forbidden` : Insufficient permissions to edit annotations.",
|
||||
404: "`not_found` : Annotation does not exist.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("delete_annotation")
|
||||
@service_api_ns.doc(description="Delete an annotation")
|
||||
@service_api_ns.doc(params={"annotation_id": "Annotation ID"})
|
||||
|
||||
@ -33,6 +33,18 @@ register_response_schema_models(service_api_ns, Parameters, AppMetaResponse, App
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Get App Parameters",
|
||||
description=(
|
||||
"Retrieve the application's input form configuration, including feature switches, input "
|
||||
"parameter names, types, and default values."
|
||||
),
|
||||
tags=["Applications"],
|
||||
responses={
|
||||
200: "Application parameters information.",
|
||||
400: "`app_unavailable` : App unavailable or misconfigured.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_app_parameters")
|
||||
@service_api_ns.doc(description="Retrieve application input parameters and configuration")
|
||||
@service_api_ns.doc(
|
||||
@ -71,6 +83,14 @@ class AppParameterApi(Resource):
|
||||
|
||||
@service_api_ns.route("/meta")
|
||||
class AppMetaApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Get App Meta",
|
||||
description="Retrieve metadata about this application, including tool icons and other configuration details.",
|
||||
tags=["Applications"],
|
||||
responses={
|
||||
200: "Successfully retrieved application meta information.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_app_meta")
|
||||
@service_api_ns.doc(description="Get application metadata")
|
||||
@service_api_ns.doc(
|
||||
@ -92,6 +112,14 @@ class AppMetaApi(Resource):
|
||||
|
||||
@service_api_ns.route("/info")
|
||||
class AppInfoApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Get App Info",
|
||||
description="Retrieve basic information about this application, including name, description, tags, and mode.",
|
||||
tags=["Applications"],
|
||||
responses={
|
||||
200: "Basic information of the application.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_app_info")
|
||||
@service_api_ns.doc(description="Get basic application information")
|
||||
@service_api_ns.doc(
|
||||
|
||||
@ -20,6 +20,7 @@ from controllers.service_api.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.service_api.schema import binary_response, expect_with_user, multipart_file_params
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
@ -39,8 +40,31 @@ register_response_schema_models(service_api_ns, AudioBinaryResponse, AudioTransc
|
||||
|
||||
@service_api_ns.route("/audio-to-text")
|
||||
class AudioApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Convert Audio to Text",
|
||||
description=(
|
||||
"Convert audio file to text. Supported MIME types: `audio/mp3`, `audio/mpga`, `audio/m4a`, "
|
||||
"`audio/wav`, and `audio/amr`. File size limit is `30 MB`."
|
||||
),
|
||||
tags=["TTS"],
|
||||
responses={
|
||||
200: "Successfully converted audio to text.",
|
||||
400: (
|
||||
"- `app_unavailable` : App unavailable or misconfigured.\n"
|
||||
"- `provider_not_support_speech_to_text` : Model provider does not support speech-to-text.\n"
|
||||
"- `provider_not_initialize` : No valid model provider credentials found.\n"
|
||||
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
|
||||
"- `model_currently_not_support` : Current model does not support this operation.\n"
|
||||
"- `completion_request_error` : Speech recognition request failed."
|
||||
),
|
||||
413: "`audio_too_large` : Audio file size exceeded the limit.",
|
||||
415: "`unsupported_audio_type` : Audio type is not allowed.",
|
||||
500: "`internal_server_error` : Internal server error.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("audio_to_text")
|
||||
@service_api_ns.doc(description="Convert audio to text using speech-to-text")
|
||||
@service_api_ns.doc(consumes=["multipart/form-data"], params=multipart_file_params(include_user=True))
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Audio successfully transcribed",
|
||||
@ -99,7 +123,27 @@ register_schema_model(service_api_ns, TextToAudioPayload)
|
||||
|
||||
@service_api_ns.route("/text-to-audio")
|
||||
class TextApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__])
|
||||
@service_api_ns.doc(
|
||||
summary="Convert Text to Audio",
|
||||
description="Convert text to speech.",
|
||||
tags=["TTS"],
|
||||
responses={
|
||||
200: (
|
||||
"Returns the generated audio. Generator responses are streamed by the service as `audio/mpeg`; "
|
||||
"otherwise the provider output is returned directly."
|
||||
),
|
||||
400: (
|
||||
"- `app_unavailable` : App unavailable or misconfigured.\n"
|
||||
"- `provider_not_initialize` : No valid model provider credentials found.\n"
|
||||
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
|
||||
"- `model_currently_not_support` : Current model does not support this operation.\n"
|
||||
"- `completion_request_error` : Text-to-speech request failed."
|
||||
),
|
||||
500: "`internal_server_error` : Internal server error.",
|
||||
},
|
||||
)
|
||||
@expect_with_user(service_api_ns, TextToAudioPayload)
|
||||
@binary_response(service_api_ns, "audio/mpeg")
|
||||
@service_api_ns.doc("text_to_audio")
|
||||
@service_api_ns.doc(description="Convert text to audio using text-to-speech")
|
||||
@service_api_ns.doc(
|
||||
@ -110,11 +154,7 @@ class TextApi(Resource):
|
||||
500: "Internal server error",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Text successfully converted to audio",
|
||||
service_api_ns.models[AudioBinaryResponse.__name__],
|
||||
)
|
||||
@service_api_ns.response(200, "Text successfully converted to audio")
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Convert text to audio using text-to-speech.
|
||||
|
||||
@ -20,6 +20,7 @@ from controllers.service_api.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api.schema import expect_user_json, expect_with_user, json_or_event_stream_response
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -92,7 +93,33 @@ register_response_schema_models(service_api_ns, GeneratedAppResponse, SimpleResu
|
||||
|
||||
@service_api_ns.route("/completion-messages")
|
||||
class CompletionApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__])
|
||||
@service_api_ns.doc(
|
||||
summary="Send Completion Message",
|
||||
description="Send a request to the text generation application.",
|
||||
tags=["Completions"],
|
||||
responses={
|
||||
200: (
|
||||
"Successful response. The content type and structure depend on the `response_mode` parameter "
|
||||
"in the request.\n"
|
||||
"\n"
|
||||
"- If `response_mode` is `blocking`, returns `application/json` with a `CompletionResponse` "
|
||||
"object.\n"
|
||||
"- If `response_mode` is `streaming`, returns `text/event-stream` with a stream of "
|
||||
"`ChunkCompletionEvent` objects."
|
||||
),
|
||||
400: (
|
||||
"- `app_unavailable` : App unavailable or misconfigured.\n"
|
||||
"- `provider_not_initialize` : No valid model provider credentials found.\n"
|
||||
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
|
||||
"- `model_currently_not_support` : Current model unavailable.\n"
|
||||
"- `completion_request_error` : Text generation failed."
|
||||
),
|
||||
429: "`too_many_requests` : Too many concurrent requests for this app.",
|
||||
500: "`internal_server_error` : Internal server error.",
|
||||
},
|
||||
)
|
||||
@expect_with_user(service_api_ns, CompletionRequestPayload)
|
||||
@json_or_event_stream_response(service_api_ns)
|
||||
@service_api_ns.doc("create_completion")
|
||||
@service_api_ns.doc(description="Create a completion for the given prompt")
|
||||
@service_api_ns.doc(
|
||||
@ -168,6 +195,15 @@ class CompletionApi(Resource):
|
||||
|
||||
@service_api_ns.route("/completion-messages/<string:task_id>/stop")
|
||||
class CompletionStopApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Stop Completion Message Generation",
|
||||
description="Stops a completion message generation task. Only supported in `streaming` mode.",
|
||||
tags=["Completions"],
|
||||
responses={
|
||||
400: "`app_unavailable` : App unavailable or misconfigured.",
|
||||
},
|
||||
)
|
||||
@expect_user_json(service_api_ns)
|
||||
@service_api_ns.doc("stop_completion")
|
||||
@service_api_ns.doc(description="Stop a running completion task")
|
||||
@service_api_ns.doc(params={"task_id": "The ID of the task to stop"})
|
||||
@ -197,7 +233,39 @@ class CompletionStopApi(Resource):
|
||||
|
||||
@service_api_ns.route("/chat-messages")
|
||||
class ChatApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__])
|
||||
@service_api_ns.doc(
|
||||
summary="Send Chat Message",
|
||||
description="Send a request to the chat application.",
|
||||
tags=["Chats", "Chatflows"],
|
||||
responses={
|
||||
200: (
|
||||
"Successful response. The content type and structure depend on the `response_mode` parameter "
|
||||
"in the request.\n"
|
||||
"\n"
|
||||
"- If `response_mode` is `blocking`, returns `application/json` with a "
|
||||
"`ChatCompletionResponse` object.\n"
|
||||
"- If `response_mode` is `streaming`, returns `text/event-stream` with a stream of "
|
||||
"Server-Sent Events."
|
||||
),
|
||||
400: (
|
||||
"- `app_unavailable` : App unavailable or misconfigured.\n"
|
||||
"- `not_chat_app` : App mode does not match the API route.\n"
|
||||
"- `conversation_completed` : The conversation has ended.\n"
|
||||
"- `provider_not_initialize` : No valid model provider credentials found.\n"
|
||||
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
|
||||
"- `model_currently_not_support` : Current model unavailable.\n"
|
||||
"- `completion_request_error` : Text generation failed."
|
||||
),
|
||||
404: "`not_found` : Conversation does not exist.",
|
||||
429: (
|
||||
"- `too_many_requests` : Too many concurrent requests for this app.\n"
|
||||
"- `rate_limit_error` : The upstream model provider rate limit was exceeded."
|
||||
),
|
||||
500: "`internal_server_error` : Internal server error.",
|
||||
},
|
||||
)
|
||||
@expect_with_user(service_api_ns, ChatRequestPayload)
|
||||
@json_or_event_stream_response(service_api_ns)
|
||||
@service_api_ns.doc("create_chat_message")
|
||||
@service_api_ns.doc(description="Send a message in a chat conversation")
|
||||
@service_api_ns.doc(
|
||||
@ -276,6 +344,15 @@ class ChatApi(Resource):
|
||||
|
||||
@service_api_ns.route("/chat-messages/<string:task_id>/stop")
|
||||
class ChatStopApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Stop Chat Message Generation",
|
||||
description="Stops a chat message generation task. Only supported in `streaming` mode.",
|
||||
tags=["Chats", "Chatflows"],
|
||||
responses={
|
||||
400: "`not_chat_app` : App mode does not match the API route.",
|
||||
},
|
||||
)
|
||||
@expect_user_json(service_api_ns)
|
||||
@service_api_ns.doc("stop_chat_message")
|
||||
@service_api_ns.doc(description="Stop a running chat message generation")
|
||||
@service_api_ns.doc(params={"task_id": "The ID of the task to stop"})
|
||||
|
||||
@ -13,6 +13,7 @@ from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.schema import expect_user_json, expect_with_user
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
@ -145,6 +146,16 @@ register_response_schema_models(
|
||||
|
||||
@service_api_ns.route("/conversations")
|
||||
class ConversationApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="List Conversations",
|
||||
description="Retrieve the conversation list for the current user, ordered by most recently active.",
|
||||
tags=["Conversations"],
|
||||
responses={
|
||||
200: "Successfully retrieved conversations list.",
|
||||
400: "`not_chat_app` : App mode does not match the API route.",
|
||||
404: "`not_found` : Last conversation does not exist (invalid `last_id`).",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc(params=query_params_from_model(ConversationListQuery))
|
||||
@service_api_ns.doc("list_conversations")
|
||||
@service_api_ns.doc(description="List all conversations for the current user")
|
||||
@ -197,6 +208,17 @@ class ConversationApi(Resource):
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>")
|
||||
class ConversationDetailApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Delete Conversation",
|
||||
description="Delete a conversation.",
|
||||
tags=["Conversations"],
|
||||
responses={
|
||||
204: "Conversation deleted successfully.",
|
||||
400: "`not_chat_app` : App mode does not match the API route.",
|
||||
404: "`not_found` : Conversation does not exist.",
|
||||
},
|
||||
)
|
||||
@expect_user_json(service_api_ns)
|
||||
@service_api_ns.doc("delete_conversation")
|
||||
@service_api_ns.doc(description="Delete a specific conversation")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
@ -225,7 +247,20 @@ class ConversationDetailApi(Resource):
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/name")
|
||||
class ConversationRenameApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__])
|
||||
@service_api_ns.doc(
|
||||
summary="Rename Conversation",
|
||||
description=(
|
||||
"Rename a conversation or auto-generate a name. The conversation name is used for display on "
|
||||
"clients that support multiple conversations."
|
||||
),
|
||||
tags=["Conversations"],
|
||||
responses={
|
||||
200: "Conversation renamed successfully.",
|
||||
400: "`not_chat_app` : App mode does not match the API route.",
|
||||
404: "`not_found` : Conversation does not exist.",
|
||||
},
|
||||
)
|
||||
@expect_with_user(service_api_ns, ConversationRenamePayload)
|
||||
@service_api_ns.doc("rename_conversation")
|
||||
@service_api_ns.doc(description="Rename a conversation or auto-generate a name")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
@ -267,6 +302,16 @@ class ConversationRenameApi(Resource):
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/variables")
|
||||
class ConversationVariablesApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="List Conversation Variables",
|
||||
description="Retrieve variables from a specific conversation.",
|
||||
tags=["Conversations"],
|
||||
responses={
|
||||
200: "Successfully retrieved conversation variables.",
|
||||
400: "`not_chat_app` : App mode does not match the API route.",
|
||||
404: "`not_found` : Conversation does not exist.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc(params=query_params_from_model(ConversationVariablesQuery))
|
||||
@service_api_ns.doc("list_conversation_variables")
|
||||
@service_api_ns.doc(description="List all variables for a conversation")
|
||||
@ -312,7 +357,22 @@ class ConversationVariablesApi(Resource):
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>")
|
||||
class ConversationVariableDetailApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@service_api_ns.doc(
|
||||
summary="Update Conversation Variable",
|
||||
description="Update the value of a specific conversation variable. The value must match the expected type.",
|
||||
tags=["Conversations"],
|
||||
responses={
|
||||
200: "Variable updated successfully.",
|
||||
400: (
|
||||
"- `not_chat_app` : App mode does not match the API route.\n"
|
||||
"- `bad_request` : Variable value type mismatch."
|
||||
),
|
||||
404: (
|
||||
"- `not_found` : Conversation does not exist.\n- `not_found` : Conversation variable does not exist."
|
||||
),
|
||||
},
|
||||
)
|
||||
@expect_with_user(service_api_ns, ConversationVariableUpdatePayload)
|
||||
@service_api_ns.doc("update_conversation_variable")
|
||||
@service_api_ns.doc(description="Update a conversation variable's value")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})
|
||||
|
||||
@ -12,6 +12,7 @@ from controllers.common.errors import (
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.schema import multipart_file_params
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
@ -23,8 +24,27 @@ register_schema_models(service_api_ns, FileResponse)
|
||||
|
||||
@service_api_ns.route("/files/upload")
|
||||
class FileApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Upload File",
|
||||
description=(
|
||||
"Upload a file for use when sending messages, enabling multimodal understanding of images, "
|
||||
"documents, audio, and video. Uploaded files are for use by the current end-user only."
|
||||
),
|
||||
tags=["Files"],
|
||||
responses={
|
||||
201: "File uploaded successfully.",
|
||||
400: (
|
||||
"- `no_file_uploaded` : No file was provided in the request.\n"
|
||||
"- `too_many_files` : Only one file is allowed per request.\n"
|
||||
"- `filename_not_exists_error` : The uploaded file has no filename."
|
||||
),
|
||||
413: "`file_too_large` : File size exceeded.",
|
||||
415: "`unsupported_file_type` : File type not allowed.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("upload_file")
|
||||
@service_api_ns.doc(description="Upload a file for use in conversations")
|
||||
@service_api_ns.doc(consumes=["multipart/form-data"], params=multipart_file_params(include_user=True))
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.service_api.app.error import (
|
||||
FileAccessDeniedError,
|
||||
FileNotFoundError,
|
||||
)
|
||||
from controllers.service_api.schema import binary_response
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
@ -30,6 +31,26 @@ class FilePreviewQuery(BaseModel):
|
||||
register_schema_model(service_api_ns, FilePreviewQuery)
|
||||
register_response_schema_model(service_api_ns, BinaryFileResponse)
|
||||
|
||||
FILE_PREVIEW_RESPONSE_MEDIA_TYPES = [
|
||||
"application/octet-stream",
|
||||
"application/pdf",
|
||||
"audio/aac",
|
||||
"audio/flac",
|
||||
"audio/mp4",
|
||||
"audio/mpeg",
|
||||
"audio/ogg",
|
||||
"audio/wav",
|
||||
"audio/x-m4a",
|
||||
"image/gif",
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/webp",
|
||||
"text/plain",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
"video/webm",
|
||||
]
|
||||
|
||||
|
||||
@service_api_ns.route("/files/<uuid:file_id>/preview")
|
||||
class FilePreviewApi(Resource):
|
||||
@ -40,7 +61,26 @@ class FilePreviewApi(Resource):
|
||||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Download File",
|
||||
description=(
|
||||
"Preview or download uploaded files previously uploaded via the [Upload "
|
||||
"File](/api-reference/files/upload-file) API. Files can only be accessed if they belong to "
|
||||
"messages within the requesting application."
|
||||
),
|
||||
tags=["Files"],
|
||||
responses={
|
||||
200: (
|
||||
"Returns the raw file content. The `Content-Type` header is set to the file's MIME type. If "
|
||||
"`as_attachment` is `true`, the file is returned as a download with `Content-Disposition: "
|
||||
"attachment`."
|
||||
),
|
||||
403: "`file_access_denied` : Access to the requested file is denied.",
|
||||
404: "`file_not_found` : The requested file was not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc(params=query_params_from_model(FilePreviewQuery))
|
||||
@binary_response(service_api_ns, FILE_PREVIEW_RESPONSE_MEDIA_TYPES)
|
||||
@service_api_ns.doc("preview_file")
|
||||
@service_api_ns.doc(description="Preview or download a file uploaded via Service API")
|
||||
@service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
|
||||
@ -52,11 +92,7 @@ class FilePreviewApi(Resource):
|
||||
404: "File not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"File retrieved successfully",
|
||||
service_api_ns.models[BinaryFileResponse.__name__],
|
||||
)
|
||||
@service_api_ns.response(200, "File retrieved successfully")
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, file_id: UUID):
|
||||
"""
|
||||
|
||||
@ -18,6 +18,7 @@ from werkzeug.exceptions import BadRequest, NotFound
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.schema import expect_with_user
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
@ -72,6 +73,23 @@ def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
|
||||
|
||||
@service_api_ns.route("/form/human_input/<string:form_token>")
|
||||
class WorkflowHumanInputFormApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Get Human Input Form",
|
||||
description=(
|
||||
"Retrieve a paused Human Input form's contents using the `form_token` from a "
|
||||
"`human_input_required` event. Requires **WebApp** delivery."
|
||||
),
|
||||
tags=["Human Input"],
|
||||
responses={
|
||||
200: "Form contents retrieved successfully.",
|
||||
404: "`not_found` : Form not found.",
|
||||
412: (
|
||||
"- `human_input_form_submitted` : Form already submitted. Forms are one-shot; the first "
|
||||
"response wins regardless of which user submits it.\n"
|
||||
"- `human_input_form_expired` : The form's expiration time passed before submission arrived."
|
||||
),
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_human_input_form")
|
||||
@service_api_ns.doc(description="Get a paused human input form by token")
|
||||
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||
@ -101,7 +119,29 @@ class WorkflowHumanInputFormApi(Resource):
|
||||
inputs = service.resolve_form_inputs(form)
|
||||
return _jsonify_form_definition(form, inputs=inputs)
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@service_api_ns.doc(
|
||||
summary="Submit Human Input Form",
|
||||
description=(
|
||||
"Submit the recipient's response to a paused Human Input form. The workflow resumes on "
|
||||
"acceptance; use [Stream Workflow Events](/api-reference/chatflows/stream-workflow-events) "
|
||||
"to follow subsequent events. Requires **WebApp** delivery."
|
||||
),
|
||||
tags=["Human Input"],
|
||||
responses={
|
||||
200: "Form submitted successfully. The response body is an empty object.",
|
||||
400: (
|
||||
"- `bad_request` : Form recipient type is invalid.\n"
|
||||
"- `invalid_form_data` : Submission failed validation against the form definition."
|
||||
),
|
||||
404: "`not_found` : Form not found.",
|
||||
412: (
|
||||
"- `human_input_form_submitted` : Form already submitted. Forms are one-shot; the first "
|
||||
"response wins regardless of which user submits it.\n"
|
||||
"- `human_input_form_expired` : The form's expiration time passed before submission arrived."
|
||||
),
|
||||
},
|
||||
)
|
||||
@expect_with_user(service_api_ns, HumanInputFormSubmitPayload)
|
||||
@service_api_ns.doc("submit_human_input_form")
|
||||
@service_api_ns.doc(description="Submit a paused human input form by token")
|
||||
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||
|
||||
@ -12,6 +12,7 @@ from controllers.common.fields import SimpleResultStringListResponse
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.schema import expect_with_user
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.base import ResponseModel
|
||||
@ -64,6 +65,19 @@ register_response_schema_models(
|
||||
|
||||
@service_api_ns.route("/messages")
|
||||
class MessageListApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="List Conversation Messages",
|
||||
description=(
|
||||
"Returns historical chat records in a scrolling load format, with the first page returning "
|
||||
"the latest `limit` messages, i.e., in reverse order."
|
||||
),
|
||||
tags=["Conversations"],
|
||||
responses={
|
||||
200: "Successfully retrieved conversation history.",
|
||||
400: "`not_chat_app` : App mode does not match the API route.",
|
||||
404: ("- `not_found` : Conversation does not exist.\n- `not_found` : First message does not exist."),
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc(params=query_params_from_model(MessageListQuery))
|
||||
@service_api_ns.doc("list_messages")
|
||||
@service_api_ns.doc(description="List messages in a conversation")
|
||||
@ -112,7 +126,19 @@ class MessageListApi(Resource):
|
||||
|
||||
@service_api_ns.route("/messages/<uuid:message_id>/feedbacks")
|
||||
class MessageFeedbackApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__])
|
||||
@service_api_ns.doc(
|
||||
summary="Submit Message Feedback",
|
||||
description=(
|
||||
"Submit feedback for a message. End users can rate messages as `like` or `dislike`, and "
|
||||
"optionally provide text feedback. Pass `null` for `rating` to revoke previously submitted "
|
||||
"feedback."
|
||||
),
|
||||
tags=["Feedback"],
|
||||
responses={
|
||||
404: "`not_found` : Message does not exist.",
|
||||
},
|
||||
)
|
||||
@expect_with_user(service_api_ns, MessageFeedbackPayload)
|
||||
@service_api_ns.response(200, "Feedback submitted successfully", service_api_ns.models[ResultResponse.__name__])
|
||||
@service_api_ns.doc("create_message_feedback")
|
||||
@service_api_ns.doc(description="Submit feedback for a message")
|
||||
@ -150,6 +176,17 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
@service_api_ns.route("/app/feedbacks")
|
||||
class AppGetFeedbacksApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="List App Feedbacks",
|
||||
description=(
|
||||
"Retrieve a paginated list of all feedback submitted for messages in this application, "
|
||||
"including both end-user and admin feedback."
|
||||
),
|
||||
tags=["Feedback"],
|
||||
responses={
|
||||
200: "A list of application feedbacks.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc(params=query_params_from_model(FeedbackListQuery))
|
||||
@service_api_ns.doc("get_app_feedbacks")
|
||||
@service_api_ns.doc(description="Get all feedbacks for the application")
|
||||
@ -177,6 +214,20 @@ class AppGetFeedbacksApi(Resource):
|
||||
|
||||
@service_api_ns.route("/messages/<uuid:message_id>/suggested")
|
||||
class MessageSuggestedApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Get Next Suggested Questions",
|
||||
description="Get next questions suggestions for the current message.",
|
||||
tags=["Chats", "Chatflows"],
|
||||
responses={
|
||||
200: "Successfully retrieved suggested questions.",
|
||||
400: (
|
||||
"- `not_chat_app` : App mode does not match the API route.\n"
|
||||
"- `bad_request` : Suggested questions feature is disabled."
|
||||
),
|
||||
404: "`not_found` : Message does not exist.",
|
||||
500: "`internal_server_error` : Internal server error.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Suggested questions retrieved successfully",
|
||||
|
||||
@ -17,6 +17,18 @@ register_response_schema_models(service_api_ns, SiteResponse)
|
||||
class AppSiteApi(Resource):
|
||||
"""Resource for app sites."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Get App WebApp Settings",
|
||||
description=(
|
||||
"Retrieve the WebApp settings of this application, including site configuration, theme, and "
|
||||
"customization options."
|
||||
),
|
||||
tags=["Applications"],
|
||||
responses={
|
||||
200: "WebApp settings of the application.",
|
||||
403: "`forbidden` : Site not found for this application or the workspace has been archived.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_app_site")
|
||||
@service_api_ns.doc(description="Get application site configuration")
|
||||
@service_api_ns.doc(
|
||||
|
||||
@ -21,6 +21,11 @@ from controllers.service_api.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api.schema import (
|
||||
expect_user_json,
|
||||
expect_with_user,
|
||||
json_or_event_stream_response,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -177,14 +182,15 @@ register_response_schema_models(
|
||||
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
|
||||
status = _enum_value(workflow_run.status)
|
||||
raw_outputs = workflow_run.outputs_dict
|
||||
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
|
||||
outputs: dict = {}
|
||||
elif isinstance(raw_outputs, dict):
|
||||
outputs = raw_outputs
|
||||
elif isinstance(raw_outputs, Mapping):
|
||||
outputs = dict(raw_outputs)
|
||||
else:
|
||||
outputs = {}
|
||||
match raw_outputs:
|
||||
case _ if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
|
||||
outputs: dict = {}
|
||||
case dict():
|
||||
outputs = raw_outputs
|
||||
case _ if isinstance(raw_outputs, Mapping):
|
||||
outputs = dict(raw_outputs)
|
||||
case _:
|
||||
outputs = {}
|
||||
return WorkflowRunResponse.model_validate(
|
||||
{
|
||||
"id": workflow_run.id,
|
||||
@ -208,6 +214,16 @@ def _serialize_workflow_log_pagination(pagination) -> dict:
|
||||
|
||||
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
|
||||
class WorkflowRunDetailApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Get Workflow Run Detail",
|
||||
description="Retrieve the current execution results of a workflow task based on the workflow execution ID.",
|
||||
tags=["Chatflows", "Workflows"],
|
||||
responses={
|
||||
200: "Successfully retrieved workflow run details.",
|
||||
400: "`not_workflow_app` : App mode does not match the API route.",
|
||||
404: "`not_found` : Workflow run not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_workflow_run_detail")
|
||||
@service_api_ns.doc(description="Get workflow run details")
|
||||
@service_api_ns.doc(params={"workflow_run_id": "Workflow run ID"})
|
||||
@ -249,7 +265,37 @@ class WorkflowRunDetailApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
|
||||
@service_api_ns.doc(
|
||||
summary="Run Workflow",
|
||||
description="Execute a workflow. Cannot be executed without a published workflow.",
|
||||
tags=["Workflows"],
|
||||
responses={
|
||||
200: (
|
||||
"Successful response. The content type and structure depend on the `response_mode` parameter "
|
||||
"in the request.\n"
|
||||
"\n"
|
||||
"- If `response_mode` is `blocking`, returns `application/json` with a "
|
||||
"`WorkflowBlockingResponse` object.\n"
|
||||
"- If `response_mode` is `streaming`, returns `text/event-stream` with a stream of "
|
||||
"`ChunkWorkflowEvent` objects."
|
||||
),
|
||||
400: (
|
||||
"- `not_workflow_app` : App mode does not match the API route.\n"
|
||||
"- `provider_not_initialize` : No valid model provider credentials found.\n"
|
||||
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
|
||||
"- `model_currently_not_support` : Current model unavailable.\n"
|
||||
"- `completion_request_error` : Workflow execution request failed.\n"
|
||||
"- `invalid_param` : Invalid parameter value."
|
||||
),
|
||||
429: (
|
||||
"- `too_many_requests` : Too many concurrent requests for this app.\n"
|
||||
"- `rate_limit_error` : The upstream model provider rate limit was exceeded."
|
||||
),
|
||||
500: "`internal_server_error` : Internal server error.",
|
||||
},
|
||||
)
|
||||
@expect_with_user(service_api_ns, WorkflowRunPayload)
|
||||
@json_or_event_stream_response(service_api_ns)
|
||||
@service_api_ns.doc("run_workflow")
|
||||
@service_api_ns.doc(description="Execute a workflow")
|
||||
@service_api_ns.doc(
|
||||
@ -313,7 +359,42 @@ class WorkflowRunApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/<string:workflow_id>/run")
|
||||
class WorkflowRunByIdApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
|
||||
@service_api_ns.doc(
|
||||
summary="Run Workflow by ID",
|
||||
description=(
|
||||
"Execute a specific workflow version identified by its ID. Useful for running a particular "
|
||||
"published version of the workflow."
|
||||
),
|
||||
tags=["Workflows"],
|
||||
responses={
|
||||
200: (
|
||||
"Successful response. The content type and structure depend on the `response_mode` parameter "
|
||||
"in the request.\n"
|
||||
"\n"
|
||||
"- If `response_mode` is `blocking`, returns `application/json` with a "
|
||||
"`WorkflowBlockingResponse` object.\n"
|
||||
"- If `response_mode` is `streaming`, returns `text/event-stream` with a stream of "
|
||||
"`ChunkWorkflowEvent` objects."
|
||||
),
|
||||
400: (
|
||||
"- `not_workflow_app` : App mode does not match the API route.\n"
|
||||
"- `bad_request` : Workflow is a draft or has an invalid ID format.\n"
|
||||
"- `provider_not_initialize` : No valid model provider credentials found.\n"
|
||||
"- `provider_quota_exceeded` : Model provider quota exhausted.\n"
|
||||
"- `model_currently_not_support` : Current model unavailable.\n"
|
||||
"- `completion_request_error` : Workflow execution request failed.\n"
|
||||
"- `invalid_param` : Required parameter missing or invalid."
|
||||
),
|
||||
404: "`not_found` : Workflow not found.",
|
||||
429: (
|
||||
"- `too_many_requests` : Too many concurrent requests for this app.\n"
|
||||
"- `rate_limit_error` : The upstream model provider rate limit was exceeded."
|
||||
),
|
||||
500: "`internal_server_error` : Internal server error.",
|
||||
},
|
||||
)
|
||||
@expect_with_user(service_api_ns, WorkflowRunPayload)
|
||||
@json_or_event_stream_response(service_api_ns)
|
||||
@service_api_ns.doc("run_workflow_by_id")
|
||||
@service_api_ns.doc(description="Execute a specific workflow by ID")
|
||||
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
|
||||
@ -387,6 +468,18 @@ class WorkflowRunByIdApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/tasks/<string:task_id>/stop")
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Stop Workflow Task",
|
||||
description="Stop a running workflow task. Only supported in `streaming` mode.",
|
||||
tags=["Workflows"],
|
||||
responses={
|
||||
400: (
|
||||
"- `not_workflow_app` : App mode does not match the API route.\n"
|
||||
"- `invalid_param` : Required parameter missing or invalid."
|
||||
),
|
||||
},
|
||||
)
|
||||
@expect_user_json(service_api_ns)
|
||||
@service_api_ns.doc("stop_workflow_task")
|
||||
@service_api_ns.doc(description="Stop a running workflow task")
|
||||
@service_api_ns.doc(params={"task_id": "Task ID to stop"})
|
||||
@ -417,6 +510,14 @@ class WorkflowTaskStopApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/logs")
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="List Workflow Logs",
|
||||
description="Retrieve paginated workflow execution logs with filtering options.",
|
||||
tags=["Chatflows", "Workflows"],
|
||||
responses={
|
||||
200: "Successfully retrieved workflow logs.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc(params=query_params_from_model(WorkflowLogQuery))
|
||||
@service_api_ns.doc("get_workflow_logs")
|
||||
@service_api_ns.doc(description="Get workflow execution logs")
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.common.fields import EventStreamResponse
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_model, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.schema import event_stream_response
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
@ -44,6 +45,24 @@ register_response_schema_model(service_api_ns, EventStreamResponse)
|
||||
class WorkflowEventsApi(Resource):
|
||||
"""Service API for getting workflow execution events after resume."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Stream Workflow Events",
|
||||
description=(
|
||||
"Resume the Server-Sent Events stream for a workflow run after a pause or a dropped SSE "
|
||||
"connection. For runs that have already finished, the stream emits a single "
|
||||
"`workflow_finished` event and closes."
|
||||
),
|
||||
tags=["Chatflows", "Workflows"],
|
||||
responses={
|
||||
200: (
|
||||
"Server-Sent Events stream. Each event is delivered as `data: {JSON}\\n\\n`. Event payloads "
|
||||
"follow the same schemas as the original streaming response."
|
||||
),
|
||||
400: "`not_workflow_app` : Please check if your app mode matches the right API route.",
|
||||
404: "`not_found` : Workflow run not found.",
|
||||
},
|
||||
)
|
||||
@event_stream_response(service_api_ns)
|
||||
@service_api_ns.doc("get_workflow_events")
|
||||
@service_api_ns.doc(description="Get workflow execution events stream after resume")
|
||||
@service_api_ns.doc(params={"task_id": "Workflow run ID"})
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, override
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, GetJsonSchemaHandler, RootModel, field_validator, model_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -79,6 +79,13 @@ class DocumentStatusPayload(BaseModel):
|
||||
document_ids: list[str] = Field(default_factory=list, description="Document IDs to update")
|
||||
|
||||
|
||||
DOCUMENT_STATUS_ACTION_PARAM = {
|
||||
"description": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'",
|
||||
"enum": ["enable", "disable", "archive", "un_archive"],
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
|
||||
class TagNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=50)
|
||||
|
||||
@ -114,6 +121,45 @@ class TagUnbindingPayload(BaseModel):
|
||||
tag_id: str | None = None
|
||||
target_id: str
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def __get_pydantic_json_schema__(cls, _core_schema: object, _handler: GetJsonSchemaHandler) -> dict[str, object]:
|
||||
tag_id_property = {
|
||||
"description": "Legacy single tag ID accepted by the Service API.",
|
||||
"type": "string",
|
||||
}
|
||||
tag_ids_property = {
|
||||
"description": "Tag IDs to unbind. Use this for new integrations.",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1,
|
||||
"type": "array",
|
||||
}
|
||||
target_id_property = {"title": "Target Id", "type": "string"}
|
||||
return {
|
||||
"anyOf": [
|
||||
{
|
||||
"properties": {
|
||||
"tag_id": tag_id_property,
|
||||
"tag_ids": tag_ids_property,
|
||||
"target_id": target_id_property,
|
||||
},
|
||||
"required": ["tag_id", "target_id"],
|
||||
"type": "object",
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"tag_id": {**tag_id_property, "nullable": True},
|
||||
"tag_ids": tag_ids_property,
|
||||
"target_id": target_id_property,
|
||||
},
|
||||
"required": ["tag_ids", "target_id"],
|
||||
"type": "object",
|
||||
},
|
||||
],
|
||||
"description": "Accepts either the legacy tag_id payload or the normalized tag_ids payload.",
|
||||
"title": cls.__name__,
|
||||
}
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_legacy_tag_id(cls, data: object) -> object:
|
||||
@ -204,6 +250,14 @@ register_response_schema_models(
|
||||
class DatasetListApi(DatasetApiResource):
|
||||
"""Resource for datasets."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="List Knowledge Bases",
|
||||
description="Returns a paginated list of knowledge bases. Supports filtering by keyword and tags.",
|
||||
tags=["Knowledge Bases"],
|
||||
responses={
|
||||
200: "List of knowledge bases.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("list_datasets")
|
||||
@service_api_ns.doc(description="List all datasets")
|
||||
@service_api_ns.doc(
|
||||
@ -262,6 +316,19 @@ class DatasetListApi(DatasetApiResource):
|
||||
}
|
||||
return dump_response(DatasetListResponse, response), 200
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Create an Empty Knowledge Base",
|
||||
description=(
|
||||
"Create a new empty knowledge base. After creation, use [Create Document by "
|
||||
"Text](/api-reference/documents/create-document-by-text) or [Create Document by "
|
||||
"File](/api-reference/documents/create-document-by-file) to add documents."
|
||||
),
|
||||
tags=["Knowledge Bases"],
|
||||
responses={
|
||||
200: "Knowledge base created successfully.",
|
||||
409: "`dataset_name_duplicate` : The dataset name already exists. Please modify your dataset name.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset")
|
||||
@service_api_ns.doc(description="Create a new dataset")
|
||||
@ -327,6 +394,19 @@ class DatasetListApi(DatasetApiResource):
|
||||
class DatasetApi(DatasetApiResource):
|
||||
"""Resource for dataset."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Get Knowledge Base",
|
||||
description=(
|
||||
"Retrieve detailed information about a specific knowledge base, including its embedding "
|
||||
"model, retrieval configuration, and document statistics."
|
||||
),
|
||||
tags=["Knowledge Bases"],
|
||||
responses={
|
||||
200: "Knowledge base details.",
|
||||
403: "`forbidden` : Insufficient permissions to access this knowledge base.",
|
||||
404: "`not_found` : Dataset not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_dataset")
|
||||
@service_api_ns.doc(description="Get a specific dataset by ID")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@ -392,6 +472,19 @@ class DatasetApi(DatasetApiResource):
|
||||
200,
|
||||
)
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Update Knowledge Base",
|
||||
description=(
|
||||
"Update the name, description, permissions, or retrieval settings of an existing knowledge "
|
||||
"base. Only the fields provided in the request body are updated."
|
||||
),
|
||||
tags=["Knowledge Bases"],
|
||||
responses={
|
||||
200: "Knowledge base updated successfully.",
|
||||
403: "`forbidden` : Insufficient permissions to access this knowledge base.",
|
||||
404: "`not_found` : Dataset not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset")
|
||||
@service_api_ns.doc(description="Update an existing dataset")
|
||||
@ -474,6 +567,22 @@ class DatasetApi(DatasetApiResource):
|
||||
|
||||
return DatasetDetailWithPartialMembersResponse.model_validate(result_data).model_dump(mode="json"), 200
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Delete Knowledge Base",
|
||||
description=(
|
||||
"Permanently delete a knowledge base and all its documents. The knowledge base must not be "
|
||||
"in use by any application."
|
||||
),
|
||||
tags=["Knowledge Bases"],
|
||||
responses={
|
||||
204: "Success.",
|
||||
404: "`not_found` : Dataset not found.",
|
||||
409: (
|
||||
"`dataset_in_use` : The knowledge base is being used by some apps. Please remove it from the "
|
||||
"apps before deleting."
|
||||
),
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("delete_dataset")
|
||||
@service_api_ns.doc(description="Delete a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@ -519,6 +628,17 @@ class DatasetApi(DatasetApiResource):
|
||||
class DocumentStatusApi(DatasetApiResource):
|
||||
"""Resource for batch document status operations."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Update Document Status in Batch",
|
||||
description="Enable, disable, archive, or unarchive multiple documents at once.",
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
200: "Documents updated successfully.",
|
||||
400: "`invalid_action` : Invalid action.",
|
||||
403: "`forbidden` : Insufficient permissions.",
|
||||
404: "`not_found` : Knowledge base not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Document status updated successfully",
|
||||
@ -529,7 +649,7 @@ class DocumentStatusApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'",
|
||||
"action": DOCUMENT_STATUS_ACTION_PARAM,
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
@ -591,6 +711,14 @@ class DocumentStatusApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/tags")
|
||||
class DatasetTagsApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="List Knowledge Tags",
|
||||
description="Returns the list of all knowledge base tags in the workspace.",
|
||||
tags=["Tags"],
|
||||
responses={
|
||||
200: "List of tags.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("list_dataset_tags")
|
||||
@service_api_ns.doc(description="Get all knowledge type tags")
|
||||
@service_api_ns.doc(
|
||||
@ -612,6 +740,14 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
tags = TagService.get_tags(db.session(), "knowledge", cid)
|
||||
return dump_response(KnowledgeTagListResponse, tags), 200
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Create Knowledge Tag",
|
||||
description="Create a new tag for organizing knowledge bases.",
|
||||
tags=["Tags"],
|
||||
responses={
|
||||
200: "Tag created successfully.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@service_api_ns.doc(description="Add a knowledge type tag")
|
||||
@ -634,7 +770,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE), db.session)
|
||||
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
@ -642,6 +778,14 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
)
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Update Knowledge Tag",
|
||||
description="Rename an existing knowledge base tag.",
|
||||
tags=["Tags"],
|
||||
responses={
|
||||
200: "Tag updated successfully.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset_tag")
|
||||
@service_api_ns.doc(description="Update a knowledge type tag")
|
||||
@ -664,9 +808,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag_id = payload.tag_id
|
||||
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id)
|
||||
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id, db.session)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id, db.session)
|
||||
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
@ -674,6 +818,14 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
)
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Delete Knowledge Tag",
|
||||
description="Permanently delete a knowledge base tag. Does not delete the knowledge bases that were tagged.",
|
||||
tags=["Tags"],
|
||||
responses={
|
||||
204: "Success.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
|
||||
@service_api_ns.doc("delete_dataset_tag")
|
||||
@service_api_ns.doc(description="Delete a knowledge type tag")
|
||||
@ -688,13 +840,21 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
def delete(self, _):
|
||||
"""Delete a knowledge type tag."""
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
TagService.delete_tag(payload.tag_id, db.session)
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/binding")
|
||||
class DatasetTagBindingApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="Create Tag Binding",
|
||||
description="Bind one or more tags to a knowledge base. A knowledge base can have multiple tags.",
|
||||
tags=["Tags"],
|
||||
responses={
|
||||
204: "Success.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__])
|
||||
@service_api_ns.doc("bind_dataset_tags")
|
||||
@service_api_ns.doc(description="Bind tags to a dataset")
|
||||
@ -713,7 +873,8 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
|
||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE),
|
||||
db.session,
|
||||
)
|
||||
|
||||
return "", 204
|
||||
@ -721,6 +882,14 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="Delete Tag Binding",
|
||||
description="Remove one or more tags from a knowledge base.",
|
||||
tags=["Tags"],
|
||||
responses={
|
||||
204: "Success.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
|
||||
@service_api_ns.doc("unbind_dataset_tags")
|
||||
@service_api_ns.doc(description="Unbind tags from a dataset")
|
||||
@ -739,7 +908,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE),
|
||||
db.session,
|
||||
)
|
||||
|
||||
return "", 204
|
||||
@ -747,6 +917,14 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/tags")
|
||||
class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="Get Knowledge Base Tags",
|
||||
description="Returns the list of tags bound to a specific knowledge base.",
|
||||
tags=["Tags"],
|
||||
responses={
|
||||
200: "Tags bound to the knowledge base.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_dataset_tags_binding_status")
|
||||
@service_api_ns.doc(description="Get tags bound to a specific dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@ -766,6 +944,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags = TagService.get_tags_by_target_id(
|
||||
"knowledge", current_user.current_tenant_id, str(dataset_id), db.session
|
||||
)
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200
|
||||
|
||||
@ -8,11 +8,12 @@ deprecated in generated API docs so clients migrate toward the canonical paths.
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Literal, Self
|
||||
from copy import deepcopy
|
||||
from typing import Any, Literal, Self, override
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request, send_file
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, GetJsonSchemaHandler, field_validator, model_validator
|
||||
from sqlalchemy import desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -39,6 +40,7 @@ from controllers.service_api.dataset.error import (
|
||||
DocumentIndexingError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.service_api.schema import binary_response
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
@ -104,6 +106,36 @@ class DocumentTextUpdate(BaseModel):
|
||||
raise ValueError("Invalid doc_form.")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def __get_pydantic_json_schema__(cls, core_schema: Any, handler: GetJsonSchemaHandler) -> dict[str, Any]:
|
||||
schema = handler.resolve_ref_schema(handler(core_schema))
|
||||
properties = schema.get("properties")
|
||||
if not isinstance(properties, dict):
|
||||
return schema
|
||||
|
||||
text_branch_properties = deepcopy(properties)
|
||||
text_branch_properties["text"] = _non_null_property_schema(properties.get("text"))
|
||||
text_branch_properties["name"] = _non_null_property_schema(properties.get("name"))
|
||||
|
||||
no_text_branch_properties = deepcopy(properties)
|
||||
no_text_branch_properties["text"] = {"type": "null"}
|
||||
|
||||
return {
|
||||
**schema,
|
||||
"anyOf": [
|
||||
{
|
||||
"properties": text_branch_properties,
|
||||
"required": ["name", "text"],
|
||||
"type": "object",
|
||||
},
|
||||
{
|
||||
"properties": no_text_branch_properties,
|
||||
"type": "object",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_text_and_name(self) -> Self:
|
||||
if self.text is not None and self.name is None:
|
||||
@ -111,6 +143,24 @@ class DocumentTextUpdate(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
def _non_null_property_schema(property_schema: object) -> dict[str, Any]:
|
||||
if not isinstance(property_schema, dict):
|
||||
return {}
|
||||
|
||||
any_of = property_schema.get("anyOf")
|
||||
if isinstance(any_of, list):
|
||||
non_null_candidates = [
|
||||
candidate for candidate in any_of if isinstance(candidate, dict) and candidate.get("type") != "null"
|
||||
]
|
||||
if len(non_null_candidates) == 1:
|
||||
return {
|
||||
**{key: value for key, value in property_schema.items() if key != "anyOf"},
|
||||
**deepcopy(non_null_candidates[0]),
|
||||
}
|
||||
|
||||
return deepcopy(property_schema)
|
||||
|
||||
|
||||
class DocumentListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
@ -351,6 +401,24 @@ def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for the canonical text document creation route."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Create Document by Text",
|
||||
description=(
|
||||
"Create a document from raw text content. The document is processed asynchronously — use the "
|
||||
"returned `batch` ID with [Get Document Indexing Status](/api-reference/documents/"
|
||||
"get-document-indexing-status) to track progress."
|
||||
),
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
200: "Document created successfully.",
|
||||
400: (
|
||||
"- `provider_not_initialize` : No valid model provider credentials found. Please go to "
|
||||
"Settings -> Model Provider to complete your provider credentials.\n"
|
||||
"- `invalid_param` : Knowledge base does not exist. / indexing_technique is required. / "
|
||||
"Invalid doc_form (must be `text_model`, `hierarchical_model`, or `qa_model`)."
|
||||
),
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text")
|
||||
@service_api_ns.doc(description="Create a new document by providing text content")
|
||||
@ -409,6 +477,25 @@ class DeprecatedDocumentAddByTextApi(DatasetApiResource):
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for the canonical text document update route."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Update Document by Text",
|
||||
description=(
|
||||
"Update an existing document's text content, name, or processing configuration. Re-triggers "
|
||||
"indexing if content changes — use the returned `batch` ID with [Get Document Indexing "
|
||||
"Status](/api-reference/documents/get-document-indexing-status) to track progress."
|
||||
),
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
200: "Document updated successfully.",
|
||||
400: (
|
||||
"- `provider_not_initialize` : No valid model provider credentials found. Please go to "
|
||||
"Settings -> Model Provider to complete your provider credentials.\n"
|
||||
"- `invalid_param` : Knowledge base does not exist, name is required when text is "
|
||||
"provided, or invalid doc_form (must be `text_model`, `hierarchical_model`, or "
|
||||
"`qa_model`)."
|
||||
),
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@service_api_ns.doc(description="Update an existing document by providing text content")
|
||||
@ -463,11 +550,42 @@ class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_file",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-file",
|
||||
doc={
|
||||
"post": {
|
||||
"deprecated": True,
|
||||
"description": (
|
||||
"Deprecated legacy alias for creating a new document by uploading a file. "
|
||||
"Use /datasets/{dataset_id}/document/create-by-file instead."
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-file")
|
||||
class DocumentAddByFileApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Create Document by File",
|
||||
description=(
|
||||
"Create a document by uploading a file. Supports common document formats (PDF, TXT, DOCX, "
|
||||
"etc.). Processing is asynchronous — use the returned `batch` ID with [Get Document "
|
||||
"Indexing Status](/api-reference/documents/get-document-indexing-status) to track progress."
|
||||
),
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
200: "Document created successfully.",
|
||||
400: (
|
||||
"- `no_file_uploaded` : Please upload your file.\n"
|
||||
"- `too_many_files` : Only one file is allowed.\n"
|
||||
"- `filename_not_exists_error` : The specified filename does not exist.\n"
|
||||
"- `provider_not_initialize` : No valid model provider credentials found. Please go to "
|
||||
"Settings -> Model Provider to complete your provider credentials.\n"
|
||||
"- `invalid_param` : Knowledge base does not exist, external datasets not supported, "
|
||||
"file too large, unsupported file type, missing required fields, or invalid doc_form "
|
||||
"(must be `text_model`, `hierarchical_model`, or `qa_model`)."
|
||||
),
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("create_document_by_file")
|
||||
@service_api_ns.doc(description="Create a new document by uploading a file")
|
||||
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_CREATE_BY_FILE_PARAMS)
|
||||
@ -658,6 +776,27 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID
|
||||
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Deprecated resource aliases for file document updates."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Update Document by File",
|
||||
description=(
|
||||
"Update an existing document by uploading a new file. Re-triggers indexing — use the returned "
|
||||
"`batch` ID with [Get Document Indexing Status](/api-reference/documents/"
|
||||
"get-document-indexing-status) to track progress."
|
||||
),
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
200: "Document updated successfully.",
|
||||
400: (
|
||||
"- `too_many_files` : Only one file is allowed.\n"
|
||||
"- `filename_not_exists_error` : The specified filename does not exist.\n"
|
||||
"- `provider_not_initialize` : No valid model provider credentials found. Please go to "
|
||||
"Settings -> Model Provider to complete your provider credentials.\n"
|
||||
"- `invalid_param` : Knowledge base does not exist, external datasets not supported, "
|
||||
"file too large, unsupported file type, or invalid doc_form (must be `text_model`, "
|
||||
"`hierarchical_model`, or `qa_model`)."
|
||||
),
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("update_document_by_file_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
@ -686,6 +825,18 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="List Documents",
|
||||
description=(
|
||||
"Returns a paginated list of documents in the knowledge base. Supports filtering by keyword "
|
||||
"and indexing status."
|
||||
),
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
200: "List of documents.",
|
||||
404: "`not_found` : Knowledge base not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("list_documents")
|
||||
@service_api_ns.doc(description="List all documents in a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", **query_params_from_model(DocumentListQuery)})
|
||||
@ -746,6 +897,19 @@ class DocumentListApi(DatasetApiResource):
|
||||
class DocumentBatchDownloadZipApi(DatasetApiResource):
|
||||
"""Download multiple uploaded-file documents as a single ZIP archive."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Download Documents as ZIP",
|
||||
description=(
|
||||
"Download multiple uploaded-file documents as a single ZIP archive. Accepts up to `100` document IDs."
|
||||
),
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
200: "ZIP archive containing the requested documents.",
|
||||
403: "`forbidden` : Insufficient permissions.",
|
||||
404: "`not_found` : Document or dataset not found.",
|
||||
},
|
||||
)
|
||||
@binary_response(service_api_ns, "application/zip")
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
@service_api_ns.doc("download_documents_as_zip")
|
||||
@service_api_ns.doc(description="Download selected uploaded documents as a single ZIP archive")
|
||||
@ -758,11 +922,7 @@ class DocumentBatchDownloadZipApi(DatasetApiResource):
|
||||
404: "Document or dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"ZIP archive generated successfully",
|
||||
service_api_ns.models[BinaryFileResponse.__name__],
|
||||
)
|
||||
@service_api_ns.response(200, "ZIP archive generated successfully")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id: UUID):
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(service_api_ns.payload or {})
|
||||
@ -789,6 +949,20 @@ class DocumentBatchDownloadZipApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="Get Document Indexing Status",
|
||||
description=(
|
||||
"Check the indexing progress of documents in a batch. Returns the current processing stage "
|
||||
"and chunk completion counts for each document. Poll this endpoint until `indexing_status` "
|
||||
"reaches `completed` or `error`. The status progresses through: `waiting` → `parsing` → "
|
||||
"`cleaning` → `splitting` → `indexing` → `completed`."
|
||||
),
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
200: "Indexing status for documents in the batch.",
|
||||
404: "`not_found` : Knowledge base not found. / Documents not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_document_indexing_status")
|
||||
@service_api_ns.doc(description="Get indexing status for documents in a batch")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "batch": "Batch ID"})
|
||||
@ -861,6 +1035,16 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
class DocumentDownloadApi(DatasetApiResource):
|
||||
"""Return a signed download URL for a document's original uploaded file."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Download Document",
|
||||
description="Get a signed download URL for a document's original uploaded file.",
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
200: "Download URL generated successfully.",
|
||||
403: "`forbidden` : No permission to access this document.",
|
||||
404: "`not_found` : Document not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_document_download_url")
|
||||
@service_api_ns.doc(description="Get a signed download URL for a document's original uploaded file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@ -895,6 +1079,24 @@ class DocumentDownloadApi(DatasetApiResource):
|
||||
class DocumentApi(DatasetApiResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Get Document",
|
||||
description=(
|
||||
"Retrieve detailed information about a specific document, including its indexing status, "
|
||||
"metadata, and processing statistics."
|
||||
),
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
200: (
|
||||
"Document details. The response shape varies based on the `metadata` query parameter. When "
|
||||
"`metadata` is `only`, only `id`, `doc_type`, and `doc_metadata` are returned. When "
|
||||
"`metadata` is `without`, `doc_type` and `doc_metadata` are omitted."
|
||||
),
|
||||
400: "`invalid_metadata` : Invalid metadata value for the specified key.",
|
||||
403: "`forbidden` : No permission.",
|
||||
404: "`not_found` : Document not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_document")
|
||||
@service_api_ns.doc(description="Get a specific document by ID")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@ -1036,6 +1238,17 @@ class DocumentApi(DatasetApiResource):
|
||||
"""Update document by file on the canonical document resource."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Delete Document",
|
||||
description="Permanently delete a document and all its chunks from the knowledge base.",
|
||||
tags=["Documents"],
|
||||
responses={
|
||||
204: "Success.",
|
||||
400: "`document_indexing` : Cannot delete document during indexing.",
|
||||
403: "`archived_document_immutable` : The archived document is not editable.",
|
||||
404: "`not_found` : Document Not Exists.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("delete_document")
|
||||
@service_api_ns.doc(description="Delete a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
||||
@ -13,6 +13,32 @@ register_response_schema_models(service_api_ns, HitTestingResponse)
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
@service_api_ns.doc(
|
||||
summary="Retrieve Chunks from a Knowledge Base / Test Retrieval",
|
||||
description=(
|
||||
"Performs a search query against a knowledge base to retrieve the most relevant chunks. This "
|
||||
"endpoint can be used for both production retrieval and test retrieval."
|
||||
),
|
||||
tags=["Knowledge Bases"],
|
||||
responses={
|
||||
200: "Retrieval results.",
|
||||
400: (
|
||||
"- `dataset_not_initialized` : The dataset is still being initialized or indexing. Please "
|
||||
"wait a moment.\n"
|
||||
"- `provider_not_initialize` : No valid model provider credentials found. Please go to "
|
||||
"Settings -> Model Provider to complete your provider credentials.\n"
|
||||
"- `provider_quota_exceeded` : Your quota for Dify Hosted OpenAI has been exhausted. Please "
|
||||
"go to Settings -> Model Provider to complete your own provider credentials.\n"
|
||||
"- `model_currently_not_support` : Dify Hosted OpenAI trial currently not support the GPT-4 "
|
||||
"model.\n"
|
||||
"- `completion_request_error` : Completion request failed.\n"
|
||||
"- `invalid_param` : Invalid parameter value."
|
||||
),
|
||||
403: "`forbidden` : Insufficient permissions.",
|
||||
404: "`not_found` : Knowledge base not found.",
|
||||
500: "`internal_server_error` : An internal error occurred during retrieval.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("dataset_hit_testing")
|
||||
@service_api_ns.doc(description="Perform hit testing on a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
|
||||
@ -24,6 +24,12 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
BUILT_IN_METADATA_ACTION_PARAM = {
|
||||
"description": "Action to perform: 'enable' or 'disable'",
|
||||
"enum": ["enable", "disable"],
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
register_schema_model(service_api_ns, MetadataUpdatePayload)
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
@ -43,6 +49,17 @@ register_response_schema_models(
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="Create Metadata Field",
|
||||
description=(
|
||||
"Create a custom metadata field for the knowledge base. Metadata fields can be used to "
|
||||
"annotate documents with structured information."
|
||||
),
|
||||
tags=["Metadata"],
|
||||
responses={
|
||||
201: "Metadata field created successfully.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__])
|
||||
@service_api_ns.doc("create_dataset_metadata")
|
||||
@service_api_ns.doc(description="Create metadata for a dataset")
|
||||
@ -71,6 +88,17 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 201
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="List Metadata Fields",
|
||||
description=(
|
||||
"Returns the list of all metadata fields (both custom and built-in) for the knowledge base, "
|
||||
"along with the count of documents using each field."
|
||||
),
|
||||
tags=["Metadata"],
|
||||
responses={
|
||||
200: "Metadata fields for the knowledge base.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_dataset_metadata")
|
||||
@service_api_ns.doc(description="Get all metadata for a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@ -96,6 +124,14 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="Update Metadata Field",
|
||||
description="Rename a custom metadata field.",
|
||||
tags=["Metadata"],
|
||||
responses={
|
||||
200: "Metadata field updated successfully.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset_metadata")
|
||||
@service_api_ns.doc(description="Update metadata name")
|
||||
@ -125,6 +161,17 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 200
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Delete Metadata Field",
|
||||
description=(
|
||||
"Permanently delete a custom metadata field. Documents using this field will lose their "
|
||||
"metadata values for it."
|
||||
),
|
||||
tags=["Metadata"],
|
||||
responses={
|
||||
204: "Success.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
@service_api_ns.doc(description="Delete metadata")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
|
||||
@ -152,6 +199,16 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
|
||||
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="Get Built-in Metadata Fields",
|
||||
description=(
|
||||
"Returns the list of built-in metadata fields provided by the system (e.g., document type, source URL)."
|
||||
),
|
||||
tags=["Metadata"],
|
||||
responses={
|
||||
200: "Built-in metadata fields.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_built_in_fields")
|
||||
@service_api_ns.doc(description="Get all built-in metadata fields")
|
||||
@service_api_ns.doc(
|
||||
@ -173,9 +230,17 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="Update Built-in Metadata Field",
|
||||
description="Enable or disable built-in metadata fields for the knowledge base.",
|
||||
tags=["Metadata"],
|
||||
responses={
|
||||
200: "Built-in metadata field toggled successfully.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("toggle_built_in_field")
|
||||
@service_api_ns.doc(description="Enable or disable built-in metadata field")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "action": "Action to perform: 'enable' or 'disable'"})
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "action": BUILT_IN_METADATA_ACTION_PARAM})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Action completed successfully",
|
||||
@ -205,6 +270,17 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="Update Document Metadata in Batch",
|
||||
description=(
|
||||
"Update metadata values for multiple documents at once. Each document in the request "
|
||||
"receives the specified metadata key-value pairs."
|
||||
),
|
||||
tags=["Metadata"],
|
||||
responses={
|
||||
200: "Document metadata updated successfully.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__])
|
||||
@service_api_ns.doc("update_documents_metadata")
|
||||
@service_api_ns.doc(description="Update metadata for multiple documents")
|
||||
|
||||
@ -19,6 +19,11 @@ from controllers.common.schema import (
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.dataset.rag_pipeline.serializers import serialize_upload_file
|
||||
from controllers.service_api.schema import (
|
||||
event_stream_response,
|
||||
json_or_event_stream_response,
|
||||
multipart_file_params,
|
||||
)
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -95,6 +100,18 @@ register_response_schema_models(
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="List Datasource Plugins",
|
||||
description=(
|
||||
"List the datasource nodes configured in the knowledge pipeline. Each node includes the "
|
||||
"plugin it uses plus the metadata needed to run it."
|
||||
),
|
||||
tags=["Knowledge Pipeline"],
|
||||
responses={
|
||||
200: "List of datasource nodes configured in the pipeline.",
|
||||
404: "`not_found` : Dataset not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins")
|
||||
@service_api_ns.doc(description="List all datasource plugins for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
@ -137,6 +154,19 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Run Datasource Node",
|
||||
description=(
|
||||
"Execute a single datasource node within the knowledge pipeline. Returns a streaming "
|
||||
"response with the node execution results."
|
||||
),
|
||||
tags=["Knowledge Pipeline"],
|
||||
responses={
|
||||
200: "Streaming response with node execution events.",
|
||||
404: "`not_found` : Dataset not found.",
|
||||
},
|
||||
)
|
||||
@event_stream_response(service_api_ns)
|
||||
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
@ -195,6 +225,24 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
class PipelineRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Run Pipeline",
|
||||
description=(
|
||||
"Execute the full knowledge pipeline for a knowledge base. Supports both streaming and "
|
||||
"blocking response modes."
|
||||
),
|
||||
tags=["Knowledge Pipeline"],
|
||||
responses={
|
||||
200: (
|
||||
"Pipeline execution result. Format depends on `response_mode`: streaming returns a "
|
||||
"`text/event-stream`, blocking returns a JSON object."
|
||||
),
|
||||
403: "`forbidden` : Forbidden.",
|
||||
404: "`not_found` : Dataset not found.",
|
||||
500: "`pipeline_run_error` : Pipeline execution failed.",
|
||||
},
|
||||
)
|
||||
@json_or_event_stream_response(service_api_ns)
|
||||
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
@ -248,8 +296,24 @@ class PipelineRunApi(DatasetApiResource):
|
||||
class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
"""Resource for uploading a file to a knowledgebase pipeline."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Upload Pipeline File",
|
||||
description="Upload a file for use in a knowledge pipeline. Accepts a single file via `multipart/form-data`.",
|
||||
tags=["Knowledge Pipeline"],
|
||||
responses={
|
||||
201: "File uploaded successfully.",
|
||||
400: (
|
||||
"- `no_file_uploaded` : Please upload your file.\n"
|
||||
"- `filename_not_exists_error` : The specified filename does not exist.\n"
|
||||
"- `too_many_files` : Only one file is allowed."
|
||||
),
|
||||
413: "`file_too_large` : File size exceeded.",
|
||||
415: "`unsupported_file_type` : File type not allowed.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload")
|
||||
@service_api_ns.doc(description="Upload a file to a knowledgebase pipeline")
|
||||
@service_api_ns.doc(consumes=["multipart/form-data"], params=multipart_file_params(include_user=False))
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
|
||||
@ -128,6 +128,18 @@ register_response_schema_models(
|
||||
class SegmentApi(DatasetApiResource):
|
||||
"""Resource for segments."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Create Chunks",
|
||||
description=(
|
||||
"Create one or more chunks within a document. Each chunk can include optional keywords and an "
|
||||
"answer field (for QA-mode documents)."
|
||||
),
|
||||
tags=["Chunks"],
|
||||
responses={
|
||||
200: "Chunks created successfully.",
|
||||
404: "`not_found` : Document is not completed or is disabled.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_segments")
|
||||
@service_api_ns.doc(description="Create segments in a document")
|
||||
@ -209,6 +221,14 @@ class SegmentApi(DatasetApiResource):
|
||||
}
|
||||
return dump_response(SegmentCreateListResponse, response), 200
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="List Chunks",
|
||||
description="Returns a paginated list of chunks within a document. Supports filtering by keyword and status.",
|
||||
tags=["Chunks"],
|
||||
responses={
|
||||
200: "List of chunks.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("list_segments")
|
||||
@service_api_ns.doc(description="List segments in a document")
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@ -294,6 +314,14 @@ class SegmentApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
class DatasetSegmentApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
summary="Delete Chunk",
|
||||
description="Permanently delete a chunk from the document.",
|
||||
tags=["Chunks"],
|
||||
responses={
|
||||
204: "Success.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("delete_segment")
|
||||
@service_api_ns.doc(description="Delete a specific segment")
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@ -329,6 +357,14 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return "", 204
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Update Chunk",
|
||||
description="Update a chunk's content, keywords, or answer. Re-triggers indexing for the modified chunk.",
|
||||
tags=["Chunks"],
|
||||
responses={
|
||||
200: "Chunk updated successfully.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_segment")
|
||||
@service_api_ns.doc(description="Update a specific segment")
|
||||
@ -391,6 +427,17 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
}
|
||||
return dump_response(SegmentDetailResponse, response), 200
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Get Chunk",
|
||||
description=(
|
||||
"Retrieve detailed information about a specific chunk, including its content, keywords, and "
|
||||
"indexing status."
|
||||
),
|
||||
tags=["Chunks"],
|
||||
responses={
|
||||
200: "Chunk details.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_segment")
|
||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@ -442,6 +489,15 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
class ChildChunkApi(DatasetApiResource):
|
||||
"""Resource for child chunks."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Create Child Chunk",
|
||||
description="Create a child chunk under the specified segment.",
|
||||
tags=["Chunks"],
|
||||
responses={
|
||||
200: "Child chunk created successfully.",
|
||||
400: "`invalid_param` : Create child chunk index failed.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_child_chunk")
|
||||
@service_api_ns.doc(description="Create a new child chunk for a segment")
|
||||
@ -511,6 +567,14 @@ class ChildChunkApi(DatasetApiResource):
|
||||
|
||||
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="List Child Chunks",
|
||||
description="Returns a paginated list of child chunks under a specific parent chunk.",
|
||||
tags=["Chunks"],
|
||||
responses={
|
||||
200: "List of child chunks.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("list_child_chunks")
|
||||
@service_api_ns.doc(description="List child chunks for a segment")
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@ -576,6 +640,15 @@ class ChildChunkApi(DatasetApiResource):
|
||||
class DatasetChildChunkApi(DatasetApiResource):
|
||||
"""Resource for updating child chunks."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Delete Child Chunk",
|
||||
description="Permanently delete a child chunk from its parent chunk.",
|
||||
tags=["Chunks"],
|
||||
responses={
|
||||
204: "Success.",
|
||||
400: "`invalid_param` : Delete child chunk index failed.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("delete_child_chunk")
|
||||
@service_api_ns.doc(description="Delete a specific child chunk")
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@ -634,6 +707,15 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
|
||||
return "", 204
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Update Child Chunk",
|
||||
description="Update the content of an existing child chunk.",
|
||||
tags=["Chunks"],
|
||||
responses={
|
||||
200: "Child chunk updated successfully.",
|
||||
400: "`invalid_param` : Update child chunk index failed.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_child_chunk")
|
||||
@service_api_ns.doc(description="Update a specific child chunk")
|
||||
|
||||
@ -17,6 +17,18 @@ register_response_schema_models(service_api_ns, EndUserDetail)
|
||||
class EndUserApi(Resource):
|
||||
"""Resource for retrieving end user details by ID."""
|
||||
|
||||
@service_api_ns.doc(
|
||||
summary="Get End User Info",
|
||||
description=(
|
||||
"Retrieve an end user by ID. Useful when other APIs return an end-user ID (e.g., "
|
||||
"`created_by` from [Upload File](/api-reference/files/upload-file))."
|
||||
),
|
||||
tags=["End Users"],
|
||||
responses={
|
||||
200: "End user retrieved successfully.",
|
||||
404: "`end_user_not_found` : End user not found.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_end_user")
|
||||
@service_api_ns.doc(description="Get an end user by ID")
|
||||
@service_api_ns.doc(
|
||||
|
||||
113
api/controllers/service_api/schema.py
Normal file
113
api/controllers/service_api/schema.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Service API OpenAPI documentation helpers.
|
||||
|
||||
These helpers keep documentation-only request shapes next to controller
|
||||
definitions without changing the Pydantic models used for runtime validation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import cast
|
||||
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel
|
||||
|
||||
USER_PROPERTY_SCHEMA: dict[str, object] = {"description": "End user identifier", "type": "string"}
|
||||
USER_QUERY_PARAM: dict[str, object] = {"description": "End user identifier", "in": "query", "type": "string"}
|
||||
USER_FORM_PARAM: dict[str, object] = {"description": "End user identifier", "in": "formData", "type": "string"}
|
||||
FILE_FORM_PARAM: dict[str, object] = {"in": "formData", "required": True, "type": "file"}
|
||||
USER_FETCH_FROM_ATTR = "_dify_service_api_user_fetch_from"
|
||||
USER_REQUIRED_ATTR = "_dify_service_api_user_required"
|
||||
JSON_USER_FETCH_FROM = "JSON"
|
||||
|
||||
|
||||
def expect_with_user(namespace: Namespace, model: type[BaseModel]):
|
||||
"""Document a JSON request body as ``model`` plus Service API ``user``."""
|
||||
|
||||
source_model = namespace.models[model.__name__]
|
||||
model_name = f"{model.__name__}WithUser"
|
||||
|
||||
def decorator(view_func):
|
||||
required = _json_user_required(view_func)
|
||||
schema = cast(dict[str, object], deepcopy(source_model.__schema__))
|
||||
_add_user_property(schema, required=required)
|
||||
if model_name not in namespace.models:
|
||||
namespace.schema_model(model_name, schema)
|
||||
return namespace.expect(namespace.models[model_name], validate=False)(view_func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def expect_user_json(namespace: Namespace):
|
||||
"""Document a JSON request body that only carries the Service API ``user``."""
|
||||
|
||||
def decorator(view_func):
|
||||
required = _json_user_required(view_func)
|
||||
schema: dict[str, object] = {"properties": {}, "title": "ServiceApiUserPayload", "type": "object"}
|
||||
_add_user_property(schema, required=required)
|
||||
model_name = "RequiredServiceApiUserPayload" if required else "OptionalServiceApiUserPayload"
|
||||
if model_name not in namespace.models:
|
||||
namespace.schema_model(model_name, schema)
|
||||
return namespace.expect(namespace.models[model_name], validate=False)(view_func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def multipart_file_params(*, include_user: bool) -> dict[str, dict[str, object]]:
|
||||
params: dict[str, dict[str, object]] = {"file": FILE_FORM_PARAM}
|
||||
if include_user:
|
||||
params["user"] = USER_FORM_PARAM
|
||||
return deepcopy(params)
|
||||
|
||||
|
||||
def json_or_event_stream_response(namespace: Namespace):
|
||||
return namespace.doc(produces=["application/json", "text/event-stream"])
|
||||
|
||||
|
||||
def event_stream_response(namespace: Namespace):
|
||||
return namespace.doc(produces=["text/event-stream"])
|
||||
|
||||
|
||||
def binary_response(namespace: Namespace, media_type: str | Sequence[str]):
|
||||
media_types = [media_type] if isinstance(media_type, str) else list(media_type)
|
||||
return namespace.doc(produces=media_types)
|
||||
|
||||
|
||||
def _json_user_required(view_func) -> bool:
|
||||
fetch_from = getattr(view_func, USER_FETCH_FROM_ATTR, None)
|
||||
if fetch_from != JSON_USER_FETCH_FROM:
|
||||
raise ValueError("JSON user documentation must match validate_app_token(fetch_user_arg=WhereisUserArg.JSON)")
|
||||
|
||||
return bool(getattr(view_func, USER_REQUIRED_ATTR, False))
|
||||
|
||||
|
||||
def _add_user_property(schema: dict[str, object], *, required: bool) -> None:
|
||||
variants: list[dict[str, object]] = []
|
||||
for keyword in ("anyOf", "oneOf"):
|
||||
candidates = schema.get(keyword)
|
||||
if isinstance(candidates, list):
|
||||
variants.extend(candidate for candidate in candidates if isinstance(candidate, dict))
|
||||
|
||||
if variants:
|
||||
for variant in variants:
|
||||
_add_user_property_to_object_schema(variant, required=required)
|
||||
|
||||
_add_user_property_to_object_schema(schema, required=required)
|
||||
|
||||
|
||||
def _add_user_property_to_object_schema(schema: dict[str, object], *, required: bool) -> None:
|
||||
properties = schema.setdefault("properties", {})
|
||||
if isinstance(properties, dict):
|
||||
cast(dict[str, object], properties)["user"] = USER_PROPERTY_SCHEMA
|
||||
|
||||
if required:
|
||||
required_fields = schema.setdefault("required", [])
|
||||
if isinstance(required_fields, list) and "user" not in required_fields:
|
||||
required_fields.append("user")
|
||||
else:
|
||||
required_fields = schema.get("required")
|
||||
if isinstance(required_fields, list) and "user" in required_fields:
|
||||
required_fields.remove("user")
|
||||
if required_fields == []:
|
||||
schema.pop("required", None)
|
||||
@ -19,6 +19,17 @@ register_response_schema_models(service_api_ns, ProviderWithModelsListResponse)
|
||||
|
||||
@service_api_ns.route("/workspaces/current/models/model-types/<string:model_type>")
|
||||
class ModelProviderAvailableModelApi(Resource):
|
||||
@service_api_ns.doc(
|
||||
summary="Get Available Models",
|
||||
description=(
|
||||
"Retrieve the list of available models by type. Primarily used to query `text-embedding` and "
|
||||
"`rerank` models for knowledge base configuration."
|
||||
),
|
||||
tags=["Models"],
|
||||
responses={
|
||||
200: "Available models for the specified type.",
|
||||
},
|
||||
)
|
||||
@service_api_ns.doc("get_available_models")
|
||||
@service_api_ns.doc(description="Get available models by model type")
|
||||
@service_api_ns.doc(params={"model_type": "Type of model to retrieve"})
|
||||
|
||||
@ -4,16 +4,23 @@ import time
|
||||
from collections.abc import Callable
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import cast, overload
|
||||
from typing import Protocol, cast, overload
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import Resource
|
||||
from flask_restx.utils import merge
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api.schema import (
|
||||
USER_FETCH_FROM_ATTR,
|
||||
USER_FORM_PARAM,
|
||||
USER_QUERY_PARAM,
|
||||
USER_REQUIRED_ATTR,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -28,6 +35,12 @@ from services.feature_service import FeatureService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RestxDocumentedView(Protocol):
|
||||
"""Callable view object carrying Flask-RESTX documentation metadata."""
|
||||
|
||||
__apidoc__: dict[str, object]
|
||||
|
||||
|
||||
class WhereisUserArg(StrEnum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
@ -43,6 +56,35 @@ class FetchUserArg(BaseModel):
|
||||
required: bool = False
|
||||
|
||||
|
||||
APP_TOKEN_FORBIDDEN_RESPONSE = {
|
||||
403: "Forbidden - token scope, app, dataset, or workspace access denied",
|
||||
}
|
||||
|
||||
DATASET_TOKEN_AUTH_RESPONSES = {
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - dataset API access or workspace access denied",
|
||||
}
|
||||
|
||||
|
||||
def _document_app_token_contract(view_func: Callable[..., object], fetch_user_arg: FetchUserArg | None) -> None:
|
||||
doc: dict[str, object] = {"responses": APP_TOKEN_FORBIDDEN_RESPONSE}
|
||||
if fetch_user_arg is not None:
|
||||
setattr(view_func, USER_FETCH_FROM_ATTR, fetch_user_arg.fetch_from.name)
|
||||
setattr(view_func, USER_REQUIRED_ATTR, fetch_user_arg.required)
|
||||
match fetch_user_arg.fetch_from:
|
||||
case WhereisUserArg.QUERY:
|
||||
doc["params"] = {"user": {**USER_QUERY_PARAM, "required": fetch_user_arg.required}}
|
||||
case WhereisUserArg.FORM:
|
||||
doc["params"] = {"user": {**USER_FORM_PARAM, "required": fetch_user_arg.required}}
|
||||
case WhereisUserArg.JSON:
|
||||
pass
|
||||
|
||||
cast(_RestxDocumentedView, view_func).__apidoc__ = cast(
|
||||
dict[str, object],
|
||||
merge(getattr(view_func, "__apidoc__", {}), doc),
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def validate_app_token[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
@ -126,6 +168,7 @@ def validate_app_token[**P, R](
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
_document_app_token_contract(decorated_view, fetch_user_arg)
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
@ -343,6 +386,8 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
__apidoc__ = {"responses": DATASET_TOKEN_AUTH_RESPONSES}
|
||||
|
||||
method_decorators = [validate_dataset_token]
|
||||
|
||||
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
|
||||
|
||||
@ -118,7 +118,7 @@ class BaseAgentRunner(AppRunner):
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.query: str | None = ""
|
||||
self.query: str = ""
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def _repack_app_generate_entity(
|
||||
|
||||
@ -72,13 +72,48 @@ def publish_text_answer(
|
||||
both the backend-produced answer and short-circuited answers (moderation /
|
||||
annotation reply) share the exact same persistence + SSE path.
|
||||
"""
|
||||
publish_text_delta(
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
delta=answer,
|
||||
user_query=user_query,
|
||||
)
|
||||
publish_message_end(
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
answer=answer,
|
||||
user_query=user_query,
|
||||
)
|
||||
|
||||
|
||||
def publish_text_delta(
|
||||
*,
|
||||
queue_manager: AppQueueManager,
|
||||
model_name: str,
|
||||
delta: str,
|
||||
user_query: str | None = None,
|
||||
) -> None:
|
||||
"""Publish one assistant text delta through the EasyUI chat pipeline."""
|
||||
if not delta:
|
||||
return
|
||||
prompt_messages = _prompt_messages_from_query(user_query)
|
||||
chunk = LLMResultChunk(
|
||||
model=model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=answer)),
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=delta)),
|
||||
)
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
|
||||
def publish_message_end(
|
||||
*,
|
||||
queue_manager: AppQueueManager,
|
||||
model_name: str,
|
||||
answer: str,
|
||||
user_query: str | None = None,
|
||||
) -> None:
|
||||
"""Publish the terminal assistant result without emitting another delta."""
|
||||
prompt_messages = _prompt_messages_from_query(user_query)
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
@ -151,7 +186,12 @@ class AgentAppRunner:
|
||||
)
|
||||
|
||||
create_response = self._agent_backend_client.create_run(runtime.request)
|
||||
terminal = self._consume_stream(create_response.run_id, queue_manager=queue_manager)
|
||||
terminal, streamed_answer = self._consume_stream(
|
||||
create_response.run_id,
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
query=query,
|
||||
)
|
||||
|
||||
if isinstance(terminal, AgentBackendDeferredToolCallInternalEvent):
|
||||
# ENG-635: the agent asked a human. End this turn with the question and
|
||||
@ -175,7 +215,13 @@ class AgentAppRunner:
|
||||
raise AgentBackendError(str(error))
|
||||
|
||||
answer = self._extract_answer(terminal.output)
|
||||
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, query=query)
|
||||
self._publish_terminal_answer(
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
answer=answer,
|
||||
query=query,
|
||||
streamed_answer=streamed_answer,
|
||||
)
|
||||
self._save_session(
|
||||
scope=scope,
|
||||
backend_run_id=terminal.run_id,
|
||||
@ -272,8 +318,16 @@ class AgentAppRunner:
|
||||
parts.append(args.markdown)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _consume_stream(self, run_id: str, *, queue_manager: AppQueueManager):
|
||||
def _consume_stream(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
queue_manager: AppQueueManager,
|
||||
model_name: str,
|
||||
query: str | None,
|
||||
):
|
||||
terminal = None
|
||||
streamed_answer_parts: list[str] = []
|
||||
for public_event in self._agent_backend_client.stream_events(run_id):
|
||||
if queue_manager.is_stopped():
|
||||
self._cancel_run(run_id)
|
||||
@ -286,16 +340,23 @@ class AgentAppRunner:
|
||||
AgentBackendInternalEventType.RUN_STARTED,
|
||||
AgentBackendInternalEventType.STREAM_EVENT,
|
||||
):
|
||||
# Stream deltas are accumulated by the backend into the
|
||||
# terminal output; token-level forwarding is an S3 refinement.
|
||||
if isinstance(internal_event, AgentBackendStreamInternalEvent):
|
||||
text_delta = self._extract_stream_text_delta(internal_event)
|
||||
if text_delta:
|
||||
streamed_answer_parts.append(text_delta)
|
||||
publish_text_delta(
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
delta=text_delta,
|
||||
user_query=query,
|
||||
)
|
||||
continue
|
||||
continue
|
||||
terminal = internal_event
|
||||
break
|
||||
if terminal is not None:
|
||||
break
|
||||
return terminal
|
||||
return terminal, "".join(streamed_answer_parts)
|
||||
|
||||
def _cancel_run(self, run_id: str) -> None:
|
||||
try:
|
||||
@ -310,6 +371,35 @@ class AgentAppRunner:
|
||||
# task pipeline streams the chunk over SSE and persists the message.
|
||||
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, user_query=query)
|
||||
|
||||
def _publish_terminal_answer(
|
||||
self,
|
||||
*,
|
||||
queue_manager: AppQueueManager,
|
||||
model_name: str,
|
||||
answer: str,
|
||||
query: str | None,
|
||||
streamed_answer: str,
|
||||
) -> None:
|
||||
"""Finish a successful streamed turn without duplicating the final text."""
|
||||
if not streamed_answer:
|
||||
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, query=query)
|
||||
return
|
||||
|
||||
if answer.startswith(streamed_answer):
|
||||
publish_text_delta(
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
delta=answer[len(streamed_answer) :],
|
||||
user_query=query,
|
||||
)
|
||||
elif answer != streamed_answer:
|
||||
logger.warning(
|
||||
"Agent App streamed answer does not match terminal output; "
|
||||
"using terminal output for message persistence."
|
||||
)
|
||||
|
||||
publish_message_end(queue_manager=queue_manager, model_name=model_name, answer=answer, user_query=query)
|
||||
|
||||
def _save_session(
|
||||
self,
|
||||
*,
|
||||
@ -357,5 +447,27 @@ class AgentAppRunner:
|
||||
return json.dumps(output, ensure_ascii=False)
|
||||
return json.dumps(output, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def _extract_stream_text_delta(event: AgentBackendStreamInternalEvent) -> str | None:
|
||||
data = event.data
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
__all__ = ["AgentAppRunner", "publish_text_answer"]
|
||||
if data.get("event_kind") == "part_delta":
|
||||
delta = data.get("delta")
|
||||
if isinstance(delta, dict) and delta.get("part_delta_kind") == "text":
|
||||
content_delta = delta.get("content_delta")
|
||||
if isinstance(content_delta, str):
|
||||
return content_delta
|
||||
|
||||
if data.get("event_kind") == "part_start":
|
||||
part = data.get("part")
|
||||
if isinstance(part, dict) and part.get("part_kind") == "text":
|
||||
content = part.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["AgentAppRunner", "publish_message_end", "publish_text_answer", "publish_text_delta"]
|
||||
|
||||
@ -231,22 +231,23 @@ class AppRunner:
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
if not stream and isinstance(invoke_result, LLMResult):
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
elif stream and isinstance(invoke_result, Generator):
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
match invoke_result:
|
||||
case LLMResult() if not stream:
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
case _ if stream and isinstance(invoke_result, Generator):
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case _:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
|
||||
def _handle_invoke_result_direct(
|
||||
self,
|
||||
|
||||
@ -882,7 +882,7 @@ class WorkflowResponseConverter:
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||
def _get_file_var_from_value(cls, value: object) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
@ -891,10 +891,11 @@ class WorkflowResponseConverter:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
match value:
|
||||
case dict() if value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
case File():
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -241,7 +241,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
created_by: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
finished_at: int | None
|
||||
exceptions_count: int | None = 0
|
||||
exceptions_count: int = 0
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
||||
|
||||
@ -144,15 +144,16 @@ def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str,
|
||||
Returns an empty dict if the context is missing or incomplete.
|
||||
"""
|
||||
parent_trace_context = args.get("parent_trace_context")
|
||||
if isinstance(parent_trace_context, ParentTraceContext):
|
||||
context = parent_trace_context
|
||||
elif isinstance(parent_trace_context, Mapping):
|
||||
try:
|
||||
context = ParentTraceContext.model_validate(parent_trace_context)
|
||||
except ValidationError:
|
||||
match parent_trace_context:
|
||||
case ParentTraceContext():
|
||||
context = parent_trace_context
|
||||
case Mapping():
|
||||
try:
|
||||
context = ParentTraceContext.model_validate(parent_trace_context)
|
||||
except ValidationError:
|
||||
return {}
|
||||
case _:
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
if context.parent_node_execution_id is None:
|
||||
return {}
|
||||
|
||||
@ -116,20 +116,21 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
case PluginParameterType.BOOLEAN:
|
||||
if value is None:
|
||||
return False
|
||||
elif isinstance(value, str):
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
else:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
match value:
|
||||
case None:
|
||||
return False
|
||||
case str():
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
case _:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
|
||||
case PluginParameterType.NUMBER:
|
||||
match value:
|
||||
|
||||
@ -71,8 +71,8 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
||||
tools: list[PromptMessageTool] = Field(default_factory=list[PromptMessageTool])
|
||||
stop: list[str] = Field(default_factory=list[str])
|
||||
tools: list[PromptMessageTool] | None = Field(default_factory=list[PromptMessageTool])
|
||||
stop: list[str] | None = Field(default_factory=list[str])
|
||||
stream: bool = False
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@ -20,6 +20,7 @@ from core.plugin.impl.exc import (
|
||||
PluginDaemonNotFoundError,
|
||||
PluginDaemonUnauthorizedError,
|
||||
PluginInvokeError,
|
||||
PluginLLMPollingUnsupportedError,
|
||||
PluginNotFoundError,
|
||||
PluginPermissionDeniedError,
|
||||
PluginUniqueIdentifierError,
|
||||
@ -370,6 +371,10 @@ class BasePluginClient:
|
||||
raise TriggerInvokeError(error_object.get("message"))
|
||||
case EventIgnoreError.__name__:
|
||||
raise EventIgnoreError(description=error_object.get("message"))
|
||||
# NOTE: current plugin sdk / plugin daemon does not raise exception with
|
||||
# type `PluginLLMPollingUnsupportedError`.
|
||||
case PluginLLMPollingUnsupportedError.__name__:
|
||||
raise PluginLLMPollingUnsupportedError(description=error_object.get("message"))
|
||||
case _:
|
||||
raise PluginInvokeError(description=message)
|
||||
case PluginDaemonInternalServerError.__name__:
|
||||
|
||||
@ -5,6 +5,13 @@ from pydantic import TypeAdapter
|
||||
|
||||
from extensions.ext_logging import get_request_id
|
||||
|
||||
# NOTE: Avoid renaming exception classes in this file, since
|
||||
# the `_handle_plugin_daemon_error` in api/core/plugin/impl/base.py
|
||||
# build exception instances based on the class name.
|
||||
#
|
||||
# Renaming of exception classes could result in incorrect exception
|
||||
# being raised.
|
||||
|
||||
|
||||
class PluginDaemonError(Exception):
|
||||
"""Base class for all plugin daemon errors."""
|
||||
@ -75,6 +82,10 @@ class PluginInvokeError(PluginDaemonClientSideError, ValueError):
|
||||
)
|
||||
|
||||
|
||||
class PluginLLMPollingUnsupportedError(PluginInvokeError):
|
||||
"""Plugin-backed LLM polling is unavailable for the requested model."""
|
||||
|
||||
|
||||
class PluginUniqueIdentifierError(PluginDaemonClientSideError):
|
||||
description: str = "Unique Identifier Error"
|
||||
|
||||
|
||||
@ -13,13 +13,17 @@ from core.plugin.entities.plugin_daemon import (
|
||||
PluginVoicesResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from core.plugin.impl.exc import PluginInvokeError, PluginLLMPollingUnsupportedError
|
||||
from graphon.model_runtime.entities.llm_entities import LLMPollingResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
_POLLING_UNSUPPORTED_INVOKE_ERROR_TYPES = frozenset((NotImplementedError.__name__,))
|
||||
_POLLING_UNSUPPORTED_ERROR_MESSAGE = "does not support polling"
|
||||
|
||||
|
||||
class PluginModelClient(BasePluginClient):
|
||||
@staticmethod
|
||||
@ -197,6 +201,103 @@ class PluginModelClient(BasePluginClient):
|
||||
except PluginDaemonInnerError as e:
|
||||
raise ValueError(e.message + str(e.code))
|
||||
|
||||
def start_llm_polling(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
json_schema: dict[str, Any] | None = None,
|
||||
) -> LLMPollingResult:
|
||||
"""Start an LLM polling request for plugin-backed long-running jobs."""
|
||||
try:
|
||||
return self._request_with_plugin_daemon_response(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/model/polling/start",
|
||||
type_=LLMPollingResult,
|
||||
data=jsonable_encoder(
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_parameters": model_parameters,
|
||||
"tools": tools,
|
||||
"stop": stop,
|
||||
"stream": False,
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
except PluginInvokeError as error:
|
||||
self._raise_typed_polling_unsupported_error(error)
|
||||
raise
|
||||
|
||||
def check_llm_polling(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
plugin_state: dict[str, Any],
|
||||
) -> LLMPollingResult:
|
||||
"""Check the latest state for a plugin-backed LLM polling job."""
|
||||
try:
|
||||
return self._request_with_plugin_daemon_response(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/model/polling/check",
|
||||
type_=LLMPollingResult,
|
||||
data=jsonable_encoder(
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"plugin_state": plugin_state,
|
||||
},
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
except PluginInvokeError as error:
|
||||
self._raise_typed_polling_unsupported_error(error)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _raise_typed_polling_unsupported_error(error: PluginInvokeError) -> None:
|
||||
"""Convert plugin polling capability failures into a dedicated Dify exception."""
|
||||
if error.get_error_type() == PluginLLMPollingUnsupportedError.__name__:
|
||||
raise PluginLLMPollingUnsupportedError(description=error.description) from error
|
||||
|
||||
if (
|
||||
error.get_error_type() in _POLLING_UNSUPPORTED_INVOKE_ERROR_TYPES
|
||||
# This is ugly, we should not rely on error messages while checking
|
||||
# error types.
|
||||
and _POLLING_UNSUPPORTED_ERROR_MESSAGE in error.get_error_message().lower()
|
||||
):
|
||||
raise PluginLLMPollingUnsupportedError(description=error.description) from error
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@ -6,6 +6,7 @@ from collections.abc import Generator, Iterable, Sequence
|
||||
from typing import IO, Any, Literal, cast, overload, override
|
||||
|
||||
from pydantic import ValidationError
|
||||
from pydantic.json_schema import JsonValue
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
@ -17,6 +18,7 @@ from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMPollingResult,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
@ -430,6 +432,54 @@ class PluginModelRuntime(ModelRuntime):
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
def start_llm_polling(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
json_schema: dict[str, Any] | None,
|
||||
) -> LLMPollingResult:
|
||||
"""Start a plugin-side polling job for long-running LLM invocations."""
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.start_llm_polling(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=list(stop) if stop else None,
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
def check_llm_polling(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
plugin_state: dict[str, JsonValue],
|
||||
) -> LLMPollingResult:
|
||||
"""Check the latest plugin-side polling state for an LLM invocation."""
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.check_llm_polling(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
plugin_state=plugin_state,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
|
||||
@ -304,22 +304,23 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
||||
Returns:
|
||||
True if any Dify $ref is found, False otherwise
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
# Check if this dict has a $ref field
|
||||
ref_uri = schema.get("$ref")
|
||||
if ref_uri and _is_dify_schema_ref(ref_uri):
|
||||
return True
|
||||
|
||||
# Check nested values
|
||||
for value in schema.values():
|
||||
if _has_dify_refs_recursive(value):
|
||||
match schema:
|
||||
case dict():
|
||||
# Check if this dict has a $ref field
|
||||
ref_uri = schema.get("$ref")
|
||||
if ref_uri and _is_dify_schema_ref(ref_uri):
|
||||
return True
|
||||
|
||||
elif isinstance(schema, list):
|
||||
# Check each item in the list
|
||||
for item in schema:
|
||||
if _has_dify_refs_recursive(item):
|
||||
return True
|
||||
# Check nested values
|
||||
for value in schema.values():
|
||||
if _has_dify_refs_recursive(value):
|
||||
return True
|
||||
|
||||
case list():
|
||||
# Check each item in the list
|
||||
for item in schema:
|
||||
if _has_dify_refs_recursive(item):
|
||||
return True
|
||||
|
||||
# Primitive types don't contain refs
|
||||
return False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any, override
|
||||
from datetime import datetime, tzinfo
|
||||
from typing import Any, cast, override
|
||||
|
||||
import pytz # type: ignore[import-untyped]
|
||||
|
||||
@ -35,17 +35,26 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
||||
|
||||
yield self.create_text_message(f"{timestamp}")
|
||||
|
||||
# TODO: this method's type is messy
|
||||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz: str | tzinfo | None = None) -> int | None:
|
||||
try:
|
||||
local_time = datetime.strptime(localtime, time_format)
|
||||
if local_tz is None:
|
||||
localtime = local_time.astimezone() # type: ignore
|
||||
elif isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
timestamp = int(localtime.timestamp()) # type: ignore
|
||||
converted_localtime: datetime
|
||||
match local_tz:
|
||||
case None:
|
||||
converted_localtime = local_time.astimezone()
|
||||
case str() as timezone_name:
|
||||
timezone = pytz.timezone(timezone_name)
|
||||
converted_localtime = timezone.localize(local_time)
|
||||
case tzinfo():
|
||||
localize = getattr(local_tz, "localize", None)
|
||||
if callable(localize):
|
||||
converted_localtime = cast(datetime, localize(local_time))
|
||||
else:
|
||||
converted_localtime = local_time.replace(tzinfo=local_tz)
|
||||
case _:
|
||||
raise ValueError("local_tz must be None, a timezone name, or a tzinfo instance")
|
||||
timestamp = int(converted_localtime.timestamp())
|
||||
return timestamp
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
||||
@ -122,13 +122,14 @@ class MCPTool(Tool):
|
||||
|
||||
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process JSON content based on its type."""
|
||||
if isinstance(content_json, dict):
|
||||
yield self.create_json_message(content_json)
|
||||
elif isinstance(content_json, list):
|
||||
yield from self._process_json_list(content_json)
|
||||
else:
|
||||
# For primitive types (str, int, bool, etc.), convert to string
|
||||
yield self.create_text_message(str(content_json))
|
||||
match content_json:
|
||||
case dict():
|
||||
yield self.create_json_message(content_json)
|
||||
case list():
|
||||
yield from self._process_json_list(content_json)
|
||||
case _:
|
||||
# For primitive types (str, int, bool, etc.), convert to string
|
||||
yield self.create_text_message(str(content_json))
|
||||
|
||||
def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process a list of JSON items."""
|
||||
@ -222,16 +223,17 @@ class MCPTool(Tool):
|
||||
|
||||
# Recursively search through nested structures
|
||||
for value in payload.values():
|
||||
if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
match value:
|
||||
case _ if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
case list() if not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
@override
|
||||
|
||||
@ -196,16 +196,17 @@ class WorkflowTool(Tool):
|
||||
return usage_candidate
|
||||
|
||||
for value in payload.values():
|
||||
if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
match value:
|
||||
case _ if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
case _ if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
@override
|
||||
@ -393,24 +394,25 @@ class WorkflowTool(Tool):
|
||||
files: list[File] = []
|
||||
result = {}
|
||||
for key, value in outputs.items():
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
item = self._update_file_mapping(item)
|
||||
file = build_from_mapping(
|
||||
mapping=item,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value = self._update_file_mapping(value)
|
||||
file = build_from_mapping(
|
||||
mapping=value,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
match value:
|
||||
case list():
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
item = self._update_file_mapping(item)
|
||||
file = build_from_mapping(
|
||||
mapping=item,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
case dict() if value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value = self._update_file_mapping(value)
|
||||
file = build_from_mapping(
|
||||
mapping=value,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
result[key] = value
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ class EventParameter(BaseModel):
|
||||
template: PluginParameterTemplate | None = Field(default=None, description="The template of the parameter")
|
||||
scope: str | None = None
|
||||
required: bool = False
|
||||
multiple: bool | None = Field(
|
||||
multiple: bool = Field(
|
||||
default=False,
|
||||
description="Whether the parameter is multiple select, only valid for select or dynamic-select type",
|
||||
)
|
||||
|
||||
@ -26,6 +26,7 @@ from core.workflow.node_runtime import (
|
||||
DifyFileReferenceFactory,
|
||||
DifyHumanInputNodeRuntime,
|
||||
DifyPreparedLLM,
|
||||
DifyPreparedPollingLLM,
|
||||
DifyPromptMessageSerializer,
|
||||
DifyRetrieverAttachmentLoader,
|
||||
DifyToolFileManager,
|
||||
@ -531,7 +532,11 @@ class DifyNodeFactory(NodeFactory):
|
||||
node_init_kwargs: dict[str, object] = {
|
||||
"credentials_provider": self._llm_credentials_provider,
|
||||
"model_factory": self._llm_model_factory,
|
||||
"model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance,
|
||||
"model_instance": (
|
||||
self._wrap_model_instance_for_node(node_data=validated_node_data, model_instance=model_instance)
|
||||
if wrap_model_instance
|
||||
else model_instance
|
||||
),
|
||||
"memory": self._build_memory_for_llm_node(
|
||||
node_data=validated_node_data,
|
||||
model_instance=model_instance,
|
||||
@ -555,6 +560,23 @@ class DifyNodeFactory(NodeFactory):
|
||||
node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY)
|
||||
return node_init_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _wrap_model_instance_for_node(
|
||||
*,
|
||||
node_data: LLMCompatibleNodeData,
|
||||
model_instance: ModelInstance,
|
||||
) -> DifyPreparedLLM:
|
||||
# Only graphon's LLM node consumes the polling protocol. Keep classifier
|
||||
# and extractor nodes on the existing wrapper even if the same model
|
||||
# advertises polling support.
|
||||
if node_data.type == BuiltinNodeTypes.LLM and DifyNodeFactory._supports_plugin_llm_polling(model_instance):
|
||||
return DifyPreparedPollingLLM(model_instance)
|
||||
return DifyPreparedLLM(model_instance)
|
||||
|
||||
@staticmethod
|
||||
def _supports_plugin_llm_polling(model_instance: ModelInstance) -> bool:
|
||||
return model_instance.get_model_schema().support_polling
|
||||
|
||||
def _build_retriever_attachment_loader(self, node_data: LLMNodeData) -> DifyRetrieverAttachmentLoader:
|
||||
return DifyRetrieverAttachmentLoader(
|
||||
file_reference_factory=self._file_reference_factory,
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload, override
|
||||
|
||||
from pydantic import JsonValue
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -38,6 +39,7 @@ from factories import file_factory
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.model_runtime.entities import LLMMode
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMPollingResult,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
@ -54,6 +56,7 @@ from graphon.nodes.human_input.entities import (
|
||||
HumanInputNodeData,
|
||||
)
|
||||
from graphon.nodes.llm.runtime_protocols import (
|
||||
LLMPollingCapableProtocol,
|
||||
LLMProtocol,
|
||||
PromptMessageSerializerProtocol,
|
||||
RetrieverAttachmentLoaderProtocol,
|
||||
@ -278,6 +281,58 @@ class DifyPreparedLLM(LLMProtocol):
|
||||
return isinstance(error, OutputParserError)
|
||||
|
||||
|
||||
class DifyPreparedPollingLLM(DifyPreparedLLM, LLMPollingCapableProtocol):
|
||||
"""Prepared workflow LLM adapter that exposes Graphon's polling protocol."""
|
||||
|
||||
def __init__(self, model_instance: ModelInstance) -> None:
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
|
||||
super().__init__(model_instance)
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
if not isinstance(model_type_instance, LargeLanguageModel):
|
||||
raise TypeError("Polling wrapper requires a large-language-model instance.")
|
||||
|
||||
plugin_model_runtime = model_type_instance.model_runtime
|
||||
if not isinstance(plugin_model_runtime, PluginModelRuntime):
|
||||
raise TypeError("Polling wrapper requires a plugin-backed model runtime.")
|
||||
|
||||
self._plugin_model_runtime = plugin_model_runtime
|
||||
|
||||
@override
|
||||
def start_llm_polling(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
json_schema: Mapping[str, Any] | None,
|
||||
) -> LLMPollingResult:
|
||||
return self._plugin_model_runtime.start_llm_polling(
|
||||
provider=self.provider,
|
||||
model=self.model_name,
|
||||
credentials=self._model_instance.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=dict(model_parameters),
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
json_schema=dict(json_schema) if json_schema is not None else None,
|
||||
)
|
||||
|
||||
@override
|
||||
def check_llm_polling(
|
||||
self,
|
||||
*,
|
||||
plugin_state: Mapping[str, JsonValue],
|
||||
) -> LLMPollingResult:
|
||||
return self._plugin_model_runtime.check_llm_polling(
|
||||
provider=self.provider,
|
||||
model=self.model_name,
|
||||
credentials=self._model_instance.credentials,
|
||||
plugin_state=dict(plugin_state),
|
||||
)
|
||||
|
||||
|
||||
class DifyPromptMessageSerializer(PromptMessageSerializerProtocol):
|
||||
@override
|
||||
def serialize(
|
||||
|
||||
@ -297,25 +297,26 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
for cond in conditions.conditions or []:
|
||||
value = cond.value
|
||||
resolved_value: str | Sequence[str] | int | float | None
|
||||
if isinstance(value, str):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_value = segment_group.text
|
||||
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values: list[str] = []
|
||||
for v in value:
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
match value:
|
||||
case str():
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(
|
||||
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
|
||||
)
|
||||
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
else:
|
||||
resolved_value = value
|
||||
resolved_value = segment_group.text
|
||||
case _ if isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values: list[str] = []
|
||||
for v in value:
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(
|
||||
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
|
||||
)
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
case _:
|
||||
resolved_value = value
|
||||
resolved_conditions.append(
|
||||
Condition(
|
||||
name=cond.name,
|
||||
|
||||
@ -167,6 +167,12 @@ def _patch_union_schema_markdown(markdown: str, spec_path: Path) -> str:
|
||||
return markdown
|
||||
|
||||
|
||||
def _strip_trailing_line_whitespace(markdown: str) -> str:
|
||||
"""Remove converter-emitted trailing whitespace without changing line structure."""
|
||||
|
||||
return "\n".join(line.rstrip(" \t") for line in markdown.split("\n"))
|
||||
|
||||
|
||||
def _convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None:
|
||||
markdown_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.TemporaryDirectory(prefix=f"{markdown_path.stem}-") as temp_dir:
|
||||
@ -201,6 +207,7 @@ def _convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None:
|
||||
temp_markdown_path.read_text(encoding="utf-8"),
|
||||
spec_path,
|
||||
)
|
||||
converted_markdown = _strip_trailing_line_whitespace(converted_markdown)
|
||||
if not converted_markdown.strip():
|
||||
raise RuntimeError(f"swagger-markdown wrote an empty document for {markdown_path}")
|
||||
|
||||
|
||||
@ -104,11 +104,14 @@ def _field_signature(field: object) -> object:
|
||||
"description",
|
||||
"example",
|
||||
"max",
|
||||
"max_items",
|
||||
"min",
|
||||
"min_items",
|
||||
"nullable",
|
||||
"readonly",
|
||||
"required",
|
||||
"title",
|
||||
"unique",
|
||||
):
|
||||
if hasattr(field_instance, attr_name):
|
||||
signature[attr_name] = _jsonable_schema_value(getattr(field_instance, attr_name))
|
||||
@ -154,9 +157,9 @@ def create_spec_app() -> Flask:
|
||||
|
||||
apply_runtime_defaults()
|
||||
|
||||
from libs.flask_restx_compat import patch_swagger_for_inline_nested_dicts
|
||||
from libs.flask_restx_compat import install_swagger_compatibility
|
||||
|
||||
patch_swagger_for_inline_nested_dicts()
|
||||
install_swagger_compatibility()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@ -354,11 +354,12 @@ def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]:
|
||||
|
||||
|
||||
def target_names(target: ast.AST) -> Iterable[str]:
|
||||
if isinstance(target, ast.Name):
|
||||
yield target.id
|
||||
elif isinstance(target, ast.Tuple | ast.List):
|
||||
for item in target.elts:
|
||||
yield from target_names(item)
|
||||
match target:
|
||||
case ast.Name():
|
||||
yield target.id
|
||||
case ast.Tuple() | ast.List():
|
||||
for item in target.elts:
|
||||
yield from target_names(item)
|
||||
|
||||
|
||||
def record_assignment(
|
||||
|
||||
@ -107,17 +107,58 @@ class AgentInviteOptionsResponse(ResponseModel):
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AgentLogItemResponse(ResponseModel):
|
||||
class AgentLogSourceResponse(ResponseModel):
|
||||
id: str
|
||||
type: Literal["webapp", "workflow"]
|
||||
app_id: str
|
||||
app_name: str
|
||||
app_icon_type: str | None = None
|
||||
app_icon: str | None = None
|
||||
app_icon_background: str | None = None
|
||||
workflow_id: str | None = None
|
||||
workflow_version: str | None = None
|
||||
node_id: str | None = None
|
||||
|
||||
|
||||
class AgentLogSourceGroupResponse(ResponseModel):
|
||||
type: Literal["webapp", "workflow"]
|
||||
label: str
|
||||
sources: list[AgentLogSourceResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentLogSourceListResponse(ResponseModel):
|
||||
data: list[AgentLogSourceResponse]
|
||||
groups: list[AgentLogSourceGroupResponse]
|
||||
|
||||
|
||||
class AgentLogConversationItemResponse(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
title: str | None = None
|
||||
end_user_id: str | None = None
|
||||
message_count: int
|
||||
user_rate: float | None = None
|
||||
operation_rate: float | None = None
|
||||
unread: bool
|
||||
source: AgentLogSourceResponse | None = None
|
||||
status: Literal["success", "failed", "paused"]
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class AgentLogMessageItemResponse(ResponseModel):
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
conversation_name: str | None = None
|
||||
query: str
|
||||
answer: str
|
||||
status: str
|
||||
error: str | None = None
|
||||
source: str | None = None
|
||||
from_source: str | None = None
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
message_tokens: int
|
||||
@ -136,7 +177,15 @@ class AgentLogItemResponse(ResponseModel):
|
||||
|
||||
|
||||
class AgentLogListResponse(ResponseModel):
|
||||
data: list[AgentLogItemResponse]
|
||||
data: list[AgentLogConversationItemResponse]
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AgentLogMessageListResponse(ResponseModel):
|
||||
data: list[AgentLogMessageItemResponse]
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
|
||||
@ -94,12 +94,13 @@ class RedisSubscriptionBase(Subscription):
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
match channel_field:
|
||||
case bytes():
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
case str():
|
||||
channel_name = channel_field
|
||||
case _:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning(
|
||||
|
||||
@ -88,22 +88,23 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
#
|
||||
# Since we have already filtered at the caller's site, we can safely set
|
||||
# `ignore_subscribe_messages=False`.
|
||||
if isinstance(self._client, RedisCluster):
|
||||
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
|
||||
# specifying the `target_node` argument would use busy-looping to wait
|
||||
# for incoming message, consuming excessive CPU quota.
|
||||
#
|
||||
# Here we specify the `target_node` to mitigate this problem.
|
||||
node = self._client.get_node_from_key(self._topic)
|
||||
return self._pubsub.get_sharded_message( # type: ignore[attr-defined]
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=1,
|
||||
target_node=node,
|
||||
)
|
||||
elif isinstance(self._client, Redis):
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
|
||||
else:
|
||||
raise AssertionError("client should be either Redis or RedisCluster.")
|
||||
match self._client:
|
||||
case RedisCluster():
|
||||
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
|
||||
# specifying the `target_node` argument would use busy-looping to wait
|
||||
# for incoming message, consuming excessive CPU quota.
|
||||
#
|
||||
# Here we specify the `target_node` to mitigate this problem.
|
||||
node = self._client.get_node_from_key(self._topic)
|
||||
return self._pubsub.get_sharded_message( # type: ignore[attr-defined]
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=1,
|
||||
target_node=node,
|
||||
)
|
||||
case Redis():
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
|
||||
case _:
|
||||
raise AssertionError("client should be either Redis or RedisCluster.")
|
||||
|
||||
@override
|
||||
def _get_message_type(self) -> str:
|
||||
|
||||
@ -138,10 +138,11 @@ class _StreamsSubscription(Subscription):
|
||||
if isinstance(fields, dict):
|
||||
data = fields.get(b"data")
|
||||
data_bytes: bytes | None = None
|
||||
if isinstance(data, str):
|
||||
data_bytes = data.encode()
|
||||
elif isinstance(data, (bytes, bytearray)):
|
||||
data_bytes = bytes(data)
|
||||
match data:
|
||||
case str():
|
||||
data_bytes = data.encode()
|
||||
case bytes() | bytearray():
|
||||
data_bytes = bytes(data)
|
||||
if data_bytes is not None:
|
||||
self._queue.put_nowait(data_bytes)
|
||||
last_id = entry_id
|
||||
|
||||
@ -9,7 +9,7 @@ from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from libs.flask_restx_compat import patch_swagger_for_inline_nested_dicts
|
||||
from libs.flask_restx_compat import install_swagger_compatibility
|
||||
from libs.token import build_force_logout_cookie_headers
|
||||
|
||||
|
||||
@ -127,16 +127,22 @@ def register_external_error_handlers(api: Api, body_formatter: ErrorBodyFormatte
|
||||
class ExternalApi(Api):
|
||||
_authorizations = {
|
||||
"Bearer": {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": "Authorization",
|
||||
"description": "Type: Bearer {your-api-key}",
|
||||
"bearerFormat": "API_KEY",
|
||||
"description": "Use the Service API key as a Bearer token in the Authorization header.",
|
||||
"scheme": "bearer",
|
||||
"type": "http",
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, app: Blueprint | Flask, *args, error_body_formatter: ErrorBodyFormatter | None = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
app: Blueprint | Flask,
|
||||
*args,
|
||||
error_body_formatter: ErrorBodyFormatter | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._error_body_formatter = error_body_formatter
|
||||
patch_swagger_for_inline_nested_dicts()
|
||||
install_swagger_compatibility()
|
||||
kwargs.setdefault("authorizations", self._authorizations)
|
||||
kwargs.setdefault("security", "Bearer")
|
||||
kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED
|
||||
|
||||
@ -8,12 +8,14 @@ spec export fail or succeed in the same way.
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import TypeGuard
|
||||
from typing import TypeGuard, cast
|
||||
|
||||
from flask import current_app
|
||||
from flask_restx import fields
|
||||
from flask_restx import swagger as restx_swagger
|
||||
from flask_restx.model import Model, OrderedModel, instance
|
||||
from flask_restx.swagger import Swagger
|
||||
from flask_restx.utils import not_none
|
||||
|
||||
|
||||
def _is_inline_field_map(value: object) -> TypeGuard[dict[object, object]]:
|
||||
@ -98,20 +100,28 @@ def _inline_model_name(nested_fields: dict[object, object]) -> str:
|
||||
return f"_AnonymousInlineModel_{digest}"
|
||||
|
||||
|
||||
def patch_swagger_for_inline_nested_dicts() -> None:
|
||||
"""Allow OpenAPI generation to handle legacy inline Flask-RESTX field dicts.
|
||||
def install_swagger_compatibility() -> None:
|
||||
"""Install Dify's Flask-RESTX OpenAPI compatibility hooks.
|
||||
|
||||
Some existing controllers use raw field mappings in `fields.Nested({...})`
|
||||
or directly in `@namespace.response(...)`. Runtime marshalling accepts that,
|
||||
but Flask-RESTX registration expects a named model. Convert those
|
||||
anonymous mappings into temporary named models during docs generation.
|
||||
|
||||
Flask-RESTX also drops parameter descriptions from generated schemas and
|
||||
does not expose the Werkzeug `uuid` route converter as `format: uuid`.
|
||||
"""
|
||||
|
||||
if getattr(Swagger, "_dify_inline_nested_dict_patch", False):
|
||||
if getattr(Swagger, "_dify_swagger_compatibility_installed", False):
|
||||
return
|
||||
|
||||
original_register_model = Swagger.register_model
|
||||
original_register_field = Swagger.register_field
|
||||
original_extract_path_params = restx_swagger.extract_path_params
|
||||
original_schema_from_parameter = Swagger.schema_from_parameter
|
||||
original_description_for = Swagger.description_for
|
||||
original_serialize_operation = Swagger.serialize_operation
|
||||
original_parameters_and_request_body_for = Swagger.parameters_and_request_body_for
|
||||
original_as_dict = Swagger.as_dict
|
||||
|
||||
def get_or_create_inline_model(self: Swagger, nested_fields: dict[object, object]) -> object:
|
||||
@ -134,6 +144,65 @@ def patch_swagger_for_inline_nested_dicts() -> None:
|
||||
|
||||
original_register_field(self, field)
|
||||
|
||||
def schema_from_parameter_with_description(self: Swagger, param: dict[str, object]) -> dict[str, object]:
|
||||
schema = cast(dict[str, object], original_schema_from_parameter(self, param))
|
||||
description = param.get("description")
|
||||
if isinstance(description, str):
|
||||
schema["description"] = description
|
||||
return schema
|
||||
|
||||
def extract_path_params_with_uuid_format(path: str):
|
||||
params = original_extract_path_params(path)
|
||||
for converter, _arguments, variable in restx_swagger.parse_rule(path):
|
||||
if converter == "uuid" and variable in params:
|
||||
params[variable]["format"] = "uuid"
|
||||
return params
|
||||
|
||||
def description_for_with_explicit_summary(self: Swagger, doc: dict[str, object], method: str):
|
||||
method_doc = doc.get(method)
|
||||
if (
|
||||
isinstance(method_doc, dict)
|
||||
and isinstance(method_doc.get("summary"), str)
|
||||
and isinstance(method_doc.get("description"), str)
|
||||
):
|
||||
return method_doc["description"]
|
||||
return original_description_for(self, doc, method)
|
||||
|
||||
def serialize_operation_with_explicit_summary_tags(
|
||||
self: Swagger, doc: dict[str, object], method: str, inherited_request_body=None
|
||||
):
|
||||
operation = original_serialize_operation(self, doc, method, inherited_request_body)
|
||||
method_doc = doc.get(method)
|
||||
if not isinstance(method_doc, dict):
|
||||
return operation
|
||||
|
||||
summary = method_doc.get("summary")
|
||||
if isinstance(summary, str):
|
||||
operation["summary"] = summary
|
||||
|
||||
tags = method_doc.get("tags")
|
||||
if isinstance(tags, list) and all(isinstance(tag, str) for tag in tags):
|
||||
operation["tags"] = tags
|
||||
|
||||
return operation
|
||||
|
||||
def serialize_resource_with_explicit_operation_tags(self: Swagger, ns, resource, url, route_doc=None, **kwargs):
|
||||
doc = self.extract_resource_doc(resource, url, route_doc=route_doc)
|
||||
if doc is False:
|
||||
return None
|
||||
|
||||
path_params, path_request_body = original_parameters_and_request_body_for(self, doc)
|
||||
path: dict[str, object] = {"parameters": path_params or None}
|
||||
methods = [method.lower() for method in resource.methods or []]
|
||||
requested_methods = [method.lower() for method in kwargs.get("methods", [])]
|
||||
for method in methods:
|
||||
if doc[method] is False or requested_methods and method not in requested_methods:
|
||||
continue
|
||||
operation = self.serialize_operation(doc, method, path_request_body)
|
||||
operation.setdefault("tags", [ns.name])
|
||||
path[method] = operation
|
||||
return not_none(path)
|
||||
|
||||
def as_dict_with_inline_dict_support(self: Swagger):
|
||||
# Temporary set RESTX_INCLUDE_ALL_MODELS = false to prevent "length changed while iterating" error
|
||||
include_all_models = current_app.config.get("RESTX_INCLUDE_ALL_MODELS", False)
|
||||
@ -145,5 +214,10 @@ def patch_swagger_for_inline_nested_dicts() -> None:
|
||||
|
||||
Swagger.register_model = register_model_with_inline_dict_support
|
||||
Swagger.register_field = register_field_with_inline_dict_support
|
||||
restx_swagger.extract_path_params = extract_path_params_with_uuid_format
|
||||
Swagger.schema_from_parameter = schema_from_parameter_with_description
|
||||
Swagger.description_for = description_for_with_explicit_summary
|
||||
Swagger.serialize_operation = serialize_operation_with_explicit_summary_tags
|
||||
Swagger.serialize_resource = serialize_resource_with_explicit_operation_tags
|
||||
Swagger.as_dict = as_dict_with_inline_dict_support
|
||||
Swagger._dify_inline_nested_dict_patch = True
|
||||
Swagger._dify_swagger_compatibility_installed = True
|
||||
|
||||
@ -1174,34 +1174,32 @@ class Conversation(Base):
|
||||
|
||||
# Convert file mapping to File object
|
||||
for key, value in inputs.items():
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
):
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
inputs[key] = build_file_from_input_mapping(
|
||||
file_mapping=value_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
value_list = value
|
||||
if all(
|
||||
isinstance(item, dict)
|
||||
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
for item in value_list
|
||||
):
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
file_list.append(
|
||||
build_file_from_input_mapping(
|
||||
file_mapping=item_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
match value:
|
||||
case dict() if cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
inputs[key] = build_file_from_input_mapping(
|
||||
file_mapping=value_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
case list():
|
||||
value_list = value
|
||||
if all(
|
||||
isinstance(item, dict)
|
||||
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
for item in value_list
|
||||
):
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
file_list.append(
|
||||
build_file_from_input_mapping(
|
||||
file_mapping=item_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
)
|
||||
)
|
||||
inputs[key] = file_list
|
||||
inputs[key] = file_list
|
||||
|
||||
return inputs
|
||||
|
||||
@ -1516,46 +1514,45 @@ class Message(Base):
|
||||
owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)),
|
||||
)
|
||||
for key, value in inputs.items():
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
):
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
inputs[key] = build_file_from_input_mapping(
|
||||
file_mapping=value_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
value_list = value
|
||||
if all(
|
||||
isinstance(item, dict)
|
||||
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
for item in value_list
|
||||
):
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
file_list.append(
|
||||
build_file_from_input_mapping(
|
||||
file_mapping=item_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
match value:
|
||||
case dict() if cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
inputs[key] = build_file_from_input_mapping(
|
||||
file_mapping=value_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
case list():
|
||||
value_list = value
|
||||
if all(
|
||||
isinstance(item, dict)
|
||||
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
for item in value_list
|
||||
):
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
file_list.append(
|
||||
build_file_from_input_mapping(
|
||||
file_mapping=item_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
)
|
||||
)
|
||||
inputs[key] = file_list
|
||||
inputs[key] = file_list
|
||||
return inputs
|
||||
|
||||
@inputs.setter
|
||||
def inputs(self, value: Mapping[str, Any]):
|
||||
inputs = dict(value)
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, File):
|
||||
inputs[k] = v.model_dump()
|
||||
elif isinstance(v, list):
|
||||
v_list = v
|
||||
if all(isinstance(item, File) for item in v_list):
|
||||
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
|
||||
match v:
|
||||
case File():
|
||||
inputs[k] = v.model_dump()
|
||||
case list():
|
||||
v_list = v
|
||||
if all(isinstance(item, File) for item in v_list):
|
||||
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
|
||||
@ -96,12 +96,13 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]):
|
||||
def process_bind_param(self, value: T | dict[str, Any] | str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self._model_class):
|
||||
model = value
|
||||
elif isinstance(value, str):
|
||||
model = self._model_class.model_validate_json(value)
|
||||
else:
|
||||
model = self._model_class.model_validate(value)
|
||||
match value:
|
||||
case _ if isinstance(value, self._model_class):
|
||||
model = value
|
||||
case str():
|
||||
model = self._model_class.model_validate_json(value)
|
||||
case _:
|
||||
model = self._model_class.model_validate(value)
|
||||
return json.dumps(model.model_dump(mode="json"), ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
@override
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -4,10 +4,9 @@ User-scoped programmatic API (bearer auth)
|
||||
## Version: 1.0
|
||||
|
||||
### Available authorizations
|
||||
#### Bearer (API Key Authentication)
|
||||
Type: Bearer {your-api-key}
|
||||
**Name:** Authorization
|
||||
**In:** header
|
||||
#### Bearer (HTTP, bearer)
|
||||
Use the Service API key as a Bearer token in the Authorization header.
|
||||
Bearer format: API_KEY
|
||||
|
||||
---
|
||||
## openapi
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -4,10 +4,9 @@ Public APIs for web applications including file uploads, chat interactions, and
|
||||
## Version: 1.0
|
||||
|
||||
### Available authorizations
|
||||
#### Bearer (API Key Authentication)
|
||||
Type: Bearer {your-api-key}
|
||||
**Name:** Authorization
|
||||
**In:** header
|
||||
#### Bearer (HTTP, bearer)
|
||||
Use the Service API key as a Bearer token in the Authorization header.
|
||||
Bearer format: API_KEY
|
||||
|
||||
---
|
||||
## web
|
||||
@ -140,7 +139,7 @@ Delete a specific conversation.
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| c_id | path | Conversation UUID | Yes | string |
|
||||
| c_id | path | Conversation UUID | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
@ -160,7 +159,7 @@ Rename a specific conversation with a custom name or auto-generate one.
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| c_id | path | Conversation UUID | Yes | string |
|
||||
| c_id | path | Conversation UUID | Yes | string (uuid) |
|
||||
| auto_generate | query | Auto-generate conversation name | No | boolean |
|
||||
| name | query | New conversation name | No | string |
|
||||
|
||||
@ -188,7 +187,7 @@ Pin a specific conversation to keep it at the top of the list.
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| c_id | path | Conversation UUID | Yes | string |
|
||||
| c_id | path | Conversation UUID | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
@ -208,7 +207,7 @@ Unpin a specific conversation to remove it from the top of the list.
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| c_id | path | Conversation UUID | Yes | string |
|
||||
| c_id | path | Conversation UUID | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
@ -494,7 +493,7 @@ Submit feedback (like/dislike) for a specific message.
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| message_id | path | Message UUID | Yes | string |
|
||||
| message_id | path | Message UUID | Yes | string (uuid) |
|
||||
| content | query | Feedback content | No | string |
|
||||
| rating | query | Feedback rating | No | string, <br>**Available values:** "dislike", "like" |
|
||||
|
||||
@ -523,7 +522,7 @@ Generate a new completion similar to an existing message (completion apps only).
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| response_mode | query | Response mode | Yes | string, <br>**Available values:** "blocking", "streaming" |
|
||||
| message_id | path | | Yes | string |
|
||||
| message_id | path | | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
@ -543,7 +542,7 @@ Get suggested follow-up questions after a message (chat apps only).
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| message_id | path | Message UUID | Yes | string |
|
||||
| message_id | path | Message UUID | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
@ -731,7 +730,7 @@ Remove a message from saved messages.
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| message_id | path | Message UUID to delete | Yes | string |
|
||||
| message_id | path | Message UUID to delete | Yes | string (uuid) |
|
||||
|
||||
#### Responses
|
||||
|
||||
@ -1633,7 +1632,7 @@ Default configuration for form inputs.
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| message_id | string | Message ID | No |
|
||||
| streaming | boolean | Enable streaming response | No |
|
||||
| streaming | boolean | Reserved for compatibility; TTS response streaming is determined by the provider output. | No |
|
||||
| text | string | Text to convert to audio | No |
|
||||
| voice | string | Voice to use for TTS | No |
|
||||
|
||||
|
||||
@ -1150,13 +1150,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
try:
|
||||
# Convert outputs to string based on type
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = safe_json_dumps(trace_info.outputs)
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
|
||||
elif isinstance(trace_info.outputs, str):
|
||||
outputs_str = trace_info.outputs
|
||||
else:
|
||||
outputs_str = str(trace_info.outputs)
|
||||
match trace_info.outputs:
|
||||
case dict() | list():
|
||||
outputs_str = safe_json_dumps(trace_info.outputs)
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
|
||||
case str():
|
||||
outputs_str = trace_info.outputs
|
||||
case _:
|
||||
outputs_str = str(trace_info.outputs)
|
||||
|
||||
llm_attributes: dict[str, Any] = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
|
||||
@ -1553,25 +1554,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id)
|
||||
|
||||
# Handle list of messages
|
||||
if isinstance(prompts, list):
|
||||
for message_index, message in enumerate(prompts):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
match prompts:
|
||||
case list():
|
||||
for message_index, message in enumerate(prompts):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
|
||||
role = message.get("role", "user")
|
||||
content = message.get("text") or message.get("content") or ""
|
||||
role = message.get("role", "user")
|
||||
content = message.get("text") or message.get("content") or ""
|
||||
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
|
||||
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_index, tool_call in enumerate(tool_calls):
|
||||
set_tool_call_attributes(message_index, tool_index, tool_call)
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_index, tool_call in enumerate(tool_calls):
|
||||
set_tool_call_attributes(message_index, tool_index, tool_call)
|
||||
|
||||
# Handle single dict or plain string prompt
|
||||
elif isinstance(prompts, (dict, str)):
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
|
||||
# Handle single dict or plain string prompt
|
||||
case dict() | str():
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
|
||||
|
||||
return attributes
|
||||
|
||||
@ -18,24 +18,25 @@ def validate_input_output(v, field_name):
|
||||
"""
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
return [
|
||||
{
|
||||
"role": "assistant" if field_name == "output" else "user",
|
||||
"content": v,
|
||||
}
|
||||
]
|
||||
elif isinstance(v, list):
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
v = replace_text_with_content(data=v)
|
||||
return v
|
||||
else:
|
||||
match v:
|
||||
case str():
|
||||
return [
|
||||
{
|
||||
"role": "assistant" if field_name == "output" else "user",
|
||||
"content": str(v),
|
||||
"content": v,
|
||||
}
|
||||
]
|
||||
case list():
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
v = replace_text_with_content(data=v)
|
||||
return v
|
||||
else:
|
||||
return [
|
||||
{
|
||||
"role": "assistant" if field_name == "output" else "user",
|
||||
"content": str(v),
|
||||
}
|
||||
]
|
||||
|
||||
return v
|
||||
|
||||
|
||||
@ -64,40 +64,20 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
"total_tokens": values.get("total_tokens", 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
match field_name:
|
||||
case "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case _:
|
||||
pass
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
match v:
|
||||
case str():
|
||||
match field_name:
|
||||
case "inputs":
|
||||
data = {
|
||||
"messages": v,
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case "outputs":
|
||||
data = {
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
@ -107,16 +87,37 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
}
|
||||
case _:
|
||||
pass
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case list():
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
match field_name:
|
||||
case "inputs":
|
||||
data = {
|
||||
"messages": v,
|
||||
}
|
||||
case "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case _:
|
||||
pass
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
if isinstance(v, dict):
|
||||
v["usage_metadata"] = usage_metadata
|
||||
v["file_list"] = file_list
|
||||
|
||||
@ -40,41 +40,19 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
"total_tokens": values.get("total_tokens", 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
if field_name == "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
match v:
|
||||
case str():
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
|
||||
for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
@ -82,16 +60,39 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case list():
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
|
||||
for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
if isinstance(v, dict):
|
||||
v["usage_metadata"] = usage_metadata
|
||||
v["file_list"] = file_list
|
||||
|
||||
@ -361,12 +361,13 @@ class ClickzettaVector(BaseVector):
|
||||
first_pass = json.loads(raw_metadata)
|
||||
|
||||
# Handle double-encoded JSON
|
||||
if isinstance(first_pass, str):
|
||||
metadata = parse_metadata_json(first_pass)
|
||||
elif isinstance(first_pass, dict):
|
||||
metadata = first_pass
|
||||
else:
|
||||
metadata = {}
|
||||
match first_pass:
|
||||
case str():
|
||||
metadata = parse_metadata_json(first_pass)
|
||||
case dict():
|
||||
metadata = first_pass
|
||||
case _:
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
@ -942,12 +943,13 @@ class ClickzettaVector(BaseVector):
|
||||
# First parse may yield a string (double-encoded JSON)
|
||||
first_pass = json.loads(row[2])
|
||||
|
||||
if isinstance(first_pass, str):
|
||||
metadata = parse_metadata_json(first_pass)
|
||||
elif isinstance(first_pass, dict):
|
||||
metadata = first_pass
|
||||
else:
|
||||
metadata = {}
|
||||
match first_pass:
|
||||
case str():
|
||||
metadata = parse_metadata_json(first_pass)
|
||||
case dict():
|
||||
metadata = first_pass
|
||||
case _:
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
|
||||
@ -98,14 +98,15 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
|
||||
values: list[Any] = []
|
||||
if isinstance(query, psql.Composed):
|
||||
for part in query:
|
||||
if isinstance(part, psql.Identifier):
|
||||
values.append(("ident", part._obj[0] if part._obj else ""))
|
||||
elif isinstance(part, psql.Literal):
|
||||
values.append(("literal", part._obj))
|
||||
elif isinstance(part, psql.Composed):
|
||||
for sub in part:
|
||||
if isinstance(sub, psql.Literal):
|
||||
values.append(("literal", sub._obj))
|
||||
match part:
|
||||
case psql.Identifier():
|
||||
values.append(("ident", part._obj[0] if part._obj else ""))
|
||||
case psql.Literal():
|
||||
values.append(("literal", part._obj))
|
||||
case psql.Composed():
|
||||
for sub in part:
|
||||
if isinstance(sub, psql.Literal):
|
||||
values.append(("literal", sub._obj))
|
||||
return values
|
||||
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ dependencies = [
|
||||
"resend>=2.27.0,<3.0.0",
|
||||
# Emerging: newer and fast-moving, use compatible pins
|
||||
"fastopenapi[flask]==0.7.0",
|
||||
"graphon==0.5.1",
|
||||
"graphon==0.5.2",
|
||||
"httpx-sse==0.4.3",
|
||||
"json-repair==0.59.4",
|
||||
]
|
||||
|
||||
@ -957,21 +957,22 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
)
|
||||
pause_reason_models = []
|
||||
for reason in pause_reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
form_id=reason.form_id,
|
||||
)
|
||||
elif isinstance(reason, SchedulingPause):
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
message=reason.message,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"unkown reason type: {type(reason)}")
|
||||
match reason:
|
||||
case HumanInputRequired():
|
||||
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
form_id=reason.form_id,
|
||||
)
|
||||
case SchedulingPause():
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
message=reason.message,
|
||||
)
|
||||
case _:
|
||||
raise AssertionError(f"unknown reason type: {type(reason)}")
|
||||
|
||||
pause_reason_models.append(pause_reason_model)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from hashlib import sha256
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, NotRequired, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
from sqlalchemy import Row, delete, func, select, update
|
||||
@ -18,6 +18,8 @@ class InvitationData(TypedDict):
|
||||
account_id: str
|
||||
email: str
|
||||
workspace_id: str
|
||||
role: NotRequired[str]
|
||||
requires_setup: NotRequired[bool]
|
||||
|
||||
|
||||
_invitation_adapter: TypeAdapter[InvitationData] = TypeAdapter(InvitationData)
|
||||
@ -1805,6 +1807,7 @@ class RegisterService:
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
requires_setup = False
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
name = normalized_email.split("@")[0]
|
||||
@ -1819,6 +1822,7 @@ class RegisterService:
|
||||
# Create new tenant member for invited tenant
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
requires_setup = True
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = db.session.scalar(
|
||||
@ -1826,15 +1830,16 @@ class RegisterService:
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
requires_setup = account.status == AccountStatus.PENDING
|
||||
|
||||
if not ta:
|
||||
if not ta and account.status == AccountStatus.PENDING:
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
|
||||
# Support resend invitation email when the account is pending status
|
||||
if account.status != AccountStatus.PENDING:
|
||||
if ta and account.status != AccountStatus.PENDING:
|
||||
raise AccountAlreadyInTenantError("Account already in tenant.")
|
||||
|
||||
token = cls.generate_invite_token(tenant, account)
|
||||
token = cls.generate_invite_token(tenant, account, role, requires_setup=requires_setup)
|
||||
language = account.interface_language or "en-US"
|
||||
|
||||
# send email
|
||||
@ -1849,12 +1854,16 @@ class RegisterService:
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
|
||||
def generate_invite_token(
|
||||
cls, tenant: Tenant, account: Account, role: str = "normal", *, requires_setup: bool = False
|
||||
) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
invitation_data = {
|
||||
"account_id": account.id,
|
||||
"email": account.email,
|
||||
"workspace_id": tenant.id,
|
||||
"role": str(role),
|
||||
"requires_setup": requires_setup,
|
||||
}
|
||||
expiry_hours = dify_config.INVITE_EXPIRY_HOURS
|
||||
redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data))
|
||||
@ -1889,16 +1898,7 @@ class RegisterService:
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
tenant_account = db.session.execute(
|
||||
select(Account, TenantAccountJoin.role)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
|
||||
).first()
|
||||
|
||||
if not tenant_account:
|
||||
return None
|
||||
|
||||
account = tenant_account[0]
|
||||
account = db.session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1))
|
||||
if not account:
|
||||
return None
|
||||
|
||||
|
||||
@ -334,16 +334,17 @@ class ComposerConfigValidator:
|
||||
|
||||
@classmethod
|
||||
def _reject_plaintext_secrets(cls, value: Any, *, path: str) -> None:
|
||||
if isinstance(value, dict):
|
||||
for key, nested in value.items():
|
||||
normalized_key = key.lower().replace("-", "_")
|
||||
nested_path = f"{path}.{key}"
|
||||
if normalized_key in _PLAINTEXT_SECRET_KEYS and isinstance(nested, str) and nested:
|
||||
raise PlaintextSecretNotAllowedError(f"Plaintext secret is not allowed at {nested_path}")
|
||||
cls._reject_plaintext_secrets(nested, path=nested_path)
|
||||
elif isinstance(value, list):
|
||||
for index, nested in enumerate(value):
|
||||
cls._reject_plaintext_secrets(nested, path=f"{path}[{index}]")
|
||||
match value:
|
||||
case dict():
|
||||
for key, nested in value.items():
|
||||
normalized_key = key.lower().replace("-", "_")
|
||||
nested_path = f"{path}.{key}"
|
||||
if normalized_key in _PLAINTEXT_SECRET_KEYS and isinstance(nested, str) and nested:
|
||||
raise PlaintextSecretNotAllowedError(f"Plaintext secret is not allowed at {nested_path}")
|
||||
cls._reject_plaintext_secrets(nested, path=nested_path)
|
||||
case list():
|
||||
for index, nested in enumerate(value):
|
||||
cls._reject_plaintext_secrets(nested, path=f"{path}[{index}]")
|
||||
|
||||
@classmethod
|
||||
def _has_install_command(cls, entry: dict[str, Any]) -> bool:
|
||||
|
||||
@ -6,12 +6,15 @@ from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import aliased
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.helper import convert_datetime_to_date, escape_like_pattern, to_timestamp
|
||||
from models.agent import WorkflowAgentNodeBinding
|
||||
from models.enums import MessageStatus
|
||||
from models.model import App, Conversation, Message
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -33,6 +36,16 @@ class AgentStatisticsQueryParams:
|
||||
timezone: str = "UTC"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentSourceFilter:
|
||||
kind: str
|
||||
app_id: str | None = None
|
||||
workflow_id: str | None = None
|
||||
workflow_version: str | None = None
|
||||
node_id: str | None = None
|
||||
invoke_from: InvokeFrom | None = None
|
||||
|
||||
|
||||
class AgentObservabilityService:
|
||||
_SOURCE_ALIASES: dict[str, InvokeFrom] = {
|
||||
"api": InvokeFrom.SERVICE_API,
|
||||
@ -66,6 +79,31 @@ class AgentObservabilityService:
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Unsupported source: {source}") from exc
|
||||
|
||||
@classmethod
|
||||
def resolve_source_filter(cls, source: str | None) -> AgentSourceFilter:
|
||||
if not source or source.strip().lower() == "all":
|
||||
return AgentSourceFilter(kind="all")
|
||||
normalized = source.strip()
|
||||
lowered = normalized.lower()
|
||||
if lowered == "webapp":
|
||||
return AgentSourceFilter(kind="webapp")
|
||||
if lowered.startswith("webapp:"):
|
||||
return AgentSourceFilter(kind="webapp", app_id=normalized.split(":", 1)[1] or None)
|
||||
if lowered == "workflow":
|
||||
return AgentSourceFilter(kind="workflow")
|
||||
if lowered.startswith("workflow:"):
|
||||
parts = normalized.split(":", 4)
|
||||
if len(parts) != 5 or not all(parts[1:]):
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return AgentSourceFilter(
|
||||
kind="workflow",
|
||||
app_id=parts[1],
|
||||
workflow_id=parts[2],
|
||||
workflow_version=parts[3],
|
||||
node_id=parts[4],
|
||||
)
|
||||
return AgentSourceFilter(kind="webapp", invoke_from=cls.resolve_source(source))
|
||||
|
||||
@staticmethod
|
||||
def _message_status(message: Message) -> str:
|
||||
if message.error or message.status == MessageStatus.ERROR:
|
||||
@ -104,19 +142,255 @@ class AgentObservabilityService:
|
||||
"updated_at": to_timestamp(message.updated_at),
|
||||
}
|
||||
|
||||
def list_logs(self, *, app: App, params: AgentLogQueryParams) -> dict[str, Any]:
|
||||
source = self.resolve_source(params.source)
|
||||
stmt = (
|
||||
select(Message, Conversation)
|
||||
.join(Conversation, Conversation.id == Message.conversation_id)
|
||||
.where(Message.app_id == app.id, Conversation.app_id == app.id)
|
||||
)
|
||||
stmt = self._apply_source_filter(stmt, source)
|
||||
def list_logs(self, *, app: App, agent_id: str, params: AgentLogQueryParams) -> dict[str, Any]:
|
||||
source_filter = self.resolve_source_filter(params.source)
|
||||
rows: list[dict[str, Any]] = []
|
||||
if source_filter.kind in {"all", "webapp"}:
|
||||
rows.extend(self._list_webapp_conversation_logs(app=app, params=params, source_filter=source_filter))
|
||||
if source_filter.kind in {"all", "workflow"}:
|
||||
rows.extend(
|
||||
self._list_workflow_conversation_logs(
|
||||
app=app,
|
||||
agent_id=agent_id,
|
||||
params=params,
|
||||
source_filter=source_filter,
|
||||
)
|
||||
)
|
||||
rows.sort(key=lambda row: (row["updated_at"] or 0, row["id"]), reverse=True)
|
||||
|
||||
if params.start:
|
||||
stmt = stmt.where(Message.created_at >= params.start)
|
||||
if params.end:
|
||||
stmt = stmt.where(Message.created_at < params.end)
|
||||
total = len(rows)
|
||||
start = (params.page - 1) * params.limit
|
||||
end = start + params.limit
|
||||
return {
|
||||
"data": rows[start:end],
|
||||
"page": params.page,
|
||||
"limit": params.limit,
|
||||
"total": total,
|
||||
"has_more": end < total,
|
||||
}
|
||||
|
||||
def list_log_messages(
|
||||
self, *, app: App, agent_id: str, conversation_id: str, params: AgentLogQueryParams
|
||||
) -> dict[str, Any]:
|
||||
source_filter = self.resolve_source_filter(params.source)
|
||||
rows: list[Message] = []
|
||||
if source_filter.kind in {"all", "webapp"}:
|
||||
rows.extend(
|
||||
self._list_webapp_messages(
|
||||
app=app,
|
||||
conversation_id=conversation_id,
|
||||
params=params,
|
||||
source_filter=source_filter,
|
||||
)
|
||||
)
|
||||
if source_filter.kind in {"all", "workflow"}:
|
||||
rows.extend(
|
||||
self._list_workflow_messages(
|
||||
app=app,
|
||||
agent_id=agent_id,
|
||||
conversation_id=conversation_id,
|
||||
params=params,
|
||||
source_filter=source_filter,
|
||||
)
|
||||
)
|
||||
|
||||
deduped = {message.id: message for message in rows}
|
||||
sorted_rows = sorted(deduped.values(), key=lambda message: (message.created_at, message.id), reverse=True)
|
||||
total = len(sorted_rows)
|
||||
start = (params.page - 1) * params.limit
|
||||
end = start + params.limit
|
||||
return {
|
||||
"data": [self.serialize_log_message(message) for message in sorted_rows[start:end]],
|
||||
"page": params.page,
|
||||
"limit": params.limit,
|
||||
"total": total,
|
||||
"has_more": end < total,
|
||||
}
|
||||
|
||||
def list_log_sources(self, *, app: App, agent_id: str) -> dict[str, Any]:
|
||||
webapp_source = self._serialize_webapp_source(app)
|
||||
workflow_sources = self._list_workflow_sources(app=app, agent_id=agent_id)
|
||||
return {
|
||||
"data": [webapp_source, *workflow_sources],
|
||||
"groups": [
|
||||
{"type": "webapp", "label": "WEBAPP", "sources": [webapp_source]},
|
||||
{"type": "workflow", "label": "WORKFLOW", "sources": workflow_sources},
|
||||
],
|
||||
}
|
||||
|
||||
def _list_webapp_conversation_logs(
|
||||
self, *, app: App, params: AgentLogQueryParams, source_filter: AgentSourceFilter
|
||||
) -> list[dict[str, Any]]:
|
||||
stmt = (
|
||||
select(
|
||||
Conversation,
|
||||
func.count(Message.id).label("message_count"),
|
||||
func.max(Message.created_at).label("created_at"),
|
||||
func.max(Message.updated_at).label("updated_at"),
|
||||
func.sum(sa.case((Message.status == MessageStatus.PAUSED, 1), else_=0)).label("paused_count"),
|
||||
func.sum(
|
||||
sa.case((or_(Message.error.is_not(None), Message.status == MessageStatus.ERROR), 1), else_=0)
|
||||
).label("failed_count"),
|
||||
)
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.where(Message.app_id == app.id, Conversation.app_id == app.id)
|
||||
.group_by(Conversation.id)
|
||||
)
|
||||
stmt = self._apply_observability_filters(stmt, params=params, source_filter=source_filter)
|
||||
rows = list(self._session.execute(stmt).all())
|
||||
return [
|
||||
self._serialize_conversation_log(
|
||||
conversation=row[0],
|
||||
message_count=row.message_count,
|
||||
paused_count=row.paused_count,
|
||||
failed_count=row.failed_count,
|
||||
source=self._serialize_webapp_source(app),
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def _list_workflow_conversation_logs(
|
||||
self, *, app: App, agent_id: str, params: AgentLogQueryParams, source_filter: AgentSourceFilter
|
||||
) -> list[dict[str, Any]]:
|
||||
workflow_app = aliased(App)
|
||||
stmt = (
|
||||
select(
|
||||
Conversation,
|
||||
workflow_app,
|
||||
WorkflowAgentNodeBinding.workflow_id,
|
||||
WorkflowAgentNodeBinding.workflow_version,
|
||||
WorkflowAgentNodeBinding.node_id,
|
||||
func.count(sa.distinct(Message.id)).label("message_count"),
|
||||
func.max(Message.created_at).label("created_at"),
|
||||
func.max(Message.updated_at).label("updated_at"),
|
||||
func.sum(sa.case((Message.status == MessageStatus.PAUSED, 1), else_=0)).label("paused_count"),
|
||||
func.sum(
|
||||
sa.case((or_(Message.error.is_not(None), Message.status == MessageStatus.ERROR), 1), else_=0)
|
||||
).label("failed_count"),
|
||||
)
|
||||
.select_from(Message)
|
||||
.join(Conversation, Conversation.id == Message.conversation_id)
|
||||
.join(WorkflowRun, WorkflowRun.id == Message.workflow_run_id)
|
||||
.join(
|
||||
WorkflowAgentNodeBinding,
|
||||
and_(
|
||||
WorkflowAgentNodeBinding.tenant_id == app.tenant_id,
|
||||
WorkflowAgentNodeBinding.agent_id == agent_id,
|
||||
WorkflowAgentNodeBinding.app_id == WorkflowRun.app_id,
|
||||
WorkflowAgentNodeBinding.workflow_id == WorkflowRun.workflow_id,
|
||||
WorkflowAgentNodeBinding.workflow_version == WorkflowRun.version,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
WorkflowNodeExecutionModel,
|
||||
and_(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == WorkflowRun.id,
|
||||
WorkflowNodeExecutionModel.node_id == WorkflowAgentNodeBinding.node_id,
|
||||
),
|
||||
)
|
||||
.join(workflow_app, workflow_app.id == WorkflowAgentNodeBinding.app_id)
|
||||
.where(Message.workflow_run_id.is_not(None), Conversation.app_id == WorkflowAgentNodeBinding.app_id)
|
||||
.group_by(
|
||||
Conversation.id,
|
||||
workflow_app.id,
|
||||
WorkflowAgentNodeBinding.workflow_id,
|
||||
WorkflowAgentNodeBinding.workflow_version,
|
||||
WorkflowAgentNodeBinding.node_id,
|
||||
)
|
||||
)
|
||||
stmt = self._apply_observability_filters(stmt, params=params, source_filter=source_filter)
|
||||
stmt = self._apply_workflow_source_filter(stmt, source_filter)
|
||||
rows = list(self._session.execute(stmt).all())
|
||||
return [
|
||||
self._serialize_conversation_log(
|
||||
conversation=row[0],
|
||||
message_count=row.message_count,
|
||||
paused_count=row.paused_count,
|
||||
failed_count=row.failed_count,
|
||||
source=self._serialize_workflow_source(
|
||||
app=row[1],
|
||||
workflow_id=row.workflow_id,
|
||||
workflow_version=row.workflow_version,
|
||||
node_id=row.node_id,
|
||||
),
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def _list_webapp_messages(
|
||||
self, *, app: App, conversation_id: str, params: AgentLogQueryParams, source_filter: AgentSourceFilter
|
||||
) -> list[Message]:
|
||||
stmt = select(Message).where(Message.app_id == app.id, Message.conversation_id == conversation_id)
|
||||
stmt = self._apply_message_filters(stmt, params=params, source_filter=source_filter)
|
||||
return list(self._session.scalars(stmt.order_by(Message.created_at.desc(), Message.id.desc())).all())
|
||||
|
||||
def _list_workflow_messages(
|
||||
self,
|
||||
*,
|
||||
app: App,
|
||||
agent_id: str,
|
||||
conversation_id: str,
|
||||
params: AgentLogQueryParams,
|
||||
source_filter: AgentSourceFilter,
|
||||
) -> list[Message]:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.join(WorkflowRun, WorkflowRun.id == Message.workflow_run_id)
|
||||
.join(
|
||||
WorkflowAgentNodeBinding,
|
||||
and_(
|
||||
WorkflowAgentNodeBinding.tenant_id == app.tenant_id,
|
||||
WorkflowAgentNodeBinding.agent_id == agent_id,
|
||||
WorkflowAgentNodeBinding.app_id == WorkflowRun.app_id,
|
||||
WorkflowAgentNodeBinding.workflow_id == WorkflowRun.workflow_id,
|
||||
WorkflowAgentNodeBinding.workflow_version == WorkflowRun.version,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
WorkflowNodeExecutionModel,
|
||||
and_(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == WorkflowRun.id,
|
||||
WorkflowNodeExecutionModel.node_id == WorkflowAgentNodeBinding.node_id,
|
||||
),
|
||||
)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
)
|
||||
stmt = self._apply_message_filters(stmt, params=params, source_filter=source_filter)
|
||||
stmt = self._apply_workflow_source_filter(stmt, source_filter)
|
||||
return list(self._session.scalars(stmt.order_by(Message.created_at.desc(), Message.id.desc())).all())
|
||||
|
||||
def _list_workflow_sources(self, *, app: App, agent_id: str) -> list[dict[str, Any]]:
|
||||
workflow_app = aliased(App)
|
||||
stmt = (
|
||||
select(
|
||||
workflow_app,
|
||||
WorkflowAgentNodeBinding.workflow_id,
|
||||
WorkflowAgentNodeBinding.workflow_version,
|
||||
WorkflowAgentNodeBinding.node_id,
|
||||
)
|
||||
.join(workflow_app, workflow_app.id == WorkflowAgentNodeBinding.app_id)
|
||||
.where(WorkflowAgentNodeBinding.tenant_id == app.tenant_id, WorkflowAgentNodeBinding.agent_id == agent_id)
|
||||
.order_by(workflow_app.name.asc(), WorkflowAgentNodeBinding.node_id.asc())
|
||||
)
|
||||
rows = self._session.execute(stmt).all()
|
||||
deduped: dict[str, dict[str, Any]] = {}
|
||||
for row in rows:
|
||||
source = self._serialize_workflow_source(
|
||||
app=row[0],
|
||||
workflow_id=row.workflow_id,
|
||||
workflow_version=row.workflow_version,
|
||||
node_id=row.node_id,
|
||||
)
|
||||
deduped[source["id"]] = source
|
||||
return list(deduped.values())
|
||||
|
||||
@classmethod
|
||||
def _apply_observability_filters(cls, stmt, *, params: AgentLogQueryParams, source_filter: AgentSourceFilter):
|
||||
stmt = cls._apply_message_filters(stmt, params=params, source_filter=source_filter, include_keyword=False)
|
||||
if params.keyword:
|
||||
escaped_keyword = escape_like_pattern(params.keyword)
|
||||
pattern = f"%{escaped_keyword}%"
|
||||
@ -127,27 +401,41 @@ class AgentObservabilityService:
|
||||
Conversation.name.ilike(pattern, escape="\\"),
|
||||
)
|
||||
)
|
||||
if params.status:
|
||||
stmt = self._apply_status_filter(stmt, params.status)
|
||||
return stmt
|
||||
|
||||
total = self._session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
|
||||
rows = list(
|
||||
self._session.execute(
|
||||
stmt.order_by(Message.created_at.desc(), Message.id.desc())
|
||||
.offset((params.page - 1) * params.limit)
|
||||
.limit(params.limit)
|
||||
).all()
|
||||
)
|
||||
data = []
|
||||
for message, conversation in rows:
|
||||
data.append(self.serialize_log_message(message, conversation))
|
||||
return {
|
||||
"data": data,
|
||||
"page": params.page,
|
||||
"limit": params.limit,
|
||||
"total": total,
|
||||
"has_more": params.page * params.limit < total,
|
||||
}
|
||||
@classmethod
|
||||
def _apply_message_filters(
|
||||
cls, stmt, *, params: AgentLogQueryParams, source_filter: AgentSourceFilter, include_keyword: bool = True
|
||||
):
|
||||
stmt = cls._apply_source_filter(stmt, source_filter.invoke_from)
|
||||
if params.start:
|
||||
stmt = stmt.where(Message.created_at >= params.start)
|
||||
if params.end:
|
||||
stmt = stmt.where(Message.created_at < params.end)
|
||||
if include_keyword and params.keyword:
|
||||
escaped_keyword = escape_like_pattern(params.keyword)
|
||||
pattern = f"%{escaped_keyword}%"
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Message.query.ilike(pattern, escape="\\"),
|
||||
Message.answer.ilike(pattern, escape="\\"),
|
||||
)
|
||||
)
|
||||
if params.status:
|
||||
stmt = cls._apply_status_filter(stmt, params.status)
|
||||
return stmt
|
||||
|
||||
@staticmethod
|
||||
def _apply_workflow_source_filter(stmt, source_filter: AgentSourceFilter):
|
||||
if source_filter.app_id:
|
||||
stmt = stmt.where(WorkflowAgentNodeBinding.app_id == source_filter.app_id)
|
||||
if source_filter.workflow_id:
|
||||
stmt = stmt.where(WorkflowAgentNodeBinding.workflow_id == source_filter.workflow_id)
|
||||
if source_filter.workflow_version:
|
||||
stmt = stmt.where(WorkflowAgentNodeBinding.workflow_version == source_filter.workflow_version)
|
||||
if source_filter.node_id:
|
||||
stmt = stmt.where(WorkflowAgentNodeBinding.node_id == source_filter.node_id)
|
||||
return stmt
|
||||
|
||||
@classmethod
|
||||
def _apply_source_filter(cls, stmt, source: InvokeFrom | None):
|
||||
@ -166,22 +454,95 @@ class AgentObservabilityService:
|
||||
return stmt.where(Message.status == MessageStatus.PAUSED)
|
||||
raise ValueError(f"Unsupported status: {status}")
|
||||
|
||||
def get_statistics_summary(self, *, app: App, params: AgentStatisticsQueryParams) -> dict[str, Any]:
|
||||
source = self.resolve_source(params.source)
|
||||
rows = self._load_daily_statistics(app=app, params=params, source=source)
|
||||
@classmethod
|
||||
def _serialize_conversation_log(
|
||||
cls,
|
||||
*,
|
||||
conversation: Conversation,
|
||||
message_count: int,
|
||||
paused_count: int,
|
||||
failed_count: int,
|
||||
source: dict[str, Any],
|
||||
created_at: datetime | None,
|
||||
updated_at: datetime | None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"id": conversation.id,
|
||||
"conversation_id": conversation.id,
|
||||
"title": conversation.name,
|
||||
"end_user_id": conversation.from_end_user_id,
|
||||
"message_count": int(message_count or 0),
|
||||
"user_rate": None,
|
||||
"operation_rate": None,
|
||||
"unread": conversation.read_at is None,
|
||||
"source": source,
|
||||
"status": cls._conversation_status(paused_count=paused_count, failed_count=failed_count),
|
||||
"created_at": to_timestamp(created_at or conversation.created_at),
|
||||
"updated_at": to_timestamp(updated_at or conversation.updated_at),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _conversation_status(*, paused_count: int, failed_count: int) -> str:
|
||||
if paused_count:
|
||||
return "paused"
|
||||
if failed_count:
|
||||
return "failed"
|
||||
return "success"
|
||||
|
||||
@staticmethod
|
||||
def _serialize_webapp_source(app: App) -> dict[str, Any]:
|
||||
icon_type = app.icon_type.value if app.icon_type else None
|
||||
return {
|
||||
"id": f"webapp:{app.id}",
|
||||
"type": "webapp",
|
||||
"app_id": app.id,
|
||||
"app_name": app.name,
|
||||
"app_icon_type": icon_type,
|
||||
"app_icon": app.icon,
|
||||
"app_icon_background": app.icon_background,
|
||||
"workflow_id": None,
|
||||
"workflow_version": None,
|
||||
"node_id": None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_workflow_source(
|
||||
*,
|
||||
app: App,
|
||||
workflow_id: str,
|
||||
workflow_version: str,
|
||||
node_id: str,
|
||||
) -> dict[str, Any]:
|
||||
icon_type = app.icon_type.value if app.icon_type else None
|
||||
return {
|
||||
"id": f"workflow:{app.id}:{workflow_id}:{workflow_version}:{node_id}",
|
||||
"type": "workflow",
|
||||
"app_id": app.id,
|
||||
"app_name": app.name,
|
||||
"app_icon_type": icon_type,
|
||||
"app_icon": app.icon,
|
||||
"app_icon_background": app.icon_background,
|
||||
"workflow_id": workflow_id,
|
||||
"workflow_version": workflow_version,
|
||||
"node_id": node_id,
|
||||
}
|
||||
|
||||
def get_statistics_summary(self, *, app: App, agent_id: str, params: AgentStatisticsQueryParams) -> dict[str, Any]:
|
||||
source_filter = self.resolve_source_filter(params.source)
|
||||
rows = self._load_daily_statistics(app=app, agent_id=agent_id, params=params, source_filter=source_filter)
|
||||
charts = self._build_charts(rows)
|
||||
summary = self._build_summary(rows)
|
||||
return {
|
||||
"source": source.value if source else "all",
|
||||
"source": params.source or "all",
|
||||
"summary": summary,
|
||||
"charts": charts,
|
||||
}
|
||||
|
||||
def _load_daily_statistics(
|
||||
self, *, app: App, params: AgentStatisticsQueryParams, source: InvokeFrom | None
|
||||
self, *, app: App, agent_id: str, params: AgentStatisticsQueryParams, source_filter: AgentSourceFilter
|
||||
) -> list[dict[str, Any]]:
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
source_condition = "AND m.invoke_from != :debugger" if source is None else "AND m.invoke_from = :source"
|
||||
message_scope = self._statistics_message_scope_sql(source_filter)
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(m.id) AS message_count,
|
||||
@ -197,15 +558,24 @@ FROM messages m
|
||||
LEFT JOIN message_feedbacks mf
|
||||
ON mf.message_id = m.id AND mf.rating = 'like'
|
||||
WHERE
|
||||
m.app_id = :app_id
|
||||
{source_condition}"""
|
||||
{message_scope}"""
|
||||
args: dict[str, Any] = {
|
||||
"tz": params.timezone,
|
||||
"app_id": app.id,
|
||||
"tenant_id": app.tenant_id,
|
||||
"agent_id": agent_id,
|
||||
"debugger": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
if source is not None:
|
||||
args["source"] = source
|
||||
if source_filter.invoke_from is not None:
|
||||
args["source"] = source_filter.invoke_from
|
||||
if source_filter.app_id:
|
||||
args["source_app_id"] = source_filter.app_id
|
||||
if source_filter.workflow_id:
|
||||
args["workflow_id"] = source_filter.workflow_id
|
||||
if source_filter.workflow_version:
|
||||
args["workflow_version"] = source_filter.workflow_version
|
||||
if source_filter.node_id:
|
||||
args["node_id"] = source_filter.node_id
|
||||
if params.start:
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
args["start"] = params.start
|
||||
@ -216,6 +586,45 @@ WHERE
|
||||
|
||||
return [dict(row._mapping) for row in self._session.execute(sa.text(sql_query), args).all()]
|
||||
|
||||
@staticmethod
|
||||
def _statistics_message_scope_sql(source_filter: AgentSourceFilter) -> str:
|
||||
app_scope = "m.app_id = :app_id"
|
||||
if source_filter.invoke_from is None:
|
||||
app_scope += " AND m.invoke_from != :debugger"
|
||||
else:
|
||||
app_scope += " AND m.invoke_from = :source"
|
||||
workflow_binding_filters = []
|
||||
if source_filter.app_id:
|
||||
workflow_binding_filters.append("wanb.app_id = :source_app_id")
|
||||
if source_filter.workflow_id:
|
||||
workflow_binding_filters.append("wanb.workflow_id = :workflow_id")
|
||||
if source_filter.workflow_version:
|
||||
workflow_binding_filters.append("wanb.workflow_version = :workflow_version")
|
||||
if source_filter.node_id:
|
||||
workflow_binding_filters.append("wanb.node_id = :node_id")
|
||||
extra_workflow_filters = f"AND {' AND '.join(workflow_binding_filters)}" if workflow_binding_filters else ""
|
||||
workflow_scope = f"""m.workflow_run_id IS NOT NULL
|
||||
AND EXISTS (
|
||||
SELECT 1
|
||||
FROM workflow_runs wr
|
||||
JOIN workflow_agent_node_bindings wanb
|
||||
ON wanb.tenant_id = :tenant_id
|
||||
AND wanb.agent_id = :agent_id
|
||||
AND wanb.app_id = wr.app_id
|
||||
AND wanb.workflow_id = wr.workflow_id
|
||||
AND wanb.workflow_version = wr.version
|
||||
{extra_workflow_filters}
|
||||
JOIN workflow_node_executions wne
|
||||
ON wne.workflow_run_id = wr.id
|
||||
AND wne.node_id = wanb.node_id
|
||||
WHERE wr.id = m.workflow_run_id
|
||||
)"""
|
||||
if source_filter.kind == "webapp":
|
||||
return app_scope
|
||||
if source_filter.kind == "workflow":
|
||||
return workflow_scope
|
||||
return f"(({app_scope}) OR ({workflow_scope}))"
|
||||
|
||||
@staticmethod
|
||||
def _build_charts(rows: list[dict[str, Any]]) -> dict[str, list[dict[str, Any]]]:
|
||||
messages = []
|
||||
|
||||
@ -10,7 +10,9 @@ to the agent drive (Agent Files §5.4 / §4):
|
||||
Both are stored as ``ToolFile`` records and bound via ``AgentDriveService.commit``
|
||||
with ``value_owned_by_drive=True`` (the drive owns their lifecycle). The returned
|
||||
skill ref records the stable drive paths + file ids (not just the raw upload id),
|
||||
so the Composer can reload the bound skill list.
|
||||
so the Composer can reload the bound skill list. The console ``/skills/upload``
|
||||
endpoints delegate to this service so "upload" now always means drive-backed skill
|
||||
normalization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -34,7 +36,7 @@ def slugify_skill_name(name: str) -> str:
|
||||
|
||||
|
||||
class SkillStandardizeService:
|
||||
"""Validate + standardize a Skill package into a per-agent drive."""
|
||||
"""Validate + standardize a Skill package into a per-agent drive upload result."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -75,7 +75,7 @@ class CreateAppParams(BaseModel):
|
||||
class AppService:
|
||||
@staticmethod
|
||||
def _build_app_list_filters(
|
||||
user_id: str, tenant_id: str, params: AppListBaseParams
|
||||
user_id: str, tenant_id: str, params: AppListBaseParams, session: scoped_session
|
||||
) -> list[sa.ColumnElement[bool]]:
|
||||
filters = [App.tenant_id == tenant_id, App.is_universal == False]
|
||||
|
||||
@ -115,7 +115,7 @@ class AppService:
|
||||
escaped_name = escape_like_pattern(name)
|
||||
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
|
||||
if params.tag_ids and len(params.tag_ids) > 0:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, match_all=True)
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, session, match_all=True)
|
||||
if target_ids and len(target_ids) > 0:
|
||||
filters.append(App.id.in_(target_ids))
|
||||
else:
|
||||
@ -197,7 +197,9 @@ class AppService:
|
||||
).scalars()
|
||||
)
|
||||
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
|
||||
def get_paginate_apps(
|
||||
self, user_id: str, tenant_id: str, params: AppListParams, session: scoped_session
|
||||
) -> Pagination | None:
|
||||
"""
|
||||
Get app list with pagination, filters, and explicit sort order.
|
||||
:param user_id: user id
|
||||
@ -205,7 +207,7 @@ class AppService:
|
||||
:param params: query parameters
|
||||
:return:
|
||||
"""
|
||||
filters = self._build_app_list_filters(user_id, tenant_id, params)
|
||||
filters = self._build_app_list_filters(user_id, tenant_id, params, session)
|
||||
if not filters:
|
||||
return None
|
||||
|
||||
@ -231,12 +233,12 @@ class AppService:
|
||||
return app_models
|
||||
|
||||
def get_paginate_starred_apps(
|
||||
self, user_id: str, tenant_id: str, params: StarredAppListParams
|
||||
self, user_id: str, tenant_id: str, params: StarredAppListParams, session: scoped_session
|
||||
) -> Pagination | None:
|
||||
"""
|
||||
Get apps starred by the current account with pagination, filters, and explicit sort order.
|
||||
"""
|
||||
filters = self._build_app_list_filters(user_id, tenant_id, params)
|
||||
filters = self._build_app_list_filters(user_id, tenant_id, params, session)
|
||||
if not filters:
|
||||
return None
|
||||
|
||||
@ -540,17 +542,21 @@ class AppService:
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
role: str | None = None,
|
||||
icon_type: IconType | str | None = None,
|
||||
icon: str | None = None,
|
||||
icon_background: str | None = None,
|
||||
role: str | None = None,
|
||||
account_id: str | None = None,
|
||||
updated_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Keep the Roster identity aligned with its Agent App shell.
|
||||
|
||||
Agent Soul remains versioned through Composer. This helper only mirrors
|
||||
user-facing identity fields so Roster and Agent Console do not drift.
|
||||
user-facing identity fields, including the roster role/persona label,
|
||||
so Roster and Agent Console do not drift.
|
||||
|
||||
Role omission is intentional: ``role=None`` preserves the backing
|
||||
Agent's current role, while ``role=""`` explicitly clears it.
|
||||
"""
|
||||
agent = self._get_backing_agent_for_update(app)
|
||||
if agent is None:
|
||||
@ -560,14 +566,14 @@ class AppService:
|
||||
agent.name = name
|
||||
if description is not None:
|
||||
agent.description = description
|
||||
if role is not None:
|
||||
agent.role = role
|
||||
if icon_type is not None:
|
||||
agent.icon_type = self._to_agent_icon_type(icon_type)
|
||||
if icon is not None:
|
||||
agent.icon = icon
|
||||
if icon_background is not None:
|
||||
agent.icon_background = icon_background
|
||||
if role is not None:
|
||||
agent.role = role
|
||||
agent.updated_by = account_id
|
||||
if updated_at is not None:
|
||||
agent.updated_at = updated_at
|
||||
@ -599,10 +605,12 @@ class AppService:
|
||||
app,
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
# Omitted role must stay omitted here: None means "preserve current
|
||||
# backing-agent role", while an empty string is an explicit clear.
|
||||
role=args.get("role"),
|
||||
icon_type=app.icon_type,
|
||||
icon=app.icon,
|
||||
icon_background=app.icon_background,
|
||||
role=args.get("role"),
|
||||
account_id=current_user.id,
|
||||
updated_at=app.updated_at,
|
||||
)
|
||||
|
||||
@ -14,12 +14,13 @@ class AttachmentService:
|
||||
_session_maker: sessionmaker
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None):
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_maker = session_factory
|
||||
else:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
match session_factory:
|
||||
case Engine():
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
case sessionmaker():
|
||||
self._session_maker = session_factory
|
||||
case _:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
|
||||
def get_file_base64(self, file_id: str) -> str:
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
|
||||
@ -13,7 +13,7 @@ import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import delete, exists, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@ -235,7 +235,9 @@ class _EstimateArgs(BaseModel):
|
||||
|
||||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
||||
def get_datasets(
|
||||
page, per_page, session: scoped_session, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False
|
||||
):
|
||||
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
|
||||
|
||||
if user:
|
||||
@ -295,6 +297,7 @@ class DatasetService:
|
||||
"knowledge",
|
||||
tenant_id,
|
||||
tag_ids,
|
||||
session,
|
||||
match_all=True,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -60,6 +60,7 @@ class ComposerSavePayload(BaseModel):
|
||||
|
||||
class RosterAgentCreatePayload(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=255)
|
||||
mode: Literal["agent"] = "agent"
|
||||
description: str = ""
|
||||
role: str = Field(default="", max_length=255)
|
||||
icon_type: AgentIconType | None = None
|
||||
|
||||
@ -39,12 +39,13 @@ class FileService:
|
||||
_session_maker: sessionmaker[Session]
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None):
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_maker = session_factory
|
||||
else:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
match session_factory:
|
||||
case Engine():
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
case sessionmaker():
|
||||
self._session_maker = session_factory
|
||||
case _:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
|
||||
@ -119,10 +119,11 @@ class HumanInputDeliveryTestService:
|
||||
|
||||
class EmailDeliveryTestHandler:
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None:
|
||||
if session_factory is None:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
elif isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
match session_factory:
|
||||
case None:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
case Engine():
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
|
||||
def supports(self, method: DeliveryChannelConfig) -> bool:
|
||||
@ -179,11 +180,12 @@ class EmailDeliveryTestHandler:
|
||||
emails: list[str] = []
|
||||
bound_reference_ids: list[str] = []
|
||||
for recipient in recipients.items:
|
||||
if isinstance(recipient, MemberRecipient):
|
||||
bound_reference_ids.append(recipient.reference_id)
|
||||
elif isinstance(recipient, ExternalRecipient):
|
||||
if recipient.email:
|
||||
emails.append(recipient.email)
|
||||
match recipient:
|
||||
case MemberRecipient():
|
||||
bound_reference_ids.append(recipient.reference_id)
|
||||
case ExternalRecipient():
|
||||
if recipient.email:
|
||||
emails.append(recipient.email)
|
||||
|
||||
if recipients.include_bound_group:
|
||||
bound_reference_ids = []
|
||||
|
||||
@ -6,7 +6,7 @@ from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
|
||||
from core.db import session_factory
|
||||
from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
@ -192,6 +192,7 @@ class SnippetService:
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
session: scoped_session,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
keyword: str | None = None,
|
||||
@ -229,20 +230,19 @@ class SnippetService:
|
||||
stmt = stmt.where(CustomizedSnippet.created_by.in_(creators))
|
||||
|
||||
if tag_ids:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, match_all=True)
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, session, match_all=True)
|
||||
if target_ids:
|
||||
stmt = stmt.where(CustomizedSnippet.id.in_(target_ids))
|
||||
else:
|
||||
return [], 0, False
|
||||
|
||||
with self._session_scope() as session:
|
||||
# Get total count
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = session.scalar(count_stmt) or 0
|
||||
# Get total count
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = session.scalar(count_stmt) or 0
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
|
||||
snippets = list(session.scalars(stmt).all())
|
||||
# Apply pagination
|
||||
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
|
||||
snippets = list(session.scalars(stmt).all())
|
||||
|
||||
has_more = len(snippets) > limit
|
||||
if has_more:
|
||||
|
||||
@ -6,10 +6,9 @@ from flask_login import current_user
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
@ -56,7 +55,7 @@ class TagService:
|
||||
|
||||
@staticmethod
|
||||
def get_target_ids_by_tag_ids(
|
||||
tag_type: str, current_tenant_id: str, tag_ids: list[str], *, match_all: bool = False
|
||||
tag_type: str, current_tenant_id: str, tag_ids: list[str], session: scoped_session, *, match_all: bool = False
|
||||
):
|
||||
"""
|
||||
Return target IDs bound to tags for the given tenant and resource type.
|
||||
@ -70,7 +69,7 @@ class TagService:
|
||||
return []
|
||||
# Deduplicate repeated query params so match_all counts each requested tag once.
|
||||
requested_tag_ids = list(dict.fromkeys(tag_ids))
|
||||
tags = db.session.scalars(
|
||||
tags = session.scalars(
|
||||
select(Tag.id).where(
|
||||
Tag.id.in_(requested_tag_ids),
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
@ -86,13 +85,13 @@ class TagService:
|
||||
if match_all:
|
||||
if len(tag_ids) != len(requested_tag_ids):
|
||||
return []
|
||||
return db.session.scalars(
|
||||
return session.scalars(
|
||||
select(TagBinding.target_id)
|
||||
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
|
||||
.group_by(TagBinding.target_id)
|
||||
.having(func.count(sa.distinct(TagBinding.tag_id)) == len(tag_ids))
|
||||
).all()
|
||||
tag_bindings = db.session.scalars(
|
||||
tag_bindings = session.scalars(
|
||||
select(TagBinding.target_id).where(
|
||||
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
|
||||
)
|
||||
@ -100,11 +99,11 @@ class TagService:
|
||||
return tag_bindings
|
||||
|
||||
@staticmethod
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str, session: scoped_session):
|
||||
if not tag_type or not tag_name:
|
||||
return []
|
||||
tags = list(
|
||||
db.session.scalars(
|
||||
session.scalars(
|
||||
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
).all()
|
||||
)
|
||||
@ -113,8 +112,8 @@ class TagService:
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
|
||||
tags = db.session.scalars(
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str, session: scoped_session):
|
||||
tags = session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
@ -128,8 +127,8 @@ class TagService:
|
||||
return tags or []
|
||||
|
||||
@staticmethod
|
||||
def save_tags(payload: SaveTagPayload) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name):
|
||||
def save_tags(payload: SaveTagPayload, session: scoped_session) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name, session):
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = Tag(
|
||||
name=payload.name,
|
||||
@ -138,17 +137,17 @@ class TagService:
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
tag.id = str(uuid.uuid4())
|
||||
db.session.add(tag)
|
||||
db.session.commit()
|
||||
session.add(tag)
|
||||
session.commit()
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def update_tags(payload: UpdateTagPayload, tag_id: str) -> Tag:
|
||||
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
def update_tags(payload: UpdateTagPayload, tag_id: str, session: scoped_session) -> Tag:
|
||||
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
if payload.name != tag.name:
|
||||
existing = db.session.scalar(
|
||||
existing = session.scalar(
|
||||
select(Tag)
|
||||
.where(
|
||||
Tag.name == payload.name,
|
||||
@ -161,31 +160,31 @@ class TagService:
|
||||
if existing:
|
||||
raise ValueError("Tag name already exists")
|
||||
tag.name = payload.name
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def get_tag_binding_count(tag_id: str) -> int:
|
||||
count = db.session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
|
||||
def get_tag_binding_count(tag_id: str, session: scoped_session) -> int:
|
||||
count = session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def delete_tag(tag_id: str):
|
||||
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
def delete_tag(tag_id: str, session: scoped_session):
|
||||
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
db.session.delete(tag)
|
||||
session.delete(tag)
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
|
||||
tag_bindings = session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
|
||||
if tag_bindings:
|
||||
for tag_binding in tag_bindings:
|
||||
db.session.delete(tag_binding)
|
||||
db.session.commit()
|
||||
session.delete(tag_binding)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def save_tag_binding(payload: TagBindingCreatePayload):
|
||||
TagService.check_target_exists(payload.type, payload.target_id)
|
||||
valid_tag_ids = db.session.scalars(
|
||||
def save_tag_binding(payload: TagBindingCreatePayload, session: scoped_session):
|
||||
TagService.check_target_exists(payload.type, payload.target_id, session)
|
||||
valid_tag_ids = session.scalars(
|
||||
select(Tag.id).where(
|
||||
Tag.id.in_(payload.tag_ids),
|
||||
Tag.tenant_id == current_user.current_tenant_id,
|
||||
@ -193,7 +192,7 @@ class TagService:
|
||||
)
|
||||
).all()
|
||||
for tag_id in valid_tag_ids:
|
||||
tag_binding = db.session.scalar(
|
||||
tag_binding = session.scalar(
|
||||
select(TagBinding)
|
||||
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == payload.target_id)
|
||||
.limit(1)
|
||||
@ -206,15 +205,15 @@ class TagService:
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(new_tag_binding)
|
||||
db.session.commit()
|
||||
session.add(new_tag_binding)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete_tag_binding(payload: TagBindingDeletePayload):
|
||||
TagService.check_target_exists(payload.type, payload.target_id)
|
||||
def delete_tag_binding(payload: TagBindingDeletePayload, session: scoped_session):
|
||||
TagService.check_target_exists(payload.type, payload.target_id, session)
|
||||
result = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
delete(TagBinding).where(
|
||||
TagBinding.target_id == payload.target_id,
|
||||
TagBinding.tag_id.in_(payload.tag_ids),
|
||||
@ -230,12 +229,12 @@ class TagService:
|
||||
)
|
||||
|
||||
if result.rowcount:
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def check_target_exists(type: str, target_id: str):
|
||||
def check_target_exists(type: str, target_id: str, session: scoped_session):
|
||||
if type == "knowledge":
|
||||
dataset = db.session.scalar(
|
||||
dataset = session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
|
||||
.limit(1)
|
||||
@ -243,13 +242,13 @@ class TagService:
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found")
|
||||
elif type == "app":
|
||||
app = db.session.scalar(
|
||||
app = session.scalar(
|
||||
select(App).where(App.tenant_id == current_user.current_tenant_id, App.id == target_id).limit(1)
|
||||
)
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
elif type == "snippet":
|
||||
snippet = db.session.scalar(
|
||||
snippet = session.scalar(
|
||||
select(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.tenant_id == current_user.current_tenant_id, CustomizedSnippet.id == target_id)
|
||||
.limit(1)
|
||||
|
||||
@ -125,10 +125,11 @@ class DraftVarLoader(VariableLoader):
|
||||
# can be safely accessed before any offloading logic is applied.
|
||||
for draft_var in draft_vars:
|
||||
value = draft_var.get_value()
|
||||
if isinstance(value, FileSegment):
|
||||
files.append(value.value)
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files.extend(value.value)
|
||||
match value:
|
||||
case FileSegment():
|
||||
files.append(value.value)
|
||||
case ArrayFileSegment():
|
||||
files.extend(value.value)
|
||||
with Session(bind=self._engine) as session:
|
||||
storage_key_loader = StorageKeyLoader(
|
||||
session,
|
||||
|
||||
@ -34,10 +34,11 @@ class WorkflowRunService:
|
||||
|
||||
def __init__(self, session_factory: Engine | sessionmaker | None = None):
|
||||
"""Initialize WorkflowRunService with repository dependencies."""
|
||||
if session_factory is None:
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
elif isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
match session_factory:
|
||||
case None:
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
case Engine():
|
||||
session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
|
||||
self._session_factory = session_factory
|
||||
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
|
||||
@ -131,10 +131,11 @@ def fetch_table_rows(
|
||||
for row in rows:
|
||||
normalized = dict(row)
|
||||
for key, value in normalized.items():
|
||||
if isinstance(value, datetime):
|
||||
normalized[key] = value.isoformat()
|
||||
elif isinstance(value, uuid.UUID):
|
||||
normalized[key] = str(value)
|
||||
match value:
|
||||
case datetime():
|
||||
normalized[key] = value.isoformat()
|
||||
case uuid.UUID():
|
||||
normalized[key] = str(value)
|
||||
result.append(normalized)
|
||||
return result
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ since these test controller-level behavior.
|
||||
import uuid
|
||||
from contextlib import ExitStack
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
from unittest.mock import ANY, Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -1129,7 +1129,7 @@ class TestDatasetTagsApiPatch:
|
||||
assert status == 200
|
||||
assert response == {"id": "tag-1", "name": "Updated Tag", "type": "knowledge", "binding_count": "5"}
|
||||
mock_tag_svc.update_tags.assert_called_once()
|
||||
update_payload, tag_id = mock_tag_svc.update_tags.call_args.args
|
||||
update_payload, tag_id, session = mock_tag_svc.update_tags.call_args.args
|
||||
assert update_payload.name == "Updated Tag"
|
||||
assert tag_id == "tag-1"
|
||||
|
||||
@ -1184,7 +1184,7 @@ class TestDatasetTagsApiDelete:
|
||||
result = api.delete(_=None)
|
||||
|
||||
assert result == ("", 204)
|
||||
mock_tag_svc.delete_tag.assert_called_once_with("tag-1")
|
||||
mock_tag_svc.delete_tag.assert_called_once_with("tag-1", ANY)
|
||||
|
||||
@patch("libs.login.current_user")
|
||||
def test_delete_tag_forbidden(self, mock_current_user, app: Flask):
|
||||
@ -1233,7 +1233,7 @@ class TestDatasetTagsBindingStatusApi:
|
||||
assert status_code == 200
|
||||
assert response["data"] == [{"id": "tag_1", "name": "Test Tag"}]
|
||||
assert response["total"] == 1
|
||||
mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123")
|
||||
mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123", ANY)
|
||||
|
||||
|
||||
class TestDatasetTagBindingApiPost:
|
||||
@ -1266,7 +1266,8 @@ class TestDatasetTagBindingApiPost:
|
||||
from services.tag_service import TagBindingCreatePayload
|
||||
|
||||
mock_tag_svc.save_tag_binding.assert_called_once_with(
|
||||
TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
|
||||
TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
|
||||
ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
@ -1317,7 +1318,8 @@ class TestDatasetTagUnbindingApiPost:
|
||||
from services.tag_service import TagBindingDeletePayload
|
||||
|
||||
mock_tag_svc.delete_tag_binding.assert_called_once_with(
|
||||
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
|
||||
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
|
||||
ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.TagService")
|
||||
@ -1347,7 +1349,8 @@ class TestDatasetTagUnbindingApiPost:
|
||||
from services.tag_service import TagBindingDeletePayload
|
||||
|
||||
mock_tag_svc.delete_tag_binding.assert_called_once_with(
|
||||
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE)
|
||||
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type=TagType.KNOWLEDGE),
|
||||
ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
|
||||
@ -2755,7 +2755,7 @@ class TestRegisterService:
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test inviting an existing member who is not in the tenant yet.
|
||||
Test inviting an existing active account who is not in the tenant yet.
|
||||
"""
|
||||
fake = Faker()
|
||||
tenant_name = fake.company()
|
||||
@ -2791,20 +2791,20 @@ class TestRegisterService:
|
||||
# Mock the email task
|
||||
with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail:
|
||||
mock_send_mail.delay.return_value = None
|
||||
with pytest.raises(AccountAlreadyInTenantError, match="Account already in tenant."):
|
||||
# Execute invitation
|
||||
token = RegisterService.invite_new_member(
|
||||
tenant=tenant,
|
||||
email=existing_member_email,
|
||||
language=language,
|
||||
role="admin",
|
||||
inviter=inviter,
|
||||
)
|
||||
|
||||
# Verify email task was not called
|
||||
mock_send_mail.delay.assert_not_called()
|
||||
token = RegisterService.invite_new_member(
|
||||
tenant=tenant,
|
||||
email=existing_member_email,
|
||||
language=language,
|
||||
role="admin",
|
||||
inviter=inviter,
|
||||
)
|
||||
|
||||
# Verify tenant member was created for existing account
|
||||
assert token is not None
|
||||
assert len(token) > 0
|
||||
mock_send_mail.delay.assert_called_once()
|
||||
|
||||
# Existing active accounts must accept the invite before becoming workspace members.
|
||||
from models.account import TenantAccountJoin
|
||||
|
||||
tenant_join = (
|
||||
@ -2812,8 +2812,13 @@ class TestRegisterService:
|
||||
.filter_by(tenant_id=tenant.id, account_id=existing_account.id)
|
||||
.first()
|
||||
)
|
||||
assert tenant_join is not None
|
||||
assert tenant_join.role == "admin"
|
||||
assert tenant_join is None
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, None, token)
|
||||
assert invitation is not None
|
||||
assert invitation["account"].id == existing_account.id
|
||||
assert invitation["data"]["role"] == "admin"
|
||||
assert invitation["data"]["requires_setup"] is False
|
||||
|
||||
def test_invite_new_member_existing_member(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
|
||||
@ -234,7 +234,7 @@ class TestAppService:
|
||||
# Get paginated apps
|
||||
params = AppListParams(page=1, limit=10, mode="chat")
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
|
||||
|
||||
# Verify pagination results
|
||||
assert paginated_apps is not None
|
||||
@ -295,7 +295,7 @@ class TestAppService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
last_modified_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
|
||||
)
|
||||
assert last_modified_apps is not None
|
||||
assert [app.name for app in last_modified_apps.items] == [
|
||||
@ -305,7 +305,10 @@ class TestAppService:
|
||||
]
|
||||
|
||||
recently_created_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", sort_by="recently_created")
|
||||
account.id,
|
||||
tenant.id,
|
||||
AppListParams(page=1, limit=10, mode="chat", sort_by="recently_created"),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert recently_created_apps is not None
|
||||
assert [app.name for app in recently_created_apps.items] == [
|
||||
@ -315,7 +318,10 @@ class TestAppService:
|
||||
]
|
||||
|
||||
earliest_created_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created")
|
||||
account.id,
|
||||
tenant.id,
|
||||
AppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created"),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert earliest_created_apps is not None
|
||||
assert [app.name for app in earliest_created_apps.items] == [
|
||||
@ -366,7 +372,7 @@ class TestAppService:
|
||||
assert star_count == 1
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
starred_by_app_id = {app.id: app.is_starred for app in paginated_apps.items}
|
||||
@ -377,7 +383,7 @@ class TestAppService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
starred_by_app_id = {app.id: app.is_starred for app in paginated_apps.items}
|
||||
@ -442,7 +448,7 @@ class TestAppService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
last_modified_apps = app_service.get_paginate_starred_apps(
|
||||
account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat")
|
||||
account.id, tenant.id, StarredAppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
|
||||
)
|
||||
assert last_modified_apps is not None
|
||||
assert [app.name for app in last_modified_apps.items] == [
|
||||
@ -457,6 +463,7 @@ class TestAppService:
|
||||
account.id,
|
||||
tenant.id,
|
||||
StarredAppListParams(page=1, limit=10, mode="chat", sort_by="recently_created"),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert recently_created_apps is not None
|
||||
assert [app.name for app in recently_created_apps.items] == [
|
||||
@ -469,6 +476,7 @@ class TestAppService:
|
||||
account.id,
|
||||
tenant.id,
|
||||
StarredAppListParams(page=1, limit=10, mode="chat", sort_by="earliest_created"),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert earliest_created_apps is not None
|
||||
assert [app.name for app in earliest_created_apps.items] == [
|
||||
@ -522,20 +530,25 @@ class TestAppService:
|
||||
completion_app = app_service.create_app(tenant.id, completion_app_params, account)
|
||||
|
||||
# Test filter by mode
|
||||
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"))
|
||||
chat_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"), db_session_with_containers
|
||||
)
|
||||
assert len(chat_apps.items) == 1
|
||||
assert chat_apps.items[0].mode == "chat"
|
||||
|
||||
# Test filter by name
|
||||
filtered_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat")
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat"), db_session_with_containers
|
||||
)
|
||||
assert len(filtered_apps.items) == 1
|
||||
assert "Chat" in filtered_apps.items[0].name
|
||||
|
||||
# Test filter by created_by_me
|
||||
my_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True)
|
||||
account.id,
|
||||
tenant.id,
|
||||
AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert len(my_apps.items) == 1
|
||||
|
||||
@ -588,6 +601,7 @@ class TestAppService:
|
||||
first_account.id,
|
||||
tenant.id,
|
||||
AppListParams(page=1, limit=10, mode="chat", creator_ids=[second_account.id]),
|
||||
db_session_with_containers,
|
||||
)
|
||||
|
||||
assert filtered_apps is not None
|
||||
@ -635,10 +649,12 @@ class TestAppService:
|
||||
# Test with tag filter
|
||||
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["tag1", "tag2"])
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
|
||||
|
||||
# Verify tag service was called
|
||||
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"], match_all=True)
|
||||
mock_tag_service.assert_called_once_with(
|
||||
"app", tenant.id, ["tag1", "tag2"], db_session_with_containers, match_all=True
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert paginated_apps is not None
|
||||
@ -651,7 +667,7 @@ class TestAppService:
|
||||
|
||||
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["nonexistent_tag"])
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params, db_session_with_containers)
|
||||
|
||||
# Should return None when no apps match tag filter
|
||||
assert paginated_apps is None
|
||||
@ -1467,7 +1483,7 @@ class TestAppService:
|
||||
|
||||
# Test 1: Search with % character
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
|
||||
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10), db_session_with_containers
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
@ -1476,7 +1492,10 @@ class TestAppService:
|
||||
|
||||
# Test 2: Search with _ character
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="test_data", mode="chat", page=1, limit=10)
|
||||
account.id,
|
||||
tenant.id,
|
||||
AppListParams(name="test_data", mode="chat", page=1, limit=10),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
@ -1485,7 +1504,10 @@ class TestAppService:
|
||||
|
||||
# Test 3: Search with \ character
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10)
|
||||
account.id,
|
||||
tenant.id,
|
||||
AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10),
|
||||
db_session_with_containers,
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
@ -1494,7 +1516,7 @@ class TestAppService:
|
||||
|
||||
# Test 4: Search with % should NOT match 100% (verifies escaping works)
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
|
||||
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10), db_session_with_containers
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
|
||||
@ -227,7 +227,7 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id)
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, db_session_with_containers, tenant_id=tenant.id)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 5
|
||||
@ -257,7 +257,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, search=search)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, per_page, db_session_with_containers, tenant_id=tenant.id, search=search
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -301,7 +303,9 @@ class TestDatasetServiceGetDatasets:
|
||||
tag_ids = [tag_1.id, tag_2.id]
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, per_page, db_session_with_containers, tenant_id=tenant.id, tag_ids=tag_ids
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -326,7 +330,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, per_page, db_session_with_containers, tenant_id=tenant.id, tag_ids=tag_ids
|
||||
)
|
||||
|
||||
# Assert
|
||||
# When tag_ids is empty, tag filtering is skipped, so normal query results are returned
|
||||
@ -356,7 +362,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, user=None)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, per_page, db_session_with_containers, tenant_id=tenant.id, user=None
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -384,6 +392,7 @@ class TestDatasetServiceGetDatasets:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1,
|
||||
per_page=20,
|
||||
session=db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
user=owner,
|
||||
include_all=True,
|
||||
@ -408,7 +417,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -432,7 +443,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -459,7 +472,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=user
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -486,7 +501,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=operator
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
@ -509,7 +526,9 @@ class TestDatasetServiceGetDatasets:
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, session=db_session_with_containers, tenant_id=tenant.id, user=operator
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
|
||||
@ -449,7 +449,7 @@ class TestTagService:
|
||||
|
||||
# Act: Execute the method under test
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids)
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -485,7 +485,7 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test with empty tag IDs
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, [])
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, [], db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -533,13 +533,19 @@ class TestTagService:
|
||||
|
||||
# Act: Execute the method under test
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, match_all=True)
|
||||
result = TagService.get_target_ids_by_tag_ids(
|
||||
"knowledge", tenant.id, tag_ids, db_session_with_containers, match_all=True
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result == [dataset_with_all_tags.id]
|
||||
|
||||
missing_tag_result = TagService.get_target_ids_by_tag_ids(
|
||||
"knowledge", tenant.id, [tags[0].id, str(uuid.uuid4())], match_all=True
|
||||
"knowledge",
|
||||
tenant.id,
|
||||
[tags[0].id, str(uuid.uuid4())],
|
||||
db_session_with_containers,
|
||||
match_all=True,
|
||||
)
|
||||
assert missing_tag_result == []
|
||||
|
||||
@ -565,7 +571,9 @@ class TestTagService:
|
||||
non_existent_tag_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, non_existent_tag_ids)
|
||||
result = TagService.get_target_ids_by_tag_ids(
|
||||
"knowledge", tenant.id, non_existent_tag_ids, db_session_with_containers
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -599,7 +607,7 @@ class TestTagService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag")
|
||||
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag", db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -625,7 +633,7 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test with non-existent tag name
|
||||
result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag")
|
||||
result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag", db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -650,8 +658,8 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test with empty parameters
|
||||
result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag")
|
||||
result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "")
|
||||
result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag", db_session_with_containers)
|
||||
result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "", db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result_empty_type is not None
|
||||
@ -688,7 +696,7 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tags_by_target_id("app", tenant.id, app.id)
|
||||
result = TagService.get_tags_by_target_id("app", tenant.id, app.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -720,7 +728,7 @@ class TestTagService:
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tags_by_target_id("app", tenant.id, app.id)
|
||||
result = TagService.get_tags_by_target_id("app", tenant.id, app.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -745,7 +753,7 @@ class TestTagService:
|
||||
tag_args = SaveTagPayload(name="test_tag_name", type="knowledge")
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.save_tags(tag_args)
|
||||
result = TagService.save_tags(tag_args, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -783,11 +791,11 @@ class TestTagService:
|
||||
|
||||
# Create first tag
|
||||
tag_args = SaveTagPayload(name="duplicate_tag", type="app")
|
||||
TagService.save_tags(tag_args)
|
||||
TagService.save_tags(tag_args, db_session_with_containers)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
TagService.save_tags(tag_args)
|
||||
TagService.save_tags(tag_args, db_session_with_containers)
|
||||
assert "Tag name already exists" in str(exc_info.value)
|
||||
|
||||
def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
@ -807,13 +815,13 @@ class TestTagService:
|
||||
|
||||
# Create a tag to update
|
||||
tag_args = SaveTagPayload(name="original_name", type="knowledge")
|
||||
tag = TagService.save_tags(tag_args)
|
||||
tag = TagService.save_tags(tag_args, db_session_with_containers)
|
||||
|
||||
# Update args
|
||||
update_args = UpdateTagPayload(name="updated_name")
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.update_tags(update_args, tag.id)
|
||||
result = TagService.update_tags(update_args, tag.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -854,7 +862,7 @@ class TestTagService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
TagService.update_tags(update_args, non_existent_tag_id)
|
||||
TagService.update_tags(update_args, non_existent_tag_id, db_session_with_containers)
|
||||
assert "Tag not found" in str(exc_info.value)
|
||||
|
||||
def test_update_tags_duplicate_name_error(
|
||||
@ -875,17 +883,17 @@ class TestTagService:
|
||||
|
||||
# Create two tags
|
||||
tag1_args = SaveTagPayload(name="first_tag", type="app")
|
||||
tag1 = TagService.save_tags(tag1_args)
|
||||
tag1 = TagService.save_tags(tag1_args, db_session_with_containers)
|
||||
|
||||
tag2_args = SaveTagPayload(name="second_tag", type="app")
|
||||
tag2 = TagService.save_tags(tag2_args)
|
||||
tag2 = TagService.save_tags(tag2_args, db_session_with_containers)
|
||||
|
||||
# Try to update second tag with first tag's name
|
||||
update_args = UpdateTagPayload(name="first_tag")
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
TagService.update_tags(update_args, tag2.id)
|
||||
TagService.update_tags(update_args, tag2.id, db_session_with_containers)
|
||||
assert "Tag name already exists" in str(exc_info.value)
|
||||
|
||||
def test_get_tag_binding_count_success(
|
||||
@ -917,8 +925,8 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id)
|
||||
result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id)
|
||||
result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id, db_session_with_containers)
|
||||
result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result_tag_with_bindings == 1
|
||||
@ -946,7 +954,7 @@ class TestTagService:
|
||||
non_existent_tag_id = str(uuid.uuid4())
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tag_binding_count(non_existent_tag_id)
|
||||
result = TagService.get_tag_binding_count(non_existent_tag_id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result == 0
|
||||
@ -986,7 +994,7 @@ class TestTagService:
|
||||
assert binding_before is not None
|
||||
|
||||
# Act: Execute the method under test
|
||||
TagService.delete_tag(tag.id)
|
||||
TagService.delete_tag(tag.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify tag was deleted
|
||||
@ -1018,7 +1026,7 @@ class TestTagService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
TagService.delete_tag(non_existent_tag_id)
|
||||
TagService.delete_tag(non_existent_tag_id, db_session_with_containers)
|
||||
assert "Tag not found" in str(exc_info.value)
|
||||
|
||||
def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
@ -1048,7 +1056,7 @@ class TestTagService:
|
||||
binding_payload = TagBindingCreatePayload(
|
||||
type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags]
|
||||
)
|
||||
TagService.save_tag_binding(binding_payload)
|
||||
TagService.save_tag_binding(binding_payload, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
||||
@ -1090,10 +1098,10 @@ class TestTagService:
|
||||
|
||||
# Create first binding
|
||||
binding_payload = TagBindingCreatePayload(type="app", target_id=app.id, tag_ids=[tag.id])
|
||||
TagService.save_tag_binding(binding_payload)
|
||||
TagService.save_tag_binding(binding_payload, db_session_with_containers)
|
||||
|
||||
# Act: Try to create duplicate binding
|
||||
TagService.save_tag_binding(binding_payload)
|
||||
TagService.save_tag_binding(binding_payload, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
||||
@ -1173,7 +1181,7 @@ class TestTagService:
|
||||
delete_payload = TagBindingDeletePayload(
|
||||
type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags]
|
||||
)
|
||||
TagService.delete_tag_binding(delete_payload)
|
||||
TagService.delete_tag_binding(delete_payload, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify tag bindings were deleted
|
||||
@ -1209,7 +1217,7 @@ class TestTagService:
|
||||
|
||||
# Act: Try to delete non-existent binding
|
||||
delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_ids=[tag.id])
|
||||
TagService.delete_tag_binding(delete_payload)
|
||||
TagService.delete_tag_binding(delete_payload, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# No error should be raised, and database state should remain unchanged
|
||||
@ -1240,7 +1248,7 @@ class TestTagService:
|
||||
dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
|
||||
|
||||
# Act: Execute the method under test
|
||||
TagService.check_target_exists("knowledge", dataset.id)
|
||||
TagService.check_target_exists("knowledge", dataset.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# No exception should be raised for existing dataset
|
||||
@ -1268,7 +1276,7 @@ class TestTagService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
TagService.check_target_exists("knowledge", non_existent_dataset_id)
|
||||
TagService.check_target_exists("knowledge", non_existent_dataset_id, db_session_with_containers)
|
||||
assert "Dataset not found" in str(exc_info.value)
|
||||
|
||||
def test_check_target_exists_app_success(
|
||||
@ -1292,7 +1300,7 @@ class TestTagService:
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
|
||||
|
||||
# Act: Execute the method under test
|
||||
TagService.check_target_exists("app", app.id)
|
||||
TagService.check_target_exists("app", app.id, db_session_with_containers)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# No exception should be raised for existing app
|
||||
@ -1320,7 +1328,7 @@ class TestTagService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
TagService.check_target_exists("app", non_existent_app_id)
|
||||
TagService.check_target_exists("app", non_existent_app_id, db_session_with_containers)
|
||||
assert "App not found" in str(exc_info.value)
|
||||
|
||||
def test_check_target_exists_invalid_type(
|
||||
@ -1346,5 +1354,5 @@ class TestTagService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
TagService.check_target_exists("invalid_type", non_existent_target_id)
|
||||
TagService.check_target_exists("invalid_type", non_existent_target_id, db_session_with_containers)
|
||||
assert "Invalid binding type" in str(exc_info.value)
|
||||
|
||||
@ -266,6 +266,7 @@ def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path):
|
||||
assert module._schema_ref_name(None) is None
|
||||
assert module._schema_markdown_type(None) == ""
|
||||
assert module._schema_markdown_type({"anyOf": [{"type": "null"}]}) == ""
|
||||
assert module._strip_trailing_line_whitespace("line \ncell\t \n") == "line\ncell\n"
|
||||
assert module._replace_schema_table_type("unchanged", "Definition", "field", "") == "unchanged"
|
||||
assert (
|
||||
module._replace_schema_table_type(
|
||||
@ -319,7 +320,10 @@ def test_convert_spec_to_markdown_patches_generated_union_tables(tmp_path, monke
|
||||
assert kwargs["check"] is False
|
||||
markdown_path = Path(args[args.index("-o") + 1])
|
||||
markdown_path.write_text(
|
||||
"""#### FormInputConfig
|
||||
"Intro line"
|
||||
+ " \n"
|
||||
+ """
|
||||
#### FormInputConfig
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
@ -340,5 +344,7 @@ def test_convert_spec_to_markdown_patches_generated_union_tables(tmp_path, monke
|
||||
module._convert_spec_to_markdown(spec_path, output_path)
|
||||
|
||||
converted = output_path.read_text(encoding="utf-8")
|
||||
assert "Intro line \n" not in converted
|
||||
assert "Intro line\n" in converted
|
||||
assert "| FormInputConfig | [ParagraphInputConfig](#paragraphinputconfig) | | |" in converted
|
||||
assert "| default | [StringSource](#stringsource) | | No |" in converted
|
||||
|
||||
@ -8,12 +8,13 @@ from pathlib import Path
|
||||
|
||||
def _walk_values(value):
|
||||
yield value
|
||||
if isinstance(value, dict):
|
||||
for child in value.values():
|
||||
yield from _walk_values(child)
|
||||
elif isinstance(value, list):
|
||||
for child in value:
|
||||
yield from _walk_values(child)
|
||||
match value:
|
||||
case dict():
|
||||
for child in value.values():
|
||||
yield from _walk_values(child)
|
||||
case list():
|
||||
for child in value:
|
||||
yield from _walk_values(child)
|
||||
|
||||
|
||||
def _load_generate_swagger_specs_module():
|
||||
@ -106,6 +107,39 @@ def test_generate_specs_writes_get_operations_without_request_bodies(tmp_path):
|
||||
assert all("requestBody" not in operation for operation in _get_operations(payload))
|
||||
|
||||
|
||||
def test_generate_specs_writes_service_api_reference_descriptions(tmp_path):
|
||||
module = _load_generate_swagger_specs_module()
|
||||
|
||||
written_paths = module.generate_specs(tmp_path)
|
||||
service_path = next(path for path in written_paths if path.name == "service-openapi.json")
|
||||
payload = json.loads(service_path.read_text(encoding="utf-8"))
|
||||
|
||||
chat_operation = payload["paths"]["/chat-messages"]["post"]
|
||||
assert chat_operation["summary"] == "Send Chat Message"
|
||||
assert chat_operation["description"] == "Send a request to the chat application."
|
||||
assert chat_operation["tags"] == ["Chatflows", "Chats"]
|
||||
|
||||
rename_operation = payload["paths"]["/conversations/{c_id}/name"]["post"]
|
||||
assert rename_operation["summary"] == "Rename Conversation"
|
||||
|
||||
|
||||
def test_standalone_inline_model_name_includes_list_constraints():
|
||||
module = _load_generate_swagger_specs_module()
|
||||
|
||||
from flask_restx import fields
|
||||
|
||||
cases = (
|
||||
({"min_items": 1}, {"min_items": 2}),
|
||||
({"max_items": 1}, {"max_items": 2}),
|
||||
({"unique": True}, {"unique": False}),
|
||||
)
|
||||
for first_kwargs, second_kwargs in cases:
|
||||
first_inline_model = {"items": fields.List(fields.String, **first_kwargs)}
|
||||
second_inline_model = {"items": fields.List(fields.String, **second_kwargs)}
|
||||
|
||||
assert module._inline_model_name(first_inline_model) != module._inline_model_name(second_inline_model)
|
||||
|
||||
|
||||
def test_generate_specs_is_idempotent(tmp_path):
|
||||
module = _load_generate_swagger_specs_module()
|
||||
|
||||
|
||||
@ -24,7 +24,9 @@ from controllers.console.agent.roster import (
|
||||
AgentAppCopyApi,
|
||||
AgentAppListApi,
|
||||
AgentInviteOptionsApi,
|
||||
AgentLogMessagesApi,
|
||||
AgentLogsApi,
|
||||
AgentLogSourcesApi,
|
||||
AgentRosterVersionDetailApi,
|
||||
AgentRosterVersionsApi,
|
||||
AgentStatisticsSummaryApi,
|
||||
@ -93,6 +95,7 @@ def _agent_app_composer_response() -> dict:
|
||||
def _app_detail_obj(**overrides):
|
||||
data = {
|
||||
"id": "app-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"name": "Iris",
|
||||
"description": "Agent app",
|
||||
"mode_compatible_with_agent": "agent",
|
||||
@ -116,7 +119,6 @@ def _app_detail_obj(**overrides):
|
||||
"deleted_tools": [],
|
||||
"site": None,
|
||||
"bound_agent_id": "00000000-0000-0000-0000-000000000001",
|
||||
"tenant_id": "tenant-1",
|
||||
}
|
||||
data.update(overrides)
|
||||
return SimpleNamespace(**data)
|
||||
@ -153,6 +155,8 @@ def test_agent_v2_console_routes_are_agent_id_first() -> None:
|
||||
"/agent/<uuid:agent_id>/chat-messages/<uuid:message_id>/suggested-questions",
|
||||
"/agent/<uuid:agent_id>/messages/<uuid:message_id>",
|
||||
"/agent/<uuid:agent_id>/logs",
|
||||
"/agent/<uuid:agent_id>/logs/<uuid:conversation_id>/messages",
|
||||
"/agent/<uuid:agent_id>/log-sources",
|
||||
"/agent/<uuid:agent_id>/statistics/summary",
|
||||
"/agent/invite-options",
|
||||
):
|
||||
@ -187,7 +191,7 @@ def test_agent_app_list_and_create_use_agent_route(
|
||||
def get_app(self, app_obj: object) -> object:
|
||||
return app_obj
|
||||
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, params) -> object:
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, params, session) -> object:
|
||||
captured["list"] = {"user_id": user_id, "tenant_id": tenant_id, "params": params}
|
||||
return SimpleNamespace(
|
||||
page=1,
|
||||
@ -270,7 +274,13 @@ def test_agent_app_list_and_create_use_agent_route(
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/agent",
|
||||
json={"name": "Iris", "description": "Agent app", "icon_type": "emoji", "icon": "robot"},
|
||||
json={
|
||||
"name": "Iris",
|
||||
"description": "Agent app",
|
||||
"role": "Coordinator",
|
||||
"icon_type": "emoji",
|
||||
"icon": "robot",
|
||||
},
|
||||
):
|
||||
created, status = unwrap(AgentAppListApi.post)(AgentAppListApi(), "tenant-1", SimpleNamespace(id=account_id))
|
||||
|
||||
@ -283,6 +293,23 @@ def test_agent_app_list_and_create_use_agent_route(
|
||||
create_call = cast(dict[str, object], captured["create"])
|
||||
create_params = cast(Any, create_call["params"])
|
||||
assert create_params.mode == "agent"
|
||||
assert create_params.agent_role == "Coordinator"
|
||||
|
||||
|
||||
def test_agent_app_create_requires_role(app: Flask, account_id: str) -> None:
|
||||
with app.test_request_context(
|
||||
"/console/api/agent",
|
||||
json={"name": "Iris", "description": "Agent app", "icon_type": "emoji", "icon": "robot"},
|
||||
):
|
||||
with pytest.raises(ValueError, match="Field required"):
|
||||
unwrap(AgentAppListApi.post)(AgentAppListApi(), "tenant-1", SimpleNamespace(id=account_id))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/agent",
|
||||
json={"name": "Iris", "description": "Agent app", "role": " ", "icon_type": "emoji", "icon": "robot"},
|
||||
):
|
||||
with pytest.raises(ValueError, match="Agent role is required"):
|
||||
unwrap(AgentAppListApi.post)(AgentAppListApi(), "tenant-1", SimpleNamespace(id=account_id))
|
||||
|
||||
|
||||
def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
|
||||
@ -331,7 +358,7 @@ def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/agent/00000000-0000-0000-0000-000000000001",
|
||||
json={"name": "Renamed", "description": "", "icon_type": "emoji", "icon": "R"},
|
||||
json={"name": "Renamed", "description": "", "role": "Reviewer", "icon_type": "emoji", "icon": "R"},
|
||||
):
|
||||
updated = unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", agent_id)
|
||||
|
||||
@ -343,6 +370,7 @@ def test_agent_app_detail_update_delete_resolve_app_from_agent_id(
|
||||
assert "bound_agent_id" not in updated
|
||||
update_call = cast(dict[str, object], captured["update"])
|
||||
assert update_call["app"] is app_model
|
||||
assert cast(dict[str, object], update_call["args"])["role"] == "Reviewer"
|
||||
|
||||
deleted, status = unwrap(AgentAppApi.delete)(AgentAppApi(), "tenant-1", agent_id)
|
||||
assert (deleted, status) == ("", 204)
|
||||
@ -395,6 +423,45 @@ def test_agent_app_copy_uses_agent_id_and_returns_agent_detail(
|
||||
}
|
||||
|
||||
|
||||
def test_agent_app_update_rejects_empty_role(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
agent_id = "00000000-0000-0000-0000-000000000001"
|
||||
app_model = _app_detail_obj(id="app-1", bound_agent_id=agent_id)
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"get_agent_app_model",
|
||||
lambda _self, **kwargs: app_model,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"get_app_backing_agent",
|
||||
lambda _self, **kwargs: SimpleNamespace(id=agent_id, role="", active_config_snapshot_id=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.FeatureService,
|
||||
"get_system_features",
|
||||
lambda: SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)),
|
||||
)
|
||||
|
||||
class FakeAppService:
|
||||
def get_app(self, app_obj: object) -> object:
|
||||
return app_obj
|
||||
|
||||
def update_app(self, app_obj: object, args: dict[str, object]) -> object:
|
||||
captured["update"] = {"app": app_obj, "args": args}
|
||||
return _app_detail_obj(id="app-1", name=args["name"], bound_agent_id=agent_id)
|
||||
|
||||
monkeypatch.setattr(roster_controller, "AppService", FakeAppService)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/agent/00000000-0000-0000-0000-000000000001",
|
||||
json={"name": "Renamed", "description": "", "role": "", "icon_type": "emoji", "icon": "R"},
|
||||
):
|
||||
with pytest.raises(ValueError, match="String should have at least 1 character"):
|
||||
unwrap(AgentAppApi.put)(AgentAppApi(), "tenant-1", agent_id)
|
||||
|
||||
|
||||
def test_invite_options_get_parses_app_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
@ -461,21 +528,59 @@ def test_agent_observability_routes_resolve_app_from_agent_id(
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class FakeObservabilityService:
|
||||
def list_logs(self, *, app, params):
|
||||
captured["logs"] = {"app": app, "params": params}
|
||||
def list_logs(self, *, app, agent_id, params):
|
||||
captured["logs"] = {"app": app, "agent_id": agent_id, "params": params}
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"conversation_id": "conversation-1",
|
||||
"id": "conversation-1",
|
||||
"title": "Debug",
|
||||
"end_user_id": "end-user-1",
|
||||
"message_count": 2,
|
||||
"user_rate": None,
|
||||
"operation_rate": None,
|
||||
"unread": True,
|
||||
"source": {
|
||||
"id": "webapp:app-1",
|
||||
"type": "webapp",
|
||||
"app_id": "app-1",
|
||||
"app_name": "Iris",
|
||||
"app_icon_type": "emoji",
|
||||
"app_icon": "robot",
|
||||
"app_icon_background": "#fff",
|
||||
"workflow_id": None,
|
||||
"workflow_version": None,
|
||||
"node_id": None,
|
||||
},
|
||||
"status": "success",
|
||||
"created_at": 1,
|
||||
"updated_at": 2,
|
||||
}
|
||||
],
|
||||
"page": 2,
|
||||
"limit": 5,
|
||||
"total": 6,
|
||||
"has_more": False,
|
||||
}
|
||||
|
||||
def list_log_messages(self, *, app, agent_id, conversation_id, params):
|
||||
captured["messages"] = {
|
||||
"app": app,
|
||||
"agent_id": agent_id,
|
||||
"conversation_id": conversation_id,
|
||||
"params": params,
|
||||
}
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": "message-1",
|
||||
"message_id": "message-1",
|
||||
"conversation_id": "conversation-1",
|
||||
"conversation_name": "Debug",
|
||||
"query": "hello",
|
||||
"answer": "hi",
|
||||
"status": "success",
|
||||
"error": None,
|
||||
"source": "explore",
|
||||
"from_source": "console",
|
||||
"from_end_user_id": None,
|
||||
"from_account_id": account_id,
|
||||
"message_tokens": 1,
|
||||
@ -488,14 +593,34 @@ def test_agent_observability_routes_resolve_app_from_agent_id(
|
||||
"updated_at": 2,
|
||||
}
|
||||
],
|
||||
"page": 2,
|
||||
"limit": 5,
|
||||
"total": 6,
|
||||
"page": 1,
|
||||
"limit": 20,
|
||||
"total": 1,
|
||||
"has_more": False,
|
||||
}
|
||||
|
||||
def get_statistics_summary(self, *, app, params):
|
||||
captured["statistics"] = {"app": app, "params": params}
|
||||
def list_log_sources(self, *, app, agent_id):
|
||||
captured["sources"] = {"app": app, "agent_id": agent_id}
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": "webapp:app-1",
|
||||
"type": "webapp",
|
||||
"app_id": "app-1",
|
||||
"app_name": "Iris",
|
||||
"app_icon_type": "emoji",
|
||||
"app_icon": "robot",
|
||||
"app_icon_background": "#fff",
|
||||
"workflow_id": None,
|
||||
"workflow_version": None,
|
||||
"node_id": None,
|
||||
}
|
||||
],
|
||||
"groups": [{"type": "webapp", "label": "WEBAPP", "sources": []}],
|
||||
}
|
||||
|
||||
def get_statistics_summary(self, *, app, agent_id, params):
|
||||
captured["statistics"] = {"app": app, "agent_id": agent_id, "params": params}
|
||||
return {
|
||||
"source": "all",
|
||||
"summary": {
|
||||
@ -532,9 +657,11 @@ def test_agent_observability_routes_resolve_app_from_agent_id(
|
||||
):
|
||||
logs = unwrap(AgentLogsApi.get)(AgentLogsApi(), "tenant-1", account, agent_id)
|
||||
|
||||
assert logs["data"][0]["id"] == "message-1"
|
||||
assert logs["data"][0]["id"] == "conversation-1"
|
||||
assert logs["data"][0]["source"]["id"] == "webapp:app-1"
|
||||
logs_call = cast(dict[str, object], captured["logs"])
|
||||
assert logs_call["app"] is app_model
|
||||
assert logs_call["agent_id"] == agent_id
|
||||
logs_params = cast(Any, logs_call["params"])
|
||||
assert logs_params.page == 2
|
||||
assert logs_params.limit == 5
|
||||
@ -542,6 +669,31 @@ def test_agent_observability_routes_resolve_app_from_agent_id(
|
||||
assert logs_params.status == "success"
|
||||
assert logs_params.source == "console"
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/agent/00000000-0000-0000-0000-000000000001/logs/00000000-0000-0000-0000-000000000002/messages"
|
||||
):
|
||||
messages = unwrap(AgentLogMessagesApi.get)(
|
||||
AgentLogMessagesApi(),
|
||||
"tenant-1",
|
||||
account,
|
||||
agent_id,
|
||||
"00000000-0000-0000-0000-000000000002",
|
||||
)
|
||||
|
||||
assert messages["data"][0]["id"] == "message-1"
|
||||
messages_call = cast(dict[str, object], captured["messages"])
|
||||
assert messages_call["app"] is app_model
|
||||
assert messages_call["agent_id"] == agent_id
|
||||
assert messages_call["conversation_id"] == "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
with app.test_request_context("/console/api/agent/00000000-0000-0000-0000-000000000001/log-sources"):
|
||||
sources = unwrap(AgentLogSourcesApi.get)(AgentLogSourcesApi(), "tenant-1", account, agent_id)
|
||||
|
||||
assert sources["data"][0]["id"] == "webapp:app-1"
|
||||
sources_call = cast(dict[str, object], captured["sources"])
|
||||
assert sources_call["app"] is app_model
|
||||
assert sources_call["agent_id"] == agent_id
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/agent/00000000-0000-0000-0000-000000000001/statistics/summary?source=api"
|
||||
):
|
||||
@ -550,6 +702,7 @@ def test_agent_observability_routes_resolve_app_from_agent_id(
|
||||
assert statistics["summary"]["total_messages"] == 1
|
||||
stats_call = cast(dict[str, object], captured["statistics"])
|
||||
assert stats_call["app"] is app_model
|
||||
assert stats_call["agent_id"] == agent_id
|
||||
stats_params = cast(Any, stats_call["params"])
|
||||
assert stats_params.source == "api"
|
||||
assert stats_params.timezone == "UTC"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user