From adc39f7b0db283e514d930be9a29c0345eb5a906 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 11 Jul 2025 16:28:40 +0800 Subject: [PATCH] feat(oauth): enhance OAuth client management and validation --- ...2025_07_04_1705-71f5020c6470_tool_oauth.py | 1 + api/models/tools.py | 18 +++++---- .../tools/builtin_tools_manage_service.py | 40 ++++++++++++++++--- api/services/tools/mcp_tools_mange_service.py | 7 ++-- 4 files changed, 50 insertions(+), 16 deletions(-) diff --git a/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py index 32cc08ab1a..ad73563246 100644 --- a/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py +++ b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py @@ -44,6 +44,7 @@ def upgrade(): batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False)) batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name']) + # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index 05a4920a9c..34bc97d006 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -109,7 +109,10 @@ class ApiToolProvider(Base): """ __tablename__ = "tool_api_providers" - __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),) + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), + ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider @@ -326,18 +329,17 @@ class MCPToolProvider(Base): @property def decrypted_credentials(self) -> dict: + from core.helper.provider_cache import NoOpProviderCredentialCache from core.tools.mcp_tool.provider import MCPToolProviderController - from core.tools.utils.configuration import ProviderConfigEncrypter + from core.tools.utils.encryption import create_provider_encrypter provider_controller = MCPToolProviderController._from_db(self) - tool_configuration = ProviderConfigEncrypter( + return create_provider_encrypter( tenant_id=self.tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.provider_id, - ) - return tool_configuration.decrypt(self.credentials, use_cache=False) + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + cache=NoOpProviderCredentialCache(), + )[0].decrypt(self.credentials) class ToolModelInvoke(Base): diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index fea74ba492..66157fb6b6 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -43,12 +43,22 @@ class BuiltinToolManageService: get builtin tool provider oauth client schema """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return { + + is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled( + tenant_id, provider_name + ) + is_system_oauth_params_exists = BuiltinToolManageService.is_oauth_system_client_exists(provider_name) + result = { "schema": provider.get_oauth_client_schema(), - "is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled( - tenant_id, provider_name - ), + "is_oauth_custom_client_enabled": is_oauth_custom_client_enabled, + "is_system_oauth_params_exists": is_system_oauth_params_exists, } + if is_oauth_custom_client_enabled: + result["client_params"] = BuiltinToolManageService.get_oauth_client(tenant_id, provider_name) + result["redirect_uri"] = ( + f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback" + ) + return result @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: @@ -415,6 +425,20 @@ class BuiltinToolManageService: session.commit() return {"result": "success"} + @staticmethod + def is_oauth_system_client_exists(provider_name: str) -> bool: + """ + check if oauth system client exists + """ + tool_provider = ToolProviderID(provider_name) + with Session(db.engine).no_autoflush as session: + system_client: ToolOAuthSystemClient | None = ( + session.query(ToolOAuthSystemClient) + .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) + .first() + ) + return system_client is not None + @staticmethod def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool: """ @@ -685,4 +709,10 @@ class BuiltinToolManageService: config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) - return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params)) + + return { + "oauth_params": encrypter.mask_tool_credentials( + encrypter.decrypt(custom_oauth_client_params.oauth_params) + ), + "enabled": custom_oauth_client_params.enabled, + } diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index 7c23abda4b..fda6da5983 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -7,13 +7,14 @@ from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError from core.helper import encrypter +from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.error import MCPAuthError, MCPError from core.mcp.mcp_client import MCPClient from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType from core.tools.mcp_tool.provider import MCPToolProviderController -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import ProviderConfigEncrypter from extensions.ext_database import db from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -69,6 +70,7 @@ class MCPToolManageService: MCPToolProvider.server_url_hash == server_url_hash, MCPToolProvider.server_identifier == server_identifier, ), + MCPToolProvider.tenant_id == tenant_id, ) .first() ) @@ -197,8 +199,7 @@ class MCPToolManageService: tool_configuration = ProviderConfigEncrypter( tenant_id=mcp_provider.tenant_id, config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.provider_id, + provider_config_cache=NoOpProviderCredentialCache(), ) credentials = tool_configuration.encrypt(credentials) mcp_provider.updated_at = datetime.now()