feat: plugin OAuth with stateful

This commit is contained in:
Harry 2025-06-20 10:34:57 +08:00
parent 366ddb05ae
commit 12c20ec7f6
15 changed files with 809 additions and 72 deletions

View File

@ -38,4 +38,4 @@ else:
celery = app.extensions["celery"]
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)
app.run(host="0.0.0.0", port=5001,debug=True)

View File

@ -1,18 +1,27 @@
import io
from flask import send_file
from flask import redirect, request, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from flask_restful import (
Resource,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.tool_labels_service import ToolLabelsService
@ -108,17 +117,19 @@ class ToolBuiltinProviderUpdateApi(Resource):
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
credentials=args["credentials"],
credential_id=args["credential_id"],
name=args["name"]
)
session.commit()
return result
@ -555,9 +566,9 @@ class ToolBuiltinListApi(Resource):
[
provider.to_dict()
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
user_id,
tenant_id,
)
]
)
@ -576,9 +587,9 @@ class ToolApiListApi(Resource):
[
provider.to_dict()
for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)
user_id,
tenant_id,
)
]
)
@ -597,9 +608,9 @@ class ToolWorkflowListApi(Resource):
[
provider.to_dict()
for provider in WorkflowToolManageService.list_tenant_workflow_tools(
user_id,
tenant_id,
)
user_id,
tenant_id,
)
]
)
@ -613,6 +624,121 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
# todo check permission
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
# check if user client is configured and enabled then using user client
# if user client is not configured then using system client
tenant_id = user.current_tenant_id
user_id = user.id
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
plugin_id=plugin_id,
)
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(user_id=current_user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider)
# todo decrypt oauth params
oauth_params = plugin_oauth_config.oauth_params
oauth_params[
'redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
response = oauth_handler.get_authorization_url(
tenant_id,
user.id,
plugin_id,
provider,
system_credentials=oauth_params,
)
return response.model_dump()
class ToolOAuthCallback(Resource):
@setup_required
def get(self):
args = (reqparse
.RequestParser()
.add_argument("context_id", type=str, required=True, nullable=False, location="args")
.parse_args()
)
context_id = args["context_id"]
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id, plugin_id, provider = (
context.get("user_id"),
context.get("tenant_id"),
context.get("plugin_id"),
context.get("provider"),
)
oauth_handler = OAuthHandler()
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
plugin_id=plugin_id,
)
oauth_params = plugin_oauth_config.oauth_params
oauth_params['redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
credentials = oauth_handler.get_credentials(
tenant_id,
user_id,
plugin_id,
provider,
system_credentials=oauth_params,
request=request,
)
if not credentials:
raise Exception("no credentials found for this plugin")
#TODO add credentials to database
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
provider=provider,
id=args["id"])
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback")
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
@ -621,6 +747,8 @@ api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-prov
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(ToolBuiltinProviderSetDefaultApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/set-default")
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)

View File

