This commit is contained in:
chariri 2026-06-25 18:39:29 +00:00 committed by GitHub
commit c1262bcc1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 3426 additions and 1173 deletions

View File

@ -167,12 +167,16 @@ register_schema_models(
ChatMessagesQuery,
MessageFeedbackPayload,
FeedbackExportQuery,
)
register_response_schema_models(
console_ns,
AnnotationCountResponse,
SuggestedQuestionsResponse,
MessageDetailResponse,
MessageInfiniteScrollPaginationResponse,
SimpleResultResponse,
TextFileResponse,
)
register_response_schema_models(console_ns, SimpleResultResponse, TextFileResponse)
@console_ns.route("/apps/<uuid:app_id>/chat-messages")

File diff suppressed because it is too large Load Diff

View File

@ -9,14 +9,21 @@ from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.common.errors import NotFoundError
from controllers.common.fields import BinaryFileResponse, RedirectResponse, SimpleResultResponse
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.trigger.entities.entities import SubscriptionBuilderUpdater
from core.trigger.entities.api_entities import (
SubscriptionBuilderApiEntity,
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
)
from core.trigger.entities.entities import RequestLog, SubscriptionBuilderUpdater
from core.trigger.trigger_manager import TriggerManager
from extensions.ext_database import db
from graphon.model_runtime.utils.encoders import jsonable_encoder
from fields.base import ResponseModel
from libs.helper import dump_response
from libs.login import login_required
from models.account import Account
from models.provider_ids import TriggerProviderID
@ -51,9 +58,9 @@ class TriggerSubscriptionBuilderVerifyPayload(BaseModel):
class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
name: str | None = None
parameters: dict[str, Any] | None = Field(default=None)
properties: dict[str, Any] | None = Field(default=None)
credentials: dict[str, Any] | None = Field(default=None)
parameters: dict[str, Any] | None = None
properties: dict[str, Any] | None = None
credentials: dict[str, Any] | None = None
@model_validator(mode="after")
def check_at_least_one_field(self):
@ -63,28 +70,48 @@ class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
class TriggerOAuthClientPayload(BaseModel):
client_params: dict[str, Any] | None = Field(default=None)
client_params: dict[str, Any] | None = None
enabled: bool | None = None
class TriggerOAuthAuthorizeResponse(BaseModel):
class TriggerProviderListResponse(RootModel[list[TriggerProviderApiEntity]]):
pass
class TriggerProviderSubscriptionListResponse(RootModel[list[TriggerProviderSubscriptionApiEntity]]):
pass
class TriggerSubscriptionBuilderCreateResponse(ResponseModel):
subscription_builder: SubscriptionBuilderApiEntity
class TriggerVerificationResponse(ResponseModel):
verified: bool
class TriggerSubscriptionBuilderLogsResponse(ResponseModel):
logs: list[RequestLog]
class TriggerOAuthAuthorizeResponse(ResponseModel):
authorization_url: str
subscription_builder_id: str
subscription_builder: Any
subscription_builder: SubscriptionBuilderApiEntity
class TriggerOAuthClientResponse(BaseModel):
class TriggerOAuthClientResponse(ResponseModel):
configured: bool
system_configured: bool
custom_configured: bool
oauth_client_schema: Any
oauth_client_schema: list[ProviderConfig]
custom_enabled: bool
redirect_uri: str
params: dict[str, Any]
params: dict[str, Any] = Field(default_factory=dict)
class TriggerProviderOpaqueResponse(RootModel[Any]):
root: Any
class TriggerProviderErrorResponse(ResponseModel):
error: str
register_schema_models(
@ -96,18 +123,24 @@ register_schema_models(
)
register_response_schema_models(
console_ns,
BinaryFileResponse,
RedirectResponse,
SimpleResultResponse,
TriggerOAuthAuthorizeResponse,
TriggerOAuthClientResponse,
TriggerProviderOpaqueResponse,
TriggerProviderApiEntity,
TriggerProviderErrorResponse,
TriggerProviderListResponse,
TriggerProviderSubscriptionListResponse,
TriggerSubscriptionBuilderCreateResponse,
TriggerSubscriptionBuilderLogsResponse,
SubscriptionBuilderApiEntity,
TriggerVerificationResponse,
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
class TriggerProviderIconApi(Resource):
@console_ns.response(200, "Success", console_ns.models[BinaryFileResponse.__name__])
# response-contract:ignore binary trigger provider icon
@console_ns.response(200, "Trigger provider icon")
@setup_required
@login_required
@account_initialization_required
@ -118,31 +151,45 @@ class TriggerProviderIconApi(Resource):
@console_ns.route("/workspaces/current/triggers")
class TriggerProviderListApi(Resource):
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200,
"Trigger providers retrieved successfully",
console_ns.models[TriggerProviderListResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str):
"""List all trigger providers for the current tenant"""
return jsonable_encoder(TriggerProviderService.list_trigger_providers(tenant_id))
return dump_response(TriggerProviderListResponse, TriggerProviderService.list_trigger_providers(tenant_id))
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/info")
class TriggerProviderInfoApi(Resource):
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200,
"Trigger provider retrieved successfully",
console_ns.models[TriggerProviderApiEntity.__name__],
)
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
"""Get info for a trigger provider"""
return jsonable_encoder(TriggerProviderService.get_trigger_provider(tenant_id, TriggerProviderID(provider)))
provider_entity = TriggerProviderService.get_trigger_provider(tenant_id, TriggerProviderID(provider))
return provider_entity.model_dump(mode="json")
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
class TriggerSubscriptionListApi(Resource):
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200,
"Trigger subscriptions retrieved successfully",
console_ns.models[TriggerProviderSubscriptionListResponse.__name__],
)
@console_ns.response(404, "Trigger provider not found", console_ns.models[TriggerProviderErrorResponse.__name__])
@setup_required
@login_required
@edit_permission_required
@ -152,16 +199,18 @@ class TriggerSubscriptionListApi(Resource):
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, provider: str):
"""List all trigger subscriptions for the current tenant's provider"""
try:
return jsonable_encoder(
return dump_response(
TriggerProviderSubscriptionListResponse,
TriggerProviderService.list_trigger_provider_subscriptions(
tenant_id=tenant_id,
provider_id=TriggerProviderID(provider),
user=user,
)
),
)
except ValueError as e:
return jsonable_encoder({"error": str(e)}), 404
return TriggerProviderErrorResponse(error=str(e)).model_dump(mode="json"), 404
except Exception as e:
logger.exception("Error listing trigger providers", exc_info=e)
raise
@ -172,7 +221,11 @@ class TriggerSubscriptionListApi(Resource):
)
class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200,
"Trigger subscription builder created successfully",
console_ns.models[TriggerSubscriptionBuilderCreateResponse.__name__],
)
@setup_required
@login_required
@edit_permission_required
@ -182,6 +235,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@with_current_tenant_id
def post(self, tenant_id: str, user: Account, provider: str):
"""Add a new subscription instance for a trigger provider"""
payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
try:
@ -192,7 +246,9 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
provider_id=TriggerProviderID(provider),
credential_type=credential_type,
)
return jsonable_encoder({"subscription_builder": subscription_builder})
return TriggerSubscriptionBuilderCreateResponse(subscription_builder=subscription_builder).model_dump(
mode="json"
)
except Exception as e:
logger.exception("Error adding provider credential", exc_info=e)
raise
@ -202,7 +258,11 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderGetApi(Resource):
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200,
"Trigger subscription builder retrieved successfully",
console_ns.models[SubscriptionBuilderApiEntity.__name__],
)
@setup_required
@login_required
@edit_permission_required
@ -210,9 +270,8 @@ class TriggerSubscriptionBuilderGetApi(Resource):
@account_initialization_required
def get(self, provider: str, subscription_builder_id: str):
"""Get a subscription instance for a trigger provider"""
return jsonable_encoder(
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
)
subscription_builder = TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
return subscription_builder.model_dump(mode="json")
@console_ns.route(
@ -220,7 +279,11 @@ class TriggerSubscriptionBuilderGetApi(Resource):
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200,
"Trigger subscription builder verified successfully",
console_ns.models[TriggerVerificationResponse.__name__],
)
@setup_required
@login_required
@edit_permission_required
@ -230,11 +293,12 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@with_current_tenant_id
def post(self, tenant_id: str, user: Account, provider: str, subscription_builder_id: str):
"""Verify and update a subscription instance for a trigger provider"""
payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_verify to prevent race conditions
return TriggerSubscriptionBuilderService.update_and_verify_builder(
result = TriggerSubscriptionBuilderService.update_and_verify_builder(
tenant_id=tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
@ -243,6 +307,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
credentials=payload.credentials,
),
)
return dump_response(TriggerVerificationResponse, result)
except Exception as e:
logger.exception("Error verifying provider credential", exc_info=e)
raise ValueError(str(e)) from e
@ -253,7 +318,11 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
)
class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200,
"Trigger subscription builder updated successfully",
console_ns.models[SubscriptionBuilderApiEntity.__name__],
)
@setup_required
@login_required
@edit_permission_required
@ -262,21 +331,20 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@with_current_tenant_id
def post(self, tenant_id: str, provider: str, subscription_builder_id: str):
"""Update a subscription instance for a trigger provider"""
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=payload.name,
parameters=payload.parameters,
properties=payload.properties,
credentials=payload.credentials,
),
)
)
return TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=payload.name,
parameters=payload.parameters,
properties=payload.properties,
credentials=payload.credentials,
),
).model_dump(mode="json")
except Exception as e:
logger.exception("Error updating provider credential", exc_info=e)
raise
@ -286,7 +354,11 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderLogsApi(Resource):
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200,
"Trigger subscription builder logs retrieved successfully",
console_ns.models[TriggerSubscriptionBuilderLogsResponse.__name__],
)
@setup_required
@login_required
@edit_permission_required
@ -294,9 +366,10 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
@account_initialization_required
def get(self, provider: str, subscription_builder_id: str):
"""Get the request logs for a subscription instance for a trigger provider"""
try:
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
return dump_response(TriggerSubscriptionBuilderLogsResponse, {"logs": logs})
except Exception as e:
logger.exception("Error getting request logs for subscription builder", exc_info=e)
raise
@ -307,7 +380,9 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
)
class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200, "Trigger subscription builder built successfully", console_ns.models[SimpleResultResponse.__name__]
)
@setup_required
@login_required
@edit_permission_required
@ -331,7 +406,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
properties=payload.properties,
),
)
return 200
return SimpleResultResponse(result="success").model_dump(mode="json")
except Exception as e:
logger.exception("Error building provider credential", exc_info=e)
raise ValueError(str(e)) from e
@ -342,7 +417,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
)
class TriggerSubscriptionUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200, "Trigger subscription updated successfully", console_ns.models[SimpleResultResponse.__name__]
)
@setup_required
@login_required
@edit_permission_required
@ -351,6 +428,7 @@ class TriggerSubscriptionUpdateApi(Resource):
@with_current_tenant_id
def post(self, tenant_id: str, subscription_id: str):
"""Update a subscription instance"""
request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
subscription = TriggerProviderService.get_subscription_by_id(
@ -376,7 +454,7 @@ class TriggerSubscriptionUpdateApi(Resource):
name=request.name,
properties=request.properties,
)
return 200
return SimpleResultResponse(result="success").model_dump(mode="json")
# For the rest cases(API_KEY, OAUTH2)
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
@ -388,7 +466,7 @@ class TriggerSubscriptionUpdateApi(Resource):
credentials=request.credentials or subscription.credentials,
parameters=request.parameters or subscription.parameters,
)
return 200
return SimpleResultResponse(result="success").model_dump(mode="json")
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
@ -409,6 +487,7 @@ class TriggerSubscriptionDeleteApi(Resource):
@with_current_tenant_id
def post(self, tenant_id: str, subscription_id: str):
"""Delete a subscription instance"""
try:
with sessionmaker(db.engine).begin() as session:
# Delete trigger provider subscription
@ -423,7 +502,7 @@ class TriggerSubscriptionDeleteApi(Resource):
tenant_id=tenant_id,
subscription_id=subscription_id,
)
return {"result": "success"}
return SimpleResultResponse(result="success").model_dump(mode="json")
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
@ -433,9 +512,10 @@ class TriggerSubscriptionDeleteApi(Resource):
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize")
class TriggerOAuthAuthorizeApi(Resource):
# response-contract:ignore cookie-bearing Flask response
@console_ns.response(
200,
"Authorization URL retrieved successfully",
"Trigger OAuth authorization URL generated successfully",
console_ns.models[TriggerOAuthAuthorizeResponse.__name__],
)
@setup_required
@ -445,10 +525,12 @@ class TriggerOAuthAuthorizeApi(Resource):
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, provider: str):
"""Initiate OAuth authorization flow for a trigger provider"""
try:
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
tenant_id = tenant_id
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
@ -492,15 +574,12 @@ class TriggerOAuthAuthorizeApi(Resource):
system_credentials=oauth_client_params,
)
# Create response with cookie
response = make_response(
jsonable_encoder(
{
"authorization_url": authorization_url_response.authorization_url,
"subscription_builder_id": subscription_builder.id,
"subscription_builder": subscription_builder,
}
)
TriggerOAuthAuthorizeResponse(
authorization_url=authorization_url_response.authorization_url,
subscription_builder_id=subscription_builder.id,
subscription_builder=subscription_builder,
).model_dump(mode="json")
)
response.set_cookie(
"context_id",
@ -519,11 +598,8 @@ class TriggerOAuthAuthorizeApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
class TriggerOAuthCallbackApi(Resource):
@console_ns.response(
302,
"Redirect to console OAuth callback page",
console_ns.models[RedirectResponse.__name__],
)
# response-contract:ignore redirect response
@console_ns.response(302, "Redirect to OAuth callback page")
@setup_required
def get(self, provider: str):
"""Handle OAuth callback for trigger provider"""
@ -589,7 +665,11 @@ class TriggerOAuthCallbackApi(Resource):
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
class TriggerOAuthClientManageApi(Resource):
@console_ns.response(200, "Success", console_ns.models[TriggerOAuthClientResponse.__name__])
@console_ns.response(
200,
"Trigger OAuth client retrieved successfully",
console_ns.models[TriggerOAuthClientResponse.__name__],
)
@setup_required
@login_required
@is_admin_or_owner_required
@ -598,6 +678,7 @@ class TriggerOAuthClientManageApi(Resource):
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
"""Get OAuth client configuration for a provider"""
try:
provider_id = TriggerProviderID(provider)
@ -618,24 +699,24 @@ class TriggerOAuthClientManageApi(Resource):
)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
return jsonable_encoder(
{
"configured": bool(custom_params or system_client_exists),
"system_configured": system_client_exists,
"custom_configured": bool(custom_params),
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
"custom_enabled": is_custom_enabled,
"redirect_uri": redirect_uri,
"params": custom_params or {},
}
)
return TriggerOAuthClientResponse(
configured=bool(custom_params or system_client_exists),
system_configured=system_client_exists,
custom_configured=bool(custom_params),
oauth_client_schema=provider_controller.get_oauth_client_schema(),
custom_enabled=is_custom_enabled,
redirect_uri=redirect_uri,
params=dict(custom_params),
).model_dump(mode="json")
except Exception as e:
logger.exception("Error getting OAuth client", exc_info=e)
raise
@console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__])
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@console_ns.response(
200, "Trigger OAuth client saved successfully", console_ns.models[SimpleResultResponse.__name__]
)
@setup_required
@login_required
@is_admin_or_owner_required
@ -644,16 +725,18 @@ class TriggerOAuthClientManageApi(Resource):
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
"""Configure custom OAuth client for a provider"""
payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
result = TriggerProviderService.save_custom_oauth_client_params(
tenant_id=tenant_id,
provider_id=provider_id,
client_params=payload.client_params,
enabled=payload.enabled,
)
return dump_response(SimpleResultResponse, result)
except ValueError as e:
raise BadRequest(str(e))
@ -661,22 +744,26 @@ class TriggerOAuthClientManageApi(Resource):
logger.exception("Error configuring OAuth client", exc_info=e)
raise
@console_ns.response(
200, "Trigger OAuth client deleted successfully", console_ns.models[SimpleResultResponse.__name__]
)
@setup_required
@login_required
@is_admin_or_owner_required
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_PREFERENCES, resource_required=False)
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_tenant_id
def delete(self, tenant_id: str, provider: str):
"""Remove custom OAuth client configuration"""
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.delete_custom_oauth_client_params(
result = TriggerProviderService.delete_custom_oauth_client_params(
tenant_id=tenant_id,
provider_id=provider_id,
)
return dump_response(SimpleResultResponse, result)
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
@ -689,7 +776,11 @@ class TriggerOAuthClientManageApi(Resource):
)
class TriggerSubscriptionVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
@console_ns.response(
200,
"Trigger subscription verified successfully",
console_ns.models[TriggerVerificationResponse.__name__],
)
@setup_required
@login_required
@edit_permission_required
@ -699,6 +790,7 @@ class TriggerSubscriptionVerifyApi(Resource):
@with_current_tenant_id
def post(self, tenant_id: str, user: Account, provider: str, subscription_id: str):
"""Verify credentials for an existing subscription (edit mode only)"""
verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
try:
@ -709,7 +801,7 @@ class TriggerSubscriptionVerifyApi(Resource):
subscription_id=subscription_id,
credentials=verify_request.credentials,
)
return result
return dump_response(TriggerVerificationResponse, result)
except ValueError as e:
logger.warning("Credential verification failed", exc_info=e)
raise BadRequest(str(e)) from e

