feat: plugin OAuth with stateful

This commit is contained in:
Harry 2025-06-20 10:34:57 +08:00
parent 366ddb05ae
commit 12c20ec7f6
15 changed files with 809 additions and 72 deletions

View File

@ -38,4 +38,4 @@ else:
celery = app.extensions["celery"] celery = app.extensions["celery"]
if __name__ == "__main__": if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001) app.run(host="0.0.0.0", port=5001,debug=True)

View File

@ -1,18 +1,27 @@
import io import io
from flask import send_file from flask import redirect, request, send_file
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, reqparse from flask_restful import (
Resource,
reqparse,
)
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value from libs.helper import alphanumeric, uuid_value
from libs.login import login_required from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.tool_labels_service import ToolLabelsService from services.tools.tool_labels_service import ToolLabelsService
@ -108,17 +117,19 @@ class ToolBuiltinProviderUpdateApi(Resource):
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider( result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
provider_name=provider,
credentials=args["credentials"], credentials=args["credentials"],
credential_id=args["credential_id"],
name=args["name"]
) )
session.commit() session.commit()
return result return result
@ -555,9 +566,9 @@ class ToolBuiltinListApi(Resource):
[ [
provider.to_dict() provider.to_dict()
for provider in BuiltinToolManageService.list_builtin_tools( for provider in BuiltinToolManageService.list_builtin_tools(
user_id, user_id,
tenant_id, tenant_id,
) )
] ]
) )
@ -576,9 +587,9 @@ class ToolApiListApi(Resource):
[ [
provider.to_dict() provider.to_dict()
for provider in ApiToolManageService.list_api_tools( for provider in ApiToolManageService.list_api_tools(
user_id, user_id,
tenant_id, tenant_id,
) )
] ]
) )
@ -597,9 +608,9 @@ class ToolWorkflowListApi(Resource):
[ [
provider.to_dict() provider.to_dict()
for provider in WorkflowToolManageService.list_tenant_workflow_tools( for provider in WorkflowToolManageService.list_tenant_workflow_tools(
user_id, user_id,
tenant_id, tenant_id,
) )
] ]
) )
@ -613,6 +624,121 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels()) return jsonable_encoder(ToolLabelsService.list_tool_labels())
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
# todo check permission
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
# check if user client is configured and enabled then using user client
# if user client is not configured then using system client
tenant_id = user.current_tenant_id
user_id = user.id
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
plugin_id=plugin_id,
)
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(user_id=current_user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider)
# todo decrypt oauth params
oauth_params = plugin_oauth_config.oauth_params
oauth_params[
'redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
response = oauth_handler.get_authorization_url(
tenant_id,
user.id,
plugin_id,
provider,
system_credentials=oauth_params,
)
return response.model_dump()
class ToolOAuthCallback(Resource):
@setup_required
def get(self):
args = (reqparse
.RequestParser()
.add_argument("context_id", type=str, required=True, nullable=False, location="args")
.parse_args()
)
context_id = args["context_id"]
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id, plugin_id, provider = (
context.get("user_id"),
context.get("tenant_id"),
context.get("plugin_id"),
context.get("provider"),
)
oauth_handler = OAuthHandler()
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
plugin_id=plugin_id,
)
oauth_params = plugin_oauth_config.oauth_params
oauth_params['redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
credentials = oauth_handler.get_credentials(
tenant_id,
user_id,
plugin_id,
provider,
system_credentials=oauth_params,
request=request,
)
if not credentials:
raise Exception("no credentials found for this plugin")
#TODO add credentials to database
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
provider=provider,
id=args["id"])
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback")
# tool provider # tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
@ -621,6 +747,8 @@ api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-prov
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info") api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(ToolBuiltinProviderSetDefaultApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/set-default")
api.add_resource( api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials" ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
) )

View File

