diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ad878fc266..519b3f8c54 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -54,6 +54,7 @@ from .app import ( completion, conversation, conversation_variables, + enduser_auth, generator, mcp_server, message, @@ -162,6 +163,7 @@ __all__ = [ "datasource_content_preview", "email_register", "endpoint", + "enduser_auth", "extension", "external", "feature", diff --git a/api/controllers/console/app/enduser_auth.py b/api/controllers/console/app/enduser_auth.py new file mode 100644 index 0000000000..a8bd4829b1 --- /dev/null +++ b/api/controllers/console/app/enduser_auth.py @@ -0,0 +1,230 @@ +""" +End-user authentication API controllers. + +Provides API endpoints for managing end-user credentials for tool authentication. +""" + +from flask import request +from flask_restx import Resource +from werkzeug.exceptions import BadRequest + +from controllers.console import console_ns +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from core.model_runtime.utils.encoders import jsonable_encoder +from libs.login import current_account_with_tenant, login_required +from models import AppMode +from services.tools.app_auth_requirement_service import AppAuthRequirementService +from services.tools.enduser_auth_service import EndUserAuthService +from services.tools.enduser_oauth_service import EndUserOAuthService + + +@console_ns.route("/apps//auth/providers") +class AppAuthProvidersApi(Resource): + """ + Get list of authentication providers required for an app. + + Returns providers that require end-user authentication based on app configuration. + """ + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW]) + def get(self, app_model): + """Get authentication providers required for the app.""" + _, tenant_id = current_account_with_tenant() + + providers = AppAuthRequirementService.get_required_providers( + tenant_id=tenant_id, + app_id=str(app_model.id), + ) + + return jsonable_encoder(providers) + + +@console_ns.route("/apps//auth/providers//credentials") +class AppAuthProviderCredentialsApi(Resource): + """ + Manage end-user credentials for a specific provider. + + Allows listing, creating, and deleting end-user credentials. + """ + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW]) + def get(self, app_model, provider_id: str): + """List end-user's credentials for this provider.""" + user, tenant_id = current_account_with_tenant() + + # For console API, use the current account user as end_user_id + # In production, this would be the actual end-user ID from the chat/completion request + end_user_id = str(user.id) + + credentials = EndUserAuthService.list_credentials( + tenant_id=tenant_id, + end_user_id=end_user_id, + provider_id=provider_id, + ) + + return jsonable_encoder(credentials) + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW]) + def post(self, app_model, provider_id: str): + """Create a new credential (API key only).""" + user, tenant_id = current_account_with_tenant() + end_user_id = str(user.id) + + payload = request.get_json() + if not payload: + raise BadRequest("Request body is required") + + credential_type = payload.get("credential_type") + credentials = payload.get("credentials") + + if not credential_type or not credentials: + raise BadRequest("credential_type and credentials are required") + + if credential_type != "api-key": + raise BadRequest( + "Only 'api-key' credential type can be created via this endpoint. " + "Use OAuth flow for OAuth credentials." + ) + + credential = EndUserAuthService.create_api_key_credential( + tenant_id=tenant_id, + end_user_id=end_user_id, + provider_id=provider_id, + credentials=credentials, + ) + + return jsonable_encoder(credential) + + +@console_ns.route("/apps//auth/providers//credentials/") +class AppAuthProviderCredentialApi(Resource): + """ + Manage a specific end-user credential. + + Allows getting, updating, or deleting a credential. + """ + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW]) + def delete(self, app_model, provider_id: str, credential_id: str): + """Delete a credential.""" + user, tenant_id = current_account_with_tenant() + end_user_id = str(user.id) + + EndUserAuthService.delete_credential( + tenant_id=tenant_id, + end_user_id=end_user_id, + provider_id=provider_id, + credential_id=credential_id, + ) + + return {"result": "success"} + + +@console_ns.route("/apps//auth/oauth//authorization-url") +class AppAuthOAuthAuthorizationUrlApi(Resource): + """ + Get OAuth authorization URL for end-user authentication. + + Returns the URL where the user should be redirected to authenticate with the provider. + """ + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW]) + def get(self, app_model, provider_id: str): + """Get OAuth authorization URL.""" + user, tenant_id = current_account_with_tenant() + end_user_id = str(user.id) + + result = EndUserOAuthService.get_authorization_url( + end_user_id=end_user_id, + tenant_id=tenant_id, + app_id=str(app_model.id), + provider=provider_id, + ) + + # Set OAuth context cookie for callback + response = jsonable_encoder({ + "authorization_url": result["authorization_url"], + }) + + # Store context_id in response for frontend to set as cookie + response["context_id"] = result["context_id"] + + return response + + +@console_ns.route("/apps//auth/oauth//callback") +class AppAuthOAuthCallbackApi(Resource): + """ + Handle OAuth callback for end-user authentication. + + This endpoint is called by the OAuth provider after user authorization. + """ + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW]) + def get(self, app_model, provider_id: str): + """Handle OAuth callback and store credentials.""" + # Get OAuth context ID from cookie + context_id = request.cookies.get("oauth_context_id") + + if not context_id: + raise BadRequest("Missing OAuth context") + + # Get OAuth error if any + error = request.args.get("error") + error_description = request.args.get("error_description") + + if error: + raise BadRequest(f"OAuth error: {error} - {error_description}") + + # Handle callback and create credential + result = EndUserOAuthService.handle_oauth_callback( + context_id=context_id, + request=request, + ) + + return jsonable_encoder(result) + + +@console_ns.route("/apps//auth/providers//credentials//refresh") +class AppAuthProviderRefreshApi(Resource): + """ + Manually refresh OAuth token for a credential. + + This endpoint allows refreshing an expired OAuth token. + """ + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW]) + def post(self, app_model, provider_id: str, credential_id: str): + """Refresh OAuth token.""" + user, tenant_id = current_account_with_tenant() + end_user_id = str(user.id) + + result = EndUserOAuthService.refresh_oauth_token( + credential_id=credential_id, + end_user_id=end_user_id, + tenant_id=tenant_id, + ) + + return jsonable_encoder(result) diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 220feced1d..3acc4cd29e 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -3,7 +3,7 @@ from typing import Any, Union from pydantic import BaseModel, Field -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType +from core.tools.entities.tool_entities import ToolAuthType, ToolInvokeMessage, ToolProviderType class AgentToolEntity(BaseModel): @@ -17,6 +17,7 @@ class AgentToolEntity(BaseModel): tool_parameters: dict[str, Any] = Field(default_factory=dict) plugin_unique_identifier: str | None = None credential_id: str | None = None + auth_type: ToolAuthType = ToolAuthType.WORKSPACE class AgentPromptEntity(BaseModel): diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 218ffafd55..22fe656379 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -124,6 +124,7 @@ class ToolProviderCredentialApiEntity(BaseModel): default=False, description="Whether the credential is the default credential for the provider in the workspace" ) credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict) + expires_at: int = Field(default=-1, description="Unix timestamp when credential expires (-1 for no expiry)") class ToolProviderCredentialInfoApiEntity(BaseModel): diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f8213d9fd7..a09dd58a53 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -49,6 +49,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType, @@ -58,7 +59,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool -from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider +from models.tools import ApiToolProvider, BuiltinToolProvider, EndUserAuthenticationProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: @@ -78,6 +79,54 @@ class ToolManager: _builtin_providers_loaded = False _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} + @classmethod + def _refresh_oauth_credentials( + cls, + tenant_id: str, + provider_id: str, + user_id: str, + decrypted_credentials: Mapping[str, Any], + ) -> tuple[dict[str, Any], int]: + """ + Refresh OAuth credentials for a provider. + + This is a helper method to centralize the OAuth token refresh logic + used by both end-user and workspace authentication flows. + + :param tenant_id: the tenant id + :param provider_id: the provider id + :param user_id: the user id (end_user_id or workspace user_id) + :param decrypted_credentials: the current decrypted credentials + + :return: tuple of (refreshed credentials dict, expires_at timestamp) + """ + from core.plugin.impl.oauth import OAuthHandler + + # Local import to avoid circular dependency at module level + # This import is necessary but creates a cycle: tool_manager -> builtin_tools_manage_service -> tool_manager + # TODO: Break the circular dependency by refactoring service layer + from services.tools.builtin_tools_manage_service import BuiltinToolManageService + + # Parse provider ID and build OAuth configuration + tool_provider = ToolProviderID(provider_id) + provider_name = tool_provider.provider_name + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" + system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) + + # Refresh the credentials using OAuth handler + oauth_handler = OAuthHandler() + refreshed_credentials = oauth_handler.refresh_credentials( + tenant_id=tenant_id, + user_id=user_id, + plugin_id=tool_provider.plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=system_credentials or {}, + credentials=decrypted_credentials, + ) + + return refreshed_credentials.credentials, refreshed_credentials.expires_at + @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: """ @@ -165,6 +214,8 @@ class ToolManager: invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, + auth_type: ToolAuthType = ToolAuthType.WORKSPACE, + end_user_id: str | None = None, ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: """ get the tool runtime @@ -176,6 +227,8 @@ class ToolManager: :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from :param credential_id: the credential id + :param auth_type: the authentication type (workspace or end_user) + :param end_user_id: the end user id (required when auth_type is END_USER) :return: the tool """ @@ -200,6 +253,75 @@ class ToolManager: ) ), ) + + # Handle end-user authentication + if auth_type == ToolAuthType.END_USER: + if not end_user_id: + raise ToolProviderNotFoundError("end_user_id is required for END_USER auth_type") + + # Query end-user credentials + enduser_provider = ( + db.session.query(EndUserAuthenticationProvider) + .where( + EndUserAuthenticationProvider.tenant_id == tenant_id, + EndUserAuthenticationProvider.end_user_id == end_user_id, + EndUserAuthenticationProvider.provider == provider_id, + ) + .order_by(EndUserAuthenticationProvider.created_at.asc()) + .first() + ) + + if enduser_provider is None: + raise ToolProviderNotFoundError( + f"No end-user credentials found for provider {provider_id}" + ) + + # Decrypt end-user credentials + encrypter, cache = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(enduser_provider.credential_type) + ], + cache=ToolProviderCredentialsCache( + tenant_id=tenant_id, provider=provider_id, credential_id=enduser_provider.id + ), + ) + + decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(enduser_provider.credentials) + + # Handle OAuth token refresh for end-users if expired + if enduser_provider.expires_at != -1 and (enduser_provider.expires_at - 60) < int(time.time()): + # Refresh credentials using the centralized helper method + refreshed_credentials, expires_at = cls._refresh_oauth_credentials( + tenant_id=tenant_id, + provider_id=provider_id, + user_id=end_user_id, + decrypted_credentials=decrypted_credentials, + ) + + # Update the provider with refreshed credentials + enduser_provider.encrypted_credentials = json.dumps(encrypter.encrypt(refreshed_credentials)) + enduser_provider.expires_at = expires_at + db.session.commit() + decrypted_credentials = refreshed_credentials + cache.delete() + + return cast( + BuiltinTool, + builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials=dict(decrypted_credentials), + credential_type=CredentialType.of(enduser_provider.credential_type), + runtime_parameters={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ), + ) + + # Handle workspace authentication (existing logic) builtin_provider = None if isinstance(provider_controller, PluginToolProviderController): provider_id_entity = ToolProviderID(provider_id) @@ -270,34 +392,19 @@ class ToolManager: # check if the credentials is expired if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): - # TODO: circular import - from core.plugin.impl.oauth import OAuthHandler - from services.tools.builtin_tools_manage_service import BuiltinToolManageService - - # refresh the credentials - tool_provider = ToolProviderID(provider_id) - provider_name = tool_provider.provider_name - redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" - system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) - - oauth_handler = OAuthHandler() - # refresh the credentials - refreshed_credentials = oauth_handler.refresh_credentials( + # Refresh credentials using the centralized helper method + refreshed_credentials, expires_at = cls._refresh_oauth_credentials( tenant_id=tenant_id, + provider_id=provider_id, user_id=builtin_provider.user_id, - plugin_id=tool_provider.plugin_id, - provider=provider_name, - redirect_uri=redirect_uri, - system_credentials=system_credentials or {}, - credentials=decrypted_credentials, + decrypted_credentials=decrypted_credentials, ) - # update the credentials - builtin_provider.encrypted_credentials = json.dumps( - encrypter.encrypt(refreshed_credentials.credentials) - ) - builtin_provider.expires_at = refreshed_credentials.expires_at + + # Update the provider with refreshed credentials + builtin_provider.encrypted_credentials = json.dumps(encrypter.encrypt(refreshed_credentials)) + builtin_provider.expires_at = expires_at db.session.commit() - decrypted_credentials = refreshed_credentials.credentials + decrypted_credentials = refreshed_credentials cache.delete() return cast( @@ -368,6 +475,7 @@ class ToolManager: agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, + end_user_id: str | None = None, ) -> Tool: """ get the agent tool runtime @@ -380,6 +488,8 @@ class ToolManager: invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, credential_id=agent_tool.credential_id, + auth_type=agent_tool.auth_type, + end_user_id=end_user_id, ) runtime_parameters = {} parameters = tool_entity.get_merged_runtime_parameters() @@ -410,6 +520,7 @@ class ToolManager: workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, + end_user_id: str | None = None, ) -> Tool: """ get the workflow tool runtime @@ -423,6 +534,8 @@ class ToolManager: invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, credential_id=workflow_tool.credential_id, + auth_type=workflow_tool.auth_type, + end_user_id=end_user_id, ) parameters = tool_runtime.get_merged_runtime_parameters() diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 2e7ec757b4..77ddc67bf2 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -66,6 +66,7 @@ class ToolNode(Node[ToolNodeData]): # get tool runtime try: from core.tools.tool_manager import ToolManager + from models.enums import UserFrom # This is an issue that caused problems before. # Logically, we shouldn't use the node_data.version field for judgment @@ -74,8 +75,20 @@ class ToolNode(Node[ToolNodeData]): variable_pool: VariablePool | None = None if self.node_data.version != "1" or self.node_data.tool_node_version is not None: variable_pool = self.graph_runtime_state.variable_pool + + # Determine end_user_id based on user_from + end_user_id = None + if self.user_from == UserFrom.END_USER: + end_user_id = self.user_id + tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool + self.tenant_id, + self.app_id, + self._node_id, + self.node_data, + self.invoke_from, + variable_pool, + end_user_id, ) except ToolNodeError as e: yield StreamCompletedEvent( diff --git a/api/migrations/versions/2025_12_14_1418-3134f4e0620d_merge_heads.py b/api/migrations/versions/2025_12_14_1418-3134f4e0620d_merge_heads.py new file mode 100644 index 0000000000..bbcfa003f0 --- /dev/null +++ b/api/migrations/versions/2025_12_14_1418-3134f4e0620d_merge_heads.py @@ -0,0 +1,25 @@ +"""merge_heads + +Revision ID: 3134f4e0620d +Revises: d57accd375ae, a7b4e8f2c9d1 +Create Date: 2025-12-14 14:18:19.393720 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3134f4e0620d' +down_revision = ('d57accd375ae', 'a7b4e8f2c9d1') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/api/services/tools/app_auth_requirement_service.py b/api/services/tools/app_auth_requirement_service.py new file mode 100644 index 0000000000..abee998f58 --- /dev/null +++ b/api/services/tools/app_auth_requirement_service.py @@ -0,0 +1,227 @@ +import logging +from typing import Any + +from sqlalchemy.orm import Session + +from core.tools.tool_manager import ToolManager +from extensions.ext_database import db +from models.workflow import Workflow +from services.tools.enduser_auth_service import EndUserAuthService + +logger = logging.getLogger(__name__) + + +class AppAuthRequirementService: + """ + Service for analyzing authentication requirements in apps. + Examines workflow DSL to identify which providers need end-user authentication. + """ + + @staticmethod + def get_tool_auth_requirements( + app_id: str, + tenant_id: str, + provider_type: str | None = None, + ) -> list[dict[str, Any]]: + """ + Get all authentication requirements for tools in an app. + + :param app_id: The application ID + :param tenant_id: The tenant ID + :param provider_type: Optional filter by provider type (e.g., "tool") + :return: List of provider requirements + """ + try: + # Get latest published workflow for the app + with Session(db.engine, autoflush=False) as session: + workflow = ( + session.query(Workflow) + .filter_by(app_id=app_id, tenant_id=tenant_id) + .order_by(Workflow.created_at.desc()) + .first() + ) + + if not workflow: + return [] + + # Parse workflow graph to find tool nodes + graph = workflow.graph_dict + if not graph or "nodes" not in graph: + return [] + + providers = [] + seen_providers = set() + + # Iterate through workflow nodes + for node in graph.get("nodes", []): + node_data = node.get("data", {}) + node_type = node_data.get("type") + + # Check if it's a tool node + if node_type == "tool": + provider_id = node_data.get("provider_id") + provider_name = node_data.get("provider_name") + tool_name = node_data.get("tool_name") + + if not provider_id: + continue + + # Avoid duplicates + if provider_id in seen_providers: + continue + + seen_providers.add(provider_id) + + # Get provider controller to check authentication requirements + try: + provider_controller = ToolManager.get_builtin_provider(provider_id, tenant_id) + + # Check if provider needs credentials + if not provider_controller.need_credentials: + continue + + # Get supported credential types + supported_types = provider_controller.get_supported_credential_types() + + # Determine required credential type (prefer OAuth if supported) + required_type = None + if supported_types: + # Prefer OAuth2, then API key + from core.plugin.entities.plugin_daemon import CredentialType + + if CredentialType.OAUTH2 in supported_types: + required_type = "oauth2" + elif CredentialType.API_KEY in supported_types: + required_type = "api-key" + else: + required_type = supported_types[0].value + + providers.append( + { + "provider_id": provider_id, + "provider_name": provider_name or provider_id, + "supported_credential_types": [ct.value for ct in supported_types], + "required_credential_type": required_type, + "provider_type": "tool", + "feature_context": { + "node_ids": [node.get("id")], + "tool_names": [tool_name] if tool_name else [], + }, + } + ) + except Exception as e: + logger.warning("Error getting provider info for %s: %s", provider_id, e) + continue + + # Filter by provider_type if specified + if provider_type: + providers = [p for p in providers if p.get("provider_type") == provider_type] + + return providers + except Exception: + logger.exception("Error getting auth requirements for app %s", app_id) + return [] + + @staticmethod + def get_required_providers( + tenant_id: str, + app_id: str, + ) -> list[dict[str, Any]]: + """ + Get list of providers that require end-user authentication for an app. + Simplified version of get_tool_auth_requirements for API use. + + :param tenant_id: The tenant ID + :param app_id: The application ID + :return: List of provider information dictionaries + """ + requirements = AppAuthRequirementService.get_tool_auth_requirements( + app_id=app_id, + tenant_id=tenant_id, + ) + + # Transform to simpler format for API response + return [ + { + "provider_id": req["provider_id"], + "provider_name": req["provider_name"], + "credential_type": req["required_credential_type"], + "is_required": True, + "oauth_config": None if req["required_credential_type"] != "oauth2" else { + "supported_types": req["supported_credential_types"], + }, + } + for req in requirements + ] + + @staticmethod + def get_auth_status( + app_id: str, + end_user_id: str, + tenant_id: str, + ) -> dict[str, Any]: + """ + Get overall authentication status for an app and end user. + Shows which providers are authenticated and which need authentication. + + :param app_id: The application ID + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :return: Dict with authentication status for all providers + """ + try: + # Get required providers for this app + required_providers = AppAuthRequirementService.get_tool_auth_requirements(app_id, tenant_id) + + # Check authentication status for each provider + provider_statuses = [] + for provider_info in required_providers: + provider_id = provider_info["provider_id"] + + # Get credentials for this provider + credentials = EndUserAuthService.list_credentials(tenant_id, end_user_id, provider_id) + + # Build status + provider_status = { + "provider_id": provider_id, + "provider_name": provider_info["provider_name"], + "provider_type": provider_info["provider_type"], + "authenticated": len(credentials) > 0, + "credentials": [ + { + "credential_id": cred.id, + "name": cred.name, + "type": cred.credential_type.value, + "is_default": cred.is_default, + "expires_at": cred.expires_at, + } + for cred in credentials + ], + } + + provider_statuses.append(provider_status) + + return {"providers": provider_statuses} + except Exception: + logger.exception("Error getting auth status for app %s", app_id) + return {"providers": []} + + @staticmethod + def is_provider_authenticated( + provider_id: str, + end_user_id: str, + tenant_id: str, + ) -> bool: + """ + Check if a specific provider is authenticated for an end user. + + :param provider_id: The provider identifier + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :return: True if authenticated, False otherwise + """ + try: + credentials = EndUserAuthService.list_credentials(tenant_id, end_user_id, provider_id) + return len(credentials) > 0 + except Exception: + return False diff --git a/api/services/tools/enduser_auth_service.py b/api/services/tools/enduser_auth_service.py new file mode 100644 index 0000000000..bf05918a0e --- /dev/null +++ b/api/services/tools/enduser_auth_service.py @@ -0,0 +1,558 @@ +import json +import logging + +from sqlalchemy import exists, select +from sqlalchemy.orm import Session + +from constants import HIDDEN_VALUE, UNKNOWN_VALUE +from core.helper.name_generator import generate_incremental_name +from core.helper.provider_cache import NoOpProviderCredentialCache +from core.plugin.entities.plugin_daemon import CredentialType +from core.tools.entities.api_entities import ToolProviderCredentialApiEntity +from core.tools.tool_manager import ToolManager +from core.tools.utils.encryption import create_provider_encrypter +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.tools import EndUserAuthenticationProvider +from services.tools.tools_transform_service import ToolTransformService + +logger = logging.getLogger(__name__) + + +class EndUserAuthService: + """ + Service for managing end-user authentication credentials. + Follows similar patterns to BuiltinToolManageService but for end users. + """ + + __MAX_CREDENTIALS_PER_PROVIDER__ = 100 + + @staticmethod + def list_credentials( + tenant_id: str, end_user_id: str, provider_id: str + ) -> list[ToolProviderCredentialApiEntity]: + """ + List all credentials for a specific provider and end user. + + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :param provider_id: The provider identifier + :return: List of credential entities + """ + with Session(db.engine, autoflush=False) as session: + credentials = ( + session.query(EndUserAuthenticationProvider) + .filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider_id) + .order_by(EndUserAuthenticationProvider.created_at.asc()) + .all() + ) + + if not credentials: + return [] + + # Get provider controller to access credential schema + provider_controller = ToolManager.get_builtin_provider(provider_id, tenant_id) + + result: list[ToolProviderCredentialApiEntity] = [] + for credential in credentials: + try: + # Create encrypter for masking credentials + encrypter, _ = EndUserAuthService._create_encrypter( + tenant_id, provider_controller, credential.credential_type + ) + + # Decrypt and mask credentials + decrypted = encrypter.decrypt(credential.credentials) + masked_credentials = encrypter.mask_plugin_credentials(decrypted) + + # Convert to API entity + credential_entity = ToolTransformService.convert_enduser_provider_to_credential_entity( + provider=credential, + credentials=dict(masked_credentials), + ) + result.append(credential_entity) + except Exception: + logger.exception("Error processing credential %s", credential.id) + continue + + return result + + @staticmethod + def get_credential( + credential_id: str, end_user_id: str, tenant_id: str, mask_credentials: bool = True + ) -> ToolProviderCredentialApiEntity | None: + """ + Get a specific credential by ID. + + :param credential_id: The credential ID + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :param mask_credentials: Whether to mask secret fields + :return: Credential entity or None + """ + with Session(db.engine, autoflush=False) as session: + credential = ( + session.query(EndUserAuthenticationProvider) + .filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id) + .first() + ) + + if not credential: + return None + + # Get provider controller + provider_controller = ToolManager.get_builtin_provider(credential.provider, tenant_id) + + # Create encrypter + encrypter, _ = EndUserAuthService._create_encrypter( + tenant_id, provider_controller, credential.credential_type + ) + + # Decrypt credentials + decrypted = encrypter.decrypt(credential.credentials) + + # Mask if requested + if mask_credentials: + decrypted = encrypter.mask_plugin_credentials(decrypted) + + # Convert to API entity + return ToolTransformService.convert_enduser_provider_to_credential_entity( + provider=credential, + credentials=dict(decrypted), + ) + + @staticmethod + def create_api_key_credential( + tenant_id: str, + end_user_id: str, + provider_id: str, + credentials: dict, + name: str | None = None, + ) -> ToolProviderCredentialApiEntity: + """ + Create a new API key credential for an end user. + + :param tenant_id: The tenant ID + :param end_user_id: The end user ID + :param provider_id: The provider identifier + :param credentials: The credential data + :param name: Optional custom name + :return: Created credential entity + """ + with Session(db.engine) as session: + try: + lock = f"enduser_credential_create_lock:{end_user_id}_{provider_id}" + with redis_client.lock(lock, timeout=20): + # Get provider controller + provider_controller = ToolManager.get_builtin_provider(provider_id, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"Provider {provider_id} does not need credentials") + + # Check credential count + credential_count = ( + session.query(EndUserAuthenticationProvider) + .filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider_id) + .count() + ) + + if credential_count >= EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__: + raise ValueError( + f"Maximum number of credentials ({EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__}) " + f"reached for provider {provider_id}" + ) + + # Validate credentials + credential_type = CredentialType.API_KEY + if CredentialType.of(credential_type).is_validate_allowed(): + provider_controller.validate_credentials(end_user_id, credentials) + + # Generate name if not provided + if name is None or name == "": + name = EndUserAuthService._generate_credential_name( + session=session, + end_user_id=end_user_id, + tenant_id=tenant_id, + provider=provider_id, + credential_type=credential_type, + ) + else: + # Validate name length + if len(name) > 30: + raise ValueError("Credential name must be 30 characters or less") + + # Check if name is already used + if session.scalar( + select( + exists().where( + EndUserAuthenticationProvider.end_user_id == end_user_id, + EndUserAuthenticationProvider.tenant_id == tenant_id, + EndUserAuthenticationProvider.provider == provider_id, + EndUserAuthenticationProvider.name == name, + ) + ) + ): + raise ValueError(f"The credential name '{name}' is already used") + + # Create encrypter + encrypter, _ = EndUserAuthService._create_encrypter( + tenant_id, provider_controller, credential_type + ) + + # Create credential record + db_credential = EndUserAuthenticationProvider( + tenant_id=tenant_id, + end_user_id=end_user_id, + provider=provider_id, + encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), + credential_type=credential_type, + name=name, + expires_at=-1, # API keys don't expire + ) + + session.add(db_credential) + session.commit() + session.refresh(db_credential) + + # Return masked credentials + masked_credentials = encrypter.mask_plugin_credentials(credentials) + return ToolTransformService.convert_enduser_provider_to_credential_entity( + provider=db_credential, + credentials=dict(masked_credentials), + ) + except Exception as e: + session.rollback() + logger.exception("Error creating API key credential") + raise ValueError(str(e)) + + @staticmethod + def create_oauth_credential( + end_user_id: str, + tenant_id: str, + provider: str, + credentials: dict, + expires_at: int = -1, + name: str | None = None, + ) -> EndUserAuthenticationProvider: + """ + Create a new OAuth credential for an end user. + Used internally by OAuth callback handler. + + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :param provider: The provider identifier + :param credentials: The OAuth credentials (access_token, refresh_token, etc.) + :param expires_at: Unix timestamp when token expires (-1 for no expiry) + :param name: Optional custom name + :return: Created credential record + """ + with Session(db.engine) as session: + try: + lock = f"enduser_credential_create_lock:{end_user_id}_{provider}" + with redis_client.lock(lock, timeout=20): + # Get provider controller + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + + # Check credential count + credential_count = ( + session.query(EndUserAuthenticationProvider) + .filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider) + .count() + ) + + if credential_count >= EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__: + raise ValueError( + f"Maximum number of credentials ({EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__}) " + f"reached for provider {provider}" + ) + + # Generate name if not provided + credential_type = CredentialType.OAUTH2 + if name is None or name == "": + name = EndUserAuthService._generate_credential_name( + session=session, + end_user_id=end_user_id, + tenant_id=tenant_id, + provider=provider, + credential_type=credential_type, + ) + + # Create encrypter + encrypter, _ = EndUserAuthService._create_encrypter( + tenant_id, provider_controller, credential_type + ) + + # Create credential record + db_credential = EndUserAuthenticationProvider( + tenant_id=tenant_id, + end_user_id=end_user_id, + provider=provider, + encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), + credential_type=credential_type, + name=name, + expires_at=expires_at, + ) + + session.add(db_credential) + session.commit() + session.refresh(db_credential) + + return db_credential + except Exception as e: + session.rollback() + logger.exception("Error creating OAuth credential") + raise ValueError(str(e)) + + @staticmethod + def update_credential( + credential_id: str, + end_user_id: str, + tenant_id: str, + credentials: dict | None = None, + name: str | None = None, + ) -> ToolProviderCredentialApiEntity: + """ + Update an existing credential (API key only). + + :param credential_id: The credential ID + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :param credentials: Updated credentials (optional) + :param name: Updated name (optional) + :return: Updated credential entity + """ + with Session(db.engine) as session: + try: + # Get credential + db_credential = ( + session.query(EndUserAuthenticationProvider) + .filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id) + .first() + ) + + if not db_credential: + raise ValueError(f"Credential {credential_id} not found") + + # Only API key credentials can be updated + if not CredentialType.of(db_credential.credential_type).is_editable(): + raise ValueError("Only API key credentials can be updated via this endpoint") + + # At least one field must be provided + if credentials is None and name is None: + raise ValueError("At least one field (credentials or name) must be provided") + + # Get provider controller + provider_controller = ToolManager.get_builtin_provider(db_credential.provider, tenant_id) + + # Create encrypter + encrypter, _ = EndUserAuthService._create_encrypter( + tenant_id, provider_controller, db_credential.credential_type + ) + + # Update credentials if provided + if credentials: + # Decrypt original credentials + original_credentials = encrypter.decrypt(db_credential.credentials) + + # Merge with new credentials, keeping hidden values + new_credentials: dict = { + key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE) + for key, value in credentials.items() + } + + # Validate new credentials + if CredentialType.of(db_credential.credential_type).is_validate_allowed(): + provider_controller.validate_credentials(end_user_id, new_credentials) + + # Encrypt and save + db_credential.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials)) + + # Update name if provided + if name and name != db_credential.name: + # Validate name length + if len(name) > 30: + raise ValueError("Credential name must be 30 characters or less") + + # Check if name is already used + if session.scalar( + select( + exists().where( + EndUserAuthenticationProvider.end_user_id == end_user_id, + EndUserAuthenticationProvider.tenant_id == tenant_id, + EndUserAuthenticationProvider.provider == db_credential.provider, + EndUserAuthenticationProvider.name == name, + ) + ) + ): + raise ValueError(f"The credential name '{name}' is already used") + + db_credential.name = name + + session.commit() + session.refresh(db_credential) + + # Return masked credentials + decrypted = encrypter.decrypt(db_credential.credentials) + masked_credentials = encrypter.mask_plugin_credentials(decrypted) + + return ToolTransformService.convert_enduser_provider_to_credential_entity( + provider=db_credential, + credentials=dict(masked_credentials), + ) + except Exception as e: + session.rollback() + logger.exception("Error updating credential") + raise ValueError(str(e)) + + @staticmethod + def delete_credential( + tenant_id: str, end_user_id: str, provider_id: str, credential_id: str + ) -> bool: + """ + Delete a credential. + + :param credential_id: The credential ID + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :return: True if deleted successfully + """ + with Session(db.engine) as session: + credential = ( + session.query(EndUserAuthenticationProvider) + .filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id) + .first() + ) + + if not credential: + raise ValueError(f"Credential {credential_id} not found") + + session.delete(credential) + session.commit() + return True + + @staticmethod + def refresh_oauth_token( + credential_id: str, end_user_id: str, tenant_id: str, refreshed_credentials: dict, expires_at: int + ) -> EndUserAuthenticationProvider: + """ + Update OAuth credentials after token refresh. + + :param credential_id: The credential ID + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :param refreshed_credentials: New credentials from OAuth refresh + :param expires_at: New expiration timestamp + :return: Updated credential record + """ + with Session(db.engine) as session: + try: + credential = ( + session.query(EndUserAuthenticationProvider) + .filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id) + .first() + ) + + if not credential: + raise ValueError(f"Credential {credential_id} not found") + + if credential.credential_type != CredentialType.OAUTH2: + raise ValueError("Only OAuth credentials can be refreshed") + + # Get provider controller + provider_controller = ToolManager.get_builtin_provider(credential.provider, tenant_id) + + # Create encrypter + encrypter, _ = EndUserAuthService._create_encrypter( + tenant_id, provider_controller, credential.credential_type + ) + + # Encrypt and save new credentials + credential.encrypted_credentials = json.dumps(encrypter.encrypt(refreshed_credentials)) + credential.expires_at = expires_at + + session.commit() + session.refresh(credential) + + return credential + except Exception as e: + session.rollback() + logger.exception("Error refreshing OAuth token") + raise ValueError(str(e)) + + @staticmethod + def get_default_credential( + end_user_id: str, tenant_id: str, provider: str + ) -> EndUserAuthenticationProvider | None: + """ + Get the default (oldest) credential for a provider. + + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :param provider: The provider identifier + :return: Credential record or None + """ + with Session(db.engine, autoflush=False) as session: + return ( + session.query(EndUserAuthenticationProvider) + .filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider) + .order_by(EndUserAuthenticationProvider.created_at.asc()) + .first() + ) + + @staticmethod + def _generate_credential_name( + session: Session, + end_user_id: str, + tenant_id: str, + provider: str, + credential_type: CredentialType, + ) -> str: + """ + Generate a unique credential name. + + :param session: Database session + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :param provider: The provider identifier + :param credential_type: The credential type + :return: Generated name (e.g., "API KEY 1", "AUTH 1") + """ + existing_credentials = ( + session.query(EndUserAuthenticationProvider) + .filter_by( + end_user_id=end_user_id, + tenant_id=tenant_id, + provider=provider, + credential_type=credential_type, + ) + .order_by(EndUserAuthenticationProvider.created_at.desc()) + .all() + ) + + return generate_incremental_name( + [credential.name for credential in existing_credentials], + f"{credential_type.get_name()}", + ) + + @staticmethod + def _create_encrypter( + tenant_id: str, provider_controller, credential_type: CredentialType | str + ) -> tuple: + """ + Create an encrypter for credential encryption/decryption. + + :param tenant_id: The tenant ID + :param provider_controller: The provider controller + :param credential_type: The credential type + :return: Tuple of (encrypter, cache) + """ + if isinstance(credential_type, str): + credential_type = CredentialType.of(credential_type) + + return create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(credential_type) + ], + cache=NoOpProviderCredentialCache(), + ) diff --git a/api/services/tools/enduser_oauth_service.py b/api/services/tools/enduser_oauth_service.py new file mode 100644 index 0000000000..1dabde90d7 --- /dev/null +++ b/api/services/tools/enduser_oauth_service.py @@ -0,0 +1,264 @@ +import logging +from typing import Any + +from werkzeug import Request + +from configs import dify_config +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.impl.oauth import OAuthHandler +from core.tools.tool_manager import ToolManager +from models.provider_ids import ToolProviderID +from services.plugin.oauth_service import OAuthProxyService +from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from services.tools.enduser_auth_service import EndUserAuthService + +logger = logging.getLogger(__name__) + + +class EndUserOAuthService: + """ + Service for managing end-user OAuth authentication flows. + Reuses existing OAuthProxyService and OAuthHandler infrastructure. + """ + + @staticmethod + def get_authorization_url( + end_user_id: str, + tenant_id: str, + app_id: str, + provider: str, + ) -> dict[str, str]: + """ + Initiate OAuth authorization flow for an end user. + + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :param app_id: The application ID + :param provider: The provider identifier + :return: Dict with authorization_url + """ + try: + # Get OAuth client configuration (reuse workspace-level logic) + oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, provider) + if not oauth_client: + raise ValueError(f"OAuth client not configured for provider {provider}") + + # Get provider controller + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + tool_provider_id = ToolProviderID(provider) + + # Create OAuth context with end-user specific data + context_id = OAuthProxyService.create_proxy_context( + user_id=end_user_id, # Using end_user_id as user_id + tenant_id=tenant_id, + plugin_id=tool_provider_id.plugin_id, + provider=tool_provider_id.provider_name, + extra_data={ + "app_id": app_id, + "provider_type": "tool", # For now, only tools support end-user auth + }, + ) + + # Use the same redirect URI as workspace OAuth to reuse the same OAuth client + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{tool_provider_id}/tool/callback" + + # Get authorization URL from OAuth handler + oauth_handler = OAuthHandler() + response = oauth_handler.get_authorization_url( + tenant_id=tenant_id, + user_id=end_user_id, + plugin_id=tool_provider_id.plugin_id, + provider=tool_provider_id.provider_name, + redirect_uri=redirect_uri, + system_credentials=oauth_client, + ) + + return { + "authorization_url": response.authorization_url, + "context_id": context_id, # Return for setting cookie + } + except Exception as e: + logger.exception("Error getting authorization URL for end user") + raise ValueError(f"Failed to initiate OAuth flow: {str(e)}") + + @staticmethod + def handle_oauth_callback( + context_id: str, + request: Request, + ) -> dict[str, Any]: + """ + Handle OAuth callback and create credential. + + :param context_id: The OAuth context ID from cookie + :param request: The callback request with authorization code + :return: Dict with credential information + """ + try: + # Validate and retrieve context + context = OAuthProxyService.use_proxy_context(context_id) + + # Extract context data + end_user_id = context.get("user_id") # user_id is actually end_user_id + tenant_id = context.get("tenant_id") + app_id = context.get("app_id") + plugin_id = context.get("plugin_id") + provider = context.get("provider") + + if not all([end_user_id, tenant_id, app_id, plugin_id, provider]): + raise ValueError("Invalid OAuth context: missing required fields") + + # Reconstruct full provider ID + full_provider = f"{plugin_id}/{provider}" if plugin_id != "langgenius" else provider + + # Get OAuth client configuration + oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, full_provider) + if not oauth_client: + raise ValueError(f"OAuth client not configured for provider {full_provider}") + + # Use the same redirect URI as workspace OAuth (must match authorization request) + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{full_provider}/tool/callback" + + # Exchange authorization code for credentials + oauth_handler = OAuthHandler() + credentials_response = oauth_handler.get_credentials( + tenant_id=tenant_id, + user_id=end_user_id, + plugin_id=plugin_id, + provider=provider, + redirect_uri=redirect_uri, + system_credentials=oauth_client, + request=request, + ) + + # Calculate expiration timestamp + expires_at = -1 + if credentials_response.expires_in and credentials_response.expires_in > 0: + import time + + expires_at = int(time.time()) + credentials_response.expires_in + + # Create credential in database + credential = EndUserAuthService.create_oauth_credential( + end_user_id=end_user_id, + tenant_id=tenant_id, + provider=full_provider, + credentials=credentials_response.credentials, + expires_at=expires_at, + ) + + return { + "success": True, + "credential_id": credential.id, + "provider": full_provider, + "app_id": app_id, + } + except Exception as e: + logger.exception("Error handling OAuth callback for end user") + raise ValueError(f"Failed to complete OAuth flow: {str(e)}") + + @staticmethod + def refresh_oauth_token( + credential_id: str, + end_user_id: str, + tenant_id: str, + ) -> dict[str, Any]: + """ + Refresh an expired OAuth token. + + :param credential_id: The credential ID + :param end_user_id: The end user ID + :param tenant_id: The tenant ID + :return: Dict with refresh status + """ + try: + # Get existing credential + credential = EndUserAuthService.get_credential( + credential_id=credential_id, + end_user_id=end_user_id, + tenant_id=tenant_id, + mask_credentials=False, # Need full credentials for refresh + ) + + if not credential: + raise ValueError(f"Credential {credential_id} not found") + + if credential.credential_type != CredentialType.OAUTH2: + raise ValueError("Only OAuth credentials can be refreshed") + + # Get OAuth client configuration + oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, credential.provider) + if not oauth_client: + raise ValueError(f"OAuth client not configured for provider {credential.provider}") + + # Get provider info + tool_provider_id = ToolProviderID(credential.provider) + # Use the same redirect URI as workspace OAuth to reuse the same OAuth client + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{credential.provider}/tool/callback" + + # Refresh credentials via OAuth handler + oauth_handler = OAuthHandler() + refreshed_response = oauth_handler.refresh_credentials( + tenant_id=tenant_id, + user_id=end_user_id, + plugin_id=tool_provider_id.plugin_id, + provider=tool_provider_id.provider_name, + redirect_uri=redirect_uri, + system_credentials=oauth_client, + credentials=credential.credentials, + ) + + # Calculate new expiration timestamp + expires_at = -1 + if refreshed_response.expires_in and refreshed_response.expires_in > 0: + import time + + expires_at = int(time.time()) + refreshed_response.expires_in + + # Update credential in database + updated_credential = EndUserAuthService.refresh_oauth_token( + credential_id=credential_id, + end_user_id=end_user_id, + tenant_id=tenant_id, + refreshed_credentials=refreshed_response.credentials, + expires_at=expires_at, + ) + + return { + "success": True, + "credential_id": updated_credential.id, + "expires_at": expires_at, + "refreshed_at": int(updated_credential.updated_at.timestamp()), + } + except Exception: + logger.exception("Error refreshing OAuth token for end user") + return { + "success": False, + "error": "Failed to refresh token", + } + + @staticmethod + def get_oauth_client_info(tenant_id: str, provider: str) -> dict[str, Any]: + """ + Get OAuth client information for a provider. + Used to check if OAuth is available and configured. + + :param tenant_id: The tenant ID + :param provider: The provider identifier + :return: Dict with OAuth client info + """ + try: + # Check if OAuth client exists (either system or custom) + oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, provider) + + return { + "configured": oauth_client is not None, + "system_configured": BuiltinToolManageService.is_oauth_system_client_exists(provider), + "custom_configured": BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider), + } + except Exception: + logger.exception("Error getting OAuth client info") + return { + "configured": False, + "system_configured": False, + "custom_configured": False, + } diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index e323b3cda9..3e3edfedb4 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -426,6 +426,27 @@ class ToolTransformService: credentials=credentials, ) + @staticmethod + def convert_enduser_provider_to_credential_entity( + provider, credentials: dict + ) -> ToolProviderCredentialApiEntity: + """ + Convert EndUserAuthenticationProvider to ToolProviderCredentialApiEntity. + + :param provider: EndUserAuthenticationProvider instance + :param credentials: Decrypted/masked credentials dict + :return: ToolProviderCredentialApiEntity + """ + return ToolProviderCredentialApiEntity( + id=provider.id, + name=provider.name, + provider=provider.provider, + credential_type=CredentialType.of(provider.credential_type), + is_default=False, # End-user credentials don't have default flag (use oldest) + credentials=credentials, + expires_at=provider.expires_at, + ) + @staticmethod def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]: """ diff --git a/api/tests/unit_tests/core/tools/test_enduser_tool_auth.py b/api/tests/unit_tests/core/tools/test_enduser_tool_auth.py new file mode 100644 index 0000000000..3d322b3f76 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_enduser_tool_auth.py @@ -0,0 +1,486 @@ +""" +Unit tests for end-user tool authentication. + +Tests the integration of end-user authentication with tool runtime resolution. +""" + +import time +from unittest.mock import MagicMock, patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.entities.tool_entities import ToolAuthType, ToolInvokeFrom, ToolProviderType +from core.tools.errors import ToolProviderNotFoundError +from core.tools.tool_manager import ToolManager + + +@pytest.fixture +def mock_db_session(): + """Mock database session.""" + with patch("core.tools.tool_manager.db") as mock_db: + yield mock_db.session + + +@pytest.fixture +def mock_provider_controller(): + """Mock builtin provider controller.""" + controller = MagicMock() + controller.need_credentials = True + controller.get_tool = MagicMock(return_value=MagicMock()) + controller.get_credentials_schema_by_type = MagicMock(return_value=[]) + return controller + + +class TestEndUserToolAuthentication: + """Test suite for end-user tool authentication.""" + + def test_end_user_auth_requires_end_user_id(self, mock_db_session, mock_provider_controller): + """ + Test that END_USER auth_type requires end_user_id parameter. + + When auth_type is END_USER but end_user_id is None, should raise error. + """ + with patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller): + with pytest.raises(ToolProviderNotFoundError, match="end_user_id is required"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="test_provider", + tool_name="test_tool", + tenant_id="test_tenant", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + auth_type=ToolAuthType.END_USER, + end_user_id=None, # Missing! + ) + + def test_end_user_auth_missing_credentials(self, mock_db_session, mock_provider_controller): + """ + Test that error is raised when end-user has no credentials for provider. + + When auth_type is END_USER but no credentials exist, should raise error. + """ + # Mock no credentials found + mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + with patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller): + with pytest.raises(ToolProviderNotFoundError, match="No end-user credentials found"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="test_provider", + tool_name="test_tool", + tenant_id="test_tenant", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + auth_type=ToolAuthType.END_USER, + end_user_id="end_user_123", + ) + + def test_end_user_auth_with_credentials(self, mock_db_session, mock_provider_controller): + """ + Test successful end-user credential resolution. + + When auth_type is END_USER and credentials exist, should return tool runtime. + """ + # Mock end-user provider + mock_enduser_provider = MagicMock() + mock_enduser_provider.id = "cred_123" + mock_enduser_provider.credential_type = "api-key" + mock_enduser_provider.credentials = '{"api_key": "encrypted"}' + mock_enduser_provider.expires_at = -1 # No expiry + + mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_enduser_provider + ) + + # Mock encrypter + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"api_key": "decrypted_key"} + mock_cache = MagicMock() + + with ( + patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller), + patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)), + ): + tool_runtime = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="test_provider", + tool_name="test_tool", + tenant_id="test_tenant", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + auth_type=ToolAuthType.END_USER, + end_user_id="end_user_123", + ) + + # Verify tool runtime was created + assert tool_runtime is not None + # Verify encrypter was called with decrypted credentials + mock_encrypter.decrypt.assert_called_once() + + def test_workspace_auth_backward_compatibility(self, mock_db_session, mock_provider_controller): + """ + Test that workspace authentication still works (backward compatibility). + + When auth_type is WORKSPACE (default), should use workspace credentials. + """ + # Mock workspace provider + mock_workspace_provider = MagicMock() + mock_workspace_provider.id = "workspace_cred_123" + mock_workspace_provider.credential_type = "api-key" + mock_workspace_provider.credentials = '{"api_key": "workspace_encrypted"}' + mock_workspace_provider.expires_at = -1 + + mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_workspace_provider + ) + + # Mock encrypter + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"api_key": "workspace_decrypted"} + mock_cache = MagicMock() + + with ( + patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller), + patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)), + patch("core.helper.credential_utils.check_credential_policy_compliance"), + ): + tool_runtime = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="test_provider", + tool_name="test_tool", + tenant_id="test_tenant", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + auth_type=ToolAuthType.WORKSPACE, # Workspace auth + end_user_id=None, # Not needed for workspace auth + ) + + # Verify tool runtime was created + assert tool_runtime is not None + + def test_workflow_tool_runtime_passes_end_user_id(self, mock_db_session, mock_provider_controller): + """ + Test that get_workflow_tool_runtime correctly passes end_user_id to get_tool_runtime. + """ + from core.workflow.nodes.tool.entities import ToolEntity + + # Create a mock ToolEntity with END_USER auth_type + workflow_tool = MagicMock(spec=ToolEntity) + workflow_tool.provider_type = ToolProviderType.BUILT_IN + workflow_tool.provider_id = "test_provider" + workflow_tool.tool_name = "test_tool" + workflow_tool.credential_id = None + workflow_tool.auth_type = ToolAuthType.END_USER + workflow_tool.tool_configurations = {} + + # Mock end-user credentials + mock_enduser_provider = MagicMock() + mock_enduser_provider.id = "cred_123" + mock_enduser_provider.credential_type = "api-key" + mock_enduser_provider.credentials = '{"api_key": "encrypted"}' + mock_enduser_provider.expires_at = -1 + + mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_enduser_provider + ) + + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"api_key": "decrypted"} + mock_cache = MagicMock() + + with ( + patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller), + patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)), + ): + tool_runtime = ToolManager.get_workflow_tool_runtime( + tenant_id="test_tenant", + app_id="test_app", + node_id="test_node", + workflow_tool=workflow_tool, + invoke_from=InvokeFrom.SERVICE_API, + variable_pool=None, + end_user_id="end_user_123", # Pass end_user_id + ) + + # Verify tool runtime was created + assert tool_runtime is not None + + +class TestOAuthTokenRefresh: + """Test suite for OAuth token refresh functionality.""" + + def test_enduser_oauth_token_refresh_when_expired(self, mock_db_session, mock_provider_controller): + """ + Test that end-user OAuth tokens are automatically refreshed when expired. + + When an OAuth token is expired (expires_at < current_time + 60s buffer), + the system should automatically refresh it before using. + """ + # Mock end-user provider with expired OAuth token + mock_enduser_provider = MagicMock() + mock_enduser_provider.id = "cred_123" + mock_enduser_provider.credential_type = "oauth2" + mock_enduser_provider.credentials = '{"access_token": "old_token", "refresh_token": "refresh"}' + # Set expiry to past (token expired) + mock_enduser_provider.expires_at = int(time.time()) - 100 + mock_enduser_provider.encrypted_credentials = '{"access_token": "old_token", "refresh_token": "refresh"}' + + mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_enduser_provider + ) + + # Mock encrypter + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"access_token": "old_token", "refresh_token": "refresh"} + mock_encrypter.encrypt.return_value = {"access_token": "new_token", "refresh_token": "refresh"} + mock_cache = MagicMock() + + # Mock OAuth refresh response + mock_refreshed_credentials = MagicMock() + mock_refreshed_credentials.credentials = {"access_token": "new_token", "refresh_token": "refresh"} + mock_refreshed_credentials.expires_at = int(time.time()) + 3600 # New expiry 1 hour from now + + with ( + patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller), + patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)), + patch.object( + ToolManager, + "_refresh_oauth_credentials", + return_value=( + mock_refreshed_credentials.credentials, + mock_refreshed_credentials.expires_at, + ), + ) as mock_refresh, + ): + tool_runtime = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="github", + tool_name="test_tool", + tenant_id="test_tenant", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + auth_type=ToolAuthType.END_USER, + end_user_id="end_user_123", + ) + + # Verify refresh was called + mock_refresh.assert_called_once_with( + tenant_id="test_tenant", + provider_id="github", + user_id="end_user_123", + decrypted_credentials={"access_token": "old_token", "refresh_token": "refresh"}, + ) + + # Verify provider was updated with new credentials + assert mock_enduser_provider.expires_at == mock_refreshed_credentials.expires_at + mock_db_session.commit.assert_called_once() + mock_cache.delete.assert_called_once() + + # Verify tool runtime was created + assert tool_runtime is not None + + def test_enduser_oauth_token_not_refreshed_when_valid(self, mock_db_session, mock_provider_controller): + """ + Test that valid OAuth tokens are NOT refreshed. + + When an OAuth token is still valid (expires_at > current_time + 60s buffer), + the system should use it without refreshing. + """ + # Mock end-user provider with valid OAuth token + mock_enduser_provider = MagicMock() + mock_enduser_provider.id = "cred_123" + mock_enduser_provider.credential_type = "oauth2" + mock_enduser_provider.credentials = '{"access_token": "valid_token", "refresh_token": "refresh"}' + # Set expiry to future (token still valid with buffer) + mock_enduser_provider.expires_at = int(time.time()) + 3600 # Valid for 1 hour + + mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_enduser_provider + ) + + # Mock encrypter + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"access_token": "valid_token", "refresh_token": "refresh"} + mock_cache = MagicMock() + + with ( + patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller), + patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)), + patch.object(ToolManager, "_refresh_oauth_credentials") as mock_refresh, + ): + tool_runtime = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="github", + tool_name="test_tool", + tenant_id="test_tenant", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + auth_type=ToolAuthType.END_USER, + end_user_id="end_user_123", + ) + + # Verify refresh was NOT called (token still valid) + mock_refresh.assert_not_called() + + # Verify tool runtime was created with original credentials + assert tool_runtime is not None + + def test_workspace_oauth_token_refresh_when_expired(self, mock_db_session, mock_provider_controller): + """ + Test that workspace OAuth tokens are automatically refreshed when expired. + + This ensures the refactored _refresh_oauth_credentials helper works + for both end-user and workspace authentication flows. + """ + # Mock workspace provider with expired OAuth token + mock_workspace_provider = MagicMock() + mock_workspace_provider.id = "workspace_cred_123" + mock_workspace_provider.user_id = "workspace_user_456" + mock_workspace_provider.credential_type = "oauth2" + mock_workspace_provider.credentials = '{"access_token": "old_workspace_token"}' + mock_workspace_provider.expires_at = int(time.time()) - 100 # Expired + mock_workspace_provider.encrypted_credentials = '{"access_token": "old_workspace_token"}' + + mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_workspace_provider + ) + + # Mock encrypter + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"access_token": "old_workspace_token"} + mock_encrypter.encrypt.return_value = {"access_token": "new_workspace_token"} + mock_cache = MagicMock() + + # Mock OAuth refresh response + refreshed_creds = {"access_token": "new_workspace_token"} + new_expires_at = int(time.time()) + 3600 + + with ( + patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller), + patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)), + patch("core.helper.credential_utils.check_credential_policy_compliance"), + patch.object( + ToolManager, "_refresh_oauth_credentials", return_value=(refreshed_creds, new_expires_at) + ) as mock_refresh, + ): + tool_runtime = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="github", + tool_name="test_tool", + tenant_id="test_tenant", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + auth_type=ToolAuthType.WORKSPACE, + ) + + # Verify refresh was called with workspace user_id + mock_refresh.assert_called_once_with( + tenant_id="test_tenant", + provider_id="github", + user_id="workspace_user_456", + decrypted_credentials={"access_token": "old_workspace_token"}, + ) + + # Verify provider was updated + assert mock_workspace_provider.expires_at == new_expires_at + mock_db_session.commit.assert_called_once() + mock_cache.delete.assert_called_once() + + # Verify tool runtime was created + assert tool_runtime is not None + + def test_oauth_token_no_refresh_for_non_oauth_credentials(self, mock_db_session, mock_provider_controller): + """ + Test that non-OAuth credentials (API keys) are never refreshed. + + API keys with expires_at = -1 should not trigger refresh logic. + """ + # Mock end-user provider with API key (no expiry) + mock_enduser_provider = MagicMock() + mock_enduser_provider.id = "cred_123" + mock_enduser_provider.credential_type = "api-key" + mock_enduser_provider.credentials = '{"api_key": "sk-1234567890"}' + mock_enduser_provider.expires_at = -1 # API keys don't expire + + mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_enduser_provider + ) + + # Mock encrypter + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"api_key": "sk-1234567890"} + mock_cache = MagicMock() + + with ( + patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller), + patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)), + patch.object(ToolManager, "_refresh_oauth_credentials") as mock_refresh, + ): + tool_runtime = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="openai", + tool_name="test_tool", + tenant_id="test_tenant", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + auth_type=ToolAuthType.END_USER, + end_user_id="end_user_123", + ) + + # Verify refresh was NOT called (API key doesn't need refresh) + mock_refresh.assert_not_called() + + # Verify tool runtime was created + assert tool_runtime is not None + + def test_refresh_oauth_credentials_helper_method(self): + """ + Test the _refresh_oauth_credentials helper method directly. + + This tests the centralized OAuth refresh logic that is used by both + end-user and workspace authentication flows. + """ + # Mock dependencies + mock_oauth_handler = MagicMock() + mock_refreshed = MagicMock() + mock_refreshed.credentials = {"access_token": "new_token", "refresh_token": "new_refresh"} + mock_refreshed.expires_at = int(time.time()) + 7200 + mock_oauth_handler.refresh_credentials.return_value = mock_refreshed + + with ( + # Patch OAuthHandler where it's imported (inside the method) + patch("core.plugin.impl.oauth.OAuthHandler", return_value=mock_oauth_handler), + patch("core.tools.tool_manager.ToolProviderID") as mock_provider_id, + patch( + "services.tools.builtin_tools_manage_service.BuiltinToolManageService.get_oauth_client", + return_value={"client_id": "test"}, + ), + patch("core.tools.tool_manager.dify_config.CONSOLE_API_URL", "http://localhost:5001"), + ): + # Setup provider ID mock + mock_provider_id.return_value.provider_name = "github" + mock_provider_id.return_value.plugin_id = "builtin" + + # Call the helper method + credentials, expires_at = ToolManager._refresh_oauth_credentials( + tenant_id="test_tenant", + provider_id="langgenius/github/github", + user_id="user_123", + decrypted_credentials={"access_token": "old_token", "refresh_token": "old_refresh"}, + ) + + # Verify OAuth handler was called correctly + mock_oauth_handler.refresh_credentials.assert_called_once_with( + tenant_id="test_tenant", + user_id="user_123", + plugin_id="builtin", + provider="github", + redirect_uri="http://localhost:5001/console/api/oauth/plugin/langgenius/github/github/tool/callback", + system_credentials={"client_id": "test"}, + credentials={"access_token": "old_token", "refresh_token": "old_refresh"}, + ) + + # Verify returned values + assert credentials == {"access_token": "new_token", "refresh_token": "new_refresh"} + assert expires_at == mock_refreshed.expires_at diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx index 839ac7de51..ce166f6e97 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx @@ -214,6 +214,7 @@ const ToolSelector: FC = ({ onSelect({ ...value, use_end_user_credentials: enabled, + auth_type: enabled ? 'end_user' : 'workspace', } as any) } const handleEndUserCredentialTypeChange = (type: string) => { diff --git a/web/app/components/workflow/nodes/tool/types.ts b/web/app/components/workflow/nodes/tool/types.ts index da3b7f7b31..05ee59a10b 100644 --- a/web/app/components/workflow/nodes/tool/types.ts +++ b/web/app/components/workflow/nodes/tool/types.ts @@ -24,4 +24,5 @@ export type ToolNodeType = CommonNodeType & { provider_icon?: Collection['icon'] provider_icon_dark?: Collection['icon_dark'] plugin_unique_identifier?: string + auth_type?: 'workspace' | 'end_user' }