diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 96cb3f5602..c78b36c3b9 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -24,7 +24,13 @@ class DatasourcePluginOauthApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider, plugin_id): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + provider = args["provider"] + plugin_id = args["plugin_id"] # Check user role first if not current_user.is_editor: raise Forbidden() @@ -35,7 +41,7 @@ class DatasourcePluginOauthApi(Resource): if not plugin_oauth_config: raise NotFound() oauth_handler = OAuthHandler() - redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/provider/{provider}/plugin/{plugin_id}/callback" + redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}" system_credentials = plugin_oauth_config.system_credentials if system_credentials: system_credentials["redirect_url"] = redirect_url @@ -49,7 +55,13 @@ class DatasourceOauthCallback(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider, plugin_id): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + provider = args["provider"] + plugin_id = args["plugin_id"] oauth_handler = OAuthHandler() plugin_oauth_config = ( db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() @@ -76,11 +88,13 @@ class DatasourceAuth(Resource): @setup_required @login_required @account_initialization_required - def post(self, provider, plugin_id): + def post(self): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -89,8 +103,8 @@ class DatasourceAuth(Resource): try: datasource_provider_service.datasource_provider_credentials_validate( tenant_id=current_user.current_tenant_id, - provider=provider, - plugin_id=plugin_id, + provider=args["provider"], + plugin_id=args["plugin_id"], credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: @@ -101,10 +115,16 @@ class DatasourceAuth(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider, plugin_id): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id + tenant_id=current_user.current_tenant_id, + provider=args["provider"], + plugin_id=args["plugin_id"] ) return {"result": datasources}, 200 @@ -113,12 +133,18 @@ class DatasourceAuthDeleteApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self, provider, plugin_id): + def delete(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() if not current_user.is_editor: raise Forbidden() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( - tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id + tenant_id=current_user.current_tenant_id, + provider=args["provider"], + plugin_id=args["plugin_id"] ) return {"result": "success"}, 200 @@ -126,13 +152,13 @@ class DatasourceAuthDeleteApi(Resource): # Import Rag Pipeline api.add_resource( DatasourcePluginOauthApi, - "/oauth/datasource/provider//plugin/", + "/oauth/plugin/datasource", ) api.add_resource( DatasourceOauthCallback, - "/oauth/datasource/provider//plugin//callback", + "/oauth/plugin/datasource/callback", ) api.add_resource( DatasourceAuth, - "/auth/datasource/provider//plugin/", + "/auth/plugin/datasource", ) diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 46b36d8349..838fee5b96 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -24,7 +24,7 @@ class DatasourceManager: @classmethod def get_datasource_plugin_provider( - cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType + cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType ) -> DatasourcePluginProviderController: """ get the datasource plugin provider @@ -38,13 +38,13 @@ class DatasourceManager: with contexts.datasource_plugin_providers_lock.get(): datasource_plugin_providers = contexts.datasource_plugin_providers.get() - if provider in datasource_plugin_providers: - return datasource_plugin_providers[provider] + if provider_id in datasource_plugin_providers: + return datasource_plugin_providers[provider_id] manager = PluginDatasourceManager() - provider_entity = manager.fetch_datasource_provider(tenant_id, provider) + provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id) if not provider_entity: - raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") + raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found") match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: @@ -71,7 +71,7 @@ class DatasourceManager: case _: raise ValueError(f"Unsupported datasource type: {datasource_type}") - datasource_plugin_providers[provider] = controller + datasource_plugin_providers[provider_id] = controller return controller diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index ea357d85b2..f469b51224 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -40,16 +40,25 @@ class PluginDatasourceManager(BasePluginClient): ) local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) - return [local_file_datasource_provider] + response + all_response = [local_file_datasource_provider] + response - def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: + for provider in all_response: + provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in provider.declaration.datasources: + tool.identity.provider = provider.declaration.identity.name + + return all_response + + def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity: """ Fetch datasource provider for the given tenant and plugin. """ - if provider == "langgenius/file/file": + if provider_id == "langgenius/file/file": return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) - tool_provider_id = ToolProviderID(provider) + tool_provider_id = ToolProviderID(provider_id) def transformer(json_response: dict[str, Any]) -> dict: data = json_response.get("data") @@ -225,13 +234,13 @@ class PluginDatasourceManager(BasePluginClient): def _get_local_file_datasource_provider(self) -> dict[str, Any]: return { "id": "langgenius/file/file", - "plugin_id": "langgenius/file/file", - "provider": "langgenius", + "plugin_id": "langgenius/file", + "provider": "file", "plugin_unique_identifier": "langgenius/file:0.0.1@dify", "declaration": { "identity": { "author": "langgenius", - "name": "langgenius/file/file", + "name": "file", "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, @@ -243,7 +252,7 @@ class PluginDatasourceManager(BasePluginClient): "identity": { "author": "langgenius", "name": "upload-file", - "provider": "langgenius", + "provider": "file", "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, }, "parameters": [], diff --git a/api/models/oauth.py b/api/models/oauth.py index 2fb34f0ac9..d823bcae16 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -25,12 +25,12 @@ class DatasourceProvider(Base): __tablename__ = "datasource_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"), - db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"), + db.UniqueConstraint("plugin_id", "provider", "auth_type", name="datasource_provider_auth_type_provider_idx"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) - plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) + plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index ccafc5555c..71edec760f 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -38,11 +38,14 @@ class DatasourceProviderService: # Get all provider configurations of the current workspace datasource_provider = ( db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, auth_type="api_key") .first() ) - provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}" + ) if not datasource_provider: for key, value in credentials.items(): if key in provider_credential_secret_variables: @@ -73,14 +76,16 @@ class DatasourceProviderService: else: raise CredentialsValidateFailedError() - def extract_secret_variables(self, tenant_id: str, provider: str) -> list[str]: + def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]: """ Extract secret input form variables. :param credential_form_schemas: :return: """ - datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, provider=provider) + datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, + provider_id=provider_id + ) credential_form_schemas = datasource_provider.declaration.credentials_schema secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: @@ -94,8 +99,7 @@ class DatasourceProviderService: get datasource credentials. :param tenant_id: workspace id - :param provider: provider name - :param plugin_id: plugin id + :param provider_id: provider id :return: """ # Get all provider configurations of the current workspace @@ -114,7 +118,7 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=provider) # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy()