mirror of https://github.com/langgenius/dify.git
Merge branch 'feat/r2' into deploy/rag-dev
# Conflicts: # api/services/rag_pipeline/rag_pipeline.py
This commit is contained in:
commit
3fce6f2581
|
|
@ -1,3 +1,4 @@
|
|||
from core.datasource.__base import datasource_provider
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
|
|
@ -43,6 +44,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
|||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
|
||||
return WebsiteCrawlDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
|
|
|
|||
|
|
@ -120,6 +120,47 @@ class DatasourceProviderService:
|
|||
|
||||
return copy_credentials_list
|
||||
|
||||
def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
|
||||
"""
|
||||
get datasource credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_id: provider id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_providers: list[DatasourceProvider] = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not datasource_providers:
|
||||
return []
|
||||
copy_credentials_list = []
|
||||
for datasource_provider in datasource_providers:
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}")
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = encrypted_credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
|
||||
copy_credentials_list.append(
|
||||
{
|
||||
"credentials": copy_credentials,
|
||||
"type": datasource_provider.auth_type,
|
||||
}
|
||||
)
|
||||
|
||||
return copy_credentials_list
|
||||
|
||||
|
||||
def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None:
|
||||
"""
|
||||
update datasource credentials.
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from models.account import Account
|
|||
from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.oauth import DatasourceProvider
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
|
|
@ -50,6 +51,7 @@ from models.workflow import (
|
|||
WorkflowType,
|
||||
)
|
||||
from services.dataset_service import DatasetService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeConfiguration,
|
||||
PipelineTemplateInfoEntity,
|
||||
|
|
@ -442,6 +444,7 @@ class RagPipelineService:
|
|||
for key, value in datasource_parameters.items():
|
||||
if not user_inputs.get(key):
|
||||
user_inputs[key] = value["value"]
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
|
|
@ -450,6 +453,14 @@ class RagPipelineService:
|
|||
tenant_id=pipeline.tenant_id,
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_real_datasource_credentials(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
provider=datasource_node_data.get('provider_name'),
|
||||
plugin_id=datasource_node_data.get('plugin_id'),
|
||||
)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
|
|
|
|||
Loading…
Reference in New Issue