View File

@ -74,8 +74,6 @@ class ToolProviderApiEntity(BaseModel):
for parameter in tool.get("parameters"):
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES:
parameter["type"] = "files"
if parameter.get("input_schema") is None:
parameter.pop("input_schema", None)
# -------------
optional_fields = self.optional_field("server_url", self.server_url)
match self.type:

View File

@ -16,30 +16,16 @@ from yarl import URL
import contexts
from configs import dify_config
from core.entities import PluginCredentialType
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from extensions.ext_database import db
from graphon.runtime import VariablePool
from models.provider_ids import ToolProviderID
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
pass
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.entities import PluginCredentialType
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.builtin_tool.tool import BuiltinTool
@ -56,12 +42,21 @@ from core.tools.entities.tool_entities import (
emoji_icon_adapter,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from graphon.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from graphon.runtime import VariablePool
from models.provider_ids import ToolProviderID
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
@ -921,23 +916,19 @@ class ToolManager:
# add tool labels
labels = ToolLabelManager.get_tool_labels(controller)
schema_type = provider_obj.schema_type
return cast(
dict,
jsonable_encoder(
{
"schema_type": provider_obj.schema_type,
"schema": provider_obj.schema,
"tools": provider_obj.tools,
"icon": icon,
"description": provider_obj.description,
"credentials": masked_credentials,
"privacy_policy": provider_obj.privacy_policy,
"custom_disclaimer": provider_obj.custom_disclaimer,
"labels": labels,
}
),
)
return {
"schema_type": getattr(schema_type, "value", schema_type),
"schema": provider_obj.schema,
"tools": [tool.model_dump(mode="json") for tool in provider_obj.tools],
"icon": icon,
"description": provider_obj.description,
"credentials": masked_credentials,
"privacy_policy": provider_obj.privacy_policy,
"custom_disclaimer": provider_obj.custom_disclaimer,
"labels": labels,
}
@classmethod
def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,9 @@
import json
import logging
from typing import Any, TypedDict, cast
from typing import Any, Literal, TypedDict, cast
from httpx import get
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
@ -22,7 +23,6 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_tool_provider_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from graphon.model_runtime.utils.encoders import jsonable_encoder
from models.tools import ApiToolProvider
from services.tools.tools_transform_service import ToolTransformService
@ -36,6 +36,27 @@ class ApiSchemaParseResult(TypedDict):
warning: dict[str, str]
class ApiToolPreviewResult(TypedDict, total=False):
result: str
error: str
class RemoteSchemaResult(TypedDict):
schema: str
class SimpleSuccessResult(TypedDict):
result: Literal["success"]
def _dump_api_tool_bundles(tool_bundles: list[ApiToolBundle]) -> list[dict[str, Any]]:
return cast(list[dict[str, Any]], TypeAdapter(list[ApiToolBundle]).dump_python(tool_bundles, mode="json"))
def _dump_provider_configs(configs: list[ProviderConfig]) -> list[dict[str, Any]]:
return cast(list[dict[str, Any]], TypeAdapter(list[ProviderConfig]).dump_python(configs, mode="json"))
class ApiToolManageService:
@staticmethod
def parser_api_schema(schema: str) -> ApiSchemaParseResult:
@ -80,14 +101,12 @@ class ApiToolManageService:
return cast(
ApiSchemaParseResult,
jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
),
{
"schema_type": schema_type.value,
"parameters_schema": _dump_api_tool_bundles(tool_bundles),
"credentials_schema": _dump_provider_configs(credentials_schema),
"warning": warnings,
},
)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@ -118,7 +137,7 @@ class ApiToolManageService:
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
) -> dict[str, Any]:
) -> SimpleSuccessResult:
"""
Create a new API tool provider.
@ -169,7 +188,7 @@ class ApiToolManageService:
schema=schema,
description=extra_info.get("description", ""),
schema_type_str=schema_type,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
tools_str=json.dumps(_dump_api_tool_bundles(tool_bundles)),
credentials_str="{}",
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
@ -201,7 +220,7 @@ class ApiToolManageService:
return {"result": "success"}
@staticmethod
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str) -> RemoteSchemaResult:
"""
get api tool provider remote schema
"""
@ -276,7 +295,7 @@ class ApiToolManageService:
privacy_policy: str | None,
custom_disclaimer: str,
labels: list[str],
) -> dict[str, Any]:
) -> SimpleSuccessResult:
"""
Update an existing API tool provider.
@ -322,7 +341,7 @@ class ApiToolManageService:
provider.schema = schema
provider.description = extra_info.get("description", "")
provider.schema_type_str = schema_type
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.tools_str = json.dumps(_dump_api_tool_bundles(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
@ -365,7 +384,7 @@ class ApiToolManageService:
return {"result": "success"}
@staticmethod
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str) -> SimpleSuccessResult:
"""
Delete an API tool provider.
@ -413,9 +432,9 @@ class ApiToolManageService:
tool_name: str,
credentials: dict[str, Any],
parameters: dict[str, Any],
schema_type: ApiProviderSchemaType,
schema_type: ApiProviderSchemaType | str,
schema: str,
) -> dict[str, Any]:
) -> ApiToolPreviewResult:
"""
Test an API tool before adding the API tool provider.
@ -464,7 +483,7 @@ class ApiToolManageService:
schema=schema,
description="",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
tools_str=json.dumps(_dump_api_tool_bundles(tool_bundles)),
credentials_str=json.dumps(credentials),
)

View File

