diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 16b32856cb..01b3372e24 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -81,12 +81,13 @@ class DatasourceOAuthCallback(Resource): user_id, tenant_id = context.get("user_id"), context.get("tenant_id") datasource_provider_id = DatasourceProviderID(provider_id) - provider_name = datasource_provider_id.provider_name plugin_id = datasource_provider_id.plugin_id - plugin_oauth_config = ( - db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first() + datasource_provider_service = DatasourceProviderService() + oauth_client_params = datasource_provider_service.get_oauth_client( + tenant_id=tenant_id, + datasource_provider_id=datasource_provider_id, ) - if not plugin_oauth_config: + if not oauth_client_params: raise NotFound() redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback" oauth_handler = OAuthHandler() @@ -96,10 +97,9 @@ class DatasourceOAuthCallback(Resource): plugin_id=plugin_id, provider=datasource_provider_id.provider_name, redirect_uri=redirect_uri, - system_credentials=plugin_oauth_config.system_credentials, + system_credentials=oauth_client_params, request=request, ) - datasource_provider_service = DatasourceProviderService() datasource_provider_service.add_datasource_oauth_provider( tenant_id=tenant_id, provider_id=datasource_provider_id, @@ -205,8 +205,28 @@ class DatasourceAuthListApi(Resource): ) return {"result": datasources}, 200 +class DatasourceAuthOauthCustomClient(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider_id: str): + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") + parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json") + args = parser.parse_args() + datasource_provider_id = DatasourceProviderID(provider_id) + datasource_provider_service = DatasourceProviderService() + + datasource_provider_service.setup_oauth_custom_client_params( + tenant_id=current_user.current_tenant_id, + datasource_provider_id=datasource_provider_id, + client_params=args.get("client_params", {}), + enabled=args.get("enabled", False), + ) + return {"result": "success"}, 200 -# Import Rag Pipeline api.add_resource( DatasourcePluginOAuthAuthorizationUrl, "/oauth/plugin//datasource/get-authorization-url", @@ -229,3 +249,8 @@ api.add_resource( DatasourceAuthListApi, "/auth/plugin/datasource/list", ) + +api.add_resource( + DatasourceAuthOauthCustomClient, + "/auth/plugin/datasource//custom-client", +) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index e1c14df4e8..10bdce7349 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -361,4 +361,4 @@ class PluginDatasourceManager(BasePluginClient): } ], }, - } + } \ No newline at end of file diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py index 5fdfd3b9d1..7f4113bb77 100644 --- a/api/core/tools/utils/encryption.py +++ b/api/core/tools/utils/encryption.py @@ -124,11 +124,15 @@ class ProviderConfigEncrypter: return data -def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): +def create_provider_encrypter( + tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache +) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache -def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController): +def create_tool_provider_encrypter( + tenant_id: str, controller: ToolProviderController +) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: cache = SingletonProviderCredentialsCache( tenant_id=tenant_id, provider_type=controller.provider_type.value, diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py index 0d3af3195f..cd3388de7e 100644 --- a/api/core/workflow/system_variable.py +++ b/api/core/workflow/system_variable.py @@ -1,5 +1,5 @@ -from collections.abc import Sequence -from typing import Any, Mapping +from collections.abc import Mapping, Sequence +from typing import Any from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator diff --git a/api/migrations/versions/2025_07_21_1220-d3c68680d3ba_add_pipeline_info_14.py b/api/migrations/versions/2025_07_21_1220-d3c68680d3ba_add_pipeline_info_14.py new file mode 100644 index 0000000000..1fec5c0703 --- /dev/null +++ b/api/migrations/versions/2025_07_21_1220-d3c68680d3ba_add_pipeline_info_14.py @@ -0,0 +1,40 @@ +"""add_pipeline_info_14 + +Revision ID: d3c68680d3ba +Revises: fcb46171d891 +Create Date: 2025-07-21 12:20:29.582951 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd3c68680d3ba' +down_revision = 'fcb46171d891' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('datasource_oauth_tenant_params', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('plugin_id', sa.String(length=255), nullable=False), + sa.Column('client_params', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('enabled', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'), + sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('datasource_oauth_tenant_params') + # ### end Alembic commands ### diff --git a/api/models/oauth.py b/api/models/oauth.py index 3fe59bb650..65bcb5c0a3 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -35,7 +35,25 @@ class DatasourceProvider(Base): plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) - avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True) + avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default") + + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) + updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) + + +class DatasourceOauthTenantParamConfig(Base): + __tablename__ = "datasource_oauth_tenant_params" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"), + db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + provider: Mapped[str] = db.Column(db.String(255), nullable=False) + plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) + client_params: Mapped[dict] = db.Column(JSONB, nullable=False, default={}) + enabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, default=False) created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 87735563f7..ea4c00da2b 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -1,19 +1,22 @@ import logging +from typing import Any from flask_login import current_user from sqlalchemy.orm import Session -from constants import HIDDEN_VALUE +from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper import encrypter from core.helper.name_generator import generate_incremental_name +from core.helper.provider_cache import NoOpProviderCredentialCache from core.model_runtime.entities.provider_entities import FormType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.plugin.entities.plugin import DatasourceProviderID from core.plugin.impl.datasource import PluginDatasourceManager from core.tools.entities.tool_entities import CredentialType +from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.oauth import DatasourceProvider +from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider logger = logging.getLogger(__name__) @@ -26,6 +29,165 @@ class DatasourceProviderService: def __init__(self) -> None: self.provider_manager = PluginDatasourceManager() + def setup_oauth_custom_client_params( + self, + tenant_id: str, + datasource_provider_id: DatasourceProviderID, + client_params: dict | None, + enabled: bool | None, + ): + """ + setup oauth custom client params + """ + if client_params is None and enabled is None: + return + provider_controller = PluginDatasourceManager() + datasource_provider = provider_controller.fetch_datasource_provider( + tenant_id=tenant_id, provider_id=str(datasource_provider_id) + ) + if not datasource_provider.declaration.oauth_schema: + raise ValueError("Datasource provider oauth schema not found") + with Session(db.engine) as session: + tenant_oauth_client_params = ( + session.query(DatasourceOauthTenantParamConfig) + .filter_by( + tenant_id=tenant_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + ) + .first() + ) + + if not tenant_oauth_client_params: + tenant_oauth_client_params = DatasourceOauthTenantParamConfig( + tenant_id=tenant_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + client_params={}, + enabled=False, + ) + session.add(tenant_oauth_client_params) + + if client_params is not None: + client_schema = datasource_provider.declaration.oauth_schema.client_schema + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in client_schema], + cache=NoOpProviderCredentialCache(), + ) + original_params = ( + encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {} + ) + new_params: dict = { + key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE) + for key, value in client_params.items() + } + tenant_oauth_client_params.client_params = encrypter.encrypt(new_params) + + if enabled is not None: + tenant_oauth_client_params.enabled = enabled + session.commit() + + def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool: + """ + check if system oauth params exist + """ + with Session(db.engine).no_autoflush as session: + return ( + session.query(DatasourceOauthParamConfig) + .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id) + .first() + is not None + ) + + def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool: + """ + check if tenant oauth params is enabled + """ + with Session(db.engine).no_autoflush as session: + return ( + session.query(DatasourceOauthTenantParamConfig) + .filter_by( + tenant_id=tenant_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + enabled=True, + ) + .count() + > 0 + ) + + def get_tenant_oauth_client( + self, tenant_id: str, datasource_provider_id: DatasourceProviderID + ) -> dict[str, Any] | None: + """ + get tenant oauth client + """ + with Session(db.engine).no_autoflush as session: + tenant_oauth_client_params = ( + session.query(DatasourceOauthTenantParamConfig) + .filter_by( + tenant_id=tenant_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + ) + .first() + ) + if tenant_oauth_client_params: + encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) + return encrypter.decrypt(tenant_oauth_client_params.client_params) + return None + + def get_oauth_encrypter( + self, tenant_id: str, datasource_provider_id: DatasourceProviderID + ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: + """ + get oauth encrypter + """ + datasource_provider = self.provider_manager.fetch_datasource_provider( + tenant_id=tenant_id, provider_id=str(datasource_provider_id) + ) + if not datasource_provider.declaration.oauth_schema: + raise ValueError("Datasource provider oauth schema not found") + + client_schema = datasource_provider.declaration.oauth_schema.client_schema + return create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in client_schema], + cache=NoOpProviderCredentialCache(), + ) + + def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None: + """ + get oauth client + """ + provider = datasource_provider_id.provider_name + plugin_id = datasource_provider_id.plugin_id + with Session(db.engine).no_autoflush as session: + # get tenant oauth client params + tenant_oauth_client_params = ( + session.query(DatasourceOauthTenantParamConfig) + .filter_by( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + enabled=True, + ) + .first() + ) + if tenant_oauth_client_params: + encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) + return encrypter.decrypt(tenant_oauth_client_params.client_params) + + # fallback to system oauth client params + oauth_client_params = ( + session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + ) + if oauth_client_params: + return oauth_client_params.system_credentials + + raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}") + @staticmethod def generate_next_datasource_provider_name( session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType @@ -69,24 +231,29 @@ class DatasourceProviderService: credential_type=credential_type, ) else: - if session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - name=db_provider_name, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - auth_type=credential_type.value, - ).count() > 0: + if ( + session.query(DatasourceProvider) + .filter_by( + tenant_id=tenant_id, + name=db_provider_name, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, + auth_type=credential_type.value, + ) + .count() + > 0 + ): db_provider_name = generate_incremental_name( - [ - provider.name - for provider in session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ) - ], - db_provider_name, - ) + [ + provider.name + for provider in session.query(DatasourceProvider).filter_by( + tenant_id=tenant_id, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, + ) + ], + db_provider_name, + ) provider_credential_secret_variables = self.extract_secret_variables( tenant_id=tenant_id, provider_id=f"{provider_id}" @@ -103,7 +270,7 @@ class DatasourceProviderService: plugin_id=provider_id.plugin_id, auth_type=credential_type.value, encrypted_credentials=credentials, - avatar_url=avatar_url, + avatar_url=avatar_url or "default", ) session.add(datasource_provider) session.commit() @@ -222,6 +389,7 @@ class DatasourceProviderService: "credential": copy_credentials, "type": datasource_provider.auth_type, "name": datasource_provider.name, + "avatar_url": datasource_provider.avatar_url, "id": datasource_provider.id, } ) @@ -239,6 +407,7 @@ class DatasourceProviderService: datasources = manager.fetch_installed_datasource_providers(tenant_id) datasource_credentials = [] for datasource in datasources: + datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}") credentials = self.get_datasource_credentials( tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id ) @@ -302,6 +471,11 @@ class DatasourceProviderService: } for credential in datasource.declaration.oauth_schema.credentials_schema or [] ], + "oauth_custom_client_params": self.get_tenant_oauth_client(tenant_id, datasource_provider_id), + "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled( + tenant_id, datasource_provider_id + ), + "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id), } if datasource.declaration.oauth_schema else None,