diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index a26f5231ae..9a40017d0d 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -12,6 +12,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import TriggerProviderID from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler +from core.trigger.entities.entities import SubscriptionBuilderUpdater from extensions.ext_database import db from libs.login import current_user, login_required from models.account import Account @@ -71,22 +72,16 @@ class TriggerSubscriptionBuilderCreateApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, nullable=True, location="json") - parser.add_argument("credentials", type=dict, required=False, nullable=False, location="json") + parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json") args = parser.parse_args() try: - credentials = args.get("credentials", {}) - credential_type = CredentialType.API_KEY if credentials else CredentialType.UNAUTHORIZED + credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value) subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( tenant_id=user.current_tenant_id, user_id=user.id, provider_id=TriggerProviderID(provider), - name=args.get("name", None), - credentials=credentials, credential_type=credential_type, - credential_expires_at=-1, - expires_at=-1, ) return jsonable_encoder({"subscription_builder": subscription_builder}) except ValueError as e: @@ -108,7 +103,20 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): if not user.is_admin_or_owner: raise Forbidden() + parser = reqparse.RequestParser() + # The credentials of the subscription builder + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + args = parser.parse_args() + try: + TriggerSubscriptionBuilderService.update_trigger_subscription_builder( + tenant_id=user.current_tenant_id, + provider_id=TriggerProviderID(provider), + subscription_builder_id=subscription_builder_id, + subscription_builder_updater=SubscriptionBuilderUpdater( + credentials=args.get("credentials", None), + ), + ) TriggerSubscriptionBuilderService.verify_trigger_subscription_builder( tenant_id=user.current_tenant_id, user_id=user.id, @@ -147,10 +155,12 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, - name=args.get("name", None), - parameters=args.get("parameters", None), - properties=args.get("properties", None), - credentials=args.get("credentials", None), + subscription_builder_updater=SubscriptionBuilderUpdater( + name=args.get("name", None), + parameters=args.get("parameters", None), + properties=args.get("properties", None), + credentials=args.get("credentials", None), + ), ) ) except Exception as e: @@ -188,7 +198,27 @@ class TriggerSubscriptionBuilderBuildApi(Resource): if not user.is_admin_or_owner: raise Forbidden() + parser = reqparse.RequestParser() + # The name of the subscription builder + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + # The parameters of the subscription builder + parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json") + # The properties of the subscription builder + parser.add_argument("properties", type=dict, required=False, nullable=True, location="json") + # The credentials of the subscription builder + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + args = parser.parse_args() try: + TriggerSubscriptionBuilderService.update_trigger_subscription_builder( + tenant_id=user.current_tenant_id, + provider_id=TriggerProviderID(provider), + subscription_builder_id=subscription_builder_id, + subscription_builder_updater=SubscriptionBuilderUpdater( + name=args.get("name", None), + parameters=args.get("parameters", None), + properties=args.get("properties", None), + ), + ) TriggerSubscriptionBuilderService.build_trigger_subscription_builder( tenant_id=user.current_tenant_id, user_id=user.id, @@ -263,6 +293,14 @@ class TriggerOAuthAuthorizeApi(Resource): if oauth_client_params is None: raise Forbidden("No OAuth client configuration found for this trigger provider") + # Create subscription builder + subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( + tenant_id=tenant_id, + user_id=user.id, + provider_id=provider_id, + credential_type=CredentialType.OAUTH2, + ) + # Create OAuth handler and proxy context oauth_handler = OAuthHandler() context_id = OAuthProxyService.create_proxy_context( @@ -270,6 +308,9 @@ class TriggerOAuthAuthorizeApi(Resource): tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, + extra_data={ + "subscription_builder_id": subscription_builder.id, + }, ) # Build redirect URI for callback @@ -284,24 +325,13 @@ class TriggerOAuthAuthorizeApi(Resource): redirect_uri=redirect_uri, system_credentials=oauth_client_params, ) - # Create subscription builder - subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( - tenant_id=tenant_id, - user_id=user.id, - provider_id=provider_id, - credentials={}, - credential_type=CredentialType.OAUTH2, - credential_expires_at=0, - expires_at=0, - name=f"{provider_name} OAuth Authentication", - ) # Create response with cookie response = make_response( jsonable_encoder( { "authorization_url": authorization_url_response.authorization_url, - "subscription_builder": subscription_builder, + "subscription_builder_id": subscription_builder.id, } ) ) @@ -339,6 +369,7 @@ class TriggerOAuthCallbackApi(Resource): provider_name = provider_id.provider_name user_id = context.get("user_id") tenant_id = context.get("tenant_id") + subscription_builder_id = context.get("subscription_builder_id") # Get OAuth client configuration oauth_client_params = TriggerProviderService.get_oauth_client( @@ -369,19 +400,18 @@ class TriggerOAuthCallbackApi(Resource): if not credentials: raise Exception("Failed to get OAuth credentials") - # Save OAuth credentials to database - subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( + # Update subscription builder + TriggerSubscriptionBuilderService.update_trigger_subscription_builder( tenant_id=tenant_id, - user_id=user_id, provider_id=provider_id, - credentials=credentials, - credential_type=CredentialType.OAUTH2, - credential_expires_at=expires_at, - expires_at=expires_at, - name=f"{provider_name} OAuth Authentication", + subscription_builder_id=subscription_builder_id, + subscription_builder_updater=SubscriptionBuilderUpdater( + credentials=credentials, + credential_expires_at=expires_at, + ), ) # Redirect to OAuth callback page - return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback?subscription_id={subscription_builder.id}") + return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") class TriggerOAuthClientManageApi(Resource): diff --git a/api/core/trigger/entities/entities.py b/api/core/trigger/entities/entities.py index 84c37ccc72..82b1aa29e2 100644 --- a/api/core/trigger/entities/entities.py +++ b/api/core/trigger/entities/entities.py @@ -217,6 +217,35 @@ class SubscriptionBuilder(BaseModel): ) +class SubscriptionBuilderUpdater(BaseModel): + name: str | None = Field(default=None, description="The name of the subscription builder") + parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters of the subscription builder") + properties: Mapping[str, Any] | None = Field(default=None, description="The properties of the subscription builder") + credentials: Mapping[str, str] | None = Field( + default=None, description="The credentials of the subscription builder" + ) + credential_type: str | None = Field(default=None, description="The credential type of the subscription builder") + credential_expires_at: int | None = Field( + default=None, description="The credential expires at of the subscription builder" + ) + expires_at: int | None = Field(default=None, description="The expires at of the subscription builder") + + def update(self, subscription_builder: SubscriptionBuilder) -> None: + if self.name: + subscription_builder.name = self.name + if self.parameters: + subscription_builder.parameters = self.parameters + if self.properties: + subscription_builder.properties = self.properties + if self.credentials: + subscription_builder.credentials = self.credentials + if self.credential_type: + subscription_builder.credential_type = self.credential_type + if self.credential_expires_at: + subscription_builder.credential_expires_at = self.credential_expires_at + if self.expires_at: + subscription_builder.expires_at = self.expires_at + # Export all entities __all__ = [ "OAuthSchema", diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 055fbb8138..635297a099 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -11,7 +11,7 @@ class OAuthProxyService(BasePluginClient): __KEY_PREFIX__ = "oauth_proxy_context:" @staticmethod - def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str): + def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str, extra_data: dict = {}): """ Create a proxy context for an OAuth 2.0 authorization request. @@ -26,6 +26,7 @@ class OAuthProxyService(BasePluginClient): """ context_id = str(uuid.uuid4()) data = { + **extra_data, "user_id": user_id, "plugin_id": plugin_id, "tenant_id": tenant_id, diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py index bdf4d37253..1bce16d95c 100644 --- a/api/services/trigger/trigger_subscription_builder_service.py +++ b/api/services/trigger/trigger_subscription_builder_service.py @@ -14,6 +14,7 @@ from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity from core.trigger.entities.entities import ( RequestLog, SubscriptionBuilder, + SubscriptionBuilderUpdater, ) from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager @@ -64,21 +65,18 @@ class TriggerSubscriptionBuilderService: return {"verified": bool(subscription_builder.credentials)} if subscription_builder.credential_type == CredentialType.API_KEY: + credentials_to_validate = subscription_builder.credentials try: - provider_controller.validate_credentials(user_id, subscription_builder.credentials) - return {"verified": True} + provider_controller.validate_credentials(user_id, credentials_to_validate) except ToolProviderCredentialValidationError as e: raise ValueError(f"Invalid credentials: {e}") + return {"verified": True} return {"verified": True} @classmethod def build_trigger_subscription_builder( - cls, - tenant_id: str, - user_id: str, - provider_id: TriggerProviderID, - subscription_builder_id: str, + cls, tenant_id: str, user_id: str, provider_id: TriggerProviderID, subscription_builder_id: str ) -> None: """Build a trigger subscription builder""" provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) @@ -143,11 +141,7 @@ class TriggerSubscriptionBuilderService: tenant_id: str, user_id: str, provider_id: TriggerProviderID, - credentials: Mapping[str, str], credential_type: CredentialType, - credential_expires_at: int, - expires_at: int, - name: str | None, ) -> SubscriptionBuilder: """ Add a new trigger subscription validation. @@ -160,17 +154,17 @@ class TriggerSubscriptionBuilderService: subscription_id = str(uuid.uuid4()) subscription_builder = SubscriptionBuilder( id=subscription_id, - name=name or "", + name=None, endpoint_id=subscription_id, tenant_id=tenant_id, user_id=user_id, provider_id=str(provider_id), parameters=subscription_schema.get_default_parameters(), properties=subscription_schema.get_default_properties(), - credentials=credentials, + credentials={}, credential_type=credential_type, - credential_expires_at=credential_expires_at, - expires_at=expires_at, + credential_expires_at=-1, + expires_at=-1, ) cache_key = cls.encode_cache_key(subscription_id) redis_client.setex( @@ -184,10 +178,7 @@ class TriggerSubscriptionBuilderService: tenant_id: str, provider_id: TriggerProviderID, subscription_builder_id: str, - name: str | None, - parameters: Mapping[str, Any] | None, - properties: Mapping[str, Any] | None, - credentials: Mapping[str, str] | None, + subscription_builder_updater: SubscriptionBuilderUpdater, ) -> SubscriptionBuilderApiEntity: """ Update a trigger subscription validation. @@ -198,23 +189,16 @@ class TriggerSubscriptionBuilderService: raise ValueError(f"Provider {provider_id} not found") cache_key = cls.encode_cache_key(subscription_id) - subscription_builder = cls.get_subscription_builder(subscription_id) - if not subscription_builder or subscription_builder.tenant_id != tenant_id: + subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id) + if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id: raise ValueError(f"Subscription {subscription_id} expired or not found") - if name: - subscription_builder.name = name - if parameters: - subscription_builder.parameters = parameters - if properties: - subscription_builder.properties = properties - if credentials: - subscription_builder.credentials = credentials + subscription_builder_updater.update(subscription_builder_cache) redis_client.setex( - cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json() + cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder_cache.model_dump_json() ) - return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder) + return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache) @classmethod def builder_to_api_entity(