@ -1,3 +1,4 @@
import binascii
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient):
provider: str, provider: str,
system_credentials: Mapping[str, Any], system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse: ) -> PluginOAuthAuthorizationUrlResponse:
return self._request_with_plugin_daemon_response( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
PluginOAuthAuthorizationUrlResponse, PluginOAuthAuthorizationUrlResponse,
@ -32,6 +33,10 @@ class OAuthHandler(BasePluginClient):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
def get_credentials( def get_credentials(
self, self,
@ -49,7 +54,7 @@ class OAuthHandler(BasePluginClient):
# encode request to raw http request # encode request to raw http request
raw_request_bytes = self._convert_request_to_raw_data(request) raw_request_bytes = self._convert_request_to_raw_data(request)
return self._request_with_plugin_daemon_response( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/oauth/get_credentials", f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
PluginOAuthCredentialsResponse, PluginOAuthCredentialsResponse,
@ -58,7 +63,8 @@ class OAuthHandler(BasePluginClient):
"data": { "data": {
"provider": provider, "provider": provider,
"system_credentials": system_credentials, "system_credentials": system_credentials,
"raw_request_bytes": raw_request_bytes, # for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
}, },
}, },
headers={ headers={
@ -66,6 +72,10 @@ class OAuthHandler(BasePluginClient):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
def _convert_request_to_raw_data(self, request: Request) -> bytes: def _convert_request_to_raw_data(self, request: Request) -> bytes:
""" """
@ -79,7 +89,7 @@ class OAuthHandler(BasePluginClient):
""" """
# Start with the request line # Start with the request line
method = request.method method = request.method
path = request.path path = request.full_path
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
raw_data = f"{method} {path} {protocol}\r\n".encode() raw_data = f"{method} {path} {protocol}\r\n".encode()

View File

@ -0,0 +1,67 @@
import secrets
import urllib.parse
from collections.abc import Mapping
from typing import Any
import requests
from dify_plugin import ToolProvider
from dify_plugin.errors.tool import ToolProviderCredentialValidationError
from werkzeug import Request
class GithubProvider(ToolProvider):
_AUTH_URL = "https://github.com/login/oauth/authorize"
_TOKEN_URL = "https://github.com/login/oauth/access_token"
_API_USER_URL = "https://api.github.com/user"
def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str:
"""
Generate the authorization URL for the Github OAuth.
"""
state = secrets.token_urlsafe(16)
params = {
"client_id": system_credentials["client_id"],
"redirect_uri": system_credentials["redirect_uri"],
"scope": system_credentials.get("scope", "read:user"),
"state": state,
# Optionally: allow_signup, login, etc.
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]:
"""
Exchange code for access_token.
"""
code = request.args.get("code")
state = request.args.get("state")
if not code:
raise ValueError("No code provided")
# Optionally: validate state here
data = {
"client_id": system_credentials["client_id"],
"client_secret": system_credentials["client_secret"],
"code": code,
"redirect_uri": system_credentials["redirect_uri"],
}
headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10)
response_json = response.json()
access_token = response_json.get("access_token")
if not access_token:
raise ValueError(f"Error in GitHub OAuth: {response_json}")
return {"access_token": access_token}
def _validate_credentials(self, credentials: dict) -> None:
try:
if "access_token" not in credentials or not credentials.get("access_token"):
raise ToolProviderCredentialValidationError("GitHub API Access Token is required.")
headers = {
"Authorization": f"Bearer {credentials['access_token']}",
"Accept": "application/vnd.github+json",
}
response = requests.get(self._API_USER_URL, headers=headers, timeout=10)
if response.status_code != 200:
raise ToolProviderCredentialValidationError(response.json().get("message"))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -1,3 +1,4 @@
import os
from datetime import timedelta from datetime import timedelta
import pytz import pytz
@ -24,12 +25,17 @@ def init_app(app: DifyApp) -> Celery:
}, },
} }
flask_debugging = os.environ.get("FLASK_DEBUG", "0").lower() in {"true", "1", "yes"}
celery_app = Celery( celery_app = Celery(
app.name, app.name,
task_cls=FlaskTask, task_cls=FlaskTask,
broker=dify_config.CELERY_BROKER_URL, broker=dify_config.CELERY_BROKER_URL,
backend=dify_config.CELERY_BACKEND, backend=dify_config.CELERY_BACKEND,
task_ignore_result=True, task_ignore_result=True,
task_always_eager=flask_debugging,
task_eager_propagates=flask_debugging,
) )
# Add SSL options to the Celery configuration # Add SSL options to the Celery configuration

View File

