diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index cfe9e7966f..077282d959 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -124,7 +124,7 @@ class DataSourceNotionListApi(Resource):
credential = datasource_provider_service.get_real_credential_by_id(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
- provider="notion",
+ provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
@@ -155,8 +155,8 @@ class DataSourceNotionListApi(Resource):
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
- provider_id="langgenius/notion_datasource/notion",
- datasource_name="notion",
+ provider_id="langgenius/notion_datasource/notion_datasource",
+ datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
@@ -209,7 +209,7 @@ class DataSourceNotionApi(Resource):
credential = datasource_provider_service.get_real_credential_by_id(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
- provider="notion",
+ provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py
index 44954278d6..d67af182cd 100644
--- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py
+++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py
@@ -32,6 +32,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
if not current_user.is_editor:
raise Forbidden()
+ credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id)
provider_name = datasource_provider_id.provider_name
plugin_id = datasource_provider_id.plugin_id
@@ -43,7 +44,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
raise ValueError(f"No OAuth Client Config for {provider_id}")
context_id = OAuthProxyService.create_proxy_context(
- user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
+ user_id=current_user.id,
+ tenant_id=tenant_id,
+ plugin_id=plugin_id,
+ provider=provider_name,
+ credential_id=credential_id,
)
oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
@@ -98,13 +103,24 @@ class DatasourceOAuthCallback(Resource):
system_credentials=oauth_client_params,
request=request,
)
- datasource_provider_service.add_datasource_oauth_provider(
- tenant_id=tenant_id,
- provider_id=datasource_provider_id,
- avatar_url=oauth_response.metadata.get("avatar_url") or None,
- name=oauth_response.metadata.get("name") or None,
- credentials=dict(oauth_response.credentials),
- )
+ credential_id = context.get("credential_id")
+ if credential_id:
+ datasource_provider_service.reauthorize_datasource_oauth_provider(
+ tenant_id=tenant_id,
+ provider_id=datasource_provider_id,
+ avatar_url=oauth_response.metadata.get("avatar_url") or None,
+ name=oauth_response.metadata.get("name") or None,
+ credentials=dict(oauth_response.credentials),
+ credential_id=context.get("credential_id"),
+ )
+ else:
+ datasource_provider_service.add_datasource_oauth_provider(
+ tenant_id=tenant_id,
+ provider_id=datasource_provider_id,
+ avatar_url=oauth_response.metadata.get("avatar_url") or None,
+ name=oauth_response.metadata.get("name") or None,
+ credentials=dict(oauth_response.credentials),
+ )
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@@ -208,7 +224,8 @@ class DatasourceAuthListApi(Resource):
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200
-
+
+
class DatasourceHardCodeAuthListApi(Resource):
@setup_required
@login_required
diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py
index 25966ed41a..dac406360f 100644
--- a/api/services/datasource_provider_service.py
+++ b/api/services/datasource_provider_service.py
@@ -82,19 +82,16 @@ class DatasourceProviderService:
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
return copy_credentials
-
- def get_default_real_credential(
- self, tenant_id: str, provider: str, plugin_id: str
- ) -> dict[str, Any]:
+
+ def get_default_real_credential(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
"""
get default credential
"""
with Session(db.engine) as session:
datasource_provider = (
- session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
- is_default=True,
- provider=provider,
- plugin_id=plugin_id).first()
+ session.query(DatasourceProvider)
+ .filter_by(tenant_id=tenant_id, is_default=True, provider=provider, plugin_id=plugin_id)
+ .first()
)
if not datasource_provider:
return {}
@@ -357,6 +354,35 @@ class DatasourceProviderService:
f"{credential_type.get_name()}",
)
+ def reauthorize_datasource_oauth_provider(
+ self,
+ name: str | None,
+ tenant_id: str,
+ provider_id: DatasourceProviderID,
+ avatar_url: str | None,
+ credentials: dict,
+ credential_id: str,
+ ) -> None:
+ """
+ update datasource oauth provider
+ """
+ with Session(db.engine) as session:
+ target_provider = session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
+ if target_provider is None:
+ raise ValueError("provider not found")
+
+ provider_credential_secret_variables = self.extract_secret_variables(
+ tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
+ )
+ for key, value in credentials.items():
+ if key in provider_credential_secret_variables:
+ credentials[key] = encrypter.encrypt_token(tenant_id, value)
+
+ target_provider.encrypted_credentials = credentials
+ target_provider.avatar_url = avatar_url or target_provider.avatar_url
+ target_provider.name = name or target_provider.name
+ session.commit()
+
def add_datasource_oauth_provider(
self,
name: str | None,
@@ -625,7 +651,7 @@ class DatasourceProviderService:
}
)
return datasource_credentials
-
+
def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]:
"""
get hard code datasource credentials.
@@ -637,14 +663,16 @@ class DatasourceProviderService:
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
- if datasource.plugin_id in ["langgenius/firecrawl_datasource", "langgenius/notion_datasource", "langgenius/jina_datasource"]:
+ if datasource.plugin_id in [
+ "langgenius/firecrawl_datasource",
+ "langgenius/notion_datasource",
+ "langgenius/jina_datasource",
+ ]:
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
credentials = self.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
- redirect_uri = (
- f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
- )
+ redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
datasource_credentials.append(
{
"provider": datasource.provider,
diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py
index b84dd0afc5..4a09e71504 100644
--- a/api/services/plugin/oauth_service.py
+++ b/api/services/plugin/oauth_service.py
@@ -11,7 +11,13 @@ class OAuthProxyService(BasePluginClient):
__KEY_PREFIX__ = "oauth_proxy_context:"
@staticmethod
- def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str):
+ def create_proxy_context(
+ user_id: str,
+ tenant_id: str,
+ plugin_id: str,
+ provider: str,
+ credential_id: str | None = None,
+ ):
"""
Create a proxy context for an OAuth 2.0 authorization request.
@@ -31,6 +37,8 @@ class OAuthProxyService(BasePluginClient):
"tenant_id": tenant_id,
"provider": provider,
}
+ if credential_id:
+ data["credential_id"] = credential_id
redis_client.setex(
f"{OAuthProxyService.__KEY_PREFIX__}{context_id}",
OAuthProxyService.__MAX_AGE__,
diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx b/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx
index a1566347b2..925a7df2d2 100644
--- a/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx
+++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx
@@ -48,7 +48,6 @@ import {
isSupportCustomRunForm,
} from '@/app/components/workflow/utils'
import Tooltip from '@/app/components/base/tooltip'
-import type { CommonNodeType } from '@/app/components/workflow/types'
import { BlockEnum, type Node, NodeRunningStatus } from '@/app/components/workflow/types'
import { useStore as useAppStore } from '@/app/components/app/store'
import { useStore } from '@/app/components/workflow/store'
@@ -71,15 +70,16 @@ import {
} from '@/app/components/plugins/plugin-auth'
import { AuthCategory } from '@/app/components/plugins/plugin-auth'
import { canFindTool } from '@/utils'
-import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
+import type { CustomRunFormProps } from '@/app/components/workflow/nodes/data-source/types'
import { DataSourceClassification } from '@/app/components/workflow/nodes/data-source/types'
import { useModalContext } from '@/context/modal-context'
import DataSourceBeforeRunForm from '@/app/components/workflow/nodes/data-source/before-run-form'
-const getCustomRunForm = (nodeType: BlockEnum, payload: CommonNodeType): React.JSX.Element => {
+const getCustomRunForm = (params: CustomRunFormProps): React.JSX.Element => {
+ const nodeType = params.payload.type
switch (nodeType) {
case BlockEnum.DataSource:
- return