mirror of
https://github.com/langgenius/dify.git
synced 2026-06-26 06:41:10 +08:00
refactor(api): migrate workspace tool endpoints to BaseModel
This commit is contained in:
parent
bb921bcc45
commit
b31200c872
File diff suppressed because it is too large
Load Diff
@ -9,14 +9,21 @@ from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.errors import NotFoundError
|
||||
from controllers.common.fields import BinaryFileResponse, RedirectResponse, SimpleResultResponse
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
from core.trigger.entities.api_entities import (
|
||||
SubscriptionBuilderApiEntity,
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
)
|
||||
from core.trigger.entities.entities import RequestLog, SubscriptionBuilderUpdater
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.provider_ids import TriggerProviderID
|
||||
@ -51,9 +58,9 @@ class TriggerSubscriptionBuilderVerifyPayload(BaseModel):
|
||||
|
||||
class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
|
||||
name: str | None = None
|
||||
parameters: dict[str, Any] | None = Field(default=None)
|
||||
properties: dict[str, Any] | None = Field(default=None)
|
||||
credentials: dict[str, Any] | None = Field(default=None)
|
||||
parameters: dict[str, Any] | None = None
|
||||
properties: dict[str, Any] | None = None
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_at_least_one_field(self):
|
||||
@ -63,28 +70,48 @@ class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
|
||||
|
||||
|
||||
class TriggerOAuthClientPayload(BaseModel):
|
||||
client_params: dict[str, Any] | None = Field(default=None)
|
||||
client_params: dict[str, Any] | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class TriggerOAuthAuthorizeResponse(BaseModel):
|
||||
class TriggerProviderListResponse(RootModel[list[TriggerProviderApiEntity]]):
|
||||
pass
|
||||
|
||||
|
||||
class TriggerProviderSubscriptionListResponse(RootModel[list[TriggerProviderSubscriptionApiEntity]]):
|
||||
pass
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderCreateResponse(ResponseModel):
|
||||
subscription_builder: SubscriptionBuilderApiEntity
|
||||
|
||||
|
||||
class TriggerVerificationResponse(ResponseModel):
|
||||
verified: bool
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderLogsResponse(ResponseModel):
|
||||
logs: list[RequestLog]
|
||||
|
||||
|
||||
class TriggerOAuthAuthorizeResponse(ResponseModel):
|
||||
authorization_url: str
|
||||
subscription_builder_id: str
|
||||
subscription_builder: Any
|
||||
subscription_builder: SubscriptionBuilderApiEntity
|
||||
|
||||
|
||||
class TriggerOAuthClientResponse(BaseModel):
|
||||
class TriggerOAuthClientResponse(ResponseModel):
|
||||
configured: bool
|
||||
system_configured: bool
|
||||
custom_configured: bool
|
||||
oauth_client_schema: Any
|
||||
oauth_client_schema: list[ProviderConfig]
|
||||
custom_enabled: bool
|
||||
redirect_uri: str
|
||||
params: dict[str, Any]
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TriggerProviderOpaqueResponse(RootModel[Any]):
|
||||
root: Any
|
||||
class TriggerProviderErrorResponse(ResponseModel):
|
||||
error: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
@ -96,18 +123,24 @@ register_schema_models(
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
BinaryFileResponse,
|
||||
RedirectResponse,
|
||||
SimpleResultResponse,
|
||||
TriggerOAuthAuthorizeResponse,
|
||||
TriggerOAuthClientResponse,
|
||||
TriggerProviderOpaqueResponse,
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderErrorResponse,
|
||||
TriggerProviderListResponse,
|
||||
TriggerProviderSubscriptionListResponse,
|
||||
TriggerSubscriptionBuilderCreateResponse,
|
||||
TriggerSubscriptionBuilderLogsResponse,
|
||||
SubscriptionBuilderApiEntity,
|
||||
TriggerVerificationResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||
class TriggerProviderIconApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[BinaryFileResponse.__name__])
|
||||
# response-contract:ignore binary trigger provider icon
|
||||
@console_ns.response(200, "Trigger provider icon")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -118,31 +151,45 @@ class TriggerProviderIconApi(Resource):
|
||||
|
||||
@console_ns.route("/workspaces/current/triggers")
|
||||
class TriggerProviderListApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trigger providers retrieved successfully",
|
||||
console_ns.models[TriggerProviderListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(tenant_id))
|
||||
return dump_response(TriggerProviderListResponse, TriggerProviderService.list_trigger_providers(tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
class TriggerProviderInfoApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trigger provider retrieved successfully",
|
||||
console_ns.models[TriggerProviderApiEntity.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
"""Get info for a trigger provider"""
|
||||
return jsonable_encoder(TriggerProviderService.get_trigger_provider(tenant_id, TriggerProviderID(provider)))
|
||||
provider_entity = TriggerProviderService.get_trigger_provider(tenant_id, TriggerProviderID(provider))
|
||||
return provider_entity.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trigger subscriptions retrieved successfully",
|
||||
console_ns.models[TriggerProviderSubscriptionListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Trigger provider not found", console_ns.models[TriggerProviderErrorResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@ -152,16 +199,18 @@ class TriggerSubscriptionListApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
return dump_response(
|
||||
TriggerProviderSubscriptionListResponse,
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
user=user,
|
||||
)
|
||||
),
|
||||
)
|
||||
except ValueError as e:
|
||||
return jsonable_encoder({"error": str(e)}), 404
|
||||
return TriggerProviderErrorResponse(error=str(e)).model_dump(mode="json"), 404
|
||||
except Exception as e:
|
||||
logger.exception("Error listing trigger providers", exc_info=e)
|
||||
raise
|
||||
@ -172,7 +221,11 @@ class TriggerSubscriptionListApi(Resource):
|
||||
)
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trigger subscription builder created successfully",
|
||||
console_ns.models[TriggerSubscriptionBuilderCreateResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@ -182,6 +235,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
|
||||
payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
@ -192,7 +246,9 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credential_type=credential_type,
|
||||
)
|
||||
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||
return TriggerSubscriptionBuilderCreateResponse(subscription_builder=subscription_builder).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error adding provider credential", exc_info=e)
|
||||
raise
|
||||
@ -202,7 +258,11 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trigger subscription builder retrieved successfully",
|
||||
console_ns.models[SubscriptionBuilderApiEntity.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@ -210,9 +270,8 @@ class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider: str, subscription_builder_id: str):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
return subscription_builder.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -220,7 +279,11 @@ class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
)
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trigger subscription builder verified successfully",
|
||||
console_ns.models[TriggerVerificationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@ -230,11 +293,12 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str, subscription_builder_id: str):
|
||||
"""Verify and update a subscription instance for a trigger provider"""
|
||||
|
||||
payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
# Use atomic update_and_verify to prevent race conditions
|
||||
return TriggerSubscriptionBuilderService.update_and_verify_builder(
|
||||
result = TriggerSubscriptionBuilderService.update_and_verify_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
@ -243,6 +307,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
credentials=payload.credentials,
|
||||
),
|
||||
)
|
||||
return dump_response(TriggerVerificationResponse, result)
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying provider credential", exc_info=e)
|
||||
raise ValueError(str(e)) from e
|
||||
@ -253,7 +318,11 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
)
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trigger subscription builder updated successfully",
|
||||
console_ns.models[SubscriptionBuilderApiEntity.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@ -262,21 +331,20 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str, subscription_builder_id: str):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
|
||||
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=payload.name,
|
||||
parameters=payload.parameters,
|
||||
properties=payload.properties,
|
||||
credentials=payload.credentials,
|
||||
),
|
||||
)
|
||||
)
|
||||
return TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=payload.name,
|
||||
parameters=payload.parameters,
|
||||
properties=payload.properties,
|
||||
credentials=payload.credentials,
|
||||
),
|
||||
).model_dump(mode="json")
|
||||
except Exception as e:
|
||||
logger.exception("Error updating provider credential", exc_info=e)
|
||||
raise
|
||||
@ -286,7 +354,11 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trigger subscription builder logs retrieved successfully",
|
||||
console_ns.models[TriggerSubscriptionBuilderLogsResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@ -294,9 +366,10 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider: str, subscription_builder_id: str):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
|
||||
try:
|
||||
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
|
||||
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
|
||||
return dump_response(TriggerSubscriptionBuilderLogsResponse, {"logs": logs})
|
||||
except Exception as e:
|
||||
logger.exception("Error getting request logs for subscription builder", exc_info=e)
|
||||
raise
|
||||
@ -307,7 +380,9 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
)
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200, "Trigger subscription builder built successfully", console_ns.models[SimpleResultResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@ -331,7 +406,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
properties=payload.properties,
|
||||
),
|
||||
)
|
||||
return 200
|
||||
return SimpleResultResponse(result="success").model_dump(mode="json")
|
||||
except Exception as e:
|
||||
logger.exception("Error building provider credential", exc_info=e)
|
||||
raise ValueError(str(e)) from e
|
||||
@ -342,7 +417,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
)
|
||||
class TriggerSubscriptionUpdateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200, "Trigger subscription updated successfully", console_ns.models[SimpleResultResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@ -351,6 +428,7 @@ class TriggerSubscriptionUpdateApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, subscription_id: str):
|
||||
"""Update a subscription instance"""
|
||||
|
||||
request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
@ -376,7 +454,7 @@ class TriggerSubscriptionUpdateApi(Resource):
|
||||
name=request.name,
|
||||
properties=request.properties,
|
||||
)
|
||||
return 200
|
||||
return SimpleResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
# For the rest cases(API_KEY, OAUTH2)
|
||||
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
|
||||
@ -388,7 +466,7 @@ class TriggerSubscriptionUpdateApi(Resource):
|
||||
credentials=request.credentials or subscription.credentials,
|
||||
parameters=request.parameters or subscription.parameters,
|
||||
)
|
||||
return 200
|
||||
return SimpleResultResponse(result="success").model_dump(mode="json")
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
@ -409,6 +487,7 @@ class TriggerSubscriptionDeleteApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, subscription_id: str):
|
||||
"""Delete a subscription instance"""
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Delete trigger provider subscription
|
||||
@ -423,7 +502,7 @@ class TriggerSubscriptionDeleteApi(Resource):
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
return {"result": "success"}
|
||||
return SimpleResultResponse(result="success").model_dump(mode="json")
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
@ -433,9 +512,10 @@ class TriggerSubscriptionDeleteApi(Resource):
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize")
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
# response-contract:ignore cookie-bearing Flask response
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Authorization URL retrieved successfully",
|
||||
"Trigger OAuth authorization URL generated successfully",
|
||||
console_ns.models[TriggerOAuthAuthorizeResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@ -445,10 +525,12 @@ class TriggerOAuthAuthorizeApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
tenant_id = tenant_id
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
@ -492,15 +574,12 @@ class TriggerOAuthAuthorizeApi(Resource):
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(
|
||||
jsonable_encoder(
|
||||
{
|
||||
"authorization_url": authorization_url_response.authorization_url,
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
"subscription_builder": subscription_builder,
|
||||
}
|
||||
)
|
||||
TriggerOAuthAuthorizeResponse(
|
||||
authorization_url=authorization_url_response.authorization_url,
|
||||
subscription_builder_id=subscription_builder.id,
|
||||
subscription_builder=subscription_builder,
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
@ -519,11 +598,8 @@ class TriggerOAuthAuthorizeApi(Resource):
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@console_ns.response(
|
||||
302,
|
||||
"Redirect to console OAuth callback page",
|
||||
console_ns.models[RedirectResponse.__name__],
|
||||
)
|
||||
# response-contract:ignore redirect response
|
||||
@console_ns.response(302, "Redirect to OAuth callback page")
|
||||
@setup_required
|
||||
def get(self, provider: str):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
@ -589,7 +665,11 @@ class TriggerOAuthCallbackApi(Resource):
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerOAuthClientResponse.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trigger OAuth client retrieved successfully",
|
||||
console_ns.models[TriggerOAuthClientResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@ -598,6 +678,7 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
@ -618,24 +699,24 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
)
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"configured": bool(custom_params or system_client_exists),
|
||||
"system_configured": system_client_exists,
|
||||
"custom_configured": bool(custom_params),
|
||||
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
|
||||
"custom_enabled": is_custom_enabled,
|
||||
"redirect_uri": redirect_uri,
|
||||
"params": custom_params or {},
|
||||
}
|
||||
)
|
||||
return TriggerOAuthClientResponse(
|
||||
configured=bool(custom_params or system_client_exists),
|
||||
system_configured=system_client_exists,
|
||||
custom_configured=bool(custom_params),
|
||||
oauth_client_schema=provider_controller.get_oauth_client_schema(),
|
||||
custom_enabled=is_custom_enabled,
|
||||
redirect_uri=redirect_uri,
|
||||
params=dict(custom_params),
|
||||
).model_dump(mode="json")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(
|
||||
200, "Trigger OAuth client saved successfully", console_ns.models[SimpleResultResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@ -644,16 +725,18 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
|
||||
payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
result = TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=payload.client_params,
|
||||
enabled=payload.enabled,
|
||||
)
|
||||
return dump_response(SimpleResultResponse, result)
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
@ -661,22 +744,26 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
logger.exception("Error configuring OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@console_ns.response(
|
||||
200, "Trigger OAuth client deleted successfully", console_ns.models[SimpleResultResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_PREFERENCES, resource_required=False)
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, provider: str):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||
result = TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return dump_response(SimpleResultResponse, result)
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
@ -689,7 +776,11 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
)
|
||||
class TriggerSubscriptionVerifyApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[TriggerProviderOpaqueResponse.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trigger subscription verified successfully",
|
||||
console_ns.models[TriggerVerificationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@ -699,6 +790,7 @@ class TriggerSubscriptionVerifyApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str, subscription_id: str):
|
||||
"""Verify credentials for an existing subscription (edit mode only)"""
|
||||
|
||||
verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
@ -709,7 +801,7 @@ class TriggerSubscriptionVerifyApi(Resource):
|
||||
subscription_id=subscription_id,
|
||||
credentials=verify_request.credentials,
|
||||
)
|
||||
return result
|
||||
return dump_response(TriggerVerificationResponse, result)
|
||||
except ValueError as e:
|
||||
logger.warning("Credential verification failed", exc_info=e)
|
||||
raise BadRequest(str(e)) from e
|
||||
|
||||
@ -74,8 +74,6 @@ class ToolProviderApiEntity(BaseModel):
|
||||
for parameter in tool.get("parameters"):
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES:
|
||||
parameter["type"] = "files"
|
||||
if parameter.get("input_schema") is None:
|
||||
parameter.pop("input_schema", None)
|
||||
# -------------
|
||||
optional_fields = self.optional_field("server_url", self.server_url)
|
||||
match self.type:
|
||||
|
||||
@ -16,30 +16,16 @@ from yarl import URL
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from extensions.ext_database import db
|
||||
from graphon.runtime import VariablePool
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.entities import PluginCredentialType
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
@ -56,12 +42,21 @@ from core.tools.entities.tool_entities import (
|
||||
emoji_icon_adapter,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from graphon.runtime import VariablePool
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -921,23 +916,19 @@ class ToolManager:
|
||||
|
||||
# add tool labels
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
schema_type = provider_obj.schema_type
|
||||
|
||||
return cast(
|
||||
dict,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": provider_obj.schema_type,
|
||||
"schema": provider_obj.schema,
|
||||
"tools": provider_obj.tools,
|
||||
"icon": icon,
|
||||
"description": provider_obj.description,
|
||||
"credentials": masked_credentials,
|
||||
"privacy_policy": provider_obj.privacy_policy,
|
||||
"custom_disclaimer": provider_obj.custom_disclaimer,
|
||||
"labels": labels,
|
||||
}
|
||||
),
|
||||
)
|
||||
return {
|
||||
"schema_type": getattr(schema_type, "value", schema_type),
|
||||
"schema": provider_obj.schema,
|
||||
"tools": [tool.model_dump(mode="json") for tool in provider_obj.tools],
|
||||
"icon": icon,
|
||||
"description": provider_obj.description,
|
||||
"credentials": masked_credentials,
|
||||
"privacy_policy": provider_obj.privacy_policy,
|
||||
"custom_disclaimer": provider_obj.custom_disclaimer,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
from httpx import get
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -22,7 +23,6 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
@ -36,6 +36,27 @@ class ApiSchemaParseResult(TypedDict):
|
||||
warning: dict[str, str]
|
||||
|
||||
|
||||
class ApiToolPreviewResult(TypedDict, total=False):
|
||||
result: str
|
||||
error: str
|
||||
|
||||
|
||||
class RemoteSchemaResult(TypedDict):
|
||||
schema: str
|
||||
|
||||
|
||||
class SimpleSuccessResult(TypedDict):
|
||||
result: Literal["success"]
|
||||
|
||||
|
||||
def _dump_api_tool_bundles(tool_bundles: list[ApiToolBundle]) -> list[dict[str, Any]]:
|
||||
return cast(list[dict[str, Any]], TypeAdapter(list[ApiToolBundle]).dump_python(tool_bundles, mode="json"))
|
||||
|
||||
|
||||
def _dump_provider_configs(configs: list[ProviderConfig]) -> list[dict[str, Any]]:
|
||||
return cast(list[dict[str, Any]], TypeAdapter(list[ProviderConfig]).dump_python(configs, mode="json"))
|
||||
|
||||
|
||||
class ApiToolManageService:
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> ApiSchemaParseResult:
|
||||
@ -80,14 +101,12 @@ class ApiToolManageService:
|
||||
|
||||
return cast(
|
||||
ApiSchemaParseResult,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
"parameters_schema": tool_bundles,
|
||||
"credentials_schema": credentials_schema,
|
||||
"warning": warnings,
|
||||
}
|
||||
),
|
||||
{
|
||||
"schema_type": schema_type.value,
|
||||
"parameters_schema": _dump_api_tool_bundles(tool_bundles),
|
||||
"credentials_schema": _dump_provider_configs(credentials_schema),
|
||||
"warning": warnings,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
@ -118,7 +137,7 @@ class ApiToolManageService:
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
) -> dict[str, Any]:
|
||||
) -> SimpleSuccessResult:
|
||||
"""
|
||||
Create a new API tool provider.
|
||||
|
||||
@ -169,7 +188,7 @@ class ApiToolManageService:
|
||||
schema=schema,
|
||||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
tools_str=json.dumps(_dump_api_tool_bundles(tool_bundles)),
|
||||
credentials_str="{}",
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
@ -201,7 +220,7 @@ class ApiToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
|
||||
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str) -> RemoteSchemaResult:
|
||||
"""
|
||||
get api tool provider remote schema
|
||||
"""
|
||||
@ -276,7 +295,7 @@ class ApiToolManageService:
|
||||
privacy_policy: str | None,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
) -> dict[str, Any]:
|
||||
) -> SimpleSuccessResult:
|
||||
"""
|
||||
Update an existing API tool provider.
|
||||
|
||||
@ -322,7 +341,7 @@ class ApiToolManageService:
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get("description", "")
|
||||
provider.schema_type_str = schema_type
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.tools_str = json.dumps(_dump_api_tool_bundles(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
|
||||
@ -365,7 +384,7 @@ class ApiToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
|
||||
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str) -> SimpleSuccessResult:
|
||||
"""
|
||||
Delete an API tool provider.
|
||||
|
||||
@ -413,9 +432,9 @@ class ApiToolManageService:
|
||||
tool_name: str,
|
||||
credentials: dict[str, Any],
|
||||
parameters: dict[str, Any],
|
||||
schema_type: ApiProviderSchemaType,
|
||||
schema_type: ApiProviderSchemaType | str,
|
||||
schema: str,
|
||||
) -> dict[str, Any]:
|
||||
) -> ApiToolPreviewResult:
|
||||
"""
|
||||
Test an API tool before adding the API tool provider.
|
||||
|
||||
@ -464,7 +483,7 @@ class ApiToolManageService:
|
||||
schema=schema,
|
||||
description="",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
tools_str=json.dumps(_dump_api_tool_bundles(tool_bundles)),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
|
||||
@ -44,6 +44,7 @@ from controllers.console.workspace.tool_providers import (
|
||||
ToolWorkflowProviderUpdateApi,
|
||||
is_valid_url,
|
||||
)
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity as CoreToolProviderApiEntity
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
@ -60,6 +61,148 @@ def empty_list() -> list[object]:
|
||||
return []
|
||||
|
||||
|
||||
def emoji_icon() -> dict[str, str]:
|
||||
return {"content": "tool", "background": "#252525"}
|
||||
|
||||
|
||||
def i18n(text: str) -> dict[str, str]:
|
||||
return {"en_US": text}
|
||||
|
||||
|
||||
def tool_payload(name: str = "ping") -> dict[str, object]:
|
||||
return {
|
||||
"author": "langgenius",
|
||||
"name": name,
|
||||
"label": i18n(name.title()),
|
||||
"description": i18n(f"{name} description"),
|
||||
"parameters": [],
|
||||
"labels": ["utilities"],
|
||||
"output_schema": {},
|
||||
}
|
||||
|
||||
|
||||
def provider_payload(
|
||||
*,
|
||||
provider_id: str = "provider-1",
|
||||
name: str = "provider",
|
||||
provider_type: str = "builtin",
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"id": provider_id,
|
||||
"author": "langgenius",
|
||||
"name": name,
|
||||
"description": i18n(f"{name} description"),
|
||||
"icon": emoji_icon(),
|
||||
"icon_dark": emoji_icon(),
|
||||
"label": i18n(name.title()),
|
||||
"type": provider_type,
|
||||
"masked_credentials": {"api_key": "[__HIDDEN__]"},
|
||||
"original_credentials": {"api_key": "sk-secret"},
|
||||
"is_team_authorization": False,
|
||||
"allow_delete": True,
|
||||
"plugin_id": "langgenius/provider",
|
||||
"plugin_unique_identifier": "langgenius/provider:1.0.0",
|
||||
"tools": tools or [tool_payload()],
|
||||
"labels": ["utilities"],
|
||||
"server_url": "",
|
||||
"updated_at": 1710000000,
|
||||
"server_identifier": "",
|
||||
"masked_headers": None,
|
||||
"original_headers": None,
|
||||
"authentication": None,
|
||||
"is_dynamic_registration": True,
|
||||
"configuration": None,
|
||||
"identity_mode": "off",
|
||||
"workflow_app_id": None,
|
||||
}
|
||||
|
||||
|
||||
def provider_entity(
|
||||
*,
|
||||
provider_id: str = "provider-1",
|
||||
name: str = "provider",
|
||||
provider_type: str = "builtin",
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
) -> CoreToolProviderApiEntity:
|
||||
return CoreToolProviderApiEntity.model_validate(
|
||||
provider_payload(provider_id=provider_id, name=name, provider_type=provider_type, tools=tools)
|
||||
)
|
||||
|
||||
|
||||
def credential_payload() -> dict[str, object]:
|
||||
return {
|
||||
"id": "credential-1",
|
||||
"name": "Default credential",
|
||||
"provider": "provider",
|
||||
"credential_type": "api-key",
|
||||
"is_default": True,
|
||||
"credentials": {"api_key": "masked"},
|
||||
"visibility": "all_team_members",
|
||||
"created_by": "user-1",
|
||||
"partial_member_list": [],
|
||||
"from_other_member": False,
|
||||
}
|
||||
|
||||
|
||||
def provider_config_payload() -> dict[str, object]:
|
||||
return {"type": "secret-input", "name": "api_key", "required": True}
|
||||
|
||||
|
||||
def api_tool_bundle_payload() -> dict[str, object]:
|
||||
return {
|
||||
"server_url": "https://api.example.com",
|
||||
"method": "get",
|
||||
"summary": "Ping",
|
||||
"operation_id": "ping",
|
||||
"parameters": [],
|
||||
"author": "langgenius",
|
||||
"icon": None,
|
||||
"openapi": {"operationId": "ping"},
|
||||
"output_schema": {},
|
||||
}
|
||||
|
||||
|
||||
def api_provider_detail_payload() -> dict[str, object]:
|
||||
return {
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"tools": [api_tool_bundle_payload()],
|
||||
"icon": emoji_icon(),
|
||||
"description": "API provider",
|
||||
"credentials": {},
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
"labels": ["utilities"],
|
||||
}
|
||||
|
||||
|
||||
def credential_info_payload() -> dict[str, object]:
|
||||
return {
|
||||
"supported_credential_types": ["api-key", "oauth2"],
|
||||
"is_oauth_custom_client_enabled": False,
|
||||
"credentials": [credential_payload()],
|
||||
}
|
||||
|
||||
|
||||
def oauth_client_schema_payload() -> dict[str, object]:
|
||||
return {
|
||||
"schema": [provider_config_payload()],
|
||||
"is_oauth_custom_client_enabled": False,
|
||||
"is_system_oauth_params_exists": True,
|
||||
"client_params": {"client_id": "masked"},
|
||||
"redirect_uri": "https://console.example.com/oauth/callback",
|
||||
}
|
||||
|
||||
|
||||
def tool_label_payload() -> dict[str, object]:
|
||||
return {
|
||||
"name": "utilities",
|
||||
"label": i18n("Utilities"),
|
||||
"icon": "wrench",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_cache() -> None:
|
||||
return
|
||||
@ -127,7 +270,7 @@ def test_create_mcp_provider_populates_tools(
|
||||
with (
|
||||
patch(
|
||||
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
|
||||
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
|
||||
return_value=provider_entity(provider_id="provider-1", provider_type="mcp", tools=[tool_payload()]),
|
||||
autospec=True,
|
||||
),
|
||||
):
|
||||
@ -138,13 +281,15 @@ def test_create_mcp_provider_populates_tools(
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
assert body.get("id") == "provider-1"
|
||||
# 若 transform 后包含 tools 字段,确保非空
|
||||
assert isinstance(body.get("tools"), list)
|
||||
assert body["tools"]
|
||||
# Assert
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
assert body.get("id") == "provider-1"
|
||||
assert body["team_credentials"] == {"api_key": "[__HIDDEN__]"}
|
||||
assert "masked_credentials" not in body
|
||||
assert "original_credentials" not in body
|
||||
assert isinstance(body.get("tools"), list)
|
||||
assert body["tools"]
|
||||
|
||||
|
||||
class TestUtils:
|
||||
@ -170,10 +315,16 @@ class TestToolProviderListApi:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers",
|
||||
return_value=["p1"],
|
||||
return_value=[provider_entity(provider_id="p1").to_dict()],
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", make_account(id="u1")) == ["p1"]
|
||||
result = method(api, "t1", make_account(id="u1"))
|
||||
|
||||
assert result[0]["id"] == "p1"
|
||||
assert result[0]["team_credentials"] == {"api_key": "[__HIDDEN__]"}
|
||||
assert "masked_credentials" not in result[0]
|
||||
assert "original_credentials" not in result[0]
|
||||
assert result[0]["tools"][0]["name"] == "ping"
|
||||
|
||||
|
||||
class TestBuiltinProviderApis:
|
||||
@ -189,10 +340,10 @@ class TestBuiltinProviderApis:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools",
|
||||
return_value=[{"a": 1}],
|
||||
return_value=[tool_payload()],
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", "provider") == [{"a": 1}]
|
||||
assert method(api, "t1", "provider")[0]["name"] == "ping"
|
||||
|
||||
def test_info(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderInfoApi()
|
||||
@ -202,10 +353,15 @@ class TestBuiltinProviderApis:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info",
|
||||
return_value={"x": 1},
|
||||
return_value=provider_entity(),
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", "provider") == {"x": 1}
|
||||
result = method(api, "t1", "provider")
|
||||
|
||||
assert result["id"] == "provider-1"
|
||||
assert result["team_credentials"] == {"api_key": "[__HIDDEN__]"}
|
||||
assert "masked_credentials" not in result
|
||||
assert "original_credentials" not in result
|
||||
|
||||
def test_delete(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderDeleteApi()
|
||||
@ -240,10 +396,10 @@ class TestBuiltinProviderApis:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider",
|
||||
return_value={"id": 1},
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "t", make_account(), "provider")["id"] == 1
|
||||
assert method(api, "t", make_account(), "provider")["result"] == "success"
|
||||
|
||||
def test_update(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderUpdateApi()
|
||||
@ -255,10 +411,10 @@ class TestBuiltinProviderApis:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider",
|
||||
return_value={"ok": True},
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "t", make_account(), "provider")["ok"]
|
||||
assert method(api, "t", make_account(), "provider")["result"] == "success"
|
||||
|
||||
def test_get_credentials(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderGetCredentialsApi()
|
||||
@ -268,10 +424,10 @@ class TestBuiltinProviderApis:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials",
|
||||
return_value={"k": "v"},
|
||||
return_value=[credential_payload()],
|
||||
),
|
||||
):
|
||||
assert method(api, "t", make_account(id="user-1"), "provider") == {"k": "v"}
|
||||
assert method(api, "t", make_account(id="user-1"), "provider")[0]["id"] == "credential-1"
|
||||
|
||||
def test_icon(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderIconApi()
|
||||
@ -295,10 +451,10 @@ class TestBuiltinProviderApis:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema",
|
||||
return_value={"schema": {}},
|
||||
return_value=[provider_config_payload()],
|
||||
),
|
||||
):
|
||||
assert method(api, "t", "provider", "oauth2") == {"schema": {}}
|
||||
assert method(api, "t", "provider", "oauth2")[0]["name"] == "api_key"
|
||||
|
||||
def test_set_default_credential(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderSetDefaultApi()
|
||||
@ -308,10 +464,10 @@ class TestBuiltinProviderApis:
|
||||
app.test_request_context("/", json={"id": "c1"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider",
|
||||
return_value={"ok": True},
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "t", "provider")["ok"]
|
||||
assert method(api, "t", "provider")["result"] == "success"
|
||||
|
||||
def test_get_credential_info(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderGetCredentialInfoApi()
|
||||
@ -321,10 +477,10 @@ class TestBuiltinProviderApis:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info",
|
||||
return_value={"info": "x"},
|
||||
return_value=credential_info_payload(),
|
||||
),
|
||||
):
|
||||
assert method(api, "t", make_account(), "provider") == {"info": "x"}
|
||||
assert method(api, "t", make_account(), "provider")["credentials"][0]["id"] == "credential-1"
|
||||
|
||||
def test_get_oauth_client_schema(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderGetOauthClientSchemaApi()
|
||||
@ -334,10 +490,10 @@ class TestBuiltinProviderApis:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema",
|
||||
return_value={"schema": {}},
|
||||
return_value=oauth_client_schema_payload(),
|
||||
),
|
||||
):
|
||||
assert method(api, "t", "provider") == {"schema": {}}
|
||||
assert method(api, "t", "provider")["schema"][0]["name"] == "api_key"
|
||||
|
||||
|
||||
class TestApiProviderApis:
|
||||
@ -354,30 +510,34 @@ class TestApiProviderApis:
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"provider": "p",
|
||||
"icon": empty_mapping(),
|
||||
"icon": emoji_icon(),
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
return_value={"result": "success"},
|
||||
) as create_api_tool_provider,
|
||||
):
|
||||
assert method(api, "t", make_account())["id"] == 1
|
||||
assert method(api, "t", make_account()) == {"result": "success"}
|
||||
|
||||
create_api_tool_provider.assert_called_once()
|
||||
assert create_api_tool_provider.call_args.args[3] == emoji_icon()
|
||||
|
||||
def test_remote_schema(self, app: Flask) -> None:
|
||||
api = ToolApiProviderGetRemoteSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
openapi_schema = '{"openapi":"3.0.0","info":{"title":"Demo API","version":"1.0.0"},"paths":{}}'
|
||||
|
||||
with (
|
||||
app.test_request_context("/?url=http://x.com"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema",
|
||||
return_value={"schema": "x"},
|
||||
return_value={"schema": openapi_schema},
|
||||
),
|
||||
):
|
||||
assert method(api, "t", make_account())["schema"] == "x"
|
||||
assert method(api, "t", make_account()) == {"schema": openapi_schema}
|
||||
|
||||
def test_list_tools(self, app: Flask) -> None:
|
||||
api = ToolApiProviderListToolsApi()
|
||||
@ -387,10 +547,10 @@ class TestApiProviderApis:
|
||||
app.test_request_context("/?provider=p"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools",
|
||||
return_value=[{"tool": 1}],
|
||||
return_value=[tool_payload("api_ping")],
|
||||
),
|
||||
):
|
||||
assert method(api, "t", make_account()) == [{"tool": 1}]
|
||||
assert method(api, "t", make_account())[0]["name"] == "api_ping"
|
||||
|
||||
def test_update(self, app: Flask) -> None:
|
||||
api = ToolApiProviderUpdateApi()
|
||||
@ -402,7 +562,7 @@ class TestApiProviderApis:
|
||||
"schema": "{}",
|
||||
"provider": "p",
|
||||
"original_provider": "o",
|
||||
"icon": empty_mapping(),
|
||||
"icon": emoji_icon(),
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
}
|
||||
@ -411,10 +571,13 @@ class TestApiProviderApis:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
return_value={"result": "success"},
|
||||
) as update_api_tool_provider,
|
||||
):
|
||||
assert method(api, "t", make_account())["ok"]
|
||||
assert method(api, "t", make_account()) == {"result": "success"}
|
||||
|
||||
update_api_tool_provider.assert_called_once()
|
||||
assert update_api_tool_provider.call_args.args[4] == emoji_icon()
|
||||
|
||||
def test_delete(self, app: Flask) -> None:
|
||||
api = ToolApiProviderDeleteApi()
|
||||
@ -437,10 +600,10 @@ class TestApiProviderApis:
|
||||
app.test_request_context("/?provider=p"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider",
|
||||
return_value={"x": 1},
|
||||
return_value=api_provider_detail_payload(),
|
||||
),
|
||||
):
|
||||
assert method(api, "t", make_account()) == {"x": 1}
|
||||
assert method(api, "t", make_account())["schema"] == "{}"
|
||||
|
||||
|
||||
class TestWorkflowApis:
|
||||
@ -457,7 +620,7 @@ class TestWorkflowApis:
|
||||
"name": "n",
|
||||
"label": "l",
|
||||
"description": "d",
|
||||
"icon": empty_mapping(),
|
||||
"icon": emoji_icon(),
|
||||
"parameters": empty_list(),
|
||||
}
|
||||
|
||||
@ -465,10 +628,13 @@ class TestWorkflowApis:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
return_value={"result": "success"},
|
||||
) as create_workflow_tool,
|
||||
):
|
||||
assert method(api, "t", make_account())["id"] == 1
|
||||
assert method(api, "t", make_account()) == {"result": "success"}
|
||||
|
||||
create_workflow_tool.assert_called_once()
|
||||
assert create_workflow_tool.call_args.kwargs["icon"] == emoji_icon()
|
||||
|
||||
def test_update_invalid(self, app: Flask) -> None:
|
||||
api = ToolWorkflowProviderUpdateApi()
|
||||
@ -479,18 +645,21 @@ class TestWorkflowApis:
|
||||
"name": "Tool",
|
||||
"label": "Tool Label",
|
||||
"description": "A tool",
|
||||
"icon": empty_mapping(),
|
||||
"icon": emoji_icon(),
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
return_value={"result": "success"},
|
||||
) as update_workflow_tool,
|
||||
):
|
||||
result = method(api, "t", make_account())
|
||||
assert result["ok"]
|
||||
assert result == {"result": "success"}
|
||||
|
||||
update_workflow_tool.assert_called_once()
|
||||
assert update_workflow_tool.call_args.args[5] == emoji_icon()
|
||||
|
||||
def test_delete(self, app: Flask) -> None:
|
||||
api = ToolWorkflowProviderDeleteApi()
|
||||
@ -500,10 +669,10 @@ class TestWorkflowApis:
|
||||
app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool",
|
||||
return_value={"ok": True},
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "t", make_account())["ok"]
|
||||
assert method(api, "t", make_account())["result"] == "success"
|
||||
|
||||
def test_get_error(self, app: Flask) -> None:
|
||||
api = ToolWorkflowProviderGetApi()
|
||||
@ -525,49 +694,40 @@ class TestLists:
|
||||
api = ToolBuiltinListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools",
|
||||
return_value=[m],
|
||||
return_value=[provider_entity(provider_id="builtin-1")],
|
||||
),
|
||||
):
|
||||
assert method(api, "t", make_account()) == [{"x": 1}]
|
||||
assert method(api, "t", make_account())[0]["id"] == "builtin-1"
|
||||
|
||||
def test_api_list(self, app: Flask) -> None:
|
||||
api = ToolApiListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools",
|
||||
return_value=[m],
|
||||
return_value=[provider_entity(provider_id="api-1", provider_type="api")],
|
||||
),
|
||||
):
|
||||
assert method(api, "t") == [{"x": 1}]
|
||||
assert method(api, "t")[0]["id"] == "api-1"
|
||||
|
||||
def test_workflow_list(self, app: Flask) -> None:
|
||||
api = ToolWorkflowListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools",
|
||||
return_value=[m],
|
||||
return_value=[provider_entity(provider_id="workflow-1", provider_type="workflow")],
|
||||
),
|
||||
):
|
||||
assert method(api, "t", make_account()) == [{"x": 1}]
|
||||
assert method(api, "t", make_account())[0]["id"] == "workflow-1"
|
||||
|
||||
|
||||
class TestLabels:
|
||||
@ -583,10 +743,10 @@ class TestLabels:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels",
|
||||
return_value=["l1"],
|
||||
return_value=[tool_label_payload()],
|
||||
),
|
||||
):
|
||||
assert method(api) == ["l1"]
|
||||
assert method(api)[0]["name"] == "utilities"
|
||||
|
||||
|
||||
class TestOAuth:
|
||||
@ -630,10 +790,10 @@ class TestOAuthCustomClient:
|
||||
app.test_request_context("/", json={"client_params": {"a": 1}}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "t", "provider")["ok"]
|
||||
assert method(api, "t", "provider") == {"result": "success"}
|
||||
|
||||
def test_get_custom_client(self, app: Flask) -> None:
|
||||
api = ToolOAuthCustomClient()
|
||||
@ -656,7 +816,7 @@ class TestOAuthCustomClient:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "t", "provider")["ok"]
|
||||
assert method(api, "t", "provider") == {"result": "success"}
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from inspect import unwrap
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -29,6 +30,8 @@ from controllers.console.workspace.trigger_providers import (
|
||||
TriggerSubscriptionVerifyApi,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity, TriggerProviderApiEntity
|
||||
from core.trigger.entities.entities import RequestLog
|
||||
from models.account import Account
|
||||
|
||||
|
||||
@ -38,6 +41,47 @@ def mock_user() -> Account:
|
||||
return user
|
||||
|
||||
|
||||
def trigger_provider() -> TriggerProviderApiEntity:
|
||||
return TriggerProviderApiEntity(
|
||||
author="Dify",
|
||||
name="github",
|
||||
label={"en_US": "GitHub"},
|
||||
description={"en_US": "GitHub trigger provider"},
|
||||
icon="icon.svg",
|
||||
icon_dark=None,
|
||||
tags=["code"],
|
||||
plugin_id="plugin",
|
||||
plugin_unique_identifier="plugin:github",
|
||||
supported_creation_methods=[],
|
||||
subscription_constructor=None,
|
||||
subscription_schema=[],
|
||||
events=[],
|
||||
)
|
||||
|
||||
|
||||
def subscription_builder() -> SubscriptionBuilderApiEntity:
|
||||
return SubscriptionBuilderApiEntity(
|
||||
id="b1",
|
||||
name="Builder",
|
||||
provider="github",
|
||||
endpoint="b1",
|
||||
parameters={"repo": "dify"},
|
||||
properties={"branch": "main"},
|
||||
credentials={"token": "secret"},
|
||||
credential_type=CredentialType.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
def request_log() -> RequestLog:
|
||||
return RequestLog(
|
||||
id="log1",
|
||||
endpoint="/hooks/b1",
|
||||
request={"headers": {}, "body": {"event": "push"}},
|
||||
response={"status": 200, "body": {"ok": True}},
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
|
||||
|
||||
class TestTriggerProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
@ -77,10 +121,10 @@ class TestTriggerProviderApis:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider",
|
||||
return_value={"id": "p1"},
|
||||
return_value=trigger_provider(),
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", "github") == {"id": "p1"}
|
||||
assert method(api, "t1", "github")["name"] == "github"
|
||||
|
||||
|
||||
class TestTriggerSubscriptionListApi:
|
||||
@ -129,11 +173,11 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
|
||||
return_value={"id": "b1"},
|
||||
return_value=subscription_builder(),
|
||||
),
|
||||
):
|
||||
result = method(api, "t1", mock_user(), "github")
|
||||
assert "subscription_builder" in result
|
||||
assert result["subscription_builder"]["id"] == "b1"
|
||||
|
||||
def test_get_builder(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderGetApi()
|
||||
@ -143,10 +187,10 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id",
|
||||
return_value={"id": "b1"},
|
||||
return_value=subscription_builder(),
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
assert method(api, "github", "b1")["id"] == "b1"
|
||||
|
||||
def test_verify_builder(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
@ -156,10 +200,10 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
app.test_request_context("/", json={"credentials": {"a": 1}}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
|
||||
return_value={"ok": True},
|
||||
return_value={"verified": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", mock_user(), "github", "b1") == {"ok": True}
|
||||
assert method(api, "t1", mock_user(), "github", "b1") == {"verified": True}
|
||||
|
||||
def test_verify_builder_error(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
@ -183,26 +227,24 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
app.test_request_context("/", json={"name": "n"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
|
||||
return_value={"id": "b1"},
|
||||
return_value=subscription_builder(),
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", "github", "b1") == {"id": "b1"}
|
||||
assert method(api, "t1", "github", "b1")["id"] == "b1"
|
||||
|
||||
def test_logs(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderLogsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
log = MagicMock()
|
||||
log.model_dump.return_value = {"a": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs",
|
||||
return_value=[log],
|
||||
return_value=[request_log()],
|
||||
),
|
||||
):
|
||||
assert "logs" in method(api, "github", "b1")
|
||||
result = method(api, "github", "b1")
|
||||
assert result["logs"][0]["id"] == "log1"
|
||||
|
||||
def test_build(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderBuildApi()
|
||||
@ -215,7 +257,7 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", mock_user(), "github", "b1") == 200
|
||||
assert method(api, "t1", mock_user(), "github", "b1") == {"result": "success"}
|
||||
|
||||
|
||||
class TestTriggerSubscriptionCrud:
|
||||
@ -239,7 +281,7 @@ class TestTriggerSubscriptionCrud:
|
||||
),
|
||||
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"),
|
||||
):
|
||||
assert method(api, "t1", "s1") == 200
|
||||
assert method(api, "t1", "s1") == {"result": "success"}
|
||||
|
||||
def test_update_not_found(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
@ -275,7 +317,7 @@ class TestTriggerSubscriptionCrud:
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription"
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", "s1") == 200
|
||||
assert method(api, "t1", "s1") == {"result": "success"}
|
||||
|
||||
def test_delete_subscription(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
@ -336,7 +378,7 @@ class TestTriggerOAuthApis:
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
|
||||
return_value=MagicMock(id="b1"),
|
||||
return_value=subscription_builder(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context",
|
||||
@ -480,7 +522,7 @@ class TestTriggerOAuthClientManageApi:
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider",
|
||||
return_value=MagicMock(get_oauth_client_schema=lambda: {}),
|
||||
return_value=MagicMock(get_oauth_client_schema=lambda: []),
|
||||
),
|
||||
):
|
||||
result = method(api, "t1", "github")
|
||||
@ -494,10 +536,10 @@ class TestTriggerOAuthClientManageApi:
|
||||
app.test_request_context("/", json={"enabled": True}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", "github") == {"ok": True}
|
||||
assert method(api, "t1", "github") == {"result": "success"}
|
||||
|
||||
def test_delete_client(self, app: Flask) -> None:
|
||||
api = TriggerOAuthClientManageApi()
|
||||
@ -507,10 +549,10 @@ class TestTriggerOAuthClientManageApi:
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", "github") == {"ok": True}
|
||||
assert method(api, "t1", "github") == {"result": "success"}
|
||||
|
||||
def test_oauth_client_post_value_error(self, app: Flask) -> None:
|
||||
api = TriggerOAuthClientManageApi()
|
||||
@ -540,10 +582,10 @@ class TestTriggerSubscriptionVerifyApi:
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
|
||||
return_value={"ok": True},
|
||||
return_value={"verified": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "t1", mock_user(), "github", "s1") == {"ok": True}
|
||||
assert method(api, "t1", mock_user(), "github", "s1") == {"verified": True}
|
||||
|
||||
@pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")])
|
||||
def test_verify_errors(self, app: Flask, raised_exception: Exception) -> None:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@ -10,16 +11,18 @@ from sqlalchemy.orm import Session
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from core.tools.errors import ApiToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from models import Account, Tenant
|
||||
from models import Account, AccountStatus, Tenant, TenantStatus
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
||||
MockDependencies = dict[str, MagicMock]
|
||||
|
||||
|
||||
class TestApiToolManageService:
|
||||
"""Integration tests for ApiToolManageService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
def mock_external_service_dependencies(self) -> Iterator[MockDependencies]:
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.tools.api_tools_manage_service.ToolLabelManager") as mock_tool_label_manager,
|
||||
@ -39,7 +42,9 @@ class TestApiToolManageService:
|
||||
"provider_controller": mock_provider_controller,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MockDependencies
|
||||
) -> tuple[Account, Tenant]:
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -57,7 +62,7 @@ class TestApiToolManageService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
|
||||
db_session_with_containers.add(account)
|
||||
@ -66,7 +71,7 @@ class TestApiToolManageService:
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
status=TenantStatus.NORMAL,
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
@ -88,7 +93,7 @@ class TestApiToolManageService:
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_openapi_schema(self):
|
||||
def _create_test_openapi_schema(self) -> str:
|
||||
"""Helper method to create a test OpenAPI schema."""
|
||||
return """
|
||||
{
|
||||
@ -121,8 +126,11 @@ class TestApiToolManageService:
|
||||
"""
|
||||
|
||||
def test_parser_api_schema_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test successful parsing of API schema.
|
||||
|
||||
@ -148,6 +156,8 @@ class TestApiToolManageService:
|
||||
# Verify credentials schema structure
|
||||
credentials_schema = result["credentials_schema"]
|
||||
assert len(credentials_schema) == 3
|
||||
assert all(isinstance(field, dict) for field in credentials_schema)
|
||||
assert all(isinstance(tool, dict) for tool in result["parameters_schema"])
|
||||
|
||||
# Check auth_type field
|
||||
auth_type_field = next(field for field in credentials_schema if field["name"] == "auth_type")
|
||||
@ -166,8 +176,11 @@ class TestApiToolManageService:
|
||||
assert api_key_value_field["default"] == ""
|
||||
|
||||
def test_parser_api_schema_invalid_schema(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test parsing of invalid API schema.
|
||||
|
||||
@ -186,8 +199,11 @@ class TestApiToolManageService:
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_parser_api_schema_malformed_json(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test parsing of malformed JSON schema.
|
||||
|
||||
@ -206,8 +222,11 @@ class TestApiToolManageService:
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_convert_schema_to_tool_bundles_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test successful conversion of schema to tool bundles.
|
||||
|
||||
@ -236,8 +255,11 @@ class TestApiToolManageService:
|
||||
assert tool_bundle.operation_id == "testOperation"
|
||||
|
||||
def test_convert_schema_to_tool_bundles_with_extra_info(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test successful conversion of schema to tool bundles with extra info.
|
||||
|
||||
@ -262,8 +284,11 @@ class TestApiToolManageService:
|
||||
assert isinstance(schema_type, str)
|
||||
|
||||
def test_convert_schema_to_tool_bundles_invalid_schema(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test conversion of invalid schema to tool bundles.
|
||||
|
||||
@ -282,8 +307,11 @@ class TestApiToolManageService:
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test successful creation of API tool provider.
|
||||
|
||||
@ -301,7 +329,7 @@ class TestApiToolManageService:
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
icon = {"content": "🔧", "background": "#FFF"}
|
||||
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
@ -341,6 +369,7 @@ class TestApiToolManageService:
|
||||
assert provider.schema_type_str == schema_type
|
||||
assert provider.privacy_policy == privacy_policy
|
||||
assert provider.custom_disclaimer == custom_disclaimer
|
||||
assert json.loads(provider.icon) == icon
|
||||
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
|
||||
@ -349,8 +378,11 @@ class TestApiToolManageService:
|
||||
mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once()
|
||||
|
||||
def test_create_api_tool_provider_duplicate_name(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test creation of API tool provider with duplicate name.
|
||||
|
||||
@ -366,7 +398,7 @@ class TestApiToolManageService:
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
icon = {"content": "🔧", "background": "#FFF"}
|
||||
credentials = {"auth_type": "none"}
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
@ -406,8 +438,11 @@ class TestApiToolManageService:
|
||||
assert f"provider {provider_name} already exists" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_invalid_schema_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test creation of API tool provider with invalid schema type.
|
||||
|
||||
@ -423,7 +458,7 @@ class TestApiToolManageService:
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
icon = {"content": "🔧", "background": "#FFF"}
|
||||
credentials = {"auth_type": "none"}
|
||||
schema_type = "invalid_type"
|
||||
schema = self._create_test_openapi_schema()
|
||||
@ -438,8 +473,11 @@ class TestApiToolManageService:
|
||||
assert "validation error" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test creation of API tool provider with missing auth type.
|
||||
|
||||
@ -455,7 +493,7 @@ class TestApiToolManageService:
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
icon = {"content": "🔧", "background": "#FFF"}
|
||||
credentials = {} # Missing auth_type
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
@ -481,8 +519,11 @@ class TestApiToolManageService:
|
||||
assert "auth_type is required" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_with_api_key_auth(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test successful creation of API tool provider with API key authentication.
|
||||
|
||||
@ -498,7 +539,7 @@ class TestApiToolManageService:
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔑"}
|
||||
icon = {"content": "🔑", "background": "#FFF"}
|
||||
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
@ -542,8 +583,11 @@ class TestApiToolManageService:
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
|
||||
|
||||
def test_delete_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""Test successful deletion of an API tool provider."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
@ -583,8 +627,8 @@ class TestApiToolManageService:
|
||||
assert deleted is None
|
||||
|
||||
def test_delete_api_tool_provider_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MockDependencies
|
||||
) -> None:
|
||||
"""Test deletion raises ValueError when provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
@ -595,14 +639,15 @@ class TestApiToolManageService:
|
||||
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
|
||||
|
||||
def test_update_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
fake = Faker()
|
||||
|
||||
# Firmware fix for cache.delete() in update flow
|
||||
mock_encrypter = mock_external_service_dependencies["encrypter"]
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.delete.return_value = None
|
||||
mock_encrypter.return_value = (mock_encrypter, mock_cache)
|
||||
@ -620,7 +665,7 @@ class TestApiToolManageService:
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=original_name,
|
||||
icon={"type": "emoji", "value": "🔧"},
|
||||
icon={"content": "🔧", "background": "#FFF"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
@ -646,7 +691,7 @@ class TestApiToolManageService:
|
||||
provider_name=new_name,
|
||||
original_provider=original_name,
|
||||
# new icon - changed 2
|
||||
icon={"type": "emoji", "value": "🚀"},
|
||||
icon={"content": "🚀", "background": "#FFF"},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
@ -677,9 +722,7 @@ class TestApiToolManageService:
|
||||
# - changed 1
|
||||
assert updated_provider.name == new_name
|
||||
# - changed 2
|
||||
icon_data = json.loads(updated_provider.icon)
|
||||
assert icon_data["type"] == "emoji"
|
||||
assert icon_data["value"] == "🚀"
|
||||
assert json.loads(updated_provider.icon) == {"content": "🚀", "background": "#FFF"}
|
||||
# - changed 3
|
||||
assert updated_provider.privacy_policy == "https://new-policy.com"
|
||||
# - changed 4
|
||||
@ -712,8 +755,11 @@ class TestApiToolManageService:
|
||||
)
|
||||
|
||||
def test_update_api_tool_provider_not_found(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""
|
||||
Test update raises ValueError when original provider not found.
|
||||
|
||||
@ -733,7 +779,7 @@ class TestApiToolManageService:
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=existing_provider_name,
|
||||
icon={"type": "emoji", "value": "🔧"},
|
||||
icon={"content": "🔧", "background": "#FFF"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
@ -756,7 +802,7 @@ class TestApiToolManageService:
|
||||
tenant_id=tenant.id,
|
||||
provider_name=target_new_name,
|
||||
original_provider=missing_original_name,
|
||||
icon={"type": "emoji", "value": "🚀"},
|
||||
icon={"content": "🚀", "background": "#FFF"},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
@ -793,8 +839,11 @@ class TestApiToolManageService:
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_not_called()
|
||||
|
||||
def test_update_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self,
|
||||
flask_req_ctx_with_containers: object,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockDependencies,
|
||||
) -> None:
|
||||
"""Test update raises ValueError when auth_type is missing from credentials."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
@ -822,7 +871,7 @@ class TestApiToolManageService:
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
original_provider=provider_name,
|
||||
icon={},
|
||||
icon={"content": "🔧", "background": "#FFF"},
|
||||
credentials={},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=schema,
|
||||
@ -832,8 +881,8 @@ class TestApiToolManageService:
|
||||
)
|
||||
|
||||
def test_list_api_tool_provider_tools_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MockDependencies
|
||||
) -> None:
|
||||
"""Test listing tools raises ValueError when provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
@ -844,8 +893,8 @@ class TestApiToolManageService:
|
||||
ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent")
|
||||
|
||||
def test_test_api_tool_preview_invalid_schema_type(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MockDependencies
|
||||
) -> None:
|
||||
"""Test preview raises ValueError for invalid schema type."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
|
||||
@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import builtins
|
||||
import importlib
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from inspect import unwrap
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -13,6 +12,9 @@ import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity as CoreToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
from models import Account
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
@ -21,6 +23,7 @@ if not hasattr(builtins, "MethodView"):
|
||||
|
||||
|
||||
_CONTROLLER_MODULE: ModuleType | None = None
|
||||
_WRAPS_MODULE: ModuleType | None = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -69,10 +72,11 @@ def controller_module(monkeypatch: pytest.MonkeyPatch):
|
||||
_CONTROLLER_MODULE = importlib.import_module(module_name)
|
||||
|
||||
module = _CONTROLLER_MODULE
|
||||
monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
# Ensure decorators that consult deployment edition do not reach the database.
|
||||
global _WRAPS_MODULE
|
||||
wraps_module = importlib.import_module("controllers.console.wraps")
|
||||
_WRAPS_MODULE = wraps_module
|
||||
monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
|
||||
|
||||
@ -88,19 +92,194 @@ def _mock_account(user_id: str = "user-123") -> Account:
|
||||
return user
|
||||
|
||||
|
||||
def _set_current_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
controller_module: ModuleType,
|
||||
user: Account,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
def _getter():
|
||||
return user, tenant_id
|
||||
|
||||
monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter, raising=False)
|
||||
if _WRAPS_MODULE is not None:
|
||||
monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter)
|
||||
|
||||
login_module = importlib.import_module("libs.login")
|
||||
monkeypatch.setattr(login_module, "_get_user", lambda: user)
|
||||
|
||||
|
||||
def _i18n(text: str) -> dict[str, str]:
|
||||
return {"en_US": text, "zh_Hans": text, "pt_BR": text, "ja_JP": text}
|
||||
|
||||
|
||||
def _tool_response(controller_module: ModuleType, name: str = "tool-a") -> tuple[dict, dict]:
|
||||
expected = {
|
||||
"author": "Dify",
|
||||
"name": name,
|
||||
"label": _i18n(name),
|
||||
"description": _i18n(f"{name} description"),
|
||||
"parameters": [],
|
||||
"labels": [],
|
||||
"output_schema": {},
|
||||
}
|
||||
tool = controller_module.ToolApiEntity.model_validate(expected)
|
||||
return tool.model_dump(mode="json"), expected
|
||||
|
||||
|
||||
def _provider_entity_response(
|
||||
controller_module: ModuleType, name: str = "provider", provider_type: str = "builtin"
|
||||
) -> tuple[CoreToolProviderApiEntity, dict]:
|
||||
service_payload = {
|
||||
"id": f"{name}-id",
|
||||
"author": "Dify",
|
||||
"name": name,
|
||||
"description": _i18n(f"{name} description"),
|
||||
"icon": "tool.svg",
|
||||
"icon_dark": "",
|
||||
"label": _i18n(name),
|
||||
"type": provider_type,
|
||||
"masked_credentials": {"api_key": "[__HIDDEN__]"},
|
||||
"original_credentials": {"api_key": "sk-secret"},
|
||||
"is_team_authorization": False,
|
||||
"allow_delete": True,
|
||||
"plugin_id": "",
|
||||
"plugin_unique_identifier": "",
|
||||
"tools": [],
|
||||
"labels": [],
|
||||
"server_url": "",
|
||||
"updated_at": 1,
|
||||
"server_identifier": "",
|
||||
"masked_headers": None,
|
||||
"original_headers": None,
|
||||
"authentication": None,
|
||||
"is_dynamic_registration": True,
|
||||
"configuration": None,
|
||||
"identity_mode": "off",
|
||||
"workflow_app_id": None,
|
||||
}
|
||||
provider = CoreToolProviderApiEntity.model_validate(service_payload)
|
||||
return provider, provider.to_dict()
|
||||
|
||||
|
||||
def _provider_list_item(
|
||||
controller_module: ModuleType, name: str = "provider", provider_type: str = "builtin"
|
||||
) -> tuple[dict, dict]:
|
||||
service_payload = {
|
||||
"id": f"{name}-id",
|
||||
"author": "Dify",
|
||||
"name": name,
|
||||
"description": _i18n(f"{name} description"),
|
||||
"icon": "tool.svg",
|
||||
"icon_dark": "",
|
||||
"label": _i18n(name),
|
||||
"type": provider_type,
|
||||
"team_credentials": {"api_key": "[__HIDDEN__]"},
|
||||
"is_team_authorization": False,
|
||||
"allow_delete": True,
|
||||
"plugin_id": "",
|
||||
"plugin_unique_identifier": "",
|
||||
"tools": [],
|
||||
"labels": [],
|
||||
}
|
||||
expected = {
|
||||
**service_payload,
|
||||
}
|
||||
provider = controller_module.ToolProviderApiEntityResponse.model_validate(expected)
|
||||
return service_payload, provider.model_dump(mode="json", exclude_unset=True)
|
||||
|
||||
|
||||
def _credential_response(controller_module: ModuleType, credential_id: str = "cred-1") -> tuple[dict, dict]:
|
||||
expected = {
|
||||
"id": credential_id,
|
||||
"name": "Credential",
|
||||
"provider": "demo",
|
||||
"credential_type": controller_module.CredentialType.API_KEY,
|
||||
"is_default": False,
|
||||
"credentials": {},
|
||||
"visibility": "all_team_members",
|
||||
"created_by": "",
|
||||
"partial_member_list": [],
|
||||
"from_other_member": False,
|
||||
}
|
||||
credential = controller_module.ToolProviderCredentialApiEntity.model_validate(expected)
|
||||
return credential.model_dump(mode="json"), credential.model_dump(mode="json")
|
||||
|
||||
|
||||
def _provider_config_response(controller_module: ModuleType) -> tuple[dict, dict]:
|
||||
expected = {
|
||||
"type": "secret-input",
|
||||
"name": "api_key",
|
||||
"scope": None,
|
||||
"required": False,
|
||||
"default": None,
|
||||
"options": None,
|
||||
"multiple": False,
|
||||
"label": None,
|
||||
"help": None,
|
||||
"url": None,
|
||||
"placeholder": None,
|
||||
}
|
||||
config = controller_module.ProviderConfig.model_validate(expected)
|
||||
return config.model_dump(mode="json"), expected
|
||||
|
||||
|
||||
def _api_provider_detail_response(controller_module: ModuleType) -> tuple[dict, dict]:
|
||||
expected = {
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"tools": [],
|
||||
"icon": {"background": "#252525", "content": "tool"},
|
||||
"description": "provider description",
|
||||
"credentials": {"auth_type": "none"},
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
"labels": [],
|
||||
}
|
||||
detail = controller_module.ApiProviderDetailResponse.model_validate(expected)
|
||||
return detail.model_dump(mode="json", by_alias=True), expected
|
||||
|
||||
|
||||
def _workflow_detail_response(controller_module: ModuleType) -> tuple[dict, dict]:
|
||||
tool_payload, tool_expected = _tool_response(controller_module, "workflow-tool")
|
||||
expected = {
|
||||
"name": "workflow-tool",
|
||||
"label": "Workflow Tool",
|
||||
"workflow_tool_id": "00000000-0000-0000-0000-000000000001",
|
||||
"workflow_app_id": "00000000-0000-0000-0000-000000000002",
|
||||
"icon": {"background": "#252525", "content": "tool"},
|
||||
"description": "description",
|
||||
"parameters": [],
|
||||
"output_schema": {},
|
||||
"tool": tool_expected,
|
||||
"synced": True,
|
||||
"privacy_policy": "",
|
||||
}
|
||||
service_payload = {**expected, "tool": tool_payload}
|
||||
detail = controller_module.WorkflowToolDetailResponse.model_validate(service_payload)
|
||||
return detail.model_dump(mode="json"), expected
|
||||
|
||||
|
||||
def _tool_label_response(controller_module: ModuleType, name: str = "search") -> tuple[dict, dict]:
|
||||
expected = {"name": name, "label": _i18n(name), "icon": "search"}
|
||||
label = controller_module.ToolLabel.model_validate(expected)
|
||||
return label.model_dump(mode="json"), expected
|
||||
|
||||
|
||||
def test_tool_provider_list_calls_service_with_query(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"provider": "builtin"}])
|
||||
service_payload, expected_response = _provider_list_item(controller_module, "builtin", "builtin")
|
||||
service_mock = MagicMock(return_value=[service_payload])
|
||||
monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock)
|
||||
|
||||
with app.test_request_context("/workspaces/current/tool-providers?type=builtin"):
|
||||
api = controller_module.ToolProviderListApi()
|
||||
response = unwrap(api.get)(api, "tenant-456", user)
|
||||
response = controller_module.ToolProviderListApi().get()
|
||||
|
||||
assert response == [{"provider": "builtin"}]
|
||||
assert response == [expected_response]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-456", "builtin")
|
||||
|
||||
|
||||
@ -108,8 +287,9 @@ def test_builtin_provider_add_passes_payload(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value={"status": "ok"})
|
||||
service_mock = MagicMock(return_value={"result": "success"})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock)
|
||||
|
||||
payload = {
|
||||
@ -123,10 +303,9 @@ def test_builtin_provider_add_passes_payload(
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
api = controller_module.ToolBuiltinProviderAddApi()
|
||||
response = unwrap(api.post)(api, "tenant-456", user, provider="openai")
|
||||
response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai")
|
||||
|
||||
assert response == {"status": "ok"}
|
||||
assert response == {"result": "success"}
|
||||
service_mock.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
tenant_id="tenant-456",
|
||||
@ -140,38 +319,88 @@ def test_builtin_provider_add_passes_payload(
|
||||
|
||||
def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-789")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-789")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"name": "tool-a"}])
|
||||
service_payload, expected_response = _tool_response(controller_module, "tool-a")
|
||||
service_mock = MagicMock(return_value=[service_payload])
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock)
|
||||
monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/tool-provider/builtin/my-provider/tools",
|
||||
method="GET",
|
||||
):
|
||||
api = controller_module.ToolBuiltinProviderListToolsApi()
|
||||
response = unwrap(api.get)(api, "tenant-789", provider="my-provider")
|
||||
response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider")
|
||||
|
||||
assert response == [{"name": "tool-a"}]
|
||||
assert response == [expected_response]
|
||||
service_mock.assert_called_once_with("tenant-789", "my-provider")
|
||||
|
||||
|
||||
def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-9")
|
||||
service_mock = MagicMock(return_value={"info": True})
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
|
||||
service_payload, expected_response = _provider_entity_response(controller_module, "demo", "builtin")
|
||||
service_mock = MagicMock(return_value=service_payload)
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
|
||||
|
||||
with app.test_request_context("/info", method="GET"):
|
||||
api = controller_module.ToolBuiltinProviderInfoApi()
|
||||
resp = unwrap(api.get)(api, "tenant-9", provider="demo")
|
||||
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
|
||||
|
||||
assert resp == {"info": True}
|
||||
assert resp == expected_response
|
||||
service_mock.assert_called_once_with("tenant-9", "demo")
|
||||
|
||||
|
||||
def test_builtin_provider_info_uses_core_to_dict_tool_projection(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account("user-tenant-9")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
|
||||
tool_parameter = ToolParameter(
|
||||
name="system_files",
|
||||
label=I18nObject(en_US="System Files", zh_Hans="System Files"),
|
||||
type=ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
input_schema=None,
|
||||
)
|
||||
tool = controller_module.ToolApiEntity(
|
||||
author="Dify",
|
||||
name="demo-tool",
|
||||
label=I18nObject(en_US="Demo Tool", zh_Hans="Demo Tool"),
|
||||
description=I18nObject(en_US="Demo Tool description", zh_Hans="Demo Tool description"),
|
||||
parameters=[tool_parameter],
|
||||
labels=[],
|
||||
output_schema={},
|
||||
)
|
||||
provider = CoreToolProviderApiEntity(
|
||||
id="demo-id",
|
||||
author="Dify",
|
||||
name="demo",
|
||||
description=I18nObject(en_US="demo description", zh_Hans="demo description"),
|
||||
icon="tool.svg",
|
||||
label=I18nObject(en_US="demo", zh_Hans="demo"),
|
||||
type=controller_module.ToolProviderType.BUILT_IN,
|
||||
masked_credentials={"api_key": "[__HIDDEN__]"},
|
||||
original_credentials={"api_key": "sk-secret"},
|
||||
tools=[tool],
|
||||
)
|
||||
service_mock = MagicMock(return_value=provider)
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
|
||||
|
||||
with app.test_request_context("/info", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
|
||||
|
||||
parameter = resp["tools"][0]["parameters"][0]
|
||||
assert parameter["type"] == "files"
|
||||
assert parameter["input_schema"] is None
|
||||
assert resp["team_credentials"] == {"api_key": "[__HIDDEN__]"}
|
||||
assert "masked_credentials" not in resp
|
||||
assert "original_credentials" not in resp
|
||||
|
||||
|
||||
def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-cred")
|
||||
service_mock = MagicMock(return_value=[{"cred": 1}])
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-cred")
|
||||
service_payload, expected_response = _credential_response(controller_module)
|
||||
service_mock = MagicMock(return_value=[service_payload])
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"get_builtin_tool_provider_credentials",
|
||||
@ -179,10 +408,9 @@ def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeyp
|
||||
)
|
||||
|
||||
with app.test_request_context("/creds", method="GET"):
|
||||
api = controller_module.ToolBuiltinProviderGetCredentialsApi()
|
||||
resp = unwrap(api.get)(api, "tenant-cred", user, provider="demo")
|
||||
resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo")
|
||||
|
||||
assert resp == [{"cred": 1}]
|
||||
assert resp == [expected_response]
|
||||
service_mock.assert_called_once_with(
|
||||
tenant_id="tenant-cred",
|
||||
provider_name="demo",
|
||||
@ -193,46 +421,51 @@ def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeyp
|
||||
|
||||
def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
service_mock = MagicMock(return_value={"schema": "ok"})
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-10")
|
||||
openapi_schema = '{"openapi":"3.0.0","info":{"title":"Demo API","version":"1.0.0"},"paths":{}}'
|
||||
service_mock = MagicMock(return_value={"schema": openapi_schema})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock)
|
||||
|
||||
with app.test_request_context("/remote?url=https://example.com/"):
|
||||
api = controller_module.ToolApiProviderGetRemoteSchemaApi()
|
||||
resp = unwrap(api.get)(api, "tenant-10", user)
|
||||
resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get()
|
||||
|
||||
assert resp == {"schema": "ok"}
|
||||
assert resp == {"schema": openapi_schema}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/")
|
||||
|
||||
|
||||
def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
service_mock = MagicMock(return_value=[{"tool": "t"}])
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-11")
|
||||
service_payload, expected_response = _tool_response(controller_module, "t")
|
||||
service_mock = MagicMock(return_value=[service_payload])
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock)
|
||||
|
||||
with app.test_request_context("/tools?provider=foo"):
|
||||
api = controller_module.ToolApiProviderListToolsApi()
|
||||
resp = unwrap(api.get)(api, "tenant-11", user)
|
||||
resp = controller_module.ToolApiProviderListToolsApi().get()
|
||||
|
||||
assert resp == [{"tool": "t"}]
|
||||
assert resp == [expected_response]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-11", "foo")
|
||||
|
||||
|
||||
def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
service_mock = MagicMock(return_value={"provider": "foo"})
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-12")
|
||||
service_payload, expected_response = _api_provider_detail_response(controller_module)
|
||||
service_mock = MagicMock(return_value=service_payload)
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock)
|
||||
|
||||
with app.test_request_context("/get?provider=foo"):
|
||||
api = controller_module.ToolApiProviderGetApi()
|
||||
resp = unwrap(api.get)(api, "tenant-12", user)
|
||||
resp = controller_module.ToolApiProviderGetApi().get()
|
||||
|
||||
assert resp == {"provider": "foo"}
|
||||
assert resp == expected_response
|
||||
service_mock.assert_called_once_with(user.id, "tenant-12", "foo")
|
||||
|
||||
|
||||
def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-13")
|
||||
service_mock = MagicMock(return_value={"schema": True})
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-13")
|
||||
service_payload, expected_response = _provider_config_response(controller_module)
|
||||
service_mock = MagicMock(return_value=[service_payload])
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"list_builtin_provider_credentials_schema",
|
||||
@ -240,16 +473,19 @@ def test_builtin_provider_credentials_schema_get(app: Flask, controller_module,
|
||||
)
|
||||
|
||||
with app.test_request_context("/schema", method="GET"):
|
||||
api = controller_module.ToolBuiltinProviderCredentialsSchemaApi()
|
||||
resp = unwrap(api.get)(api, "tenant-13", provider="demo", credential_type="api-key")
|
||||
resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get(
|
||||
provider="demo", credential_type="api-key"
|
||||
)
|
||||
|
||||
assert resp == {"schema": True}
|
||||
assert resp == [expected_response]
|
||||
service_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
tool_service = MagicMock(return_value={"wf": 1})
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf")
|
||||
service_payload, expected_response = _workflow_detail_response(controller_module)
|
||||
tool_service = MagicMock(return_value=service_payload)
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"get_workflow_tool_by_tool_id",
|
||||
@ -258,16 +494,17 @@ def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatc
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000001"
|
||||
with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"):
|
||||
api = controller_module.ToolWorkflowProviderGetApi()
|
||||
resp = unwrap(api.get)(api, "tenant-wf", user)
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
|
||||
assert resp == {"wf": 1}
|
||||
assert resp == expected_response
|
||||
tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id)
|
||||
|
||||
|
||||
def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
service_mock = MagicMock(return_value={"app": 1})
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf2")
|
||||
service_payload, expected_response = _workflow_detail_response(controller_module)
|
||||
service_mock = MagicMock(return_value=service_payload)
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"get_workflow_tool_by_app_id",
|
||||
@ -276,31 +513,32 @@ def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch
|
||||
|
||||
app_id = "00000000-0000-0000-0000-000000000002"
|
||||
with app.test_request_context(f"/workflow?workflow_app_id={app_id}"):
|
||||
api = controller_module.ToolWorkflowProviderGetApi()
|
||||
resp = unwrap(api.get)(api, "tenant-wf2", user)
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
|
||||
assert resp == {"app": 1}
|
||||
assert resp == expected_response
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id)
|
||||
|
||||
|
||||
def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
service_mock = MagicMock(return_value=[{"id": 1}])
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf3")
|
||||
service_payload, expected_response = _tool_response(controller_module, "workflow-tool")
|
||||
service_mock = MagicMock(return_value=[service_payload])
|
||||
monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock)
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000003"
|
||||
with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"):
|
||||
api = controller_module.ToolWorkflowProviderListToolApi()
|
||||
resp = unwrap(api.get)(api, "tenant-wf3", user)
|
||||
resp = controller_module.ToolWorkflowProviderListToolApi().get()
|
||||
|
||||
assert resp == [{"id": 1}]
|
||||
assert resp == [expected_response]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id)
|
||||
|
||||
|
||||
def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-bt")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"})
|
||||
provider, expected_response = _provider_entity_response(controller_module, "builtin", "builtin")
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"list_builtin_tools",
|
||||
@ -308,16 +546,16 @@ def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.M
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/builtin"):
|
||||
api = controller_module.ToolBuiltinListApi()
|
||||
resp = unwrap(api.get)(api, "tenant-bt", user)
|
||||
resp = controller_module.ToolBuiltinListApi().get()
|
||||
|
||||
assert resp == [{"name": "builtin"}]
|
||||
assert resp == [expected_response]
|
||||
|
||||
|
||||
def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-api")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-api")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "api"})
|
||||
provider, expected_response = _provider_entity_response(controller_module, "api", "api")
|
||||
monkeypatch.setattr(
|
||||
controller_module.ApiToolManageService,
|
||||
"list_api_tools",
|
||||
@ -325,16 +563,16 @@ def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.Monke
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/api"):
|
||||
api = controller_module.ToolApiListApi()
|
||||
resp = unwrap(api.get)(api, "tenant-api")
|
||||
resp = controller_module.ToolApiListApi().get()
|
||||
|
||||
assert resp == [{"name": "api"}]
|
||||
assert resp == [expected_response]
|
||||
|
||||
|
||||
def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf4")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "wf"})
|
||||
provider, expected_response = _provider_entity_response(controller_module, "wf", "workflow")
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"list_tenant_workflow_tools",
|
||||
@ -342,20 +580,21 @@ def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/workflow"):
|
||||
api = controller_module.ToolWorkflowListApi()
|
||||
resp = unwrap(api.get)(api, "tenant-wf4", user)
|
||||
resp = controller_module.ToolWorkflowListApi().get()
|
||||
|
||||
assert resp == [{"name": "wf"}]
|
||||
assert resp == [expected_response]
|
||||
|
||||
|
||||
def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"])
|
||||
user = _mock_account("user-label")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-labels")
|
||||
service_payload, expected_response = _tool_label_response(controller_module, "a")
|
||||
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: [service_payload])
|
||||
|
||||
with app.test_request_context("/tool-labels"):
|
||||
api = controller_module.ToolLabelsApi()
|
||||
resp = unwrap(api.get)(api)
|
||||
resp = controller_module.ToolLabelsApi().get()
|
||||
|
||||
assert resp == ["a", "b"]
|
||||
assert resp == [expected_response]
|
||||
|
||||
|
||||
# --- _resolve_identity_mode: gating + None-resolution (PR #36839 review) ---
|
||||
|
||||
@ -1009,8 +1009,7 @@ class TestExternalDatasetApi:
|
||||
# 4. Provide default values when parameters are missing
|
||||
# 5. Raise BadRequest exceptions when validation fails
|
||||
#
|
||||
# Response formatting is handled by Flask-RESTX's marshal_with decorator
|
||||
# or marshal function, which:
|
||||
# Response formatting is handled by controller response schemas, which:
|
||||
#
|
||||
# 1. Formats response data according to defined models
|
||||
# 2. Handles nested objects and lists
|
||||
|
||||
@ -1984,7 +1984,14 @@ export type ModelConfigPartial = {
|
||||
|
||||
export type LlmMode = 'chat' | 'completion'
|
||||
|
||||
export type Type = 'github' | 'marketplace' | 'package'
|
||||
export type Type
|
||||
= | 'app-selector'
|
||||
| 'array[tools]'
|
||||
| 'boolean'
|
||||
| 'model-selector'
|
||||
| 'secret-input'
|
||||
| 'select'
|
||||
| 'text-input'
|
||||
|
||||
export type Github = {
|
||||
github_plugin_unique_identifier: string
|
||||
|
||||
@ -2169,7 +2169,15 @@ export const zConversationMessageDetail = z.object({
|
||||
/**
|
||||
* Type
|
||||
*/
|
||||
export const zType = z.enum(['github', 'marketplace', 'package'])
|
||||
export const zType = z.enum([
|
||||
'app-selector',
|
||||
'array[tools]',
|
||||
'boolean',
|
||||
'model-selector',
|
||||
'secret-input',
|
||||
'select',
|
||||
'text-input',
|
||||
])
|
||||
|
||||
/**
|
||||
* Github
|
||||
|
||||
@ -145,7 +145,7 @@ export const zGetOauthPluginByProviderToolAuthorizationUrlPath = z.object({
|
||||
})
|
||||
|
||||
/**
|
||||
* Authorization URL retrieved successfully
|
||||
* Tool OAuth authorization URL generated successfully
|
||||
*/
|
||||
export const zGetOauthPluginByProviderToolAuthorizationUrlResponse
|
||||
= zPluginOAuthAuthorizationUrlResponse
|
||||
|
||||
@ -515,7 +515,14 @@ export type DatasetWeightedScoreResponse = {
|
||||
weight_type?: string | null
|
||||
}
|
||||
|
||||
export type Type = 'github' | 'marketplace' | 'package'
|
||||
export type Type
|
||||
= | 'app-selector'
|
||||
| 'array[tools]'
|
||||
| 'boolean'
|
||||
| 'model-selector'
|
||||
| 'secret-input'
|
||||
| 'select'
|
||||
| 'text-input'
|
||||
|
||||
export type Github = {
|
||||
github_plugin_unique_identifier: string
|
||||
|
||||
@ -547,7 +547,15 @@ export const zDatasetRerankingModelResponse = z.object({
|
||||
/**
|
||||
* Type
|
||||
*/
|
||||
export const zType = z.enum(['github', 'marketplace', 'package'])
|
||||
export const zType = z.enum([
|
||||
'app-selector',
|
||||
'array[tools]',
|
||||
'boolean',
|
||||
'model-selector',
|
||||
'secret-input',
|
||||
'select',
|
||||
'text-input',
|
||||
])
|
||||
|
||||
/**
|
||||
* Github
|
||||
|
||||
@ -157,14 +157,10 @@ import {
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderCredentialsPath,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderCredentialsQuery,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderCredentialsResponse,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderIconPath,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderIconResponse,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderInfoPath,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderInfoResponse,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthClientSchemaPath,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthClientSchemaResponse,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClientPath,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClientResponse,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderToolsPath,
|
||||
zGetWorkspacesCurrentToolProviderBuiltinByProviderToolsResponse,
|
||||
zGetWorkspacesCurrentToolProviderMcpToolsByProviderIdPath,
|
||||
@ -181,8 +177,6 @@ import {
|
||||
zGetWorkspacesCurrentToolsBuiltinResponse,
|
||||
zGetWorkspacesCurrentToolsMcpResponse,
|
||||
zGetWorkspacesCurrentToolsWorkflowResponse,
|
||||
zGetWorkspacesCurrentTriggerProviderByProviderIconPath,
|
||||
zGetWorkspacesCurrentTriggerProviderByProviderIconResponse,
|
||||
zGetWorkspacesCurrentTriggerProviderByProviderInfoPath,
|
||||
zGetWorkspacesCurrentTriggerProviderByProviderInfoResponse,
|
||||
zGetWorkspacesCurrentTriggerProviderByProviderOauthClientPath,
|
||||
@ -3212,21 +3206,6 @@ export const delete14 = {
|
||||
}
|
||||
|
||||
export const get67 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
operationId: 'getWorkspacesCurrentToolProviderBuiltinByProviderIcon',
|
||||
path: '/workspaces/current/tool-provider/builtin/{provider}/icon',
|
||||
tags: ['console'],
|
||||
})
|
||||
.input(z.object({ params: zGetWorkspacesCurrentToolProviderBuiltinByProviderIconPath }))
|
||||
.output(zGetWorkspacesCurrentToolProviderBuiltinByProviderIconResponse)
|
||||
|
||||
export const icon2 = {
|
||||
get: get67,
|
||||
}
|
||||
|
||||
export const get68 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3238,10 +3217,10 @@ export const get68 = oc
|
||||
.output(zGetWorkspacesCurrentToolProviderBuiltinByProviderInfoResponse)
|
||||
|
||||
export const info2 = {
|
||||
get: get68,
|
||||
get: get67,
|
||||
}
|
||||
|
||||
export const get69 = oc
|
||||
export const get68 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3255,7 +3234,7 @@ export const get69 = oc
|
||||
.output(zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthClientSchemaResponse)
|
||||
|
||||
export const clientSchema = {
|
||||
get: get69,
|
||||
get: get68,
|
||||
}
|
||||
|
||||
export const delete15 = oc
|
||||
@ -3273,19 +3252,6 @@ export const delete15 = oc
|
||||
)
|
||||
.output(zDeleteWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClientResponse)
|
||||
|
||||
export const get70 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
operationId: 'getWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClient',
|
||||
path: '/workspaces/current/tool-provider/builtin/{provider}/oauth/custom-client',
|
||||
tags: ['console'],
|
||||
})
|
||||
.input(
|
||||
z.object({ params: zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClientPath }),
|
||||
)
|
||||
.output(zGetWorkspacesCurrentToolProviderBuiltinByProviderOauthCustomClientResponse)
|
||||
|
||||
export const post56 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
@ -3304,7 +3270,6 @@ export const post56 = oc
|
||||
|
||||
export const customClient = {
|
||||
delete: delete15,
|
||||
get: get70,
|
||||
post: post56,
|
||||
}
|
||||
|
||||
@ -3313,7 +3278,7 @@ export const oauth = {
|
||||
customClient,
|
||||
}
|
||||
|
||||
export const get71 = oc
|
||||
export const get69 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3325,7 +3290,7 @@ export const get71 = oc
|
||||
.output(zGetWorkspacesCurrentToolProviderBuiltinByProviderToolsResponse)
|
||||
|
||||
export const tools2 = {
|
||||
get: get71,
|
||||
get: get69,
|
||||
}
|
||||
|
||||
export const post57 = oc
|
||||
@ -3354,7 +3319,6 @@ export const byProvider2 = {
|
||||
credentials: credentials3,
|
||||
defaultCredential,
|
||||
delete: delete14,
|
||||
icon: icon2,
|
||||
info: info2,
|
||||
oauth,
|
||||
tools: tools2,
|
||||
@ -3380,7 +3344,7 @@ export const auth = {
|
||||
post: post58,
|
||||
}
|
||||
|
||||
export const get72 = oc
|
||||
export const get70 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3392,14 +3356,14 @@ export const get72 = oc
|
||||
.output(zGetWorkspacesCurrentToolProviderMcpToolsByProviderIdResponse)
|
||||
|
||||
export const byProviderId = {
|
||||
get: get72,
|
||||
get: get70,
|
||||
}
|
||||
|
||||
export const tools3 = {
|
||||
byProviderId,
|
||||
}
|
||||
|
||||
export const get73 = oc
|
||||
export const get71 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3411,7 +3375,7 @@ export const get73 = oc
|
||||
.output(zGetWorkspacesCurrentToolProviderMcpUpdateByProviderIdResponse)
|
||||
|
||||
export const byProviderId2 = {
|
||||
get: get73,
|
||||
get: get71,
|
||||
}
|
||||
|
||||
export const update4 = {
|
||||
@ -3490,7 +3454,7 @@ export const delete17 = {
|
||||
post: post61,
|
||||
}
|
||||
|
||||
export const get74 = oc
|
||||
export const get72 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3501,11 +3465,11 @@ export const get74 = oc
|
||||
.input(z.object({ query: zGetWorkspacesCurrentToolProviderWorkflowGetQuery.optional() }))
|
||||
.output(zGetWorkspacesCurrentToolProviderWorkflowGetResponse)
|
||||
|
||||
export const get75 = {
|
||||
get: get74,
|
||||
export const get73 = {
|
||||
get: get72,
|
||||
}
|
||||
|
||||
export const get76 = oc
|
||||
export const get74 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3517,7 +3481,7 @@ export const get76 = oc
|
||||
.output(zGetWorkspacesCurrentToolProviderWorkflowToolsResponse)
|
||||
|
||||
export const tools4 = {
|
||||
get: get76,
|
||||
get: get74,
|
||||
}
|
||||
|
||||
export const post62 = oc
|
||||
@ -3538,7 +3502,7 @@ export const update5 = {
|
||||
export const workflow = {
|
||||
create: create2,
|
||||
delete: delete17,
|
||||
get: get75,
|
||||
get: get73,
|
||||
tools: tools4,
|
||||
update: update5,
|
||||
}
|
||||
@ -3550,7 +3514,7 @@ export const toolProvider = {
|
||||
workflow,
|
||||
}
|
||||
|
||||
export const get77 = oc
|
||||
export const get75 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3562,10 +3526,10 @@ export const get77 = oc
|
||||
.output(zGetWorkspacesCurrentToolProvidersResponse)
|
||||
|
||||
export const toolProviders = {
|
||||
get: get77,
|
||||
get: get75,
|
||||
}
|
||||
|
||||
export const get78 = oc
|
||||
export const get76 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3576,10 +3540,10 @@ export const get78 = oc
|
||||
.output(zGetWorkspacesCurrentToolsApiResponse)
|
||||
|
||||
export const api2 = {
|
||||
get: get78,
|
||||
get: get76,
|
||||
}
|
||||
|
||||
export const get79 = oc
|
||||
export const get77 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3590,10 +3554,10 @@ export const get79 = oc
|
||||
.output(zGetWorkspacesCurrentToolsBuiltinResponse)
|
||||
|
||||
export const builtin2 = {
|
||||
get: get79,
|
||||
get: get77,
|
||||
}
|
||||
|
||||
export const get80 = oc
|
||||
export const get78 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3604,10 +3568,10 @@ export const get80 = oc
|
||||
.output(zGetWorkspacesCurrentToolsMcpResponse)
|
||||
|
||||
export const mcp2 = {
|
||||
get: get80,
|
||||
get: get78,
|
||||
}
|
||||
|
||||
export const get81 = oc
|
||||
export const get79 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3618,7 +3582,7 @@ export const get81 = oc
|
||||
.output(zGetWorkspacesCurrentToolsWorkflowResponse)
|
||||
|
||||
export const workflow2 = {
|
||||
get: get81,
|
||||
get: get79,
|
||||
}
|
||||
|
||||
export const tools5 = {
|
||||
@ -3628,25 +3592,10 @@ export const tools5 = {
|
||||
workflow: workflow2,
|
||||
}
|
||||
|
||||
export const get82 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
operationId: 'getWorkspacesCurrentTriggerProviderByProviderIcon',
|
||||
path: '/workspaces/current/trigger-provider/{provider}/icon',
|
||||
tags: ['console'],
|
||||
})
|
||||
.input(z.object({ params: zGetWorkspacesCurrentTriggerProviderByProviderIconPath }))
|
||||
.output(zGetWorkspacesCurrentTriggerProviderByProviderIconResponse)
|
||||
|
||||
export const icon3 = {
|
||||
get: get82,
|
||||
}
|
||||
|
||||
/**
|
||||
* Get info for a trigger provider
|
||||
*/
|
||||
export const get83 = oc
|
||||
export const get80 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3659,7 +3608,7 @@ export const get83 = oc
|
||||
.output(zGetWorkspacesCurrentTriggerProviderByProviderInfoResponse)
|
||||
|
||||
export const info3 = {
|
||||
get: get83,
|
||||
get: get80,
|
||||
}
|
||||
|
||||
/**
|
||||
@ -3680,7 +3629,7 @@ export const delete18 = oc
|
||||
/**
|
||||
* Get OAuth client configuration for a provider
|
||||
*/
|
||||
export const get84 = oc
|
||||
export const get81 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3714,7 +3663,7 @@ export const post63 = oc
|
||||
|
||||
export const client = {
|
||||
delete: delete18,
|
||||
get: get84,
|
||||
get: get81,
|
||||
post: post63,
|
||||
}
|
||||
|
||||
@ -3781,7 +3730,7 @@ export const create3 = {
|
||||
/**
|
||||
* Get the request logs for a subscription instance for a trigger provider
|
||||
*/
|
||||
export const get85 = oc
|
||||
export const get82 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3802,7 +3751,7 @@ export const get85 = oc
|
||||
)
|
||||
|
||||
export const bySubscriptionBuilderId2 = {
|
||||
get: get85,
|
||||
get: get82,
|
||||
}
|
||||
|
||||
export const logs = {
|
||||
@ -3876,7 +3825,7 @@ export const verifyAndUpdate = {
|
||||
/**
|
||||
* Get a subscription instance for a trigger provider
|
||||
*/
|
||||
export const get86 = oc
|
||||
export const get83 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3897,7 +3846,7 @@ export const get86 = oc
|
||||
)
|
||||
|
||||
export const bySubscriptionBuilderId5 = {
|
||||
get: get86,
|
||||
get: get83,
|
||||
}
|
||||
|
||||
export const builder = {
|
||||
@ -3912,7 +3861,7 @@ export const builder = {
|
||||
/**
|
||||
* List all trigger subscriptions for the current tenant's provider
|
||||
*/
|
||||
export const get87 = oc
|
||||
export const get84 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3925,13 +3874,13 @@ export const get87 = oc
|
||||
.output(zGetWorkspacesCurrentTriggerProviderByProviderSubscriptionsListResponse)
|
||||
|
||||
export const list4 = {
|
||||
get: get87,
|
||||
get: get84,
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiate OAuth authorization flow for a trigger provider
|
||||
*/
|
||||
export const get88 = oc
|
||||
export const get85 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -3948,7 +3897,7 @@ export const get88 = oc
|
||||
.output(zGetWorkspacesCurrentTriggerProviderByProviderSubscriptionsOauthAuthorizeResponse)
|
||||
|
||||
export const authorize = {
|
||||
get: get88,
|
||||
get: get85,
|
||||
}
|
||||
|
||||
export const oauth3 = {
|
||||
@ -3995,7 +3944,6 @@ export const subscriptions = {
|
||||
}
|
||||
|
||||
export const byProvider3 = {
|
||||
icon: icon3,
|
||||
info: info3,
|
||||
oauth: oauth2,
|
||||
subscriptions,
|
||||
@ -4065,7 +4013,7 @@ export const triggerProvider = {
|
||||
/**
|
||||
* List all trigger providers for the current tenant
|
||||
*/
|
||||
export const get89 = oc
|
||||
export const get86 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -4077,7 +4025,7 @@ export const get89 = oc
|
||||
.output(zGetWorkspacesCurrentTriggersResponse)
|
||||
|
||||
export const triggers = {
|
||||
get: get89,
|
||||
get: get86,
|
||||
}
|
||||
|
||||
export const post71 = oc
|
||||
@ -4177,7 +4125,7 @@ export const switch3 = {
|
||||
post: post75,
|
||||
}
|
||||
|
||||
export const get90 = oc
|
||||
export const get87 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -4189,7 +4137,7 @@ export const get90 = oc
|
||||
.output(zGetWorkspacesByTenantIdModelProvidersByProviderByIconTypeByLangResponse)
|
||||
|
||||
export const byLang = {
|
||||
get: get90,
|
||||
get: get87,
|
||||
}
|
||||
|
||||
export const byIconType = {
|
||||
@ -4208,7 +4156,7 @@ export const byTenantId = {
|
||||
modelProviders: modelProviders2,
|
||||
}
|
||||
|
||||
export const get91 = oc
|
||||
export const get88 = oc
|
||||
.route({
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
@ -4219,7 +4167,7 @@ export const get91 = oc
|
||||
.output(zGetWorkspacesResponse)
|
||||
|
||||
export const workspaces = {
|
||||
get: get91,
|
||||
get: get88,
|
||||
current,
|
||||
customConfig,
|
||||
info: info4,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -10,13 +10,21 @@ type SwaggerSchema = JsonObject & {
|
||||
$ref?: string
|
||||
}
|
||||
|
||||
type OpenApiMediaType = JsonObject & {
|
||||
schema?: SwaggerSchema
|
||||
}
|
||||
|
||||
type OpenApiResponse = JsonObject & {
|
||||
content?: Record<string, OpenApiMediaType>
|
||||
}
|
||||
|
||||
type OpenApiComponents = JsonObject & {
|
||||
schemas?: Record<string, SwaggerSchema>
|
||||
}
|
||||
|
||||
type SwaggerOperation = JsonObject & {
|
||||
operationId?: string
|
||||
responses?: Record<string, unknown>
|
||||
responses?: Record<string, OpenApiResponse>
|
||||
}
|
||||
|
||||
type SwaggerDocument = JsonObject & {
|
||||
@ -52,6 +60,17 @@ const currentDir = path.dirname(fileURLToPath(import.meta.url))
|
||||
const apiOpenApiDir = path.resolve(currentDir, 'openapi')
|
||||
|
||||
const operationMethods = new Set(['delete', 'get', 'patch', 'post', 'put'])
|
||||
const pydanticDecimalStringPattern = '^(?!^[-+.]*$)[+-]?0*\\d*\\.?\\d*$'
|
||||
const codegenSafeDecimalStringPattern = '^(?![-+.]*$)[+-]?0*\\d*\\.?\\d*$'
|
||||
|
||||
const opaqueJsonContent = (): Record<string, OpenApiMediaType> => ({
|
||||
'application/json': {
|
||||
schema: {
|
||||
additionalProperties: true,
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const apiSpecs: ApiSpec[] = [
|
||||
{ filename: 'console-openapi.json', name: 'console' },
|
||||
@ -182,6 +201,46 @@ const addOperationIds = (document: SwaggerDocument) => {
|
||||
}
|
||||
}
|
||||
|
||||
const isOpaqueContractResponse = (response: OpenApiResponse) => {
|
||||
const content = response.content
|
||||
if (!isObject(content))
|
||||
return false
|
||||
|
||||
return Object.entries(content).some(([mediaType, media]) => {
|
||||
if (!isObject(media))
|
||||
return false
|
||||
|
||||
return (mediaType === 'application/json' || mediaType === 'text/event-stream') && !('schema' in media)
|
||||
})
|
||||
}
|
||||
|
||||
const hasOpaqueContractSuccessResponse = (operation: SwaggerOperation) => {
|
||||
return Object.entries(operation.responses ?? {}).some(([status, response]) => {
|
||||
return /^2\d\d$/.test(status) && isObject(response) && isOpaqueContractResponse(response)
|
||||
})
|
||||
}
|
||||
|
||||
const normalizeOpaqueContractResponses = (document: SwaggerDocument) => {
|
||||
// Some backend endpoints has no schema (e.g. external) and will trap heyapi here
|
||||
// So we forge an opaque schema here
|
||||
for (const pathItem of Object.values(document.paths ?? {})) {
|
||||
for (const [method, operation] of Object.entries(pathItem)) {
|
||||
if (!operationMethods.has(method) || !isObject(operation))
|
||||
continue
|
||||
|
||||
const swaggerOperation = operation as SwaggerOperation
|
||||
if (!hasOpaqueContractSuccessResponse(swaggerOperation))
|
||||
continue
|
||||
|
||||
Object.values(swaggerOperation.responses ?? {})
|
||||
.filter(response => isObject(response) && isOpaqueContractResponse(response))
|
||||
.forEach((response) => {
|
||||
response.content = opaqueJsonContent()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const hasSuccessResponse = (operation: SwaggerOperation) => {
|
||||
return Object.entries(operation.responses ?? {}).some(([status, response]) => {
|
||||
if (!/^2\d\d$/.test(status))
|
||||
@ -215,6 +274,7 @@ const filterContractOperations = (document: SwaggerDocument) => {
|
||||
}
|
||||
|
||||
const normalizeApiSwagger = (document: SwaggerDocument) => {
|
||||
normalizeOpaqueContractResponses(document)
|
||||
filterContractOperations(document)
|
||||
addOperationIds(document)
|
||||
|
||||
@ -380,10 +440,20 @@ const createApiConfig = (job: ApiJob): UserConfig => ({
|
||||
'name': 'zod',
|
||||
'~resolvers': {
|
||||
string: (ctx) => {
|
||||
if (ctx.schema.format !== 'binary')
|
||||
return undefined
|
||||
if (ctx.schema.format === 'binary')
|
||||
return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File')))
|
||||
|
||||
return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File')))
|
||||
if (ctx.schema.pattern === pydanticDecimalStringPattern) {
|
||||
// the pydantic generated regex will emit error like
|
||||
// regexp/no-useless-assertions, so patch the regex here
|
||||
return $(ctx.symbols.z)
|
||||
.attr('string')
|
||||
.call()
|
||||
.attr('regex')
|
||||
.call($.regexp(codegenSafeDecimalStringPattern))
|
||||
}
|
||||
|
||||
return undefined
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user