mirror of https://github.com/langgenius/dify.git
fix style check
This commit is contained in:
parent
73d4bb596a
commit
efce1b04e0
|
|
@ -1440,12 +1440,12 @@ def transform_datasource_credentials():
|
|||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
if notion_credentials:
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for credential in notion_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
for notion_credential in notion_credentials:
|
||||
tenant_id = notion_credential.tenant_id
|
||||
if tenant_id not in notion_credentials_tenant_mapping:
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in notion_credentials_tenant_mapping.items():
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
|
|
@ -1454,12 +1454,12 @@ def transform_datasource_credentials():
|
|||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = credential.access_token
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = credential.source_info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
|
|
@ -1487,12 +1487,12 @@ def transform_datasource_credentials():
|
|||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
if firecrawl_credentials:
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for credential in firecrawl_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
for firecrawl_credential in firecrawl_credentials:
|
||||
tenant_id = firecrawl_credential.tenant_id
|
||||
if tenant_id not in firecrawl_credentials_tenant_mapping:
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
|
|
@ -1502,10 +1502,10 @@ def transform_datasource_credentials():
|
|||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(credential.credentials)
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
|
|
@ -1530,12 +1530,12 @@ def transform_datasource_credentials():
|
|||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||
if jina_credentials:
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for credential in jina_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
for jina_credential in jina_credentials:
|
||||
tenant_id = jina_credential.tenant_id
|
||||
if tenant_id not in jina_credentials_tenant_mapping:
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in jina_credentials_tenant_mapping.items():
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
|
|
@ -1546,10 +1546,10 @@ def transform_datasource_credentials():
|
|||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(credential.credentials)
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
|
|
|
|||
|
|
@ -124,6 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
|
|
@ -204,6 +207,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both text and name must be strings.")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
|
|
@ -308,6 +313,8 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
|
|
@ -396,8 +403,12 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
|
|
@ -577,7 +588,7 @@ class DocumentApi(DatasetApiResource):
|
|||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
|
|
@ -610,7 +621,7 @@ class DocumentApi(DatasetApiResource):
|
|||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
|
|
|
|||
|
|
@ -214,7 +214,10 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
|||
raise UnsupportedFileTypeError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
raise FilenameNotExistsError+
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return dict(blocking_response.to_dict())
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class DatasourceManager:
|
|||
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
|
||||
if not provider_entity:
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
|
||||
|
||||
controller = None
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
|||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"label": self.label.to_dict(),
|
||||
"type": self.type.value,
|
||||
"type": self.type,
|
||||
"team_credentials": self.masked_credentials,
|
||||
"is_team_authorization": self.is_team_authorization,
|
||||
"allow_delete": self.allow_delete,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Optional
|
|||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -32,20 +33,20 @@ class DatasourceFileMessageTransformer:
|
|||
try:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_url(
|
||||
tool_file: ToolFile | None = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if tool_file:
|
||||
url = f"/files/datasources/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
|
||||
|
||||
url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
||||
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
|
|
@ -72,7 +73,7 @@ class DatasourceFileMessageTransformer:
|
|||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_raw(
|
||||
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
|
|
@ -80,25 +81,27 @@ class DatasourceFileMessageTransformer:
|
|||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BINARY_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
if blob_tool_file:
|
||||
url = cls.get_datasource_file_url(
|
||||
datasource_file_id=blob_tool_file.id, extension=guess_extension(blob_tool_file.mimetype)
|
||||
)
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BINARY_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file", None)
|
||||
file: Optional[File] = meta.get("file")
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class VariablePool(BaseModel):
|
|||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add rag pipeline variables to the variable pool
|
||||
if self.rag_pipeline_variables:
|
||||
rag_pipeline_variables_map = defaultdict(dict)
|
||||
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
|
||||
for rag_var in self.rag_pipeline_variables:
|
||||
node_id = rag_var.variable.belong_to_node_id
|
||||
key = rag_var.variable.variable
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from core.file.enums import FileTransferMethod, FileType
|
|||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
|
|
@ -87,29 +87,18 @@ class DatasourceNode(Node):
|
|||
raise DatasourceNodeError("Invalid datasource info format")
|
||||
datasource_info: dict[str, Any] = datasource_info_value
|
||||
# get datasource runtime
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
except DatasourceNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to get datasource runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
parameters_for_log = datasource_info
|
||||
|
||||
|
|
@ -282,27 +271,6 @@ class DatasourceNode(Node):
|
|||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _append_variables_recursively(
|
||||
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
|
||||
):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
variable_pool.add([node_id] + [".".join(variable_key_list)], variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
|
@ -423,13 +391,6 @@ class DatasourceNode(Node):
|
|||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||
agent_execution_metadata = {
|
||||
key: value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
|||
class DatasourceNodeRunApiEntity(BaseModel):
|
||||
pipeline_id: str
|
||||
node_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
inputs: dict[str, Any]
|
||||
datasource_type: str
|
||||
credential_id: Optional[str] = None
|
||||
is_published: bool
|
||||
|
|
|
|||
|
|
@ -580,10 +580,10 @@ class RagPipelineService:
|
|||
)
|
||||
yield start_event.model_dump()
|
||||
try:
|
||||
for message in online_document_result:
|
||||
for online_document_message in online_document_result:
|
||||
end_time = time.time()
|
||||
online_document_event = DatasourceCompletedEvent(
|
||||
data=message.result, time_consuming=round(end_time - start_time, 2)
|
||||
data=online_document_message.result, time_consuming=round(end_time - start_time, 2)
|
||||
)
|
||||
yield online_document_event.model_dump()
|
||||
except Exception as e:
|
||||
|
|
@ -609,10 +609,10 @@ class RagPipelineService:
|
|||
completed=0,
|
||||
)
|
||||
yield start_event.model_dump()
|
||||
for message in online_drive_result:
|
||||
for online_drive_message in online_drive_result:
|
||||
end_time = time.time()
|
||||
online_drive_event = DatasourceCompletedEvent(
|
||||
data=message.result,
|
||||
data=online_drive_message.result,
|
||||
time_consuming=round(end_time - start_time, 2),
|
||||
total=None,
|
||||
completed=None,
|
||||
|
|
@ -629,19 +629,19 @@ class RagPipelineService:
|
|||
)
|
||||
start_time = time.time()
|
||||
try:
|
||||
for message in website_crawl_result:
|
||||
for website_crawl_message in website_crawl_result:
|
||||
end_time = time.time()
|
||||
if message.result.status == "completed":
|
||||
if website_crawl_message.result.status == "completed":
|
||||
crawl_event = DatasourceCompletedEvent(
|
||||
data=message.result.web_info_list or [],
|
||||
total=message.result.total,
|
||||
completed=message.result.completed,
|
||||
data=website_crawl_message.result.web_info_list or [],
|
||||
total=website_crawl_message.result.total,
|
||||
completed=website_crawl_message.result.completed,
|
||||
time_consuming=round(end_time - start_time, 2),
|
||||
)
|
||||
else:
|
||||
crawl_event = DatasourceProcessingEvent(
|
||||
total=message.result.total,
|
||||
completed=message.result.completed,
|
||||
total=website_crawl_message.result.total,
|
||||
completed=website_crawl_message.result.completed,
|
||||
)
|
||||
yield crawl_event.model_dump()
|
||||
except Exception as e:
|
||||
|
|
@ -723,12 +723,12 @@ class RagPipelineService:
|
|||
)
|
||||
try:
|
||||
variables: dict[str, Any] = {}
|
||||
for message in online_document_result:
|
||||
if message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
for online_document_message in online_document_result:
|
||||
if online_document_message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(online_document_message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = online_document_message.message.variable_name
|
||||
variable_value = online_document_message.message.variable_value
|
||||
if online_document_message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
|
|
@ -793,8 +793,9 @@ class RagPipelineService:
|
|||
for event in generator:
|
||||
if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent)):
|
||||
node_run_result = event.node_run_result
|
||||
# sign output files
|
||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
|
||||
if node_run_result:
|
||||
# sign output files
|
||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
|
||||
break
|
||||
|
||||
if not node_run_result:
|
||||
|
|
@ -1358,3 +1359,99 @@ class RagPipelineService:
|
|||
workflow_thread_pool_id=None,
|
||||
is_retry=True,
|
||||
)
|
||||
|
||||
def get_datasource_plugins(self, tenant_id: str, dataset_id: str, is_published: bool) -> list[dict]:
|
||||
"""
|
||||
Get datasource plugins
|
||||
"""
|
||||
dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
workflow: Workflow | None = None
|
||||
if is_published:
|
||||
workflow = self.get_published_workflow(pipeline=pipeline)
|
||||
else:
|
||||
workflow = self.get_draft_workflow(pipeline=pipeline)
|
||||
if not pipeline or not workflow:
|
||||
raise ValueError("Pipeline or workflow not found")
|
||||
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
datasource_plugins = []
|
||||
for datasource_node in datasource_nodes:
|
||||
if datasource_node.get("type") == "datasource":
|
||||
datasource_node_data = datasource_node.get("data", {})
|
||||
if not datasource_node_data:
|
||||
continue
|
||||
|
||||
variables = workflow.rag_pipeline_variables
|
||||
if variables:
|
||||
variables_map = {item["variable"]: item for item in variables}
|
||||
else:
|
||||
variables_map = {}
|
||||
|
||||
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||
user_input_variables_keys = []
|
||||
user_input_variables = []
|
||||
|
||||
for _, value in datasource_parameters.items():
|
||||
if value.get("value") and isinstance(value.get("value"), str):
|
||||
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||
match = re.match(pattern, value["value"])
|
||||
if match:
|
||||
full_path = match.group(1)
|
||||
last_part = full_path.split(".")[-1]
|
||||
user_input_variables_keys.append(last_part)
|
||||
elif value.get("value") and isinstance(value.get("value"), list):
|
||||
last_part = value.get("value")[-1]
|
||||
user_input_variables_keys.append(last_part)
|
||||
for key, value in variables_map.items():
|
||||
if key in user_input_variables_keys:
|
||||
user_input_variables.append(value)
|
||||
|
||||
# get credentials
|
||||
datasource_provider_service: DatasourceProviderService = DatasourceProviderService()
|
||||
credentials: list[dict[Any, Any]] = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_node_data.get("provider_name"),
|
||||
plugin_id=datasource_node_data.get("plugin_id"),
|
||||
)
|
||||
credential_info_list: list[Any] = []
|
||||
for credential in credentials:
|
||||
credential_info_list.append(
|
||||
{
|
||||
"id": credential.get("id"),
|
||||
"name": credential.get("name"),
|
||||
"type": credential.get("type"),
|
||||
"is_default": credential.get("is_default"),
|
||||
}
|
||||
)
|
||||
|
||||
datasource_plugins.append(
|
||||
{
|
||||
"node_id": datasource_node.get("id"),
|
||||
"plugin_id": datasource_node_data.get("plugin_id"),
|
||||
"provider_name": datasource_node_data.get("provider_name"),
|
||||
"datasource_type": datasource_node_data.get("provider_type"),
|
||||
"title": datasource_node_data.get("title"),
|
||||
"user_input_variables": user_input_variables,
|
||||
"credentials": credential_info_list,
|
||||
}
|
||||
)
|
||||
|
||||
return datasource_plugins
|
||||
|
||||
def get_pipeline(self, tenant_id: str, dataset_id: str) -> Pipeline:
|
||||
"""
|
||||
Get pipeline
|
||||
"""
|
||||
dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
return pipeline
|
||||
|
|
|
|||
|
|
@ -99,12 +99,6 @@ class ToolTransformService:
|
|||
provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.declaration.identity.icon
|
||||
)
|
||||
else:
|
||||
provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value,
|
||||
provider_name=provider.name,
|
||||
icon=provider.declaration.identity.icon,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
|
|
|
|||
Loading…
Reference in New Issue