feat(trigger): enhance subscription builder management and update API

- Introduced `SubscriptionBuilderUpdater` class to streamline updates to subscription builders, encapsulating properties like name, parameters, and credentials.
- Refactored API endpoints to utilize the new updater class, improving code clarity and maintainability.
- Adjusted OAuth handling to create and update subscription builders more effectively, ensuring proper credential management.

This change enhances the overall functionality and organization of the trigger subscription builder API.
This commit is contained in:
Harry 2025-09-08 15:09:47 +08:00
parent a799b54b9e
commit eb95c5cd07
4 changed files with 109 additions and 65 deletions

View File

@ -12,6 +12,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.trigger.entities.entities import SubscriptionBuilderUpdater
from extensions.ext_database import db
from libs.login import current_user, login_required
from models.account import Account
@ -71,22 +72,16 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("credentials", type=dict, required=False, nullable=False, location="json")
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
credentials = args.get("credentials", {})
credential_type = CredentialType.API_KEY if credentials else CredentialType.UNAUTHORIZED
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
name=args.get("name", None),
credentials=credentials,
credential_type=credential_type,
credential_expires_at=-1,
expires_at=-1,
)
return jsonable_encoder({"subscription_builder": subscription_builder})
except ValueError as e:
@ -108,7 +103,20 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=args.get("credentials", None),
),
)
TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
@ -147,10 +155,12 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
credentials=args.get("credentials", None),
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
credentials=args.get("credentials", None),
),
)
)
except Exception as e:
@ -188,7 +198,27 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
),
)
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
@ -263,6 +293,14 @@ class TriggerOAuthAuthorizeApi(Resource):
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
# Create subscription builder
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=tenant_id,
user_id=user.id,
provider_id=provider_id,
credential_type=CredentialType.OAUTH2,
)
# Create OAuth handler and proxy context
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
@ -270,6 +308,9 @@ class TriggerOAuthAuthorizeApi(Resource):
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider_name,
extra_data={
"subscription_builder_id": subscription_builder.id,
},
)
# Build redirect URI for callback
@ -284,24 +325,13 @@ class TriggerOAuthAuthorizeApi(Resource):
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
# Create subscription builder
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=tenant_id,
user_id=user.id,
provider_id=provider_id,
credentials={},
credential_type=CredentialType.OAUTH2,
credential_expires_at=0,
expires_at=0,
name=f"{provider_name} OAuth Authentication",
)
# Create response with cookie
response = make_response(
jsonable_encoder(
{
"authorization_url": authorization_url_response.authorization_url,
"subscription_builder": subscription_builder,
"subscription_builder_id": subscription_builder.id,
}
)
)
@ -339,6 +369,7 @@ class TriggerOAuthCallbackApi(Resource):
provider_name = provider_id.provider_name
user_id = context.get("user_id")
tenant_id = context.get("tenant_id")
subscription_builder_id = context.get("subscription_builder_id")
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
@ -369,19 +400,18 @@ class TriggerOAuthCallbackApi(Resource):
if not credentials:
raise Exception("Failed to get OAuth credentials")
# Save OAuth credentials to database
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
# Update subscription builder
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
credentials=credentials,
credential_type=CredentialType.OAUTH2,
credential_expires_at=expires_at,
expires_at=expires_at,
name=f"{provider_name} OAuth Authentication",
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=credentials,
credential_expires_at=expires_at,
),
)
# Redirect to OAuth callback page
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback?subscription_id={subscription_builder.id}")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class TriggerOAuthClientManageApi(Resource):

View File

