fix style check (#25854)

This commit is contained in:
QuantumGhost 2025-09-17 22:37:17 +08:00 committed by GitHub
commit 1631f9438d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 195 additions and 126 deletions

View File

@ -1440,12 +1440,12 @@ def transform_datasource_credentials():
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
if notion_credentials:
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
for credential in notion_credentials:
tenant_id = credential.tenant_id
for notion_credential in notion_credentials:
tenant_id = notion_credential.tenant_id
if tenant_id not in notion_credentials_tenant_mapping:
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in notion_credentials_tenant_mapping.items():
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
# check notion plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
@ -1454,12 +1454,12 @@ def transform_datasource_credentials():
# install notion plugin
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
auth_count = 0
for credential in credentials:
for notion_tenant_credential in notion_tenant_credentials:
auth_count += 1
# get credential oauth params
access_token = credential.access_token
access_token = notion_tenant_credential.access_token
# notion info
notion_info = credential.source_info
notion_info = notion_tenant_credential.source_info
workspace_id = notion_info.get("workspace_id")
workspace_name = notion_info.get("workspace_name")
workspace_icon = notion_info.get("workspace_icon")
@ -1487,12 +1487,12 @@ def transform_datasource_credentials():
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
if firecrawl_credentials:
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for credential in firecrawl_credentials:
tenant_id = credential.tenant_id
for firecrawl_credential in firecrawl_credentials:
tenant_id = firecrawl_credential.tenant_id
if tenant_id not in firecrawl_credentials_tenant_mapping:
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items():
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
# check firecrawl plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
@ -1502,10 +1502,10 @@ def transform_datasource_credentials():
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
auth_count = 0
for credential in credentials:
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(credential.credentials)
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = {
@ -1530,12 +1530,12 @@ def transform_datasource_credentials():
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
if jina_credentials:
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for credential in jina_credentials:
tenant_id = credential.tenant_id
for jina_credential in jina_credentials:
tenant_id = jina_credential.tenant_id
if tenant_id not in jina_credentials_tenant_mapping:
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in jina_credentials_tenant_mapping.items():
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
# check jina plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
@ -1546,10 +1546,10 @@ def transform_datasource_credentials():
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
auth_count = 0
for credential in credentials:
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(credential.credentials)
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
new_credentials = {
"integration_secret": api_key,

View File

@ -124,6 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource):
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
@ -204,6 +207,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
name = args.get("name")
if text is None or name is None:
raise ValueError("Both text and name must be strings.")
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
@ -308,6 +313,8 @@ class DocumentAddByFileApi(DatasetApiResource):
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
@ -396,8 +403,12 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("current_user is required")
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
@ -577,7 +588,7 @@ class DocumentApi(DatasetApiResource):
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@ -610,7 +621,7 @@ class DocumentApi(DatasetApiResource):
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,

View File

@ -215,6 +215,9 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("Invalid user account")
try:
upload_file = FileService(db.engine).upload_file(

View File

@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]

View File

@ -46,7 +46,7 @@ class DatasourceManager:
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
controller = None
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(

View File

@ -62,7 +62,7 @@ class DatasourceProviderApiEntity(BaseModel):
"description": self.description.to_dict(),
"icon": self.icon,
"label": self.label.to_dict(),
"type": self.type.value,
"type": self.type,
"team_credentials": self.masked_credentials,
"is_team_authorization": self.is_team_authorization,
"allow_delete": self.allow_delete,

View File

@ -5,6 +5,7 @@ from mimetypes import guess_extension, guess_type
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.file import File, FileTransferMethod, FileType
from core.tools.tool_file_manager import ToolFileManager
from models.tools import ToolFile
logger = logging.getLogger(__name__)
@ -31,20 +32,20 @@ class DatasourceFileMessageTransformer:
try:
assert isinstance(message.message, DatasourceMessage.TextMessage)
tool_file_manager = ToolFileManager()
file = tool_file_manager.create_file_by_url(
tool_file: ToolFile | None = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
if tool_file:
url = f"/files/datasources/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}"
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.TEXT,
@ -71,7 +72,7 @@ class DatasourceFileMessageTransformer:
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
tool_file_manager = ToolFileManager()
file = tool_file_manager.create_file_by_raw(
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
@ -79,25 +80,27 @@ class DatasourceFileMessageTransformer:
mimetype=mimetype,
filename=filename,
)
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype))
# check if file is image
if "image" in mimetype:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.BINARY_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
if blob_tool_file:
url = cls.get_datasource_file_url(
datasource_file_id=blob_tool_file.id, extension=guess_extension(blob_tool_file.mimetype)
)
# check if file is image
if "image" in mimetype:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.BINARY_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
elif message.type == DatasourceMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file", None)
file: Optional[File] = meta.get("file")
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None

View File

@ -67,7 +67,7 @@ class VariablePool(BaseModel):
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
# Add rag pipeline variables to the variable pool
if self.rag_pipeline_variables:
rag_pipeline_variables_map = defaultdict(dict)
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
for rag_var in self.rag_pipeline_variables:
node_id = rag_var.variable.belong_to_node_id
key = rag_var.variable.variable

View File

@ -19,7 +19,7 @@ from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@ -87,29 +87,18 @@ class DatasourceNode(Node):
raise DatasourceNodeError("Invalid datasource info format")
datasource_info: dict[str, Any] = datasource_info_value
# get datasource runtime
try:
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.datasource_manager import DatasourceManager
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
except DatasourceNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)
)
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
parameters_for_log = datasource_info
@ -282,27 +271,6 @@ class DatasourceNode(Node):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _append_variables_recursively(
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
variable_pool.add([node_id] + [".".join(variable_key_list)], variable_value)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -423,13 +391,6 @@ class DatasourceNode(Node):
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = {
key: value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
json.append(message.message.json_object)
elif message.type == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel
class DatasourceNodeRunApiEntity(BaseModel):
pipeline_id: str
node_id: str
inputs: Mapping[str, Any]
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
is_published: bool

View File

@ -578,10 +578,10 @@ class RagPipelineService:
)
yield start_event.model_dump()
try:
for message in online_document_result:
for online_document_message in online_document_result:
end_time = time.time()
online_document_event = DatasourceCompletedEvent(
data=message.result, time_consuming=round(end_time - start_time, 2)
data=online_document_message.result, time_consuming=round(end_time - start_time, 2)
)
yield online_document_event.model_dump()
except Exception as e:
@ -607,10 +607,10 @@ class RagPipelineService:
completed=0,
)
yield start_event.model_dump()
for message in online_drive_result:
for online_drive_message in online_drive_result:
end_time = time.time()
online_drive_event = DatasourceCompletedEvent(
data=message.result,
data=online_drive_message.result,
time_consuming=round(end_time - start_time, 2),
total=None,
completed=None,
@ -627,19 +627,19 @@ class RagPipelineService:
)
start_time = time.time()
try:
for message in website_crawl_result:
for website_crawl_message in website_crawl_result:
end_time = time.time()
if message.result.status == "completed":
if website_crawl_message.result.status == "completed":
crawl_event = DatasourceCompletedEvent(
data=message.result.web_info_list or [],
total=message.result.total,
completed=message.result.completed,
data=website_crawl_message.result.web_info_list or [],
total=website_crawl_message.result.total,
completed=website_crawl_message.result.completed,
time_consuming=round(end_time - start_time, 2),
)
else:
crawl_event = DatasourceProcessingEvent(
total=message.result.total,
completed=message.result.completed,
total=website_crawl_message.result.total,
completed=website_crawl_message.result.completed,
)
yield crawl_event.model_dump()
except Exception as e:
@ -721,12 +721,12 @@ class RagPipelineService:
)
try:
variables: dict[str, Any] = {}
for message in online_document_result:
if message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
for online_document_message in online_document_result:
if online_document_message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(online_document_message.message, DatasourceMessage.VariableMessage)
variable_name = online_document_message.message.variable_name
variable_value = online_document_message.message.variable_value
if online_document_message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
@ -791,8 +791,9 @@ class RagPipelineService:
for event in generator:
if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent)):
node_run_result = event.node_run_result
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
if node_run_result:
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
break
if not node_run_result:
@ -1356,3 +1357,99 @@ class RagPipelineService:
workflow_thread_pool_id=None,
is_retry=True,
)
def get_datasource_plugins(self, tenant_id: str, dataset_id: str, is_published: bool) -> list[dict]:
"""
Get datasource plugins
"""
dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
workflow: Workflow | None = None
if is_published:
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not pipeline or not workflow:
raise ValueError("Pipeline or workflow not found")
datasource_nodes = workflow.graph_dict.get("nodes", [])
datasource_plugins = []
for datasource_node in datasource_nodes:
if datasource_node.get("type") == "datasource":
datasource_node_data = datasource_node.get("data", {})
if not datasource_node_data:
continue
variables = workflow.rag_pipeline_variables
if variables:
variables_map = {item["variable"]: item for item in variables}
else:
variables_map = {}
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
user_input_variables_keys = []
user_input_variables = []
for _, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
user_input_variables_keys.append(last_part)
elif value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
user_input_variables_keys.append(last_part)
for key, value in variables_map.items():
if key in user_input_variables_keys:
user_input_variables.append(value)
# get credentials
datasource_provider_service: DatasourceProviderService = DatasourceProviderService()
credentials: list[dict[Any, Any]] = datasource_provider_service.list_datasource_credentials(
tenant_id=tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
credential_info_list: list[Any] = []
for credential in credentials:
credential_info_list.append(
{
"id": credential.get("id"),
"name": credential.get("name"),
"type": credential.get("type"),
"is_default": credential.get("is_default"),
}
)
datasource_plugins.append(
{
"node_id": datasource_node.get("id"),
"plugin_id": datasource_node_data.get("plugin_id"),
"provider_name": datasource_node_data.get("provider_name"),
"datasource_type": datasource_node_data.get("provider_type"),
"title": datasource_node_data.get("title"),
"user_input_variables": user_input_variables,
"credentials": credential_info_list,
}
)
return datasource_plugins
def get_pipeline(self, tenant_id: str, dataset_id: str) -> Pipeline:
"""
Get pipeline
"""
dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
return pipeline

View File

@ -99,12 +99,6 @@ class ToolTransformService:
provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.declaration.identity.icon
)
else:
provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value,
provider_name=provider.name,
icon=provider.declaration.identity.icon,
)
@classmethod
def builtin_provider_to_user_provider(