@ -1,3 +1,4 @@
import binascii
from collections.abc import Mapping
from typing import Any
@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient):
provider: str,
system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse:
return self._request_with_plugin_daemon_response(
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
PluginOAuthAuthorizationUrlResponse,
@ -32,6 +33,10 @@ class OAuthHandler(BasePluginClient):
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
def get_credentials(
self,
@ -49,7 +54,7 @@ class OAuthHandler(BasePluginClient):
# encode request to raw http request
raw_request_bytes = self._convert_request_to_raw_data(request)
return self._request_with_plugin_daemon_response(
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
PluginOAuthCredentialsResponse,
@ -58,7 +63,8 @@ class OAuthHandler(BasePluginClient):
"data": {
"provider": provider,
"system_credentials": system_credentials,
"raw_request_bytes": raw_request_bytes,
# for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
},
},
headers={
@ -66,6 +72,10 @@ class OAuthHandler(BasePluginClient):
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
def _convert_request_to_raw_data(self, request: Request) -> bytes:
"""
@ -79,7 +89,7 @@ class OAuthHandler(BasePluginClient):
"""
# Start with the request line
method = request.method
path = request.path
path = request.full_path
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
raw_data = f"{method} {path} {protocol}\r\n".encode()

View File

@ -0,0 +1,67 @@
import secrets
import urllib.parse
from collections.abc import Mapping
from typing import Any
import requests
from dify_plugin import ToolProvider
from dify_plugin.errors.tool import ToolProviderCredentialValidationError
from werkzeug import Request
class GithubProvider(ToolProvider):
_AUTH_URL = "https://github.com/login/oauth/authorize"
_TOKEN_URL = "https://github.com/login/oauth/access_token"
_API_USER_URL = "https://api.github.com/user"
def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str:
"""
Generate the authorization URL for the Github OAuth.
"""
state = secrets.token_urlsafe(16)
params = {
"client_id": system_credentials["client_id"],
"redirect_uri": system_credentials["redirect_uri"],
"scope": system_credentials.get("scope", "read:user"),
"state": state,
# Optionally: allow_signup, login, etc.
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]:
"""
Exchange code for access_token.
"""
code = request.args.get("code")
state = request.args.get("state")
if not code:
raise ValueError("No code provided")
# Optionally: validate state here
data = {
"client_id": system_credentials["client_id"],
"client_secret": system_credentials["client_secret"],
"code": code,
"redirect_uri": system_credentials["redirect_uri"],
}
headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10)
response_json = response.json()
access_token = response_json.get("access_token")
if not access_token:
raise ValueError(f"Error in GitHub OAuth: {response_json}")
return {"access_token": access_token}
def _validate_credentials(self, credentials: dict) -> None:
try:
if "access_token" not in credentials or not credentials.get("access_token"):
raise ToolProviderCredentialValidationError("GitHub API Access Token is required.")
headers = {
"Authorization": f"Bearer {credentials['access_token']}",
"Accept": "application/vnd.github+json",
}
response = requests.get(self._API_USER_URL, headers=headers, timeout=10)
if response.status_code != 200:
raise ToolProviderCredentialValidationError(response.json().get("message"))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -1,3 +1,4 @@
import os
from datetime import timedelta
import pytz
@ -24,12 +25,17 @@ def init_app(app: DifyApp) -> Celery:
},
}
flask_debugging = os.environ.get("FLASK_DEBUG", "0").lower() in {"true", "1", "yes"}
celery_app = Celery(
app.name,
task_cls=FlaskTask,
broker=dify_config.CELERY_BROKER_URL,
backend=dify_config.CELERY_BACKEND,
task_ignore_result=True,
task_always_eager=flask_debugging,
task_eager_propagates=flask_debugging,
)
# Add SSL options to the Celery configuration

View File

