feat(oauth): enhance OAuth client management and validation

This commit is contained in:
Harry 2025-07-11 16:28:40 +08:00
parent fb9e4a4227
commit adc39f7b0d
4 changed files with 50 additions and 16 deletions

View File

@ -44,6 +44,7 @@ def upgrade():
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
# ### end Alembic commands ###

View File

@ -109,7 +109,10 @@ class ApiToolProvider(Base):
"""
__tablename__ = "tool_api_providers"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),)
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# name of the api provider
@ -326,18 +329,17 @@ class MCPToolProvider(Base):
@property
def decrypted_credentials(self) -> dict:
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController._from_db(self)
tool_configuration = ProviderConfigEncrypter(
return create_provider_encrypter(
tenant_id=self.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
)
return tool_configuration.decrypt(self.credentials, use_cache=False)
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
cache=NoOpProviderCredentialCache(),
)[0].decrypt(self.credentials)
class ToolModelInvoke(Base):

View File

@ -43,12 +43,22 @@ class BuiltinToolManageService:
get builtin tool provider oauth client schema
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return {
is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
tenant_id, provider_name
)
is_system_oauth_params_exists = BuiltinToolManageService.is_oauth_system_client_exists(provider_name)
result = {
"schema": provider.get_oauth_client_schema(),
"is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled(
tenant_id, provider_name
),
"is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
"is_system_oauth_params_exists": is_system_oauth_params_exists,
}
if is_oauth_custom_client_enabled:
result["client_params"] = BuiltinToolManageService.get_oauth_client(tenant_id, provider_name)
result["redirect_uri"] = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback"
)
return result
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
@ -415,6 +425,20 @@ class BuiltinToolManageService:
session.commit()
return {"result": "success"}
@staticmethod
def is_oauth_system_client_exists(provider_name: str) -> bool:
"""
check if oauth system client exists
"""
tool_provider = ToolProviderID(provider_name)
with Session(db.engine).no_autoflush as session:
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
return system_client is not None
@staticmethod
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
"""
@ -685,4 +709,10 @@ class BuiltinToolManageService:
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
return {
"oauth_params": encrypter.mask_tool_credentials(
encrypter.decrypt(custom_oauth_client_params.oauth_params)
),
"enabled": custom_oauth_client_params.enabled,
}

View File

@ -7,13 +7,14 @@ from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
@ -69,6 +70,7 @@ class MCPToolManageService:
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
@ -197,8 +199,7 @@ class MCPToolManageService:
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.updated_at = datetime.now()