import json import logging import re from collections.abc import Mapping from typing import Any, Optional from sqlalchemy import desc from sqlalchemy.orm import Session from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_encryption import create_provider_encrypter from core.plugin.entities.plugin import TriggerProviderID from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderCredentialApiEntity from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import ( create_trigger_provider_encrypter_for_credential, create_trigger_provider_oauth_encrypter, ) from extensions.ext_database import db from extensions.ext_redis import redis_client from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerProvider from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) class TriggerProviderService: """Service for managing trigger providers and credentials""" __MAX_TRIGGER_PROVIDER_COUNT__ = 100 @classmethod def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]: """List all trigger providers for the current tenant""" return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)] @classmethod def list_trigger_provider_credentials( cls, tenant_id: str, provider_id: TriggerProviderID ) -> list[TriggerProviderCredentialApiEntity]: """List all trigger providers for the current tenant""" credentials: list[TriggerProviderCredentialApiEntity] = [] with Session(db.engine, autoflush=False) as session: credentials_db = ( session.query(TriggerProvider) .filter_by(tenant_id=tenant_id, provider_id=str(provider_id)) .order_by(desc(TriggerProvider.created_at)) .all() ) credentials = [credential.to_api_entity() for credential in credentials_db] provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) for credential in credentials: encrypter, _ = create_trigger_provider_encrypter_for_credential( tenant_id=tenant_id, controller=provider_controller, credential=credential, ) credential.credentials = encrypter.decrypt(credential.credentials) return credentials @classmethod def add_trigger_provider( cls, tenant_id: str, user_id: str, provider_id: TriggerProviderID, credential_type: CredentialType, credentials: dict, name: Optional[str] = None, expires_at: int = -1, ) -> dict: """ Add a new trigger provider with credentials. Supports multiple credential instances per provider. :param tenant_id: Tenant ID :param provider_id: Provider identifier (e.g., "plugin_id/provider_name") :param credential_type: Type of credential (oauth or api_key) :param credentials: Credential data to encrypt and store :param name: Optional name for this credential instance :param expires_at: OAuth token expiration timestamp :return: Success response """ try: provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) with Session(db.engine) as session: # Use distributed lock to prevent race conditions lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}" with redis_client.lock(lock_key, timeout=20): # Check provider count limit provider_count = ( session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).count() ) if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__: raise ValueError( f"Maximum number of providers ({cls.__MAX_TRIGGER_PROVIDER_COUNT__}) " f"reached for {provider_id}" ) # Generate name if not provided if not name: name = cls._generate_provider_name( session=session, tenant_id=tenant_id, provider_id=provider_id, credential_type=credential_type, ) else: # Check if name already exists existing = ( session.query(TriggerProvider) .filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name) .first() ) if existing: raise ValueError(f"Credential name '{name}' already exists for this provider") encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=provider_controller.get_credential_schema_config(credential_type), cache=NoOpProviderCredentialCache(), ) # Create provider record db_provider = TriggerProvider( tenant_id=tenant_id, user_id=user_id, provider_id=provider_id, credential_type=credential_type.value, encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), name=name, expires_at=expires_at, ) session.add(db_provider) session.commit() return {"result": "success", "id": str(db_provider.id)} except Exception as e: logger.exception("Failed to add trigger provider") raise ValueError(str(e)) @classmethod def update_trigger_provider( cls, tenant_id: str, credential_id: str, credentials: Optional[dict] = None, name: Optional[str] = None, ) -> dict: """ Update an existing trigger provider's credentials or name. :param tenant_id: Tenant ID :param credential_id: Credential instance ID :param credentials: New credentials (optional) :param name: New name (optional) :return: Success response """ with Session(db.engine) as session: db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() if not db_provider: raise ValueError(f"Trigger provider credential {credential_id} not found") try: provider_controller = TriggerManager.get_trigger_provider( tenant_id, TriggerProviderID(db_provider.provider_id) ) if credentials: encrypter, cache = create_trigger_provider_encrypter_for_credential( tenant_id=tenant_id, controller=provider_controller, credential=db_provider, ) # Handle hidden values original_credentials = encrypter.decrypt(db_provider.credentials) new_credentials = { key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE) for key, value in credentials.items() } db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials)) cache.delete() # Update name if provided if name and name != db_provider.name: # Check if name already exists existing = ( session.query(TriggerProvider) .filter_by(tenant_id=tenant_id, provider_id=db_provider.provider_id, name=name) .filter(TriggerProvider.id != credential_id) .first() ) if existing: raise ValueError(f"Credential name '{name}' already exists") db_provider.name = name session.commit() return {"result": "success"} except Exception as e: session.rollback() raise ValueError(str(e)) @classmethod def delete_trigger_provider(cls, tenant_id: str, credential_id: str) -> dict: """ Delete a trigger provider credential. :param tenant_id: Tenant ID :param credential_id: Credential instance ID :return: Success response """ with Session(db.engine) as session: db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() if not db_provider: raise ValueError(f"Trigger provider credential {credential_id} not found") provider_controller = TriggerManager.get_trigger_provider( tenant_id, TriggerProviderID(db_provider.provider_id) ) # Clear cache _, cache = create_trigger_provider_encrypter_for_credential( tenant_id=tenant_id, controller=provider_controller, credential=db_provider, ) session.delete(db_provider) session.commit() cache.delete() return {"result": "success"} @classmethod def refresh_oauth_token( cls, tenant_id: str, credential_id: str, ) -> dict: """ Refresh OAuth token for a trigger provider. :param tenant_id: Tenant ID :param credential_id: Credential instance ID :return: New token info """ with Session(db.engine) as session: db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() if not db_provider: raise ValueError(f"Trigger provider credential {credential_id} not found") if db_provider.credential_type != CredentialType.OAUTH2.value: raise ValueError("Only OAuth credentials can be refreshed") provider_id = TriggerProviderID(db_provider.provider_id) provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) # Create encrypter encrypter, cache = create_trigger_provider_encrypter_for_credential( tenant_id=tenant_id, controller=provider_controller, credential=db_provider, ) # Decrypt current credentials current_credentials = encrypter.decrypt(db_provider.credentials) # Get OAuth client configuration redirect_uri = ( f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{db_provider.provider_id}/trigger/callback" ) system_credentials = cls.get_oauth_client(tenant_id, provider_id) # Refresh token oauth_handler = OAuthHandler() refreshed_credentials = oauth_handler.refresh_credentials( tenant_id=tenant_id, user_id=db_provider.user_id, plugin_id=provider_id.plugin_id, provider=provider_id.provider_name, redirect_uri=redirect_uri, system_credentials=system_credentials or {}, credentials=current_credentials, ) # Update credentials db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(dict(refreshed_credentials.credentials))) db_provider.expires_at = refreshed_credentials.expires_at session.commit() # Clear cache cache.delete() return { "result": "success", "expires_at": refreshed_credentials.expires_at, } @classmethod def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Optional[Mapping[str, Any]]: """ Get OAuth client configuration for a provider. First tries tenant-level OAuth, then falls back to system OAuth. :param tenant_id: Tenant ID :param provider_id: Provider identifier :return: OAuth client configuration or None """ provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) with Session(db.engine, autoflush=False) as session: tenant_client: TriggerOAuthTenantClient | None = ( session.query(TriggerOAuthTenantClient) .filter_by( tenant_id=tenant_id, provider=provider_id.provider_name, plugin_id=provider_id.plugin_id, enabled=True, ) .first() ) oauth_params: Mapping[str, Any] | None = None if tenant_client: encrypter, _ = create_trigger_provider_oauth_encrypter(tenant_id, provider_controller) oauth_params = encrypter.decrypt(tenant_client.oauth_params) return oauth_params is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id) if not is_verified: return oauth_params # Check for system-level OAuth client system_client: TriggerOAuthSystemClient | None = ( session.query(TriggerOAuthSystemClient) .filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name) .first() ) if system_client: try: oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") return oauth_params @classmethod def save_custom_oauth_client_params( cls, tenant_id: str, provider_id: TriggerProviderID, client_params: Optional[dict] = None, enabled: Optional[bool] = None, ) -> dict: """ Save or update custom OAuth client parameters for a trigger provider. :param tenant_id: Tenant ID :param provider_id: Provider identifier :param client_params: OAuth client parameters (client_id, client_secret, etc.) :param enabled: Enable/disable the custom OAuth client :return: Success response """ if client_params is None and enabled is None: return {"result": "success"} # Get provider controller to access schema provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) with Session(db.engine) as session: # Find existing custom client params custom_client = ( session.query(TriggerOAuthTenantClient) .filter_by( tenant_id=tenant_id, plugin_id=provider_id.plugin_id, provider=provider_id.provider_name, ) .first() ) # Create new record if doesn't exist if custom_client is None: custom_client = TriggerOAuthTenantClient( tenant_id=tenant_id, plugin_id=provider_id.plugin_id, provider=provider_id.provider_name, ) session.add(custom_client) # Update client params if provided if client_params is not None: encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) # Handle hidden values original_params = encrypter.decrypt(custom_client.oauth_params) new_params: dict = { key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE) for key, value in client_params.items() } custom_client.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params)) # Update enabled status if provided if enabled is not None: custom_client.enabled = enabled session.commit() return {"result": "success"} @classmethod def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict: """ Get custom OAuth client parameters for a trigger provider. :param tenant_id: Tenant ID :param provider_id: Provider identifier :return: Masked OAuth client parameters """ with Session(db.engine) as session: custom_client = ( session.query(TriggerOAuthTenantClient) .filter_by( tenant_id=tenant_id, plugin_id=provider_id.plugin_id, provider=provider_id.provider_name, ) .first() ) if custom_client is None: return {} # Get provider controller to access schema provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) # Create encrypter to decrypt and mask values encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) return encrypter.mask_tool_credentials(encrypter.decrypt(custom_client.oauth_params)) @classmethod def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict: """ Delete custom OAuth client parameters for a trigger provider. :param tenant_id: Tenant ID :param provider_id: Provider identifier :return: Success response """ with Session(db.engine) as session: session.query(TriggerOAuthTenantClient).filter_by( tenant_id=tenant_id, provider=provider_id.provider_name, plugin_id=provider_id.plugin_id, ).delete() session.commit() return {"result": "success"} @classmethod def is_oauth_custom_client_enabled(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool: """ Check if custom OAuth client is enabled for a trigger provider. :param tenant_id: Tenant ID :param provider_id: Provider identifier :return: True if enabled, False otherwise """ with Session(db.engine, autoflush=False) as session: custom_client = ( session.query(TriggerOAuthTenantClient) .filter_by( tenant_id=tenant_id, plugin_id=provider_id.plugin_id, provider=provider_id.provider_name, enabled=True, ) .first() ) return custom_client is not None @classmethod def _generate_provider_name( cls, session: Session, tenant_id: str, provider_id: TriggerProviderID, credential_type: CredentialType, ) -> str: """ Generate a unique name for a provider credential instance. :param session: Database session :param tenant_id: Tenant ID :param provider: Provider identifier :param credential_type: Credential type :return: Generated name """ try: db_providers = ( session.query(TriggerProvider) .filter_by( tenant_id=tenant_id, provider_id=provider_id, credential_type=credential_type.value, ) .order_by(desc(TriggerProvider.created_at)) .all() ) # Get base name base_name = credential_type.get_name() # Find existing numbered names pattern = rf"^{re.escape(base_name)}\s+(\d+)$" numbers = [] for db_provider in db_providers: if db_provider.name: match = re.match(pattern, db_provider.name.strip()) if match: numbers.append(int(match.group(1))) # Generate next number if not numbers: return f"{base_name} 1" max_number = max(numbers) return f"{base_name} {max_number + 1}" except Exception as e: logger.warning("Error generating provider name") return f"{credential_type.get_name()} 1"