@ -0,0 +1,66 @@
"""add tool oauth credentials
Revision ID: 99310d2c25a6
Revises: 4474872b0ee6
Create Date: 2025-06-18 15:06:15.261915
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '99310d2c25a6'
down_revision = '4474872b0ee6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_oauth_system_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
)
op.create_table('tool_oauth_user_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_user_client_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_user_client')
)
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
batch_op.alter_column('credential_type',
existing_type=sa.VARCHAR(length=255),
type_=sa.String(length=32),
existing_nullable=False,
existing_server_default=sa.text("'api_key'::character varying"))
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
batch_op.create_unique_constraint('unique_builtin_tool_provider', ['tenant_id', 'provider', 'credential_type'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique')
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
batch_op.alter_column('credential_type',
existing_type=sa.String(length=32),
type_=sa.VARCHAR(length=255),
existing_nullable=False,
existing_server_default=sa.text("'api_key'::character varying"))
batch_op.drop_column('default')
op.drop_table('tool_oauth_user_clients')
op.drop_table('tool_oauth_system_clients')
# ### end Alembic commands ###

View File

@ -0,0 +1,39 @@
"""multiple credential
Revision ID: 222376193a49
Revises: 99310d2c25a6
Create Date: 2025-06-19 11:33:46.400455
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '222376193a49'
down_revision = '99310d2c25a6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
batch_op.add_column(sa.Column('owner_type', sa.Text(), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
batch_op.drop_column('owner_type')
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'credential_type'])
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""multiple credential
Revision ID: a9306e69af07
Revises: 222376193a49
Create Date: 2025-06-19 13:53:41.554159
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a9306e69af07'
down_revision = '222376193a49'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.create_unique_constraint('unique_builtin_tool_provider', ['provider', 'tenant_id', 'default'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique')
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""multiple credential
Revision ID: 6835b906335f
Revises: e315d2a83984
Create Date: 2025-06-19 13:59:58.107955
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '6835b906335f'
down_revision = 'e315d2a83984'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['provider', 'tenant_id', 'default'])
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""multiple credential
Revision ID: e315d2a83984
Revises: a9306e69af07
Create Date: 2025-06-19 13:59:13.860523
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'e315d2a83984'
down_revision = 'a9306e69af07'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id'])
# ### end Alembic commands ###

View File

@ -0,0 +1,53 @@
"""multiple credential
Revision ID: 110e30078dd3
Revises: 6835b906335f
Create Date: 2025-06-19 15:11:42.688478
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '110e30078dd3'
down_revision = '6835b906335f'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op:
batch_op.alter_column('plugin_id',
existing_type=sa.UUID(),
type_=sa.String(length=512),
existing_nullable=False)
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
batch_op.add_column(sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False))
batch_op.alter_column('plugin_id',
existing_type=sa.UUID(),
type_=sa.String(length=512),
existing_nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
batch_op.alter_column('plugin_id',
existing_type=sa.String(length=512),
type_=sa.UUID(),
existing_nullable=False)
batch_op.drop_column('enabled')
with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op:
batch_op.alter_column('plugin_id',
existing_type=sa.String(length=512),
type_=sa.UUID(),
existing_nullable=False)
# ### end Alembic commands ###

View File

@ -1,3 +1,4 @@
import enum
import json
from datetime import datetime
from typing import Any, cast
@ -17,6 +18,65 @@ from .model import Account, App, Tenant
from .types import StringUUID
class ToolProviderCredentialType(enum.StrEnum):
API_KEY = "api_key",
OAUTH2 = "oauth2",
def is_editable(self):
return self == ToolProviderCredentialType.API_KEY
@classmethod
def get_credential_type(cls, credential_type: str) -> "ToolProviderCredentialType":
if credential_type == "api_key":
return cls.API_KEY
elif credential_type == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")
# system level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthSystemClient(Base):
__tablename__ = "tool_oauth_system_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
# owner type, e.g., "system", "user"
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params))
# user level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthUserClient(Base):
__tablename__ = "tool_oauth_user_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_user_client_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_user_client"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
owner_type: Mapped[str] = mapped_column(db.Text, nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params))
class BuiltinToolProvider(Base):
"""
This table stores the tool provider information for built-in tools for each tenant.
@ -25,12 +85,11 @@ class BuiltinToolProvider(Base):
__tablename__ = "tool_builtin_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
# one tenant can only have one tool provider with the same name
db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"),
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(db.String(256), nullable=False)
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider
@ -45,6 +104,11 @@ class BuiltinToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
default: Mapped[bool] = mapped_column(
db.Boolean, nullable=False, server_default=db.text("false")
)
# credential type, e.g., "api_key", "oauth2"
credential_type: Mapped[str] = mapped_column(db.String(32), nullable=False, server_default=db.text("'api_key'::character varying"))
@property
def credentials(self) -> dict:
@ -59,7 +123,6 @@ class ApiToolProvider(Base):
__tablename__ = "tool_api_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))

View File

@ -1,7 +1,62 @@
import json
import uuid
from core.plugin.impl.base import BasePluginClient
from extensions.ext_redis import redis_client
class OAuthService(BasePluginClient):
@classmethod
def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str:
return "1234567890"
class OAuthProxyService(BasePluginClient):
# Default max age for proxy context parameter in seconds
__MAX_AGE__ = 5 * 60 # 5 minutes
@staticmethod
def create_proxy_context(user_id, tenant_id, plugin_id, provider):
"""
Create a proxy context for an OAuth 2.0 authorization request.
This parameter is a crucial security measure to prevent Cross-Site Request
Forgery (CSRF) attacks. It works by generating a unique nonce and storing it
in a distributed cache (Redis) along with the user's session context.
The returned nonce should be included as the 'proxy_context' parameter in the
authorization URL. Upon callback, the `retrieve_proxy_context` method
is used to verify the state, ensuring the request's integrity and authenticity,
and mitigating replay attacks.
"""
seconds, microseconds = redis_client.time()
context_id = str(uuid.uuid4())
data = {
"user_id": user_id,
"plugin_id": plugin_id,
"tenant_id": tenant_id,
"provider": provider,
# encode redis time to avoid distribution time skew
"timestamp": seconds,
}
# ignore nonce collision
redis_client.setex(
f"oauth_proxy_context:{context_id}",
OAuthProxyService.__MAX_AGE__,
json.dumps(data),
)
return context_id
@staticmethod
def use_proxy_context(context_id, max_age=__MAX_AGE__):
"""
Validate the proxy context parameter.
This checks if the context_id is valid and not expired.
"""
if not context_id:
raise ValueError("context_id is required")
# get data from redis
data = redis_client.getdel(f"oauth_proxy_context:{context_id}")
if not data:
raise ValueError("context_id is invalid")
# check if data is expired
seconds, microseconds = redis_client.time()
state = json.loads(data)
if state.get("timestamp") < seconds - max_age:
raise ValueError("context_id is expired")
return state

View File

@ -2,8 +2,6 @@ import json
import logging
from pathlib import Path
from sqlalchemy.orm import Session
from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
@ -16,7 +14,7 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthUserClient, ToolProviderCredentialType
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
@ -109,63 +107,69 @@ class BuiltinToolManageService:
@staticmethod
def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
user_id: str, tenant_id: str, provider_name:str, credentials: dict, credential_id: str, name: str | None = None
):
"""
update builtin tool provider
"""
# get if the provider exists
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
if not ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable():
raise ValueError(f"you cannot update oauth2 provider {provider_name} credentials")
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# exclude oauth2 provider
if provider.credential_type != ToolProviderCredentialType.OAUTH2.value:
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
# validate credentials
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
credentials = tool_configuration.encrypt(credentials)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# Decrypt and restore original credentials for masked values
credentials = BuiltinToolManageService._dec
rypt_and_restore_credentials(
provider_controller, tool_configuration, provider, credentials
)
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
# update name if provided
if name is not None and provider.name != name:
provider.name = name
db.session.commit()
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
raise ValueError(str(e))
if provider is None:
# create provider
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
encrypted_credentials=json.dumps(credentials),
)
return {"result": "success"}
db.session.add(provider)
else:
provider.encrypted_credentials = json.dumps(credentials)
@staticmethod
def add_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str, credentials: dict, name: str | None = None
):
"""
add builtin tool provider
"""
# delete cache
tool_configuration.delete_tool_credentials_cache()
db.session.commit()
return {"result": "success"}
@staticmethod
@ -214,6 +218,78 @@ class BuiltinToolManageService:
return {"result": "success"}
@staticmethod
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
"""
set default provider
"""
# get provider
target_provider = db.session.query(BuiltinToolProvider).filter_by(id=id).first()
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
db.session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
default=True
).update({"default": False})
# set new default provider
target_provider.default = True
db.session.commit()
return {"result": "success"}
@staticmethod
def fetch_default_provider(tenant_id: str, user_id: str, provider_name: str):
"""
fetch default provider
if there is no explicitly set default provider, return the oldest provider as default
"""
# 1. check if default provider exists
default_provider = db.session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
default=True
).first()
if default_provider:
return default_provider
# 2. if no default provider, set the oldest provider as default
oldest_provider = (db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, user_id=user_id, provider=provider_name)
.order_by(BuiltinToolProvider.created_at)
.first()
)
if oldest_provider:
return oldest_provider
raise ValueError(f"no default provider found for {provider_name}")
@staticmethod
def get_builtin_tool_provider(tenant_id: str, user_id: str, provider: str, plugin_id: str):
"""
get builtin tool provider
"""
user_client = db.session.query(ToolOAuthUserClient).filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
).first()
if user_client:
plugin_oauth_config = user_client
else:
plugin_oauth_config = db.session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
if plugin_oauth_config:
return plugin_oauth_config
raise ValueError("no oauth available config found for this plugin")
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
"""
@ -286,6 +362,15 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None:
provider = (db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first())
return provider
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
try:
@ -327,3 +412,42 @@ class BuiltinToolManageService:
)
.first()
)
@staticmethod
def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials):
"""
Decrypt original credentials and restore masked values from the input credentials
:param provider_controller: the provider controller
:param tool_configuration: the tool configuration encrypter
:param provider: the provider object from database
:param credentials: the input credentials from user
:return: the processed credentials with original values restored
"""
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: # type: ignore
credentials[name] = original_credentials[name] # type: ignore
return credentials
@staticmethod
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):
"""
Validate and encrypt credentials, then save to database
:param provider_controller: the provider controller
:param tool_configuration: the tool configuration encrypter
:param provider: the provider object from database
:param credentials: the credentials to encrypt and save
:param user_id: the user id for validation
"""
# validate credentials
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.encrypted_credentials = json.dumps(encrypted_credentials)
tool_configuration.delete_tool_credentials_cache()

27
api/tool_oauth.http Normal file
View File

@ -0,0 +1,27 @@
@accessToken=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiYjM4Y2Y5N2MtODNiYS00MWI3LWEyZjMtMzZlOTgzZjE4YmQ5IiwiZXhwIjoxNzUwNDE3NDI0LCJpc3MiOiJTRUxGX0hPU1RFRCIsInN1YiI6IkNvbnNvbGUgQVBJIFBhc3Nwb3J0In0.pPCkISnSmnu3hOCyEVTIJoNeWxtx7E9LNy0cDQUy__Q
# set default credential
POST /console/api/workspaces/current/tool-provider/builtin/langgenius/github/github/set-default
Host: 127.0.0.1:5001
Content-Type: application/json
Authorization: Bearer {{accessToken}}
{
"id": "55fb78d2-0ce6-4496-9488-3b8d9f40818f"
}
###
# get oauth url
GET /console/api/oauth/plugin/tool?plugin_id=c58a1845-f3a4-4d93-b749-a71e9998b702/github&provider=github
Host: 127.0.0.1:5001
Authorization: Bearer {{accessToken}}
###
# get oauth token
GET /console/api/oauth/plugin/tool/callback?state=734072c2-d8ed-4b0b-8ed8-4efd69d15a4f&code=e2d68a6216a3b7d70d2f&state=NQCjFkMKtf32XCMHc8KBdw
Host: 127.0.0.1:5001
Authorization: Bearer {{accessToken}}