@ -0,0 +1,66 @@
"""add tool oauth credentials
Revision ID: 99310d2c25a6
Revises: 4474872b0ee6
Create Date: 2025-06-18 15:06:15.261915
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '99310d2c25a6'
down_revision = '4474872b0ee6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_oauth_system_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
)
op.create_table('tool_oauth_user_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_user_client_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_user_client')
)
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
batch_op.alter_column('credential_type',
existing_type=sa.VARCHAR(length=255),
type_=sa.String(length=32),
existing_nullable=False,
existing_server_default=sa.text("'api_key'::character varying"))
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
batch_op.create_unique_constraint('unique_builtin_tool_provider', ['tenant_id', 'provider', 'credential_type'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique')
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
batch_op.alter_column('credential_type',
existing_type=sa.String(length=32),
type_=sa.VARCHAR(length=255),
existing_nullable=False,
existing_server_default=sa.text("'api_key'::character varying"))
batch_op.drop_column('default')
op.drop_table('tool_oauth_user_clients')
op.drop_table('tool_oauth_system_clients')
# ### end Alembic commands ###

View File

@ -0,0 +1,39 @@
"""multiple credential
Revision ID: 222376193a49
Revises: 99310d2c25a6
Create Date: 2025-06-19 11:33:46.400455
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '222376193a49'
down_revision = '99310d2c25a6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
batch_op.add_column(sa.Column('owner_type', sa.Text(), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
batch_op.drop_column('owner_type')
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'credential_type'])
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""multiple credential
Revision ID: a9306e69af07
Revises: 222376193a49
Create Date: 2025-06-19 13:53:41.554159
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a9306e69af07'
down_revision = '222376193a49'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.create_unique_constraint('unique_builtin_tool_provider', ['provider', 'tenant_id', 'default'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique')
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""multiple credential
Revision ID: 6835b906335f
Revises: e315d2a83984
Create Date: 2025-06-19 13:59:58.107955
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '6835b906335f'
down_revision = 'e315d2a83984'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['provider', 'tenant_id', 'default'])
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""multiple credential
Revision ID: e315d2a83984
Revises: a9306e69af07
Create Date: 2025-06-19 13:59:13.860523
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'e315d2a83984'
down_revision = 'a9306e69af07'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id'])
# ### end Alembic commands ###

View File

@ -0,0 +1,53 @@
"""multiple credential
Revision ID: 110e30078dd3
Revises: 6835b906335f
Create Date: 2025-06-19 15:11:42.688478
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '110e30078dd3'
down_revision = '6835b906335f'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op:
batch_op.alter_column('plugin_id',
existing_type=sa.UUID(),
type_=sa.String(length=512),
existing_nullable=False)
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
batch_op.add_column(sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False))
batch_op.alter_column('plugin_id',
existing_type=sa.UUID(),
type_=sa.String(length=512),
existing_nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
batch_op.alter_column('plugin_id',
existing_type=sa.String(length=512),
type_=sa.UUID(),
existing_nullable=False)
batch_op.drop_column('enabled')
with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op:
batch_op.alter_column('plugin_id',
existing_type=sa.String(length=512),
type_=sa.UUID(),
existing_nullable=False)
# ### end Alembic commands ###

View File

@ -1,3 +1,4 @@
import enum
import json import json
from datetime import datetime from datetime import datetime
from typing import Any, cast from typing import Any, cast
@ -17,6 +18,65 @@ from .model import Account, App, Tenant
from .types import StringUUID from .types import StringUUID
class ToolProviderCredentialType(enum.StrEnum):
API_KEY = "api_key",
OAUTH2 = "oauth2",
def is_editable(self):
return self == ToolProviderCredentialType.API_KEY
@classmethod
def get_credential_type(cls, credential_type: str) -> "ToolProviderCredentialType":
if credential_type == "api_key":
return cls.API_KEY
elif credential_type == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")
# system level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthSystemClient(Base):
__tablename__ = "tool_oauth_system_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
# owner type, e.g., "system", "user"
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params))
# user level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthUserClient(Base):
__tablename__ = "tool_oauth_user_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_user_client_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_user_client"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
owner_type: Mapped[str] = mapped_column(db.Text, nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params))
class BuiltinToolProvider(Base): class BuiltinToolProvider(Base):
""" """
This table stores the tool provider information for built-in tools for each tenant. This table stores the tool provider information for built-in tools for each tenant.
@ -25,12 +85,11 @@ class BuiltinToolProvider(Base):
__tablename__ = "tool_builtin_providers" __tablename__ = "tool_builtin_providers"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
# one tenant can only have one tool provider with the same name
db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"),
) )
# id of the tool provider # id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(db.String(256), nullable=False)
# id of the tenant # id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider # who created this tool provider
@ -45,6 +104,11 @@ class BuiltinToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
) )
default: Mapped[bool] = mapped_column(
db.Boolean, nullable=False, server_default=db.text("false")
)
# credential type, e.g., "api_key", "oauth2"
credential_type: Mapped[str] = mapped_column(db.String(32), nullable=False, server_default=db.text("'api_key'::character varying"))
@property @property
def credentials(self) -> dict: def credentials(self) -> dict:
@ -59,7 +123,6 @@ class ApiToolProvider(Base):
__tablename__ = "tool_api_providers" __tablename__ = "tool_api_providers"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))

View File

@ -1,7 +1,62 @@
import json
import uuid
from core.plugin.impl.base import BasePluginClient from core.plugin.impl.base import BasePluginClient
from extensions.ext_redis import redis_client
class OAuthService(BasePluginClient): class OAuthProxyService(BasePluginClient):
@classmethod # Default max age for proxy context parameter in seconds
def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: __MAX_AGE__ = 5 * 60 # 5 minutes
return "1234567890"
@staticmethod
def create_proxy_context(user_id, tenant_id, plugin_id, provider):
"""
Create a proxy context for an OAuth 2.0 authorization request.
This parameter is a crucial security measure to prevent Cross-Site Request
Forgery (CSRF) attacks. It works by generating a unique nonce and storing it
in a distributed cache (Redis) along with the user's session context.
The returned nonce should be included as the 'proxy_context' parameter in the
authorization URL. Upon callback, the `retrieve_proxy_context` method
is used to verify the state, ensuring the request's integrity and authenticity,
and mitigating replay attacks.
"""
seconds, microseconds = redis_client.time()
context_id = str(uuid.uuid4())
data = {
"user_id": user_id,
"plugin_id": plugin_id,
"tenant_id": tenant_id,
"provider": provider,
# encode redis time to avoid distribution time skew
"timestamp": seconds,
}
# ignore nonce collision
redis_client.setex(
f"oauth_proxy_context:{context_id}",
OAuthProxyService.__MAX_AGE__,
json.dumps(data),
)
return context_id
@staticmethod
def use_proxy_context(context_id, max_age=__MAX_AGE__):
"""
Validate the proxy context parameter.
This checks if the context_id is valid and not expired.
"""
if not context_id:
raise ValueError("context_id is required")
# get data from redis
data = redis_client.getdel(f"oauth_proxy_context:{context_id}")
if not data:
raise ValueError("context_id is invalid")
# check if data is expired
seconds, microseconds = redis_client.time()
state = json.loads(data)
if state.get("timestamp") < seconds - max_age:
raise ValueError("context_id is expired")
return state

View File

@ -2,8 +2,6 @@ import json
import logging import logging
from pathlib import Path from pathlib import Path
from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from core.helper.position_helper import is_filtered from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -16,7 +14,7 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.utils.configuration import ProviderConfigEncrypter
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import BuiltinToolProvider from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthUserClient, ToolProviderCredentialType
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -109,63 +107,69 @@ class BuiltinToolManageService:
@staticmethod @staticmethod
def update_builtin_tool_provider( def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict user_id: str, tenant_id: str, provider_name:str, credentials: dict, credential_id: str, name: str | None = None
): ):
""" """
update builtin tool provider update builtin tool provider
""" """
# get if the provider exists # get if the provider exists
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
if not ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable():
raise ValueError(f"you cannot update oauth2 provider {provider_name} credentials")
try: try:
# get provider # exclude oauth2 provider
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) if provider.credential_type != ToolProviderCredentialType.OAUTH2.value:
if not provider_controller.need_credentials: provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
raise ValueError(f"provider {provider_name} does not need credentials") if not provider_controller.need_credentials:
tool_configuration = ProviderConfigEncrypter( raise ValueError(f"provider {provider_name} does not need credentials")
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# get original credentials if exists tool_configuration = ProviderConfigEncrypter(
if provider is not None: tenant_id=tenant_id,
original_credentials = tool_configuration.decrypt(provider.credentials) config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) provider_type=provider_controller.provider_type.value,
# check if the credential has changed, save the original credential provider_identity=provider_controller.entity.identity.name,
for name, value in credentials.items(): )
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name] # Decrypt and restore original credentials for masked values
# validate credentials credentials = BuiltinToolManageService._dec
provider_controller.validate_credentials(user_id, credentials) rypt_and_restore_credentials(
# encrypt credentials provider_controller, tool_configuration, provider, credentials
credentials = tool_configuration.encrypt(credentials) )
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
# update name if provided
if name is not None and provider.name != name:
provider.name = name
db.session.commit()
except ( except (
PluginDaemonClientSideError, PluginDaemonClientSideError,
ToolProviderNotFoundError, ToolProviderNotFoundError,
ToolNotFoundError, ToolNotFoundError,
ToolProviderCredentialValidationError, ToolProviderCredentialValidationError,
) as e: ) as e:
raise ValueError(str(e)) raise ValueError(str(e))
if provider is None: return {"result": "success"}
# create provider
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
encrypted_credentials=json.dumps(credentials),
)
db.session.add(provider) @staticmethod
else: def add_builtin_tool_provider(
provider.encrypted_credentials = json.dumps(credentials) user_id: str, tenant_id: str, provider_name: str, credentials: dict, name: str | None = None
):
"""
add builtin tool provider
"""
# delete cache
tool_configuration.delete_tool_credentials_cache()
db.session.commit()
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
@ -214,6 +218,78 @@ class BuiltinToolManageService:
return {"result": "success"} return {"result": "success"}
@staticmethod
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
"""
set default provider
"""
# get provider
target_provider = db.session.query(BuiltinToolProvider).filter_by(id=id).first()
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
db.session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
default=True
).update({"default": False})
# set new default provider
target_provider.default = True
db.session.commit()
return {"result": "success"}
@staticmethod
def fetch_default_provider(tenant_id: str, user_id: str, provider_name: str):
"""
fetch default provider
if there is no explicitly set default provider, return the oldest provider as default
"""
# 1. check if default provider exists
default_provider = db.session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
default=True
).first()
if default_provider:
return default_provider
# 2. if no default provider, set the oldest provider as default
oldest_provider = (db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, user_id=user_id, provider=provider_name)
.order_by(BuiltinToolProvider.created_at)
.first()
)
if oldest_provider:
return oldest_provider
raise ValueError(f"no default provider found for {provider_name}")
@staticmethod
def get_builtin_tool_provider(tenant_id: str, user_id: str, provider: str, plugin_id: str):
"""
get builtin tool provider
"""
user_client = db.session.query(ToolOAuthUserClient).filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
).first()
if user_client:
plugin_oauth_config = user_client
else:
plugin_oauth_config = db.session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
if plugin_oauth_config:
return plugin_oauth_config
raise ValueError("no oauth available config found for this plugin")
@staticmethod @staticmethod
def get_builtin_tool_provider_icon(provider: str): def get_builtin_tool_provider_icon(provider: str):
""" """
@ -286,6 +362,15 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result) return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None:
provider = (db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first())
return provider
@staticmethod @staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
try: try:
@ -327,3 +412,42 @@ class BuiltinToolManageService:
) )
.first() .first()
) )
@staticmethod
def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials):
"""
Decrypt original credentials and restore masked values from the input credentials
:param provider_controller: the provider controller
:param tool_configuration: the tool configuration encrypter
:param provider: the provider object from database
:param credentials: the input credentials from user
:return: the processed credentials with original values restored
"""
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: # type: ignore
credentials[name] = original_credentials[name] # type: ignore
return credentials
@staticmethod
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):
"""
Validate and encrypt credentials, then save to database
:param provider_controller: the provider controller
:param tool_configuration: the tool configuration encrypter
:param provider: the provider object from database
:param credentials: the credentials to encrypt and save
:param user_id: the user id for validation
"""
# validate credentials
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.encrypted_credentials = json.dumps(encrypted_credentials)
tool_configuration.delete_tool_credentials_cache()

27
api/tool_oauth.http Normal file
View File

@ -0,0 +1,27 @@
@accessToken=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiYjM4Y2Y5N2MtODNiYS00MWI3LWEyZjMtMzZlOTgzZjE4YmQ5IiwiZXhwIjoxNzUwNDE3NDI0LCJpc3MiOiJTRUxGX0hPU1RFRCIsInN1YiI6IkNvbnNvbGUgQVBJIFBhc3Nwb3J0In0.pPCkISnSmnu3hOCyEVTIJoNeWxtx7E9LNy0cDQUy__Q
# set default credential
POST /console/api/workspaces/current/tool-provider/builtin/langgenius/github/github/set-default
Host: 127.0.0.1:5001
Content-Type: application/json
Authorization: Bearer {{accessToken}}
{
"id": "55fb78d2-0ce6-4496-9488-3b8d9f40818f"
}
###
# get oauth url
GET /console/api/oauth/plugin/tool?plugin_id=c58a1845-f3a4-4d93-b749-a71e9998b702/github&provider=github
Host: 127.0.0.1:5001
Authorization: Bearer {{accessToken}}
###
# get oauth token
GET /console/api/oauth/plugin/tool/callback?state=734072c2-d8ed-4b0b-8ed8-4efd69d15a4f&code=e2d68a6216a3b7d70d2f&state=NQCjFkMKtf32XCMHc8KBdw
Host: 127.0.0.1:5001
Authorization: Bearer {{accessToken}}