@ -217,6 +217,35 @@ class SubscriptionBuilder(BaseModel):
)
class SubscriptionBuilderUpdater(BaseModel):
name: str | None = Field(default=None, description="The name of the subscription builder")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters of the subscription builder")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties of the subscription builder")
credentials: Mapping[str, str] | None = Field(
default=None, description="The credentials of the subscription builder"
)
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
credential_expires_at: int | None = Field(
default=None, description="The credential expires at of the subscription builder"
)
expires_at: int | None = Field(default=None, description="The expires at of the subscription builder")
def update(self, subscription_builder: SubscriptionBuilder) -> None:
if self.name:
subscription_builder.name = self.name
if self.parameters:
subscription_builder.parameters = self.parameters
if self.properties:
subscription_builder.properties = self.properties
if self.credentials:
subscription_builder.credentials = self.credentials
if self.credential_type:
subscription_builder.credential_type = self.credential_type
if self.credential_expires_at:
subscription_builder.credential_expires_at = self.credential_expires_at
if self.expires_at:
subscription_builder.expires_at = self.expires_at
# Export all entities
__all__ = [
"OAuthSchema",

View File

@ -11,7 +11,7 @@ class OAuthProxyService(BasePluginClient):
__KEY_PREFIX__ = "oauth_proxy_context:"
@staticmethod
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str):
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str, extra_data: dict = {}):
"""
Create a proxy context for an OAuth 2.0 authorization request.
@ -26,6 +26,7 @@ class OAuthProxyService(BasePluginClient):
"""
context_id = str(uuid.uuid4())
data = {
**extra_data,
"user_id": user_id,
"plugin_id": plugin_id,
"tenant_id": tenant_id,

View File

@ -14,6 +14,7 @@ from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity
from core.trigger.entities.entities import (
RequestLog,
SubscriptionBuilder,
SubscriptionBuilderUpdater,
)
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
@ -64,21 +65,18 @@ class TriggerSubscriptionBuilderService:
return {"verified": bool(subscription_builder.credentials)}
if subscription_builder.credential_type == CredentialType.API_KEY:
credentials_to_validate = subscription_builder.credentials
try:
provider_controller.validate_credentials(user_id, subscription_builder.credentials)
return {"verified": True}
provider_controller.validate_credentials(user_id, credentials_to_validate)
except ToolProviderCredentialValidationError as e:
raise ValueError(f"Invalid credentials: {e}")
return {"verified": True}
return {"verified": True}
@classmethod
def build_trigger_subscription_builder(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
subscription_builder_id: str,
cls, tenant_id: str, user_id: str, provider_id: TriggerProviderID, subscription_builder_id: str
) -> None:
"""Build a trigger subscription builder"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
@ -143,11 +141,7 @@ class TriggerSubscriptionBuilderService:
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
credentials: Mapping[str, str],
credential_type: CredentialType,
credential_expires_at: int,
expires_at: int,
name: str | None,
) -> SubscriptionBuilder:
"""
Add a new trigger subscription validation.
@ -160,17 +154,17 @@ class TriggerSubscriptionBuilderService:
subscription_id = str(uuid.uuid4())
subscription_builder = SubscriptionBuilder(
id=subscription_id,
name=name or "",
name=None,
endpoint_id=subscription_id,
tenant_id=tenant_id,
user_id=user_id,
provider_id=str(provider_id),
parameters=subscription_schema.get_default_parameters(),
properties=subscription_schema.get_default_properties(),
credentials=credentials,
credentials={},
credential_type=credential_type,
credential_expires_at=credential_expires_at,
expires_at=expires_at,
credential_expires_at=-1,
expires_at=-1,
)
cache_key = cls.encode_cache_key(subscription_id)
redis_client.setex(
@ -184,10 +178,7 @@ class TriggerSubscriptionBuilderService:
tenant_id: str,
provider_id: TriggerProviderID,
subscription_builder_id: str,
name: str | None,
parameters: Mapping[str, Any] | None,
properties: Mapping[str, Any] | None,
credentials: Mapping[str, str] | None,
subscription_builder_updater: SubscriptionBuilderUpdater,
) -> SubscriptionBuilderApiEntity:
"""
Update a trigger subscription validation.
@ -198,23 +189,16 @@ class TriggerSubscriptionBuilderService:
raise ValueError(f"Provider {provider_id} not found")
cache_key = cls.encode_cache_key(subscription_id)
subscription_builder = cls.get_subscription_builder(subscription_id)
if not subscription_builder or subscription_builder.tenant_id != tenant_id:
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
raise ValueError(f"Subscription {subscription_id} expired or not found")
if name:
subscription_builder.name = name
if parameters:
subscription_builder.parameters = parameters
if properties:
subscription_builder.properties = properties
if credentials:
subscription_builder.credentials = credentials
subscription_builder_updater.update(subscription_builder_cache)
redis_client.setex(
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder_cache.model_dump_json()
)
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder)
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache)
@classmethod
def builder_to_api_entity(