@ -44,6 +44,7 @@ from controllers.console.workspace.tool_providers import (
ToolWorkflowProviderUpdateApi,
is_valid_url,
)
from core.tools.entities.api_entities import ToolProviderApiEntity as CoreToolProviderApiEntity
from models.account import Account, TenantAccountRole
from services.tools.mcp_tools_manage_service import ReconnectResult
from tests.test_containers_integration_tests.controllers.console.helpers import (
@ -60,6 +61,148 @@ def empty_list() -> list[object]:
return []
def emoji_icon() -> dict[str, str]:
return {"content": "tool", "background": "#252525"}
def i18n(text: str) -> dict[str, str]:
return {"en_US": text}
def tool_payload(name: str = "ping") -> dict[str, object]:
return {
"author": "langgenius",
"name": name,
"label": i18n(name.title()),
"description": i18n(f"{name} description"),
"parameters": [],
"labels": ["utilities"],
"output_schema": {},
}
def provider_payload(
*,
provider_id: str = "provider-1",
name: str = "provider",
provider_type: str = "builtin",
tools: list[dict[str, object]] | None = None,
) -> dict[str, object]:
return {
"id": provider_id,
"author": "langgenius",
"name": name,
"description": i18n(f"{name} description"),
"icon": emoji_icon(),
"icon_dark": emoji_icon(),
"label": i18n(name.title()),
"type": provider_type,
"masked_credentials": {"api_key": "[__HIDDEN__]"},
"original_credentials": {"api_key": "sk-secret"},
"is_team_authorization": False,
"allow_delete": True,
"plugin_id": "langgenius/provider",
"plugin_unique_identifier": "langgenius/provider:1.0.0",
"tools": tools or [tool_payload()],
"labels": ["utilities"],
"server_url": "",
"updated_at": 1710000000,
"server_identifier": "",
"masked_headers": None,
"original_headers": None,
"authentication": None,
"is_dynamic_registration": True,
"configuration": None,
"identity_mode": "off",
"workflow_app_id": None,
}
def provider_entity(
*,
provider_id: str = "provider-1",
name: str = "provider",
provider_type: str = "builtin",
tools: list[dict[str, object]] | None = None,
) -> CoreToolProviderApiEntity:
return CoreToolProviderApiEntity.model_validate(
provider_payload(provider_id=provider_id, name=name, provider_type=provider_type, tools=tools)
)
def credential_payload() -> dict[str, object]:
return {
"id": "credential-1",
"name": "Default credential",
"provider": "provider",
"credential_type": "api-key",
"is_default": True,
"credentials": {"api_key": "masked"},
"visibility": "all_team_members",
"created_by": "user-1",
"partial_member_list": [],
"from_other_member": False,
}
def provider_config_payload() -> dict[str, object]:
return {"type": "secret-input", "name": "api_key", "required": True}
def api_tool_bundle_payload() -> dict[str, object]:
return {
"server_url": "https://api.example.com",
"method": "get",
"summary": "Ping",
"operation_id": "ping",
"parameters": [],
"author": "langgenius",
"icon": None,
"openapi": {"operationId": "ping"},
"output_schema": {},
}
def api_provider_detail_payload() -> dict[str, object]:
return {
"schema_type": "openapi",
"schema": "{}",
"tools": [api_tool_bundle_payload()],
"icon": emoji_icon(),
"description": "API provider",
"credentials": {},
"privacy_policy": "",
"custom_disclaimer": "",
"labels": ["utilities"],
}
def credential_info_payload() -> dict[str, object]:
return {
"supported_credential_types": ["api-key", "oauth2"],
"is_oauth_custom_client_enabled": False,
"credentials": [credential_payload()],
}
def oauth_client_schema_payload() -> dict[str, object]:
return {
"schema": [provider_config_payload()],
"is_oauth_custom_client_enabled": False,
"is_system_oauth_params_exists": True,
"client_params": {"client_id": "masked"},
"redirect_uri": "https://console.example.com/oauth/callback",
}
def tool_label_payload() -> dict[str, object]:
return {
"name": "utilities",
"label": i18n("Utilities"),
"icon": "wrench",
}
@pytest.fixture
def _mock_cache() -> None:
return
@ -127,7 +270,7 @@ def test_create_mcp_provider_populates_tools(
with (
patch(
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
return_value=provider_entity(provider_id="provider-1", provider_type="mcp", tools=[tool_payload()]),
autospec=True,
),
):
@ -138,13 +281,15 @@ def test_create_mcp_provider_populates_tools(
content_type="application/json",
)
# Assert
assert resp.status_code == 200
body = resp.get_json()
assert body.get("id") == "provider-1"
# 若 transform 后包含 tools 字段,确保非空
assert isinstance(body.get("tools"), list)
assert body["tools"]
# Assert
assert resp.status_code == 200
body = resp.get_json()
assert body.get("id") == "provider-1"
assert body["team_credentials"] == {"api_key": "[__HIDDEN__]"}
assert "masked_credentials" not in body
assert "original_credentials" not in body
assert isinstance(body.get("tools"), list)
assert body["tools"]
class TestUtils:
@ -170,10 +315,16 @@ class TestToolProviderListApi:
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers",
return_value=["p1"],
return_value=[provider_entity(provider_id="p1").to_dict()],
),
):
assert method(api, "t1", make_account(id="u1")) == ["p1"]
result = method(api, "t1", make_account(id="u1"))
assert result[0]["id"] == "p1"
assert result[0]["team_credentials"] == {"api_key": "[__HIDDEN__]"}
assert "masked_credentials" not in result[0]
assert "original_credentials" not in result[0]
assert result[0]["tools"][0]["name"] == "ping"
class TestBuiltinProviderApis:
@ -189,10 +340,10 @@ class TestBuiltinProviderApis:
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools",
return_value=[{"a": 1}],
return_value=[tool_payload()],
),
):
assert method(api, "t1", "provider") == [{"a": 1}]
assert method(api, "t1", "provider")[0]["name"] == "ping"
def test_info(self, app: Flask) -> None:
api = ToolBuiltinProviderInfoApi()
@ -202,10 +353,15 @@ class TestBuiltinProviderApis:
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info",
return_value={"x": 1},
return_value=provider_entity(),
),
):
assert method(api, "t1", "provider") == {"x": 1}
result = method(api, "t1", "provider")
assert result["id"] == "provider-1"
assert result["team_credentials"] == {"api_key": "[__HIDDEN__]"}
assert "masked_credentials" not in result
assert "original_credentials" not in result
def test_delete(self, app: Flask) -> None:
api = ToolBuiltinProviderDeleteApi()
@ -240,10 +396,10 @@ class TestBuiltinProviderApis:
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider",
return_value={"id": 1},
return_value={"result": "success"},
),
):
assert method(api, "t", make_account(), "provider")["id"] == 1
assert method(api, "t", make_account(), "provider")["result"] == "success"
def test_update(self, app: Flask) -> None:
api = ToolBuiltinProviderUpdateApi()
@ -255,10 +411,10 @@ class TestBuiltinProviderApis:
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider",
return_value={"ok": True},
return_value={"result": "success"},
),
):
assert method(api, "t", make_account(), "provider")["ok"]
assert method(api, "t", make_account(), "provider")["result"] == "success"
def test_get_credentials(self, app: Flask) -> None:
api = ToolBuiltinProviderGetCredentialsApi()
@ -268,10 +424,10 @@ class TestBuiltinProviderApis:
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials",
return_value={"k": "v"},
return_value=[credential_payload()],
),
):
assert method(api, "t", make_account(id="user-1"), "provider") == {"k": "v"}
assert method(api, "t", make_account(id="user-1"), "provider")[0]["id"] == "credential-1"
def test_icon(self, app: Flask) -> None:
api = ToolBuiltinProviderIconApi()
@ -295,10 +451,10 @@ class TestBuiltinProviderApis:
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema",
return_value={"schema": {}},
return_value=[provider_config_payload()],
),
):
assert method(api, "t", "provider", "oauth2") == {"schema": {}}
assert method(api, "t", "provider", "oauth2")[0]["name"] == "api_key"
def test_set_default_credential(self, app: Flask) -> None:
api = ToolBuiltinProviderSetDefaultApi()
@ -308,10 +464,10 @@ class TestBuiltinProviderApis:
app.test_request_context("/", json={"id": "c1"}),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider",
return_value={"ok": True},
return_value={"result": "success"},
),
):
assert method(api, "t", "provider")["ok"]
assert method(api, "t", "provider")["result"] == "success"
def test_get_credential_info(self, app: Flask) -> None:
api = ToolBuiltinProviderGetCredentialInfoApi()
@ -321,10 +477,10 @@ class TestBuiltinProviderApis:
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info",
return_value={"info": "x"},
return_value=credential_info_payload(),
),
):
assert method(api, "t", make_account(), "provider") == {"info": "x"}
assert method(api, "t", make_account(), "provider")["credentials"][0]["id"] == "credential-1"
def test_get_oauth_client_schema(self, app: Flask) -> None:
api = ToolBuiltinProviderGetOauthClientSchemaApi()
@ -334,10 +490,10 @@ class TestBuiltinProviderApis:
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema",
return_value={"schema": {}},
return_value=oauth_client_schema_payload(),
),
):
assert method(api, "t", "provider") == {"schema": {}}
assert method(api, "t", "provider")["schema"][0]["name"] == "api_key"
class TestApiProviderApis:
@ -354,30 +510,34 @@ class TestApiProviderApis:
"schema_type": "openapi",
"schema": "{}",
"provider": "p",
"icon": empty_mapping(),
"icon": emoji_icon(),
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider",
return_value={"id": 1},
),
return_value={"result": "success"},
) as create_api_tool_provider,
):
assert method(api, "t", make_account())["id"] == 1
assert method(api, "t", make_account()) == {"result": "success"}
create_api_tool_provider.assert_called_once()
assert create_api_tool_provider.call_args.args[3] == emoji_icon()
def test_remote_schema(self, app: Flask) -> None:
api = ToolApiProviderGetRemoteSchemaApi()
method = unwrap(api.get)
openapi_schema = '{"openapi":"3.0.0","info":{"title":"Demo API","version":"1.0.0"},"paths":{}}'
with (
app.test_request_context("/?url=http://x.com"),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema",
return_value={"schema": "x"},
return_value={"schema": openapi_schema},
),
):
assert method(api, "t", make_account())["schema"] == "x"
assert method(api, "t", make_account()) == {"schema": openapi_schema}
def test_list_tools(self, app: Flask) -> None:
api = ToolApiProviderListToolsApi()
@ -387,10 +547,10 @@ class TestApiProviderApis:
app.test_request_context("/?provider=p"),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools",
return_value=[{"tool": 1}],
return_value=[tool_payload("api_ping")],
),
):
assert method(api, "t", make_account()) == [{"tool": 1}]
assert method(api, "t", make_account())[0]["name"] == "api_ping"
def test_update(self, app: Flask) -> None:
api = ToolApiProviderUpdateApi()
@ -402,7 +562,7 @@ class TestApiProviderApis:
"schema": "{}",
"provider": "p",
"original_provider": "o",
"icon": empty_mapping(),
"icon": emoji_icon(),
"privacy_policy": "",
"custom_disclaimer": "",
}
@ -411,10 +571,13 @@ class TestApiProviderApis:
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider",
return_value={"ok": True},
),
return_value={"result": "success"},
) as update_api_tool_provider,
):
assert method(api, "t", make_account())["ok"]
assert method(api, "t", make_account()) == {"result": "success"}
update_api_tool_provider.assert_called_once()
assert update_api_tool_provider.call_args.args[4] == emoji_icon()
def test_delete(self, app: Flask) -> None:
api = ToolApiProviderDeleteApi()
@ -437,10 +600,10 @@ class TestApiProviderApis:
app.test_request_context("/?provider=p"),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider",
return_value={"x": 1},
return_value=api_provider_detail_payload(),
),
):
assert method(api, "t", make_account()) == {"x": 1}
assert method(api, "t", make_account())["schema"] == "{}"
class TestWorkflowApis:
@ -457,7 +620,7 @@ class TestWorkflowApis:
"name": "n",
"label": "l",
"description": "d",
"icon": empty_mapping(),
"icon": emoji_icon(),
"parameters": empty_list(),
}
@ -465,10 +628,13 @@ class TestWorkflowApis:
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool",
return_value={"id": 1},
),
return_value={"result": "success"},
) as create_workflow_tool,
):
assert method(api, "t", make_account())["id"] == 1
assert method(api, "t", make_account()) == {"result": "success"}
create_workflow_tool.assert_called_once()
assert create_workflow_tool.call_args.kwargs["icon"] == emoji_icon()
def test_update_invalid(self, app: Flask) -> None:
api = ToolWorkflowProviderUpdateApi()
@ -479,18 +645,21 @@ class TestWorkflowApis:
"name": "Tool",
"label": "Tool Label",
"description": "A tool",
"icon": empty_mapping(),
"icon": emoji_icon(),
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool",
return_value={"ok": True},
),
return_value={"result": "success"},
) as update_workflow_tool,
):
result = method(api, "t", make_account())
assert result["ok"]
assert result == {"result": "success"}
update_workflow_tool.assert_called_once()
assert update_workflow_tool.call_args.args[5] == emoji_icon()
def test_delete(self, app: Flask) -> None:
api = ToolWorkflowProviderDeleteApi()
@ -500,10 +669,10 @@ class TestWorkflowApis:
app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool",
return_value={"ok": True},
return_value={"result": "success"},
),
):
assert method(api, "t", make_account())["ok"]
assert method(api, "t", make_account())["result"] == "success"
def test_get_error(self, app: Flask) -> None:
api = ToolWorkflowProviderGetApi()
@ -525,49 +694,40 @@ class TestLists:
api = ToolBuiltinListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools",
return_value=[m],
return_value=[provider_entity(provider_id="builtin-1")],
),
):
assert method(api, "t", make_account()) == [{"x": 1}]
assert method(api, "t", make_account())[0]["id"] == "builtin-1"
def test_api_list(self, app: Flask) -> None:
api = ToolApiListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools",
return_value=[m],
return_value=[provider_entity(provider_id="api-1", provider_type="api")],
),
):
assert method(api, "t") == [{"x": 1}]
assert method(api, "t")[0]["id"] == "api-1"
def test_workflow_list(self, app: Flask) -> None:
api = ToolWorkflowListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools",
return_value=[m],
return_value=[provider_entity(provider_id="workflow-1", provider_type="workflow")],
),
):
assert method(api, "t", make_account()) == [{"x": 1}]
assert method(api, "t", make_account())[0]["id"] == "workflow-1"
class TestLabels:
@ -583,10 +743,10 @@ class TestLabels:
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels",
return_value=["l1"],
return_value=[tool_label_payload()],
),
):
assert method(api) == ["l1"]
assert method(api)[0]["name"] == "utilities"
class TestOAuth:
@ -630,10 +790,10 @@ class TestOAuthCustomClient:
app.test_request_context("/", json={"client_params": {"a": 1}}),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params",
return_value={"ok": True},
return_value={"result": "success"},
),
):
assert method(api, "t", "provider")["ok"]
assert method(api, "t", "provider") == {"result": "success"}
def test_get_custom_client(self, app: Flask) -> None:
api = ToolOAuthCustomClient()
@ -656,7 +816,7 @@ class TestOAuthCustomClient:
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params",
return_value={"ok": True},
return_value={"result": "success"},
),
):
assert method(api, "t", "provider")["ok"]
assert method(api, "t", "provider") == {"result": "success"}

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from datetime import datetime
from inspect import unwrap
from unittest.mock import MagicMock, patch
@ -29,6 +30,8 @@ from controllers.console.workspace.trigger_providers import (
TriggerSubscriptionVerifyApi,
)
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity, TriggerProviderApiEntity
from core.trigger.entities.entities import RequestLog
from models.account import Account
@ -38,6 +41,47 @@ def mock_user() -> Account:
return user
def trigger_provider() -> TriggerProviderApiEntity:
return TriggerProviderApiEntity(
author="Dify",
name="github",
label={"en_US": "GitHub"},
description={"en_US": "GitHub trigger provider"},
icon="icon.svg",
icon_dark=None,
tags=["code"],
plugin_id="plugin",
plugin_unique_identifier="plugin:github",
supported_creation_methods=[],
subscription_constructor=None,
subscription_schema=[],
events=[],
)
def subscription_builder() -> SubscriptionBuilderApiEntity:
return SubscriptionBuilderApiEntity(
id="b1",
name="Builder",
provider="github",
endpoint="b1",
parameters={"repo": "dify"},
properties={"branch": "main"},
credentials={"token": "secret"},
credential_type=CredentialType.UNAUTHORIZED,
)
def request_log() -> RequestLog:
return RequestLog(
id="log1",
endpoint="/hooks/b1",
request={"headers": {}, "body": {"event": "push"}},
response={"status": 200, "body": {"ok": True}},
created_at=datetime(2024, 1, 1),
)
class TestTriggerProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers: Flask) -> Flask:
@ -77,10 +121,10 @@ class TestTriggerProviderApis:
app.test_request_context("/"),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider",
return_value={"id": "p1"},
return_value=trigger_provider(),
),
):
assert method(api, "t1", "github") == {"id": "p1"}
assert method(api, "t1", "github")["name"] == "github"
class TestTriggerSubscriptionListApi:
@ -129,11 +173,11 @@ class TestTriggerSubscriptionBuilderApis:
app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
return_value={"id": "b1"},
return_value=subscription_builder(),
),
):
result = method(api, "t1", mock_user(), "github")
assert "subscription_builder" in result
assert result["subscription_builder"]["id"] == "b1"
def test_get_builder(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderGetApi()
@ -143,10 +187,10 @@ class TestTriggerSubscriptionBuilderApis:
app.test_request_context("/"),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id",
return_value={"id": "b1"},
return_value=subscription_builder(),
),
):
assert method(api, "github", "b1") == {"id": "b1"}
assert method(api, "github", "b1")["id"] == "b1"
def test_verify_builder(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderVerifyApi()
@ -156,10 +200,10 @@ class TestTriggerSubscriptionBuilderApis:
app.test_request_context("/", json={"credentials": {"a": 1}}),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
return_value={"ok": True},
return_value={"verified": True},
),
):
assert method(api, "t1", mock_user(), "github", "b1") == {"ok": True}
assert method(api, "t1", mock_user(), "github", "b1") == {"verified": True}
def test_verify_builder_error(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderVerifyApi()
@ -183,26 +227,24 @@ class TestTriggerSubscriptionBuilderApis:
app.test_request_context("/", json={"name": "n"}),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
return_value={"id": "b1"},
return_value=subscription_builder(),
),
):
assert method(api, "t1", "github", "b1") == {"id": "b1"}
assert method(api, "t1", "github", "b1")["id"] == "b1"
def test_logs(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderLogsApi()
method = unwrap(api.get)
log = MagicMock()
log.model_dump.return_value = {"a": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs",
return_value=[log],
return_value=[request_log()],
),
):
assert "logs" in method(api, "github", "b1")
result = method(api, "github", "b1")
assert result["logs"][0]["id"] == "log1"
def test_build(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderBuildApi()
@ -215,7 +257,7 @@ class TestTriggerSubscriptionBuilderApis:
return_value=None,
),
):
assert method(api, "t1", mock_user(), "github", "b1") == 200
assert method(api, "t1", mock_user(), "github", "b1") == {"result": "success"}
class TestTriggerSubscriptionCrud:
@ -239,7 +281,7 @@ class TestTriggerSubscriptionCrud:
),
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"),
):
assert method(api, "t1", "s1") == 200
assert method(api, "t1", "s1") == {"result": "success"}
def test_update_not_found(self, app: Flask) -> None:
api = TriggerSubscriptionUpdateApi()
@ -275,7 +317,7 @@ class TestTriggerSubscriptionCrud:
"controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription"
),
):
assert method(api, "t1", "s1") == 200
assert method(api, "t1", "s1") == {"result": "success"}
def test_delete_subscription(self, app: Flask) -> None:
api = TriggerSubscriptionDeleteApi()
@ -336,7 +378,7 @@ class TestTriggerOAuthApis:
),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
return_value=MagicMock(id="b1"),
return_value=subscription_builder(),
),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context",
@ -480,7 +522,7 @@ class TestTriggerOAuthClientManageApi:
),
patch(
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider",
return_value=MagicMock(get_oauth_client_schema=lambda: {}),
return_value=MagicMock(get_oauth_client_schema=lambda: []),
),
):
result = method(api, "t1", "github")
@ -494,10 +536,10 @@ class TestTriggerOAuthClientManageApi:
app.test_request_context("/", json={"enabled": True}),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
return_value={"ok": True},
return_value={"result": "success"},
),
):
assert method(api, "t1", "github") == {"ok": True}
assert method(api, "t1", "github") == {"result": "success"}
def test_delete_client(self, app: Flask) -> None:
api = TriggerOAuthClientManageApi()
@ -507,10 +549,10 @@ class TestTriggerOAuthClientManageApi:
app.test_request_context("/"),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params",
return_value={"ok": True},
return_value={"result": "success"},
),
):
assert method(api, "t1", "github") == {"ok": True}
assert method(api, "t1", "github") == {"result": "success"}
def test_oauth_client_post_value_error(self, app: Flask) -> None:
api = TriggerOAuthClientManageApi()
@ -540,10 +582,10 @@ class TestTriggerSubscriptionVerifyApi:
app.test_request_context("/", json={"credentials": {}}),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
return_value={"ok": True},
return_value={"verified": True},
),
):
assert method(api, "t1", mock_user(), "github", "s1") == {"ok": True}
assert method(api, "t1", mock_user(), "github", "s1") == {"verified": True}
@pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")])
def test_verify_errors(self, app: Flask, raised_exception: Exception) -> None:

