feat(oauth): merge tool oauth and remove sequence number branches

This commit is contained in:
Harry 2025-06-25 14:51:55 +08:00
parent 1a2dfd950e
commit ce4cc54cc9
5 changed files with 32 additions and 12 deletions

View File

@ -676,14 +676,9 @@ class ToolPluginOAuthApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
# check if user client is configured and enabled then using user client
# if user client is not configured then using system client
tenant_id = user.current_tenant_id
user_id = user.id
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider(
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
plugin_id=plugin_id,
)
@ -727,9 +722,8 @@ class ToolOAuthCallback(Resource):
context.get("provider"),
)
oauth_handler = OAuthHandler()
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider(
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
plugin_id=plugin_id,
)

View File

@ -579,7 +579,7 @@ class ToolManager:
if "builtin" in filters:
def get_builtin_providers(tenant_id):
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id

View File

@ -64,4 +64,4 @@ class GithubProvider(ToolProvider):
if response.status_code != 200:
raise ToolProviderCredentialValidationError(response.json().get("message"))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,25 @@
"""merge tool oauth and remove sequence number branches
Revision ID: 46d46b3f389c
Revises: 0ab65e1cc7fa, 71f5020c6470
Create Date: 2025-06-25 11:01:55.215896
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '46d46b3f389c'
down_revision = ('0ab65e1cc7fa', '71f5020c6470')
branch_labels = None
depends_on = None
def upgrade():
pass
def downgrade():
pass

View File

@ -299,7 +299,7 @@ class BuiltinToolManageService:
db.session.delete(provider_obj)
db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter(
@ -334,7 +334,7 @@ class BuiltinToolManageService:
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider(tenant_id: str, user_id: str, provider: str, plugin_id: str):
def get_builtin_tool_oauth_client(tenant_id: str, provider: str, plugin_id: str):
"""
get builtin tool provider
"""
@ -450,6 +450,7 @@ class BuiltinToolManageService:
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
def _query(provider_filters: list[ColumnExpressionArgument[bool]]):
return (
db.session.query(BuiltinToolProvider)