From 875aea1c22e20131337cce446e43770c7bb83eb0 Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 30 Jul 2025 13:39:04 +0800 Subject: [PATCH 1/3] feat: datasource reauthentication --- .../datasets/rag_pipeline/datasource_auth.py | 35 ++++++++---- api/services/datasource_provider_service.py | 54 ++++++++++++++----- api/services/plugin/oauth_service.py | 10 +++- 3 files changed, 76 insertions(+), 23 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 44954278d6..d67af182cd 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -32,6 +32,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): if not current_user.is_editor: raise Forbidden() + credential_id = request.args.get("credential_id") datasource_provider_id = DatasourceProviderID(provider_id) provider_name = datasource_provider_id.provider_name plugin_id = datasource_provider_id.plugin_id @@ -43,7 +44,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): raise ValueError(f"No OAuth Client Config for {provider_id}") context_id = OAuthProxyService.create_proxy_context( - user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name + user_id=current_user.id, + tenant_id=tenant_id, + plugin_id=plugin_id, + provider=provider_name, + credential_id=credential_id, ) oauth_handler = OAuthHandler() redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback" @@ -98,13 +103,24 @@ class DatasourceOAuthCallback(Resource): system_credentials=oauth_client_params, request=request, ) - datasource_provider_service.add_datasource_oauth_provider( - tenant_id=tenant_id, - provider_id=datasource_provider_id, - avatar_url=oauth_response.metadata.get("avatar_url") or None, - name=oauth_response.metadata.get("name") or None, - credentials=dict(oauth_response.credentials), - ) + credential_id = context.get("credential_id") + if credential_id: + datasource_provider_service.reauthorize_datasource_oauth_provider( + tenant_id=tenant_id, + provider_id=datasource_provider_id, + avatar_url=oauth_response.metadata.get("avatar_url") or None, + name=oauth_response.metadata.get("name") or None, + credentials=dict(oauth_response.credentials), + credential_id=context.get("credential_id"), + ) + else: + datasource_provider_service.add_datasource_oauth_provider( + tenant_id=tenant_id, + provider_id=datasource_provider_id, + avatar_url=oauth_response.metadata.get("avatar_url") or None, + name=oauth_response.metadata.get("name") or None, + credentials=dict(oauth_response.credentials), + ) return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") @@ -208,7 +224,8 @@ class DatasourceAuthListApi(Resource): tenant_id=current_user.current_tenant_id ) return {"result": jsonable_encoder(datasources)}, 200 - + + class DatasourceHardCodeAuthListApi(Resource): @setup_required @login_required diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 25966ed41a..dac406360f 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -82,19 +82,16 @@ class DatasourceProviderService: if key in credential_secret_variables: copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) return copy_credentials - - def get_default_real_credential( - self, tenant_id: str, provider: str, plugin_id: str - ) -> dict[str, Any]: + + def get_default_real_credential(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]: """ get default credential """ with Session(db.engine) as session: datasource_provider = ( - session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, - is_default=True, - provider=provider, - plugin_id=plugin_id).first() + session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, is_default=True, provider=provider, plugin_id=plugin_id) + .first() ) if not datasource_provider: return {} @@ -357,6 +354,35 @@ class DatasourceProviderService: f"{credential_type.get_name()}", ) + def reauthorize_datasource_oauth_provider( + self, + name: str | None, + tenant_id: str, + provider_id: DatasourceProviderID, + avatar_url: str | None, + credentials: dict, + credential_id: str, + ) -> None: + """ + update datasource oauth provider + """ + with Session(db.engine) as session: + target_provider = session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first() + if target_provider is None: + raise ValueError("provider not found") + + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2 + ) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + target_provider.encrypted_credentials = credentials + target_provider.avatar_url = avatar_url or target_provider.avatar_url + target_provider.name = name or target_provider.name + session.commit() + def add_datasource_oauth_provider( self, name: str | None, @@ -625,7 +651,7 @@ class DatasourceProviderService: } ) return datasource_credentials - + def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]: """ get hard code datasource credentials. @@ -637,14 +663,16 @@ class DatasourceProviderService: datasources = manager.fetch_installed_datasource_providers(tenant_id) datasource_credentials = [] for datasource in datasources: - if datasource.plugin_id in ["langgenius/firecrawl_datasource", "langgenius/notion_datasource", "langgenius/jina_datasource"]: + if datasource.plugin_id in [ + "langgenius/firecrawl_datasource", + "langgenius/notion_datasource", + "langgenius/jina_datasource", + ]: datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}") credentials = self.get_datasource_credentials( tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id ) - redirect_uri = ( - f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" - ) + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" datasource_credentials.append( { "provider": datasource.provider, diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index b84dd0afc5..4a09e71504 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -11,7 +11,13 @@ class OAuthProxyService(BasePluginClient): __KEY_PREFIX__ = "oauth_proxy_context:" @staticmethod - def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str): + def create_proxy_context( + user_id: str, + tenant_id: str, + plugin_id: str, + provider: str, + credential_id: str | None = None, + ): """ Create a proxy context for an OAuth 2.0 authorization request. @@ -31,6 +37,8 @@ class OAuthProxyService(BasePluginClient): "tenant_id": tenant_id, "provider": provider, } + if credential_id: + data["credential_id"] = credential_id redis_client.setex( f"{OAuthProxyService.__KEY_PREFIX__}{context_id}", OAuthProxyService.__MAX_AGE__, From f37109ef39c342614400594182e0ac371db3e323 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 30 Jul 2025 14:34:38 +0800 Subject: [PATCH 2/3] transform document --- api/controllers/console/datasets/data_source.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index cfe9e7966f..077282d959 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -124,7 +124,7 @@ class DataSourceNotionListApi(Resource): credential = datasource_provider_service.get_real_credential_by_id( tenant_id=current_user.current_tenant_id, credential_id=credential_id, - provider="notion", + provider="notion_datasource", plugin_id="langgenius/notion_datasource", ) if not credential: @@ -155,8 +155,8 @@ class DataSourceNotionListApi(Resource): from core.datasource.datasource_manager import DatasourceManager datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id="langgenius/notion_datasource/notion", - datasource_name="notion", + provider_id="langgenius/notion_datasource/notion_datasource", + datasource_name="notion_datasource", tenant_id=current_user.current_tenant_id, datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, ) @@ -209,7 +209,7 @@ class DataSourceNotionApi(Resource): credential = datasource_provider_service.get_real_credential_by_id( tenant_id=current_user.current_tenant_id, credential_id=credential_id, - provider="notion", + provider="notion_datasource", plugin_id="langgenius/notion_datasource", ) From 69738794bc4aa558bef7d916ed8ccf371971868d Mon Sep 17 00:00:00 2001 From: Joel Date: Wed, 30 Jul 2025 14:39:00 +0800 Subject: [PATCH 3/3] feat: support custom before run form --- .../_base/components/workflow-panel/index.tsx | 15 ++++++++++----- .../workflow-panel/last-run/use-last-run.ts | 13 +++++++++++++ .../nodes/data-source/before-run-form.tsx | 17 ++++++++++------- .../workflow/nodes/data-source/types.ts | 6 ++++++ 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx b/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx index a1566347b2..925a7df2d2 100644 --- a/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx @@ -48,7 +48,6 @@ import { isSupportCustomRunForm, } from '@/app/components/workflow/utils' import Tooltip from '@/app/components/base/tooltip' -import type { CommonNodeType } from '@/app/components/workflow/types' import { BlockEnum, type Node, NodeRunningStatus } from '@/app/components/workflow/types' import { useStore as useAppStore } from '@/app/components/app/store' import { useStore } from '@/app/components/workflow/store' @@ -71,15 +70,16 @@ import { } from '@/app/components/plugins/plugin-auth' import { AuthCategory } from '@/app/components/plugins/plugin-auth' import { canFindTool } from '@/utils' -import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types' +import type { CustomRunFormProps } from '@/app/components/workflow/nodes/data-source/types' import { DataSourceClassification } from '@/app/components/workflow/nodes/data-source/types' import { useModalContext } from '@/context/modal-context' import DataSourceBeforeRunForm from '@/app/components/workflow/nodes/data-source/before-run-form' -const getCustomRunForm = (nodeType: BlockEnum, payload: CommonNodeType): React.JSX.Element => { +const getCustomRunForm = (params: CustomRunFormProps): React.JSX.Element => { + const nodeType = params.payload.type switch (nodeType) { case BlockEnum.DataSource: - return + return default: return
Custom Run Form: {nodeType} not found
} @@ -227,6 +227,7 @@ const BasePanel: FC = ({ tabType, isRunAfterSingleRun, setTabType, + handleAfterCustomSingleRun, singleRunParams, nodeInfo, setRunInputData, @@ -306,7 +307,11 @@ const BasePanel: FC = ({ } if (isShowSingleRun) { - const form = getCustomRunForm(data.type, data) + const form = getCustomRunForm({ + payload: data, + onSuccess: handleAfterCustomSingleRun, + onCancel: hideSingleRun, + }) return (
= { [BlockEnum.LLM]: useLLMSingleRunFormParams, @@ -117,6 +118,7 @@ const useLastRun = ({ const isIterationNode = blockType === BlockEnum.Iteration const isLoopNode = blockType === BlockEnum.Loop const isAggregatorNode = blockType === BlockEnum.VariableAggregator + const isCustomRunNode = isSupportCustomRunForm(blockType) const { handleSyncWorkflowDraft } = useNodesSyncDraft() const { getData: getDataForCheckMore, @@ -299,10 +301,20 @@ const useLastRun = ({ }) } + const handleAfterCustomSingleRun = () => { + invalidLastRun() + setTabType(TabType.lastRun) + hideSingleRun() + } + const handleSingleRun = () => { const { isValid } = checkValid() if(!isValid) return + if(isCustomRunNode) { + showSingleRun() + return + } const vars = singleRunParams?.getDependentVars?.() // no need to input params if (isAggregatorNode ? checkAggregatorVarsSet(vars) : isAllVarsHasValue(vars)) { @@ -323,6 +335,7 @@ const useLastRun = ({ tabType, isRunAfterSingleRun, setTabType: handleTabClicked, + handleAfterCustomSingleRun, singleRunParams, nodeInfo, setRunInputData, diff --git a/web/app/components/workflow/nodes/data-source/before-run-form.tsx b/web/app/components/workflow/nodes/data-source/before-run-form.tsx index b3fa4f1d04..0875789021 100644 --- a/web/app/components/workflow/nodes/data-source/before-run-form.tsx +++ b/web/app/components/workflow/nodes/data-source/before-run-form.tsx @@ -1,18 +1,21 @@ 'use client' import type { FC } from 'react' import React from 'react' -import type { DataSourceNodeType } from './types' +import type { CustomRunFormProps, DataSourceNodeType } from './types' +import Button from '@/app/components/base/button' -type Props = { - payload: DataSourceNodeType -} - -const BeforeRunForm: FC = ({ +const BeforeRunForm: FC = ({ payload, + onSuccess, + onCancel, }) => { return (
- DataSource: {payload.datasource_name} + DataSource: {(payload as DataSourceNodeType).datasource_name} +
+ + +
) } diff --git a/web/app/components/workflow/nodes/data-source/types.ts b/web/app/components/workflow/nodes/data-source/types.ts index a13b5c0fd7..5d95c0f3c6 100644 --- a/web/app/components/workflow/nodes/data-source/types.ts +++ b/web/app/components/workflow/nodes/data-source/types.ts @@ -28,3 +28,9 @@ export type DataSourceNodeType = CommonNodeType & { datasource_parameters: ToolVarInputs datasource_configurations: Record } + +export type CustomRunFormProps = { + payload: CommonNodeType + onSuccess: () => void + onCancel: () => void +}