View File

@ -1,6 +1,7 @@
import inspect
import json
from unittest.mock import patch
from collections.abc import Iterator
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
@ -10,16 +11,18 @@ from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ApiProviderSchemaType
from core.tools.errors import ApiToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from models import Account, Tenant
from models import Account, AccountStatus, Tenant, TenantStatus
from models.tools import ApiToolProvider
from services.tools.api_tools_manage_service import ApiToolManageService
MockDependencies = dict[str, MagicMock]
class TestApiToolManageService:
"""Integration tests for ApiToolManageService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
def mock_external_service_dependencies(self) -> Iterator[MockDependencies]:
"""Mock setup for external service dependencies."""
with (
patch("services.tools.api_tools_manage_service.ToolLabelManager") as mock_tool_label_manager,
@ -39,7 +42,9 @@ class TestApiToolManageService:
"provider_controller": mock_provider_controller,
}
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
def _create_test_account_and_tenant(
self, db_session_with_containers: Session, mock_external_service_dependencies: MockDependencies
) -> tuple[Account, Tenant]:
"""
Helper method to create a test account and tenant for testing.
@ -57,7 +62,7 @@ class TestApiToolManageService:
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
@ -66,7 +71,7 @@ class TestApiToolManageService:
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
status=TenantStatus.NORMAL,
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
@ -88,7 +93,7 @@ class TestApiToolManageService:
return account, tenant
def _create_test_openapi_schema(self):
def _create_test_openapi_schema(self) -> str:
"""Helper method to create a test OpenAPI schema."""
return """
{
@ -121,8 +126,11 @@ class TestApiToolManageService:
"""
def test_parser_api_schema_success(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test successful parsing of API schema.
@ -148,6 +156,8 @@ class TestApiToolManageService:
# Verify credentials schema structure
credentials_schema = result["credentials_schema"]
assert len(credentials_schema) == 3
assert all(isinstance(field, dict) for field in credentials_schema)
assert all(isinstance(tool, dict) for tool in result["parameters_schema"])
# Check auth_type field
auth_type_field = next(field for field in credentials_schema if field["name"] == "auth_type")
@ -166,8 +176,11 @@ class TestApiToolManageService:
assert api_key_value_field["default"] == ""
def test_parser_api_schema_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test parsing of invalid API schema.
@ -186,8 +199,11 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value)
def test_parser_api_schema_malformed_json(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test parsing of malformed JSON schema.
@ -206,8 +222,11 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value)
def test_convert_schema_to_tool_bundles_success(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test successful conversion of schema to tool bundles.
@ -236,8 +255,11 @@ class TestApiToolManageService:
assert tool_bundle.operation_id == "testOperation"
def test_convert_schema_to_tool_bundles_with_extra_info(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test successful conversion of schema to tool bundles with extra info.
@ -262,8 +284,11 @@ class TestApiToolManageService:
assert isinstance(schema_type, str)
def test_convert_schema_to_tool_bundles_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test conversion of invalid schema to tool bundles.
@ -282,8 +307,11 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value)
def test_create_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test successful creation of API tool provider.
@ -301,7 +329,7 @@ class TestApiToolManageService:
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
icon = {"content": "🔧", "background": "#FFF"}
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
@ -341,6 +369,7 @@ class TestApiToolManageService:
assert provider.schema_type_str == schema_type
assert provider.privacy_policy == privacy_policy
assert provider.custom_disclaimer == custom_disclaimer
assert json.loads(provider.icon) == icon
# Verify mock interactions
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
@ -349,8 +378,11 @@ class TestApiToolManageService:
mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once()
def test_create_api_tool_provider_duplicate_name(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test creation of API tool provider with duplicate name.
@ -366,7 +398,7 @@ class TestApiToolManageService:
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
icon = {"content": "🔧", "background": "#FFF"}
credentials = {"auth_type": "none"}
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
@ -406,8 +438,11 @@ class TestApiToolManageService:
assert f"provider {provider_name} already exists" in str(exc_info.value)
def test_create_api_tool_provider_invalid_schema_type(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test creation of API tool provider with invalid schema type.
@ -423,7 +458,7 @@ class TestApiToolManageService:
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
icon = {"content": "🔧", "background": "#FFF"}
credentials = {"auth_type": "none"}
schema_type = "invalid_type"
schema = self._create_test_openapi_schema()
@ -438,8 +473,11 @@ class TestApiToolManageService:
assert "validation error" in str(exc_info.value)
def test_create_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test creation of API tool provider with missing auth type.
@ -455,7 +493,7 @@ class TestApiToolManageService:
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
icon = {"content": "🔧", "background": "#FFF"}
credentials = {} # Missing auth_type
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
@ -481,8 +519,11 @@ class TestApiToolManageService:
assert "auth_type is required" in str(exc_info.value)
def test_create_api_tool_provider_with_api_key_auth(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test successful creation of API tool provider with API key authentication.
@ -498,7 +539,7 @@ class TestApiToolManageService:
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔑"}
icon = {"content": "🔑", "background": "#FFF"}
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
@ -542,8 +583,11 @@ class TestApiToolManageService:
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
def test_delete_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""Test successful deletion of an API tool provider."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
@ -583,8 +627,8 @@ class TestApiToolManageService:
assert deleted is None
def test_delete_api_tool_provider_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MockDependencies
) -> None:
"""Test deletion raises ValueError when provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
@ -595,14 +639,15 @@ class TestApiToolManageService:
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
def test_update_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
fake = Faker()
# Firmware fix for cache.delete() in update flow
mock_encrypter = mock_external_service_dependencies["encrypter"]
from unittest.mock import MagicMock
mock_cache = MagicMock()
mock_cache.delete.return_value = None
mock_encrypter.return_value = (mock_encrypter, mock_cache)
@ -620,7 +665,7 @@ class TestApiToolManageService:
user_id=account.id,
tenant_id=tenant.id,
provider_name=original_name,
icon={"type": "emoji", "value": "🔧"},
icon={"content": "🔧", "background": "#FFF"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema=self._create_test_openapi_schema(),
@ -646,7 +691,7 @@ class TestApiToolManageService:
provider_name=new_name,
original_provider=original_name,
# new icon - changed 2
icon={"type": "emoji", "value": "🚀"},
icon={"content": "🚀", "background": "#FFF"},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=self._create_test_openapi_schema(),
@ -677,9 +722,7 @@ class TestApiToolManageService:
# - changed 1
assert updated_provider.name == new_name
# - changed 2
icon_data = json.loads(updated_provider.icon)
assert icon_data["type"] == "emoji"
assert icon_data["value"] == "🚀"
assert json.loads(updated_provider.icon) == {"content": "🚀", "background": "#FFF"}
# - changed 3
assert updated_provider.privacy_policy == "https://new-policy.com"
# - changed 4
@ -712,8 +755,11 @@ class TestApiToolManageService:
)
def test_update_api_tool_provider_not_found(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""
Test update raises ValueError when original provider not found.
@ -733,7 +779,7 @@ class TestApiToolManageService:
user_id=account.id,
tenant_id=tenant.id,
provider_name=existing_provider_name,
icon={"type": "emoji", "value": "🔧"},
icon={"content": "🔧", "background": "#FFF"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema=self._create_test_openapi_schema(),
@ -756,7 +802,7 @@ class TestApiToolManageService:
tenant_id=tenant.id,
provider_name=target_new_name,
original_provider=missing_original_name,
icon={"type": "emoji", "value": "🚀"},
icon={"content": "🚀", "background": "#FFF"},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=self._create_test_openapi_schema(),
@ -793,8 +839,11 @@ class TestApiToolManageService:
mock_external_service_dependencies["provider_controller"].from_db.assert_not_called()
def test_update_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
self,
flask_req_ctx_with_containers: object,
db_session_with_containers: Session,
mock_external_service_dependencies: MockDependencies,
) -> None:
"""Test update raises ValueError when auth_type is missing from credentials."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
@ -822,7 +871,7 @@ class TestApiToolManageService:
tenant_id=tenant.id,
provider_name=provider_name,
original_provider=provider_name,
icon={},
icon={"content": "🔧", "background": "#FFF"},
credentials={},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
@ -832,8 +881,8 @@ class TestApiToolManageService:
)
def test_list_api_tool_provider_tools_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MockDependencies
) -> None:
"""Test listing tools raises ValueError when provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
@ -844,8 +893,8 @@ class TestApiToolManageService:
ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent")
def test_test_api_tool_preview_invalid_schema_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MockDependencies
) -> None:
"""Test preview raises ValueError for invalid schema type."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(

View File

@ -5,7 +5,6 @@ from __future__ import annotations
import builtins
import importlib
from contextlib import ExitStack, contextmanager
from inspect import unwrap
from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock, patch
@ -13,6 +12,9 @@ import pytest
from flask import Flask
from flask.views import MethodView
from core.tools.entities.api_entities import ToolProviderApiEntity as CoreToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter
from models import Account
from models.account import TenantAccountRole
@ -21,6 +23,7 @@ if not hasattr(builtins, "MethodView"):
_CONTROLLER_MODULE: ModuleType | None = None
_WRAPS_MODULE: ModuleType | None = None
@contextmanager
@ -69,10 +72,11 @@ def controller_module(monkeypatch: pytest.MonkeyPatch):
_CONTROLLER_MODULE = importlib.import_module(module_name)
module = _CONTROLLER_MODULE
monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload)
# Ensure decorators that consult deployment edition do not reach the database.
global _WRAPS_MODULE
wraps_module = importlib.import_module("controllers.console.wraps")
_WRAPS_MODULE = wraps_module
monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD")
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
@ -88,19 +92,194 @@ def _mock_account(user_id: str = "user-123") -> Account:
return user
def _set_current_account(
monkeypatch: pytest.MonkeyPatch,
controller_module: ModuleType,
user: Account,
tenant_id: str,
) -> None:
def _getter():
return user, tenant_id
monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter, raising=False)
if _WRAPS_MODULE is not None:
monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter)
login_module = importlib.import_module("libs.login")
monkeypatch.setattr(login_module, "_get_user", lambda: user)
def _i18n(text: str) -> dict[str, str]:
return {"en_US": text, "zh_Hans": text, "pt_BR": text, "ja_JP": text}
def _tool_response(controller_module: ModuleType, name: str = "tool-a") -> tuple[dict, dict]:
expected = {
"author": "Dify",
"name": name,
"label": _i18n(name),
"description": _i18n(f"{name} description"),
"parameters": [],
"labels": [],
"output_schema": {},
}
tool = controller_module.ToolApiEntity.model_validate(expected)
return tool.model_dump(mode="json"), expected
def _provider_entity_response(
controller_module: ModuleType, name: str = "provider", provider_type: str = "builtin"
) -> tuple[CoreToolProviderApiEntity, dict]:
service_payload = {
"id": f"{name}-id",
"author": "Dify",
"name": name,
"description": _i18n(f"{name} description"),
"icon": "tool.svg",
"icon_dark": "",
"label": _i18n(name),
"type": provider_type,
"masked_credentials": {"api_key": "[__HIDDEN__]"},
"original_credentials": {"api_key": "sk-secret"},
"is_team_authorization": False,
"allow_delete": True,
"plugin_id": "",
"plugin_unique_identifier": "",
"tools": [],
"labels": [],
"server_url": "",
"updated_at": 1,
"server_identifier": "",
"masked_headers": None,
"original_headers": None,
"authentication": None,
"is_dynamic_registration": True,
"configuration": None,
"identity_mode": "off",
"workflow_app_id": None,
}
provider = CoreToolProviderApiEntity.model_validate(service_payload)
return provider, provider.to_dict()
def _provider_list_item(
controller_module: ModuleType, name: str = "provider", provider_type: str = "builtin"
) -> tuple[dict, dict]:
service_payload = {
"id": f"{name}-id",
"author": "Dify",
"name": name,
"description": _i18n(f"{name} description"),
"icon": "tool.svg",
"icon_dark": "",
"label": _i18n(name),
"type": provider_type,
"team_credentials": {"api_key": "[__HIDDEN__]"},
"is_team_authorization": False,
"allow_delete": True,
"plugin_id": "",
"plugin_unique_identifier": "",
"tools": [],
"labels": [],
}
expected = {
**service_payload,
}
provider = controller_module.ToolProviderApiEntityResponse.model_validate(expected)
return service_payload, provider.model_dump(mode="json", exclude_unset=True)
def _credential_response(controller_module: ModuleType, credential_id: str = "cred-1") -> tuple[dict, dict]:
expected = {
"id": credential_id,
"name": "Credential",
"provider": "demo",
"credential_type": controller_module.CredentialType.API_KEY,
"is_default": False,
"credentials": {},
"visibility": "all_team_members",
"created_by": "",
"partial_member_list": [],
"from_other_member": False,
}
credential = controller_module.ToolProviderCredentialApiEntity.model_validate(expected)
return credential.model_dump(mode="json"), credential.model_dump(mode="json")
def _provider_config_response(controller_module: ModuleType) -> tuple[dict, dict]:
expected = {
"type": "secret-input",
"name": "api_key",
"scope": None,
"required": False,
"default": None,
"options": None,
"multiple": False,
"label": None,
"help": None,
"url": None,
"placeholder": None,
}
config = controller_module.ProviderConfig.model_validate(expected)
return config.model_dump(mode="json"), expected
def _api_provider_detail_response(controller_module: ModuleType) -> tuple[dict, dict]:
expected = {
"schema_type": "openapi",
"schema": "{}",
"tools": [],
"icon": {"background": "#252525", "content": "tool"},
"description": "provider description",
"credentials": {"auth_type": "none"},
"privacy_policy": "",
"custom_disclaimer": "",
"labels": [],
}
detail = controller_module.ApiProviderDetailResponse.model_validate(expected)
return detail.model_dump(mode="json", by_alias=True), expected
def _workflow_detail_response(controller_module: ModuleType) -> tuple[dict, dict]:
tool_payload, tool_expected = _tool_response(controller_module, "workflow-tool")
expected = {
"name": "workflow-tool",
"label": "Workflow Tool",
"workflow_tool_id": "00000000-0000-0000-0000-000000000001",
"workflow_app_id": "00000000-0000-0000-0000-000000000002",
"icon": {"background": "#252525", "content": "tool"},
"description": "description",
"parameters": [],
"output_schema": {},
"tool": tool_expected,
"synced": True,
"privacy_policy": "",
}
service_payload = {**expected, "tool": tool_payload}
detail = controller_module.WorkflowToolDetailResponse.model_validate(service_payload)
return detail.model_dump(mode="json"), expected
def _tool_label_response(controller_module: ModuleType, name: str = "search") -> tuple[dict, dict]:
expected = {"name": name, "label": _i18n(name), "icon": "search"}
label = controller_module.ToolLabel.model_validate(expected)
return label.model_dump(mode="json"), expected
def test_tool_provider_list_calls_service_with_query(
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
service_mock = MagicMock(return_value=[{"provider": "builtin"}])
service_payload, expected_response = _provider_list_item(controller_module, "builtin", "builtin")
service_mock = MagicMock(return_value=[service_payload])
monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock)
with app.test_request_context("/workspaces/current/tool-providers?type=builtin"):
api = controller_module.ToolProviderListApi()
response = unwrap(api.get)(api, "tenant-456", user)
response = controller_module.ToolProviderListApi().get()
assert response == [{"provider": "builtin"}]
assert response == [expected_response]
service_mock.assert_called_once_with(user.id, "tenant-456", "builtin")
@ -108,8 +287,9 @@ def test_builtin_provider_add_passes_payload(
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
service_mock = MagicMock(return_value={"status": "ok"})
service_mock = MagicMock(return_value={"result": "success"})
monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock)
payload = {
@ -123,10 +303,9 @@ def test_builtin_provider_add_passes_payload(
method="POST",
json=payload,
):
api = controller_module.ToolBuiltinProviderAddApi()
response = unwrap(api.post)(api, "tenant-456", user, provider="openai")
response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai")
assert response == {"status": "ok"}
assert response == {"result": "success"}
service_mock.assert_called_once_with(
user_id="user-123",
tenant_id="tenant-456",
@ -140,38 +319,88 @@ def test_builtin_provider_add_passes_payload(
def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-789")
_set_current_account(monkeypatch, controller_module, user, "tenant-789")
service_mock = MagicMock(return_value=[{"name": "tool-a"}])
service_payload, expected_response = _tool_response(controller_module, "tool-a")
service_mock = MagicMock(return_value=[service_payload])
monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock)
monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload)
with app.test_request_context(
"/workspaces/current/tool-provider/builtin/my-provider/tools",
method="GET",
):
api = controller_module.ToolBuiltinProviderListToolsApi()
response = unwrap(api.get)(api, "tenant-789", provider="my-provider")
response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider")
assert response == [{"name": "tool-a"}]
assert response == [expected_response]
service_mock.assert_called_once_with("tenant-789", "my-provider")
def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-9")
service_mock = MagicMock(return_value={"info": True})
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
service_payload, expected_response = _provider_entity_response(controller_module, "demo", "builtin")
service_mock = MagicMock(return_value=service_payload)
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
with app.test_request_context("/info", method="GET"):
api = controller_module.ToolBuiltinProviderInfoApi()
resp = unwrap(api.get)(api, "tenant-9", provider="demo")
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
assert resp == {"info": True}
assert resp == expected_response
service_mock.assert_called_once_with("tenant-9", "demo")
def test_builtin_provider_info_uses_core_to_dict_tool_projection(
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
):
user = _mock_account("user-tenant-9")
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
tool_parameter = ToolParameter(
name="system_files",
label=I18nObject(en_US="System Files", zh_Hans="System Files"),
type=ToolParameter.ToolParameterType.SYSTEM_FILES,
form=ToolParameter.ToolParameterForm.LLM,
input_schema=None,
)
tool = controller_module.ToolApiEntity(
author="Dify",
name="demo-tool",
label=I18nObject(en_US="Demo Tool", zh_Hans="Demo Tool"),
description=I18nObject(en_US="Demo Tool description", zh_Hans="Demo Tool description"),
parameters=[tool_parameter],
labels=[],
output_schema={},
)
provider = CoreToolProviderApiEntity(
id="demo-id",
author="Dify",
name="demo",
description=I18nObject(en_US="demo description", zh_Hans="demo description"),
icon="tool.svg",
label=I18nObject(en_US="demo", zh_Hans="demo"),
type=controller_module.ToolProviderType.BUILT_IN,
masked_credentials={"api_key": "[__HIDDEN__]"},
original_credentials={"api_key": "sk-secret"},
tools=[tool],
)
service_mock = MagicMock(return_value=provider)
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
with app.test_request_context("/info", method="GET"):
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
parameter = resp["tools"][0]["parameters"][0]
assert parameter["type"] == "files"
assert parameter["input_schema"] is None
assert resp["team_credentials"] == {"api_key": "[__HIDDEN__]"}
assert "masked_credentials" not in resp
assert "original_credentials" not in resp
def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-cred")
service_mock = MagicMock(return_value=[{"cred": 1}])
_set_current_account(monkeypatch, controller_module, user, "tenant-cred")
service_payload, expected_response = _credential_response(controller_module)
service_mock = MagicMock(return_value=[service_payload])
monkeypatch.setattr(
controller_module.BuiltinToolManageService,
"get_builtin_tool_provider_credentials",
@ -179,10 +408,9 @@ def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeyp
)
with app.test_request_context("/creds", method="GET"):
api = controller_module.ToolBuiltinProviderGetCredentialsApi()
resp = unwrap(api.get)(api, "tenant-cred", user, provider="demo")
resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo")
assert resp == [{"cred": 1}]
assert resp == [expected_response]
service_mock.assert_called_once_with(
tenant_id="tenant-cred",
provider_name="demo",
@ -193,46 +421,51 @@ def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeyp
def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
service_mock = MagicMock(return_value={"schema": "ok"})
_set_current_account(monkeypatch, controller_module, user, "tenant-10")
openapi_schema = '{"openapi":"3.0.0","info":{"title":"Demo API","version":"1.0.0"},"paths":{}}'
service_mock = MagicMock(return_value={"schema": openapi_schema})
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock)
with app.test_request_context("/remote?url=https://example.com/"):
api = controller_module.ToolApiProviderGetRemoteSchemaApi()
resp = unwrap(api.get)(api, "tenant-10", user)
resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get()
assert resp == {"schema": "ok"}
assert resp == {"schema": openapi_schema}
service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/")
def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
service_mock = MagicMock(return_value=[{"tool": "t"}])
_set_current_account(monkeypatch, controller_module, user, "tenant-11")
service_payload, expected_response = _tool_response(controller_module, "t")
service_mock = MagicMock(return_value=[service_payload])
monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock)
with app.test_request_context("/tools?provider=foo"):
api = controller_module.ToolApiProviderListToolsApi()
resp = unwrap(api.get)(api, "tenant-11", user)
resp = controller_module.ToolApiProviderListToolsApi().get()
assert resp == [{"tool": "t"}]
assert resp == [expected_response]
service_mock.assert_called_once_with(user.id, "tenant-11", "foo")
def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
service_mock = MagicMock(return_value={"provider": "foo"})
_set_current_account(monkeypatch, controller_module, user, "tenant-12")
service_payload, expected_response = _api_provider_detail_response(controller_module)
service_mock = MagicMock(return_value=service_payload)
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock)
with app.test_request_context("/get?provider=foo"):
api = controller_module.ToolApiProviderGetApi()
resp = unwrap(api.get)(api, "tenant-12", user)
resp = controller_module.ToolApiProviderGetApi().get()
assert resp == {"provider": "foo"}
assert resp == expected_response
service_mock.assert_called_once_with(user.id, "tenant-12", "foo")
def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-13")
service_mock = MagicMock(return_value={"schema": True})
_set_current_account(monkeypatch, controller_module, user, "tenant-13")
service_payload, expected_response = _provider_config_response(controller_module)
service_mock = MagicMock(return_value=[service_payload])
monkeypatch.setattr(
controller_module.BuiltinToolManageService,
"list_builtin_provider_credentials_schema",
@ -240,16 +473,19 @@ def test_builtin_provider_credentials_schema_get(app: Flask, controller_module,
)
with app.test_request_context("/schema", method="GET"):
api = controller_module.ToolBuiltinProviderCredentialsSchemaApi()
resp = unwrap(api.get)(api, "tenant-13", provider="demo", credential_type="api-key")
resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get(
provider="demo", credential_type="api-key"
)
assert resp == {"schema": True}
assert resp == [expected_response]
service_mock.assert_called_once()
def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
tool_service = MagicMock(return_value={"wf": 1})
_set_current_account(monkeypatch, controller_module, user, "tenant-wf")
service_payload, expected_response = _workflow_detail_response(controller_module)
tool_service = MagicMock(return_value=service_payload)
monkeypatch.setattr(
controller_module.WorkflowToolManageService,
"get_workflow_tool_by_tool_id",
@ -258,16 +494,17 @@ def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatc
tool_id = "00000000-0000-0000-0000-000000000001"
with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"):
api = controller_module.ToolWorkflowProviderGetApi()
resp = unwrap(api.get)(api, "tenant-wf", user)
resp = controller_module.ToolWorkflowProviderGetApi().get()
assert resp == {"wf": 1}
assert resp == expected_response
tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id)
def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
service_mock = MagicMock(return_value={"app": 1})
_set_current_account(monkeypatch, controller_module, user, "tenant-wf2")
service_payload, expected_response = _workflow_detail_response(controller_module)
service_mock = MagicMock(return_value=service_payload)
monkeypatch.setattr(
controller_module.WorkflowToolManageService,
"get_workflow_tool_by_app_id",
@ -276,31 +513,32 @@ def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch
app_id = "00000000-0000-0000-0000-000000000002"
with app.test_request_context(f"/workflow?workflow_app_id={app_id}"):
api = controller_module.ToolWorkflowProviderGetApi()
resp = unwrap(api.get)(api, "tenant-wf2", user)
resp = controller_module.ToolWorkflowProviderGetApi().get()
assert resp == {"app": 1}
assert resp == expected_response
service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id)
def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
service_mock = MagicMock(return_value=[{"id": 1}])
_set_current_account(monkeypatch, controller_module, user, "tenant-wf3")
service_payload, expected_response = _tool_response(controller_module, "workflow-tool")
service_mock = MagicMock(return_value=[service_payload])
monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock)
tool_id = "00000000-0000-0000-0000-000000000003"
with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"):
api = controller_module.ToolWorkflowProviderListToolApi()
resp = unwrap(api.get)(api, "tenant-wf3", user)
resp = controller_module.ToolWorkflowProviderListToolApi().get()
assert resp == [{"id": 1}]
assert resp == [expected_response]
service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id)
def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-bt")
provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"})
provider, expected_response = _provider_entity_response(controller_module, "builtin", "builtin")
monkeypatch.setattr(
controller_module.BuiltinToolManageService,
"list_builtin_tools",
@ -308,16 +546,16 @@ def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.M
)
with app.test_request_context("/tools/builtin"):
api = controller_module.ToolBuiltinListApi()
resp = unwrap(api.get)(api, "tenant-bt", user)
resp = controller_module.ToolBuiltinListApi().get()
assert resp == [{"name": "builtin"}]
assert resp == [expected_response]
def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-api")
_set_current_account(monkeypatch, controller_module, user, "tenant-api")
provider = SimpleNamespace(to_dict=lambda: {"name": "api"})
provider, expected_response = _provider_entity_response(controller_module, "api", "api")
monkeypatch.setattr(
controller_module.ApiToolManageService,
"list_api_tools",
@ -325,16 +563,16 @@ def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.Monke
)
with app.test_request_context("/tools/api"):
api = controller_module.ToolApiListApi()
resp = unwrap(api.get)(api, "tenant-api")
resp = controller_module.ToolApiListApi().get()
assert resp == [{"name": "api"}]
assert resp == [expected_response]
def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-wf4")
provider = SimpleNamespace(to_dict=lambda: {"name": "wf"})
provider, expected_response = _provider_entity_response(controller_module, "wf", "workflow")
monkeypatch.setattr(
controller_module.WorkflowToolManageService,
"list_tenant_workflow_tools",
@ -342,20 +580,21 @@ def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.
)
with app.test_request_context("/tools/workflow"):
api = controller_module.ToolWorkflowListApi()
resp = unwrap(api.get)(api, "tenant-wf4", user)
resp = controller_module.ToolWorkflowListApi().get()
assert resp == [{"name": "wf"}]
assert resp == [expected_response]
def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"])
user = _mock_account("user-label")
_set_current_account(monkeypatch, controller_module, user, "tenant-labels")
service_payload, expected_response = _tool_label_response(controller_module, "a")
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: [service_payload])
with app.test_request_context("/tool-labels"):
api = controller_module.ToolLabelsApi()
resp = unwrap(api.get)(api)
resp = controller_module.ToolLabelsApi().get()
assert resp == ["a", "b"]
assert resp == [expected_response]
# --- _resolve_identity_mode: gating + None-resolution (PR #36839 review) ---

View File

@ -1009,8 +1009,7 @@ class TestExternalDatasetApi:
# 4. Provide default values when parameters are missing
# 5. Raise BadRequest exceptions when validation fails
#
# Response formatting is handled by Flask-RESTX's marshal_with decorator
# or marshal function, which:
# Response formatting is handled by controller response schemas, which:
#
# 1. Formats response data according to defined models
# 2. Handles nested objects and lists

View File

@ -269,6 +269,7 @@ export type MessageDetailResponse = {
agent_thoughts?: Array<AgentThought>
annotation?: ConversationAnnotation | null
annotation_hit_history?: ConversationAnnotationHitHistory | null
answer: string
answer_tokens?: number | null
conversation_id: string
created_at?: number | null
@ -284,12 +285,11 @@ export type MessageDetailResponse = {
}
message?: JsonValue | null
message_files?: Array<MessageFile>
message_metadata_dict?: JsonValue | null
message_tokens?: number | null
metadata?: JsonValue | null
parent_message_id?: string | null
provider_response_latency?: number | null
query: string
re_sign_file_url_answer: string
status: string
workflow_run_id?: string | null
}
@ -723,7 +723,6 @@ export type AgentThought = {
created_at?: number | null
files: Array<string>
id: string
message_chain_id?: string | null
message_id: string
observation?: string | null
position: number
@ -743,8 +742,8 @@ export type ConversationAnnotation = {
export type ConversationAnnotationHitHistory = {
annotation_create_account?: SimpleAccount | null
annotation_id: string
created_at?: number | null
id: string
}
export type HumanInputContent = {

View File

@ -570,7 +570,6 @@ export const zAgentThought = z.object({
created_at: z.int().nullish(),
files: z.array(z.string()),
id: z.string(),
message_chain_id: z.string().nullish(),
message_id: z.string(),
observation: z.string().nullish(),
position: z.int(),
@ -1056,8 +1055,8 @@ export const zConversationAnnotation = z.object({
*/
export const zConversationAnnotationHitHistory = z.object({
annotation_create_account: zSimpleAccount.nullish(),
annotation_id: z.string(),
created_at: z.int().nullish(),
id: z.string(),
})
/**
@ -2035,6 +2034,7 @@ export const zMessageDetailResponse = z.object({
agent_thoughts: z.array(zAgentThought).optional(),
annotation: zConversationAnnotation.nullish(),
annotation_hit_history: zConversationAnnotationHitHistory.nullish(),
answer: z.string(),
answer_tokens: z.int().nullish(),
conversation_id: z.string(),
created_at: z.int().nullish(),
@ -2048,12 +2048,11 @@ export const zMessageDetailResponse = z.object({
inputs: z.record(z.string(), zJsonValue),
message: zJsonValue.nullish(),
message_files: z.array(zMessageFile).optional(),
message_metadata_dict: zJsonValue.nullish(),
message_tokens: z.int().nullish(),
metadata: zJsonValue.nullish(),
parent_message_id: z.string().nullish(),
provider_response_latency: z.number().nullish(),
query: z.string(),
re_sign_file_url_answer: z.string(),
status: z.string(),
workflow_run_id: z.string().nullish(),
})

View File

@ -472,6 +472,7 @@ export type MessageDetailResponse = {
agent_thoughts?: Array<AgentThought>
annotation?: ConversationAnnotation | null
annotation_hit_history?: ConversationAnnotationHitHistory | null
answer: string
answer_tokens?: number | null
conversation_id: string
created_at?: number | null
@ -487,12 +488,11 @@ export type MessageDetailResponse = {
}
message?: JsonValue | null
message_files?: Array<MessageFile>
message_metadata_dict?: JsonValue | null
message_tokens?: number | null
metadata?: JsonValue | null
parent_message_id?: string | null
provider_response_latency?: number | null
query: string
re_sign_file_url_answer: string
status: string
workflow_run_id?: string | null
}
@ -1498,7 +1498,6 @@ export type AgentThought = {
created_at?: number | null
files: Array<string>
id: string
message_chain_id?: string | null
message_id: string
observation?: string | null
position: number
@ -1518,8 +1517,8 @@ export type ConversationAnnotation = {
export type ConversationAnnotationHitHistory = {
annotation_create_account?: SimpleAccount | null
annotation_id: string
created_at?: number | null
id: string
}
export type HumanInputContent = {
@ -1984,7 +1983,14 @@ export type ModelConfigPartial = {
export type LlmMode = 'chat' | 'completion'
export type Type = 'github' | 'marketplace' | 'package'
export type Type
= | 'app-selector'
| 'array[tools]'
| 'boolean'
| 'model-selector'
| 'secret-input'
| 'select'
| 'text-input'
export type Github = {
github_plugin_unique_identifier: string

View File

@ -1150,7 +1150,6 @@ export const zAgentThought = z.object({
created_at: z.int().nullish(),
files: z.array(z.string()),
id: z.string(),
message_chain_id: z.string().nullish(),
message_id: z.string(),
observation: z.string().nullish(),
position: z.int(),
@ -1371,8 +1370,8 @@ export const zConversationAnnotation = z.object({
*/
export const zConversationAnnotationHitHistory = z.object({
annotation_create_account: zSimpleAccount.nullish(),
annotation_id: z.string(),
created_at: z.int().nullish(),
id: z.string(),
})
/**
@ -2169,7 +2168,15 @@ export const zConversationMessageDetail = z.object({
/**
* Type
*/
export const zType = z.enum(['github', 'marketplace', 'package'])
export const zType = z.enum([
'app-selector',
'array[tools]',
'boolean',
'model-selector',
'secret-input',
'select',
'text-input',
])
/**
* Github
@ -3455,6 +3462,7 @@ export const zMessageDetailResponse = z.object({
agent_thoughts: z.array(zAgentThought).optional(),
annotation: zConversationAnnotation.nullish(),
annotation_hit_history: zConversationAnnotationHitHistory.nullish(),
answer: z.string(),
answer_tokens: z.int().nullish(),
conversation_id: z.string(),
created_at: z.int().nullish(),
@ -3468,12 +3476,11 @@ export const zMessageDetailResponse = z.object({
inputs: z.record(z.string(), zJsonValue),
message: zJsonValue.nullish(),
message_files: z.array(zMessageFile).optional(),
message_metadata_dict: zJsonValue.nullish(),
message_tokens: z.int().nullish(),
metadata: zJsonValue.nullish(),
parent_message_id: z.string().nullish(),
provider_response_latency: z.number().nullish(),
query: z.string(),
re_sign_file_url_answer: z.string(),
status: z.string(),
workflow_run_id: z.string().nullish(),
})

View File

@ -246,7 +246,6 @@ export type AgentThought = {
created_at?: number | null
files: Array<string>
id: string
message_chain_id?: string | null
message_id: string
observation?: string | null
position: number

View File

@ -266,7 +266,6 @@ export const zAgentThought = z.object({
created_at: z.int().nullish(),
files: z.array(z.string()),
id: z.string(),
message_chain_id: z.string().nullish(),
message_id: z.string(),
observation: z.string().nullish(),
position: z.int(),

View File

@ -145,7 +145,7 @@ export const zGetOauthPluginByProviderToolAuthorizationUrlPath = z.object({
})
/**
* Authorization URL retrieved successfully
* Tool OAuth authorization URL generated successfully
*/
export const zGetOauthPluginByProviderToolAuthorizationUrlResponse
= zPluginOAuthAuthorizationUrlResponse

View File

@ -515,7 +515,14 @@ export type DatasetWeightedScoreResponse = {
weight_type?: string | null
}
export type Type = 'github' | 'marketplace' | 'package'
export type Type
= | 'app-selector'
| 'array[tools]'
| 'boolean'
| 'model-selector'
| 'secret-input'
| 'select'
| 'text-input'
export type Github = {
github_plugin_unique_identifier: string

View File

@ -547,7 +547,15 @@ export const zDatasetRerankingModelResponse = z.object({
/**
* Type
*/
export const zType = z.enum(['github', 'marketplace', 'package'])
export const zType = z.enum([
'app-selector',
'array[tools]',
'boolean',
'model-selector',
'secret-input',
'select',
'text-input',
])
/**
* Github

View File

@ -157,14 +157,10 @@ import {
zGetWorkspacesCurrentToolProviderBuiltinByProviderCredentialsPath,
zGetWorkspacesCurrentToolProviderBuiltinByProviderCredentialsQuery,
zGetWorkspacesCurrentToolProviderBuiltinByProviderCredentialsResponse,
zGetWorkspacesCurrentToolProviderBuiltinByProviderIconPath,
zGetWorkspacesCurrentToolProviderBuiltinByProviderIconResponse,
zGetWorkspacesCurrentToolProviderBuiltinByProviderInfoPath,
zGetWorkspacesCurrentToolProviderBuiltinByProviderInfoResponse,
zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthClientSchemaPath,
zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthClientSchemaResponse,
zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClientPath,
zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClientResponse,
zGetWorkspacesCurrentToolProviderBuiltinByProviderToolsPath,
zGetWorkspacesCurrentToolProviderBuiltinByProviderToolsResponse,
zGetWorkspacesCurrentToolProviderMcpToolsByProviderIdPath,
@ -181,8 +177,6 @@ import {
zGetWorkspacesCurrentToolsBuiltinResponse,
zGetWorkspacesCurrentToolsMcpResponse,
zGetWorkspacesCurrentToolsWorkflowResponse,
zGetWorkspacesCurrentTriggerProviderByProviderIconPath,
zGetWorkspacesCurrentTriggerProviderByProviderIconResponse,
zGetWorkspacesCurrentTriggerProviderByProviderInfoPath,
zGetWorkspacesCurrentTriggerProviderByProviderInfoResponse,
zGetWorkspacesCurrentTriggerProviderByProviderOauthClientPath,
@ -3212,21 +3206,6 @@ export const delete14 = {
}
export const get67 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
operationId: 'getWorkspacesCurrentToolProviderBuiltinByProviderIcon',
path: '/workspaces/current/tool-provider/builtin/{provider}/icon',
tags: ['console'],
})
.input(z.object({ params: zGetWorkspacesCurrentToolProviderBuiltinByProviderIconPath }))
.output(zGetWorkspacesCurrentToolProviderBuiltinByProviderIconResponse)
export const icon2 = {
get: get67,
}
export const get68 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3238,10 +3217,10 @@ export const get68 = oc
.output(zGetWorkspacesCurrentToolProviderBuiltinByProviderInfoResponse)
export const info2 = {
get: get68,
get: get67,
}
export const get69 = oc
export const get68 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3255,7 +3234,7 @@ export const get69 = oc
.output(zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthClientSchemaResponse)
export const clientSchema = {
get: get69,
get: get68,
}
export const delete15 = oc
@ -3273,19 +3252,6 @@ export const delete15 = oc
)
.output(zDeleteWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClientResponse)
export const get70 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
operationId: 'getWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClient',
path: '/workspaces/current/tool-provider/builtin/{provider}/oauth/custom-client',
tags: ['console'],
})
.input(
z.object({ params: zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClientPath }),
)
.output(zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClientResponse)
export const post56 = oc
.route({
inputStructure: 'detailed',
@ -3304,7 +3270,6 @@ export const post56 = oc
export const customClient = {
delete: delete15,
get: get70,
post: post56,
}
@ -3313,7 +3278,7 @@ export const oauth = {
customClient,
}
export const get71 = oc
export const get69 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3325,7 +3290,7 @@ export const get71 = oc
.output(zGetWorkspacesCurrentToolProviderBuiltinByProviderToolsResponse)
export const tools2 = {
get: get71,
get: get69,
}
export const post57 = oc
@ -3354,7 +3319,6 @@ export const byProvider2 = {
credentials: credentials3,
defaultCredential,
delete: delete14,
icon: icon2,
info: info2,
oauth,
tools: tools2,
@ -3380,7 +3344,7 @@ export const auth = {
post: post58,
}
export const get72 = oc
export const get70 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3392,14 +3356,14 @@ export const get72 = oc
.output(zGetWorkspacesCurrentToolProviderMcpToolsByProviderIdResponse)
export const byProviderId = {
get: get72,
get: get70,
}
export const tools3 = {
byProviderId,
}
export const get73 = oc
export const get71 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3411,7 +3375,7 @@ export const get73 = oc
.output(zGetWorkspacesCurrentToolProviderMcpUpdateByProviderIdResponse)
export const byProviderId2 = {
get: get73,
get: get71,
}
export const update4 = {
@ -3490,7 +3454,7 @@ export const delete17 = {
post: post61,
}
export const get74 = oc
export const get72 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3501,11 +3465,11 @@ export const get74 = oc
.input(z.object({ query: zGetWorkspacesCurrentToolProviderWorkflowGetQuery.optional() }))
.output(zGetWorkspacesCurrentToolProviderWorkflowGetResponse)
export const get75 = {
get: get74,
export const get73 = {
get: get72,
}
export const get76 = oc
export const get74 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3517,7 +3481,7 @@ export const get76 = oc
.output(zGetWorkspacesCurrentToolProviderWorkflowToolsResponse)
export const tools4 = {
get: get76,
get: get74,
}
export const post62 = oc
@ -3538,7 +3502,7 @@ export const update5 = {
export const workflow = {
create: create2,
delete: delete17,
get: get75,
get: get73,
tools: tools4,
update: update5,
}
@ -3550,7 +3514,7 @@ export const toolProvider = {
workflow,
}
export const get77 = oc
export const get75 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3562,10 +3526,10 @@ export const get77 = oc
.output(zGetWorkspacesCurrentToolProvidersResponse)
export const toolProviders = {
get: get77,
get: get75,
}
export const get78 = oc
export const get76 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3576,10 +3540,10 @@ export const get78 = oc
.output(zGetWorkspacesCurrentToolsApiResponse)
export const api2 = {
get: get78,
get: get76,
}
export const get79 = oc
export const get77 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3590,10 +3554,10 @@ export const get79 = oc
.output(zGetWorkspacesCurrentToolsBuiltinResponse)
export const builtin2 = {
get: get79,
get: get77,
}
export const get80 = oc
export const get78 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3604,10 +3568,10 @@ export const get80 = oc
.output(zGetWorkspacesCurrentToolsMcpResponse)
export const mcp2 = {
get: get80,
get: get78,
}
export const get81 = oc
export const get79 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3618,7 +3582,7 @@ export const get81 = oc
.output(zGetWorkspacesCurrentToolsWorkflowResponse)
export const workflow2 = {
get: get81,
get: get79,
}
export const tools5 = {
@ -3628,25 +3592,10 @@ export const tools5 = {
workflow: workflow2,
}
export const get82 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
operationId: 'getWorkspacesCurrentTriggerProviderByProviderIcon',
path: '/workspaces/current/trigger-provider/{provider}/icon',
tags: ['console'],
})
.input(z.object({ params: zGetWorkspacesCurrentTriggerProviderByProviderIconPath }))
.output(zGetWorkspacesCurrentTriggerProviderByProviderIconResponse)
export const icon3 = {
get: get82,
}
/**
* Get info for a trigger provider
*/
export const get83 = oc
export const get80 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3659,7 +3608,7 @@ export const get83 = oc
.output(zGetWorkspacesCurrentTriggerProviderByProviderInfoResponse)
export const info3 = {
get: get83,
get: get80,
}
/**
@ -3680,7 +3629,7 @@ export const delete18 = oc
/**
* Get OAuth client configuration for a provider
*/
export const get84 = oc
export const get81 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3714,7 +3663,7 @@ export const post63 = oc
export const client = {
delete: delete18,
get: get84,
get: get81,
post: post63,
}
@ -3781,7 +3730,7 @@ export const create3 = {
/**
* Get the request logs for a subscription instance for a trigger provider
*/
export const get85 = oc
export const get82 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3802,7 +3751,7 @@ export const get85 = oc
)
export const bySubscriptionBuilderId2 = {
get: get85,
get: get82,
}
export const logs = {
@ -3876,7 +3825,7 @@ export const verifyAndUpdate = {
/**
* Get a subscription instance for a trigger provider
*/
export const get86 = oc
export const get83 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3897,7 +3846,7 @@ export const get86 = oc
)
export const bySubscriptionBuilderId5 = {
get: get86,
get: get83,
}
export const builder = {
@ -3912,7 +3861,7 @@ export const builder = {
/**
* List all trigger subscriptions for the current tenant's provider
*/
export const get87 = oc
export const get84 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3925,13 +3874,13 @@ export const get87 = oc
.output(zGetWorkspacesCurrentTriggerProviderByProviderSubscriptionsListResponse)
export const list4 = {
get: get87,
get: get84,
}
/**
* Initiate OAuth authorization flow for a trigger provider
*/
export const get88 = oc
export const get85 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -3948,7 +3897,7 @@ export const get88 = oc
.output(zGetWorkspacesCurrentTriggerProviderByProviderSubscriptionsOauthAuthorizeResponse)
export const authorize = {
get: get88,
get: get85,
}
export const oauth3 = {
@ -3995,7 +3944,6 @@ export const subscriptions = {
}
export const byProvider3 = {
icon: icon3,
info: info3,
oauth: oauth2,
subscriptions,
@ -4065,7 +4013,7 @@ export const triggerProvider = {
/**
* List all trigger providers for the current tenant
*/
export const get89 = oc
export const get86 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -4077,7 +4025,7 @@ export const get89 = oc
.output(zGetWorkspacesCurrentTriggersResponse)
export const triggers = {
get: get89,
get: get86,
}
export const post71 = oc
@ -4177,7 +4125,7 @@ export const switch3 = {
post: post75,
}
export const get90 = oc
export const get87 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -4189,7 +4137,7 @@ export const get90 = oc
.output(zGetWorkspacesByTenantIdModelProvidersByProviderByIconTypeByLangResponse)
export const byLang = {
get: get90,
get: get87,
}
export const byIconType = {
@ -4208,7 +4156,7 @@ export const byTenantId = {
modelProviders: modelProviders2,
}
export const get91 = oc
export const get88 = oc
.route({
inputStructure: 'detailed',
method: 'GET',
@ -4219,7 +4167,7 @@ export const get91 = oc
.output(zGetWorkspacesResponse)
export const workspaces = {
get: get91,
get: get88,
current,
customConfig,
info: info4,

View File

@ -10,13 +10,21 @@ type SwaggerSchema = JsonObject & {
$ref?: string
}
type OpenApiMediaType = JsonObject & {
schema?: SwaggerSchema
}
type OpenApiResponse = JsonObject & {
content?: Record<string, OpenApiMediaType>
}
type OpenApiComponents = JsonObject & {
schemas?: Record<string, SwaggerSchema>
}
type SwaggerOperation = JsonObject & {
operationId?: string
responses?: Record<string, unknown>
responses?: Record<string, OpenApiResponse>
}
type SwaggerDocument = JsonObject & {
@ -52,6 +60,17 @@ const currentDir = path.dirname(fileURLToPath(import.meta.url))
const apiOpenApiDir = path.resolve(currentDir, 'openapi')
const operationMethods = new Set(['delete', 'get', 'patch', 'post', 'put'])
const pydanticDecimalStringPattern = '^(?!^[-+.]*$)[+-]?0*\\d*\\.?\\d*$'
const codegenSafeDecimalStringPattern = '^(?![-+.]*$)[+-]?0*\\d*\\.?\\d*$'
const opaqueJsonContent = (): Record<string, OpenApiMediaType> => ({
'application/json': {
schema: {
additionalProperties: true,
type: 'object',
},
},
})
const apiSpecs: ApiSpec[] = [
{ filename: 'console-openapi.json', name: 'console' },
@ -182,6 +201,46 @@ const addOperationIds = (document: SwaggerDocument) => {
}
}
const isOpaqueContractResponse = (response: OpenApiResponse) => {
const content = response.content
if (!isObject(content))
return false
return Object.entries(content).some(([mediaType, media]) => {
if (!isObject(media))
return false
return (mediaType === 'application/json' || mediaType === 'text/event-stream') && !('schema' in media)
})
}
const hasOpaqueContractSuccessResponse = (operation: SwaggerOperation) => {
return Object.entries(operation.responses ?? {}).some(([status, response]) => {
return /^2\d\d$/.test(status) && isObject(response) && isOpaqueContractResponse(response)
})
}
const normalizeOpaqueContractResponses = (document: SwaggerDocument) => {
// Some backend endpoints has no schema (e.g. external) and will trap heyapi here
// So we forge an opaque schema here
for (const pathItem of Object.values(document.paths ?? {})) {
for (const [method, operation] of Object.entries(pathItem)) {
if (!operationMethods.has(method) || !isObject(operation))
continue
const swaggerOperation = operation as SwaggerOperation
if (!hasOpaqueContractSuccessResponse(swaggerOperation))
continue
Object.values(swaggerOperation.responses ?? {})
.filter(response => isObject(response) && isOpaqueContractResponse(response))
.forEach((response) => {
response.content = opaqueJsonContent()
})
}
}
}
const hasSuccessResponse = (operation: SwaggerOperation) => {
return Object.entries(operation.responses ?? {}).some(([status, response]) => {
if (!/^2\d\d$/.test(status))
@ -215,6 +274,7 @@ const filterContractOperations = (document: SwaggerDocument) => {
}
const normalizeApiSwagger = (document: SwaggerDocument) => {
normalizeOpaqueContractResponses(document)
filterContractOperations(document)
addOperationIds(document)
@ -380,10 +440,20 @@ const createApiConfig = (job: ApiJob): UserConfig => ({
'name': 'zod',
'~resolvers': {
string: (ctx) => {
if (ctx.schema.format !== 'binary')
return undefined
if (ctx.schema.format === 'binary')
return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File')))
return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File')))
if (ctx.schema.pattern === pydanticDecimalStringPattern) {
// the pydantic generated regex will emit error like
// regexp/no-useless-assertions, so patch the regex here
return $(ctx.symbols.z)
.attr('string')
.call()
.attr('regex')
.call($.regexp(codegenSafeDecimalStringPattern))
}
return undefined
},
},
},

View File

@ -217,14 +217,8 @@ const toFeedback = (feedback: NonNullable<MessageDetailResponse['feedbacks']>[nu
}
}
type AgentDebugMessageWithLegacyAnswer = MessageDetailResponse & {
answer?: string | null
}
const getAgentDebugMessageAnswer = (message: MessageDetailResponse) => {
const legacyAnswer = (message as AgentDebugMessageWithLegacyAnswer).answer
return message.re_sign_file_url_answer ?? legacyAnswer ?? ''
return message.answer ?? ''
}
function getFormattedAgentDebugChatTree(messages: MessageDetailResponse[]): ChatItemInTree[] {