mirror of https://github.com/langgenius/dify.git
feat(trigger): enhance subscription builder management and update API
- Introduced `SubscriptionBuilderUpdater` class to streamline updates to subscription builders, encapsulating properties like name, parameters, and credentials. - Refactored API endpoints to utilize the new updater class, improving code clarity and maintainability. - Adjusted OAuth handling to create and update subscription builders more effectively, ensuring proper credential management. This change enhances the overall functionality and organization of the trigger subscription builder API.
This commit is contained in:
parent
a799b54b9e
commit
eb95c5cd07
|
|
@ -12,6 +12,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
|
|
@ -71,22 +72,16 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
credentials = args.get("credentials", {})
|
||||
credential_type = CredentialType.API_KEY if credentials else CredentialType.UNAUTHORIZED
|
||||
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
name=args.get("name", None),
|
||||
credentials=credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=-1,
|
||||
expires_at=-1,
|
||||
)
|
||||
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||
except ValueError as e:
|
||||
|
|
@ -108,7 +103,20 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
|||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
|
|
@ -147,10 +155,12 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -188,7 +198,27 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
),
|
||||
)
|
||||
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
|
|
@ -263,6 +293,14 @@ class TriggerOAuthAuthorizeApi(Resource):
|
|||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Create subscription builder
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=provider_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
)
|
||||
|
||||
# Create OAuth handler and proxy context
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
|
|
@ -270,6 +308,9 @@ class TriggerOAuthAuthorizeApi(Resource):
|
|||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
extra_data={
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
},
|
||||
)
|
||||
|
||||
# Build redirect URI for callback
|
||||
|
|
@ -284,24 +325,13 @@ class TriggerOAuthAuthorizeApi(Resource):
|
|||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
# Create subscription builder
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=provider_id,
|
||||
credentials={},
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
credential_expires_at=0,
|
||||
expires_at=0,
|
||||
name=f"{provider_name} OAuth Authentication",
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(
|
||||
jsonable_encoder(
|
||||
{
|
||||
"authorization_url": authorization_url_response.authorization_url,
|
||||
"subscription_builder": subscription_builder,
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
|
@ -339,6 +369,7 @@ class TriggerOAuthCallbackApi(Resource):
|
|||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
|
|
@ -369,19 +400,18 @@ class TriggerOAuthCallbackApi(Resource):
|
|||
if not credentials:
|
||||
raise Exception("Failed to get OAuth credentials")
|
||||
|
||||
# Save OAuth credentials to database
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
# Update subscription builder
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
credentials=credentials,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
credential_expires_at=expires_at,
|
||||
expires_at=expires_at,
|
||||
name=f"{provider_name} OAuth Authentication",
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=credentials,
|
||||
credential_expires_at=expires_at,
|
||||
),
|
||||
)
|
||||
# Redirect to OAuth callback page
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback?subscription_id={subscription_builder.id}")
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
|
|
|
|||
|
|
@ -217,6 +217,35 @@ class SubscriptionBuilder(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class SubscriptionBuilderUpdater(BaseModel):
|
||||
name: str | None = Field(default=None, description="The name of the subscription builder")
|
||||
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters of the subscription builder")
|
||||
properties: Mapping[str, Any] | None = Field(default=None, description="The properties of the subscription builder")
|
||||
credentials: Mapping[str, str] | None = Field(
|
||||
default=None, description="The credentials of the subscription builder"
|
||||
)
|
||||
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
|
||||
credential_expires_at: int | None = Field(
|
||||
default=None, description="The credential expires at of the subscription builder"
|
||||
)
|
||||
expires_at: int | None = Field(default=None, description="The expires at of the subscription builder")
|
||||
|
||||
def update(self, subscription_builder: SubscriptionBuilder) -> None:
|
||||
if self.name:
|
||||
subscription_builder.name = self.name
|
||||
if self.parameters:
|
||||
subscription_builder.parameters = self.parameters
|
||||
if self.properties:
|
||||
subscription_builder.properties = self.properties
|
||||
if self.credentials:
|
||||
subscription_builder.credentials = self.credentials
|
||||
if self.credential_type:
|
||||
subscription_builder.credential_type = self.credential_type
|
||||
if self.credential_expires_at:
|
||||
subscription_builder.credential_expires_at = self.credential_expires_at
|
||||
if self.expires_at:
|
||||
subscription_builder.expires_at = self.expires_at
|
||||
|
||||
# Export all entities
|
||||
__all__ = [
|
||||
"OAuthSchema",
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class OAuthProxyService(BasePluginClient):
|
|||
__KEY_PREFIX__ = "oauth_proxy_context:"
|
||||
|
||||
@staticmethod
|
||||
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str):
|
||||
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str, extra_data: dict = {}):
|
||||
"""
|
||||
Create a proxy context for an OAuth 2.0 authorization request.
|
||||
|
||||
|
|
@ -26,6 +26,7 @@ class OAuthProxyService(BasePluginClient):
|
|||
"""
|
||||
context_id = str(uuid.uuid4())
|
||||
data = {
|
||||
**extra_data,
|
||||
"user_id": user_id,
|
||||
"plugin_id": plugin_id,
|
||||
"tenant_id": tenant_id,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity
|
|||
from core.trigger.entities.entities import (
|
||||
RequestLog,
|
||||
SubscriptionBuilder,
|
||||
SubscriptionBuilderUpdater,
|
||||
)
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
|
|
@ -64,21 +65,18 @@ class TriggerSubscriptionBuilderService:
|
|||
return {"verified": bool(subscription_builder.credentials)}
|
||||
|
||||
if subscription_builder.credential_type == CredentialType.API_KEY:
|
||||
credentials_to_validate = subscription_builder.credentials
|
||||
try:
|
||||
provider_controller.validate_credentials(user_id, subscription_builder.credentials)
|
||||
return {"verified": True}
|
||||
provider_controller.validate_credentials(user_id, credentials_to_validate)
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
raise ValueError(f"Invalid credentials: {e}")
|
||||
return {"verified": True}
|
||||
|
||||
return {"verified": True}
|
||||
|
||||
@classmethod
|
||||
def build_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
cls, tenant_id: str, user_id: str, provider_id: TriggerProviderID, subscription_builder_id: str
|
||||
) -> None:
|
||||
"""Build a trigger subscription builder"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
|
|
@ -143,11 +141,7 @@ class TriggerSubscriptionBuilderService:
|
|||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
credential_expires_at: int,
|
||||
expires_at: int,
|
||||
name: str | None,
|
||||
) -> SubscriptionBuilder:
|
||||
"""
|
||||
Add a new trigger subscription validation.
|
||||
|
|
@ -160,17 +154,17 @@ class TriggerSubscriptionBuilderService:
|
|||
subscription_id = str(uuid.uuid4())
|
||||
subscription_builder = SubscriptionBuilder(
|
||||
id=subscription_id,
|
||||
name=name or "",
|
||||
name=None,
|
||||
endpoint_id=subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=subscription_schema.get_default_parameters(),
|
||||
properties=subscription_schema.get_default_properties(),
|
||||
credentials=credentials,
|
||||
credentials={},
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=credential_expires_at,
|
||||
expires_at=expires_at,
|
||||
credential_expires_at=-1,
|
||||
expires_at=-1,
|
||||
)
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
redis_client.setex(
|
||||
|
|
@ -184,10 +178,7 @@ class TriggerSubscriptionBuilderService:
|
|||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
name: str | None,
|
||||
parameters: Mapping[str, Any] | None,
|
||||
properties: Mapping[str, Any] | None,
|
||||
credentials: Mapping[str, str] | None,
|
||||
subscription_builder_updater: SubscriptionBuilderUpdater,
|
||||
) -> SubscriptionBuilderApiEntity:
|
||||
"""
|
||||
Update a trigger subscription validation.
|
||||
|
|
@ -198,23 +189,16 @@ class TriggerSubscriptionBuilderService:
|
|||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder = cls.get_subscription_builder(subscription_id)
|
||||
if not subscription_builder or subscription_builder.tenant_id != tenant_id:
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
if name:
|
||||
subscription_builder.name = name
|
||||
if parameters:
|
||||
subscription_builder.parameters = parameters
|
||||
if properties:
|
||||
subscription_builder.properties = properties
|
||||
if credentials:
|
||||
subscription_builder.credentials = credentials
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
|
||||
redis_client.setex(
|
||||
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
|
||||
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder)
|
||||
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache)
|
||||
|
||||
@classmethod
|
||||
def builder_to_api_entity(
|
||||
|
|
|
|||
Loading…
Reference in New Issue