diff --git a/api/app.py b/api/app.py index 4f393f6c20..11decffe96 100644 --- a/api/app.py +++ b/api/app.py @@ -38,4 +38,4 @@ else: celery = app.extensions["celery"] if __name__ == "__main__": - app.run(host="0.0.0.0", port=5001) + app.run(host="0.0.0.0", port=5001,debug=True) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2b1379bfb2..e3285a16c7 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,18 +1,27 @@ import io -from flask import send_file +from flask import redirect, request, send_file from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restful import ( + Resource, + reqparse, +) from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, +) from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.impl.oauth import OAuthHandler from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required +from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.tools.tool_labels_service import ToolLabelsService @@ -108,17 +117,19 @@ class ToolBuiltinProviderUpdateApi(Resource): tenant_id = user.current_tenant_id parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() with Session(db.engine) as session: result = BuiltinToolManageService.update_builtin_tool_provider( - session=session, user_id=user_id, tenant_id=tenant_id, - provider_name=provider, credentials=args["credentials"], + credential_id=args["credential_id"], + name=args["name"] ) session.commit() return result @@ -555,9 +566,9 @@ class ToolBuiltinListApi(Resource): [ provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -576,9 +587,9 @@ class ToolApiListApi(Resource): [ provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -597,9 +608,9 @@ class ToolWorkflowListApi(Resource): [ provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -613,6 +624,121 @@ class ToolLabelsApi(Resource): return jsonable_encoder(ToolLabelsService.list_tool_labels()) +class ToolPluginOAuthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + provider = args["provider"] + plugin_id = args["plugin_id"] + + # todo check permission + user = current_user + + if not user.is_admin_or_owner: + raise Forbidden() + + # check if user client is configured and enabled then using user client + # if user client is not configured then using system client + tenant_id = user.current_tenant_id + user_id = user.id + + plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + plugin_id=plugin_id, + ) + + oauth_handler = OAuthHandler() + context_id = OAuthProxyService.create_proxy_context(user_id=current_user.id, + tenant_id=tenant_id, + plugin_id=plugin_id, + provider=provider) + # todo decrypt oauth params + oauth_params = plugin_oauth_config.oauth_params + oauth_params[ + 'redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + + response = oauth_handler.get_authorization_url( + tenant_id, + user.id, + plugin_id, + provider, + system_credentials=oauth_params, + ) + return response.model_dump() + + +class ToolOAuthCallback(Resource): + + @setup_required + def get(self): + args = (reqparse + .RequestParser() + .add_argument("context_id", type=str, required=True, nullable=False, location="args") + .parse_args() + ) + context_id = args["context_id"] + context = OAuthProxyService.use_proxy_context(context_id) + if context is None: + raise Forbidden("Invalid context_id") + + user_id, tenant_id, plugin_id, provider = ( + context.get("user_id"), + context.get("tenant_id"), + context.get("plugin_id"), + context.get("provider"), + ) + oauth_handler = OAuthHandler() + plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + plugin_id=plugin_id, + ) + oauth_params = plugin_oauth_config.oauth_params + oauth_params['redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + + credentials = oauth_handler.get_credentials( + tenant_id, + user_id, + plugin_id, + provider, + system_credentials=oauth_params, + request=request, + ) + + if not credentials: + raise Exception("no credentials found for this plugin") + + #TODO add credentials to database + return redirect(f"{dify_config.CONSOLE_WEB_URL}") + + +class ToolBuiltinProviderSetDefaultApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument("id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + return BuiltinToolManageService.set_default_provider( + tenant_id=current_user.current_tenant_id, + user_id=current_user.id, + provider=provider, + id=args["id"]) + + +# tool oauth +api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool") +api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback") + # tool provider api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") @@ -621,6 +747,8 @@ api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-prov api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource(ToolBuiltinProviderSetDefaultApi, + "/workspaces/current/tool-provider/builtin//set-default") api.add_resource( ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index 91774984c8..13873b6ba8 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -1,3 +1,4 @@ +import binascii from collections.abc import Mapping from typing import Any @@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient): provider: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", PluginOAuthAuthorizationUrlResponse, @@ -32,6 +33,10 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + def get_credentials( self, @@ -49,7 +54,7 @@ class OAuthHandler(BasePluginClient): # encode request to raw http request raw_request_bytes = self._convert_request_to_raw_data(request) - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_credentials", PluginOAuthCredentialsResponse, @@ -58,7 +63,8 @@ class OAuthHandler(BasePluginClient): "data": { "provider": provider, "system_credentials": system_credentials, - "raw_request_bytes": raw_request_bytes, + # for json serialization + "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), }, }, headers={ @@ -66,6 +72,10 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + def _convert_request_to_raw_data(self, request: Request) -> bytes: """ @@ -79,7 +89,7 @@ class OAuthHandler(BasePluginClient): """ # Start with the request line method = request.method - path = request.path + path = request.full_path protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") raw_data = f"{method} {path} {protocol}\r\n".encode() diff --git a/api/dify-plugin-sdks/python/examples/github/provider/github.py b/api/dify-plugin-sdks/python/examples/github/provider/github.py new file mode 100644 index 0000000000..36f2f85910 --- /dev/null +++ b/api/dify-plugin-sdks/python/examples/github/provider/github.py @@ -0,0 +1,67 @@ +import secrets +import urllib.parse +from collections.abc import Mapping +from typing import Any + +import requests +from dify_plugin import ToolProvider +from dify_plugin.errors.tool import ToolProviderCredentialValidationError +from werkzeug import Request + + +class GithubProvider(ToolProvider): + _AUTH_URL = "https://github.com/login/oauth/authorize" + _TOKEN_URL = "https://github.com/login/oauth/access_token" + _API_USER_URL = "https://api.github.com/user" + + def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str: + """ + Generate the authorization URL for the Github OAuth. + """ + state = secrets.token_urlsafe(16) + params = { + "client_id": system_credentials["client_id"], + "redirect_uri": system_credentials["redirect_uri"], + "scope": system_credentials.get("scope", "read:user"), + "state": state, + # Optionally: allow_signup, login, etc. + } + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]: + """ + Exchange code for access_token. + """ + code = request.args.get("code") + state = request.args.get("state") + if not code: + raise ValueError("No code provided") + # Optionally: validate state here + + data = { + "client_id": system_credentials["client_id"], + "client_secret": system_credentials["client_secret"], + "code": code, + "redirect_uri": system_credentials["redirect_uri"], + } + headers = {"Accept": "application/json"} + response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10) + response_json = response.json() + access_token = response_json.get("access_token") + if not access_token: + raise ValueError(f"Error in GitHub OAuth: {response_json}") + return {"access_token": access_token} + + def _validate_credentials(self, credentials: dict) -> None: + try: + if "access_token" not in credentials or not credentials.get("access_token"): + raise ToolProviderCredentialValidationError("GitHub API Access Token is required.") + headers = { + "Authorization": f"Bearer {credentials['access_token']}", + "Accept": "application/vnd.github+json", + } + response = requests.get(self._API_USER_URL, headers=headers, timeout=10) + if response.status_code != 200: + raise ToolProviderCredentialValidationError(response.json().get("message")) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index a837552007..14ec4ebae0 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,3 +1,4 @@ +import os from datetime import timedelta import pytz @@ -24,12 +25,17 @@ def init_app(app: DifyApp) -> Celery: }, } + + flask_debugging = os.environ.get("FLASK_DEBUG", "0").lower() in {"true", "1", "yes"} + celery_app = Celery( app.name, task_cls=FlaskTask, broker=dify_config.CELERY_BROKER_URL, backend=dify_config.CELERY_BACKEND, task_ignore_result=True, + task_always_eager=flask_debugging, + task_eager_propagates=flask_debugging, ) # Add SSL options to the Celery configuration diff --git a/api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py b/api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py new file mode 100644 index 0000000000..95e74571d5 --- /dev/null +++ b/api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py @@ -0,0 +1,66 @@ +"""add tool oauth credentials + +Revision ID: 99310d2c25a6 +Revises: 4474872b0ee6 +Create Date: 2025-06-18 15:06:15.261915 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '99310d2c25a6' +down_revision = '4474872b0ee6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_oauth_system_clients', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'), + sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx') + ) + op.create_table('tool_oauth_user_clients', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_oauth_user_client_pkey'), + sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_user_client') + ) + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + batch_op.alter_column('credential_type', + existing_type=sa.VARCHAR(length=255), + type_=sa.String(length=32), + existing_nullable=False, + existing_server_default=sa.text("'api_key'::character varying")) + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + batch_op.create_unique_constraint('unique_builtin_tool_provider', ['tenant_id', 'provider', 'credential_type']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique') + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider']) + batch_op.alter_column('credential_type', + existing_type=sa.String(length=32), + type_=sa.VARCHAR(length=255), + existing_nullable=False, + existing_server_default=sa.text("'api_key'::character varying")) + batch_op.drop_column('default') + + op.drop_table('tool_oauth_user_clients') + op.drop_table('tool_oauth_system_clients') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py b/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py new file mode 100644 index 0000000000..82e812cb3d --- /dev/null +++ b/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py @@ -0,0 +1,39 @@ +"""multiple credential + +Revision ID: 222376193a49 +Revises: 99310d2c25a6 +Create Date: 2025-06-19 11:33:46.400455 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '222376193a49' +down_revision = '99310d2c25a6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + + with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: + batch_op.add_column(sa.Column('owner_type', sa.Text(), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: + batch_op.drop_column('owner_type') + + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'credential_type']) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py b/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py new file mode 100644 index 0000000000..216661550a --- /dev/null +++ b/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py @@ -0,0 +1,33 @@ +"""multiple credential + +Revision ID: a9306e69af07 +Revises: 222376193a49 +Create Date: 2025-06-19 13:53:41.554159 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'a9306e69af07' +down_revision = '222376193a49' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.create_unique_constraint('unique_builtin_tool_provider', ['provider', 'tenant_id', 'default']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py b/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py new file mode 100644 index 0000000000..d90e0d178e --- /dev/null +++ b/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py @@ -0,0 +1,33 @@ +"""multiple credential + +Revision ID: 6835b906335f +Revises: e315d2a83984 +Create Date: 2025-06-19 13:59:58.107955 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '6835b906335f' +down_revision = 'e315d2a83984' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['provider', 'tenant_id', 'default']) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py b/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py new file mode 100644 index 0000000000..2f0caeaf0d --- /dev/null +++ b/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py @@ -0,0 +1,33 @@ +"""multiple credential + +Revision ID: e315d2a83984 +Revises: a9306e69af07 +Create Date: 2025-06-19 13:59:13.860523 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e315d2a83984' +down_revision = 'a9306e69af07' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id']) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py b/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py new file mode 100644 index 0000000000..84a5461a4d --- /dev/null +++ b/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py @@ -0,0 +1,53 @@ +"""multiple credential + +Revision ID: 110e30078dd3 +Revises: 6835b906335f +Create Date: 2025-06-19 15:11:42.688478 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '110e30078dd3' +down_revision = '6835b906335f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.UUID(), + type_=sa.String(length=512), + existing_nullable=False) + + with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: + batch_op.add_column(sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False)) + batch_op.alter_column('plugin_id', + existing_type=sa.UUID(), + type_=sa.String(length=512), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.String(length=512), + type_=sa.UUID(), + existing_nullable=False) + batch_op.drop_column('enabled') + + with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.String(length=512), + type_=sa.UUID(), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index 03fbc3acb1..4b493e7596 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,3 +1,4 @@ +import enum import json from datetime import datetime from typing import Any, cast @@ -17,6 +18,65 @@ from .model import Account, App, Tenant from .types import StringUUID +class ToolProviderCredentialType(enum.StrEnum): + API_KEY = "api_key", + OAUTH2 = "oauth2", + + def is_editable(self): + return self == ToolProviderCredentialType.API_KEY + + @classmethod + def get_credential_type(cls, credential_type: str) -> "ToolProviderCredentialType": + if credential_type == "api_key": + return cls.API_KEY + elif credential_type == "oauth2": + return cls.OAUTH2 + else: + raise ValueError(f"Invalid credential type: {credential_type}") + +# system level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthSystemClient(Base): + __tablename__ = "tool_oauth_system_clients" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), + db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + # owner type, e.g., "system", "user" + + # oauth params of the tool provider + encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + + @property + def oauth_params(self) -> dict: + return cast(dict, json.loads(self.encrypted_oauth_params)) + + +# user level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthUserClient(Base): + __tablename__ = "tool_oauth_user_clients" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_oauth_user_client_pkey"), + db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_user_client"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # tenant id + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + owner_type: Mapped[str] = mapped_column(db.Text, nullable=False) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + # oauth params of the tool provider + encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + + @property + def oauth_params(self) -> dict: + return cast(dict, json.loads(self.encrypted_oauth_params)) + class BuiltinToolProvider(Base): """ This table stores the tool provider information for built-in tools for each tenant. @@ -25,12 +85,11 @@ class BuiltinToolProvider(Base): __tablename__ = "tool_builtin_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), - # one tenant can only have one tool provider with the same name - db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), ) # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column(db.String(256), nullable=False) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider @@ -45,6 +104,11 @@ class BuiltinToolProvider(Base): updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) + default: Mapped[bool] = mapped_column( + db.Boolean, nullable=False, server_default=db.text("false") + ) + # credential type, e.g., "api_key", "oauth2" + credential_type: Mapped[str] = mapped_column(db.String(32), nullable=False, server_default=db.text("'api_key'::character varying")) @property def credentials(self) -> dict: @@ -59,7 +123,6 @@ class ApiToolProvider(Base): __tablename__ = "tool_api_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), - db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 461247419b..dcc14a8fad 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,7 +1,62 @@ +import json +import uuid + from core.plugin.impl.base import BasePluginClient +from extensions.ext_redis import redis_client -class OAuthService(BasePluginClient): - @classmethod - def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: - return "1234567890" +class OAuthProxyService(BasePluginClient): + # Default max age for proxy context parameter in seconds + __MAX_AGE__ = 5 * 60 # 5 minutes + + @staticmethod + def create_proxy_context(user_id, tenant_id, plugin_id, provider): + """ + Create a proxy context for an OAuth 2.0 authorization request. + + This parameter is a crucial security measure to prevent Cross-Site Request + Forgery (CSRF) attacks. It works by generating a unique nonce and storing it + in a distributed cache (Redis) along with the user's session context. + + The returned nonce should be included as the 'proxy_context' parameter in the + authorization URL. Upon callback, the `retrieve_proxy_context` method + is used to verify the state, ensuring the request's integrity and authenticity, + and mitigating replay attacks. + """ + seconds, microseconds = redis_client.time() + context_id = str(uuid.uuid4()) + data = { + "user_id": user_id, + "plugin_id": plugin_id, + "tenant_id": tenant_id, + "provider": provider, + # encode redis time to avoid distribution time skew + "timestamp": seconds, + } + # ignore nonce collision + redis_client.setex( + f"oauth_proxy_context:{context_id}", + OAuthProxyService.__MAX_AGE__, + json.dumps(data), + ) + return context_id + + + @staticmethod + def use_proxy_context(context_id, max_age=__MAX_AGE__): + """ + Validate the proxy context parameter. + This checks if the context_id is valid and not expired. + """ + if not context_id: + raise ValueError("context_id is required") + # get data from redis + data = redis_client.getdel(f"oauth_proxy_context:{context_id}") + if not data: + raise ValueError("context_id is invalid") + # check if data is expired + seconds, microseconds = redis_client.time() + state = json.loads(data) + if state.get("timestamp") < seconds - max_age: + raise ValueError("context_id is expired") + return state diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 3ccd14415d..25d927f9f9 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,8 +2,6 @@ import json import logging from pathlib import Path -from sqlalchemy.orm import Session - from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder @@ -16,7 +14,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ProviderConfigEncrypter from extensions.ext_database import db -from models.tools import BuiltinToolProvider +from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthUserClient, ToolProviderCredentialType from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -109,63 +107,69 @@ class BuiltinToolManageService: @staticmethod def update_builtin_tool_provider( - session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + user_id: str, tenant_id: str, provider_name:str, credentials: dict, credential_id: str, name: str | None = None ): """ update builtin tool provider """ # get if the provider exists - provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) + + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") + + if not ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable(): + raise ValueError(f"you cannot update oauth2 provider {provider_name} credentials") try: - # get provider - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + # exclude oauth2 provider + if provider.credential_type != ToolProviderCredentialType.OAUTH2.value: + provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider_name} does not need credentials") - # get original credentials if exists - if provider is not None: - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] - # validate credentials - provider_controller.validate_credentials(user_id, credentials) - # encrypt credentials - credentials = tool_configuration.encrypt(credentials) + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + + # Decrypt and restore original credentials for masked values + credentials = BuiltinToolManageService._dec + rypt_and_restore_credentials( + provider_controller, tool_configuration, provider, credentials + ) + + # Encrypt and save the credentials + BuiltinToolManageService._encrypt_and_save_credentials( + provider_controller, tool_configuration, provider, credentials, user_id + ) + + # update name if provided + if name is not None and provider.name != name: + provider.name = name + + db.session.commit() except ( - PluginDaemonClientSideError, - ToolProviderNotFoundError, - ToolNotFoundError, - ToolProviderCredentialValidationError, + PluginDaemonClientSideError, + ToolProviderNotFoundError, + ToolNotFoundError, + ToolProviderCredentialValidationError, ) as e: raise ValueError(str(e)) - if provider is None: - # create provider - provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - encrypted_credentials=json.dumps(credentials), - ) + return {"result": "success"} - db.session.add(provider) - else: - provider.encrypted_credentials = json.dumps(credentials) + @staticmethod + def add_builtin_tool_provider( + user_id: str, tenant_id: str, provider_name: str, credentials: dict, name: str | None = None + ): + """ + add builtin tool provider + """ + - # delete cache - tool_configuration.delete_tool_credentials_cache() - - db.session.commit() return {"result": "success"} @staticmethod @@ -214,6 +218,78 @@ class BuiltinToolManageService: return {"result": "success"} + @staticmethod + def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str): + """ + set default provider + """ + # get provider + target_provider = db.session.query(BuiltinToolProvider).filter_by(id=id).first() + if target_provider is None: + raise ValueError("provider not found") + + # clear default provider + db.session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + default=True + ).update({"default": False}) + + # set new default provider + target_provider.default = True + db.session.commit() + return {"result": "success"} + + @staticmethod + def fetch_default_provider(tenant_id: str, user_id: str, provider_name: str): + """ + fetch default provider + if there is no explicitly set default provider, return the oldest provider as default + """ + # 1. check if default provider exists + default_provider = db.session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, + user_id=user_id, + provider=provider_name, + default=True + ).first() + if default_provider: + return default_provider + + # 2. if no default provider, set the oldest provider as default + oldest_provider = (db.session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, user_id=user_id, provider=provider_name) + .order_by(BuiltinToolProvider.created_at) + .first() + ) + if oldest_provider: + return oldest_provider + + raise ValueError(f"no default provider found for {provider_name}") + + @staticmethod + def get_builtin_tool_provider(tenant_id: str, user_id: str, provider: str, plugin_id: str): + """ + get builtin tool provider + """ + user_client = db.session.query(ToolOAuthUserClient).filter_by( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + enabled=True, + ).first() + + if user_client: + plugin_oauth_config = user_client + else: + plugin_oauth_config = db.session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() + + if plugin_oauth_config: + return plugin_oauth_config + + raise ValueError("no oauth available config found for this plugin") + @staticmethod def get_builtin_tool_provider_icon(provider: str): """ @@ -286,6 +362,15 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) + @staticmethod + def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None: + provider = (db.session.query(BuiltinToolProvider) + .filter(BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first()) + return provider + @staticmethod def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: try: @@ -327,3 +412,42 @@ class BuiltinToolManageService: ) .first() ) + + @staticmethod + def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials): + """ + Decrypt original credentials and restore masked values from the input credentials + + :param provider_controller: the provider controller + :param tool_configuration: the tool configuration encrypter + :param provider: the provider object from database + :param credentials: the input credentials from user + :return: the processed credentials with original values restored + """ + original_credentials = tool_configuration.decrypt(provider.credentials) + masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: # type: ignore + credentials[name] = original_credentials[name] # type: ignore + + return credentials + + @staticmethod + def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id): + """ + Validate and encrypt credentials, then save to database + + :param provider_controller: the provider controller + :param tool_configuration: the tool configuration encrypter + :param provider: the provider object from database + :param credentials: the credentials to encrypt and save + :param user_id: the user id for validation + """ + # validate credentials + provider_controller.validate_credentials(user_id, credentials) + # encrypt credentials + encrypted_credentials = tool_configuration.encrypt(credentials) + provider.encrypted_credentials = json.dumps(encrypted_credentials) + tool_configuration.delete_tool_credentials_cache() diff --git a/api/tool_oauth.http b/api/tool_oauth.http new file mode 100644 index 0000000000..9915472d03 --- /dev/null +++ b/api/tool_oauth.http @@ -0,0 +1,27 @@ + +@accessToken=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiYjM4Y2Y5N2MtODNiYS00MWI3LWEyZjMtMzZlOTgzZjE4YmQ5IiwiZXhwIjoxNzUwNDE3NDI0LCJpc3MiOiJTRUxGX0hPU1RFRCIsInN1YiI6IkNvbnNvbGUgQVBJIFBhc3Nwb3J0In0.pPCkISnSmnu3hOCyEVTIJoNeWxtx7E9LNy0cDQUy__Q + + + +# set default credential +POST /console/api/workspaces/current/tool-provider/builtin/langgenius/github/github/set-default +Host: 127.0.0.1:5001 +Content-Type: application/json +Authorization: Bearer {{accessToken}} + +{ + "id": "55fb78d2-0ce6-4496-9488-3b8d9f40818f" +} +### + +# get oauth url +GET /console/api/oauth/plugin/tool?plugin_id=c58a1845-f3a4-4d93-b749-a71e9998b702/github&provider=github +Host: 127.0.0.1:5001 +Authorization: Bearer {{accessToken}} + +### + +# get oauth token +GET /console/api/oauth/plugin/tool/callback?state=734072c2-d8ed-4b0b-8ed8-4efd69d15a4f&code=e2d68a6216a3b7d70d2f&state=NQCjFkMKtf32XCMHc8KBdw +Host: 127.0.0.1:5001 +Authorization: Bearer {{accessToken}}